Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 175 additions & 29 deletions apps/backend/src/__tests__/connect.test.ts
Original file line number Diff line number Diff line change
@@ -1,39 +1,185 @@
import { describe, it, expect } from 'vitest';
import Fastify, { type FastifyInstance } from 'fastify';
import { describe, it, expect, beforeEach, vi } from 'vitest';

// Mock test for GitHub OAuth callback state validation
// Note: This test verifies the expected behavior of the
// /api/connect/github/callback endpoint when invalid or
// malformed OAuth state values are received.
//
// The implementation in connect.ts now:
// - safely parses OAuth state via parseGoogleState()
// - validates required fields (userId + nonce)
// - redirects invalid callbacks safely
//
// Security note:
// OAuth state validation helps prevent tampered callback
// requests and malformed state payload attacks.
import { connectRoutes } from '../routes/connect.js';

describe('GET /api/connect/github/callback - Invalid OAuth State', () => {
import type { PrismaClient } from '@prisma/client';

it('should redirect with connect_failed when state is invalid', async () => {
// Expected behavior:
// parseGoogleState('invalid_state') -> null
// reply.redirect(`${PUBLIC_APP_URL}/settings?error=connect_failed`)
const USER_ID = 'user-abc';
const VALID_NONCE = 'a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2';
const VALID_STATE = Buffer.from(JSON.stringify({ userId: USER_ID, nonce: VALID_NONCE })).toString('base64');
const ATTACKER_USER_ID = 'user-victim';
const CRAFTED_STATE = Buffer.from(JSON.stringify({ userId: ATTACKER_USER_ID, nonce: 'nonce-never-issued' })).toString('base64');

expect(true).toBe(true);
// Mock encrypt so the token-storage path does not throw in tests
vi.mock('../utils/encryption.js', () => ({
encrypt: vi.fn().mockReturnValue('encrypted_token'),
}));

const mockPrisma = {
oAuthToken: {
findMany: vi.fn(),
upsert: vi.fn(),
delete: vi.fn(),
},
};

// Redis mock that reports as connected and ready
const mockRedis = {
status: 'ready',
set: vi.fn(),
get: vi.fn(),
del: vi.fn(),
};

// Redis mock that simulates connection failure
const mockRedisDown = {
status: 'end',
set: vi.fn(),
get: vi.fn(),
del: vi.fn(),
};

async function buildApp(redisOverride?: object | null): Promise<FastifyInstance> {
const app = Fastify({ logger: false });
app.decorate('prisma', mockPrisma as unknown as PrismaClient);
app.decorate('redis', (redisOverride === undefined ? mockRedis : redisOverride) as any);
app.decorate('authenticate', async (request: any) => {
request.user = { id: USER_ID };
});
vi.stubGlobal('fetch', vi.fn().mockResolvedValue({
json: async () => ({ access_token: 'gh_token_abc', scope: 'user:follow' }),
}));
app.register(connectRoutes, { prefix: '/api/connect' });
await app.ready();
return app;
}

// ─────────────────────────────────────────────────────────────────────────────
// GET /api/connect/github/callback — CSRF nonce enforcement
// ─────────────────────────────────────────────────────────────────────────────

describe('GET /api/connect/github/callback — CSRF nonce enforcement', () => {
beforeEach(() => {
vi.clearAllMocks();
process.env.PUBLIC_APP_URL = 'https://app.devcard.test';
process.env.BACKEND_URL = 'https://api.devcard.test';
process.env.GITHUB_CLIENT_ID = 'gh_client_id';
process.env.GITHUB_CLIENT_SECRET = 'gh_client_secret';
});

it('returns 503 when Redis is unavailable (status !== ready)', async () => {
const app = await buildApp(mockRedisDown);
const res = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${VALID_STATE}`,
});

expect(res.statusCode).toBe(503);
expect(mockPrisma.oAuthToken.upsert).not.toHaveBeenCalled();
expect(mockRedisDown.get).not.toHaveBeenCalled();
});

it('returns 503 when app.redis is null/falsy', async () => {
const app = await buildApp(null);
const res = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${VALID_STATE}`,
});

expect(res.statusCode).toBe(503);
expect(mockPrisma.oAuthToken.upsert).not.toHaveBeenCalled();
});

it('redirects to invalid_state when nonce was never issued (crafted state)', async () => {
mockRedis.get.mockResolvedValue(null);

const app = await buildApp();
const res = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${CRAFTED_STATE}`,
});

expect(res.statusCode).toBe(302);
expect(res.headers.location).toContain('error=invalid_state');
expect(mockPrisma.oAuthToken.upsert).not.toHaveBeenCalled();
});

it('redirects to invalid_state when nonce is present but userId does not match', async () => {
mockRedis.get.mockResolvedValue('user-different');

const app = await buildApp();
const res = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${VALID_STATE}`,
});

expect(res.statusCode).toBe(302);
expect(res.headers.location).toContain('error=invalid_state');
expect(mockPrisma.oAuthToken.upsert).not.toHaveBeenCalled();
});

it('completes the OAuth flow and stores the token when nonce is valid', async () => {
mockRedis.get.mockResolvedValue(USER_ID);
mockRedis.del.mockResolvedValue(1);
mockPrisma.oAuthToken.upsert.mockResolvedValue({});

const app = await buildApp();
const res = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${VALID_STATE}`,
});

it('should reject malformed oauth state payloads', async () => {
// Example malformed payload:
// { invalid: true }
//
// Expected:
// - missing userId
// - missing nonce
// - redirect to connect_failed
expect(res.statusCode).toBe(302);
expect(res.headers.location).toContain('connected=github');
expect(mockRedis.del).toHaveBeenCalledWith(`oauth:nonce:${VALID_NONCE}`);
expect(mockPrisma.oAuthToken.upsert).toHaveBeenCalledOnce();
});

it('consumes the nonce exactly once — replay of the same state is rejected', async () => {
mockRedis.get.mockResolvedValueOnce(USER_ID);
mockRedis.del.mockResolvedValue(1);
mockPrisma.oAuthToken.upsert.mockResolvedValue({});

const app = await buildApp();
const first = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${VALID_STATE}`,
});
expect(first.statusCode).toBe(302);
expect(first.headers.location).toContain('connected=github');

mockRedis.get.mockResolvedValueOnce(null);
const second = await app.inject({
method: 'GET',
url: `/api/connect/github/callback?code=gh_code&state=${VALID_STATE}`,
});
expect(second.statusCode).toBe(302);
expect(second.headers.location).toContain('error=invalid_state');
expect(mockPrisma.oAuthToken.upsert).toHaveBeenCalledOnce();
});

it('redirects to connect_failed when code or state is missing', async () => {
const app = await buildApp();

const noCode = await app.inject({ method: 'GET', url: '/api/connect/github/callback?state=abc' });
expect(noCode.statusCode).toBe(302);
expect(noCode.headers.location).toContain('error=missing_params');

const noState = await app.inject({ method: 'GET', url: '/api/connect/github/callback?code=abc' });
expect(noState.statusCode).toBe(302);
expect(noState.headers.location).toContain('error=missing_params');
});

expect(true).toBe(true);
it('redirects to connect_failed when state is not valid base64 JSON', async () => {
const app = await buildApp();
const res = await app.inject({
method: 'GET',
url: '/api/connect/github/callback?code=gh_code&state=not_valid_base64!!!',
});

expect(res.statusCode).toBe(302);
expect(res.headers.location).toContain('error=connect_failed');
expect(mockPrisma.oAuthToken.upsert).not.toHaveBeenCalled();
});
});
37 changes: 23 additions & 14 deletions apps/backend/src/routes/connect.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import type { FastifyInstance, FastifyRequest, FastifyReply } from 'fastify';
import { randomBytes } from 'crypto';
import { randomBytes } from 'node:crypto';

import { encrypt } from '../utils/encryption.js';

import type { FastifyInstance, FastifyRequest, FastifyReply } from 'fastify';

const GITHUB_AUTH_URL = 'https://github.com/login/oauth/authorize';
const GITHUB_TOKEN_URL = 'https://github.com/login/oauth/access_token';

Expand All @@ -22,7 +24,7 @@
nonce: string;
}

export async function connectRoutes(app: FastifyInstance) {

Check warning on line 27 in apps/backend/src/routes/connect.ts

View workflow job for this annotation

GitHub Actions / backend-ci

Missing return type on function
// ─── Status ───

app.get('/status', {
Expand All @@ -30,9 +32,9 @@
const server = request.server as any;
if (typeof server?.authenticate === 'function') { await server.authenticate(request, reply); return }
if (typeof (app as any).authenticate === 'function') { await (app as any).authenticate(request, reply); return }
try { await request.jwtVerify() } catch (e) { reply.status(401).send({ error: 'Unauthorized' }) }
try { await request.jwtVerify() } catch (_e) { reply.status(401).send({ error: 'Unauthorized' }) }
}],
}, async (request: FastifyRequest, reply: FastifyReply) => {
}, async (request: FastifyRequest, _reply: FastifyReply) => {
const userId = (request.user as any).id;

const tokens = await app.prisma.oAuthToken.findMany({
Expand All @@ -50,7 +52,7 @@
const server = request.server as any;
if (typeof server?.authenticate === 'function') { await server.authenticate(request, reply); return }
if (typeof (app as any).authenticate === 'function') { await (app as any).authenticate(request, reply); return }
try { await request.jwtVerify() } catch (e) { reply.status(401).send({ error: 'Unauthorized' }) }
try { await request.jwtVerify() } catch (_e) { reply.status(401).send({ error: 'Unauthorized' }) }
}],
}, async (request: FastifyRequest, reply: FastifyReply) => {
const userId = (request.user as any).id;
Expand Down Expand Up @@ -93,16 +95,23 @@
return reply.redirect(`${process.env.PUBLIC_APP_URL}/settings?error=connect_failed`);
}

// Verify nonce was issued by this server -- prevents CSRF
const storedUserId = app.redis ? await app.redis.get(`oauth:nonce:${decodedState.nonce}`) : null;
// Hard-fail when Redis is unavailable: proceeding without nonce
// verification would allow CSRF — an attacker could craft a valid-looking
// state for any userId and have a token stored under their target's account.
if (!app.redis || app.redis.status !== 'ready') {
app.log.error('OAuth CSRF check skipped: Redis unavailable — aborting callback');
return reply.status(503).send({ error: 'Service temporarily unavailable. Please try again.' });
}

const storedUserId = await app.redis.get(`oauth:nonce:${decodedState.nonce}`);

if (app.redis && (!storedUserId || storedUserId !== decodedState.userId)) {
app.log.warn({ nonce: decodedState.nonce }, 'OAuth CSRF check failed: nonce mismatch');
if (!storedUserId || storedUserId !== decodedState.userId) {
app.log.warn({ nonce: decodedState.nonce }, 'OAuth CSRF check failed: nonce mismatch or nonce not found');
return reply.redirect(`${process.env.PUBLIC_APP_URL}/settings?error=invalid_state`);
}

// Consume the nonce -- one-time use only (if redis configured)
if (app.redis) await app.redis.del(`oauth:nonce:${decodedState.nonce}`);
// Consume the nonce one-time use, prevents replay attacks
await app.redis.del(`oauth:nonce:${decodedState.nonce}`);

const userId = decodedState.userId;

Expand Down Expand Up @@ -175,7 +184,7 @@
const server = request.server as any;
if (typeof server?.authenticate === 'function') { await server.authenticate(request, reply); return }
if (typeof (app as any).authenticate === 'function') { await (app as any).authenticate(request, reply); return }
try { await request.jwtVerify() } catch (e) { reply.status(401).send({ error: 'Unauthorized' }) }
try { await request.jwtVerify() } catch (_e) { reply.status(401).send({ error: 'Unauthorized' }) }
}],
}, async (request: FastifyRequest<{ Params: { platform: string } }>, reply: FastifyReply) => {
const userId = (request.user as any).id;
Expand All @@ -196,7 +205,7 @@
},
});
return { success: true };
} catch (error) {
} catch (_error) {
return reply.status(404).send({ error: 'Connection not found' });
}
});
Expand All @@ -218,4 +227,4 @@

function generateState(): string {
return randomBytes(32).toString('hex');
}
}
Loading