Skip to content

Commit 6ee24df

Browse files
authored
Merge pull request #386 from commonknowledge/fix/login-rate-limit-only-failed
fix: only increment login rate limit on failed attempts
2 parents 19101a2 + 2eb9893 commit 6ee24df

5 files changed

Lines changed: 114 additions & 39 deletions

File tree

src/app/api/login/route.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ import z from "zod";
33
import { setJWT } from "@/auth/jwt";
44
import { findUserByEmailAndPassword } from "@/server/repositories/User";
55
import logger from "@/server/services/logger";
6-
import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit";
6+
import {
7+
checkLoginAttempt,
8+
getClientIp,
9+
rollbackLoginAttempt,
10+
} from "@/server/utils/ratelimit";
711

812
const loginSchema = z.object({
913
email: z.string().email(),
@@ -14,7 +18,7 @@ export async function POST(request: NextRequest) {
1418
const ip = getClientIp(request);
1519
logger.info(`Login request from ${ip}`);
1620

17-
const allowed = await checkLoginRateLimit(ip);
21+
const allowed = await checkLoginAttempt(ip);
1822
if (!allowed) {
1923
return NextResponse.json(
2024
{ error: "Too many login attempts, please try again later" },
@@ -27,6 +31,7 @@ export async function POST(request: NextRequest) {
2731
const result = loginSchema.safeParse(body);
2832

2933
if (result.error) {
34+
await rollbackLoginAttempt(ip);
3035
return NextResponse.json(
3136
{ error: "Invalid credentials" },
3237
{ status: 400 },
@@ -41,10 +46,12 @@ export async function POST(request: NextRequest) {
4146
);
4247
}
4348

49+
await rollbackLoginAttempt(ip);
4450
await setJWT(user.id, user.email);
4551
return NextResponse.json({ success: true });
4652
} catch (error) {
4753
logger.warn(`Failed to log in user`, { error });
54+
await rollbackLoginAttempt(ip);
4855
return NextResponse.json({ error: "Failed to log in" }, { status: 500 });
4956
}
5057
}

src/server/trpc/routers/auth.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {
2020
} from "@/server/repositories/User";
2121
import logger from "@/server/services/logger";
2222
import { sendEmail } from "@/server/services/mailer";
23-
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
23+
import { checkForgotPasswordAttempt } from "@/server/utils/ratelimit";
2424
import { publicProcedure, router } from "../index";
2525

2626
export const authRouter = router({
@@ -84,7 +84,7 @@ export const authRouter = router({
8484
forgotPassword: publicProcedure
8585
.input(z.object({ email: z.string() }))
8686
.mutation(async ({ ctx, input }) => {
87-
const allowed = await checkForgotPasswordRateLimit(ctx.ip);
87+
const allowed = await checkForgotPasswordAttempt(ctx.ip);
8888
if (!allowed) {
8989
throw new TRPCError({
9090
code: "TOO_MANY_REQUESTS",

src/server/utils/ratelimit.ts

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,42 @@ const WINDOW_SECONDS = 15 * 60; // 15 minutes
1313
const LOGIN_MAX_ATTEMPTS = 5;
1414
const FORGOT_PASSWORD_MAX_ATTEMPTS = 5;
1515

16-
async function checkRateLimit(
17-
key: string,
18-
maxAttempts: number,
19-
): Promise<boolean> {
16+
async function recordAttempt(key: string): Promise<number> {
2017
const redis = getClient();
2118
const results = await redis
2219
.multi()
2320
.incr(key)
2421
.expire(key, WINDOW_SECONDS)
2522
.exec();
26-
const count = results && results[0] ? (results[0][1] as number) : 0;
27-
return count <= maxAttempts;
23+
return results && results[0] ? (results[0][1] as number) : 0;
2824
}
2925

30-
export async function checkLoginRateLimit(ip: string): Promise<boolean> {
31-
return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS);
26+
export async function checkLoginAttempt(ip: string): Promise<boolean> {
27+
const count = await recordAttempt(`rate_limit:login:${ip}`);
28+
return count <= LOGIN_MAX_ATTEMPTS;
3229
}
3330

34-
export async function checkForgotPasswordRateLimit(
35-
ip: string,
36-
): Promise<boolean> {
37-
return checkRateLimit(
38-
`rate_limit:forgot_password:${ip}`,
39-
FORGOT_PASSWORD_MAX_ATTEMPTS,
40-
);
31+
export async function rollbackLoginAttempt(ip: string): Promise<void> {
32+
const redis = getClient();
33+
const key = `rate_limit:login:${ip}`;
34+
35+
const results = await redis
36+
.multi()
37+
.decr(key)
38+
.expire(key, WINDOW_SECONDS)
39+
.exec();
40+
41+
const newCount =
42+
results && results[0] && Array.isArray(results[0])
43+
? (results[0][1] as number)
44+
: 0;
45+
46+
if (newCount <= 0) {
47+
await redis.del(key);
48+
}
49+
}
50+
51+
export async function checkForgotPasswordAttempt(ip: string): Promise<boolean> {
52+
const count = await recordAttempt(`rate_limit:forgot_password:${ip}`);
53+
return count <= FORGOT_PASSWORD_MAX_ATTEMPTS;
4154
}

tests/unit/app/api/login/route.test.ts

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ vi.mock("@/server/repositories/User", () => ({
66
}));
77
vi.mock("@/server/utils/ratelimit", async (importOriginal) => {
88
const actual = await importOriginal();
9-
return { ...(actual as object), checkLoginRateLimit: vi.fn() };
9+
return {
10+
...(actual as object),
11+
checkLoginAttempt: vi.fn(),
12+
rollbackLoginAttempt: vi.fn(),
13+
};
1014
});
1115
vi.mock("@/auth/jwt", () => ({
1216
setJWT: vi.fn(),
@@ -17,9 +21,13 @@ vi.mock("@/server/services/logger", () => ({
1721

1822
import { POST } from "@/app/api/login/route";
1923
import { findUserByEmailAndPassword } from "@/server/repositories/User";
20-
import { checkLoginRateLimit } from "@/server/utils/ratelimit";
24+
import {
25+
checkLoginAttempt,
26+
rollbackLoginAttempt,
27+
} from "@/server/utils/ratelimit";
2128

22-
const mockCheckLoginRateLimit = vi.mocked(checkLoginRateLimit);
29+
const mockCheckLoginAttempt = vi.mocked(checkLoginAttempt);
30+
const mockRollbackLoginAttempt = vi.mocked(rollbackLoginAttempt);
2331
const mockFindUser = vi.mocked(findUserByEmailAndPassword);
2432

2533
function makeRequest(
@@ -36,24 +44,25 @@ function makeRequest(
3644
describe("POST /api/login", () => {
3745
beforeEach(() => {
3846
vi.clearAllMocks();
39-
mockCheckLoginRateLimit.mockResolvedValue(true);
47+
mockCheckLoginAttempt.mockResolvedValue(true);
48+
mockRollbackLoginAttempt.mockResolvedValue(undefined);
4049
mockFindUser.mockResolvedValue(null);
4150
});
4251

4352
describe("rate limiting", () => {
4453
test("allows request when rate limit is not exceeded", async () => {
45-
mockCheckLoginRateLimit.mockResolvedValue(true);
54+
mockCheckLoginAttempt.mockResolvedValue(true);
4655
const request = makeRequest(
4756
{ email: "user@example.com", password: "pass" },
4857
{ "x-forwarded-for": "1.2.3.4" },
4958
);
5059
const response = await POST(request);
5160
expect(response.status).not.toBe(429);
52-
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("1.2.3.4");
61+
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("1.2.3.4");
5362
});
5463

5564
test("returns 429 when rate limit is exceeded", async () => {
56-
mockCheckLoginRateLimit.mockResolvedValue(false);
65+
mockCheckLoginAttempt.mockResolvedValue(false);
5766
const request = makeRequest(
5867
{ email: "user@example.com", password: "pass" },
5968
{ "x-forwarded-for": "1.2.3.4" },
@@ -70,7 +79,7 @@ describe("POST /api/login", () => {
7079
{ "x-forwarded-for": "10.0.0.1, 10.0.0.2, 10.0.0.3" },
7180
);
7281
await POST(request);
73-
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("10.0.0.1");
82+
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("10.0.0.1");
7483
});
7584

7685
test("trims whitespace from the first IP in x-forwarded-for", async () => {
@@ -79,7 +88,7 @@ describe("POST /api/login", () => {
7988
{ "x-forwarded-for": " 192.168.1.1 , 10.0.0.1" },
8089
);
8190
await POST(request);
82-
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("192.168.1.1");
91+
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("192.168.1.1");
8392
});
8493

8594
test("falls back to x-real-ip when x-forwarded-for is absent", async () => {
@@ -88,7 +97,7 @@ describe("POST /api/login", () => {
8897
{ "x-real-ip": "5.6.7.8" },
8998
);
9099
await POST(request);
91-
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("5.6.7.8");
100+
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("5.6.7.8");
92101
});
93102

94103
test('uses "unknown" when no IP header is present', async () => {
@@ -97,7 +106,53 @@ describe("POST /api/login", () => {
97106
password: "pass",
98107
});
99108
await POST(request);
100-
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("unknown");
109+
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("unknown");
110+
});
111+
112+
test("rolls back the attempt on successful login", async () => {
113+
mockFindUser.mockResolvedValue({
114+
id: "1",
115+
email: "user@example.com",
116+
name: "User",
117+
createdAt: new Date(),
118+
passwordHash: "",
119+
avatarUrl: undefined,
120+
});
121+
const request = makeRequest(
122+
{ email: "user@example.com", password: "correctpassword" },
123+
{ "x-forwarded-for": "1.2.3.4" },
124+
);
125+
await POST(request);
126+
expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("1.2.3.4");
127+
});
128+
129+
test("does not roll back on invalid credentials", async () => {
130+
mockFindUser.mockResolvedValue(null);
131+
const request = makeRequest(
132+
{ email: "user@example.com", password: "wrongpassword" },
133+
{ "x-forwarded-for": "1.2.3.4" },
134+
);
135+
const response = await POST(request);
136+
expect(response.status).toBe(401);
137+
expect(mockRollbackLoginAttempt).not.toHaveBeenCalled();
138+
});
139+
140+
test("rolls back the attempt on invalid request body", async () => {
141+
const request = makeRequest({ email: "not-an-email", password: "" });
142+
const response = await POST(request);
143+
expect(response.status).toBe(400);
144+
expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("unknown");
145+
});
146+
147+
test("rolls back the attempt on unexpected server error", async () => {
148+
mockFindUser.mockRejectedValue(new Error("db failure"));
149+
const request = makeRequest(
150+
{ email: "user@example.com", password: "pass" },
151+
{ "x-forwarded-for": "1.2.3.4" },
152+
);
153+
const response = await POST(request);
154+
expect(response.status).toBe(500);
155+
expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("1.2.3.4");
101156
});
102157
});
103158

tests/unit/server/trpc/routers/auth.test.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { beforeEach, describe, expect, test, vi } from "vitest";
22

33
vi.mock("@/server/utils/ratelimit", () => ({
4-
checkForgotPasswordRateLimit: vi.fn(),
4+
checkForgotPasswordAttempt: vi.fn(),
55
}));
66
vi.mock("@/server/repositories/User", () => ({
77
findUserByEmail: vi.fn(),
@@ -15,9 +15,9 @@ vi.mock("@/server/services/logger", () => ({
1515

1616
import { findUserByEmail } from "@/server/repositories/User";
1717
import { authRouter } from "@/server/trpc/routers/auth";
18-
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
18+
import { checkForgotPasswordAttempt } from "@/server/utils/ratelimit";
1919

20-
const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit);
20+
const mockCheckForgotPasswordAttempt = vi.mocked(checkForgotPasswordAttempt);
2121
const mockFindUserByEmail = vi.mocked(findUserByEmail);
2222

2323
function makeCaller(ip = "1.2.3.4") {
@@ -27,22 +27,22 @@ function makeCaller(ip = "1.2.3.4") {
2727
describe("auth.forgotPassword", () => {
2828
beforeEach(() => {
2929
vi.clearAllMocks();
30-
mockCheckRateLimit.mockResolvedValue(true);
30+
mockCheckForgotPasswordAttempt.mockResolvedValue(true);
3131
mockFindUserByEmail.mockResolvedValue(undefined);
3232
});
3333

3434
describe("rate limiting", () => {
3535
test("allows request when rate limit is not exceeded", async () => {
36-
mockCheckRateLimit.mockResolvedValue(true);
36+
mockCheckForgotPasswordAttempt.mockResolvedValue(true);
3737
const caller = makeCaller("1.2.3.4");
3838
await expect(
3939
caller.forgotPassword({ email: "user@example.com" }),
4040
).resolves.toBe(true);
41-
expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4");
41+
expect(mockCheckForgotPasswordAttempt).toHaveBeenCalledWith("1.2.3.4");
4242
});
4343

4444
test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => {
45-
mockCheckRateLimit.mockResolvedValue(false);
45+
mockCheckForgotPasswordAttempt.mockResolvedValue(false);
4646
const caller = makeCaller("1.2.3.4");
4747
await expect(
4848
caller.forgotPassword({ email: "user@example.com" }),
@@ -52,11 +52,11 @@ describe("auth.forgotPassword", () => {
5252
test("passes the caller IP to the rate limiter", async () => {
5353
const caller = makeCaller("9.8.7.6");
5454
await caller.forgotPassword({ email: "user@example.com" });
55-
expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6");
55+
expect(mockCheckForgotPasswordAttempt).toHaveBeenCalledWith("9.8.7.6");
5656
});
5757

5858
test("does not look up user when rate limit is exceeded", async () => {
59-
mockCheckRateLimit.mockResolvedValue(false);
59+
mockCheckForgotPasswordAttempt.mockResolvedValue(false);
6060
const caller = makeCaller();
6161
await expect(
6262
caller.forgotPassword({ email: "user@example.com" }),

0 commit comments

Comments
 (0)