Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/commands/ai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ describe('askAi', () => {
stack: null,
doctor: null,
recentHistory: [],
recentShellHistory: [],
};

it('requires ai assistant to be enabled in config', async () => {
Expand Down Expand Up @@ -143,6 +144,7 @@ describe('askAi', () => {
thinkingLevel: 'high',
includeThoughts: true,
},
useSearchGrounding: true,
},
},
}),
Expand Down Expand Up @@ -250,6 +252,43 @@ describe('askAi', () => {
).rejects.toThrow('boom');
});

it('falls back gracefully when browsing options are unsupported', async () => {
await writeConfig({ aiAssistantEnabled: true }, dir);
process.env.DUBSTACK_GEMINI_API_KEY = 'gem-key';
delete process.env.DUBSTACK_AI_GATEWAY_API_KEY;

const streamText = vi
.fn()
.mockImplementationOnce(() => {
throw new Error('unsupported provider option useSearchGrounding');
})
.mockReturnValueOnce({
fullStream: streamFrom(['fallback answer']),
});
const googleModel = vi.fn().mockReturnValue('google-model');
const createGoogleGenerativeAI = vi.fn().mockReturnValue(googleModel);
const createGateway = vi.fn();
const collectAiContext = vi.fn().mockResolvedValue(fakeContext);
const { createBashTool } = createBashToolMock();
const output = createOutputCapture();

const result = await askAi('Explain this stack', dir, {
output: output.stream,
deps: {
streamText,
createGoogleGenerativeAI,
createGateway,
collectAiContext,
createBashTool,
},
});

expect(streamText).toHaveBeenCalledTimes(2);
expect(output.writes.join('')).toContain('Web browsing is unavailable');
expect(result.webBrowsingRequested).toBe(true);
expect(result.webBrowsingUsed).toBe(false);
});

it('requires at least one AI key environment variable', async () => {
await writeConfig({ aiAssistantEnabled: true }, dir);
delete process.env.DUBSTACK_GEMINI_API_KEY;
Expand Down
146 changes: 104 additions & 42 deletions src/commands/ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ interface AskAiOptions {
interface AskAiResult {
provider: 'google' | 'gateway';
modelId: string;
webBrowsingRequested: boolean;
webBrowsingUsed: boolean;
}

const DEFAULT_DEPS: AskAiDependencies = {
Expand All @@ -49,7 +51,7 @@ const THINKING_PROVIDER_OPTIONS = {
includeThoughts: true,
},
},
};
} as const;

const SPINNER_FRAMES = ['-', '\\', '|', '/'] as const;

Expand Down Expand Up @@ -82,50 +84,38 @@ export async function askAi(
'Safety: use bash only when command output is needed. Do not run destructive commands (for example, rm -rf, git reset --hard, git clean -fd), even if the user explicitly asks. This sandbox only allows read-only command families. If the user insists on blocked actions, explain the command is blocked here and provide a manual command they can run themselves at their own risk.',
});

const result = deps.streamText({
model: resolved.model,
system: buildAiSystemPrompt(),
prompt: contextPrompt,
stopWhen: stepCountIs(6),
tools: {
bash: bashToolkit.tools.bash,
},
providerOptions: THINKING_PROVIDER_OPTIONS,
});

const thinkingRenderer = createThinkingRenderer(output);
const webBrowsingRequested = config.ai.webBrowsing.mode === 'model-native';
let webBrowsingUsed = webBrowsingRequested;
let wroteOutput = false;
for await (const part of result.fullStream) {
switch (part.type) {
case 'reasoning-start': {
thinkingRenderer.start();
break;
}
case 'reasoning-delta': {
thinkingRenderer.update(part.text);
break;
}
case 'reasoning-end': {
thinkingRenderer.stop();
break;
}
case 'text-delta': {
thinkingRenderer.pauseForText();
output.write(part.text);
wroteOutput = true;
break;
}
case 'error': {
throw part.error instanceof Error
? part.error
: new DubError('AI assistant stream failed unexpectedly.');
}
default: {
break;
}
const runStream = async (withWebBrowsing: boolean): Promise<boolean> => {
const result = deps.streamText({
model: resolved.model,
system: buildAiSystemPrompt(),
prompt: contextPrompt,
stopWhen: stepCountIs(6),
tools: {
bash: bashToolkit.tools.bash,
},
providerOptions: buildProviderOptions({ withWebBrowsing }) as never,
});
return renderStream(result, output);
};

try {
wroteOutput = await runStream(webBrowsingRequested);
} catch (error) {
if (!isBrowsingUnsupportedError(error)) {
throw error;
}
if (config.ai.webBrowsing.fallback !== 'graceful') {
throw error;
}
webBrowsingUsed = false;
output.write(
'[note] Web browsing is unavailable for this provider/model right now. Continuing with local context and model knowledge.\n',
);
wroteOutput = await runStream(false);
}
thinkingRenderer.stop();

if (wroteOutput) {
output.write('\n');
Expand All @@ -134,9 +124,81 @@ export async function askAi(
return {
provider: resolved.provider,
modelId: resolved.modelId,
webBrowsingRequested,
webBrowsingUsed,
};
}

function buildProviderOptions(options: {
withWebBrowsing: boolean;
}): Record<string, unknown> {
const googleOptions: Record<string, unknown> = {
...(THINKING_PROVIDER_OPTIONS.google as unknown as Record<string, unknown>),
};
if (options.withWebBrowsing) {
googleOptions.useSearchGrounding = true;
}
return { google: googleOptions };
}

async function renderStream(
result: {
fullStream: AsyncIterable<{
type: string;
text?: string;
error?: unknown;
}>;
},
output: WritableLike,
): Promise<boolean> {
const thinkingRenderer = createThinkingRenderer(output);
let wroteOutput = false;
try {
for await (const part of result.fullStream) {
switch (part.type) {
case 'reasoning-start': {
thinkingRenderer.start();
break;
}
case 'reasoning-delta': {
thinkingRenderer.update(part.text ?? '');
break;
}
case 'reasoning-end': {
thinkingRenderer.stop();
break;
}
case 'text-delta': {
thinkingRenderer.pauseForText();
output.write(part.text ?? '');
wroteOutput = true;
break;
}
case 'error': {
throw part.error instanceof Error
? part.error
: new DubError('AI assistant stream failed unexpectedly.');
}
default: {
break;
}
}
}
} finally {
thinkingRenderer.stop();
}
return wroteOutput;
}

function isBrowsingUnsupportedError(error: unknown): boolean {
const text = error instanceof Error ? error.message : String(error);
const normalized = text.toLowerCase();
return (
normalized.includes('unsupported') &&
(normalized.includes('grounding') || normalized.includes('brows'))
);
}

function resolveModel(deps: AskAiDependencies): {
provider: 'google' | 'gateway';
model: LanguageModel;
Expand Down
65 changes: 60 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ import { track } from './commands/track';
import { trunk } from './commands/trunk';
import { undo } from './commands/undo';
import { untrack } from './commands/untrack';
import {
collectKnownTopLevelCommands,
preprocessCliArgs,
promptTypoResolution,
type ShortcutMetadata,
} from './lib/ai-shortcut';
import { readConfig } from './lib/config';
import { DubError } from './lib/errors';
import { getCurrentBranch } from './lib/git';
import {
Expand All @@ -69,7 +76,14 @@ const program = new Command();
program
.name('dub')
.description('Manage stacked diffs (dependent git branches) with ease')
.version(version);
.version(version)
.addHelpText(
'after',
`
Examples:
$ dub "what changed in this stack?" Ask AI directly
$ dub --ai "summarize terminal work" Force AI shortcut mode`,
);

program
.command('init')
Expand Down Expand Up @@ -912,14 +926,21 @@ program

program
.command('ai')
.description('Use DubStack AI assistant utilities')
.description(
'Use DubStack AI assistant utilities (or shortcut with: dub PROMPT)',
)
.addCommand(
new Command('ask')
.argument('<prompt...>', 'Prompt text to send to the AI assistant')
.description('Ask DubStack AI assistant a question')
.description('Ask DubStack AI assistant a question (explicit mode)')
.action(async (promptParts: string[]) => {
const { askAi } = await import('./commands/ai');
await askAi(promptParts.join(' '), process.cwd());
if (!invocationMetadata.invocationMode) {
invocationMetadata.invocationMode = 'explicit-ai';
}
const result = await askAi(promptParts.join(' '), process.cwd());
invocationMetadata.webBrowsingRequested = result.webBrowsingRequested;
invocationMetadata.webBrowsingUsed = result.webBrowsingUsed;
}),
)
.addCommand(
Expand Down Expand Up @@ -1120,6 +1141,11 @@ interface HistoryCaptureState {
const MAX_HISTORY_OUTPUT_LINES = 120;
const MAX_HISTORY_OUTPUT_LINE_LENGTH = 500;
let historyCapture: HistoryCaptureState | null = null;
let historyArgsForCapture: string[] | null = null;
let invocationMetadata: ShortcutMetadata & {
webBrowsingRequested?: boolean;
webBrowsingUsed?: boolean;
} = {};

program.hook('preAction', () => {
beginHistoryCapture();
Expand All @@ -1131,6 +1157,27 @@ program.hook('postAction', async () => {

async function main() {
try {
const rawArgs = process.argv.slice(2);
historyArgsForCapture = rawArgs;
const knownCommands = collectKnownTopLevelCommands(program.commands);
const config = await readConfig(process.cwd()).catch(() => null);
const shortcutEnabled = config?.ai.shortcutFallback.enabled ?? true;
const preprocessed =
shortcutEnabled || rawArgs[0] === '--ai'
? await preprocessCliArgs(
rawArgs,
knownCommands,
Boolean(process.stdin.isTTY && process.stdout.isTTY),
promptTypoResolution,
)
: { finalArgs: rawArgs, metadata: {} };
invocationMetadata = { ...preprocessed.metadata };
process.argv = [
process.argv[0],
process.argv[1],
...preprocessed.finalArgs,
];

await program.parseAsync(process.argv);
} catch (error) {
if (error instanceof DubError) {
Expand All @@ -1150,7 +1197,8 @@ async function main() {
function beginHistoryCapture(): void {
if (historyCapture) return;

const sanitizedArgs = sanitizeCommandArgs(process.argv.slice(2));
const captureArgs = historyArgsForCapture ?? process.argv.slice(2);
const sanitizedArgs = sanitizeCommandArgs(captureArgs);
if (sanitizedArgs.length === 0) return;

const output: string[] = [];
Expand Down Expand Up @@ -1243,13 +1291,20 @@ async function finalizeHistoryCapture(
durationMs: Date.now() - capture.startedAt,
output: capture.output,
errorMessage,
invocationMode: invocationMetadata.invocationMode,
typoGuardTriggered: invocationMetadata.typoGuardTriggered,
webBrowsingRequested: invocationMetadata.webBrowsingRequested,
webBrowsingUsed: invocationMetadata.webBrowsingUsed,
context: {
currentBranch,
operation,
},
}).catch(() => {
// Do not block command execution if history append fails.
});

historyArgsForCapture = null;
invocationMetadata = {};
}

function truncateHistoryLine(line: string): string {
Expand Down
Loading