Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions token/rate.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package token

import (
"errors"
"sync"
"time"
)
Expand All @@ -27,29 +26,20 @@ type Limiter struct {

// NewLimiter creates a new rate limiter with the given capacity and refill rate.
// Capacity is the maximum burst size. Rate is tokens added per second.
// Returns an error if capacity or rate is negative.
func NewLimiter(capacity, rate float64) (*Limiter, error) {
func NewLimiter(capacity, rate uint32) *Limiter {
return NewLimiterWithClock(capacity, rate, realClock{})
}

// NewLimiterWithClock creates a new rate limiter with a custom clock.
// Use this constructor for testing with a mock clock.
func NewLimiterWithClock(capacity, rate float64, clock clock) (*Limiter, error) {
if capacity < 0 {
return nil, errors.New("capacity must be greater than zero")
}

if rate < 0 {
return nil, errors.New("rate must be greater than zero")
}

func NewLimiterWithClock(capacity, rate uint32, clock clock) *Limiter {
return &Limiter{
capacity: capacity,
tokens: capacity,
rate: rate,
capacity: float64(capacity),
tokens: float64(capacity),
rate: float64(rate),
clock: clock,
lastRefillAt: clock.Now(),
}, nil
}
}

// Allow reports whether a request is allowed. It consumes one token if
Expand Down
48 changes: 17 additions & 31 deletions token/rate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,9 @@ func (c *testClock) advance(by time.Duration) {
c.now = c.now.Add(by)
}

func TestNewLimiter(t *testing.T) {
lim, err := token.NewLimiter(5, 2)
require.NoError(t, err)
require.True(t, lim.Allow())
}

func TestNewLimiterWithClock_NegativeCapacity(t *testing.T) {
clock := &testClock{now: time.Now()}
_, err := token.NewLimiterWithClock(-1, 2, clock)
require.Error(t, err)
}

func TestNewLimiterWithClock_NegativeRate(t *testing.T) {
clock := &testClock{now: time.Now()}
_, err := token.NewLimiterWithClock(5, -1, clock)
require.Error(t, err)
}

func TestLimiter_Allow_ClockGoesBackwards(t *testing.T) {
clock := &testClock{now: time.Now()}
lim, err := token.NewLimiterWithClock(1, 1, clock)
require.NoError(t, err)
lim := token.NewLimiterWithClock(1, 1, clock)

// Drain the token
require.True(t, lim.Allow())
Expand All @@ -56,8 +37,8 @@ func TestLimiter_Allow_ClockGoesBackwards(t *testing.T) {

func TestLimiter_Allow(t *testing.T) {
type fields struct {
capacity float64
rate float64
capacity uint32
rate uint32
}

clock := &testClock{now: time.Now()}
Expand Down Expand Up @@ -93,8 +74,7 @@ func TestLimiter_Allow(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lim, err := token.NewLimiterWithClock(tt.fields.capacity, tt.fields.rate, clock)
require.NoError(t, err)
lim := token.NewLimiterWithClock(tt.fields.capacity, tt.fields.rate, clock)

for range tt.previousAttempts {
lim.Allow()
Expand All @@ -110,8 +90,7 @@ func TestLimiter_Allow(t *testing.T) {
}

func TestLimiter_Allow_Concurrent(t *testing.T) {
lim, err := token.NewLimiter(100, 0)
require.NoError(t, err)
lim := token.NewLimiter(100, 0)

var (
allowed atomic.Int64
Expand Down Expand Up @@ -139,24 +118,31 @@ func TestLimiter_Allow_Concurrent(t *testing.T) {
}

func TestLimiter_Allow_ConcurrentWithRefill(t *testing.T) {
lim, err := token.NewLimiter(10, 1000)
require.NoError(t, err)
clock := &testClock{now: time.Now()}
lim := token.NewLimiterWithClock(10, 1000, clock)

var wg sync.WaitGroup
var (
allowed atomic.Int64
wg sync.WaitGroup
)

// Hammer the limiter from multiple goroutines
// Clock doesn't advance, so no refill happens - exactly 10 should be allowed
for range 100 {
wg.Add(1)

go func() {
defer wg.Done()

for range 100 {
lim.Allow()
if lim.Allow() {
allowed.Add(1)
}
}
}()
}

wg.Wait()
// If we get here without race detector complaints, the test passes

require.Equal(t, int64(10), allowed.Load())
}
13 changes: 4 additions & 9 deletions token/registry.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package token

import (
"fmt"
"sync"
)

Expand All @@ -10,19 +9,15 @@ type (
Registry struct {
mu sync.Mutex
limiters map[Identifier]*Limiter
capacity, rate float64
capacity, rate uint32
}
)

func NewRegistry(capacity, rate float64, users ...Identifier) (*Registry, error) {
func NewRegistry(capacity, rate uint32, users ...Identifier) (*Registry, error) {
limiters := make(map[Identifier]*Limiter)

for _, user := range users {
limiter, err := NewLimiter(capacity, rate)
if err != nil {
return nil, fmt.Errorf("fail to create a new limiter %w", err)
}

limiter := NewLimiter(capacity, rate)
limiters[user] = limiter
}

Expand All @@ -39,7 +34,7 @@ func (r *Registry) Allow(key Identifier) bool {

lim, ok := r.limiters[key]
if !ok {
lim, _ = NewLimiter(r.capacity, r.rate)
lim = NewLimiter(r.capacity, r.rate)
r.limiters[key] = lim
}

Expand Down
48 changes: 37 additions & 11 deletions token/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/serroba/rate/token"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -21,16 +22,6 @@ func TestNewRegistry_WithUsers(t *testing.T) {
require.NotNil(t, reg)
}

func TestNewRegistry_InvalidCapacity(t *testing.T) {
_, err := token.NewRegistry(-1, 2, "alice")
require.Error(t, err)
}

func TestNewRegistry_InvalidRate(t *testing.T) {
_, err := token.NewRegistry(10, -1, "alice")
require.Error(t, err)
}

func TestRegistry_Allow_ExistingUser(t *testing.T) {
reg, err := token.NewRegistry(2, 0, "alice")
require.NoError(t, err)
Expand Down Expand Up @@ -94,6 +85,41 @@ func TestRegistry_Allow_Concurrent(t *testing.T) {
require.Equal(t, int64(200), allowed.Load())
}

func TestRegistry_Deny_Concurrent(t *testing.T) {
reg, err := token.NewRegistry(100, 0)
require.NoError(t, err)

var (
allowed atomic.Int64
deny atomic.Int64
wg sync.WaitGroup
)

// 50 goroutines per user, 4 users = 200 goroutines
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "50 goroutines per user, 4 users = 200 goroutines" but the code actually creates 110 goroutines per user (4 users × 110 = 440 goroutines). Update the comment to accurately reflect the actual number of goroutines being created.

Copilot uses AI. Check for mistakes.
users := []token.Identifier{"alice", "bob", "charlie", "diana"}
for _, user := range users {
for range 110 {
wg.Add(1)

go func(u token.Identifier) {
defer wg.Done()

if reg.Allow(u) {
allowed.Add(1)
} else {
deny.Add(1)
}
}(user)
}
}

wg.Wait()

// Each user has capacity 100, only 50 requests each, so all should be allowed
Copy link

Copilot AI Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states "only 50 requests each" but the code actually makes 110 requests per user. Update the comment to match the actual test implementation (110 requests per user with capacity 100, resulting in 400 allowed and 40 denied).

Suggested change
// Each user has capacity 100, only 50 requests each, so all should be allowed
// Each user has capacity 100 and makes 110 requests; in total 400 allowed and 40 denied

Copilot uses AI. Check for mistakes.
assert.Equal(t, int64(400), allowed.Load())
assert.Equal(t, int64(40), deny.Load())
}

func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) {
reg, err := token.NewRegistry(5, 0)
require.NoError(t, err)
Expand All @@ -107,7 +133,7 @@ func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) {
go func(id int) {
defer wg.Done()

user := token.Identifier(string(rune('a' + id%26)))
user := token.Identifier(rune('a' + id%26))
reg.Allow(user)
}(i)
}
Expand Down