diff --git a/token/rate.go b/token/rate.go index 23aa43c..1c716af 100644 --- a/token/rate.go +++ b/token/rate.go @@ -1,7 +1,6 @@ package token import ( - "errors" "sync" "time" ) @@ -27,29 +26,20 @@ type Limiter struct { // NewLimiter creates a new rate limiter with the given capacity and refill rate. // Capacity is the maximum burst size. Rate is tokens added per second. -// Returns an error if capacity or rate is negative. -func NewLimiter(capacity, rate float64) (*Limiter, error) { +func NewLimiter(capacity, rate uint32) *Limiter { return NewLimiterWithClock(capacity, rate, realClock{}) } // NewLimiterWithClock creates a new rate limiter with a custom clock. // Use this constructor for testing with a mock clock. -func NewLimiterWithClock(capacity, rate float64, clock clock) (*Limiter, error) { - if capacity < 0 { - return nil, errors.New("capacity must be greater than zero") - } - - if rate < 0 { - return nil, errors.New("rate must be greater than zero") - } - +func NewLimiterWithClock(capacity, rate uint32, clock clock) *Limiter { return &Limiter{ - capacity: capacity, - tokens: capacity, - rate: rate, + capacity: float64(capacity), + tokens: float64(capacity), + rate: float64(rate), clock: clock, lastRefillAt: clock.Now(), - }, nil + } } // Allow reports whether a request is allowed. It consumes one token if diff --git a/token/rate_test.go b/token/rate_test.go index f551420..d954cb1 100644 --- a/token/rate_test.go +++ b/token/rate_test.go @@ -22,28 +22,9 @@ func (c *testClock) advance(by time.Duration) { c.now = c.now.Add(by) } -func TestNewLimiter(t *testing.T) { - lim, err := token.NewLimiter(5, 2) - require.NoError(t, err) - require.True(t, lim.Allow()) -} - -func TestNewLimiterWithClock_NegativeCapacity(t *testing.T) { - clock := &testClock{now: time.Now()} - _, err := token.NewLimiterWithClock(-1, 2, clock) - require.Error(t, err) -} - -func TestNewLimiterWithClock_NegativeRate(t *testing.T) { - clock := &testClock{now: time.Now()} - _, err := token.NewLimiterWithClock(5, -1, clock) - require.Error(t, err) -} - func TestLimiter_Allow_ClockGoesBackwards(t *testing.T) { clock := &testClock{now: time.Now()} - lim, err := token.NewLimiterWithClock(1, 1, clock) - require.NoError(t, err) + lim := token.NewLimiterWithClock(1, 1, clock) // Drain the token require.True(t, lim.Allow()) @@ -56,8 +37,8 @@ func TestLimiter_Allow_ClockGoesBackwards(t *testing.T) { func TestLimiter_Allow(t *testing.T) { type fields struct { - capacity float64 - rate float64 + capacity uint32 + rate uint32 } clock := &testClock{now: time.Now()} @@ -93,8 +74,7 @@ func TestLimiter_Allow(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - lim, err := token.NewLimiterWithClock(tt.fields.capacity, tt.fields.rate, clock) - require.NoError(t, err) + lim := token.NewLimiterWithClock(tt.fields.capacity, tt.fields.rate, clock) for range tt.previousAttempts { lim.Allow() @@ -110,8 +90,7 @@ func TestLimiter_Allow(t *testing.T) { } func TestLimiter_Allow_Concurrent(t *testing.T) { - lim, err := token.NewLimiter(100, 0) - require.NoError(t, err) + lim := token.NewLimiter(100, 0) var ( allowed atomic.Int64 @@ -139,12 +118,16 @@ func TestLimiter_Allow_Concurrent(t *testing.T) { } func TestLimiter_Allow_ConcurrentWithRefill(t *testing.T) { - lim, err := token.NewLimiter(10, 1000) - require.NoError(t, err) + clock := &testClock{now: time.Now()} + lim := token.NewLimiterWithClock(10, 1000, clock) - var wg sync.WaitGroup + var ( + allowed atomic.Int64 + wg sync.WaitGroup + ) // Hammer the limiter from multiple goroutines + // Clock doesn't advance, so no refill happens - exactly 10 should be allowed for range 100 { wg.Add(1) @@ -152,11 +135,14 @@ func TestLimiter_Allow_ConcurrentWithRefill(t *testing.T) { defer wg.Done() for range 100 { - lim.Allow() + if lim.Allow() { + allowed.Add(1) + } } }() } wg.Wait() - // If we get here without race detector complaints, the test passes + + require.Equal(t, int64(10), allowed.Load()) } diff --git a/token/registry.go b/token/registry.go index 8cbc071..31b42b8 100644 --- a/token/registry.go +++ b/token/registry.go @@ -1,7 +1,6 @@ package token import ( - "fmt" "sync" ) @@ -10,19 +9,15 @@ type ( Registry struct { mu sync.Mutex limiters map[Identifier]*Limiter - capacity, rate float64 + capacity, rate uint32 } ) -func NewRegistry(capacity, rate float64, users ...Identifier) (*Registry, error) { +func NewRegistry(capacity, rate uint32, users ...Identifier) (*Registry, error) { limiters := make(map[Identifier]*Limiter) for _, user := range users { - limiter, err := NewLimiter(capacity, rate) - if err != nil { - return nil, fmt.Errorf("fail to create a new limiter %w", err) - } - + limiter := NewLimiter(capacity, rate) limiters[user] = limiter } @@ -39,7 +34,7 @@ func (r *Registry) Allow(key Identifier) bool { lim, ok := r.limiters[key] if !ok { - lim, _ = NewLimiter(r.capacity, r.rate) + lim = NewLimiter(r.capacity, r.rate) r.limiters[key] = lim } diff --git a/token/registry_test.go b/token/registry_test.go index 336e225..78ee4b9 100644 --- a/token/registry_test.go +++ b/token/registry_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/serroba/rate/token" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -21,16 +22,6 @@ func TestNewRegistry_WithUsers(t *testing.T) { require.NotNil(t, reg) } -func TestNewRegistry_InvalidCapacity(t *testing.T) { - _, err := token.NewRegistry(-1, 2, "alice") - require.Error(t, err) -} - -func TestNewRegistry_InvalidRate(t *testing.T) { - _, err := token.NewRegistry(10, -1, "alice") - require.Error(t, err) -} - func TestRegistry_Allow_ExistingUser(t *testing.T) { reg, err := token.NewRegistry(2, 0, "alice") require.NoError(t, err) @@ -94,6 +85,41 @@ func TestRegistry_Allow_Concurrent(t *testing.T) { require.Equal(t, int64(200), allowed.Load()) } +func TestRegistry_Deny_Concurrent(t *testing.T) { + reg, err := token.NewRegistry(100, 0) + require.NoError(t, err) + + var ( + allowed atomic.Int64 + deny atomic.Int64 + wg sync.WaitGroup + ) + + // 50 goroutines per user, 4 users = 200 goroutines + users := []token.Identifier{"alice", "bob", "charlie", "diana"} + for _, user := range users { + for range 110 { + wg.Add(1) + + go func(u token.Identifier) { + defer wg.Done() + + if reg.Allow(u) { + allowed.Add(1) + } else { + deny.Add(1) + } + }(user) + } + } + + wg.Wait() + + // Each user has capacity 100, only 50 requests each, so all should be allowed + assert.Equal(t, int64(400), allowed.Load()) + assert.Equal(t, int64(40), deny.Load()) +} + func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) { reg, err := token.NewRegistry(5, 0) require.NoError(t, err) @@ -107,7 +133,7 @@ func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) { go func(id int) { defer wg.Done() - user := token.Identifier(string(rune('a' + id%26))) + user := token.Identifier(rune('a' + id%26)) reg.Allow(user) }(i) }