From 43cddf93665f4db37e587dfc895d2f7b6666fd4b Mon Sep 17 00:00:00 2001 From: Maximilian Fellner Date: Thu, 23 Apr 2026 22:23:15 +0200 Subject: [PATCH 1/4] Add OpenAI-compatible custom endpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Register Ollama, vLLM, or other OpenAI-compatible servers from Settings → Providers. Endpoints are stored under ~/.pi/agent/models.json via a new CustomProviderStore in pi-sdk-driver (atomic writes, placeholder API key when auth is not required). Driver rejects IDs that collide with built-in API-key or OAuth providers; URL validation parses with the URL constructor. Main-process probe handler uses Electron net.fetch to list /models with a 5s timeout. Co-Authored-By: Claude Opus 4.7 (1M context) --- apps/desktop/electron/app-store.ts | 35 ++ apps/desktop/electron/main.ts | 65 ++- apps/desktop/electron/preload.ts | 17 +- apps/desktop/src/App.tsx | 23 + apps/desktop/src/ipc.ts | 29 ++ .../src/settings-custom-endpoints-section.tsx | 448 ++++++++++++++++++ .../src/settings-providers-section.tsx | 15 +- apps/desktop/src/settings-view.tsx | 8 +- apps/desktop/src/styles/main.css | 7 + apps/desktop/tsconfig.paths.json | 1 + packages/pi-sdk-driver/package.json | 4 + .../src/custom-provider-store.ts | 194 ++++++++ .../src/custom-provider-types.ts | 32 ++ packages/pi-sdk-driver/src/index.ts | 7 +- packages/pi-sdk-driver/src/runtime-deps.ts | 7 +- .../pi-sdk-driver/src/runtime-supervisor.ts | 38 ++ 16 files changed, 924 insertions(+), 6 deletions(-) create mode 100644 apps/desktop/src/settings-custom-endpoints-section.tsx create mode 100644 packages/pi-sdk-driver/src/custom-provider-store.ts create mode 100644 packages/pi-sdk-driver/src/custom-provider-types.ts diff --git a/apps/desktop/electron/app-store.ts b/apps/desktop/electron/app-store.ts index 67fb85a8..5b8c970c 100644 --- a/apps/desktop/electron/app-store.ts +++ b/apps/desktop/electron/app-store.ts @@ -86,6 +86,7 @@ import { toSessionQueuedMessages, toSessionRef, } from "./app-store-utils"; +import type { CustomProviderConfig } from "../src/ipc"; import { resolveRepoWorkspaceId } from "../src/workspace-roots"; import { SessionStateMap, type QueuedComposerEditState } from "./session-state-map"; import { createEmptyExtensionUiState, serializeExtensionUiState } from "./session-state-map"; @@ -586,6 +587,40 @@ export class DesktopAppStore implements AppStoreInternals { ); } + async listCustomProviders(): Promise { + await this.initialize(); + const entries = await this.driver.runtimeSupervisor.listCustomProviders(); + return entries.map((entry) => ({ + providerId: entry.providerId, + baseUrl: entry.baseUrl, + ...(entry.apiKey !== undefined ? { apiKey: entry.apiKey } : {}), + models: entry.models.map((model) => ({ + id: model.id, + ...(model.contextWindow !== undefined ? { contextWindow: model.contextWindow } : {}), + })), + })); + } + + async setCustomProvider(workspaceId: string, config: CustomProviderConfig): Promise { + return this.withRuntimeUpdate(workspaceId, (ws) => + this.driver.runtimeSupervisor.setCustomProvider(ws, { + providerId: config.providerId, + baseUrl: config.baseUrl, + ...(config.apiKey !== undefined ? { apiKey: config.apiKey } : {}), + models: config.models.map((model) => ({ + id: model.id, + ...(model.contextWindow !== undefined ? { contextWindow: model.contextWindow } : {}), + })), + }), + ); + } + + async deleteCustomProvider(workspaceId: string, providerId: string): Promise { + return this.withRuntimeUpdate(workspaceId, (ws) => + this.driver.runtimeSupervisor.deleteCustomProvider(ws, providerId), + ); + } + async setEnableSkillCommands(workspaceId: string, enabled: boolean): Promise { return this.withRuntimeUpdate(workspaceId, (ws) => this.driver.runtimeSupervisor.setEnableSkillCommands(ws, enabled), diff --git a/apps/desktop/electron/main.ts b/apps/desktop/electron/main.ts index e5352c17..6651e724 100644 --- a/apps/desktop/electron/main.ts +++ b/apps/desktop/electron/main.ts @@ -6,10 +6,12 @@ import { ipcMain, Menu, nativeImage, + net, shell, type MenuItemConstructorOptions, type MessageBoxOptions, } from "electron"; +import { isValidHttpBaseUrl } from "@pi-gui/pi-sdk-driver"; import { randomUUID } from "node:crypto"; import { readFile, stat } from "node:fs/promises"; import path from "node:path"; @@ -25,7 +27,13 @@ import { import { checkForUpdate, initUpdateChecker } from "./update-checker"; import { ThemeManager } from "./theme-manager"; import type { DesktopAppState, ThemeMode } from "../src/desktop-state"; -import { desktopIpc, getDesktopCommandFromShortcut } from "../src/ipc"; +import { + desktopIpc, + getDesktopCommandFromShortcut, + type CustomProviderConfig, + type CustomProviderProbeInput, + type CustomProviderProbeResult, +} from "../src/ipc"; import { SUPPORTED_COMPOSER_IMAGE_TYPES } from "../src/composer-attachments"; import type { ComposerAttachment, @@ -466,6 +474,16 @@ app.whenReady().then(async () => { ipcMain.handle(desktopIpc.setProviderApiKey, (_event, workspaceId: string, providerId: string, apiKey: string) => store.setProviderApiKey(workspaceId, providerId, apiKey), ); + ipcMain.handle(desktopIpc.listCustomProviders, () => store.listCustomProviders()); + ipcMain.handle(desktopIpc.setCustomProvider, (_event, workspaceId: string, config: CustomProviderConfig) => + store.setCustomProvider(workspaceId, config), + ); + ipcMain.handle(desktopIpc.deleteCustomProvider, (_event, workspaceId: string, providerId: string) => + store.deleteCustomProvider(workspaceId, providerId), + ); + ipcMain.handle(desktopIpc.probeCustomProviderModels, (_event, input: CustomProviderProbeInput) => + probeCustomProviderModels(input), + ); ipcMain.handle(desktopIpc.setEnableSkillCommands, (_event, workspaceId: string, enabled: boolean) => store.setEnableSkillCommands(workspaceId, enabled), ); @@ -769,3 +787,48 @@ async function promptForText(message: string, placeholder = ""): Promise } return result.trim(); } + +async function probeCustomProviderModels(input: CustomProviderProbeInput): Promise { + const baseUrl = input.baseUrl?.trim(); + if (!baseUrl || !isValidHttpBaseUrl(baseUrl)) { + return { ok: false, error: "Base URL must start with http:// or https://" }; + } + const target = `${baseUrl.replace(/\/+$/, "")}/models`; + const apiKey = input.apiKey?.trim(); + try { + const response = await net.fetch(target, { + method: "GET", + headers: apiKey ? { Authorization: `Bearer ${apiKey}` } : undefined, + signal: AbortSignal.timeout(5000), + }); + if (!response.ok) { + return { ok: false, error: `${response.status} ${response.statusText} from ${target}` }; + } + const payload = (await response.json()) as unknown; + const data = (payload as { data?: unknown }).data; + if (!Array.isArray(data)) { + return { ok: false, error: `Response from ${target} is missing a "data" array` }; + } + const models = data + .map((entry) => { + if (entry && typeof entry === "object" && typeof (entry as { id?: unknown }).id === "string") { + return (entry as { id: string }).id; + } + return undefined; + }) + .filter((id): id is string => Boolean(id && id.length > 0)); + return { ok: true, models }; + } catch (error) { + return { ok: false, error: describeProbeError(error, target) }; + } +} + +function describeProbeError(error: unknown, target: string): string { + if (error instanceof Error && error.name === "TimeoutError") { + return `Timed out after 5s contacting ${target}`; + } + if (error instanceof Error) { + return error.message; + } + return String(error); +} diff --git a/apps/desktop/electron/preload.ts b/apps/desktop/electron/preload.ts index acb73485..edad4b17 100644 --- a/apps/desktop/electron/preload.ts +++ b/apps/desktop/electron/preload.ts @@ -1,6 +1,13 @@ import { contextBridge, ipcRenderer, webUtils } from "electron"; import { PRELOAD_DEV_RELOAD_MARKER } from "./dev-reload-preload-probe"; -import { desktopIpc, type DesktopNotificationPermissionStatus, type PiDesktopCommand } from "../src/ipc"; +import { + desktopIpc, + type CustomProviderConfig, + type CustomProviderProbeInput, + type CustomProviderProbeResult, + type DesktopNotificationPermissionStatus, + type PiDesktopCommand, +} from "../src/ipc"; import type { NavigateSessionTreeOptions, NavigateSessionTreeResult, @@ -149,6 +156,14 @@ contextBridge.exposeInMainWorld("piApp", { ipcRenderer.invoke(desktopIpc.logoutProvider, workspaceId, providerId) as Promise, setProviderApiKey: (workspaceId: string, providerId: string, apiKey: string) => ipcRenderer.invoke(desktopIpc.setProviderApiKey, workspaceId, providerId, apiKey) as Promise, + listCustomProviders: () => + ipcRenderer.invoke(desktopIpc.listCustomProviders) as Promise, + setCustomProvider: (workspaceId: string, config: CustomProviderConfig) => + ipcRenderer.invoke(desktopIpc.setCustomProvider, workspaceId, config) as Promise, + deleteCustomProvider: (workspaceId: string, providerId: string) => + ipcRenderer.invoke(desktopIpc.deleteCustomProvider, workspaceId, providerId) as Promise, + probeCustomProviderModels: (input: CustomProviderProbeInput) => + ipcRenderer.invoke(desktopIpc.probeCustomProviderModels, input) as Promise, setEnableSkillCommands: (workspaceId: string, enabled: boolean) => ipcRenderer.invoke(desktopIpc.setEnableSkillCommands, workspaceId, enabled) as Promise, setScopedModelPatterns: (workspaceId: string, patterns: readonly string[]) => diff --git a/apps/desktop/src/App.tsx b/apps/desktop/src/App.tsx index 9b51c728..4fd31371 100644 --- a/apps/desktop/src/App.tsx +++ b/apps/desktop/src/App.tsx @@ -22,6 +22,7 @@ import { parseTreeComposerCommand } from "./composer-commands"; import { desktopCommands, getDesktopCommandFromShortcut, + type CustomProviderConfig, type DesktopNotificationPermissionStatus, type PiDesktopCommand, } from "./ipc"; @@ -1383,6 +1384,26 @@ export default function App() { return state.lastError; }; + const handleSaveCustomProvider = async (config: CustomProviderConfig): Promise => { + if (!api || !settingsWorkspace) { + return "Select a workspace first."; + } + const state = await updateSnapshot(api, setSnapshot, () => + api.setCustomProvider(settingsWorkspace.id, config), + ); + return state.lastError; + }; + + const handleDeleteCustomProvider = async (providerId: string): Promise => { + if (!api || !settingsWorkspace) { + return "Select a workspace first."; + } + const state = await updateSnapshot(api, setSnapshot, () => + api.deleteCustomProvider(settingsWorkspace.id, providerId), + ); + return state.lastError; + }; + const handleToggleSkill = (filePath: string, enabled: boolean) => { if (!skillsWorkspace) { return; @@ -1673,6 +1694,8 @@ export default function App() { onLogoutProvider={handleLogoutProvider} onSetProviderApiKey={handleSetProviderApiKey} onRemoveProviderApiKey={handleRemoveProviderApiKey} + onSaveCustomProvider={handleSaveCustomProvider} + onDeleteCustomProvider={handleDeleteCustomProvider} onSetModelSettingsScopeMode={handleSetModelSettingsScopeMode} onSetDefaultModel={handleSetDefaultModel} onSetNotificationPreferences={handleSetNotificationPreferences} diff --git a/apps/desktop/src/ipc.ts b/apps/desktop/src/ipc.ts index bb4716ed..a42ba728 100644 --- a/apps/desktop/src/ipc.ts +++ b/apps/desktop/src/ipc.ts @@ -26,6 +26,27 @@ export type DesktopNotificationPermissionStatus = | "unsupported" | "unknown"; +export interface CustomProviderModelConfig { + readonly id: string; + readonly contextWindow?: number; +} + +export interface CustomProviderConfig { + readonly providerId: string; + readonly baseUrl: string; + readonly apiKey?: string; + readonly models: readonly CustomProviderModelConfig[]; +} + +export interface CustomProviderProbeInput { + readonly baseUrl: string; + readonly apiKey?: string; +} + +export type CustomProviderProbeResult = + | { readonly ok: true; readonly models: readonly string[] } + | { readonly ok: false; readonly error: string }; + export const desktopIpc = { stateRequest: "pi-gui:state-request", stateChanged: "pi-gui:state-changed", @@ -62,6 +83,10 @@ export const desktopIpc = { loginProvider: "pi-gui:login-provider", logoutProvider: "pi-gui:logout-provider", setProviderApiKey: "pi-gui:set-provider-api-key", + listCustomProviders: "pi-gui:list-custom-providers", + setCustomProvider: "pi-gui:set-custom-provider", + deleteCustomProvider: "pi-gui:delete-custom-provider", + probeCustomProviderModels: "pi-gui:probe-custom-provider-models", setEnableSkillCommands: "pi-gui:set-enable-skill-commands", setScopedModelPatterns: "pi-gui:set-scoped-model-patterns", setSkillEnabled: "pi-gui:set-skill-enabled", @@ -185,6 +210,10 @@ export interface PiDesktopApi { loginProvider(workspaceId: string, providerId: string): Promise; logoutProvider(workspaceId: string, providerId: string): Promise; setProviderApiKey(workspaceId: string, providerId: string, apiKey: string): Promise; + listCustomProviders(): Promise; + setCustomProvider(workspaceId: string, config: CustomProviderConfig): Promise; + deleteCustomProvider(workspaceId: string, providerId: string): Promise; + probeCustomProviderModels(input: CustomProviderProbeInput): Promise; setEnableSkillCommands(workspaceId: string, enabled: boolean): Promise; setScopedModelPatterns(workspaceId: string, patterns: readonly string[]): Promise; setSkillEnabled(workspaceId: string, filePath: string, enabled: boolean): Promise; diff --git a/apps/desktop/src/settings-custom-endpoints-section.tsx b/apps/desktop/src/settings-custom-endpoints-section.tsx new file mode 100644 index 00000000..979efaca --- /dev/null +++ b/apps/desktop/src/settings-custom-endpoints-section.tsx @@ -0,0 +1,448 @@ +import { useCallback, useEffect, useMemo, useState } from "react"; +import { CUSTOM_PROVIDER_ID_PATTERN, isValidHttpBaseUrl } from "@pi-gui/pi-sdk-driver/custom-provider-types"; +import type { CustomProviderConfig, CustomProviderModelConfig } from "./ipc"; +import { SettingsGroup } from "./settings-utils"; + +interface SettingsCustomEndpointsSectionProps { + readonly existingProviderIds: readonly string[]; + readonly onSaveCustomProvider: (config: CustomProviderConfig) => Promise; + readonly onDeleteCustomProvider: (providerId: string) => Promise; +} + +type DialogMode = { kind: "closed" } | { kind: "create" } | { kind: "edit"; original: CustomProviderConfig }; + +export function SettingsCustomEndpointsSection({ + existingProviderIds, + onSaveCustomProvider, + onDeleteCustomProvider, +}: SettingsCustomEndpointsSectionProps) { + const [entries, setEntries] = useState([]); + const [loadError, setLoadError] = useState(); + const [dialog, setDialog] = useState({ kind: "closed" }); + const [reloadKey, setReloadKey] = useState(0); + + useEffect(() => { + const api = window.piApp; + if (!api) { + return; + } + let cancelled = false; + void api + .listCustomProviders() + .then((list) => { + if (!cancelled) { + setEntries(list); + setLoadError(undefined); + } + }) + .catch((error) => { + if (!cancelled) { + setLoadError(error instanceof Error ? error.message : String(error)); + } + }); + return () => { + cancelled = true; + }; + }, [reloadKey]); + + const reload = useCallback(() => setReloadKey((key) => key + 1), []); + + const handleSave = useCallback( + async (config: CustomProviderConfig): Promise => { + const error = await onSaveCustomProvider(config); + if (!error) { + reload(); + } + return error; + }, + [onSaveCustomProvider, reload], + ); + + const handleDelete = useCallback( + async (providerId: string) => { + const error = await onDeleteCustomProvider(providerId); + if (error) { + setLoadError(error); + return; + } + reload(); + }, + [onDeleteCustomProvider, reload], + ); + + return ( + <> + + {loadError ? ( +
+ {loadError} +
+ ) : null} + {entries.length === 0 ? ( +
+ No custom endpoints yet. +
+ ) : ( + entries.map((entry) => ( +
+
+
{entry.providerId}
+
+ {entry.baseUrl} · {entry.models.length} model{entry.models.length === 1 ? "" : "s"} +
+
+
+ + +
+
+ )) + )} +
+
+
Add endpoint
+
+ Register a local or custom OpenAI-compatible server. +
+
+
+ +
+
+
+ + {dialog.kind !== "closed" ? ( + setDialog({ kind: "closed" })} + onSave={handleSave} + /> + ) : null} + + ); +} + +interface CustomEndpointDialogProps { + readonly mode: Exclude; + readonly existingProviderIds: readonly string[]; + readonly onClose: () => void; + readonly onSave: (config: CustomProviderConfig) => Promise; +} + +function CustomEndpointDialog({ mode, existingProviderIds, onClose, onSave }: CustomEndpointDialogProps) { + const initial = mode.kind === "edit" ? mode.original : undefined; + const [providerId, setProviderId] = useState(initial?.providerId ?? ""); + const [baseUrl, setBaseUrl] = useState(initial?.baseUrl ?? ""); + const [apiKey, setApiKey] = useState(initial?.apiKey ?? ""); + const [models, setModels] = useState( + initial ? [...initial.models] : [], + ); + const [probeCandidates, setProbeCandidates] = useState([]); + const [probeError, setProbeError] = useState(); + const [probePending, setProbePending] = useState(false); + const [formError, setFormError] = useState(); + const [savePending, setSavePending] = useState(false); + + const selectedModelIds = useMemo(() => new Set(models.map((model) => model.id)), [models]); + const isEdit = mode.kind === "edit"; + + const idValidationError = useMemo(() => validateProviderId(providerId, existingProviderIds, initial?.providerId), [ + providerId, + existingProviderIds, + initial?.providerId, + ]); + + const handleProbe = async () => { + const api = window.piApp; + if (!api) { + setProbeError("Desktop bridge is not available."); + return; + } + if (!isValidHttpBaseUrl(baseUrl)) { + setProbeError("Base URL must start with http:// or https://"); + return; + } + setProbePending(true); + setProbeError(undefined); + const result = await api.probeCustomProviderModels({ + baseUrl: baseUrl.trim(), + apiKey: apiKey.trim() ? apiKey.trim() : undefined, + }); + setProbePending(false); + if (!result.ok) { + setProbeError(result.error); + setProbeCandidates([]); + return; + } + setProbeCandidates(result.models); + }; + + const toggleModel = (id: string, contextWindow?: number) => { + setModels((current) => { + const existing = current.find((model) => model.id === id); + if (existing) { + return current.filter((model) => model.id !== id); + } + return [...current, contextWindow !== undefined ? { id, contextWindow } : { id }]; + }); + }; + + const handleManualAdd = (id: string) => { + const trimmed = id.trim(); + if (!trimmed) { + return; + } + if (selectedModelIds.has(trimmed)) { + return; + } + setModels((current) => [...current, { id: trimmed }]); + }; + + const handleSave = async () => { + if (idValidationError) { + setFormError(idValidationError); + return; + } + if (!isValidHttpBaseUrl(baseUrl)) { + setFormError("Base URL must start with http:// or https://"); + return; + } + if (models.length === 0) { + setFormError("Select at least one model."); + return; + } + setSavePending(true); + setFormError(undefined); + const error = await onSave({ + providerId: providerId.trim(), + baseUrl: baseUrl.trim(), + ...(apiKey.trim() ? { apiKey: apiKey.trim() } : {}), + models, + }); + if (error) { + setSavePending(false); + setFormError(error); + return; + } + onClose(); + }; + + return ( +
+
{ + if (event.key === "Escape" && !savePending) { + event.preventDefault(); + onClose(); + } + }} + > +
{isEdit ? "Edit custom endpoint" : "Add custom endpoint"}
+

+ Configure an OpenAI-compatible server. The endpoint and API key are stored in plaintext at + ~/.pi/agent/models.json. +

+ + + + +
+
+ Models + +
+ {probeError ? ( +

{probeError}

+ ) : null} + +

+ Tool calling is required. Smaller models (< 7B) often do not emit OpenAI-style function calls cleanly. +

+
+ + {formError ?

{formError}

: null} +
+ + +
+
+
+ ); +} + +interface ModelChecklistProps { + readonly probed: readonly string[]; + readonly selected: readonly CustomProviderModelConfig[]; + readonly onToggle: (id: string, contextWindow?: number) => void; + readonly onManualAdd: (id: string) => void; + readonly disabled: boolean; +} + +function ModelChecklist({ probed, selected, onToggle, onManualAdd, disabled }: ModelChecklistProps) { + const [manualDraft, setManualDraft] = useState(""); + const selectedIds = useMemo(() => new Set(selected.map((model) => model.id)), [selected]); + const knownIds = useMemo(() => new Set([...probed, ...selected.map((model) => model.id)]), [probed, selected]); + + const submitManual = () => { + onManualAdd(manualDraft); + setManualDraft(""); + }; + + return ( +
+ {knownIds.size === 0 ? ( +

+ Click “Detect models” or type a model ID below to add one manually. +

+ ) : ( +
    + {[...knownIds].sort((a, b) => a.localeCompare(b)).map((id) => ( +
  • + +
  • + ))} +
+ )} +
+ setManualDraft(event.target.value)} + onKeyDown={(event) => { + if (event.key === "Enter") { + event.preventDefault(); + submitManual(); + } + }} + /> + +
+
+ ); +} + +function validateProviderId( + candidate: string, + existing: readonly string[], + editing?: string, +): string | undefined { + const trimmed = candidate.trim(); + if (!trimmed) { + return "Provider ID is required."; + } + if (!CUSTOM_PROVIDER_ID_PATTERN.test(trimmed)) { + return "Use lowercase letters, digits, and dashes (max 64 chars)."; + } + if (trimmed !== editing && existing.includes(trimmed)) { + return `Provider ID "${trimmed}" is already in use.`; + } + return undefined; +} diff --git a/apps/desktop/src/settings-providers-section.tsx b/apps/desktop/src/settings-providers-section.tsx index 2e6a8d26..f7f2e4ec 100644 --- a/apps/desktop/src/settings-providers-section.tsx +++ b/apps/desktop/src/settings-providers-section.tsx @@ -1,5 +1,7 @@ -import { useEffect, useState } from "react"; +import { useEffect, useMemo, useState } from "react"; import type { RuntimeSnapshot } from "@pi-gui/session-driver/runtime-types"; +import type { CustomProviderConfig } from "./ipc"; +import { SettingsCustomEndpointsSection } from "./settings-custom-endpoints-section"; import { filterProviders, ProviderRow, SettingsGroup } from "./settings-utils"; interface SettingsProvidersSectionProps { @@ -8,6 +10,8 @@ interface SettingsProvidersSectionProps { readonly onLogoutProvider: (providerId: string) => void; readonly onSetProviderApiKey: (providerId: string, apiKey: string) => Promise; readonly onRemoveProviderApiKey: (providerId: string) => Promise; + readonly onSaveCustomProvider: (config: CustomProviderConfig) => Promise; + readonly onDeleteCustomProvider: (providerId: string) => Promise; } export function SettingsProvidersSection({ @@ -16,6 +20,8 @@ export function SettingsProvidersSection({ onLogoutProvider, onSetProviderApiKey, onRemoveProviderApiKey, + onSaveCustomProvider, + onDeleteCustomProvider, }: SettingsProvidersSectionProps) { const [providerQuery, setProviderQuery] = useState(""); const [apiKeyProviderId, setApiKeyProviderId] = useState(); @@ -28,6 +34,7 @@ export function SettingsProvidersSection({ const oauthProviders = providers.filter((p) => p.oauthSupported); const filteredProviders = filterProviders(providers, providerQuery); const apiKeyProvider = apiKeyProviderId ? providers.find((provider) => provider.id === apiKeyProviderId) : undefined; + const existingProviderIds = useMemo(() => providers.map((provider) => provider.id), [providers]); useEffect(() => { setApiKeyDraft(""); @@ -104,6 +111,12 @@ export function SettingsProvidersSection({ ))} + +
diff --git a/apps/desktop/src/settings-view.tsx b/apps/desktop/src/settings-view.tsx index 903c0cf7..2139232a 100644 --- a/apps/desktop/src/settings-view.tsx +++ b/apps/desktop/src/settings-view.tsx @@ -1,6 +1,6 @@ import type { RuntimeSettingsSnapshot, RuntimeSnapshot } from "@pi-gui/session-driver/runtime-types"; import type { ModelSettingsScopeMode, NotificationPreferences, WorkspaceRecord } from "./desktop-state"; -import type { DesktopNotificationPermissionStatus } from "./ipc"; +import type { CustomProviderConfig, DesktopNotificationPermissionStatus } from "./ipc"; import { SettingsAppearanceSection } from "./settings-appearance-section"; import { SettingsGeneralSection } from "./settings-general-section"; import { SettingsModelsSection } from "./settings-models-section"; @@ -28,6 +28,8 @@ interface SettingsViewProps { readonly onLogoutProvider: (providerId: string) => void; readonly onSetProviderApiKey: (providerId: string, apiKey: string) => Promise; readonly onRemoveProviderApiKey: (providerId: string) => Promise; + readonly onSaveCustomProvider: (config: CustomProviderConfig) => Promise; + readonly onDeleteCustomProvider: (providerId: string) => Promise; readonly onSetNotificationPreferences: (preferences: Partial) => void; readonly onRequestNotificationPermission: () => void; readonly onOpenSystemNotificationSettings: () => void; @@ -52,6 +54,8 @@ export function SettingsView({ onLogoutProvider, onSetProviderApiKey, onRemoveProviderApiKey, + onSaveCustomProvider, + onDeleteCustomProvider, onSetNotificationPreferences, onRequestNotificationPermission, onOpenSystemNotificationSettings, @@ -106,6 +110,8 @@ export function SettingsView({ onLogoutProvider={onLogoutProvider} onSetProviderApiKey={onSetProviderApiKey} onRemoveProviderApiKey={onRemoveProviderApiKey} + onSaveCustomProvider={onSaveCustomProvider} + onDeleteCustomProvider={onDeleteCustomProvider} /> ) : null} diff --git a/apps/desktop/src/styles/main.css b/apps/desktop/src/styles/main.css index a1f50cb4..2052cb43 100644 --- a/apps/desktop/src/styles/main.css +++ b/apps/desktop/src/styles/main.css @@ -1958,6 +1958,13 @@ font-weight: 560; } +.settings-field__header { + display: flex; + align-items: center; + justify-content: space-between; + gap: 12px; +} + .settings-select, .settings-search { width: min(420px, 100%); diff --git a/apps/desktop/tsconfig.paths.json b/apps/desktop/tsconfig.paths.json index 8f8ad9ce..62a7c62d 100644 --- a/apps/desktop/tsconfig.paths.json +++ b/apps/desktop/tsconfig.paths.json @@ -3,6 +3,7 @@ "baseUrl": ".", "paths": { "@pi-gui/pi-sdk-driver": ["../../packages/pi-sdk-driver/src/index.ts"], + "@pi-gui/pi-sdk-driver/custom-provider-types": ["../../packages/pi-sdk-driver/src/custom-provider-types.ts"], "@pi-gui/pi-sdk-driver/dev-reload-probe": ["../../packages/pi-sdk-driver/src/dev-reload-probe.ts"], "@pi-gui/catalogs": ["../../packages/catalogs/src/index.ts"], "@pi-gui/catalogs/dev-reload-probe": ["../../packages/catalogs/src/dev-reload-probe.ts"], diff --git a/packages/pi-sdk-driver/package.json b/packages/pi-sdk-driver/package.json index 3885f8e2..eb447a19 100644 --- a/packages/pi-sdk-driver/package.json +++ b/packages/pi-sdk-driver/package.json @@ -9,6 +9,10 @@ ".": { "types": "./dist/index.d.ts", "default": "./dist/index.js" + }, + "./custom-provider-types": { + "types": "./dist/custom-provider-types.d.ts", + "default": "./dist/custom-provider-types.js" } }, "files": [ diff --git a/packages/pi-sdk-driver/src/custom-provider-store.ts b/packages/pi-sdk-driver/src/custom-provider-store.ts new file mode 100644 index 00000000..ae55ad5c --- /dev/null +++ b/packages/pi-sdk-driver/src/custom-provider-store.ts @@ -0,0 +1,194 @@ +import { mkdir, readFile, rename, writeFile } from "node:fs/promises"; +import { dirname } from "node:path"; +import { + CUSTOM_PROVIDER_ID_PATTERN, + CUSTOM_PROVIDER_PLACEHOLDER_API_KEY, + isValidHttpBaseUrl, + OPENAI_COMPLETIONS_API, + type CustomProviderEntry, + type CustomProviderInput, + type CustomProviderModelInput, +} from "./custom-provider-types.js"; + +export type { CustomProviderEntry, CustomProviderInput, CustomProviderModelInput } from "./custom-provider-types.js"; +export { + CUSTOM_PROVIDER_ID_PATTERN, + CUSTOM_PROVIDER_PLACEHOLDER_API_KEY, + isValidHttpBaseUrl, + OPENAI_COMPLETIONS_API, +} from "./custom-provider-types.js"; + +export class CustomProviderStore { + private queue: Promise = Promise.resolve(); + + constructor(private readonly modelsJsonPath: string) {} + + async list(): Promise { + return this.enqueue(async () => { + const data = await readModelsJson(this.modelsJsonPath); + return readCustomProviders(data); + }); + } + + async set(input: CustomProviderInput): Promise { + validateInput(input); + await this.enqueue(async () => { + const data = await readModelsJson(this.modelsJsonPath); + const providers = ensureProvidersRecord(data); + providers[input.providerId] = toProviderConfig(input); + await atomicWriteJson(this.modelsJsonPath, data); + }); + } + + async delete(providerId: string): Promise { + return this.enqueue(async () => { + const data = await readModelsJson(this.modelsJsonPath); + const providers = data.providers; + if (!providers || typeof providers !== "object" || !(providerId in providers)) { + return false; + } + delete (providers as Record)[providerId]; + await atomicWriteJson(this.modelsJsonPath, data); + return true; + }); + } + + private enqueue(task: () => Promise): Promise { + const next = this.queue.then(task, task); + this.queue = next.catch(() => undefined); + return next; + } +} + +function validateInput(input: CustomProviderInput): void { + if (!CUSTOM_PROVIDER_ID_PATTERN.test(input.providerId)) { + throw new Error( + `Provider ID must be lowercase alphanumerics or dashes (max 64 chars): ${JSON.stringify(input.providerId)}`, + ); + } + if (!isValidHttpBaseUrl(input.baseUrl)) { + throw new Error(`Base URL must start with http:// or https://: ${JSON.stringify(input.baseUrl)}`); + } + if (input.models.length === 0) { + throw new Error("At least one model is required."); + } + for (const model of input.models) { + if (!model.id || typeof model.id !== "string") { + throw new Error("Model id is required."); + } + if (model.contextWindow !== undefined && !Number.isFinite(model.contextWindow)) { + throw new Error(`Model ${model.id} has non-numeric contextWindow.`); + } + } +} + +function toProviderConfig(input: CustomProviderInput): Record { + const trimmedKey = input.apiKey?.trim(); + return { + baseUrl: input.baseUrl, + api: OPENAI_COMPLETIONS_API, + apiKey: trimmedKey ? trimmedKey : CUSTOM_PROVIDER_PLACEHOLDER_API_KEY, + models: input.models.map((model) => { + const entry: Record = { id: model.id }; + if (model.contextWindow !== undefined) { + entry.contextWindow = model.contextWindow; + } + return entry; + }), + }; +} + +function readCustomProviders(data: Record): readonly CustomProviderEntry[] { + const providers = data.providers; + if (!providers || typeof providers !== "object") { + return []; + } + const entries: CustomProviderEntry[] = []; + for (const [providerId, rawConfig] of Object.entries(providers as Record)) { + if (!rawConfig || typeof rawConfig !== "object") { + continue; + } + const config = rawConfig as Record; + const baseUrl = typeof config.baseUrl === "string" ? config.baseUrl : undefined; + if (!baseUrl) { + continue; + } + const models = Array.isArray(config.models) + ? (config.models as unknown[]) + .map((raw): CustomProviderModelInput | undefined => { + if (!raw || typeof raw !== "object") { + return undefined; + } + const modelConfig = raw as Record; + if (typeof modelConfig.id !== "string" || !modelConfig.id) { + return undefined; + } + const contextWindow = + typeof modelConfig.contextWindow === "number" ? modelConfig.contextWindow : undefined; + return contextWindow !== undefined + ? { id: modelConfig.id, contextWindow } + : { id: modelConfig.id }; + }) + .filter((entry): entry is CustomProviderModelInput => entry !== undefined) + : []; + const rawApiKey = typeof config.apiKey === "string" ? config.apiKey : undefined; + const apiKey = rawApiKey === CUSTOM_PROVIDER_PLACEHOLDER_API_KEY ? undefined : rawApiKey; + entries.push({ + providerId, + baseUrl, + ...(apiKey !== undefined ? { apiKey } : {}), + models, + }); + } + entries.sort((left, right) => left.providerId.localeCompare(right.providerId)); + return entries; +} + +function ensureProvidersRecord(data: Record): Record { + if (!data.providers || typeof data.providers !== "object") { + data.providers = {}; + } + return data.providers as Record; +} + +function isMissingFileError(error: unknown): boolean { + const code = (error as NodeJS.ErrnoException | null)?.code; + return code === "ENOENT" || code === "ENOTDIR"; +} + +async function readModelsJson(path: string): Promise> { + let text: string; + try { + text = await readFile(path, "utf8"); + } catch (error) { + if (isMissingFileError(error)) { + return {}; + } + throw error; + } + if (text.trim().length === 0) { + return {}; + } + let parsed: unknown; + try { + parsed = JSON.parse(text); + } catch (error) { + throw new Error( + `${path} is not valid JSON. Fix or remove the file before editing custom endpoints from the app. (${ + error instanceof Error ? error.message : String(error) + })`, + ); + } + if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) { + throw new Error(`${path} must contain a JSON object at the top level.`); + } + return parsed as Record; +} + +async function atomicWriteJson(path: string, data: Record): Promise { + await mkdir(dirname(path), { recursive: true }); + const payload = `${JSON.stringify(data, null, 2)}\n`; + const tempPath = `${path}.${process.pid}.${Date.now()}.tmp`; + await writeFile(tempPath, payload, "utf8"); + await rename(tempPath, path); +} diff --git a/packages/pi-sdk-driver/src/custom-provider-types.ts b/packages/pi-sdk-driver/src/custom-provider-types.ts new file mode 100644 index 00000000..fe13245b --- /dev/null +++ b/packages/pi-sdk-driver/src/custom-provider-types.ts @@ -0,0 +1,32 @@ +export interface CustomProviderModelInput { + readonly id: string; + readonly contextWindow?: number; +} + +export interface CustomProviderInput { + readonly providerId: string; + readonly baseUrl: string; + readonly apiKey?: string; + readonly models: readonly CustomProviderModelInput[]; +} + +export interface CustomProviderEntry { + readonly providerId: string; + readonly baseUrl: string; + readonly apiKey?: string; + readonly models: readonly CustomProviderModelInput[]; +} + +export const CUSTOM_PROVIDER_ID_PATTERN = /^[a-z0-9][a-z0-9-]{0,63}$/; +export const OPENAI_COMPLETIONS_API = "openai-completions"; + +export const CUSTOM_PROVIDER_PLACEHOLDER_API_KEY = "unused"; + +export function isValidHttpBaseUrl(value: string): boolean { + try { + const url = new URL(value.trim()); + return (url.protocol === "http:" || url.protocol === "https:") && url.host.length > 0; + } catch { + return false; + } +} diff --git a/packages/pi-sdk-driver/src/index.ts b/packages/pi-sdk-driver/src/index.ts index a414753f..869980bb 100644 --- a/packages/pi-sdk-driver/src/index.ts +++ b/packages/pi-sdk-driver/src/index.ts @@ -7,7 +7,12 @@ export { export type { ExtensionUiDialogRequest, ExtensionUiState, ExtensionUiWidgetState } from "./extension-ui-state.js"; export type { PiSdkDriverConfig } from "./pi-sdk-driver.js"; export { createPiSdkDriver, PiSdkDriver } from "./pi-sdk-driver.js"; -export { RuntimeSupervisor } from "./runtime-supervisor.js"; +export { + CUSTOM_PROVIDER_ID_PATTERN, + isValidHttpBaseUrl, + OPENAI_COMPLETIONS_API, + RuntimeSupervisor, +} from "./runtime-supervisor.js"; export type { PiSdkDriverOptions, SyncWorkspaceResult } from "./session-supervisor.js"; export { SessionSupervisor } from "./session-supervisor.js"; export { sessionKey } from "./session-supervisor-utils.js"; diff --git a/packages/pi-sdk-driver/src/runtime-deps.ts b/packages/pi-sdk-driver/src/runtime-deps.ts index e3aa53b8..300a1351 100644 --- a/packages/pi-sdk-driver/src/runtime-deps.ts +++ b/packages/pi-sdk-driver/src/runtime-deps.ts @@ -1,20 +1,25 @@ import { join, resolve } from "node:path"; import { AuthStorage, ModelRegistry, getAgentDir } from "@mariozechner/pi-coding-agent"; +import { CustomProviderStore } from "./custom-provider-store.js"; import type { RuntimeSupervisorOptions } from "./runtime-supervisor.js"; export interface RuntimeDependencies { readonly agentDir: string; readonly authStorage: AuthStorage; readonly modelRegistry: ModelRegistry; + readonly customProviderStore: CustomProviderStore; } export function createRuntimeDependencies(options: RuntimeSupervisorOptions = {}): RuntimeDependencies { const agentDir = resolve(options.agentDir ?? getAgentDir()); + const modelsJsonPath = join(agentDir, "models.json"); const authStorage = options.authStorage ?? AuthStorage.create(join(agentDir, "auth.json")); - const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, join(agentDir, "models.json")); + const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsJsonPath); + const customProviderStore = options.customProviderStore ?? new CustomProviderStore(modelsJsonPath); return { agentDir, authStorage, modelRegistry, + customProviderStore, }; } diff --git a/packages/pi-sdk-driver/src/runtime-supervisor.ts b/packages/pi-sdk-driver/src/runtime-supervisor.ts index f41cfc5d..75f91869 100644 --- a/packages/pi-sdk-driver/src/runtime-supervisor.ts +++ b/packages/pi-sdk-driver/src/runtime-supervisor.ts @@ -28,6 +28,14 @@ import { createRuntimeDependencies } from "./runtime-deps.js"; import { createSettingsManagerWithoutNpmPackages, isGlobalNpmLookupError } from "./npm-package-fallback.js"; import { skillSlashCommand } from "./runtime-command-utils.js"; import type { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent"; +import { CustomProviderStore, type CustomProviderEntry, type CustomProviderInput } from "./custom-provider-store.js"; + +export { + CUSTOM_PROVIDER_ID_PATTERN, + isValidHttpBaseUrl, + OPENAI_COMPLETIONS_API, +} from "./custom-provider-store.js"; +export type { CustomProviderEntry, CustomProviderInput, CustomProviderModelInput } from "./custom-provider-store.js"; interface ModelSettingsSnapshot { readonly defaultProvider?: string; @@ -52,6 +60,7 @@ export interface RuntimeSupervisorOptions { readonly agentDir?: string; readonly authStorage?: AuthStorage; readonly modelRegistry?: ModelRegistry; + readonly customProviderStore?: CustomProviderStore; } type ResourceScope = "user" | "project"; @@ -61,6 +70,7 @@ export class RuntimeSupervisor implements RuntimeResourceDriver { private readonly agentDir: string; private readonly authStorage: AuthStorage; private readonly modelRegistry: ModelRegistry; + private readonly customProviderStore: CustomProviderStore; private readonly contexts = new Map(); constructor(options: RuntimeSupervisorOptions = {}) { @@ -68,6 +78,7 @@ export class RuntimeSupervisor implements RuntimeResourceDriver { this.agentDir = deps.agentDir; this.authStorage = deps.authStorage; this.modelRegistry = deps.modelRegistry; + this.customProviderStore = deps.customProviderStore; } async getRuntimeSnapshot(workspace: WorkspaceRef): Promise { @@ -118,6 +129,33 @@ export class RuntimeSupervisor implements RuntimeResourceDriver { return this.buildSnapshot(context); } + async listCustomProviders(): Promise { + return this.customProviderStore.list(); + } + + async setCustomProvider(workspace: WorkspaceRef, input: CustomProviderInput): Promise { + const oauthProviderIds = new Set(this.authStorage.getOAuthProviders().map((provider) => provider.id)); + if (providerSupportsDesktopApiKeySetup(input.providerId) || oauthProviderIds.has(input.providerId)) { + throw new Error( + `Provider ID "${input.providerId}" conflicts with a built-in provider. Pick a unique ID.`, + ); + } + const context = await this.ensureContext(workspace); + await this.customProviderStore.set(input); + this.modelRegistry.refresh(); + await context.resourceLoader.reload(); + await this.autoEnableModelsForAuthenticatedProviders(context, [input.providerId]); + return this.buildSnapshot(context); + } + + async deleteCustomProvider(workspace: WorkspaceRef, providerId: string): Promise { + const context = await this.ensureContext(workspace); + await this.customProviderStore.delete(providerId); + this.modelRegistry.refresh(); + await context.resourceLoader.reload(); + return this.buildSnapshot(context); + } + async setDefaultModel( workspace: WorkspaceRef, selection: { From 9aaabe359eaf7d30727c43a4f138bd39b7793461 Mon Sep 17 00:00:00 2001 From: Maximilian Fellner Date: Thu, 23 Apr 2026 22:23:29 +0200 Subject: [PATCH 2/4] Drop action cell for env/external provider rows A disabled "Managed externally" button on environment-variable and models.json-override provider rows was visual noise. Render no action cell when the provider can't be managed from the app; keep the description text that already explains the state. Update the two provider-settings specs that asserted on the disabled button to check that the control cell is absent instead. Co-Authored-By: Claude Opus 4.7 (1M context) --- apps/desktop/src/settings-utils.tsx | 40 +++++++++++-------- .../tests/core/provider-settings.spec.ts | 4 +- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/apps/desktop/src/settings-utils.tsx b/apps/desktop/src/settings-utils.tsx index 1a26a781..930d560c 100644 --- a/apps/desktop/src/settings-utils.tsx +++ b/apps/desktop/src/settings-utils.tsx @@ -150,16 +150,18 @@ export function ProviderRow({
{provider.name}
{describeProviderStatus(provider)}
-
- -
+ {action ? ( +
+ +
+ ) : null} ); } @@ -190,11 +192,13 @@ function resolveProviderAction( onLoginProvider: (providerId: string) => void, onLogoutProvider: (providerId: string) => void, onConfigureApiKey: (provider: RuntimeSnapshot["providers"][number]) => void, -): { - readonly disabled: boolean; - readonly label: string; - readonly onClick?: () => void; -} { +): + | { + readonly disabled: boolean; + readonly label: string; + readonly onClick?: () => void; + } + | undefined { if (provider.authSource === "oauth") { return { disabled: false, @@ -219,8 +223,12 @@ function resolveProviderAction( }; } + if (provider.authSource === "env" || provider.authSource === "external") { + return undefined; + } + return { disabled: true, - label: provider.authSource === "env" || provider.authSource === "external" ? "Managed externally" : "Configure externally", + label: "Configure externally", }; } diff --git a/apps/desktop/tests/core/provider-settings.spec.ts b/apps/desktop/tests/core/provider-settings.spec.ts index b2e214e2..b460531c 100644 --- a/apps/desktop/tests/core/provider-settings.spec.ts +++ b/apps/desktop/tests/core/provider-settings.spec.ts @@ -103,7 +103,7 @@ test("settings shows environment-configured providers as managed externally", as has: window.locator(".settings-row__title", { hasText: /^openai$/ }), }); await expect(openAiRow).toContainText("Environment variable"); - await expect(openAiRow.getByRole("button", { name: "Managed externally" })).toBeDisabled(); + await expect(openAiRow.locator(".settings-row__control")).toHaveCount(0); } finally { await harness.close(); if (previousOpenAiKey === undefined) { @@ -162,7 +162,7 @@ test("settings keeps models.json provider overrides in the external-config state has: window.locator(".settings-row__title", { hasText: /^openai$/ }), }); await expect(openAiRow).toContainText("Configured externally"); - await expect(openAiRow.getByRole("button", { name: "Managed externally" })).toBeDisabled(); + await expect(openAiRow.locator(".settings-row__control")).toHaveCount(0); } finally { await harness.close(); } From 5e74b3e5f5cb23c8b20da61b0e9418e4ad019c87 Mon Sep 17 00:00:00 2001 From: Maximilian Fellner Date: Thu, 23 Apr 2026 22:23:45 +0200 Subject: [PATCH 3/4] Cover custom endpoints add/edit/delete in core lane MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Verify the feature end-to-end on the Electron surface: seeded workspace with no existing custom providers, drive Settings → Providers → Custom endpoints → Add, fill the form with a manual model ID, assert the entry renders and ~/.pi/agent/models.json round-trips the expected provider config. Also exercise edit and delete. A second spec covers collision-with-existing-provider validation, invalid base URL rejection, and ESC dismiss. No probe coverage here — the Detect-models flow makes a real HTTP request and belongs in a live/native spec; add later. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../tests/core/custom-endpoints.spec.ts | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 apps/desktop/tests/core/custom-endpoints.spec.ts diff --git a/apps/desktop/tests/core/custom-endpoints.spec.ts b/apps/desktop/tests/core/custom-endpoints.spec.ts new file mode 100644 index 00000000..d9c497fd --- /dev/null +++ b/apps/desktop/tests/core/custom-endpoints.spec.ts @@ -0,0 +1,152 @@ +import { readFile } from "node:fs/promises"; +import { join } from "node:path"; +import { expect, test } from "@playwright/test"; +import { + desktopShortcut, + launchDesktop, + makeUserDataDir, + makeWorkspace, + seedAgentDir, +} from "../helpers/electron-app"; + +async function readModelsJson(agentDir: string): Promise> { + const raw = await readFile(join(agentDir, "models.json"), "utf8"); + return JSON.parse(raw) as Record; +} + +async function openProvidersSettings(window: Awaited>["firstWindow"]>>) { + await window.keyboard.press(desktopShortcut(",")); + await expect(window.getByTestId("settings-surface")).toBeVisible(); + await window.getByRole("button", { name: "Providers", exact: true }).click(); + await expect(window.locator(".view-header__title")).toHaveText("Providers"); +} + +test("settings lets the user add, edit, and delete an OpenAI-compatible custom endpoint", async () => { + test.setTimeout(60_000); + const userDataDir = await makeUserDataDir(); + const agentDir = join(userDataDir, "agent"); + const workspacePath = await makeWorkspace("custom-endpoints-add-workspace"); + await seedAgentDir(agentDir, { enabledModels: [] }); + + const harness = await launchDesktop(userDataDir, { + agentDir, + initialWorkspaces: [workspacePath], + scrubProviderEnv: true, + testMode: "background", + }); + + try { + const window = await harness.firstWindow(); + await openProvidersSettings(window); + + const customEndpoints = window.locator(".settings-section", { + has: window.locator(".settings-section__title", { hasText: "Custom endpoints" }), + }); + await expect(customEndpoints).toContainText("No custom endpoints yet."); + await customEndpoints.getByRole("button", { name: "Add endpoint", exact: true }).click(); + + const dialog = window.getByTestId("custom-endpoint-dialog"); + await expect(dialog).toBeVisible(); + + await dialog.getByLabel("Provider ID").fill("ollama-local"); + await dialog.getByLabel("Base URL").fill("http://localhost:11434/v1"); + await dialog.getByLabel("Add model ID manually").fill("llama3.1"); + await dialog.getByRole("button", { name: "Add", exact: true }).click(); + + await dialog.getByRole("button", { name: "Add endpoint", exact: true }).click(); + await expect(dialog).toHaveCount(0); + + const entryRow = customEndpoints.locator(".settings-row", { + has: window.locator(".settings-row__title", { hasText: /^ollama-local$/ }), + }); + await expect(entryRow).toContainText("http://localhost:11434/v1"); + await expect(entryRow).toContainText("1 model"); + + const savedModels = await readModelsJson(agentDir); + const savedProviders = savedModels.providers as Record>; + expect(savedProviders["ollama-local"]).toMatchObject({ + baseUrl: "http://localhost:11434/v1", + api: "openai-completions", + apiKey: "unused", + models: [{ id: "llama3.1" }], + }); + + // Edit flow: change base URL. + await entryRow.getByRole("button", { name: "Edit", exact: true }).click(); + const editDialog = window.getByTestId("custom-endpoint-dialog"); + await expect(editDialog).toBeVisible(); + await expect(editDialog.getByLabel("Provider ID")).toBeDisabled(); + const baseUrlInput = editDialog.getByLabel("Base URL"); + await baseUrlInput.fill("http://localhost:8000/v1"); + await editDialog.getByRole("button", { name: "Save changes", exact: true }).click(); + await expect(editDialog).toHaveCount(0); + await expect(entryRow).toContainText("http://localhost:8000/v1"); + + const editedModels = await readModelsJson(agentDir); + const editedProviders = editedModels.providers as Record>; + expect(editedProviders["ollama-local"]).toMatchObject({ + baseUrl: "http://localhost:8000/v1", + }); + + // Delete flow. + await entryRow.getByRole("button", { name: "Remove", exact: true }).click(); + await expect(customEndpoints).toContainText("No custom endpoints yet."); + + const afterDelete = await readModelsJson(agentDir); + const afterDeleteProviders = (afterDelete.providers as Record) ?? {}; + expect(afterDeleteProviders["ollama-local"]).toBeUndefined(); + } finally { + await harness.close(); + } +}); + +test("custom endpoint dialog blocks colliding provider IDs and invalid base URLs", async () => { + test.setTimeout(60_000); + const userDataDir = await makeUserDataDir(); + const agentDir = join(userDataDir, "agent"); + const workspacePath = await makeWorkspace("custom-endpoints-validation-workspace"); + await seedAgentDir(agentDir, { enabledModels: [] }); + + const harness = await launchDesktop(userDataDir, { + agentDir, + initialWorkspaces: [workspacePath], + scrubProviderEnv: true, + testMode: "background", + }); + + try { + const window = await harness.firstWindow(); + await openProvidersSettings(window); + + const customEndpoints = window.locator(".settings-section", { + has: window.locator(".settings-section__title", { hasText: "Custom endpoints" }), + }); + await customEndpoints.getByRole("button", { name: "Add endpoint", exact: true }).click(); + + const dialog = window.getByTestId("custom-endpoint-dialog"); + await expect(dialog).toBeVisible(); + + // Collides with the seeded openai provider. + await dialog.getByLabel("Provider ID").fill("openai"); + await expect(dialog).toContainText("already in use"); + const saveButton = dialog.getByRole("button", { name: "Add endpoint", exact: true }); + await expect(saveButton).toBeDisabled(); + + // Switch to a unique ID so ID validation no longer blocks save. + await dialog.getByLabel("Provider ID").fill("my-endpoint"); + await dialog.getByLabel("Base URL").fill("ftp://not-allowed"); + await dialog.getByLabel("Add model ID manually").fill("test-model"); + await dialog.getByRole("button", { name: "Add", exact: true }).click(); + + await saveButton.click(); + await expect(dialog).toContainText("Base URL must start with http:// or https://"); + await expect(dialog).toBeVisible(); + + // ESC closes the dialog without saving. + await dialog.press("Escape"); + await expect(dialog).toHaveCount(0); + await expect(customEndpoints).toContainText("No custom endpoints yet."); + } finally { + await harness.close(); + } +}); From f5730da9145a922c10715e21703549b4f2066d61 Mon Sep 17 00:00:00 2001 From: Matthew Lam Date: Thu, 30 Apr 2026 22:39:28 -0400 Subject: [PATCH 4/4] Harden custom provider ownership --- apps/desktop/electron/app-store.ts | 64 ++++++++- .../tests/core/custom-endpoints.spec.ts | 126 +++++++++++++++++- .../src/custom-provider-store.ts | 23 +++- .../src/custom-provider-types.ts | 27 ++++ .../pi-sdk-driver/src/runtime-supervisor.ts | 10 +- 5 files changed, 238 insertions(+), 12 deletions(-) diff --git a/apps/desktop/electron/app-store.ts b/apps/desktop/electron/app-store.ts index 6b5e03a4..94bf833b 100644 --- a/apps/desktop/electron/app-store.ts +++ b/apps/desktop/electron/app-store.ts @@ -482,7 +482,7 @@ export class DesktopAppStore implements AppStoreInternals { await this.initialize(); const normalizedShell = integratedTerminalShell.trim(); if (this.state.integratedTerminalShell === normalizedShell) { - return this.emit(); + return structuredClone(this.state); } this.state = { ...this.state, @@ -643,12 +643,14 @@ export class DesktopAppStore implements AppStoreInternals { ...(model.contextWindow !== undefined ? { contextWindow: model.contextWindow } : {}), })), }), + { refreshAllWorkspaces: true }, ); } async deleteCustomProvider(workspaceId: string, providerId: string): Promise { return this.withRuntimeUpdate(workspaceId, (ws) => this.driver.runtimeSupervisor.deleteCustomProvider(ws, providerId), + { refreshAllWorkspaces: true }, ); } @@ -743,6 +745,7 @@ export class DesktopAppStore implements AppStoreInternals { action: (ws: WorkspaceRef) => Promise, options?: { readonly reloadSessions?: boolean; + readonly refreshAllWorkspaces?: boolean; }, ): Promise { await this.initialize(); @@ -753,16 +756,71 @@ export class DesktopAppStore implements AppStoreInternals { return this.withErrorHandling(async () => { const snapshot = await action(ws); - this.runtimeByWorkspace.set(workspaceId, snapshot); + if (options?.refreshAllWorkspaces) { + await this.refreshRuntimeForAllWorkspaces(workspaceId, snapshot); + } else { + this.runtimeByWorkspace.set(workspaceId, snapshot); + } if (options?.reloadSessions) { this.clearExtensionUiForWorkspace(workspaceId); await this.reloadSessionsForWorkspace(workspaceId); } - await this.refreshSessionCommandsForWorkspace(workspaceId); + if (options?.refreshAllWorkspaces) { + await this.refreshSessionCommandsForAllWorkspaces(); + } else { + await this.refreshSessionCommandsForWorkspace(workspaceId); + } return this.refreshState({ clearLastError: true }); }); } + private async refreshRuntimeForAllWorkspaces( + updatedWorkspaceId: string, + updatedSnapshot: RuntimeSnapshot, + ): Promise { + this.runtimeByWorkspace.set(updatedWorkspaceId, updatedSnapshot); + const workspacesToRefresh = this.state.workspaces.filter((workspace) => workspace.id !== updatedWorkspaceId); + const snapshots = await Promise.allSettled( + workspacesToRefresh.map(async (workspace) => { + const runtime = await this.driver.runtimeSupervisor.refreshRuntime({ + workspaceId: workspace.id, + path: workspace.path, + displayName: workspace.name, + }); + return [workspace, runtime] as const; + }), + ); + snapshots.forEach((result, index) => { + const workspace = workspacesToRefresh[index]; + if (result.status === "fulfilled") { + this.runtimeByWorkspace.set(result.value[0].id, result.value[1]); + return; + } + console.warn( + `[pi-gui] Failed to refresh runtime for ${workspace?.path ?? "unknown workspace"} after custom provider update: ${ + result.reason instanceof Error ? result.reason.message : String(result.reason) + }`, + ); + }); + } + + private async refreshSessionCommandsForAllWorkspaces(): Promise { + const results = await Promise.allSettled( + this.state.workspaces.map((workspace) => this.refreshSessionCommandsForWorkspace(workspace.id)), + ); + results.forEach((result, index) => { + if (result.status === "fulfilled") { + return; + } + const workspace = this.state.workspaces[index]; + console.warn( + `[pi-gui] Failed to refresh session commands for ${workspace?.path ?? "unknown workspace"} after custom provider update: ${ + result.reason instanceof Error ? result.reason.message : String(result.reason) + }`, + ); + }); + } + /* ── Internal infrastructure (AppStoreInternals) ───────── */ private async initializeInternal(): Promise { diff --git a/apps/desktop/tests/core/custom-endpoints.spec.ts b/apps/desktop/tests/core/custom-endpoints.spec.ts index bdfbd066..688ea981 100644 --- a/apps/desktop/tests/core/custom-endpoints.spec.ts +++ b/apps/desktop/tests/core/custom-endpoints.spec.ts @@ -1,12 +1,15 @@ -import { readFile } from "node:fs/promises"; +import { readFile, writeFile } from "node:fs/promises"; import { join } from "node:path"; import { expect, test } from "@playwright/test"; import { desktopShortcut, + getDesktopState, launchDesktop, makeUserDataDir, makeWorkspace, + type PiAppWindow, seedAgentDir, + waitForWorkspaceByPath, } from "../helpers/electron-app"; async function readModelsJson(agentDir: string): Promise> { @@ -26,17 +29,19 @@ test("settings lets the user add, edit, and delete an OpenAI-compatible custom e const userDataDir = await makeUserDataDir(); const agentDir = join(userDataDir, "agent"); const workspacePath = await makeWorkspace("custom-endpoints-add-workspace"); + const otherWorkspacePath = await makeWorkspace("custom-endpoints-other-workspace"); await seedAgentDir(agentDir, { enabledModels: [] }); const harness = await launchDesktop(userDataDir, { agentDir, - initialWorkspaces: [workspacePath], + initialWorkspaces: [workspacePath, otherWorkspacePath], scrubProviderEnv: true, testMode: "background", }); try { const window = await harness.firstWindow(); + const otherWorkspace = await waitForWorkspaceByPath(window, otherWorkspacePath); await openProvidersSettings(window); const customEndpoints = window.locator(".settings-section", { @@ -71,6 +76,13 @@ test("settings lets the user add, edit, and delete an OpenAI-compatible custom e piGuiCustomEndpoint: true, models: [{ id: "llama3.1" }], }); + await expect.poll(async () => { + const state = await getDesktopState(window); + return ( + state.runtimeByWorkspace[otherWorkspace.id]?.providers.some((provider) => provider.id === "ollama-local") + ?? false + ); + }).toBe(true); // Edit flow: change base URL. await entryRow.getByRole("button", { name: "Edit", exact: true }).click(); @@ -96,6 +108,116 @@ test("settings lets the user add, edit, and delete an OpenAI-compatible custom e const afterDelete = await readModelsJson(agentDir); const afterDeleteProviders = (afterDelete.providers as Record) ?? {}; expect(afterDeleteProviders["ollama-local"]).toBeUndefined(); + await expect.poll(async () => { + const state = await getDesktopState(window); + return ( + state.runtimeByWorkspace[otherWorkspace.id]?.providers.some((provider) => provider.id === "ollama-local") + ?? false + ); + }).toBe(false); + } finally { + await harness.close(); + } +}); + +test("custom endpoints keep legacy managed entries separate from built-in overrides", async () => { + test.setTimeout(60_000); + const userDataDir = await makeUserDataDir(); + const agentDir = join(userDataDir, "agent"); + const workspacePath = await makeWorkspace("custom-endpoints-ownership-workspace"); + await seedAgentDir(agentDir, { enabledModels: [] }); + await writeFile( + join(agentDir, "models.json"), + `${JSON.stringify( + { + providers: { + openai: { + baseUrl: "https://proxy.example.test/v1", + api: "openai-completions", + apiKey: "test-openai-key", + models: [{ id: "proxy-model" }], + }, + deepseek: { + baseUrl: "https://deepseek-proxy.example.test/v1", + api: "openai-completions", + apiKey: "test-deepseek-key", + models: [{ id: "deepseek-chat" }], + }, + "legacy-local": { + baseUrl: "http://localhost:11434/v1", + api: "openai-completions", + apiKey: "unused", + models: [{ id: "llama3.1" }], + }, + }, + }, + null, + 2, + )}\n`, + "utf8", + ); + + const harness = await launchDesktop(userDataDir, { + agentDir, + initialWorkspaces: [workspacePath], + scrubProviderEnv: true, + testMode: "background", + }); + + try { + const window = await harness.firstWindow(); + const workspace = await waitForWorkspaceByPath(window, workspacePath); + await openProvidersSettings(window); + + const customEndpoints = window.locator(".settings-section", { + has: window.locator(".settings-section__title", { hasText: "Custom endpoints" }), + }); + await expect(customEndpoints).toContainText("legacy-local"); + await expect( + customEndpoints.locator(".settings-row", { + has: window.locator(".settings-row__title", { hasText: /^openai$/ }), + }), + ).toHaveCount(0); + await expect( + customEndpoints.locator(".settings-row", { + has: window.locator(".settings-row__title", { hasText: /^deepseek$/ }), + }), + ).toHaveCount(0); + + const blockedState = await window.evaluate(async ({ workspaceId }) => { + const app = (window as PiAppWindow).piApp; + if (!app) { + throw new Error("piApp IPC bridge is unavailable"); + } + return app.setCustomProvider(workspaceId, { + providerId: "openai", + baseUrl: "http://localhost:11434/v1", + models: [{ id: "should-not-save" }], + }); + }, { workspaceId: workspace.id }); + expect(blockedState.lastError).toContain("conflicts with a built-in provider"); + + await window.evaluate(async ({ workspaceId }) => { + const app = (window as PiAppWindow).piApp; + if (!app) { + throw new Error("piApp IPC bridge is unavailable"); + } + await app.deleteCustomProvider(workspaceId, "openai"); + }, { workspaceId: workspace.id }); + const afterBlockedDelete = await readModelsJson(agentDir); + expect((afterBlockedDelete.providers as Record).openai).toBeDefined(); + expect((afterBlockedDelete.providers as Record).deepseek).toBeDefined(); + + const legacyRow = customEndpoints.locator(".settings-row", { + has: window.locator(".settings-row__title", { hasText: /^legacy-local$/ }), + }); + await legacyRow.getByRole("button", { name: "Remove", exact: true }).click(); + await expect(customEndpoints).toContainText("No custom endpoints yet."); + + const afterLegacyDelete = await readModelsJson(agentDir); + const providers = afterLegacyDelete.providers as Record; + expect(providers.openai).toBeDefined(); + expect(providers["legacy-local"]).toBeUndefined(); } finally { await harness.close(); } diff --git a/packages/pi-sdk-driver/src/custom-provider-store.ts b/packages/pi-sdk-driver/src/custom-provider-store.ts index e4cf8f52..3bbb03c9 100644 --- a/packages/pi-sdk-driver/src/custom-provider-store.ts +++ b/packages/pi-sdk-driver/src/custom-provider-store.ts @@ -1,6 +1,7 @@ import { mkdir, readFile, rename, writeFile } from "node:fs/promises"; import { dirname } from "node:path"; import { + BUILT_IN_PROVIDER_IDS, CUSTOM_PROVIDER_ID_PATTERN, CUSTOM_PROVIDER_PLACEHOLDER_API_KEY, isValidHttpBaseUrl, @@ -13,6 +14,7 @@ import { export type { CustomProviderEntry, CustomProviderInput, CustomProviderModelInput } from "./custom-provider-types.js"; export { + BUILT_IN_PROVIDER_IDS, CUSTOM_PROVIDER_ID_PATTERN, CUSTOM_PROVIDER_PLACEHOLDER_API_KEY, isValidHttpBaseUrl, @@ -38,7 +40,7 @@ export class CustomProviderStore { const data = await readModelsJson(this.modelsJsonPath); const providers = ensureProvidersRecord(data); const existing = providers[input.providerId]; - if (existing && typeof existing === "object" && !isPiGuiCustomProviderConfig(existing as Record)) { + if (existing && typeof existing === "object" && !isPiGuiCustomProviderConfig(input.providerId, existing as Record)) { throw new Error( `Provider ID "${input.providerId}" already exists in models.json and is not managed by pi-gui.`, ); @@ -56,7 +58,7 @@ export class CustomProviderStore { return false; } const existing = (providers as Record)[providerId]; - if (!existing || typeof existing !== "object" || !isPiGuiCustomProviderConfig(existing as Record)) { + if (!existing || typeof existing !== "object" || !isPiGuiCustomProviderConfig(providerId, existing as Record)) { return false; } delete (providers as Record)[providerId]; @@ -122,7 +124,7 @@ function readCustomProviders(data: Record): readonly CustomProv continue; } const config = rawConfig as Record; - if (!isPiGuiCustomProviderConfig(config)) { + if (!isPiGuiCustomProviderConfig(providerId, config)) { continue; } const baseUrl = typeof config.baseUrl === "string" ? config.baseUrl : undefined; @@ -160,8 +162,19 @@ function readCustomProviders(data: Record): readonly CustomProv return entries; } -function isPiGuiCustomProviderConfig(config: Record): boolean { - return config[PI_GUI_CUSTOM_PROVIDER_MARKER] === true; +function isPiGuiCustomProviderConfig(providerId: string, config: Record): boolean { + if (config[PI_GUI_CUSTOM_PROVIDER_MARKER] === true) { + return true; + } + if (BUILT_IN_PROVIDER_IDS.has(providerId)) { + return false; + } + return ( + config.api === OPENAI_COMPLETIONS_API && + typeof config.baseUrl === "string" && + Array.isArray(config.models) && + config.models.length > 0 + ); } function ensureProvidersRecord(data: Record): Record { diff --git a/packages/pi-sdk-driver/src/custom-provider-types.ts b/packages/pi-sdk-driver/src/custom-provider-types.ts index bc6d8695..cd8d8c28 100644 --- a/packages/pi-sdk-driver/src/custom-provider-types.ts +++ b/packages/pi-sdk-driver/src/custom-provider-types.ts @@ -22,6 +22,33 @@ export const OPENAI_COMPLETIONS_API = "openai-completions"; export const CUSTOM_PROVIDER_PLACEHOLDER_API_KEY = "unused"; export const PI_GUI_CUSTOM_PROVIDER_MARKER = "piGuiCustomEndpoint"; +export const BUILT_IN_PROVIDER_IDS: ReadonlySet = new Set([ + "amazon-bedrock", + "anthropic", + "azure-openai-responses", + "cerebras", + "deepseek", + "fireworks", + "github-copilot", + "google", + "google-antigravity", + "google-gemini-cli", + "google-vertex", + "groq", + "huggingface", + "kimi-coding", + "minimax", + "minimax-cn", + "mistral", + "openai", + "openai-codex", + "opencode", + "opencode-go", + "openrouter", + "vercel-ai-gateway", + "xai", + "zai", +]); export function isValidHttpBaseUrl(value: string): boolean { try { diff --git a/packages/pi-sdk-driver/src/runtime-supervisor.ts b/packages/pi-sdk-driver/src/runtime-supervisor.ts index b5262a46..0a09dd60 100644 --- a/packages/pi-sdk-driver/src/runtime-supervisor.ts +++ b/packages/pi-sdk-driver/src/runtime-supervisor.ts @@ -28,9 +28,15 @@ import { createRuntimeDependencies } from "./runtime-deps.js"; import { createSettingsManagerWithoutNpmPackages, isGlobalNpmLookupError } from "./npm-package-fallback.js"; import { skillSlashCommand } from "./runtime-command-utils.js"; import type { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent"; -import { CustomProviderStore, type CustomProviderEntry, type CustomProviderInput } from "./custom-provider-store.js"; +import { + BUILT_IN_PROVIDER_IDS, + CustomProviderStore, + type CustomProviderEntry, + type CustomProviderInput, +} from "./custom-provider-store.js"; export { + BUILT_IN_PROVIDER_IDS, CUSTOM_PROVIDER_ID_PATTERN, isValidHttpBaseUrl, OPENAI_COMPLETIONS_API, @@ -135,7 +141,7 @@ export class RuntimeSupervisor implements RuntimeResourceDriver { async setCustomProvider(workspace: WorkspaceRef, input: CustomProviderInput): Promise { const oauthProviderIds = new Set(this.authStorage.getOAuthProviders().map((provider) => provider.id)); - if (providerSupportsDesktopApiKeySetup(input.providerId) || oauthProviderIds.has(input.providerId)) { + if (BUILT_IN_PROVIDER_IDS.has(input.providerId) || oauthProviderIds.has(input.providerId)) { throw new Error( `Provider ID "${input.providerId}" conflicts with a built-in provider. Pick a unique ID.`, );