67 lines
1.3 KiB
Go
67 lines
1.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
type RateLimiter struct {
|
|
limiters map[string]*rate.Limiter
|
|
mu sync.Mutex
|
|
rate rate.Limit
|
|
burst int
|
|
}
|
|
|
|
func NewRateLimiter(r float64, burst int) *RateLimiter {
|
|
return &RateLimiter{
|
|
limiters: make(map[string]*rate.Limiter),
|
|
rate: rate.Limit(r),
|
|
burst: burst,
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
l, exists := rl.limiters[ip]
|
|
if !exists {
|
|
l = rate.NewLimiter(rl.rate, rl.burst)
|
|
rl.limiters[ip] = l
|
|
}
|
|
return l
|
|
}
|
|
|
|
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := clientIP(r)
|
|
if !rl.getLimiter(ip).Allow() {
|
|
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func clientIP(r *http.Request) string {
|
|
// Cloudflare / proxy headers
|
|
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
|
return ip
|
|
}
|
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|
return ip
|
|
}
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
return strings.Split(xff, ",")[0]
|
|
}
|
|
// Strip port
|
|
ip := r.RemoteAddr
|
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
ip = ip[:idx]
|
|
}
|
|
return ip
|
|
}
|