Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/app/api/login/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import z from "zod";
import { setJWT } from "@/auth/jwt";
import { findUserByEmailAndPassword } from "@/server/repositories/User";
import logger from "@/server/services/logger";
import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit";
import {
checkLoginAttempt,
getClientIp,
rollbackLoginAttempt,
} from "@/server/utils/ratelimit";

const loginSchema = z.object({
email: z.string().email(),
Expand All @@ -14,7 +18,7 @@ export async function POST(request: NextRequest) {
const ip = getClientIp(request);
logger.info(`Login request from ${ip}`);

const allowed = await checkLoginRateLimit(ip);
const allowed = await checkLoginAttempt(ip);
if (!allowed) {
return NextResponse.json(
{ error: "Too many login attempts, please try again later" },
Expand All @@ -27,6 +31,7 @@ export async function POST(request: NextRequest) {
const result = loginSchema.safeParse(body);

if (result.error) {
await rollbackLoginAttempt(ip);
return NextResponse.json(
{ error: "Invalid credentials" },
{ status: 400 },
Expand All @@ -41,10 +46,12 @@ export async function POST(request: NextRequest) {
);
}

await rollbackLoginAttempt(ip);
await setJWT(user.id, user.email);
return NextResponse.json({ success: true });
} catch (error) {
logger.warn(`Failed to log in user`, { error });
await rollbackLoginAttempt(ip);
return NextResponse.json({ error: "Failed to log in" }, { status: 500 });
}
}
4 changes: 2 additions & 2 deletions src/server/trpc/routers/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {
} from "@/server/repositories/User";
import logger from "@/server/services/logger";
import { sendEmail } from "@/server/services/mailer";
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
import { checkForgotPasswordAttempt } from "@/server/utils/ratelimit";
import { publicProcedure, router } from "../index";

export const authRouter = router({
Expand Down Expand Up @@ -84,7 +84,7 @@ export const authRouter = router({
forgotPassword: publicProcedure
.input(z.object({ email: z.string() }))
.mutation(async ({ ctx, input }) => {
const allowed = await checkForgotPasswordRateLimit(ctx.ip);
const allowed = await checkForgotPasswordAttempt(ctx.ip);
if (!allowed) {
throw new TRPCError({
code: "TOO_MANY_REQUESTS",
Expand Down
43 changes: 28 additions & 15 deletions src/server/utils/ratelimit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,42 @@ const WINDOW_SECONDS = 15 * 60; // 15 minutes
const LOGIN_MAX_ATTEMPTS = 5;
const FORGOT_PASSWORD_MAX_ATTEMPTS = 5;

async function checkRateLimit(
key: string,
maxAttempts: number,
): Promise<boolean> {
async function recordAttempt(key: string): Promise<number> {
const redis = getClient();
const results = await redis
.multi()
.incr(key)
.expire(key, WINDOW_SECONDS)
.exec();
const count = results && results[0] ? (results[0][1] as number) : 0;
return count <= maxAttempts;
return results && results[0] ? (results[0][1] as number) : 0;
}

export async function checkLoginRateLimit(ip: string): Promise<boolean> {
return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS);
export async function checkLoginAttempt(ip: string): Promise<boolean> {
const count = await recordAttempt(`rate_limit:login:${ip}`);
return count <= LOGIN_MAX_ATTEMPTS;
}

Comment on lines +26 to 30
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about rollback says a post-expiry rollback results in a key at -1 and the next INCR producing 0, but the current implementation deletes the key when newCount <= 0. That means the next INCR would produce 1, so the comment is misleading and should be updated to match the actual behavior.

Copilot uses AI. Check for mistakes.
export async function checkForgotPasswordRateLimit(
ip: string,
): Promise<boolean> {
return checkRateLimit(
`rate_limit:forgot_password:${ip}`,
FORGOT_PASSWORD_MAX_ATTEMPTS,
);
export async function rollbackLoginAttempt(ip: string): Promise<void> {
const redis = getClient();
const key = `rate_limit:login:${ip}`;

const results = await redis
.multi()
.decr(key)
.expire(key, WINDOW_SECONDS)
.exec();

const newCount =
results && results[0] && Array.isArray(results[0])
? (results[0][1] as number)
: 0;

if (newCount <= 0) {
Comment on lines +40 to +46
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollbackLoginAttempt unconditionally calls EXPIRE after DECR, which resets the TTL. This can unintentionally extend the lockout window for any remaining failed attempts (e.g., 4 failures followed by a successful login will keep those failures “fresh” for another 15 minutes). Rollback should generally preserve the existing TTL (or only set an expiry when the key is newly created / missing a TTL).

Copilot uses AI. Check for mistakes.
await redis.del(key);
}
}

export async function checkForgotPasswordAttempt(ip: string): Promise<boolean> {
const count = await recordAttempt(`rate_limit:forgot_password:${ip}`);
return count <= FORGOT_PASSWORD_MAX_ATTEMPTS;
Comment on lines +51 to +53
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DEL is executed outside the MULTI/EXEC transaction, so there’s a race where another request can INCR the same key after the DECR but before DEL, and then the DEL wipes out a non-zero count. That would undercount failed attempts and weaken rate limiting. Consider making the decrement+conditional-delete atomic (e.g., a small Lua script that does DECR and DEL-if-<=0, and manages TTL appropriately).

Copilot uses AI. Check for mistakes.
}
77 changes: 66 additions & 11 deletions tests/unit/app/api/login/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ vi.mock("@/server/repositories/User", () => ({
}));
vi.mock("@/server/utils/ratelimit", async (importOriginal) => {
const actual = await importOriginal();
return { ...(actual as object), checkLoginRateLimit: vi.fn() };
return {
...(actual as object),
checkLoginAttempt: vi.fn(),
rollbackLoginAttempt: vi.fn(),
};
});
vi.mock("@/auth/jwt", () => ({
setJWT: vi.fn(),
Expand All @@ -17,9 +21,13 @@ vi.mock("@/server/services/logger", () => ({

import { POST } from "@/app/api/login/route";
import { findUserByEmailAndPassword } from "@/server/repositories/User";
import { checkLoginRateLimit } from "@/server/utils/ratelimit";
import {
checkLoginAttempt,
rollbackLoginAttempt,
} from "@/server/utils/ratelimit";

const mockCheckLoginRateLimit = vi.mocked(checkLoginRateLimit);
const mockCheckLoginAttempt = vi.mocked(checkLoginAttempt);
const mockRollbackLoginAttempt = vi.mocked(rollbackLoginAttempt);
const mockFindUser = vi.mocked(findUserByEmailAndPassword);

function makeRequest(
Expand All @@ -36,24 +44,25 @@ function makeRequest(
describe("POST /api/login", () => {
beforeEach(() => {
vi.clearAllMocks();
mockCheckLoginRateLimit.mockResolvedValue(true);
mockCheckLoginAttempt.mockResolvedValue(true);
mockRollbackLoginAttempt.mockResolvedValue(undefined);
mockFindUser.mockResolvedValue(null);
});

describe("rate limiting", () => {
test("allows request when rate limit is not exceeded", async () => {
mockCheckLoginRateLimit.mockResolvedValue(true);
mockCheckLoginAttempt.mockResolvedValue(true);
const request = makeRequest(
{ email: "user@example.com", password: "pass" },
{ "x-forwarded-for": "1.2.3.4" },
);
const response = await POST(request);
expect(response.status).not.toBe(429);
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("1.2.3.4");
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("1.2.3.4");
});

test("returns 429 when rate limit is exceeded", async () => {
mockCheckLoginRateLimit.mockResolvedValue(false);
mockCheckLoginAttempt.mockResolvedValue(false);
const request = makeRequest(
{ email: "user@example.com", password: "pass" },
{ "x-forwarded-for": "1.2.3.4" },
Expand All @@ -70,7 +79,7 @@ describe("POST /api/login", () => {
{ "x-forwarded-for": "10.0.0.1, 10.0.0.2, 10.0.0.3" },
);
await POST(request);
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("10.0.0.1");
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("10.0.0.1");
});

test("trims whitespace from the first IP in x-forwarded-for", async () => {
Expand All @@ -79,7 +88,7 @@ describe("POST /api/login", () => {
{ "x-forwarded-for": " 192.168.1.1 , 10.0.0.1" },
);
await POST(request);
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("192.168.1.1");
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("192.168.1.1");
});

test("falls back to x-real-ip when x-forwarded-for is absent", async () => {
Expand All @@ -88,7 +97,7 @@ describe("POST /api/login", () => {
{ "x-real-ip": "5.6.7.8" },
);
await POST(request);
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("5.6.7.8");
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("5.6.7.8");
});

test('uses "unknown" when no IP header is present', async () => {
Expand All @@ -97,7 +106,53 @@ describe("POST /api/login", () => {
password: "pass",
});
await POST(request);
expect(mockCheckLoginRateLimit).toHaveBeenCalledWith("unknown");
expect(mockCheckLoginAttempt).toHaveBeenCalledWith("unknown");
});

test("rolls back the attempt on successful login", async () => {
mockFindUser.mockResolvedValue({
id: "1",
email: "user@example.com",
name: "User",
createdAt: new Date(),
passwordHash: "",
avatarUrl: undefined,
});
const request = makeRequest(
{ email: "user@example.com", password: "correctpassword" },
{ "x-forwarded-for": "1.2.3.4" },
);
await POST(request);
expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("1.2.3.4");
});

test("does not roll back on invalid credentials", async () => {
mockFindUser.mockResolvedValue(null);
const request = makeRequest(
{ email: "user@example.com", password: "wrongpassword" },
{ "x-forwarded-for": "1.2.3.4" },
);
const response = await POST(request);
expect(response.status).toBe(401);
expect(mockRollbackLoginAttempt).not.toHaveBeenCalled();
});

test("rolls back the attempt on invalid request body", async () => {
const request = makeRequest({ email: "not-an-email", password: "" });
const response = await POST(request);
expect(response.status).toBe(400);
expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("unknown");
});

test("rolls back the attempt on unexpected server error", async () => {
mockFindUser.mockRejectedValue(new Error("db failure"));
const request = makeRequest(
{ email: "user@example.com", password: "pass" },
{ "x-forwarded-for": "1.2.3.4" },
);
const response = await POST(request);
expect(response.status).toBe(500);
expect(mockRollbackLoginAttempt).toHaveBeenCalledWith("1.2.3.4");
});
});

Expand Down
18 changes: 9 additions & 9 deletions tests/unit/server/trpc/routers/auth.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { beforeEach, describe, expect, test, vi } from "vitest";

vi.mock("@/server/utils/ratelimit", () => ({
checkForgotPasswordRateLimit: vi.fn(),
checkForgotPasswordAttempt: vi.fn(),
}));
vi.mock("@/server/repositories/User", () => ({
findUserByEmail: vi.fn(),
Expand All @@ -15,9 +15,9 @@ vi.mock("@/server/services/logger", () => ({

import { findUserByEmail } from "@/server/repositories/User";
import { authRouter } from "@/server/trpc/routers/auth";
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
import { checkForgotPasswordAttempt } from "@/server/utils/ratelimit";

const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit);
const mockCheckForgotPasswordAttempt = vi.mocked(checkForgotPasswordAttempt);
const mockFindUserByEmail = vi.mocked(findUserByEmail);

function makeCaller(ip = "1.2.3.4") {
Expand All @@ -27,22 +27,22 @@ function makeCaller(ip = "1.2.3.4") {
describe("auth.forgotPassword", () => {
beforeEach(() => {
vi.clearAllMocks();
mockCheckRateLimit.mockResolvedValue(true);
mockCheckForgotPasswordAttempt.mockResolvedValue(true);
mockFindUserByEmail.mockResolvedValue(undefined);
});

describe("rate limiting", () => {
test("allows request when rate limit is not exceeded", async () => {
mockCheckRateLimit.mockResolvedValue(true);
mockCheckForgotPasswordAttempt.mockResolvedValue(true);
const caller = makeCaller("1.2.3.4");
await expect(
caller.forgotPassword({ email: "user@example.com" }),
).resolves.toBe(true);
expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4");
expect(mockCheckForgotPasswordAttempt).toHaveBeenCalledWith("1.2.3.4");
});

test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => {
mockCheckRateLimit.mockResolvedValue(false);
mockCheckForgotPasswordAttempt.mockResolvedValue(false);
const caller = makeCaller("1.2.3.4");
await expect(
caller.forgotPassword({ email: "user@example.com" }),
Expand All @@ -52,11 +52,11 @@ describe("auth.forgotPassword", () => {
test("passes the caller IP to the rate limiter", async () => {
const caller = makeCaller("9.8.7.6");
await caller.forgotPassword({ email: "user@example.com" });
expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6");
expect(mockCheckForgotPasswordAttempt).toHaveBeenCalledWith("9.8.7.6");
});

test("does not look up user when rate limit is exceeded", async () => {
mockCheckRateLimit.mockResolvedValue(false);
mockCheckForgotPasswordAttempt.mockResolvedValue(false);
const caller = makeCaller();
await expect(
caller.forgotPassword({ email: "user@example.com" }),
Expand Down
Loading