diff --git a/package-lock.json b/package-lock.json index bfa0142..3de5455 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@kryptsec/oasis", - "version": "0.1.3", + "version": "0.1.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@kryptsec/oasis", - "version": "0.1.3", + "version": "0.1.5", "license": "MIT", "dependencies": { "@anthropic-ai/sdk": "^0.78.0", @@ -18,7 +18,8 @@ "dotenv": "^16.3.1", "gradient-string": "^3.0.0", "openai": "^6.25.0", - "ora": "^8.1.1" + "ora": "^8.1.1", + "zod": "^4.3.6" }, "bin": { "oasis": "bin/oasis" @@ -2680,6 +2681,15 @@ "funding": { "url": "https://github.com/chalk/wrap-ansi?sponsor=1" } + }, + "node_modules/zod": { + "version": "4.3.6", + "resolved": "https://registry.npmjs.org/zod/-/zod-4.3.6.tgz", + "integrity": "sha512-rftlrkhHZOcjDwkGlnUtZZkvaPHCsDATp4pGpuOOMDaTdDDXF91wuVDJoWoPsKX/3YPQ5fHuF3STjcYyKr+Qhg==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } } } } diff --git a/package.json b/package.json index 148e40b..dd0aa3c 100644 --- a/package.json +++ b/package.json @@ -51,7 +51,8 @@ "dotenv": "^16.3.1", "gradient-string": "^3.0.0", "openai": "^6.25.0", - "ora": "^8.1.1" + "ora": "^8.1.1", + "zod": "^4.3.6" }, "devDependencies": { "@types/gradient-string": "^1.1.6", diff --git a/src/lib/analyzer.ts b/src/lib/analyzer.ts index 83d0133..b80e32c 100644 --- a/src/lib/analyzer.ts +++ b/src/lib/analyzer.ts @@ -19,15 +19,18 @@ import { finalizeRubricScore, fallbackOverallScore, } from './scoring.js'; +import { AnalysisResponseSchema } from './schemas.js'; +import { MAX_COMPLETION_TOKENS, ANALYZER_OUTPUT_LIMIT } from './constants.js'; import { withRateLimitRetry } from './retry.js'; -import { isAnthropicProvider, resolveProvider, resolveProviderName } from './providers.js'; +import { isAnthropicProvider, resolveProvider } from './providers.js'; +import { normalizeProvider } from './config.js'; // ============================================================================= // Configuration // ============================================================================= const DEFAULT_ANALYZER_MODEL = 'claude-sonnet-4-5-20250929'; -const MAX_OUTPUT_PER_STEP = 500; +const MAX_OUTPUT_PER_STEP = ANALYZER_OUTPUT_LIMIT; // ============================================================================= // System Prompt @@ -229,7 +232,7 @@ Return ONLY the JSON object, no other text.`; // Response Parser // ============================================================================= -async function parseAnalysisResponse( +export async function parseAnalysisResponse( response: string, runId: string, result: RunResult, @@ -243,39 +246,41 @@ async function parseAnalysisResponse( jsonStr = jsonStr.trim(); try { - const parsed = JSON.parse(jsonStr); + const parsed = AnalysisResponseSchema.parse(JSON.parse(jsonStr)); const analysisResult: AnalysisResult = { runId, analyzedAt: new Date(), analyzerModel: DEFAULT_ANALYZER_MODEL, attackChain: { - phases: parsed.attackChain?.phases || [], - techniques: parsed.attackChain?.techniques || [], - killChainCoverage: parsed.attackChain?.killChainCoverage || [], + phases: parsed.attackChain.phases, + techniques: parsed.attackChain.techniques, + killChainCoverage: parsed.attackChain.killChainCoverage, }, narrative: { - summary: parsed.narrative?.summary || 'Analysis unavailable', - detailed: parsed.narrative?.detailed || '', - keyFindings: parsed.narrative?.keyFindings || [], + summary: parsed.narrative.summary, + detailed: parsed.narrative.detailed, + keyFindings: parsed.narrative.keyFindings, }, behavior: { - approach: parsed.behavior?.approach || 'exploratory', - approachDescription: parsed.behavior?.approachDescription || '', - strengths: parsed.behavior?.strengths || [], - inefficiencies: parsed.behavior?.inefficiencies || [], - decisionQuality: parsed.behavior?.decisionQuality || 0, + approach: (['methodical', 'aggressive', 'exploratory', 'targeted'].includes(parsed.behavior.approach) + ? parsed.behavior.approach as 'methodical' | 'aggressive' | 'exploratory' | 'targeted' + : 'exploratory'), + approachDescription: parsed.behavior.approachDescription, + strengths: parsed.behavior.strengths, + inefficiencies: parsed.behavior.inefficiencies, + decisionQuality: parsed.behavior.decisionQuality, }, strategy: { - reconQuality: parsed.strategy?.reconQuality ?? 0, - exploitEfficiency: parsed.strategy?.exploitEfficiency ?? 0, - adaptability: parsed.strategy?.adaptability ?? 0, - overallScore: parsed.strategy?.overallScore || fallbackOverallScore( - parsed.strategy?.reconQuality ?? 0, - parsed.strategy?.exploitEfficiency ?? 0, - parsed.strategy?.adaptability ?? 0, + reconQuality: parsed.strategy.reconQuality, + exploitEfficiency: parsed.strategy.exploitEfficiency, + adaptability: parsed.strategy.adaptability, + overallScore: parsed.strategy.overallScore ?? fallbackOverallScore( + parsed.strategy.reconQuality, + parsed.strategy.exploitEfficiency, + parsed.strategy.adaptability, ), - scoreBreakdown: parsed.strategy?.scoreBreakdown ?? '', + scoreBreakdown: parsed.strategy.scoreBreakdown, }, }; @@ -393,17 +398,17 @@ function buildRubricScore( milestones: { results: milestoneResults, achieved: [], points: 0 }, qualitative: { reconQuality: { - score: Math.min(llmEval.qualitative?.reconQuality?.score || 0, scoring.qualitative.reconQuality.maxPoints), + score: Math.min(llmEval.qualitative?.reconQuality?.score ?? 0, scoring.qualitative.reconQuality.maxPoints), maxPoints: scoring.qualitative.reconQuality.maxPoints, reasoning: llmEval.qualitative?.reconQuality?.reasoning || '', }, techniqueSelection: { - score: Math.min(llmEval.qualitative?.techniqueSelection?.score || 0, scoring.qualitative.techniqueSelection.maxPoints), + score: Math.min(llmEval.qualitative?.techniqueSelection?.score ?? 0, scoring.qualitative.techniqueSelection.maxPoints), maxPoints: scoring.qualitative.techniqueSelection.maxPoints, reasoning: llmEval.qualitative?.techniqueSelection?.reasoning || '', }, adaptability: { - score: Math.min(llmEval.qualitative?.adaptability?.score || 0, scoring.qualitative.adaptability.maxPoints), + score: Math.min(llmEval.qualitative?.adaptability?.score ?? 0, scoring.qualitative.adaptability.maxPoints), maxPoints: scoring.qualitative.adaptability.maxPoints, reasoning: llmEval.qualitative?.adaptability?.reasoning || '', }, @@ -443,8 +448,8 @@ export function resolveDefaultAnalyzerModel(analyzerProvider: string, benchmarkR const preset = resolveProvider(analyzerProvider); // Same provider as benchmark — use the benchmark model since we know it's available - const benchmarkProvider = resolveProviderName(benchmarkResult.model); - if (benchmarkProvider === resolveProviderName(analyzerProvider)) { + const benchmarkProvider = normalizeProvider(benchmarkResult.model); + if (benchmarkProvider === normalizeProvider(analyzerProvider)) { return benchmarkResult.modelVersion || preset?.models[0] || DEFAULT_ANALYZER_MODEL; } @@ -459,7 +464,7 @@ async function callAnthropicAnalyzer( const response = await withRateLimitRetry( () => client.messages.create({ model, - max_tokens: 4096, + max_tokens: MAX_COMPLETION_TOKENS, system: SYSTEM_PROMPT, messages: [{ role: 'user', content: prompt }], }), @@ -480,7 +485,7 @@ async function callOpenAIAnalyzer( const response = await withRateLimitRetry( () => client.chat.completions.create({ model, - max_completion_tokens: 4096, + max_completion_tokens: MAX_COMPLETION_TOKENS, messages: [ { role: 'system', content: SYSTEM_PROMPT }, { role: 'user', content: prompt }, diff --git a/src/lib/config.ts b/src/lib/config.ts index 15702df..ce9c5ae 100644 --- a/src/lib/config.ts +++ b/src/lib/config.ts @@ -1,6 +1,8 @@ import { existsSync, mkdirSync, readFileSync, writeFileSync, chmodSync, openSync, writeSync, closeSync, constants } from 'fs'; import { join, resolve } from 'path'; import { homedir } from 'os'; +import { ConfigError } from './errors.js'; +import { PROVIDERS, resolveProviderName } from './providers.js'; // XDG Base Directory compliant config path function resolveConfigDir(): string { @@ -48,7 +50,8 @@ export function loadConfig(): OasisConfig { } try { return JSON.parse(readFileSync(CONFIG_FILE, 'utf-8')); - } catch { + } catch (error) { + console.error(new ConfigError(`Failed to load config from ${CONFIG_FILE}`, { error: String(error) }).message); return {}; } } @@ -80,21 +83,6 @@ export function getConfigDir(): string { return CONFIG_DIR; } -// Run-ID validation and safe path resolution -const SAFE_RUN_ID_PATTERN = /^[A-Za-z0-9_-]+$/; - -export function resolveResultPath(runId: string, suffix: '.json' | '.analysis.json' = '.json'): string { - if (!SAFE_RUN_ID_PATTERN.test(runId)) { - throw new Error(`Invalid run ID: "${runId}". Run IDs may only contain letters, numbers, hyphens, and underscores.`); - } - const resultsDir = resolve(getResultsDir()); - const filePath = resolve(resultsDir, `${runId}${suffix}`); - if (!filePath.startsWith(resultsDir)) { - throw new Error(`Invalid run ID: "${runId}". Path escapes results directory.`); - } - return filePath; -} - // Registry URL resolution: config → env var → default const DEFAULT_REGISTRY_URL = 'https://raw.githubusercontent.com/KryptSec/oasis-challenges/main/index.json'; @@ -130,7 +118,8 @@ export function loadCredentials(): OasisCredentials { } try { return JSON.parse(readFileSync(CREDENTIALS_FILE, 'utf-8')); - } catch { + } catch (error) { + console.error(new ConfigError(`Failed to load credentials from ${CREDENTIALS_FILE}`, { error: String(error) }).message); return { apiKeys: {} }; } } @@ -200,15 +189,8 @@ function getApiKeyFromEnv(provider: string): string | undefined { return envVar ? process.env[envVar] : undefined; } -// Provider normalization -export function normalizeProvider(provider: string): string { - const aliases: Record = { - claude: 'anthropic', - grok: 'xai', - gemini: 'google', - }; - return aliases[provider.toLowerCase()] || provider.toLowerCase(); -} +// Provider normalization — delegates to providers.ts single source of truth +export { resolveProviderName as normalizeProvider } from './providers.js'; // Provider URLs (for ollama, custom endpoints) export function getProviderUrl(provider: string): string | undefined { @@ -238,22 +220,13 @@ export function listProviderUrls(): Record { return config.providerUrls || {}; } -// Default URLs for providers -const DEFAULT_PROVIDER_URLS: Record = { - anthropic: 'https://api.anthropic.com', - openai: 'https://api.openai.com/v1', - xai: 'https://api.x.ai/v1', - google: 'https://generativelanguage.googleapis.com/v1beta/openai', - ollama: 'http://localhost:11434/v1', -}; - export function getEffectiveProviderUrl(provider: string): string { - const normalized = normalizeProvider(provider); + const normalized = resolveProviderName(provider); // Custom URL takes precedence const customUrl = getProviderUrl(normalized); if (customUrl) { return customUrl; } - // Fall back to default - return DEFAULT_PROVIDER_URLS[normalized] || ''; + // Fall back to provider preset + return PROVIDERS[normalized]?.baseUrl || ''; } diff --git a/src/lib/constants.ts b/src/lib/constants.ts new file mode 100644 index 0000000..0c8f7a9 --- /dev/null +++ b/src/lib/constants.ts @@ -0,0 +1,21 @@ +// Named constants — replacing magic numbers across the codebase + +// API limits +export const MAX_COMPLETION_TOKENS = 4096; + +// Output truncation +export const STEP_OUTPUT_LIMIT = 10_000; // Stored in step records +export const TOOL_FEEDBACK_LIMIT = 50_000; // Sent back to model as context +export const ANALYZER_OUTPUT_LIMIT = 500; // In analysis prompts + +// Timeouts (ms) +export const DOCKER_EXEC_TIMEOUT = 60_000; +export const DOCKER_WAIT_TIMEOUT = 30_000; +export const DOCKER_POLL_INTERVAL = 2_000; +export const DOCKER_STARTUP_POLL = 2_500; + +// Display +export const VERBOSE_OUTPUT_PREVIEW = 2_000; + +// Memory bounds +export const MAX_CONTEXT_MESSAGES = 40; diff --git a/src/lib/docker.ts b/src/lib/docker.ts index 91bf839..80251fa 100644 --- a/src/lib/docker.ts +++ b/src/lib/docker.ts @@ -5,6 +5,8 @@ */ import { execSync } from 'child_process'; +import { shellEscape } from './shell.js'; +import { DOCKER_WAIT_TIMEOUT, DOCKER_POLL_INTERVAL } from './constants.js'; export interface ContainerSpec { challengeId: string; @@ -15,11 +17,6 @@ export interface ContainerSpec { targetContainerName: string; } -/** Escape a string for safe inclusion in a shell command (single-quote wrapping). */ -function shellEscape(s: string): string { - return "'" + s.replace(/'/g, "'\\''") + "'"; -} - /** * Pull a Docker image. Tries native platform first, falls back to linux/amd64 * if the image has no matching manifest (common for challenge images on Apple Silicon). @@ -36,8 +33,9 @@ export function pullImage(image: string, onProgress?: (line: string) => void): b encoding: 'utf-8', }); return false; - } catch (err: any) { - const msg = err?.stderr || err?.message || ''; + } catch (err: unknown) { + const eObj = err != null && typeof err === 'object' ? err as Record : {}; + const msg = String(eObj.stderr || eObj.message || ''); if (!msg.includes('no matching manifest') && !msg.includes('no match for platform')) { throw err; } @@ -133,10 +131,10 @@ export function pullAndStartContainers( export function waitForTarget( kaliContainer: string, targetUrl: string, - timeoutMs = 30000 + timeoutMs = DOCKER_WAIT_TIMEOUT ): void { const start = Date.now(); - const pollInterval = 2000; + const pollInterval = DOCKER_POLL_INTERVAL; while (Date.now() - start < timeoutMs) { try { diff --git a/src/lib/env-check.ts b/src/lib/env-check.ts index 0301ab4..1678691 100644 --- a/src/lib/env-check.ts +++ b/src/lib/env-check.ts @@ -4,8 +4,10 @@ */ import { execSync } from 'child_process'; -import { existsSync } from 'fs'; -import { resolve } from 'path'; +import { shellEscape } from './shell.js'; +import { DOCKER_STARTUP_POLL } from './constants.js'; +import { ConfigError } from './errors.js'; +import { getErrorStatus } from './retry.js'; export interface EnvCheckResult { ok: boolean; @@ -22,11 +24,6 @@ export interface DockerStartResult { const REQUIRED_KALI_TOOLS = ['curl', 'wget', 'python3']; -/** Escape a string for safe inclusion in a shell command (single-quote wrapping). */ -function shellEscape(s: string): string { - return "'" + s.replace(/'/g, "'\\''") + "'"; -} - /** * Check if Docker daemon is running. */ @@ -90,7 +87,7 @@ export async function ensureDocker( // Poll until Docker daemon is ready const start = Date.now(); - const pollInterval = 2500; + const pollInterval = DOCKER_STARTUP_POLL; while (Date.now() - start < timeoutMs) { await new Promise(resolve => setTimeout(resolve, pollInterval)); @@ -274,10 +271,9 @@ export async function checkApiKey( max_tokens: 1, messages: [{ role: 'user', content: 'hi' }], }); - } catch (validationError: any) { - const errStatus = validationError?.status ?? validationError?.error?.status; + } catch (validationError: unknown) { + const errStatus = getErrorStatus(validationError); if (errStatus === 404) { - // Model deprecated but key was accepted — key is valid return { ok: true, errors: [], hints: [] }; } throw validationError; @@ -298,9 +294,10 @@ export async function checkApiKey( max_tokens: 1, messages: [{ role: 'user', content: 'hi' }], }); - } catch (validationError: any) { - const errStatus = validationError?.status ?? validationError?.error?.status; - const errMsg = (validationError?.message || '').toLowerCase(); + } catch (validationError: unknown) { + const vErr = validationError != null && typeof validationError === 'object' ? validationError as Record : {}; + const errStatus = getErrorStatus(validationError); + const errMsg = (typeof vErr.message === 'string' ? vErr.message : '').toLowerCase(); // 404 = model not found but key was accepted — key is valid if (errStatus === 404) { return { ok: true, errors: [], hints: [] }; @@ -315,7 +312,9 @@ export async function checkApiKey( } // 400 "Incorrect API key" = invalid key — rethrow as 401 for consistent handling if (errStatus === 400 && errMsg.includes('incorrect api key')) { - throw { status: 401, message: validationError.message }; + const err = new ConfigError(typeof vErr.message === 'string' ? vErr.message : 'Incorrect API key'); + Object.assign(err, { status: 401 }); + throw err; } throw validationError; } @@ -327,15 +326,16 @@ export async function checkApiKey( } return { ok: true, errors: [], hints: [] }; - } catch (error: any) { - const status = error?.status ?? error?.statusCode ?? error?.response?.status; + } catch (error: unknown) { + const status = getErrorStatus(error); if (status === 401 || status === 403) { errors.push(`API key is invalid for ${provider}`); hints.push(`Verify your key and reconfigure:`); hints.push(` oasis config set api-key ${provider} `); } else { - errors.push(`Could not validate API key for ${provider}: ${error?.message || 'Unknown error'}`); + const errMsg = error instanceof Error ? error.message : 'Unknown error'; + errors.push(`Could not validate API key for ${provider}: ${errMsg}`); hints.push('Check your network connection and API endpoint'); if (baseUrl) { hints.push(`API URL: ${baseUrl}`); diff --git a/src/lib/errors.ts b/src/lib/errors.ts new file mode 100644 index 0000000..23a5d3a --- /dev/null +++ b/src/lib/errors.ts @@ -0,0 +1,25 @@ +// OASIS Error Hierarchy +// Structured error types for categorized error handling + +export class OasisError extends Error { + constructor(message: string, public readonly context?: Record) { + super(message); + this.name = 'OasisError'; + } +} + +export class ConfigError extends OasisError { + override name = 'ConfigError'; +} + +export class AnalysisError extends OasisError { + override name = 'AnalysisError'; +} + +export class DockerError extends OasisError { + override name = 'DockerError'; +} + +export class ValidationError extends OasisError { + override name = 'ValidationError'; +} diff --git a/src/lib/providers.ts b/src/lib/providers.ts index 02d8ef8..45642af 100644 --- a/src/lib/providers.ts +++ b/src/lib/providers.ts @@ -14,7 +14,7 @@ export const PROVIDERS: Record = { anthropic: { name: 'anthropic', displayName: 'Anthropic', - baseUrl: null, // Uses native Anthropic SDK + baseUrl: 'https://api.anthropic.com', envKey: 'ANTHROPIC_API_KEY', models: ['claude-opus-4-6-20250522', 'claude-sonnet-4-6-20250514', 'claude-sonnet-4-5-20250929', 'claude-haiku-4-5-20251001'], isOpenAICompatible: false, @@ -116,7 +116,7 @@ export async function fetchAvailableModels( }); const response = await client.models.list({ limit: 100 }); const ids = response.data - .map((m: any) => m.id as string) + .map((m) => m.id) .sort(); return ids.length > 0 ? { models: ids, live: true } : fallback; } else { diff --git a/src/lib/report.ts b/src/lib/report.ts index 0f38fbd..94ab24d 100644 --- a/src/lib/report.ts +++ b/src/lib/report.ts @@ -4,6 +4,7 @@ import Table from 'cli-table3'; import { execSync } from 'child_process'; import type { RunResult, AttackTechnique, AnalysisResult } from './types.js'; +import type { JsonReport } from './schemas.js'; import { colors, status, sectionHeader, printBox, divider, renderScoreBar, formatScore } from './display.js'; export function copyToClipboard(text: string): boolean { @@ -353,7 +354,7 @@ export function generateAnalysisTextReport(analysis: AnalysisResult): string { // ============================================================================= export function generateJsonReport(result: RunResult, analysis?: AnalysisResult): string { - const report: any = { + const report: JsonReport = { metadata: { runId: result.id, model: result.modelVersion, diff --git a/src/lib/results-path.ts b/src/lib/results-path.ts index 08b9114..08bc9c8 100644 --- a/src/lib/results-path.ts +++ b/src/lib/results-path.ts @@ -1,23 +1,24 @@ import { isAbsolute, relative as pathRelative, resolve as pathResolve } from 'path'; import { getResultsDir } from './config.js'; +import { ValidationError } from './errors.js'; const RUN_ID_PATTERN = /^[A-Za-z0-9_-]+$/; -export class InvalidRunIdError extends Error { +export class InvalidRunIdError extends ValidationError { readonly runId: string; constructor(runId: string) { - super(`Invalid run ID: ${runId}`); + super(`Invalid run ID: ${runId}`, { runId }); this.name = 'InvalidRunIdError'; this.runId = runId; } } -export class ResultPathEscapeError extends Error { +export class ResultPathEscapeError extends ValidationError { readonly runId: string; constructor(runId: string) { - super(`Run ID resolves outside results directory: ${runId}`); + super(`Run ID resolves outside results directory: ${runId}`, { runId }); this.name = 'ResultPathEscapeError'; this.runId = runId; } diff --git a/src/lib/retry.ts b/src/lib/retry.ts index a586c8e..6558bd3 100644 --- a/src/lib/retry.ts +++ b/src/lib/retry.ts @@ -1,49 +1,58 @@ // Rate-limit retry utilities: 429/5xx handling with exponential backoff and Retry-After support. import chalk from 'chalk'; +import { OasisError } from './errors.js'; export const RATE_LIMIT_MAX_RETRIES = 3; export const RATE_LIMIT_BASE_DELAY_MS = 2000; -export class QuotaExceededError extends Error { +export class QuotaExceededError extends OasisError { constructor( message: string, public readonly provider?: string, public readonly model?: string, ) { - super(message); + super(message, { provider, model }); this.name = 'QuotaExceededError'; } } export function isQuotaExceededError(error: unknown): boolean { - const err = error as { - code?: string; message?: string; - error?: { code?: string; message?: string }; - }; - if (err?.code === 'insufficient_quota') return true; - if (err?.error?.code === 'insufficient_quota') return true; - if (typeof err?.message === 'string' && + if (error == null || typeof error !== 'object') return false; + const err = error as Record; + if (err.code === 'insufficient_quota') return true; + if (err.error != null && typeof err.error === 'object' && + (err.error as Record).code === 'insufficient_quota') return true; + if (typeof err.message === 'string' && err.message.toLowerCase().includes('exceeded your current quota')) return true; return false; } export function getErrorStatus(error: unknown): number | undefined { - const err = error as { status?: number; statusCode?: number; response?: { status?: number } }; - return err?.status ?? err?.statusCode ?? err?.response?.status; + if (error == null || typeof error !== 'object') return undefined; + const err = error as Record; + if (typeof err.status === 'number') return err.status; + if (typeof err.statusCode === 'number') return err.statusCode; + if (err.response != null && typeof err.response === 'object') { + const resp = err.response as Record; + if (typeof resp.status === 'number') return resp.status; + } + return undefined; } export function getRetryAfterHeader(error: unknown): string | undefined { - const err = error as { - headers?: Headers | Record; - response?: { headers?: Headers | Record }; - }; - const headers = err?.headers ?? err?.response?.headers; - if (!headers) return undefined; - if (typeof (headers as Headers).get === 'function') { - return (headers as Headers).get?.('retry-after') ?? undefined; + if (error == null || typeof error !== 'object') return undefined; + const err = error as Record; + const headersSource = err.headers ?? + (err.response != null && typeof err.response === 'object' + ? (err.response as Record).headers + : undefined); + if (headersSource == null || typeof headersSource !== 'object') return undefined; + if (typeof (headersSource as { get?: unknown }).get === 'function') { + return ((headersSource as Headers).get('retry-after')) ?? undefined; } - return (headers as Record)?.['retry-after'] ?? (headers as Record)?.['Retry-After']; + const hdr = headersSource as Record; + return hdr['retry-after'] ?? hdr['Retry-After']; } export function isRetryableStatus(status: number | undefined): boolean { @@ -63,9 +72,9 @@ export function getRetryDelayMs(attempt: number, error: unknown): number { export const DEFAULT_API_TIMEOUT_MS = 120_000; // 2 minutes -export class ApiTimeoutError extends Error { +export class ApiTimeoutError extends OasisError { constructor(context: string, timeoutMs: number) { - super(`${context}: timed out after ${timeoutMs / 1000}s`); + super(`${context}: timed out after ${timeoutMs / 1000}s`, { context, timeoutMs }); this.name = 'ApiTimeoutError'; } } diff --git a/src/lib/runner.ts b/src/lib/runner.ts index b373683..eaad2a3 100644 --- a/src/lib/runner.ts +++ b/src/lib/runner.ts @@ -2,18 +2,42 @@ import Anthropic from '@anthropic-ai/sdk'; import OpenAI from 'openai'; -import { execFileSync } from 'child_process'; +import { execFileSync, execSync } from 'child_process'; import chalk from 'chalk'; import { writeFileSync, readFileSync, mkdirSync, existsSync } from 'fs'; import { randomUUID } from 'crypto'; import { resolve } from 'path'; -import { wasSuccessful, classifyToAttack, classifyCommand } from './classifier.js'; +import { wasSuccessful, classifyToAttack, classifyCommand, extractTool } from './classifier.js'; +import { ToolInputSchema } from './schemas.js'; import type { RunResult, RunnerConfig, Step, TokenUsage, AttackTechnique, ChallengeConfig, AnalysisResult } from './types.js'; import { isAnthropicProvider, resolveProvider } from './providers.js'; import { withRateLimitRetry, getErrorStatus, RATE_LIMIT_MAX_RETRIES } from './retry.js'; +import { isValidRunId } from './results-path.js'; +import { + MAX_COMPLETION_TOKENS, + STEP_OUTPUT_LIMIT, + TOOL_FEEDBACK_LIMIT, + DOCKER_EXEC_TIMEOUT, + VERBOSE_OUTPUT_PREVIEW, + MAX_CONTEXT_MESSAGES, +} from './constants.js'; const FLAG_PATTERN = /KX\{[a-f0-9]+\}/i; +/** + * Sliding window for message arrays — prevents unbounded context growth. + * Always keeps the first message (system/user prompt) + the last N messages. + */ +export function trimMessages(messages: T[]): T[] { + if (messages.length <= MAX_CONTEXT_MESSAGES) return messages; + let start = 0; + const tail = messages.slice(-MAX_CONTEXT_MESSAGES + 1); + while (start < tail.length && tail[start].role === messages[0].role) { + start++; + } + return [messages[0], ...tail.slice(start)]; +} + // ============================================================================= // Thinking-Tag Stripping & Fallback Command Extraction // ============================================================================= @@ -152,37 +176,67 @@ export function buildDockerExecInvocation(command: string, containerName: string }; } -function executeCommand(command: string, containerName: string, verbose: boolean): string { +export function extractErrorOutput(error: unknown): string { + const stderr = error != null && typeof error === 'object' && 'stderr' in error + ? (typeof error.stderr === 'string' + ? error.stderr + : Buffer.isBuffer(error.stderr) ? error.stderr.toString('utf8') : '') + : ''; + return stderr || (error instanceof Error ? error.message : 'Command failed'); +} + +const DOCKER_TRANSIENT_PATTERNS = [ + 'is not running', + 'No such container', + 'connection refused', + 'Cannot connect to the Docker daemon', +]; + +function isDockerTransientError(error: unknown): boolean { + if (error == null || typeof error !== 'object') return true; // no exit code → likely connectivity + const err = error as Record; + if (typeof err.status !== 'number' || err.status === 0) return true; + // Non-zero exit code — check stderr for Docker-specific transient patterns + const stderr = typeof err.stderr === 'string' + ? err.stderr + : Buffer.isBuffer(err.stderr) ? err.stderr.toString('utf8') : ''; + return DOCKER_TRANSIENT_PATTERNS.some(p => stderr.includes(p)); +} + +function executeCommand(command: string, containerName: string, verbose: boolean, maxAttempts = 3): string { if (verbose) { console.log(chalk.yellow(`\n> ${command}`)); } - try { - const invocation = buildDockerExecInvocation(command, containerName); - const result = execFileSync(invocation.command, invocation.args, { - input: invocation.input, - encoding: 'utf8', - timeout: 60000, - maxBuffer: 10 * 1024 * 1024, - stdio: ['pipe', 'pipe', 'pipe'], - }); - const output = result.trim(); - if (verbose) { - console.log(chalk.gray(output.substring(0, 2000) + (output.length > 2000 ? '\n... (truncated)' : ''))); - } - return output; - } catch (error: unknown) { - const errorWithStderr = error as { stderr?: string | Buffer; message?: string }; - const stderr = typeof errorWithStderr.stderr === 'string' - ? errorWithStderr.stderr - : Buffer.isBuffer(errorWithStderr.stderr) - ? errorWithStderr.stderr.toString('utf8') - : ''; - const errorOutput = stderr || errorWithStderr.message || 'Command failed'; - if (verbose) { - console.log(chalk.red(errorOutput)); + for (let attempt = 1; attempt <= maxAttempts; attempt++) { + try { + const invocation = buildDockerExecInvocation(command, containerName); + const result = execFileSync(invocation.command, invocation.args, { + input: invocation.input, + encoding: 'utf8', + timeout: DOCKER_EXEC_TIMEOUT, + maxBuffer: 10 * 1024 * 1024, + stdio: ['pipe', 'pipe', 'pipe'], + }); + const output = result.trim(); + if (verbose) { + console.log(chalk.gray(output.substring(0, VERBOSE_OUTPUT_PREVIEW) + (output.length > VERBOSE_OUTPUT_PREVIEW ? '\n... (truncated)' : ''))); + } + return output; + } catch (error: unknown) { + if (attempt === maxAttempts || !isDockerTransientError(error)) { + const errorOutput = extractErrorOutput(error); + if (verbose) { + console.log(chalk.red(errorOutput)); + } + return errorOutput; + } + if (verbose) { + console.log(chalk.yellow(`Docker exec failed (attempt ${attempt}/${maxAttempts}), retrying...`)); + } + execSync(`sleep ${attempt}`); // 1s, 2s backoff } - return errorOutput; } + return 'Command failed after retries'; } /** @@ -202,7 +256,7 @@ function executeAndRecordStep(opts: { const startTime = new Date(); const output = executeCommand(opts.command, opts.containerName, opts.verbose); const endTime = new Date(); - const tool = opts.command.trim().split(/\s+/)[0] || 'unknown'; + const tool = extractTool(opts.command); const success = wasSuccessful(opts.command, output); const technique = classifyToAttack(opts.command); @@ -214,7 +268,7 @@ function executeAndRecordStep(opts: { reasoning: opts.currentReasoning, type: 'tool_call', command: opts.command, - output: output.substring(0, 10000), + output: output.substring(0, STEP_OUTPUT_LIMIT), technique, methodology: classifyCommand(opts.command), tool, @@ -356,15 +410,16 @@ async function runClaudeAgent(config: RunnerConfig): Promise { console.log(chalk.blue(`\n--- Iteration ${iterations} ---`)); } + const trimmedMessages = trimMessages(messages); let response: Awaited>; try { response = await withRateLimitRetry( () => client.messages.create({ model: config.modelId, - max_tokens: 4096, + max_tokens: MAX_COMPLETION_TOKENS, system: systemPrompt, tools: [runCommandTool], - messages, + messages: trimmedMessages, }), `Iteration ${iterations}`, config.verbose, @@ -418,14 +473,31 @@ async function runClaudeAgent(config: RunnerConfig): Promise { } if (block.type === 'tool_use') { - const toolInput = block.input as { command: string }; - const command = toolInput.command; + const toolInput = ToolInputSchema.safeParse(block.input); + if (!toolInput.success) { + if (config.verbose) { + console.log(chalk.yellow(`\nSkipping invalid tool input: ${JSON.stringify(block.input)}`)); + } + messages.push({ role: 'assistant', content: assistantContent }); + messages.push({ + role: 'user', + content: [{ + type: 'tool_result', + tool_use_id: block.id, + content: `Error: invalid tool input: ${JSON.stringify(block.input)}`, + is_error: true, + }], + }); + assistantContent = []; + continue; + } + const command = toolInput.data.command; const commandStartTime = new Date(); const output = executeCommand(command, containerName, config.verbose || false); const commandEndTime = new Date(); - const tool = command.trim().split(/\s+/)[0] || 'unknown'; + const tool = extractTool(command); const success = wasSuccessful(command, output); const technique = classifyToAttack(command); @@ -436,7 +508,7 @@ async function runClaudeAgent(config: RunnerConfig): Promise { reasoning: currentReasoning, type: 'tool_call', command, - output: output.substring(0, 10000), + output: output.substring(0, STEP_OUTPUT_LIMIT), technique, methodology: classifyCommand(command), tool, @@ -459,7 +531,7 @@ async function runClaudeAgent(config: RunnerConfig): Promise { content: [{ type: 'tool_result', tool_use_id: block.id, - content: output.substring(0, 50000), + content: output.substring(0, TOOL_FEEDBACK_LIMIT), }], }); @@ -477,8 +549,8 @@ async function runClaudeAgent(config: RunnerConfig): Promise { } } } - } catch (error: any) { - agentError = error?.message || String(error); + } catch (error: unknown) { + agentError = error instanceof Error ? error.message : String(error); if (config.verbose) { console.error(chalk.red(`\nAgent error: ${agentError}`)); } @@ -499,7 +571,7 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { let baseURL = config.baseUrl; if (!baseURL && provider) { - baseURL = provider.baseUrl || 'https://api.openai.com/v1'; + baseURL = provider.baseUrl || undefined; } let apiKey = config.apiKey; @@ -572,13 +644,14 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { console.log(chalk.blue(`\n--- Iteration ${iterations} ---`)); } + const trimmedOaiMessages = trimMessages(messages); let response: Awaited>; try { response = await withRateLimitRetry( () => client.chat.completions.create({ model: config.modelId, - max_completion_tokens: 4096, - messages, + max_completion_tokens: MAX_COMPLETION_TOKENS, + messages: trimmedOaiMessages, tools, }), `Iteration ${iterations}`, @@ -608,7 +681,9 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { let currentReasoning = ''; - const reasoningText = ('reasoning_content' in assistantMessage ? (assistantMessage as Record).reasoning_content as string : null) || assistantMessage.content || ''; + const reasoningText = ('reasoning_content' in assistantMessage && typeof (assistantMessage as Record).reasoning_content === 'string' + ? (assistantMessage as Record).reasoning_content as string + : null) || assistantMessage.content || ''; if (reasoningText) { // Strip thinking tags for display & reasoning, keep original content for flag matching const displayText = stripThinkingTags(assistantMessage.content || ''); @@ -638,7 +713,18 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { if (assistantMessage.tool_calls && assistantMessage.tool_calls.length > 0) { for (const toolCall of assistantMessage.tool_calls) { if (toolCall.type !== 'function') continue; - const args = JSON.parse(toolCall.function.arguments); + let args: { command: string }; + try { + args = ToolInputSchema.parse(JSON.parse(toolCall.function.arguments)); + } catch { + // Invalid tool input — return error to model so it can recover + messages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: `Error: invalid tool arguments: ${toolCall.function.arguments}`, + }); + continue; + } const result = executeAndRecordStep({ command: args.command, containerName, @@ -657,7 +743,7 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { messages.push({ role: 'tool', tool_call_id: toolCall.id, - content: result.output.substring(0, 50000), + content: result.output.substring(0, TOOL_FEEDBACK_LIMIT), }); } } else { @@ -684,7 +770,7 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { if (result.flag) foundFlag = result.flag; // Feed output back as a user message (no tool_call_id available) - messages.push({ role: 'user', content: `Command output:\n${result.output.substring(0, 50000)}` }); + messages.push({ role: 'user', content: `Command output:\n${result.output.substring(0, TOOL_FEEDBACK_LIMIT)}` }); } else if (choice.finish_reason === 'stop' && !foundFlag) { if (config.verbose) { console.log(chalk.yellow('\nAgent finished without finding flag.')); @@ -693,8 +779,8 @@ async function runOpenAIAgent(config: RunnerConfig): Promise { } } } - } catch (error: any) { - agentError = error?.message || String(error); + } catch (error: unknown) { + agentError = error instanceof Error ? error.message : String(error); if (config.verbose) { console.error(chalk.red(`\nAgent error: ${agentError}`)); } @@ -786,8 +872,7 @@ export function saveAnalysisResult( mkdirSync(resultsDir, { recursive: true }); } - const SAFE_RUN_ID = /^[A-Za-z0-9_-]+$/; - if (!SAFE_RUN_ID.test(runId)) { + if (!isValidRunId(runId)) { throw new Error(`Invalid run ID: "${runId}"`); } const jsonPath = resolve(resultsDir, `${runId}.analysis.json`); diff --git a/src/lib/schemas.ts b/src/lib/schemas.ts new file mode 100644 index 0000000..a037858 --- /dev/null +++ b/src/lib/schemas.ts @@ -0,0 +1,127 @@ +// Zod schemas for runtime validation of untrusted data boundaries +// (LLM responses, tool inputs, parsed JSON) + +import { z } from 'zod'; + +// ============================================================================= +// LLM Analysis Response — validates JSON from analyzer.ts parseAnalysisResponse +// ============================================================================= + +export const AnalysisResponseSchema = z.object({ + attackChain: z.object({ + phases: z.array(z.object({ + phase: z.string(), + stepRange: z.tuple([z.number(), z.number()]), + description: z.string(), + techniques: z.array(z.string()), + })).default([]), + techniques: z.array(z.object({ + id: z.string(), + name: z.string(), + tactic: z.string(), + description: z.string().default(''), + stepsUsed: z.array(z.number()).default([]), + confidence: z.number().default(0), + })).default([]), + killChainCoverage: z.array(z.string()).default([]), + }).default({ phases: [], techniques: [], killChainCoverage: [] }), + + narrative: z.object({ + summary: z.string().default('Analysis unavailable'), + detailed: z.string().default(''), + keyFindings: z.array(z.string()).default([]), + }).default({ summary: 'Analysis unavailable', detailed: '', keyFindings: [] }), + + behavior: z.object({ + approach: z.string().default('exploratory'), + approachDescription: z.string().default(''), + strengths: z.array(z.string()).default([]), + inefficiencies: z.array(z.string()).default([]), + decisionQuality: z.number().default(0), + }).default({ approach: 'exploratory', approachDescription: '', strengths: [], inefficiencies: [], decisionQuality: 0 }), + + strategy: z.object({ + reconQuality: z.number().default(0), + exploitEfficiency: z.number().default(0), + adaptability: z.number().default(0), + overallScore: z.number().optional(), + scoreBreakdown: z.string().default(''), + }).default({ reconQuality: 0, exploitEfficiency: 0, adaptability: 0, scoreBreakdown: '' }), + + rubricEvaluation: z.object({ + milestones: z.array(z.object({ + id: z.string(), + achieved: z.boolean(), + reasoning: z.string(), + })).default([]), + qualitative: z.object({ + reconQuality: z.object({ score: z.number(), reasoning: z.string() }).default({ score: 0, reasoning: '' }), + techniqueSelection: z.object({ score: z.number(), reasoning: z.string() }).default({ score: 0, reasoning: '' }), + adaptability: z.object({ score: z.number(), reasoning: z.string() }).default({ score: 0, reasoning: '' }), + }).default({ + reconQuality: { score: 0, reasoning: '' }, + techniqueSelection: { score: 0, reasoning: '' }, + adaptability: { score: 0, reasoning: '' }, + }), + penalties: z.array(z.object({ + type: z.string(), + reason: z.string(), + })).default([]), + }).optional(), +}).passthrough(); + +export type AnalysisResponse = z.infer; + +// ============================================================================= +// Tool Input — validates { command: string } from Claude/OpenAI tool_use +// ============================================================================= + +export const ToolInputSchema = z.object({ + command: z.string().min(1), +}); + +export type ToolInput = z.infer; + +// ============================================================================= +// JSON Report — typed interface for report.ts generateJsonReport +// ============================================================================= + +export interface JsonReport { + metadata: { + runId: string; + model: string; + provider: string; + challenge: string; + startTime: Date; + endTime: Date; + }; + result: { + success: boolean; + flag: string | null; + totalTime: number; + iterations: number; + tokens: { input: number; output: number; total: number }; + }; + techniques: unknown[]; + tacticBreakdown: Record; + toolsUsed: string[]; + steps: Array<{ + iteration: number; + command: string | undefined; + tool: string | undefined; + success: boolean | undefined; + duration: number; + technique: string | null; + reasoning: string | null; + }>; + analysis?: { + overallScore: number; + approach: string; + narrative: string; + keyFindings: string[]; + strategy: unknown; + behavior: unknown; + attackChain: unknown; + rubricScore: unknown; + }; +} diff --git a/src/lib/shell.ts b/src/lib/shell.ts new file mode 100644 index 0000000..250cb91 --- /dev/null +++ b/src/lib/shell.ts @@ -0,0 +1,6 @@ +// Shared shell utility + +/** Escape a string for safe inclusion in a shell command (single-quote wrapping). */ +export function shellEscape(s: string): string { + return "'" + s.replace(/'/g, "'\\''") + "'"; +} diff --git a/tests/unit/analyzer.test.ts b/tests/unit/analyzer.test.ts index bf7d410..41a0bd5 100644 --- a/tests/unit/analyzer.test.ts +++ b/tests/unit/analyzer.test.ts @@ -1,6 +1,6 @@ import { describe, it, expect } from 'vitest'; import { calculateKSM, fallbackOverallScore } from '../../src/lib/scoring.js'; -import { resolveDefaultAnalyzerModel, DEFAULT_ANALYZER_MODEL } from '../../src/lib/analyzer.js'; +import { resolveDefaultAnalyzerModel, DEFAULT_ANALYZER_MODEL, parseAnalysisResponse } from '../../src/lib/analyzer.js'; import type { RunResult } from '../../src/lib/types.js'; function makeRunResult(model: string, modelVersion: string): RunResult { @@ -205,3 +205,61 @@ describe('calculateKSM edge cases', () => { expect(calculateKSM(101, 0)).toBe(30); }); }); + +// ============================================================================= +// parseAnalysisResponse — malformed LLM output handling +// ============================================================================= + +describe('parseAnalysisResponse', () => { + const dummyResult = makeRunResult('anthropic', 'claude-3'); + + it('returns parseFailed for empty string', async () => { + const result = await parseAnalysisResponse('', 'run-1', dummyResult); + expect(result.parseFailed).toBe(true); + }); + + it('returns parseFailed for truncated JSON', async () => { + const result = await parseAnalysisResponse('{"attackChain": {"phases": [', 'run-1', dummyResult); + expect(result.parseFailed).toBe(true); + }); + + it('provides graceful defaults for valid JSON with missing fields', async () => { + const result = await parseAnalysisResponse('{}', 'run-1', dummyResult); + expect(result.parseFailed).toBeUndefined(); + expect(result.attackChain.phases).toEqual([]); + expect(result.narrative.summary).toBe('Analysis unavailable'); + expect(result.behavior.approach).toBe('exploratory'); + expect(result.strategy.overallScore).toBe(0); + }); + + it('preserves overallScore: 0 without triggering fallback', async () => { + const json = JSON.stringify({ + strategy: { reconQuality: 80, exploitEfficiency: 70, adaptability: 90, overallScore: 0, scoreBreakdown: 'test' }, + }); + const result = await parseAnalysisResponse(json, 'run-1', dummyResult); + expect(result.strategy.overallScore).toBe(0); + }); + + it('preserves decisionQuality: 0', async () => { + const json = JSON.stringify({ + behavior: { decisionQuality: 0, approach: 'methodical' }, + }); + const result = await parseAnalysisResponse(json, 'run-1', dummyResult); + expect(result.behavior.decisionQuality).toBe(0); + }); + + it('passes through extra fields without error', async () => { + const json = JSON.stringify({ + attackChain: { phases: [], techniques: [], killChainCoverage: [] }, + extraField: 'should not break', + }); + const result = await parseAnalysisResponse(json, 'run-1', dummyResult); + expect(result.parseFailed).toBeUndefined(); + }); + + it('strips markdown code fences', async () => { + const json = '```json\n{"strategy": {"overallScore": 42}}\n```'; + const result = await parseAnalysisResponse(json, 'run-1', dummyResult); + expect(result.strategy.overallScore).toBe(42); + }); +}); diff --git a/tests/unit/runner.test.ts b/tests/unit/runner.test.ts index 60e1082..0d779f9 100644 --- a/tests/unit/runner.test.ts +++ b/tests/unit/runner.test.ts @@ -1,5 +1,5 @@ import { describe, it, expect } from 'vitest'; -import { buildDockerExecInvocation, stripThinkingTags, extractCommandFromText, findJsonBlocks } from '../../src/lib/runner.js'; +import { buildDockerExecInvocation, stripThinkingTags, extractCommandFromText, findJsonBlocks, extractErrorOutput, trimMessages } from '../../src/lib/runner.js'; describe('buildDockerExecInvocation', () => { it('uses docker exec with stdin script mode (no bash -c)', () => { @@ -187,3 +187,145 @@ describe('findJsonBlocks', () => { expect(findJsonBlocks(input)).toEqual(['{"complete":true}']); }); }); + +// ============================================================================= +// extractErrorOutput +// ============================================================================= + +describe('extractErrorOutput', () => { + it('extracts stderr string from error object', () => { + const err = { stderr: 'permission denied', message: 'fallback' }; + expect(extractErrorOutput(err)).toBe('permission denied'); + }); + + it('extracts stderr Buffer from error object', () => { + const err = { stderr: Buffer.from('buffer error'), message: 'fallback' }; + expect(extractErrorOutput(err)).toBe('buffer error'); + }); + + it('falls back to Error message when no stderr', () => { + expect(extractErrorOutput(new Error('something broke'))).toBe('something broke'); + }); + + it('returns default for null', () => { + expect(extractErrorOutput(null)).toBe('Command failed'); + }); + + it('returns default for undefined', () => { + expect(extractErrorOutput(undefined)).toBe('Command failed'); + }); + + it('handles string errors', () => { + expect(extractErrorOutput('raw string')).toBe('Command failed'); + }); + + it('prefers stderr over message', () => { + const err = new Error('msg'); + (err as any).stderr = 'stderr output'; + expect(extractErrorOutput(err)).toBe('stderr output'); + }); +}); + +// ============================================================================= +// trimMessages +// ============================================================================= + +describe('trimMessages', () => { + const MAX_CONTEXT_MESSAGES = 40; // mirrors constants.ts + + function makeMessages(count: number, startRole: 'user' | 'assistant' = 'user') { + const roles = ['user', 'assistant'] as const; + const offset = startRole === 'user' ? 0 : 1; + return Array.from({ length: count }, (_, i) => ({ + role: roles[(i + offset) % 2], + content: `msg-${i}`, + })); + } + + it('returns messages unchanged when under limit', () => { + const msgs = makeMessages(10); + expect(trimMessages(msgs)).toEqual(msgs); + }); + + it('returns messages unchanged when at limit', () => { + const msgs = makeMessages(MAX_CONTEXT_MESSAGES); + expect(trimMessages(msgs)).toEqual(msgs); + }); + + it('trims messages over limit preserving anchor', () => { + const msgs = makeMessages(MAX_CONTEXT_MESSAGES + 10); + const result = trimMessages(msgs); + // First message preserved + expect(result[0]).toBe(msgs[0]); + // Last message preserved + expect(result[result.length - 1]).toBe(msgs[msgs.length - 1]); + // Length is at most MAX_CONTEXT_MESSAGES + expect(result.length).toBeLessThanOrEqual(MAX_CONTEXT_MESSAGES); + }); + + it('preserves role alternation after trim', () => { + const msgs = makeMessages(MAX_CONTEXT_MESSAGES + 10); + const result = trimMessages(msgs); + for (let i = 1; i < result.length; i++) { + expect(result[i].role).not.toBe(result[i - 1].role); + } + }); + + it('drops adjacent same-role when anchor matches tail[0]', () => { + // Force anchor and tail[0] to share a role by using an even-offset count + // anchor is user (index 0), and we need tail[0] to also be user + // With alternating roles, tail[0] role depends on the slice offset + // Build a custom array where this collision happens + const msgs = makeMessages(MAX_CONTEXT_MESSAGES + 1); + // msgs[0].role = 'user', tail = msgs.slice(-39) + // msgs.slice(-39)[0] = msgs[MAX_CONTEXT_MESSAGES + 1 - 39] = msgs[2] + // msgs[2].role = 'user' — collision! tail[0] should be dropped + const result = trimMessages(msgs); + expect(result[0].role).toBe('user'); + expect(result[1].role).not.toBe('user'); + // Verify no adjacent same-role + for (let i = 1; i < result.length; i++) { + expect(result[i].role).not.toBe(result[i - 1].role); + } + }); + + it('drops multiple consecutive same-role messages at trim boundary', () => { + // Build array where the trim boundary lands on multiple same-role messages + // that collide with the anchor (messages[0]). + // Anchor = user. We need tail[0], tail[1], ... to also be 'user'. + const msgs: { role: string; content: string }[] = [ + { role: 'user', content: 'anchor' }, + ]; + // Fill with alternating roles up to the trim point + for (let i = 1; i <= 5; i++) { + msgs.push({ role: i % 2 === 0 ? 'user' : 'assistant', content: `early-${i}` }); + } + // Now add a block of consecutive 'user' messages (simulating tool results) + // followed by normal alternation to fill past the limit + msgs.push({ role: 'user', content: 'tool-1' }); + msgs.push({ role: 'user', content: 'tool-2' }); + msgs.push({ role: 'user', content: 'tool-3' }); + // Fill remaining with alternating to go past limit + let nextRole: 'assistant' | 'user' = 'assistant'; + while (msgs.length <= MAX_CONTEXT_MESSAGES + 5) { + msgs.push({ role: nextRole, content: `fill-${msgs.length}` }); + nextRole = nextRole === 'assistant' ? 'user' : 'assistant'; + } + + const result = trimMessages(msgs); + // Anchor preserved + expect(result[0].role).toBe('user'); + // result[1] must not be 'user' (anchor collision resolved) + expect(result[1].role).not.toBe('user'); + }); + + it('works with OpenAI-style system role anchor', () => { + const msgs: { role: string; content: string }[] = [ + { role: 'system', content: 'system prompt' }, + ...makeMessages(MAX_CONTEXT_MESSAGES + 10).slice(1), + ]; + const result = trimMessages(msgs); + expect(result[0].role).toBe('system'); + expect(result.length).toBeLessThanOrEqual(MAX_CONTEXT_MESSAGES); + }); +}); diff --git a/tests/unit/shell.test.ts b/tests/unit/shell.test.ts new file mode 100644 index 0000000..ebfd877 --- /dev/null +++ b/tests/unit/shell.test.ts @@ -0,0 +1,59 @@ +import { describe, it, expect } from 'vitest'; +import { shellEscape } from '../../src/lib/shell.js'; + +describe('shellEscape', () => { + it('wraps plain strings in single quotes', () => { + expect(shellEscape('hello')).toBe("'hello'"); + }); + + it('escapes single quotes', () => { + expect(shellEscape("it's")).toBe("'it'\\''s'"); + }); + + it('handles double quotes', () => { + expect(shellEscape('say "hello"')).toBe("'say \"hello\"'"); + }); + + it('handles backticks', () => { + expect(shellEscape('echo `whoami`')).toBe("'echo `whoami`'"); + }); + + it('handles empty string', () => { + expect(shellEscape('')).toBe("''"); + }); + + it('handles newlines', () => { + expect(shellEscape('line1\nline2')).toBe("'line1\nline2'"); + }); + + it('handles null bytes', () => { + expect(shellEscape('test\0null')).toBe("'test\0null'"); + }); + + it('handles unicode', () => { + expect(shellEscape('hello \u{1F600}')).toBe("'hello \u{1F600}'"); + }); + + it('handles nested quotes', () => { + expect(shellEscape("it's a \"test\"")).toBe("'it'\\''s a \"test\"'"); + }); + + it('neutralizes semicolon injection', () => { + const result = shellEscape('; rm -rf /'); + expect(result).toBe("'; rm -rf /'"); + }); + + it('neutralizes command substitution', () => { + const result = shellEscape('$(whoami)'); + expect(result).toBe("'$(whoami)'"); + }); + + it('neutralizes backtick execution', () => { + const result = shellEscape('`id`'); + expect(result).toBe("'`id`'"); + }); + + it('handles multiple consecutive single quotes', () => { + expect(shellEscape("'''")).toBe("''\\'''\\'''\\'''"); + }); +}); diff --git a/vitest.config.ts b/vitest.config.ts index 0ccc3ba..ff8c3f5 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -3,6 +3,6 @@ import { defineConfig } from 'vitest/config'; export default defineConfig({ test: { include: ['tests/**/*.test.ts'], - testTimeout: 10000, + testTimeout: 30000, }, });