diff --git a/src/app/api/login/route.ts b/src/app/api/login/route.ts index a0e627d1..81f97eb6 100644 --- a/src/app/api/login/route.ts +++ b/src/app/api/login/route.ts @@ -3,7 +3,7 @@ import z from "zod"; import { setJWT } from "@/auth/jwt"; import { findUserByEmailAndPassword } from "@/server/repositories/User"; import logger from "@/server/services/logger"; -import { checkLoginRateLimit } from "@/server/utils/ratelimit"; +import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit"; const loginSchema = z.object({ email: z.string().email(), @@ -11,11 +11,7 @@ const loginSchema = z.object({ }); export async function POST(request: NextRequest) { - const forwarded = request.headers.get("x-forwarded-for"); - const ip = - (forwarded ? forwarded.split(",")[0].trim() : null) ?? - request.headers.get("x-real-ip") ?? - "unknown"; + const ip = getClientIp(request); logger.info(`Login request from ${ip}`); const allowed = await checkLoginRateLimit(ip); diff --git a/src/server/trpc/index.ts b/src/server/trpc/index.ts index 1874113d..99803c45 100644 --- a/src/server/trpc/index.ts +++ b/src/server/trpc/index.ts @@ -4,6 +4,7 @@ import z, { ZodError } from "zod"; import { getServerSession } from "@/auth"; import { ADMIN_USER_EMAIL } from "@/constants"; import { canReadDataSource } from "@/server/utils/auth"; +import { getClientIp } from "@/server/utils/ratelimit"; import { hasPasswordHashSerializer, serverDataSourceSerializer, @@ -14,13 +15,14 @@ import { findOrganisationForUser } from "../repositories/Organisation"; import { findPublishedPublicMapByMapId } from "../repositories/PublicMap"; import { findUserById } from "../repositories/User"; -export async function createContext() { +export async function createContext(opts?: { req?: Request }) { const session = await getServerSession(); let user = null; if (session.currentUser) { user = await findUserById(session.currentUser.id); } - return { user }; + const ip = opts?.req ? getClientIp(opts.req) : "unknown"; + return { user, ip }; } export type Context = Awaited>; diff --git a/src/server/trpc/routers/auth.ts b/src/server/trpc/routers/auth.ts index 4c09db62..daf31279 100644 --- a/src/server/trpc/routers/auth.ts +++ b/src/server/trpc/routers/auth.ts @@ -20,6 +20,7 @@ import { } from "@/server/repositories/User"; import logger from "@/server/services/logger"; import { sendEmail } from "@/server/services/mailer"; +import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; import { publicProcedure, router } from "../index"; export const authRouter = router({ @@ -82,7 +83,15 @@ export const authRouter = router({ }), forgotPassword: publicProcedure .input(z.object({ email: z.string() })) - .mutation(async ({ input }) => { + .mutation(async ({ ctx, input }) => { + const allowed = await checkForgotPasswordRateLimit(ctx.ip); + if (!allowed) { + throw new TRPCError({ + code: "TOO_MANY_REQUESTS", + message: "Too many requests, please try again later", + }); + } + const { email } = input; const user = await findUserByEmail(email); if (!user) return true; diff --git a/src/server/utils/ratelimit.ts b/src/server/utils/ratelimit.ts index 1cb9302e..6eb89e58 100644 --- a/src/server/utils/ratelimit.ts +++ b/src/server/utils/ratelimit.ts @@ -1,16 +1,41 @@ import { getClient } from "@/server/services/redis"; +export function getClientIp(req: Request): string { + const forwarded = req.headers.get("x-forwarded-for"); + return ( + (forwarded ? forwarded.split(",")[0].trim() : null) ?? + req.headers.get("x-real-ip") ?? + "unknown" + ); +} + const WINDOW_SECONDS = 15 * 60; // 15 minutes -const MAX_ATTEMPTS = 5; +const LOGIN_MAX_ATTEMPTS = 5; +const FORGOT_PASSWORD_MAX_ATTEMPTS = 5; -export async function checkLoginRateLimit(ip: string): Promise { +async function checkRateLimit( + key: string, + maxAttempts: number, +): Promise { const redis = getClient(); - const key = `rate_limit:login:${ip}`; const results = await redis .multi() .incr(key) .expire(key, WINDOW_SECONDS) .exec(); const count = results && results[0] ? (results[0][1] as number) : 0; - return count <= MAX_ATTEMPTS; + return count <= maxAttempts; +} + +export async function checkLoginRateLimit(ip: string): Promise { + return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS); +} + +export async function checkForgotPasswordRateLimit( + ip: string, +): Promise { + return checkRateLimit( + `rate_limit:forgot_password:${ip}`, + FORGOT_PASSWORD_MAX_ATTEMPTS, + ); } diff --git a/tests/unit/app/api/login/route.test.ts b/tests/unit/app/api/login/route.test.ts index a3bd44b0..101697d6 100644 --- a/tests/unit/app/api/login/route.test.ts +++ b/tests/unit/app/api/login/route.test.ts @@ -4,9 +4,10 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; vi.mock("@/server/repositories/User", () => ({ findUserByEmailAndPassword: vi.fn(), })); -vi.mock("@/server/utils/ratelimit", () => ({ - checkLoginRateLimit: vi.fn(), -})); +vi.mock("@/server/utils/ratelimit", async (importOriginal) => { + const actual = await importOriginal(); + return { ...(actual as object), checkLoginRateLimit: vi.fn() }; +}); vi.mock("@/auth/jwt", () => ({ setJWT: vi.fn(), })); diff --git a/tests/unit/server/trpc/routers/auth.test.ts b/tests/unit/server/trpc/routers/auth.test.ts new file mode 100644 index 00000000..21a18a0c --- /dev/null +++ b/tests/unit/server/trpc/routers/auth.test.ts @@ -0,0 +1,67 @@ +import { beforeEach, describe, expect, test, vi } from "vitest"; + +vi.mock("@/server/utils/ratelimit", () => ({ + checkForgotPasswordRateLimit: vi.fn(), +})); +vi.mock("@/server/repositories/User", () => ({ + findUserByEmail: vi.fn(), +})); +vi.mock("@/server/services/mailer", () => ({ + sendEmail: vi.fn(), +})); +vi.mock("@/server/services/logger", () => ({ + default: { info: vi.fn(), warn: vi.fn(), error: vi.fn() }, +})); + +import { findUserByEmail } from "@/server/repositories/User"; +import { authRouter } from "@/server/trpc/routers/auth"; +import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; + +const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit); +const mockFindUserByEmail = vi.mocked(findUserByEmail); + +function makeCaller(ip = "1.2.3.4") { + return authRouter.createCaller({ user: null, ip }); +} + +describe("auth.forgotPassword", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockCheckRateLimit.mockResolvedValue(true); + mockFindUserByEmail.mockResolvedValue(undefined); + }); + + describe("rate limiting", () => { + test("allows request when rate limit is not exceeded", async () => { + mockCheckRateLimit.mockResolvedValue(true); + const caller = makeCaller("1.2.3.4"); + await expect( + caller.forgotPassword({ email: "user@example.com" }), + ).resolves.toBe(true); + expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4"); + }); + + test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => { + mockCheckRateLimit.mockResolvedValue(false); + const caller = makeCaller("1.2.3.4"); + await expect( + caller.forgotPassword({ email: "user@example.com" }), + ).rejects.toMatchObject({ code: "TOO_MANY_REQUESTS" }); + }); + + test("passes the caller IP to the rate limiter", async () => { + const caller = makeCaller("9.8.7.6"); + await caller.forgotPassword({ email: "user@example.com" }); + expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6"); + }); + + test("does not look up user when rate limit is exceeded", async () => { + mockCheckRateLimit.mockResolvedValue(false); + const caller = makeCaller(); + await expect( + caller.forgotPassword({ email: "user@example.com" }), + ).rejects.toThrow(); + expect(mockFindUserByEmail).not.toHaveBeenCalled(); + }); + }); +});