diff --git a/README.md b/README.md index 6069e86..f9df4d6 100644 --- a/README.md +++ b/README.md @@ -91,3 +91,32 @@ func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, }) } ``` + + +## handle concurrency problem with sync.Map + +```golang +var ipLimiterMap sync.Map + +// RateLimiterMiddleware - 建立 ratelimiter middleware +func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, burst int) http.Handler { + + // var mu sync.Mutex + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Fetch IP + ip := r.getIP(req) + // Create limiter if not present for IP + limiterAny, _ := ipLimiterMap.LoadOrStore(ip, rate.NewLimiter(limit, burst)) + limiter := limiterAny.(*rate.Limiter) + // return error if the limit has been reached + if !limiter.Allow() { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{"error": "Too many requests"}) + return + } + next.ServeHTTP(w, req) + }) +} +``` + diff --git a/internal/service/rate_limiter/rate-limiter.go b/internal/service/rate_limiter/rate-limiter.go index e562916..69add5e 100644 --- a/internal/service/rate_limiter/rate-limiter.go +++ b/internal/service/rate_limiter/rate-limiter.go @@ -22,21 +22,18 @@ func (r *RateLimiter) getIP(req *http.Request) string { return host } +var ipLimiterMap sync.Map + // RateLimiterMiddleware - 建立 ratelimiter middleware func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, burst int) http.Handler { - ipLimiterMap := make(map[string]*rate.Limiter) - var mu sync.Mutex + + // var mu sync.Mutex return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Fetch IP ip := r.getIP(req) // Create limiter if not present for IP - mu.Lock() - limiter, exists := ipLimiterMap[ip] - if !exists { - limiter = rate.NewLimiter(limit, burst) - ipLimiterMap[ip] = limiter - } - mu.Unlock() + limiterAny, _ := ipLimiterMap.LoadOrStore(ip, rate.NewLimiter(limit, burst)) + limiter := limiterAny.(*rate.Limiter) // return error if the limit has been reached if !limiter.Allow() { w.Header().Set("Content-Type", "application/json")