diff --git a/src/app/api/login/route.ts b/src/app/api/login/route.ts index 81f97eb6..3a6b9f7d 100644 --- a/src/app/api/login/route.ts +++ b/src/app/api/login/route.ts @@ -3,7 +3,11 @@ import z from "zod"; import { setJWT } from "@/auth/jwt"; import { findUserByEmailAndPassword } from "@/server/repositories/User"; import logger from "@/server/services/logger"; -import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit"; +import { + checkLoginAttempt, + getClientIp, + rollbackLoginAttempt, +} from "@/server/utils/ratelimit"; const loginSchema = z.object({ email: z.string().email(), @@ -14,7 +18,7 @@ export async function POST(request: NextRequest) { const ip = getClientIp(request); logger.info(`Login request from ${ip}`); - const allowed = await checkLoginRateLimit(ip); + const allowed = await checkLoginAttempt(ip); if (!allowed) { return NextResponse.json( { error: "Too many login attempts, please try again later" }, @@ -27,6 +31,7 @@ export async function POST(request: NextRequest) { const result = loginSchema.safeParse(body); if (result.error) { + await rollbackLoginAttempt(ip); return NextResponse.json( { error: "Invalid credentials" }, { status: 400 }, @@ -41,10 +46,12 @@ export async function POST(request: NextRequest) { ); } + await rollbackLoginAttempt(ip); await setJWT(user.id, user.email); return NextResponse.json({ success: true }); } catch (error) { logger.warn(`Failed to log in user`, { error }); + await rollbackLoginAttempt(ip); return NextResponse.json({ error: "Failed to log in" }, { status: 500 }); } } diff --git a/src/server/trpc/routers/auth.ts b/src/server/trpc/routers/auth.ts index daf31279..26d638ad 100644 --- a/src/server/trpc/routers/auth.ts +++ b/src/server/trpc/routers/auth.ts @@ -20,7 +20,7 @@ import { } from "@/server/repositories/User"; import logger from "@/server/services/logger"; import { sendEmail } from "@/server/services/mailer"; -import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; +import { checkForgotPasswordAttempt } from "@/server/utils/ratelimit"; import { publicProcedure, router } from "../index"; export const authRouter = router({ @@ -84,7 +84,7 @@ export const authRouter = router({ forgotPassword: publicProcedure .input(z.object({ email: z.string() })) .mutation(async ({ ctx, input }) => { - const allowed = await checkForgotPasswordRateLimit(ctx.ip); + const allowed = await checkForgotPasswordAttempt(ctx.ip); if (!allowed) { throw new TRPCError({ code: "TOO_MANY_REQUESTS", diff --git a/src/server/utils/ratelimit.ts b/src/server/utils/ratelimit.ts index 6eb89e58..728787ee 100644 --- a/src/server/utils/ratelimit.ts +++ b/src/server/utils/ratelimit.ts @@ -13,29 +13,42 @@ const WINDOW_SECONDS = 15 * 60; // 15 minutes const LOGIN_MAX_ATTEMPTS = 5; const FORGOT_PASSWORD_MAX_ATTEMPTS = 5; -async function checkRateLimit( - key: string, - maxAttempts: number, -): Promise { +async function recordAttempt(key: string): Promise { const redis = getClient(); const results = await redis .multi() .incr(key) .expire(key, WINDOW_SECONDS) .exec(); - const count = results && results[0] ? (results[0][1] as number) : 0; - return count <= maxAttempts; + return results && results[0] ? (results[0][1] as number) : 0; } -export async function checkLoginRateLimit(ip: string): Promise { - return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS); +export async function checkLoginAttempt(ip: string): Promise { + const count = await recordAttempt(`rate_limit:login:${ip}`); + return count <= LOGIN_MAX_ATTEMPTS; } -export async function checkForgotPasswordRateLimit( - ip: string, -): Promise { - return checkRateLimit( - `rate_limit:forgot_password:${ip}`, - FORGOT_PASSWORD_MAX_ATTEMPTS, - ); +export async function rollbackLoginAttempt(ip: string): Promise { + const redis = getClient(); + const key = `rate_limit:login:${ip}`; + + const results = await redis + .multi() + .decr(key) + .expire(key, WINDOW_SECONDS) + .exec(); + + const newCount = + results && results[0] && Array.isArray(results[0]) + ? (results[0][1] as number) + : 0; + + if (newCount <= 0) { + await redis.del(key); + } +} + +export async function checkForgotPasswordAttempt(ip: string): Promise { + const count = await recordAttempt(`rate_limit:forgot_password:${ip}`); + return count <= FORGOT_PASSWORD_MAX_ATTEMPTS; } diff --git a/tests/unit/app/api/login/route.test.ts b/tests/unit/app/api/login/route.test.ts index 101697d6..8ba4b2d5 100644 --- a/tests/unit/app/api/login/route.test.ts +++ b/tests/unit/app/api/login/route.test.ts @@ -6,7 +6,11 @@ vi.mock("@/server/repositories/User", () => ({ })); vi.mock("@/server/utils/ratelimit", async (importOriginal) => { const actual = await importOriginal(); - return { ...(actual as object), checkLoginRateLimit: vi.fn() }; + return { + ...(actual as object), + checkLoginAttempt: vi.fn(), + rollbackLoginAttempt: vi.fn(), + }; }); vi.mock("@/auth/jwt", () => ({ setJWT: vi.fn(), @@ -17,9 +21,13 @@ vi.mock("@/server/services/logger", () => ({ import { POST } from "@/app/api/login/route"; import { findUserByEmailAndPassword } from "@/server/repositories/User"; -import { checkLoginRateLimit } from "@/server/utils/ratelimit"; +import { + checkLoginAttempt, + rollbackLoginAttempt, +} from "@/server/utils/ratelimit"; -const mockCheckLoginRateLimit = vi.mocked(checkLoginRateLimit); +const mockCheckLoginAttempt = vi.mocked(checkLoginAttempt); +const mockRollbackLoginAttempt = vi.mocked(rollbackLoginAttempt); const mockFindUser = vi.mocked(findUserByEmailAndPassword); function makeRequest( @@ -36,24 +44,25 @@ function makeRequest( describe("POST /api/login", () => { beforeEach(() => { vi.clearAllMocks(); - mockCheckLoginRateLimit.mockResolvedValue(true); + mockCheckLoginAttempt.mockResolvedValue(true); + mockRollbackLoginAttempt.mockResolvedValue(undefined); mockFindUser.mockResolvedValue(null); }); describe("rate limiting", () => { test("allows request when rate limit is not exceeded", async () => { - mockCheckLoginRateLimit.mockResolvedValue(true); + mockCheckLoginAttempt.mockResolvedValue(true); const request = makeRequest( { email: "user@example.com", password: "pass" }, { "x-forwarded-for": "1.2.3.4" }, ); const response = await POST(request); expect(response.status).not.toBe(429); - expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("1.2.3.4"); + expect(mockCheckLoginAttempt).toHaveBeenCalledWith("1.2.3.4"); }); test("returns 429 when rate limit is exceeded", async () => { - mockCheckLoginRateLimit.mockResolvedValue(false); + mockCheckLoginAttempt.mockResolvedValue(false); const request = makeRequest( { email: "user@example.com", password: "pass" }, { "x-forwarded-for": "1.2.3.4" }, @@ -70,7 +79,7 @@ describe("POST /api/login", () => { { "x-forwarded-for": "10.0.0.1, 10.0.0.2, 10.0.0.3" }, ); await POST(request); - expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("10.0.0.1"); + expect(mockCheckLoginAttempt).toHaveBeenCalledWith("10.0.0.1"); }); test("trims whitespace from the first IP in x-forwarded-for", async () => { @@ -79,7 +88,7 @@ describe("POST /api/login", () => { { "x-forwarded-for": " 192.168.1.1 , 10.0.0.1" }, ); await POST(request); - expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("192.168.1.1"); + expect(mockCheckLoginAttempt).toHaveBeenCalledWith("192.168.1.1"); }); test("falls back to x-real-ip when x-forwarded-for is absent", async () => { @@ -88,7 +97,7 @@ describe("POST /api/login", () => { { "x-real-ip": "5.6.7.8" }, ); await POST(request); - expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("5.6.7.8"); + expect(mockCheckLoginAttempt).toHaveBeenCalledWith("5.6.7.8"); }); test('uses "unknown" when no IP header is present', async () => { @@ -97,7 +106,53 @@ describe("POST /api/login", () => { password: "pass", }); await POST(request); - expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("unknown"); + expect(mockCheckLoginAttempt).toHaveBeenCalledWith("unknown"); + }); + + test("rolls back the attempt on successful login", async () => { + mockFindUser.mockResolvedValue({ + id: "1", + email: "user@example.com", + name: "User", + createdAt: new Date(), + passwordHash: "", + avatarUrl: undefined, + }); + const request = makeRequest( + { email: "user@example.com", password: "correctpassword" }, + { "x-forwarded-for": "1.2.3.4" }, + ); + await POST(request); + expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("1.2.3.4"); + }); + + test("does not roll back on invalid credentials", async () => { + mockFindUser.mockResolvedValue(null); + const request = makeRequest( + { email: "user@example.com", password: "wrongpassword" }, + { "x-forwarded-for": "1.2.3.4" }, + ); + const response = await POST(request); + expect(response.status).toBe(401); + expect(mockRollbackLoginAttempt).not.toHaveBeenCalled(); + }); + + test("rolls back the attempt on invalid request body", async () => { + const request = makeRequest({ email: "not-an-email", password: "" }); + const response = await POST(request); + expect(response.status).toBe(400); + expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("unknown"); + }); + + test("rolls back the attempt on unexpected server error", async () => { + mockFindUser.mockRejectedValue(new Error("db failure")); + const request = makeRequest( + { email: "user@example.com", password: "pass" }, + { "x-forwarded-for": "1.2.3.4" }, + ); + const response = await POST(request); + expect(response.status).toBe(500); + expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("1.2.3.4"); }); }); diff --git a/tests/unit/server/trpc/routers/auth.test.ts b/tests/unit/server/trpc/routers/auth.test.ts index 21a18a0c..32e26c58 100644 --- a/tests/unit/server/trpc/routers/auth.test.ts +++ b/tests/unit/server/trpc/routers/auth.test.ts @@ -1,7 +1,7 @@ import { beforeEach, describe, expect, test, vi } from "vitest"; vi.mock("@/server/utils/ratelimit", () => ({ - checkForgotPasswordRateLimit: vi.fn(), + checkForgotPasswordAttempt: vi.fn(), })); vi.mock("@/server/repositories/User", () => ({ findUserByEmail: vi.fn(), @@ -15,9 +15,9 @@ vi.mock("@/server/services/logger", () => ({ import { findUserByEmail } from "@/server/repositories/User"; import { authRouter } from "@/server/trpc/routers/auth"; -import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit"; +import { checkForgotPasswordAttempt } from "@/server/utils/ratelimit"; -const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit); +const mockCheckForgotPasswordAttempt = vi.mocked(checkForgotPasswordAttempt); const mockFindUserByEmail = vi.mocked(findUserByEmail); function makeCaller(ip = "1.2.3.4") { @@ -27,22 +27,22 @@ function makeCaller(ip = "1.2.3.4") { describe("auth.forgotPassword", () => { beforeEach(() => { vi.clearAllMocks(); - mockCheckRateLimit.mockResolvedValue(true); + mockCheckForgotPasswordAttempt.mockResolvedValue(true); mockFindUserByEmail.mockResolvedValue(undefined); }); describe("rate limiting", () => { test("allows request when rate limit is not exceeded", async () => { - mockCheckRateLimit.mockResolvedValue(true); + mockCheckForgotPasswordAttempt.mockResolvedValue(true); const caller = makeCaller("1.2.3.4"); await expect( caller.forgotPassword({ email: "user@example.com" }), ).resolves.toBe(true); - expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4"); + expect(mockCheckForgotPasswordAttempt).toHaveBeenCalledWith("1.2.3.4"); }); test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => { - mockCheckRateLimit.mockResolvedValue(false); + mockCheckForgotPasswordAttempt.mockResolvedValue(false); const caller = makeCaller("1.2.3.4"); await expect( caller.forgotPassword({ email: "user@example.com" }), @@ -52,11 +52,11 @@ describe("auth.forgotPassword", () => { test("passes the caller IP to the rate limiter", async () => { const caller = makeCaller("9.8.7.6"); await caller.forgotPassword({ email: "user@example.com" }); - expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6"); + expect(mockCheckForgotPasswordAttempt).toHaveBeenCalledWith("9.8.7.6"); }); test("does not look up user when rate limit is exceeded", async () => { - mockCheckRateLimit.mockResolvedValue(false); + mockCheckForgotPasswordAttempt.mockResolvedValue(false); const caller = makeCaller(); await expect( caller.forgotPassword({ email: "user@example.com" }),