upload-service/internal/middleware/ratelimit.go

67 lines
1.3 KiB
Go

package middleware
import (
"net/http"
"strings"
"sync"
"golang.org/x/time/rate"
)
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.Mutex
rate rate.Limit
burst int
}
func NewRateLimiter(r float64, burst int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(r),
burst: burst,
}
}
func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
l, exists := rl.limiters[ip]
if !exists {
l = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[ip] = l
}
return l
}
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := clientIP(r)
if !rl.getLimiter(ip).Allow() {
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
func clientIP(r *http.Request) string {
// Cloudflare / proxy headers
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
return ip
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
return strings.Split(xff, ",")[0]
}
// Strip port
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}