Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/app/api/login/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ import z from "zod";
import { setJWT } from "@/auth/jwt";
import { findUserByEmailAndPassword } from "@/server/repositories/User";
import logger from "@/server/services/logger";
import { checkLoginRateLimit } from "@/server/utils/ratelimit";
import { checkLoginRateLimit, getClientIp } from "@/server/utils/ratelimit";

const loginSchema = z.object({
email: z.string().email(),
password: z.string().min(1, "Password is required"),
});

export async function POST(request: NextRequest) {
const forwarded = request.headers.get("x-forwarded-for");
const ip =
(forwarded ? forwarded.split(",")[0].trim() : null) ??
request.headers.get("x-real-ip") ??
"unknown";
const ip = getClientIp(request);
logger.info(`Login request from ${ip}`);

const allowed = await checkLoginRateLimit(ip);
Expand Down
6 changes: 4 additions & 2 deletions src/server/trpc/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import z, { ZodError } from "zod";
import { getServerSession } from "@/auth";
import { ADMIN_USER_EMAIL } from "@/constants";
import { canReadDataSource } from "@/server/utils/auth";
import { getClientIp } from "@/server/utils/ratelimit";
import {
hasPasswordHashSerializer,
serverDataSourceSerializer,
Expand All @@ -14,13 +15,14 @@ import { findOrganisationForUser } from "../repositories/Organisation";
import { findPublishedPublicMapByMapId } from "../repositories/PublicMap";
import { findUserById } from "../repositories/User";

export async function createContext() {
export async function createContext(opts?: { req?: Request }) {
const session = await getServerSession();
let user = null;
if (session.currentUser) {
user = await findUserById(session.currentUser.id);
}
return { user };
const ip = opts?.req ? getClientIp(opts.req) : "unknown";
return { user, ip };
}

export type Context = Awaited<ReturnType<typeof createContext>>;
Expand Down
11 changes: 10 additions & 1 deletion src/server/trpc/routers/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
} from "@/server/repositories/User";
import logger from "@/server/services/logger";
import { sendEmail } from "@/server/services/mailer";
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";
import { publicProcedure, router } from "../index";

export const authRouter = router({
Expand Down Expand Up @@ -82,7 +83,15 @@ export const authRouter = router({
}),
forgotPassword: publicProcedure
.input(z.object({ email: z.string() }))
.mutation(async ({ input }) => {
.mutation(async ({ ctx, input }) => {
const allowed = await checkForgotPasswordRateLimit(ctx.ip);
if (!allowed) {
throw new TRPCError({
code: "TOO_MANY_REQUESTS",
message: "Too many requests, please try again later",
});
}
Comment on lines +89 to +93
Copy link

Copilot AI Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing TRPCError with code: "TOO_MANY_REQUESTS" here will be treated like any other tRPC error by the API handler and (currently) gets logged at error level and sent to Sentry (see src/app/api/trpc/[trpc]/route.ts:16-51, where ACCEPTED_ERROR_CODES is empty). For rate-limiting, this can generate high-volume noise during normal throttling or attacks. Consider handling this case so it’s not captured/logged as an error (e.g., add TOO_MANY_REQUESTS to an accepted/ignored list, or adjust the handler/logging strategy for this error code).

Copilot uses AI. Check for mistakes.

const { email } = input;
const user = await findUserByEmail(email);
if (!user) return true;
Expand Down
33 changes: 29 additions & 4 deletions src/server/utils/ratelimit.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
import { getClient } from "@/server/services/redis";

export function getClientIp(req: Request): string {
const forwarded = req.headers.get("x-forwarded-for");
return (
(forwarded ? forwarded.split(",")[0].trim() : null) ??
req.headers.get("x-real-ip") ??
"unknown"
);
}

const WINDOW_SECONDS = 15 * 60; // 15 minutes
const MAX_ATTEMPTS = 5;
const LOGIN_MAX_ATTEMPTS = 5;
const FORGOT_PASSWORD_MAX_ATTEMPTS = 5;

export async function checkLoginRateLimit(ip: string): Promise<boolean> {
async function checkRateLimit(
key: string,
maxAttempts: number,
): Promise<boolean> {
const redis = getClient();
const key = `rate_limit:login:${ip}`;
const results = await redis
.multi()
.incr(key)
.expire(key, WINDOW_SECONDS)
.exec();
const count = results && results[0] ? (results[0][1] as number) : 0;
return count <= MAX_ATTEMPTS;
return count <= maxAttempts;
}

export async function checkLoginRateLimit(ip: string): Promise<boolean> {
return checkRateLimit(`rate_limit:login:${ip}`, LOGIN_MAX_ATTEMPTS);
}

export async function checkForgotPasswordRateLimit(
ip: string,
): Promise<boolean> {
return checkRateLimit(
`rate_limit:forgot_password:${ip}`,
FORGOT_PASSWORD_MAX_ATTEMPTS,
);
}
7 changes: 4 additions & 3 deletions tests/unit/app/api/login/route.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import { beforeEach, describe, expect, test, vi } from "vitest";
vi.mock("@/server/repositories/User", () => ({
findUserByEmailAndPassword: vi.fn(),
}));
vi.mock("@/server/utils/ratelimit", () => ({
checkLoginRateLimit: vi.fn(),
}));
vi.mock("@/server/utils/ratelimit", async (importOriginal) => {
const actual = await importOriginal();
return { ...(actual as object), checkLoginRateLimit: vi.fn() };
});
vi.mock("@/auth/jwt", () => ({
setJWT: vi.fn(),
}));
Expand Down
67 changes: 67 additions & 0 deletions tests/unit/server/trpc/routers/auth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { beforeEach, describe, expect, test, vi } from "vitest";

vi.mock("@/server/utils/ratelimit", () => ({
checkForgotPasswordRateLimit: vi.fn(),
}));
vi.mock("@/server/repositories/User", () => ({
findUserByEmail: vi.fn(),
}));
vi.mock("@/server/services/mailer", () => ({
sendEmail: vi.fn(),
}));
vi.mock("@/server/services/logger", () => ({
default: { info: vi.fn(), warn: vi.fn(), error: vi.fn() },
}));

import { findUserByEmail } from "@/server/repositories/User";
import { authRouter } from "@/server/trpc/routers/auth";
import { checkForgotPasswordRateLimit } from "@/server/utils/ratelimit";

const mockCheckRateLimit = vi.mocked(checkForgotPasswordRateLimit);
const mockFindUserByEmail = vi.mocked(findUserByEmail);

function makeCaller(ip = "1.2.3.4") {
return authRouter.createCaller({ user: null, ip });
}

describe("auth.forgotPassword", () => {
beforeEach(() => {
vi.clearAllMocks();
mockCheckRateLimit.mockResolvedValue(true);
mockFindUserByEmail.mockResolvedValue(undefined);
});

describe("rate limiting", () => {
test("allows request when rate limit is not exceeded", async () => {
mockCheckRateLimit.mockResolvedValue(true);
const caller = makeCaller("1.2.3.4");
await expect(
caller.forgotPassword({ email: "user@example.com" }),
).resolves.toBe(true);
expect(mockCheckRateLimit).toHaveBeenCalledWith("1.2.3.4");
});

test("throws TOO_MANY_REQUESTS when rate limit is exceeded", async () => {
mockCheckRateLimit.mockResolvedValue(false);
const caller = makeCaller("1.2.3.4");
await expect(
caller.forgotPassword({ email: "user@example.com" }),
).rejects.toMatchObject({ code: "TOO_MANY_REQUESTS" });
});

test("passes the caller IP to the rate limiter", async () => {
const caller = makeCaller("9.8.7.6");
await caller.forgotPassword({ email: "user@example.com" });
expect(mockCheckRateLimit).toHaveBeenCalledWith("9.8.7.6");
});

test("does not look up user when rate limit is exceeded", async () => {
mockCheckRateLimit.mockResolvedValue(false);
const caller = makeCaller();
await expect(
caller.forgotPassword({ email: "user@example.com" }),
).rejects.toThrow();
expect(mockFindUserByEmail).not.toHaveBeenCalled();
});
});
});
Loading