diff --git a/token/rate.go b/token/rate.go index d9fc558..23aa43c 100644 --- a/token/rate.go +++ b/token/rate.go @@ -2,6 +2,7 @@ package token import ( "errors" + "sync" "time" ) @@ -18,6 +19,7 @@ func (c realClock) Now() time.Time { // Limiter implements a token bucket rate limiter. It allows a burst of // requests up to capacity, then refills tokens at the specified rate per second. type Limiter struct { + mu sync.Mutex capacity, tokens, rate float64 lastRefillAt time.Time clock clock @@ -54,6 +56,9 @@ func NewLimiterWithClock(capacity, rate float64, clock clock) (*Limiter, error) // available and returns true. If no tokens are available, it returns false // without blocking. func (lim *Limiter) Allow() bool { + lim.mu.Lock() + defer lim.mu.Unlock() + lim.refill() if lim.tokens >= 1 { diff --git a/token/rate_test.go b/token/rate_test.go index b2bac05..f551420 100644 --- a/token/rate_test.go +++ b/token/rate_test.go @@ -1,6 +1,8 @@ package token_test import ( + "sync" + "sync/atomic" "testing" "time" @@ -106,3 +108,55 @@ func TestLimiter_Allow(t *testing.T) { }) } } + +func TestLimiter_Allow_Concurrent(t *testing.T) { + lim, err := token.NewLimiter(100, 0) + require.NoError(t, err) + + var ( + allowed atomic.Int64 + wg sync.WaitGroup + ) + + // Launch 200 goroutines, but only 100 should be allowed + + for range 200 { + wg.Add(1) + + go func() { + defer wg.Done() + + if lim.Allow() { + allowed.Add(1) + } + }() + } + + wg.Wait() + + // With capacity 100 and rate 0, exactly 100 should be allowed + require.Equal(t, int64(100), allowed.Load(), "expected exactly 100 requests to be allowed") +} + +func TestLimiter_Allow_ConcurrentWithRefill(t *testing.T) { + lim, err := token.NewLimiter(10, 1000) + require.NoError(t, err) + + var wg sync.WaitGroup + + // Hammer the limiter from multiple goroutines + for range 100 { + wg.Add(1) + + go func() { + defer wg.Done() + + for range 100 { + lim.Allow() + } + }() + } + + wg.Wait() + // If we get here without race detector complaints, the test passes +}