From 61b26a88a4fb72bc12681e62eb5f4b91a4d6621b Mon Sep 17 00:00:00 2001 From: Joaquim d'Souza Date: Sun, 22 Mar 2026 15:58:42 +0100 Subject: [PATCH 1/4] feat: add forgot password rate limiting --- src/app/api/login/route.ts | 11 ++- src/server/trpc/index.ts | 6 +- src/server/trpc/routers/auth.ts | 11 ++- src/server/utils/ratelimit.ts | 34 ++++++++-- tests/unit/app/api/login/route.test.ts | 7 +- tests/unit/server/trpc/routers/auth.test.ts | 74 +++++++++++++++++++++ 6 files changed, 127 insertions(+), 16 deletions(-) create mode 100644 tests/unit/server/trpc/routers/auth.test.ts diff --git a/src/app/api/login/route.ts b/src/app/api/login/route.ts index a0e627d1..c6f7c16d 100644 --- a/src/app/api/login/route.ts +++ b/src/app/api/login/route.ts @@ -3,7 +3,10 @@ import z from "zod"; import { setJWT } from "@/auth/jwt"; import { findUserByEmailAndPassword } from "@/server/repositories/User"; import logger from "@/server/services/logger"; -import { checkLoginRateLimit } from "@/server/utils/ratelimit"; +import { + checkLoginRateLimit, + getClientIp, +} from "@/server/utils/ratelimit"; const loginSchema = z.object({ email: z.string().email(), @@ -11,11 +14,7 @@ const loginSchema = z.object({ }); export async function POST(request: NextRequest) { - const forwarded = request.headers.get("x-forwarded-for"); - const ip = - (forwarded ? forwarded.split(",")[0].trim() : null) ?? - request.headers.get("x-real-ip") ?? - "unknown"; + const ip = getClientIp(request); logger.info(`Login request from ${ip}`); const allowed = await checkLoginRateLimit(ip); diff --git a/src/server/trpc/index.ts b/src/server/trpc/index.ts index 1874113d..99803c45 100644 --- a/src/server/trpc/index.ts +++ b/src/server/trpc/index.ts @@ -4,6 +4,7 @@ import z, { ZodError } from "zod"; import { getServerSession } from "@/auth"; import { ADMIN_USER_EMAIL } from "@/constants"; import { canReadDataSource } from "@/server/utils/auth"; +import { getClientIp } from "@/server/utils/ratelimit"; import { hasPasswordHashSerializer, serverDataSourceSerializer, @@ -14,13 +15,14 @@ import { findOrganisationForUser } from "../repositories/Organisation"; import { findPublishedPublicMapByMapId } from "../repositories/PublicMap"; import { findUserById } from "../repositories/User"; -export async function createContext() { +export async function createContext(opts?: { req?: Request }) { const session = await getServerSession(); let user = null; if (session.currentUser) { user = await findUserById(session.currentUser.id); } - return { user }; + const ip = opts?.req ? getClientIp(opts.req) : "unknown"; + return { user, ip }; } export type Context = Awaited>; diff --git a/src/server/trpc/routers/auth.ts b/src/server/trpc/routers/auth.ts index 4c09db62..daf31279 100644 --- a/src/server/trpc/routers/auth.ts +++ b/src/server/trpc/routers/auth.ts @@ -20,6 +20,7 @@ import { } from "@/server/repositories/User"; import logger from "@/server/services/logger"; import { sendEmail } from "@/server/services/mailer"; +import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; import { publicProcedure, router } from "../index"; export const authRouter = router({ @@ -82,7 +83,15 @@ export const authRouter = router({ }), forgotPassword: publicProcedure .input(z.object({ email: z.string() })) - .mutation(async ({ input }) => { + .mutation(async ({ ctx, input }) => { + const allowed = await checkForgotPasswordRateLimit(ctx.ip); + if (!allowed) { + throw new TRPCError({ + code: "TOO_MANY_REQUESTS", + message: "Too many requests, please try again later", + }); + } + const { email } = input; const user = await findUserByEmail(email); if (!user) return true; diff --git a/src/server/utils/ratelimit.ts b/src/server/utils/ratelimit.ts index 1cb9302e..2a597d66 100644 --- a/src/server/utils/ratelimit.ts +++ b/src/server/utils/ratelimit.ts @@ -1,16 +1,42 @@ import { getClient } from "@/server/services/redis"; +import type { NextRequest } from "next/server"; + +export function getClientIp(req: Request | NextRequest): string { + const forwarded = req.headers.get("x-forwarded-for"); + return ( + (forwarded ? forwarded.split(",")[0].trim() : null) ?? + req.headers.get("x-real-ip") ?? + "unknown" + ); +} const WINDOW_SECONDS = 15 * 60; // 15 minutes -const MAX_ATTEMPTS = 5; +const LOGIN_MAX_ATTEMPTS = 5; +const FORGOT_PASSWORD_MAX_ATTEMPTS = 5; -export async function checkLoginRateLimit(ip: string): Promise { +async function checkRateLimit( + key: string, + maxAttempts: number, +): Promise { const redis = getClient(); - const key = `rate_limit:login:${ip}`; const results = await redis .multi() .incr(key) .expire(key, WINDOW_SECONDS) .exec(); const count = results && results[0] ? (results[0][1] as number) : 0; - return count <= MAX_ATTEMPTS; + return count <= maxAttempts; +} + +export async function checkLoginRateLimit(ip: string): Promise { + return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS); +} + +export async function checkForgotPasswordRateLimit( + ip: string, +): Promise { + return checkRateLimit( + `rate_limit:forgot_password:${ip}`, + FORGOT_PASSWORD_MAX_ATTEMPTS, + ); } diff --git a/tests/unit/app/api/login/route.test.ts b/tests/unit/app/api/login/route.test.ts index a3bd44b0..101697d6 100644 --- a/tests/unit/app/api/login/route.test.ts +++ b/tests/unit/app/api/login/route.test.ts @@ -4,9 +4,10 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; vi.mock("@/server/repositories/User", () => ({ findUserByEmailAndPassword: vi.fn(), })); -vi.mock("@/server/utils/ratelimit", () => ({ - checkLoginRateLimit: vi.fn(), -})); +vi.mock("@/server/utils/ratelimit", async (importOriginal) => { + const actual = await importOriginal(); + return { ...(actual as object), checkLoginRateLimit: vi.fn() }; +}); vi.mock("@/auth/jwt", () => ({ setJWT: vi.fn(), })); diff --git a/tests/unit/server/trpc/routers/auth.test.ts b/tests/unit/server/trpc/routers/auth.test.ts new file mode 100644 index 00000000..380ce46e --- /dev/null +++ b/tests/unit/server/trpc/routers/auth.test.ts @@ -0,0 +1,74 @@ +import { TRPCError } from "@trpc/server"; +import { beforeEach, describe, expect, test, vi } from "vitest"; + +vi.mock("@/server/utils/ratelimit", () => ({ + checkForgotPasswordRateLimit: vi.fn(), +})); +vi.mock("@/server/repositories/User", () => ({ + findUserByEmail: vi.fn(), +})); +vi.mock("@/server/services/mailer", () => ({ + sendEmail: vi.fn(), +})); +vi.mock("@/server/services/logger", () => ({ + default: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; +import { findUserByEmail } from "@/server/repositories/User"; +import { authRouter } from "@/server/trpc/routers/auth"; + +const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit); +const mockFindUserByEmail = vi.mocked(findUserByEmail); + +function makeCaller(ip = "1.2.3.4") { + return authRouter.createCaller({ user: null, ip }); +} + +describe("auth.forgotPassword", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockCheckRateLimit.mockResolvedValue(true); + mockFindUserByEmail.mockResolvedValue(null); + }); + + describe("rate limiting", () => { + test("allows request when rate limit is not exceeded", async () => { + mockCheckRateLimit.mockResolvedValue(true); + const caller = makeCaller("1.2.3.4"); + await expect( + caller.forgotPassword({ email: "user@example.com" }), + ).resolves.toBe(true); + expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4"); + }); + + test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => { + mockCheckRateLimit.mockResolvedValue(false); + const caller = makeCaller("1.2.3.4"); + await expect( + caller.forgotPassword({ email: "user@example.com" }), + ).rejects.toThrow(TRPCError); + try { + await caller.forgotPassword({ email: "user@example.com" }); + } catch (err) { + expect(err).toBeInstanceOf(TRPCError); + expect((err as TRPCError).code).toBe("TOO_MANY_REQUESTS"); + } + }); + + test("passes the caller IP to the rate limiter", async () => { + const caller = makeCaller("9.8.7.6"); + await caller.forgotPassword({ email: "user@example.com" }); + expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6"); + }); + + test("does not look up user when rate limit is exceeded", async () => { + mockCheckRateLimit.mockResolvedValue(false); + const caller = makeCaller(); + await expect( + caller.forgotPassword({ email: "user@example.com" }), + ).rejects.toThrow(); + expect(mockFindUserByEmail).not.toHaveBeenCalled(); + }); + }); +}); From 1cb733ab606bb7c066982fb63b0286abf309a279 Mon Sep 17 00:00:00 2001 From: joaquimds Date: Sun, 22 Mar 2026 16:08:58 +0100 Subject: [PATCH 2/4] Update tests/unit/server/trpc/routers/auth.test.ts Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/unit/server/trpc/routers/auth.test.ts | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/unit/server/trpc/routers/auth.test.ts b/tests/unit/server/trpc/routers/auth.test.ts index 380ce46e..eb1bb5b4 100644 --- a/tests/unit/server/trpc/routers/auth.test.ts +++ b/tests/unit/server/trpc/routers/auth.test.ts @@ -47,13 +47,7 @@ describe("auth.forgotPassword", () => { const caller = makeCaller("1.2.3.4"); await expect( caller.forgotPassword({ email: "user@example.com" }), - ).rejects.toThrow(TRPCError); - try { - await caller.forgotPassword({ email: "user@example.com" }); - } catch (err) { - expect(err).toBeInstanceOf(TRPCError); - expect((err as TRPCError).code).toBe("TOO_MANY_REQUESTS"); - } + ).rejects.toMatchObject({ code: "TOO_MANY_REQUESTS" }); }); test("passes the caller IP to the rate limiter", async () => { From 9333cea888535fc370a6c2dd7f160e379e468c77 Mon Sep 17 00:00:00 2001 From: joaquimds Date: Sun, 22 Mar 2026 16:09:14 +0100 Subject: [PATCH 3/4] Update src/server/utils/ratelimit.ts Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/server/utils/ratelimit.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/server/utils/ratelimit.ts b/src/server/utils/ratelimit.ts index 2a597d66..6eb89e58 100644 --- a/src/server/utils/ratelimit.ts +++ b/src/server/utils/ratelimit.ts @@ -1,7 +1,6 @@ import { getClient } from "@/server/services/redis"; -import type { NextRequest } from "next/server"; -export function getClientIp(req: Request | NextRequest): string { +export function getClientIp(req: Request): string { const forwarded = req.headers.get("x-forwarded-for"); return ( (forwarded ? forwarded.split(",")[0].trim() : null) ?? From 3b17d3ee76758b3a5b07fefbfbf0fb4d5550cb2c Mon Sep 17 00:00:00 2001 From: Joaquim d'Souza Date: Sun, 22 Mar 2026 16:11:17 +0100 Subject: [PATCH 4/4] fix: lint --- src/app/api/login/route.ts | 5 +---- tests/unit/server/trpc/routers/auth.test.ts | 5 ++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/app/api/login/route.ts b/src/app/api/login/route.ts index c6f7c16d..81f97eb6 100644 --- a/src/app/api/login/route.ts +++ b/src/app/api/login/route.ts @@ -3,10 +3,7 @@ import z from "zod"; import { setJWT } from "@/auth/jwt"; import { findUserByEmailAndPassword } from "@/server/repositories/User"; import logger from "@/server/services/logger"; -import { - checkLoginRateLimit, - getClientIp, -} from "@/server/utils/ratelimit"; +import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit"; const loginSchema = z.object({ email: z.string().email(), diff --git a/tests/unit/server/trpc/routers/auth.test.ts b/tests/unit/server/trpc/routers/auth.test.ts index eb1bb5b4..21a18a0c 100644 --- a/tests/unit/server/trpc/routers/auth.test.ts +++ b/tests/unit/server/trpc/routers/auth.test.ts @@ -1,4 +1,3 @@ -import { TRPCError } from "@trpc/server"; import { beforeEach, describe, expect, test, vi } from "vitest"; vi.mock("@/server/utils/ratelimit", () => ({ @@ -14,9 +13,9 @@ vi.mock("@/server/services/logger", () => ({ default: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, })); -import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; import { findUserByEmail } from "@/server/repositories/User"; import { authRouter } from "@/server/trpc/routers/auth"; +import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit); const mockFindUserByEmail = vi.mocked(findUserByEmail); @@ -29,7 +28,7 @@ describe("auth.forgotPassword", () => { beforeEach(() => { vi.clearAllMocks(); mockCheckRateLimit.mockResolvedValue(true); - mockFindUserByEmail.mockResolvedValue(null); + mockFindUserByEmail.mockResolvedValue(undefined); }); describe("rate limiting", () => {