package middleware import ( "net/http" "strings" "sync" "golang.org/x/time/rate" ) type RateLimiter struct { limiters map[string]*rate.Limiter mu sync.Mutex rate rate.Limit burst int } func NewRateLimiter(r float64, burst int) *RateLimiter { return &RateLimiter{ limiters: make(map[string]*rate.Limiter), rate: rate.Limit(r), burst: burst, } } func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter { rl.mu.Lock() defer rl.mu.Unlock() l, exists := rl.limiters[ip] if !exists { l = rate.NewLimiter(rl.rate, rl.burst) rl.limiters[ip] = l } return l } func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := clientIP(r) if !rl.getLimiter(ip).Allow() { http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } func clientIP(r *http.Request) string { // Cloudflare / proxy headers if ip := r.Header.Get("CF-Connecting-IP"); ip != "" { return ip } if ip := r.Header.Get("X-Real-IP"); ip != "" { return ip } if xff := r.Header.Get("X-Forwarded-For"); xff != "" { return strings.Split(xff, ",")[0] } // Strip port ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip }