Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions apps/api/src/handlers/chat/chat.integration.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/**
* RAG chat integration test against a real Postgres+pgvector. Gated like the
* others. MockAiProvider (no OPENAI_API_KEY) gives deterministic embeddings +
* answer, so retrieval and citations are assertable offline.
*/
import { afterAll, beforeAll, beforeEach, describe, expect, it } from 'vitest';
import type { Context, APIGatewayProxyEventV2 } from 'aws-lambda';

const RUN = process.env['RUN_INTEGRATION'] === '1';
const INTEGRATION_URL =
process.env['INTEGRATION_DATABASE_URL'] ??
'postgres://clouddocs:clouddocs@localhost:5434/clouddocs';
const LOCAL_HOST_RE = /(?:^|@|\/\/)(?:localhost|127\.0\.0\.1|0\.0\.0\.0)(?::|\/|$)/;
const URL_LOCAL = LOCAL_HOST_RE.test(INTEGRATION_URL);

if (RUN && URL_LOCAL) {
process.env['DATABASE_URL'] = INTEGRATION_URL;
delete process.env['OPENAI_API_KEY'];
}

import { closePool, query } from '../../lib/db/client';
import { getAiProvider } from '../../lib/ai';
import { signAccessToken } from '../../lib/auth/jwt';
import { EmbeddingsRepo } from '../../repositories/embeddings-repo';
import { handler as chatHandler } from './handler';

const lambdaCtx = {} as Context;

function makeEvent(body: unknown, headers: Record<string, string>): APIGatewayProxyEventV2 {
return {
routeKey: 'POST /v1/chat',
rawPath: '/v1/chat',
requestContext: {
http: {
method: 'POST',
path: '/v1/chat',
protocol: 'HTTP/2',
sourceIp: '127.0.0.1',
userAgent: 'vitest',
},
},
headers: { 'user-agent': 'vitest', 'content-type': 'application/json', ...headers },
body: JSON.stringify(body),
isBase64Encoded: false,
} as APIGatewayProxyEventV2;
}

async function seed(): Promise<{ orgId: string; token: string; docId: string }> {
const orgId = (
await query<{ id: string }>(
`INSERT INTO organizations (name, slug) VALUES ('Acme','acme') RETURNING id`,
)
)[0]!.id;
const userId = (
await query<{ id: string }>(
`INSERT INTO users (email, password_hash) VALUES ('d@test.com','x') RETURNING id`,
)
)[0]!.id;
await query(`INSERT INTO memberships (user_id, org_id, role) VALUES ($1,$2,'owner')`, [
userId,
orgId,
]);
const docId = (
await query<{ id: string }>(
`INSERT INTO documents (org_id, uploaded_by, filename, mime_type, size_bytes, s3_key, status)
VALUES ($1,$2,'acme-invoice.pdf','application/pdf',10,'raw-uploads/x/1/a.pdf','ready') RETURNING id`,
[orgId, userId],
)
)[0]!.id;

// One embedded chunk about billing.
const [vec] = (await getAiProvider().embed(['invoice billing payment amount due 6000 net 30']))
.vectors;
await new EmbeddingsRepo(orgId).replaceForDocument(docId, [
{
chunkIndex: 0,
chunkText: 'Invoice total amount due is 6000 USD, net 30 days.',
embedding: vec!,
},
]);

const token = (
await signAccessToken({ userId, email: 'd@test.com', memberships: [{ orgId, role: 'owner' }] })
).token;
return { orgId, token, docId };
}

describe.skipIf(!RUN || !URL_LOCAL)('RAG chat', () => {
let auth: { orgId: string; token: string; docId: string };

beforeAll(() => {
if (!process.env['JWT_PRIVATE_KEY'] || !process.env['JWT_PUBLIC_KEY']) {
throw new Error('JWT keys must be set for integration tests.');
}
});

beforeEach(async () => {
await query(
'TRUNCATE embeddings, ai_analyses, documents, refresh_tokens, invitations, memberships, organizations, users RESTART IDENTITY CASCADE',
);
auth = await seed();
});

afterAll(async () => {
await closePool();
});

it('answers with citations to the relevant document', async () => {
const res = (await chatHandler(
makeEvent(
{ messages: [{ role: 'user', content: 'How much do I owe on the invoice?' }] },
{ authorization: `Bearer ${auth.token}`, 'x-org-id': auth.orgId },
),
lambdaCtx,
)) as { statusCode: number; body: string };

expect(res.statusCode).toBe(200);
const body = JSON.parse(res.body);
expect(typeof body.answer).toBe('string');
expect(body.answer.length).toBeGreaterThan(0);
expect(body.citations.length).toBeGreaterThan(0);
expect(body.citations[0].documentId).toBe(auth.docId);
expect(body.citations[0].filename).toBe('acme-invoice.pdf');
});

it('scopes retrieval to a documentId when provided', async () => {
const res = (await chatHandler(
makeEvent(
{
messages: [{ role: 'user', content: 'summary?' }],
documentId: auth.docId,
},
{ authorization: `Bearer ${auth.token}`, 'x-org-id': auth.orgId },
),
lambdaCtx,
)) as { statusCode: number; body: string };
expect(res.statusCode).toBe(200);
expect(
JSON.parse(res.body).citations.every(
(c: { documentId: string }) => c.documentId === auth.docId,
),
).toBe(true);
});
});
39 changes: 39 additions & 0 deletions apps/api/src/handlers/chat/handler.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { ChatRequestSchema } from '@clouddocs/shared-types';

import {
compose,
jsonResponse,
withActiveOrg,
withAuth,
withErrorHandler,
withJsonBody,
withRequestLogger,
withSecrets,
withValidation,
type LambdaHandler,
} from '../../middlewares';
import { chatUseCase } from './usecase';

/**
* POST /v1/chat — RAG chat over the active org's documents (optionally scoped to
* one document). Stateless: the client sends the recent conversation.
*/
export const handler: LambdaHandler = withSecrets(
withRequestLogger(
compose(
withErrorHandler,
withJsonBody,
)(
withAuth(
withActiveOrg(
withValidation(ChatRequestSchema, async (ctx) => {
const orgId = ctx.orgId;
if (!orgId) throw new Error('Org context missing — middleware bug.');
const result = await chatUseCase(orgId, ctx.body);
return jsonResponse(200, result);
}),
),
),
),
),
);
41 changes: 41 additions & 0 deletions apps/api/src/handlers/chat/usecase.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/**
* RAG chat use case: embed the latest user message, retrieve the most relevant
* chunks (optionally scoped to one document), ground gpt-4o-mini on them, and
* return the answer plus the chunks it was given as citations.
*/
import type { ChatRequest, ChatResponse, Citation } from '@clouddocs/shared-types';

import { getAiProvider, type ChatTurn } from '../../lib/ai';
import { buildChatSystem } from '../../lib/ai/prompts/chat/v1';
import { EmbeddingsRepo } from '../../repositories/embeddings-repo';

const TOP_K = 6;
const SNIPPET_CHARS = 240;

export async function chatUseCase(orgId: string, req: ChatRequest): Promise<ChatResponse> {
const lastUser = [...req.messages].reverse().find((m) => m.role === 'user');
if (!lastUser) return { answer: 'Ask a question to get started.', citations: [] };

const provider = getAiProvider();
const { vectors } = await provider.embed([lastUser.content]);
const chunks = await new EmbeddingsRepo(orgId).search(vectors[0] ?? [], TOP_K, req.documentId);

const system = buildChatSystem(
chunks.map((c) => ({ filename: c.filename, chunkText: c.chunkText })),
);
const turns: ChatTurn[] = [
{ role: 'system', content: system },
...req.messages.map((m) => ({ role: m.role, content: m.content })),
];

const { data: answer } = await provider.chat(turns);

const citations: Citation[] = chunks.map((c) => ({
documentId: c.documentId,
filename: c.filename,
chunkIndex: c.chunkIndex,
snippet: c.chunkText.slice(0, SNIPPET_CHARS),
}));

return { answer, citations };
}
2 changes: 1 addition & 1 deletion apps/api/src/lib/ai/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export type { AiProvider, AiResult, AiUsage, EmbedResult } from './provider';
export type { AiProvider, AiResult, AiUsage, ChatTurn, EmbedResult } from './provider';
export { EMBEDDING_DIMENSIONS } from './provider';
export { OpenAiProvider } from './openai-provider';
export { MockAiProvider } from './mock-provider';
Expand Down
17 changes: 16 additions & 1 deletion apps/api/src/lib/ai/mock-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
*/
import type { ClassifyResult, SummaryResult } from '@clouddocs/shared-types';

import { EMBEDDING_DIMENSIONS, type AiProvider, type AiResult, type EmbedResult } from './provider';
import {
EMBEDDING_DIMENSIONS,
type AiProvider,
type AiResult,
type ChatTurn,
type EmbedResult,
} from './provider';

export class MockAiProvider implements AiProvider {
async summarize(text: string): Promise<AiResult<SummaryResult>> {
Expand All @@ -32,6 +38,15 @@ export class MockAiProvider implements AiProvider {
usage: { model: 'mock', promptVersion: 'embed.mock', costUsd: 0 },
};
}

async chat(turns: ChatTurn[]): Promise<AiResult<string>> {
const lastUser = [...turns].reverse().find((t) => t.role === 'user');
const hasContext = turns.some((t) => t.role === 'system' && t.content.includes('Context:'));
const answer = hasContext
? `Based on your documents, here is the answer to "${lastUser?.content ?? ''}". [1]`
: `I could not find anything relevant in your documents for "${lastUser?.content ?? ''}".`;
return { data: answer, usage: { model: 'mock', promptVersion: 'chat.mock', costUsd: 0 } };
}
}

/**
Expand Down
24 changes: 23 additions & 1 deletion apps/api/src/lib/ai/openai-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import {
} from '@clouddocs/shared-types';

import { AppError } from '../errors';
import type { AiProvider, AiResult, AiUsage, EmbedResult } from './provider';
import type { AiProvider, AiResult, AiUsage, ChatTurn, EmbedResult } from './provider';
import { CHAT_PROMPT_VERSION } from './prompts/chat/v1';
import { SUMMARY_PROMPT_VERSION, SUMMARY_SYSTEM, summaryUserPrompt } from './prompts/summary/v1';
import {
CLASSIFY_PROMPT_VERSION,
Expand Down Expand Up @@ -111,6 +112,27 @@ export class OpenAiProvider implements AiProvider {
};
}

async chat(turns: ChatTurn[]): Promise<AiResult<string>> {
const completion = await this.client.chat.completions.create({
model: MODEL,
messages: turns.map((t) => ({ role: t.role, content: t.content })),
});
const answer = completion.choices[0]?.message.content;
if (!answer) throw new AppError('ai_error', 'OpenAI returned no content.', 502);
const inputTokens = completion.usage?.prompt_tokens;
const outputTokens = completion.usage?.completion_tokens;
return {
data: answer,
usage: {
model: MODEL,
promptVersion: CHAT_PROMPT_VERSION,
...(inputTokens != null ? { inputTokens } : {}),
...(outputTokens != null ? { outputTokens } : {}),
costUsd: estimateCost(inputTokens, outputTokens),
},
};
}

private async complete(
system: string,
user: string,
Expand Down
31 changes: 31 additions & 0 deletions apps/api/src/lib/ai/prompts/chat/v1.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* RAG chat prompt v1. The system message carries the retrieved context; the
* model must answer only from it and admit when the answer isn't there.
*/
export const CHAT_PROMPT_VERSION = 'chat.v1';

export interface ContextChunk {
filename: string;
chunkText: string;
}

export function buildChatSystem(chunks: ContextChunk[]): string {
if (chunks.length === 0) {
return [
'You are CloudDocs AI, a helpful assistant for a document-management product.',
'The user has no documents matching this question. Tell them you could not find',
'anything relevant in their documents, and suggest uploading or rephrasing.',
].join(' ');
}
const context = chunks
.map((c, i) => `[${i + 1}] (from "${c.filename}")\n${c.chunkText}`)
.join('\n\n');
return [
"You are CloudDocs AI, answering questions about the user's documents.",
'Use ONLY the context below to answer. If the answer is not in the context,',
'say you could not find it in their documents — do not invent facts.',
'Be concise. When you use a passage, reference it like [1], [2].',
'\n\nContext:\n',
context,
].join(' ');
}
7 changes: 7 additions & 0 deletions apps/api/src/lib/ai/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ export interface EmbedResult {
usage: AiUsage;
}

export interface ChatTurn {
role: 'system' | 'user' | 'assistant';
content: string;
}

export interface AiProvider {
summarize(text: string): Promise<AiResult<SummaryResult>>;
classify(text: string): Promise<AiResult<ClassifyResult>>;
/** Embed a batch of texts (document chunks or a search query). */
embed(texts: string[]): Promise<EmbedResult>;
/** Single-shot chat completion (non-streaming). Returns the answer text. */
chat(turns: ChatTurn[]): Promise<AiResult<string>>;
}
41 changes: 41 additions & 0 deletions apps/api/src/repositories/embeddings-repo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,45 @@ export class EmbeddingsRepo extends OrgScopedRepository {
);
return Number(rows[0]?.n ?? 0);
}

/**
* Retrieve the top-K chunks nearest the query vector (cosine), for RAG. Joins
* documents for the filename. Optionally scoped to a single document.
* `$1`=org, `$2`=vector, `$3`=limit, `$4`=documentId (when filtering).
*/
async search(
queryVector: number[],
limit: number,
documentId?: string,
): Promise<RetrievedChunk[]> {
const vec = toVectorLiteral(queryVector);
const params: unknown[] = [vec, limit];
let filter = '';
if (documentId) {
params.push(documentId);
filter = 'AND e.document_id = $4';
}
return this.scopedQuery<RetrievedChunk>(
`SELECT e.document_id AS "documentId",
d.filename AS filename,
e.chunk_index AS "chunkIndex",
e.chunk_text AS "chunkText",
(e.embedding <=> $2::vector) AS distance
FROM embeddings e
JOIN documents d ON d.id = e.document_id AND d.org_id = $1
WHERE e.org_id = $1 ${filter}
ORDER BY e.embedding <=> $2::vector
LIMIT $3`,
params,
);
}
}

// type alias (not interface) so it satisfies the scopedQuery `Record` constraint.
export type RetrievedChunk = {
documentId: string;
filename: string;
chunkIndex: number;
chunkText: string;
distance: number;
};
Loading
Loading