Skip to content

Commit 19101a2

Browse files
authored
Merge pull request #385 from commonknowledge/feat/protect-forgot-password
feat: add forgot password rate limiting
2 parents 36a2fda + 3b17d3e commit 19101a2

6 files changed

Lines changed: 116 additions & 16 deletions

File tree

src/app/api/login/route.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,15 @@ import z from "zod";
33
import { setJWT } from "@/auth/jwt";
44
import { findUserByEmailAndPassword } from "@/server/repositories/User";
55
import logger from "@/server/services/logger";
6-
import { checkLoginRateLimit } from "@/server/utils/ratelimit";
6+
import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit";
77

88
const loginSchema = z.object({
99
email: z.string().email(),
1010
password: z.string().min(1, "Password is required"),
1111
});
1212

1313
export async function POST(request: NextRequest) {
14-
const forwarded = request.headers.get("x-forwarded-for");
15-
const ip =
16-
(forwarded ? forwarded.split(",")[0].trim() : null) ??
17-
request.headers.get("x-real-ip") ??
18-
"unknown";
14+
const ip = getClientIp(request);
1915
logger.info(`Login request from ${ip}`);
2016

2117
const allowed = await checkLoginRateLimit(ip);

src/server/trpc/index.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import z, { ZodError } from "zod";
44
import { getServerSession } from "@/auth";
55
import { ADMIN_USER_EMAIL } from "@/constants";
66
import { canReadDataSource } from "@/server/utils/auth";
7+
import { getClientIp } from "@/server/utils/ratelimit";
78
import {
89
hasPasswordHashSerializer,
910
serverDataSourceSerializer,
@@ -14,13 +15,14 @@ import { findOrganisationForUser } from "../repositories/Organisation";
1415
import { findPublishedPublicMapByMapId } from "../repositories/PublicMap";
1516
import { findUserById } from "../repositories/User";
1617

17-
export async function createContext() {
18+
export async function createContext(opts?: { req?: Request }) {
1819
const session = await getServerSession();
1920
let user = null;
2021
if (session.currentUser) {
2122
user = await findUserById(session.currentUser.id);
2223
}
23-
return { user };
24+
const ip = opts?.req ? getClientIp(opts.req) : "unknown";
25+
return { user, ip };
2426
}
2527

2628
export type Context = Awaited<ReturnType<typeof createContext>>;

src/server/trpc/routers/auth.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
} from "@/server/repositories/User";
2121
import logger from "@/server/services/logger";
2222
import { sendEmail } from "@/server/services/mailer";
23+
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
2324
import { publicProcedure, router } from "../index";
2425

2526
export const authRouter = router({
@@ -82,7 +83,15 @@ export const authRouter = router({
8283
}),
8384
forgotPassword: publicProcedure
8485
.input(z.object({ email: z.string() }))
85-
.mutation(async ({ input }) => {
86+
.mutation(async ({ ctx, input }) => {
87+
const allowed = await checkForgotPasswordRateLimit(ctx.ip);
88+
if (!allowed) {
89+
throw new TRPCError({
90+
code: "TOO_MANY_REQUESTS",
91+
message: "Too many requests, please try again later",
92+
});
93+
}
94+
8695
const { email } = input;
8796
const user = await findUserByEmail(email);
8897
if (!user) return true;

src/server/utils/ratelimit.ts

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,41 @@
11
import { getClient } from "@/server/services/redis";
22

3+
export function getClientIp(req: Request): string {
4+
const forwarded = req.headers.get("x-forwarded-for");
5+
return (
6+
(forwarded ? forwarded.split(",")[0].trim() : null) ??
7+
req.headers.get("x-real-ip") ??
8+
"unknown"
9+
);
10+
}
11+
312
const WINDOW_SECONDS = 15 * 60; // 15 minutes
4-
const MAX_ATTEMPTS = 5;
13+
const LOGIN_MAX_ATTEMPTS = 5;
14+
const FORGOT_PASSWORD_MAX_ATTEMPTS = 5;
515

6-
export async function checkLoginRateLimit(ip: string): Promise<boolean> {
16+
async function checkRateLimit(
17+
key: string,
18+
maxAttempts: number,
19+
): Promise<boolean> {
720
const redis = getClient();
8-
const key = `rate_limit:login:${ip}`;
921
const results = await redis
1022
.multi()
1123
.incr(key)
1224
.expire(key, WINDOW_SECONDS)
1325
.exec();
1426
const count = results && results[0] ? (results[0][1] as number) : 0;
15-
return count <= MAX_ATTEMPTS;
27+
return count <= maxAttempts;
28+
}
29+
30+
export async function checkLoginRateLimit(ip: string): Promise<boolean> {
31+
return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS);
32+
}
33+
34+
export async function checkForgotPasswordRateLimit(
35+
ip: string,
36+
): Promise<boolean> {
37+
return checkRateLimit(
38+
`rate_limit:forgot_password:${ip}`,
39+
FORGOT_PASSWORD_MAX_ATTEMPTS,
40+
);
1641
}

tests/unit/app/api/login/route.test.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ import { beforeEach, describe, expect, test, vi } from "vitest";
44
vi.mock("@/server/repositories/User", () => ({
55
findUserByEmailAndPassword: vi.fn(),
66
}));
7-
vi.mock("@/server/utils/ratelimit", () => ({
8-
checkLoginRateLimit: vi.fn(),
9-
}));
7+
vi.mock("@/server/utils/ratelimit", async (importOriginal) => {
8+
const actual = await importOriginal();
9+
return { ...(actual as object), checkLoginRateLimit: vi.fn() };
10+
});
1011
vi.mock("@/auth/jwt", () => ({
1112
setJWT: vi.fn(),
1213
}));
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import { beforeEach, describe, expect, test, vi } from "vitest";
2+
3+
vi.mock("@/server/utils/ratelimit", () => ({
4+
checkForgotPasswordRateLimit: vi.fn(),
5+
}));
6+
vi.mock("@/server/repositories/User", () => ({
7+
findUserByEmail: vi.fn(),
8+
}));
9+
vi.mock("@/server/services/mailer", () => ({
10+
sendEmail: vi.fn(),
11+
}));
12+
vi.mock("@/server/services/logger", () => ({
13+
default: { info: vi.fn(), warn: vi.fn(), error: vi.fn() },
14+
}));
15+
16+
import { findUserByEmail } from "@/server/repositories/User";
17+
import { authRouter } from "@/server/trpc/routers/auth";
18+
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
19+
20+
const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit);
21+
const mockFindUserByEmail = vi.mocked(findUserByEmail);
22+
23+
function makeCaller(ip = "1.2.3.4") {
24+
return authRouter.createCaller({ user: null, ip });
25+
}
26+
27+
describe("auth.forgotPassword", () => {
28+
beforeEach(() => {
29+
vi.clearAllMocks();
30+
mockCheckRateLimit.mockResolvedValue(true);
31+
mockFindUserByEmail.mockResolvedValue(undefined);
32+
});
33+
34+
describe("rate limiting", () => {
35+
test("allows request when rate limit is not exceeded", async () => {
36+
mockCheckRateLimit.mockResolvedValue(true);
37+
const caller = makeCaller("1.2.3.4");
38+
await expect(
39+
caller.forgotPassword({ email: "user@example.com" }),
40+
).resolves.toBe(true);
41+
expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4");
42+
});
43+
44+
test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => {
45+
mockCheckRateLimit.mockResolvedValue(false);
46+
const caller = makeCaller("1.2.3.4");
47+
await expect(
48+
caller.forgotPassword({ email: "user@example.com" }),
49+
).rejects.toMatchObject({ code: "TOO_MANY_REQUESTS" });
50+
});
51+
52+
test("passes the caller IP to the rate limiter", async () => {
53+
const caller = makeCaller("9.8.7.6");
54+
await caller.forgotPassword({ email: "user@example.com" });
55+
expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6");
56+
});
57+
58+
test("does not look up user when rate limit is exceeded", async () => {
59+
mockCheckRateLimit.mockResolvedValue(false);
60+
const caller = makeCaller();
61+
await expect(
62+
caller.forgotPassword({ email: "user@example.com" }),
63+
).rejects.toThrow();
64+
expect(mockFindUserByEmail).not.toHaveBeenCalled();
65+
});
66+
});
67+
});

0 commit comments

Comments
 (0)