diff --git a/AGENTS.md b/AGENTS.md index f25de08542..e3fe3d3cdd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -64,6 +64,11 @@ internal/ ### Key Patterns - **Config is a Service**: accessed via `config.Service`, not global state. +- **Provider updates**: Catwalk is the default provider seed/cache. Venice and + Copilot models are live-overlaid only when authenticated and use the existing + provider cache/syncer with a 60s warm-cache TTL; `crush update-providers` + refreshes authenticated live providers best-effort, while `--source=venice` + and `--source=copilot` refresh them directly. - **Tools are self-documenting**: each tool has a `.go` implementation and a `.md` description file in `internal/agent/tools/`. - **System prompts are Go templates**: `internal/agent/templates/*.md.tpl` diff --git a/README.md b/README.md index f3ddaf7afa..2ab556614d 100644 --- a/README.md +++ b/README.md @@ -859,6 +859,7 @@ command: ```bash # Update providers remotely from Catwalk. +# This also refreshes Venice and Copilot models when you're authenticated. crush update-providers # Update providers from a custom Catwalk base URL. diff --git a/internal/app/app.go b/internal/app/app.go index d8a3abc63b..a5e6666cde 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -504,6 +504,7 @@ func (app *App) setupEvents() { setupSubscriberMustDeliver(ctx, app.serviceEventsWG, "run-completions", app.runCompletions.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events) setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events) + setupSubscriber(ctx, app.serviceEventsWG, "providers", config.SubscribeProviderEvents, app.events) if app.Skills != nil { setupSubscriber(ctx, app.serviceEventsWG, "skills", app.Skills.SubscribeEvents, app.events) } diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 3b4b35b681..ae9c017c45 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "log/slog" + "os" "charm.land/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" @@ -17,7 +18,7 @@ var updateProvidersCmd = &cobra.Command{ Short: "Update providers", Long: `Update provider information from a specified local path or remote URL.`, Example: ` -# Update Catwalk providers remotely (default) +# Update Catwalk providers remotely (default), plus authenticated live providers crush update-providers # Update Catwalk providers from a custom URL @@ -34,6 +35,12 @@ crush update-providers --source=hyper # Update Hyper from a custom URL crush update-providers --source=hyper https://hyper.example.com + +# Update Venice provider models +crush update-providers --source=venice + +# Update Copilot provider models +crush update-providers --source=copilot `, RunE: func(cmd *cobra.Command, args []string) error { // NOTE(@andreynering): We want to skip logging output do stdout here. @@ -48,10 +55,23 @@ crush update-providers --source=hyper https://hyper.example.com switch updateProvidersSource { case "catwalk": err = config.UpdateProviders(pathOrURL) + if err == nil && pathOrURL == "" { + updateAuthenticatedLiveProviders(cmd) + } case "hyper": err = config.UpdateHyper(pathOrURL) + case "venice", "copilot": + cfg, loadErr := loadUpdateProvidersConfig(cmd) + if loadErr != nil { + return loadErr + } + if updateProvidersSource == "venice" { + err = config.UpdateVenice(pathOrURL, cfg) + } else { + err = config.UpdateCopilot(pathOrURL, cfg) + } default: - return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource) + return fmt.Errorf("invalid source %q, must be 'catwalk', 'hyper', 'venice', or 'copilot'", updateProvidersSource) } if err != nil { @@ -77,6 +97,49 @@ crush update-providers --source=hyper https://hyper.example.com }, } +func updateAuthenticatedLiveProviders(cmd *cobra.Command) { + cfg, err := loadUpdateProvidersConfig(cmd) + if err != nil { + fmt.Fprintf(os.Stderr, "Note: skipping Venice and Copilot updates: %v\n", err) + return + } + if err := config.UpdateVenice("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) { + fmt.Fprintf(os.Stderr, "Note: skipping Venice update: %v\n", err) + } + if err := config.UpdateCopilot("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) { + fmt.Fprintf(os.Stderr, "Note: skipping Copilot update: %v\n", err) + } +} + +func loadUpdateProvidersConfig(cmd *cobra.Command) (*config.Config, error) { + cwd, err := ResolveCwd(cmd) + if err != nil { + return nil, err + } + dataDir, _ := cmd.Flags().GetString("data-dir") + debug, _ := cmd.Flags().GetBool("debug") + + // The config is loaded only to resolve provider credentials; + // updateLiveProvider fetches explicitly afterwards. Disable provider + // auto-update during Load so it doesn't fetch the same endpoints + // first (avoiding a double fetch per provider). + previous, hadPrevious := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE") + _ = os.Setenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE", "1") + defer func() { + if hadPrevious { + _ = os.Setenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE", previous) + } else { + _ = os.Unsetenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE") + } + }() + + store, err := config.Load(cwd, dataDir, debug) + if err != nil { + return nil, err + } + return store.Config(), nil +} + func init() { - updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)") + updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk, hyper, venice, or copilot)") } diff --git a/internal/cmd/update_providers_test.go b/internal/cmd/update_providers_test.go new file mode 100644 index 0000000000..d1d8d6dfd1 --- /dev/null +++ b/internal/cmd/update_providers_test.go @@ -0,0 +1,27 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpdateProvidersCmd_SourceFlagIncludesLiveProviders(t *testing.T) { + t.Parallel() + + flag := updateProvidersCmd.Flags().Lookup("source") + require.NotNil(t, flag) + require.Equal(t, "catwalk", flag.DefValue) + require.Contains(t, flag.Usage, "venice") + require.Contains(t, flag.Usage, "copilot") +} + +func TestUpdateProvidersCmd_ExamplesIncludeLiveProviders(t *testing.T) { + t.Parallel() + + example := strings.ToLower(updateProvidersCmd.Example) + require.Contains(t, example, "--source=venice") + require.Contains(t, example, "--source=copilot") + require.Contains(t, example, "authenticated live providers") +} diff --git a/internal/config/copilot_models.go b/internal/config/copilot_models.go new file mode 100644 index 0000000000..5416d4889c --- /dev/null +++ b/internal/config/copilot_models.go @@ -0,0 +1,176 @@ +package config + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "slices" + "strings" + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" +) + +var _ liveProviderClient = (*realCopilotModelsClient)(nil) + +type copilotTokenRefresher func(context.Context, string) (*oauth.Token, error) + +type realCopilotModelsClient struct { + baseURL string + apiKey string + oauthToken *oauth.Token + refreshToken copilotTokenRefresher +} + +func (r *realCopilotModelsClient) Get(ctx context.Context) (catwalk.Provider, error) { + var result catwalk.Provider + baseURL := strings.TrimRight(r.baseURL, "/") + accessToken, err := r.accessToken(ctx) + if err != nil { + return result, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) + if err != nil { + return result, fmt.Errorf("could not create request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + for key, value := range copilot.Headers() { + req.Header.Set(key, value) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var models copilotModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&models); err != nil { + return result, fmt.Errorf("failed to decode response: %w", err) + } + + result = catwalk.Provider{ + ID: catwalk.InferenceProviderCopilot, + APIEndpoint: baseURL, + Type: catwalk.TypeOpenAICompat, + Models: copilotModelsToCatwalkModels(models), + } + return result, nil +} + +func (r *realCopilotModelsClient) accessToken(ctx context.Context) (string, error) { + if r.oauthToken != nil { + if token := strings.TrimSpace(r.oauthToken.AccessToken); token != "" && !r.oauthToken.IsExpired() { + r.apiKey = token + return token, nil + } + + refreshToken := strings.TrimSpace(r.oauthToken.RefreshToken) + if refreshToken != "" { + refresh := r.refreshToken + if refresh == nil { + refresh = copilot.RefreshToken + } + refreshedToken, err := refresh(ctx, refreshToken) + if err != nil { + return "", fmt.Errorf("failed to refresh Copilot token: %w", err) + } + r.oauthToken = refreshedToken + r.apiKey = strings.TrimSpace(refreshedToken.AccessToken) + if r.apiKey != "" { + return r.apiKey, nil + } + } + } + + if token := strings.TrimSpace(r.apiKey); token != "" { + return token, nil + } + return "", fmt.Errorf("missing Copilot access token") +} + +type copilotModelsResponse struct { + Data []copilotModel `json:"data"` +} + +type copilotModel struct { + ID string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Capabilities copilotModelCapabilities `json:"capabilities"` +} + +type copilotModelCapabilities struct { + Type string `json:"type"` + Limits copilotModelLimits `json:"limits"` + Supports copilotModelSupports `json:"supports"` +} + +type copilotModelLimits struct { + MaxContextWindowTokens int64 `json:"max_context_window_tokens"` + MaxOutputTokens int64 `json:"max_output_tokens"` +} + +type copilotModelSupports struct { + Vision bool `json:"vision"` + ReasoningEffort []string `json:"reasoning_effort"` + AdaptiveThinking bool `json:"adaptive_thinking"` +} + +func copilotModelsToCatwalkModels(response copilotModelsResponse) []catwalk.Model { + aliasedVersions := make(map[string]bool, len(response.Data)) + for _, model := range response.Data { + if model.Version != "" && model.ID != model.Version { + aliasedVersions[model.Version] = true + } + } + + seen := make(map[string]bool, len(response.Data)) + models := make([]catwalk.Model, 0, len(response.Data)) + for _, model := range response.Data { + if shouldSkipCopilotModel(model, aliasedVersions) || seen[model.ID] { + continue + } + seen[model.ID] = true + reasoningLevels := model.Capabilities.Supports.ReasoningEffort + models = append(models, catwalk.Model{ + ID: model.ID, + Name: model.Name, + ContextWindow: model.Capabilities.Limits.MaxContextWindowTokens, + DefaultMaxTokens: model.Capabilities.Limits.MaxOutputTokens, + CanReason: model.Capabilities.Supports.AdaptiveThinking || len(reasoningLevels) > 0, + ReasoningLevels: reasoningLevels, + SupportsImages: model.Capabilities.Supports.Vision, + }) + } + + slices.SortStableFunc(models, func(a, b catwalk.Model) int { + return strings.Compare(a.ID, b.ID) + }) + return models +} + +func shouldSkipCopilotModel(model copilotModel, aliasedVersions map[string]bool) bool { + if capType := model.Capabilities.Type; capType != "" && capType != "chat" { + return true + } + return model.ID == "" || + aliasedVersions[model.ID] || + strings.Contains(model.ID, "embedding") || + strings.HasPrefix(model.ID, "accounts/msft/routers") || + strings.HasPrefix(model.ID, "oswe-vscode") || + strings.HasPrefix(model.ID, "lark") || + strings.HasPrefix(model.ID, "mai-code") || + model.ID == "gpt-4-o-preview" || + model.ID == "trajectory-compaction" +} diff --git a/internal/config/copilot_models_test.go b/internal/config/copilot_models_test.go new file mode 100644 index 0000000000..2837290302 --- /dev/null +++ b/internal/config/copilot_models_test.go @@ -0,0 +1,195 @@ +package config + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/stretchr/testify/require" +) + +func TestRealCopilotModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + require.Equal(t, "application/json", r.Header.Get("Accept")) + for key, value := range copilot.Headers() { + require.Equal(t, value, r.Header.Get(key)) + } + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{ + "data": [ + { + "id": "claude-sonnet-4.6", + "name": "Claude Sonnet 4.6", + "version": "claude-sonnet-4.6", + "capabilities": { + "limits": { + "max_context_window_tokens": 264000, + "max_output_tokens": 64000 + }, + "supports": { + "vision": true, + "reasoning_effort": ["low", "medium", "high"] + } + } + } + ] + }`)) + require.NoError(t, err) + })) + defer server.Close() + + client := &realCopilotModelsClient{baseURL: server.URL + "/", apiKey: " test-token "} + provider, err := client.Get(t.Context()) + + require.NoError(t, err) + require.Equal(t, catwalk.InferenceProviderCopilot, provider.ID) + require.Equal(t, server.URL, provider.APIEndpoint) + require.Equal(t, catwalk.TypeOpenAICompat, provider.Type) + require.Equal(t, []catwalk.Model{ + { + ID: "claude-sonnet-4.6", + Name: "Claude Sonnet 4.6", + ContextWindow: 264000, + DefaultMaxTokens: 64000, + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + SupportsImages: true, + }, + }, provider.Models) +} + +func TestRealCopilotModelsClientGetRefreshesExpiredOAuthToken(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Bearer refreshed-token", r.Header.Get("Authorization")) + _, err := w.Write([]byte(`{ + "data": [ + { + "id": "gpt-5.1", + "name": "GPT 5.1", + "capabilities": { + "limits": { + "max_context_window_tokens": 128000, + "max_output_tokens": 16000 + } + } + } + ] + }`)) + require.NoError(t, err) + })) + defer server.Close() + + client := &realCopilotModelsClient{ + baseURL: server.URL, + oauthToken: &oauth.Token{ + AccessToken: "expired-token", + RefreshToken: "github-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(-time.Hour).Unix(), + }, + refreshToken: func(_ctx context.Context, githubToken string) (*oauth.Token, error) { + require.Equal(t, "github-token", githubToken) + return &oauth.Token{ + AccessToken: "refreshed-token", + RefreshToken: githubToken, + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(time.Hour).Unix(), + }, nil + }, + } + + provider, err := client.Get(t.Context()) + + require.NoError(t, err) + require.Equal(t, []catwalk.Model{{ID: "gpt-5.1", Name: "GPT 5.1", ContextWindow: 128000, DefaultMaxTokens: 16000}}, provider.Models) + require.Equal(t, "refreshed-token", client.apiKey) +} + +func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { + t.Parallel() + + models := copilotModelsToCatwalkModels(copilotModelsResponse{Data: []copilotModel{ + { + ID: "gpt-5.1", + Name: "GPT 5.1", + Version: "gpt-5.1", + Capabilities: copilotModelCapabilities{ + Limits: copilotModelLimits{MaxContextWindowTokens: 128000, MaxOutputTokens: 16000}, + Supports: copilotModelSupports{ + Vision: true, + ReasoningEffort: []string{"low", "medium", "high"}, + AdaptiveThinking: true, + }, + }, + }, + { + ID: "gpt-5.1", + Name: "GPT 5.1 Duplicate", + Version: "gpt-5.1", + }, + { + ID: "aliased-model", + Name: "Aliased Model", + Version: "versioned-model-2026-01-01", + }, + { + ID: "versioned-model-2026-01-01", + Name: "Versioned Model", + Version: "versioned-model-2026-01-01", + }, + { + ID: "unaliased-model-2026-01-01", + Name: "Unaliased Dated Model", + Version: "unaliased-model-2026-01-01", + Capabilities: copilotModelCapabilities{ + Type: "chat", + }, + }, + { + ID: "some-embeddings-model", + Name: "Embeddings via Type", + Capabilities: copilotModelCapabilities{ + Type: "embeddings", + }, + }, + {ID: "text-embedding-3", Name: "Embedding"}, + {ID: "accounts/msft/routers/test", Name: "Router"}, + {ID: "oswe-vscode-test", Name: "OSWE"}, + {ID: "lark-test", Name: "Lark"}, + {ID: "mai-code-test", Name: "MAI"}, + {ID: "gpt-4-o-preview", Name: "Preview"}, + {ID: "trajectory-compaction", Name: "Compaction"}, + }}) + + require.Equal(t, []catwalk.Model{ + { + ID: "aliased-model", + Name: "Aliased Model", + }, + { + ID: "gpt-5.1", + Name: "GPT 5.1", + ContextWindow: 128000, + DefaultMaxTokens: 16000, + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + SupportsImages: true, + }, + { + ID: "unaliased-model-2026-01-01", + Name: "Unaliased Dated Model", + }, + }, models) +} diff --git a/internal/config/live_provider.go b/internal/config/live_provider.go new file mode 100644 index 0000000000..e04562a742 --- /dev/null +++ b/internal/config/live_provider.go @@ -0,0 +1,188 @@ +package config + +import ( + "context" + "errors" + "log/slog" + "os" + "slices" + "sync" + "sync/atomic" + "time" + + "charm.land/catwalk/pkg/catwalk" +) + +const ( + liveModelsTTL = time.Minute + liveProviderFetchTimeout = 45 * time.Second +) + +type liveProviderClient interface { + Get(context.Context) (catwalk.Provider, error) +} + +var _ syncer[catwalk.Provider] = (*liveProviderSync)(nil) + +type liveProviderSync struct { + once sync.Once + resultMu sync.RWMutex + result catwalk.Provider + cache cache[catwalk.Provider] + client liveProviderClient + onRefresh func(catwalk.Provider) + seed catwalk.Provider + autoupdate bool + credentialed bool + ttl time.Duration + init atomic.Bool +} + +func (s *liveProviderSync) Init(client liveProviderClient, path string, autoupdate bool, seed catwalk.Provider, credentialed bool) { + s.client = client + s.cache = newCache[catwalk.Provider](path) + s.autoupdate = autoupdate + s.seed = seed + s.credentialed = credentialed + s.ttl = liveModelsTTL + s.init.Store(true) +} + +func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { + if !s.init.Load() { + panic("called Get before Init") + } + + s.once.Do(func() { + if !s.autoupdate { + slog.Info("Using provider seed", "provider", s.seed.ID) + s.setResult(s.seed) + return + } + if !s.credentialed { + slog.Info("Skipping live provider sync without credentials", "provider", s.seed.ID) + s.setResult(s.seed) + return + } + + cached, _, cachedErr := s.cache.Get() + cachedAvailable := cachedErr == nil && len(cached.Models) > 0 + fallback := s.seed + if cachedAvailable { + fallback = cached + } + + if cachedAvailable { + if age, ok := cacheAge(s.cache.path); ok && age < s.ttl { + slog.Info("Using cached live provider models", "provider", fallback.ID, "age", age) + s.setResult(cached) + return + } + } + + s.setResult(fallback) + s.refreshInBackground() + }) + return s.getResult(), nil +} + +func (s *liveProviderSync) refreshInBackground() { + slog.Info("Refreshing live provider models in background", "provider", s.seed.ID) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), liveProviderFetchTimeout) + defer cancel() + + result, err := s.client.Get(ctx) + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("Live provider models not updated in time", "provider", s.seed.ID) + return + } + if err != nil { + slog.Warn("Live provider models not updated", "provider", s.seed.ID, "err", err) + return + } + if len(result.Models) == 0 { + slog.Warn("Live provider did not return any models", "provider", s.seed.ID) + return + } + + merged := mergeLiveProvider(s.seed, result) + s.setResult(merged) + if err := s.cache.Store(merged); err != nil { + slog.Warn("Failed to store live provider cache", "provider", s.seed.ID, "err", err) + return + } + if s.onRefresh != nil { + s.onRefresh(merged) + } + }() +} + +func (s *liveProviderSync) setResult(provider catwalk.Provider) { + s.resultMu.Lock() + defer s.resultMu.Unlock() + s.result = provider +} + +func (s *liveProviderSync) getResult() catwalk.Provider { + s.resultMu.RLock() + defer s.resultMu.RUnlock() + return s.result +} + +func cacheAge(path string) (time.Duration, bool) { + info, err := os.Stat(path) + if err != nil || info.IsDir() { + return 0, false + } + + age := max(time.Since(info.ModTime()), 0) + return age, true +} + +func mergeLiveProvider(seed, live catwalk.Provider) catwalk.Provider { + merged := seed + if live.ID != "" { + merged.ID = live.ID + } + if live.Name != "" { + merged.Name = live.Name + } + if live.APIKey != "" { + merged.APIKey = live.APIKey + } + if live.APIEndpoint != "" { + merged.APIEndpoint = live.APIEndpoint + } + if live.Type != "" { + merged.Type = live.Type + } + if live.DefaultLargeModelID != "" { + merged.DefaultLargeModelID = live.DefaultLargeModelID + } + if live.DefaultSmallModelID != "" { + merged.DefaultSmallModelID = live.DefaultSmallModelID + } + if len(live.DefaultHeaders) > 0 { + merged.DefaultHeaders = live.DefaultHeaders + } + merged.Models = live.Models + + if len(merged.Models) == 0 { + return merged + } + + if merged.DefaultLargeModelID == "" || !modelExists(merged.Models, merged.DefaultLargeModelID) { + merged.DefaultLargeModelID = merged.Models[0].ID + } + if merged.DefaultSmallModelID == "" || !modelExists(merged.Models, merged.DefaultSmallModelID) { + merged.DefaultSmallModelID = merged.Models[0].ID + } + return merged +} + +func modelExists(models []catwalk.Model, id string) bool { + return slices.ContainsFunc(models, func(model catwalk.Model) bool { + return model.ID == id + }) +} diff --git a/internal/config/live_provider_test.go b/internal/config/live_provider_test.go new file mode 100644 index 0000000000..a95a6f5ead --- /dev/null +++ b/internal/config/live_provider_test.go @@ -0,0 +1,424 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "os" + "sync" + "testing" + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type mockLiveProviderClient struct { + provider catwalk.Provider + err error + started chan struct{} + release chan struct{} + + mu sync.Mutex + callCount int + startOnce sync.Once +} + +func (m *mockLiveProviderClient) Get(ctx context.Context) (catwalk.Provider, error) { + m.mu.Lock() + m.callCount++ + m.mu.Unlock() + + if m.started != nil { + m.startOnce.Do(func() { close(m.started) }) + } + if m.release != nil { + select { + case <-m.release: + case <-ctx.Done(): + return catwalk.Provider{}, ctx.Err() + } + } + return m.provider, m.err +} + +func (m *mockLiveProviderClient) calls() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.callCount +} + +func TestLiveProviderSync_GetPanicIfNotInit(t *testing.T) { + t.Parallel() + + syncer := &liveProviderSync{} + require.Panics(t, func() { + _, _ = syncer.Get(t.Context()) + }) +} + +func TestLiveProviderSync_GetAutoUpdateDisabledReturnsSeed(t *testing.T) { + t.Parallel() + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", false, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 0, client.calls()) +} + +func TestLiveProviderSync_GetWithoutCredentialsReturnsSeed(t *testing.T) { + t.Parallel() + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, seed, false) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 0, client.calls()) +} + +func TestLiveProviderSync_GetWarmCacheReturnsCachedWithoutFetch(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 0, client.calls()) +} + +func TestLiveProviderSync_GetStaleCacheReturnsCachedAndRefreshesInBackground(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + seed := testLiveSeedProvider() + started := make(chan struct{}) + release := make(chan struct{}) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{ + Models: []catwalk.Model{ + {ID: "live-model", Name: "Live Model"}, + {ID: "other-live-model", Name: "Other Live Model"}, + }, + }, + started: started, + release: release, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + <-started + require.Equal(t, 1, client.calls()) + + close(release) + require.Eventually(t, func() bool { + stored, _, err := newCache[catwalk.Provider](path).Get() + return err == nil && len(stored.Models) == 2 && stored.Models[0].ID == "live-model" + }, time.Second, 10*time.Millisecond) + + provider, err = syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed.ID, provider.ID) + require.Equal(t, seed.Name, provider.Name) + require.Equal(t, seed.APIEndpoint, provider.APIEndpoint) + require.Equal(t, seed.Type, provider.Type) + require.Equal(t, "live-model", provider.DefaultLargeModelID) + require.Equal(t, "live-model", provider.DefaultSmallModelID) + require.Equal(t, []catwalk.Model{ + {ID: "live-model", Name: "Live Model"}, + {ID: "other-live-model", Name: "Other Live Model"}, + }, provider.Models) +} + +func TestLiveProviderSync_GetBackgroundRefreshInvokesCallbackAfterStore(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + seed := testLiveSeedProvider() + started := make(chan struct{}) + release := make(chan struct{}) + callbackCh := make(chan catwalk.Provider, 1) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + started: started, + release: release, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, seed, true) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + <-started + close(release) + + var callbackProvider catwalk.Provider + require.Eventually(t, func() bool { + select { + case callbackProvider = <-callbackCh: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond) + require.Equal(t, seed.ID, callbackProvider.ID) + require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, callbackProvider.Models) + require.Empty(t, callbackCh) +} + +func TestLiveProviderSync_GetWarmCacheDoesNotInvokeCallback(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + callbackCh := make(chan catwalk.Provider, 1) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 0, client.calls()) + require.Empty(t, callbackCh) +} + +func TestLiveProviderSync_GetWithoutCredentialsDoesNotInvokeCallback(t *testing.T) { + t.Parallel() + + callbackCh := make(chan catwalk.Provider, 1) + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, seed, false) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 0, client.calls()) + require.Empty(t, callbackCh) +} + +func TestLiveProviderSync_GetStoreFailureStillUsesMergedResult(t *testing.T) { + t.Parallel() + + blocker := t.TempDir() + "/blocker" + require.NoError(t, os.WriteFile(blocker, []byte("not a dir"), 0o644)) + path := blocker + "/provider.json" + + seed := testLiveSeedProvider() + started := make(chan struct{}) + callbackCh := make(chan catwalk.Provider, 1) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + started: started, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, seed, true) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + <-started + require.Eventually(t, func() bool { + provider, err := syncer.Get(t.Context()) + return err == nil && len(provider.Models) == 1 && provider.Models[0].ID == "live-model" + }, time.Second, 10*time.Millisecond) + require.Equal(t, 1, client.calls()) + require.Empty(t, callbackCh) +} + +func TestLiveProviderSync_GetFetchErrorUsesCached(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + client := &mockLiveProviderClient{err: errors.New("network error")} + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) +} + +func TestLiveProviderSync_GetDeadlineExceededUsesCached(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + client := &mockLiveProviderClient{err: context.DeadlineExceeded} + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) +} + +func TestLiveProviderSync_GetFetchErrorUsesSeedWithoutCache(t *testing.T) { + t.Parallel() + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{err: errors.New("network error")} + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) +} + +func TestLiveProviderSync_GetEmptyModelsUsesFallback(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + client := &mockLiveProviderClient{provider: catwalk.Provider{ID: "test-live"}} + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) +} + +func TestLiveProviderSync_GetCalledMultipleTimesSchedulesOneRefresh(t *testing.T) { + t.Parallel() + + started := make(chan struct{}) + release := make(chan struct{}) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + started: started, + release: release, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, testLiveSeedProvider(), true) + + provider1, err1 := syncer.Get(t.Context()) + require.NoError(t, err1) + provider2, err2 := syncer.Get(t.Context()) + require.NoError(t, err2) + require.Equal(t, provider1, provider2) + <-started + require.Equal(t, 1, client.calls()) + close(release) +} + +func TestCacheAge(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + _, ok := cacheAge(path) + require.False(t, ok) + + writeLiveProviderCache(t, path, testLiveSeedProvider()) + age, ok := cacheAge(path) + require.True(t, ok) + require.Less(t, age, liveModelsTTL) + + future := time.Now().Add(liveModelsTTL) + require.NoError(t, os.Chtimes(path, future, future)) + age, ok = cacheAge(path) + require.True(t, ok) + require.Zero(t, age) +} + +func testLiveSeedProvider() catwalk.Provider { + return catwalk.Provider{ + ID: "test-live", + Name: "Test Live", + APIEndpoint: "https://example.com/v1", + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "seed-large-model", + DefaultSmallModelID: "seed-small-model", + DefaultHeaders: map[string]string{ + "X-Test": "seed", + }, + Models: []catwalk.Model{ + {ID: "seed-large-model", Name: "Seed Large Model"}, + {ID: "seed-small-model", Name: "Seed Small Model"}, + }, + } +} + +func writeLiveProviderCache(t *testing.T, path string, provider catwalk.Provider) { + t.Helper() + + data, err := json.Marshal(provider) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) +} diff --git a/internal/config/provider.go b/internal/config/provider.go index 32a3894358..82ac581fe7 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "os" "path/filepath" "runtime" @@ -19,7 +20,10 @@ import ( "charm.land/catwalk/pkg/embedded" "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/x/etag" ) @@ -28,11 +32,33 @@ type syncer[T any] interface { } var ( - providerOnce sync.Once - providerList []catwalk.Provider - providerErr error + providerOnce sync.Once + providerMu sync.RWMutex + providerList []catwalk.Provider + providerErr error + pendingProvider map[catwalk.InferenceProvider]catwalk.Provider + + providerEvents = pubsub.NewBroker[ProvidersUpdatedEvent]() ) +// ProvidersUpdatedEvent reports that a provider's model list changed. +type ProvidersUpdatedEvent struct { + ProviderID string +} + +// SubscribeProviderEvents returns provider update events. +func SubscribeProviderEvents(ctx context.Context) <-chan pubsub.Event[ProvidersUpdatedEvent] { + return providerEvents.Subscribe(ctx) +} + +var errMissingLiveProviderCredentials = errors.New("missing live provider credentials") + +// IsMissingLiveProviderCredentials reports whether err means a live provider +// cannot be updated because no credentials are configured. +func IsMissingLiveProviderCredentials(err error) bool { + return errors.Is(err, errMissingLiveProviderCredentials) +} + // file to cache provider data func cachePathFor(name string) string { xdgDataHome := os.Getenv("XDG_DATA_HOME") @@ -85,7 +111,7 @@ func UpdateProviders(pathOrURL string) error { return fmt.Errorf("failed to save providers to cache: %w", err) } - slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor) + slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor("providers")) return nil } @@ -122,9 +148,23 @@ func UpdateHyper(pathOrURL string) error { return nil } +// UpdateVenice updates the cached Venice provider from a live source, embedded +// seed, or local provider file. +func UpdateVenice(pathOrURL string, cfg *Config) error { + return updateLiveProvider("Venice", "venice", catwalk.InferenceProviderVenice, pathOrURL, cfg, newVeniceLiveProviderClient) +} + +// UpdateCopilot updates the cached Copilot provider from a live source, +// embedded seed, or local provider file. +func UpdateCopilot(pathOrURL string, cfg *Config) error { + return updateLiveProvider("Copilot", "copilot", catwalk.InferenceProviderCopilot, pathOrURL, cfg, newCopilotLiveProviderClient) +} + var ( catwalkSyncer = &catwalkSync{} hyperSyncer = &hyperSync{} + veniceSyncer = &liveProviderSync{} + copilotSyncer = &liveProviderSync{} ) // Providers returns the list of providers, taking into account cached results @@ -141,8 +181,12 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { var wg sync.WaitGroup var errs []error providers := csync.NewSlice[catwalk.Provider]() - autoupdate := !cfg.Options.DisableProviderAutoUpdate - customProvidersOnly := cfg.Options.DisableDefaultProviders + options := &Options{} + if cfg != nil && cfg.Options != nil { + options = cfg.Options + } + autoupdate := !options.DisableProviderAutoUpdate + customProvidersOnly := options.DisableDefaultProviders ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) defer cancel() @@ -186,14 +230,286 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { wg.Wait() + items := slices.Collect(providers.Seq()) + if !customProvidersOnly { + items = overlayLiveProviderModels(ctx, cfg, items, autoupdate) + } + if hyperFound { - providerList = append([]catwalk.Provider{hyperProvider}, slices.Collect(providers.Seq())...) + setProviderList(append([]catwalk.Provider{hyperProvider}, items...)) } else { - providerList = slices.Collect(providers.Seq()) + setProviderList(items) } providerErr = errors.Join(errs...) }) - return providerList, providerErr + return providerSnapshot(), providerErr +} + +func setProviderList(providers []catwalk.Provider) { + providerMu.Lock() + defer providerMu.Unlock() + + providers = cloneProviders(providers) + for providerID, provider := range pendingProvider { + index := slices.IndexFunc(providers, func(item catwalk.Provider) bool { + return item.ID == providerID + }) + if index >= 0 { + providers[index] = provider + } + } + pendingProvider = nil + providerList = providers +} + +func providerSnapshot() []catwalk.Provider { + providerMu.RLock() + defer providerMu.RUnlock() + return cloneProviders(providerList) +} + +func updateProviderList(provider catwalk.Provider) { + providerMu.Lock() + defer providerMu.Unlock() + + provider = cloneProvider(provider) + index := slices.IndexFunc(providerList, func(item catwalk.Provider) bool { + return item.ID == provider.ID + }) + if index < 0 { + if pendingProvider == nil { + pendingProvider = make(map[catwalk.InferenceProvider]catwalk.Provider) + } + pendingProvider[provider.ID] = provider + return + } + providerList = slices.Clone(providerList) + providerList[index] = provider +} + +func publishProviderUpdated(provider catwalk.Provider) { + updateProviderList(provider) + providerEvents.Publish(pubsub.UpdatedEvent, ProvidersUpdatedEvent{ProviderID: string(provider.ID)}) +} + +func cloneProviders(providers []catwalk.Provider) []catwalk.Provider { + if providers == nil { + return nil + } + cloned := make([]catwalk.Provider, len(providers)) + for i, provider := range providers { + cloned[i] = cloneProvider(provider) + } + return cloned +} + +func cloneProvider(provider catwalk.Provider) catwalk.Provider { + provider.Models = slices.Clone(provider.Models) + provider.DefaultHeaders = cloneStringMap(provider.DefaultHeaders) + return provider +} + +func cloneStringMap(values map[string]string) map[string]string { + if values == nil { + return nil + } + return maps.Clone(values) +} + +func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []catwalk.Provider, autoupdate bool) []catwalk.Provider { + if cfg == nil || len(providers) == 0 { + return providers + } + + environment := env.New() + resolver := NewShellVariableResolver(environment) + + syncProvider := func(providerID catwalk.InferenceProvider, cacheName string, syncer *liveProviderSync, newClient liveProviderClientFunc) { + index := slices.IndexFunc(providers, func(provider catwalk.Provider) bool { + return provider.ID == providerID + }) + if index < 0 { + return + } + + seed := providers[index] + client, credentialed, err := newClient(seed, cfg, resolver, "") + if err != nil { + slog.Warn("Skipping live provider sync", "provider", providerID, "error", err) + return + } + + syncer.Init(client, cachePathFor(cacheName), autoupdate, seed, credentialed) + syncer.onRefresh = publishProviderUpdated + provider, err := syncer.Get(ctx) + if err != nil { + slog.Warn("Live provider sync failed", "provider", providerID, "error", err) + } + if len(provider.Models) > 0 { + providers[index] = provider + } + } + + syncProvider(catwalk.InferenceProviderVenice, "venice", veniceSyncer, newVeniceLiveProviderClient) + syncProvider(catwalk.InferenceProviderCopilot, "copilot", copilotSyncer, newCopilotLiveProviderClient) + + return providers +} + +type liveProviderClientFunc func(catwalk.Provider, *Config, VariableResolver, string) (liveProviderClient, bool, error) + +func newVeniceLiveProviderClient(seed catwalk.Provider, cfg *Config, resolver VariableResolver, baseURLOverride string) (liveProviderClient, bool, error) { + var providerConfig ProviderConfig + configExists := cfg != nil && cfg.Providers != nil + if configExists { + providerConfig, configExists = cfg.Providers.Get(string(catwalk.InferenceProviderVenice)) + } + if configExists && providerConfig.Disable { + return realVeniceModelsClient{baseURL: seed.APIEndpoint}, false, nil + } + + baseURL := cmp.Or(baseURLOverride, seed.APIEndpoint) + if configExists && providerConfig.BaseURL != "" && baseURLOverride == "" { + resolved, err := resolver.ResolveValue(providerConfig.BaseURL) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Venice base URL: %w", err) + } + baseURL = resolved + } + + apiKey := "" + if configExists && providerConfig.APIKey != "" { + resolved, err := resolver.ResolveValue(providerConfig.APIKey) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Venice API key: %w", err) + } + apiKey = strings.TrimSpace(resolved) + } + if apiKey == "" { + resolved, err := resolver.ResolveValue("$VENICE_API_KEY") + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Venice API key from environment: %w", err) + } + apiKey = strings.TrimSpace(resolved) + } + + return realVeniceModelsClient{baseURL: baseURL, apiKey: apiKey}, apiKey != "", nil +} + +func newCopilotLiveProviderClient(seed catwalk.Provider, cfg *Config, resolver VariableResolver, baseURLOverride string) (liveProviderClient, bool, error) { + var providerConfig ProviderConfig + configExists := cfg != nil && cfg.Providers != nil + if configExists { + providerConfig, configExists = cfg.Providers.Get(string(catwalk.InferenceProviderCopilot)) + } + if configExists && providerConfig.Disable { + return &realCopilotModelsClient{baseURL: seed.APIEndpoint}, false, nil + } + + baseURL := cmp.Or(baseURLOverride, seed.APIEndpoint) + if configExists && providerConfig.BaseURL != "" && baseURLOverride == "" { + resolved, err := resolver.ResolveValue(providerConfig.BaseURL) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Copilot base URL: %w", err) + } + baseURL = resolved + } + + apiKey := "" + if configExists && providerConfig.APIKey != "" { + resolved, err := resolver.ResolveValue(providerConfig.APIKey) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Copilot API key: %w", err) + } + apiKey = strings.TrimSpace(resolved) + } + oauthToken := providerConfig.OAuthToken + credentialed := apiKey != "" || usableOAuthToken(oauthToken) + + return &realCopilotModelsClient{baseURL: baseURL, apiKey: apiKey, oauthToken: oauthToken}, credentialed, nil +} + +func usableOAuthToken(token *oauth.Token) bool { + if token == nil { + return false + } + return strings.TrimSpace(token.AccessToken) != "" || strings.TrimSpace(token.RefreshToken) != "" +} + +func updateLiveProvider(name, cacheName string, providerID catwalk.InferenceProvider, pathOrURL string, cfg *Config, newClient liveProviderClientFunc) error { + seed, err := liveProviderSeed(providerID) + if err != nil { + return err + } + + var provider catwalk.Provider + switch { + case pathOrURL == "embedded": + provider = seed + case pathOrURL != "" && !strings.HasPrefix(pathOrURL, "http://") && !strings.HasPrefix(pathOrURL, "https://"): + content, err := os.ReadFile(pathOrURL) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + if err := json.Unmarshal(content, &provider); err != nil { + return fmt.Errorf("failed to unmarshal provider data: %w", err) + } + default: + if cfg == nil { + return fmt.Errorf("failed to fetch provider from %s: %w", name, errMissingLiveProviderCredentials) + } + environment := env.New() + resolver := NewShellVariableResolver(environment) + client, credentialed, err := newClient(seed, cfg, resolver, pathOrURL) + if err != nil { + return fmt.Errorf("failed to prepare %s provider update: %w", name, err) + } + if !credentialed { + return fmt.Errorf("failed to fetch provider from %s: %w", name, errMissingLiveProviderCredentials) + } + + ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) + defer cancel() + + provider, err = client.Get(ctx) + if err != nil { + return fmt.Errorf("failed to fetch provider from %s: %w", name, err) + } + if len(provider.Models) == 0 { + return fmt.Errorf("failed to fetch provider from %s: no models returned", name) + } + provider = mergeLiveProvider(seed, provider) + } + + if err := newCache[catwalk.Provider](cachePathFor(cacheName)).Store(provider); err != nil { + return fmt.Errorf("failed to save %s provider to cache: %w", name, err) + } + + slog.Info(name+" provider updated successfully", "from", cmp.Or(pathOrURL, "live"), "to", cachePathFor(cacheName)) + return nil +} + +func liveProviderSeed(providerID catwalk.InferenceProvider) (catwalk.Provider, error) { + providers, _, err := newCache[[]catwalk.Provider](cachePathFor("providers")).Get() + if err != nil || len(providers) == 0 { + providers = embedded.GetAll() + } + if provider, ok := findProvider(providers, providerID); ok { + return provider, nil + } + if provider, ok := findProvider(embedded.GetAll(), providerID); ok { + return provider, nil + } + return catwalk.Provider{}, fmt.Errorf("provider %s not found", providerID) +} + +func findProvider(providers []catwalk.Provider, providerID catwalk.InferenceProvider) (catwalk.Provider, bool) { + for _, provider := range providers { + if provider.ID == providerID { + return provider, true + } + } + return catwalk.Provider{}, false } type cache[T any] struct { diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 283c18c8ab..fb4cdab98c 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,22 +1,35 @@ package config import ( + "context" "encoding/json" + "net/http" + "net/http/httptest" "os" "path/filepath" "sync" + "sync/atomic" "testing" + "time" "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/stretchr/testify/require" ) func resetProviderState() { providerOnce = sync.Once{} + providerMu.Lock() providerList = nil providerErr = nil + pendingProvider = nil + providerMu.Unlock() catwalkSyncer = &catwalkSync{} hyperSyncer = &hyperSync{} + veniceSyncer = &liveProviderSync{} + copilotSyncer = &liveProviderSync{} } func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) { @@ -217,6 +230,341 @@ func TestProviders_Integration_BothFail(t *testing.T) { require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models. } +func TestProviders_Integration_LiveOverlayRefreshesWithCredentialsInBackground(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + started := make(chan struct{}, 2) + release := make(chan struct{}) + var requestCount atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + requestCount.Add(1) + started <- struct{}{} + <-release + + switch r.Header.Get("Authorization") { + case "Bearer venice-token": + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "venice-live", + "type": "text", + "model_spec": { + "name": "Venice Live", + "availableContextTokens": 4096, + "maxCompletionTokens": 1024, + "pricing": {"input": {"usd": "0.1"}, "output": {"usd": "0.2"}}, + "capabilities": {"supportsVision": true} + } + } + ] + }`)) + case "Bearer copilot-token": + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "copilot-live", + "name": "Copilot Live", + "capabilities": { + "limits": {"max_context_window_tokens": 8192, "max_output_tokens": 2048}, + "supports": {"vision": true} + } + } + ] + }`)) + default: + w.WriteHeader(http.StatusUnauthorized) + } + })) + defer server.Close() + + providers := []catwalk.Provider{ + { + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "venice-live", + DefaultSmallModelID: "venice-live", + Models: []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, + }, + { + Name: "Copilot", + ID: catwalk.InferenceProviderCopilot, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "copilot-live", + DefaultSmallModelID: "copilot-live", + Models: []catwalk.Model{{ID: "copilot-seed", Name: "Copilot Seed"}}, + }, + } + cfg := &Config{ + Options: &Options{}, + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + string(catwalk.InferenceProviderVenice): {APIKey: "venice-token"}, + string(catwalk.InferenceProviderCopilot): {APIKey: "copilot-token"}, + }), + } + + resultCh := make(chan []catwalk.Provider, 1) + go func() { + resultCh <- overlayLiveProviderModels(t.Context(), cfg, providers, true) + }() + + var result []catwalk.Provider + select { + case result = <-resultCh: + case <-time.After(100 * time.Millisecond): + close(release) + require.FailNow(t, "Live overlay blocked on background refresh") + } + + venice, ok := findProvider(result, catwalk.InferenceProviderVenice) + require.True(t, ok) + require.Equal(t, []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, venice.Models) + + copilot, ok := findProvider(result, catwalk.InferenceProviderCopilot) + require.True(t, ok) + require.Equal(t, []catwalk.Model{{ID: "copilot-seed", Name: "Copilot Seed"}}, copilot.Models) + + require.Eventually(t, func() bool { return requestCount.Load() == 2 }, time.Second, 10*time.Millisecond) + <-started + <-started + close(release) + + require.Eventually(t, func() bool { + venice, _, err := newCache[catwalk.Provider](cachePathFor("venice")).Get() + return err == nil && len(venice.Models) == 1 && venice.Models[0].ID == "venice-live" + }, time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { + copilot, _, err := newCache[catwalk.Provider](cachePathFor("copilot")).Get() + return err == nil && len(copilot.Models) == 1 && copilot.Models[0].ID == "copilot-live" + }, time.Second, 10*time.Millisecond) +} + +func TestProviders_Integration_LiveOverlayUpdatesProviderListAndPublishesEvent(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + started := make(chan struct{}, 1) + release := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + started <- struct{}{} + <-release + require.Equal(t, "Bearer venice-token", r.Header.Get("Authorization")) + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "venice-live", + "type": "text", + "model_spec": { + "name": "Venice Live", + "availableContextTokens": 4096, + "maxCompletionTokens": 1024, + "pricing": {"input": {"usd": "0.1"}, "output": {"usd": "0.2"}}, + "capabilities": {"supportsVision": true} + } + } + ] + }`)) + })) + defer server.Close() + + seed := catwalk.Provider{ + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "venice-seed", + DefaultSmallModelID: "venice-seed", + Models: []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, + } + cfg := &Config{ + Options: &Options{}, + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + string(catwalk.InferenceProviderVenice): {APIKey: "venice-token"}, + }), + } + providerOnce.Do(func() {}) + setProviderList([]catwalk.Provider{seed}) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + events := SubscribeProviderEvents(ctx) + + result := overlayLiveProviderModels(t.Context(), cfg, []catwalk.Provider{seed}, true) + require.Equal(t, []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, result[0].Models) + <-started + close(release) + + select { + case event := <-events: + require.Equal(t, pubsub.UpdatedEvent, event.Type) + require.Equal(t, string(catwalk.InferenceProviderVenice), event.Payload.ProviderID) + case <-time.After(time.Second): + require.FailNow(t, "Timed out waiting for providers updated event") + } + + providers, err := Providers(cfg) + require.NoError(t, err) + venice, ok := findProvider(providers, catwalk.InferenceProviderVenice) + require.True(t, ok) + require.Equal(t, []catwalk.Model{{ + ID: "venice-live", + Name: "Venice Live", + CostPer1MIn: 0.1, + CostPer1MOut: 0.2, + ContextWindow: 4096, + DefaultMaxTokens: 1024, + SupportsImages: true, + }}, venice.Models) +} + +func TestProviders_Integration_LiveOverlaySkipsWithoutCredentials(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + seed := catwalk.Provider{ + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: "http://127.0.0.1:1", + Models: []catwalk.Model{{ID: "seed-model", Name: "Seed Model"}}, + } + cached := seed + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + require.NoError(t, newCache[catwalk.Provider](cachePathFor("venice")).Store(cached)) + + cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} + result := overlayLiveProviderModels(t.Context(), cfg, []catwalk.Provider{seed}, true) + require.Equal(t, []catwalk.Provider{seed}, result) +} + +func TestNewVeniceLiveProviderClient_EnvFallbackUsesResolver(t *testing.T) { + t.Parallel() + + seed := catwalk.Provider{ + ID: catwalk.InferenceProviderVenice, + APIEndpoint: "https://api.venice.ai/api/v1", + } + cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} + resolver := NewShellVariableResolver(env.NewFromMap(map[string]string{ + "VENICE_API_KEY": " env-venice-key ", + })) + + client, credentialed, err := newVeniceLiveProviderClient(seed, cfg, resolver, "") + require.NoError(t, err) + require.True(t, credentialed) + veniceClient, ok := client.(realVeniceModelsClient) + require.True(t, ok) + require.Equal(t, "env-venice-key", veniceClient.apiKey) +} + +func TestUpdateVenice_LiveSourceStoresCache(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer venice-token", r.Header.Get("Authorization")) + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "venice-live", + "type": "text", + "model_spec": { + "name": "Venice Live", + "availableContextTokens": 4096, + "maxCompletionTokens": 1024, + "pricing": {"input": {"usd": "0.1"}, "output": {"usd": "0.2"}}, + "capabilities": {"supportsVision": true} + } + } + ] + }`)) + })) + defer server.Close() + + seed := catwalk.Provider{ + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "venice-live", + DefaultSmallModelID: "venice-live", + } + require.NoError(t, newCache[[]catwalk.Provider](cachePathFor("providers")).Store([]catwalk.Provider{seed})) + cfg := &Config{ + Options: &Options{}, + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + string(catwalk.InferenceProviderVenice): {APIKey: "venice-token"}, + }), + } + + require.NoError(t, UpdateVenice(server.URL, cfg)) + + cached, _, err := newCache[catwalk.Provider](cachePathFor("venice")).Get() + require.NoError(t, err) + require.Equal(t, "Venice", cached.Name) + require.Equal(t, "venice-live", cached.DefaultLargeModelID) + require.Len(t, cached.Models, 1) + require.Equal(t, "venice-live", cached.Models[0].ID) +} + +func TestUpdateVenice_LiveSourceRequiresCredentials(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + seed := catwalk.Provider{ID: catwalk.InferenceProviderVenice, Name: "Venice"} + require.NoError(t, newCache[[]catwalk.Provider](cachePathFor("providers")).Store([]catwalk.Provider{seed})) + cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} + + err := UpdateVenice("", cfg) + require.Error(t, err) + require.True(t, IsMissingLiveProviderCredentials(err)) +} + +func TestUpdateCopilot_FileSourceStoresCache(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + provider := catwalk.Provider{ + Name: "Copilot", + ID: catwalk.InferenceProviderCopilot, + Models: []catwalk.Model{ + {ID: "copilot-file", Name: "Copilot File"}, + }, + } + data, err := json.Marshal(provider) + require.NoError(t, err) + path := filepath.Join(tmpDir, "copilot.json") + require.NoError(t, os.WriteFile(path, data, 0o644)) + + require.NoError(t, UpdateCopilot(path, nil)) + + cached, _, err := newCache[catwalk.Provider](cachePathFor("copilot")).Get() + require.NoError(t, err) + require.Equal(t, provider, cached) +} + func TestCache_StoreAndGet(t *testing.T) { t.Parallel() diff --git a/internal/config/venice_models.go b/internal/config/venice_models.go new file mode 100644 index 0000000000..9c76cf7fc8 --- /dev/null +++ b/internal/config/venice_models.go @@ -0,0 +1,153 @@ +package config + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "charm.land/catwalk/pkg/catwalk" +) + +var _ liveProviderClient = realVeniceModelsClient{} + +type realVeniceModelsClient struct { + baseURL string + apiKey string +} + +func (r realVeniceModelsClient) Get(ctx context.Context) (catwalk.Provider, error) { + var result catwalk.Provider + baseURL := strings.TrimRight(r.baseURL, "/") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) + if err != nil { + return result, fmt.Errorf("could not create request: %w", err) + } + if apiKey := strings.TrimSpace(r.apiKey); apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var models veniceModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&models); err != nil { + return result, fmt.Errorf("failed to decode response: %w", err) + } + + result = catwalk.Provider{ + ID: catwalk.InferenceProviderVenice, + APIEndpoint: baseURL, + Type: catwalk.TypeOpenAICompat, + Models: veniceModelsToCatwalkModels(models), + } + return result, nil +} + +type veniceModelsResponse struct { + Data []veniceModel `json:"data"` +} + +type veniceModel struct { + ID string `json:"id"` + ContextLength int64 `json:"context_length"` + ModelSpec veniceModelSpec `json:"model_spec"` + Type string `json:"type"` +} + +type veniceModelSpec struct { + AvailableContextTokens int64 `json:"availableContextTokens"` + MaxCompletionTokens int64 `json:"maxCompletionTokens"` + Capabilities veniceModelCapabilities `json:"capabilities"` + Name string `json:"name"` + Offline bool `json:"offline"` + Pricing veniceModelPricing `json:"pricing"` +} + +type veniceModelCapabilities struct { + SupportsReasoning bool `json:"supportsReasoning"` + ReasoningEffortOptions []string `json:"reasoningEffortOptions"` + DefaultReasoningEffort string `json:"defaultReasoningEffort"` + SupportsVision bool `json:"supportsVision"` +} + +type veniceModelPricing struct { + Input veniceModelPricingValue `json:"input"` + Output veniceModelPricingValue `json:"output"` + CacheInput veniceModelPricingValue `json:"cache_input"` +} + +type veniceModelPricingValue struct { + USD veniceUSD `json:"usd"` +} + +type veniceUSD float64 + +func (v *veniceUSD) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *v = 0 + return nil + } + + var number float64 + if err := json.Unmarshal(data, &number); err == nil { + *v = veniceUSD(number) + return nil + } + + var value string + if err := json.Unmarshal(data, &value); err != nil { + return fmt.Errorf("failed to decode usd value: %w", err) + } + value = strings.TrimSpace(value) + if value == "" { + *v = 0 + return nil + } + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("failed to parse usd value: %w", err) + } + *v = veniceUSD(parsed) + return nil +} + +func veniceModelsToCatwalkModels(response veniceModelsResponse) []catwalk.Model { + models := make([]catwalk.Model, 0, len(response.Data)) + for _, model := range response.Data { + if !strings.EqualFold(model.Type, "text") || model.ModelSpec.Offline { + continue + } + + contextWindow := model.ModelSpec.AvailableContextTokens + if contextWindow == 0 { + contextWindow = model.ContextLength + } + + models = append(models, catwalk.Model{ + ID: model.ID, + Name: model.ModelSpec.Name, + CostPer1MIn: float64(model.ModelSpec.Pricing.Input.USD), + CostPer1MOut: float64(model.ModelSpec.Pricing.Output.USD), + CostPer1MInCached: float64(model.ModelSpec.Pricing.CacheInput.USD), + ContextWindow: contextWindow, + DefaultMaxTokens: model.ModelSpec.MaxCompletionTokens, + CanReason: model.ModelSpec.Capabilities.SupportsReasoning, + ReasoningLevels: model.ModelSpec.Capabilities.ReasoningEffortOptions, + DefaultReasoningEffort: model.ModelSpec.Capabilities.DefaultReasoningEffort, + SupportsImages: model.ModelSpec.Capabilities.SupportsVision, + }) + } + return models +} diff --git a/internal/config/venice_models_test.go b/internal/config/venice_models_test.go new file mode 100644 index 0000000000..9129f66341 --- /dev/null +++ b/internal/config/venice_models_test.go @@ -0,0 +1,148 @@ +package config + +import ( + "net/http" + "net/http/httptest" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{ + "data": [ + { + "id": "venice-reasoning", + "type": "text", + "model_spec": { + "name": "Venice Reasoning", + "availableContextTokens": 128000, + "maxCompletionTokens": 4096, + "offline": false, + "pricing": { + "input": { "usd": "0.15" }, + "output": { "usd": 0.6 }, + "cache_input": { "usd": "0.03" } + }, + "capabilities": { + "supportsReasoning": true, + "reasoningEffortOptions": ["low", "medium", "high"], + "defaultReasoningEffort": "medium", + "supportsVision": true + } + } + }, + { + "id": "venice-fallback", + "type": "text", + "context_length": 64000, + "model_spec": { + "name": "Venice Fallback", + "maxCompletionTokens": 2048, + "offline": false + } + }, + { + "id": "venice-image", + "type": "image", + "model_spec": { "name": "Image Model" } + }, + { + "id": "venice-offline", + "type": "text", + "model_spec": { "name": "Offline Model", "offline": true } + } + ] + }`)) + require.NoError(t, err) + })) + defer server.Close() + + client := realVeniceModelsClient{baseURL: server.URL + "/", apiKey: " test-key "} + provider, err := client.Get(t.Context()) + + require.NoError(t, err) + require.Equal(t, catwalk.InferenceProviderVenice, provider.ID) + require.Equal(t, server.URL, provider.APIEndpoint) + require.Equal(t, catwalk.TypeOpenAICompat, provider.Type) + require.Equal(t, []catwalk.Model{ + { + ID: "venice-reasoning", + Name: "Venice Reasoning", + CostPer1MIn: 0.15, + CostPer1MOut: 0.6, + CostPer1MInCached: 0.03, + ContextWindow: 128000, + DefaultMaxTokens: 4096, + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, + { + ID: "venice-fallback", + Name: "Venice Fallback", + ContextWindow: 64000, + DefaultMaxTokens: 2048, + }, + }, provider.Models) +} + +func TestVeniceModelsToCatwalkModelsFiltersNonTextAndOfflineModels(t *testing.T) { + t.Parallel() + + models := veniceModelsToCatwalkModels(veniceModelsResponse{Data: []veniceModel{ + { + ID: "text-model", + Type: "TEXT", + ModelSpec: veniceModelSpec{ + Name: "Text Model", + AvailableContextTokens: 32000, + MaxCompletionTokens: 2048, + Pricing: veniceModelPricing{ + Input: veniceModelPricingValue{USD: veniceUSD(0.1)}, + Output: veniceModelPricingValue{USD: veniceUSD(0.2)}, + CacheInput: veniceModelPricingValue{USD: veniceUSD(0.01)}, + }, + Capabilities: veniceModelCapabilities{ + SupportsReasoning: true, + ReasoningEffortOptions: []string{"low"}, + DefaultReasoningEffort: "low", + }, + }, + }, + { + ID: "image-model", + Type: "image", + ModelSpec: veniceModelSpec{Name: "Image Model"}, + }, + { + ID: "offline-model", + Type: "text", + ModelSpec: veniceModelSpec{Name: "Offline Model", Offline: true}, + }, + }}) + + require.Equal(t, []catwalk.Model{ + { + ID: "text-model", + Name: "Text Model", + CostPer1MIn: 0.1, + CostPer1MOut: 0.2, + CostPer1MInCached: 0.01, + ContextWindow: 32000, + DefaultMaxTokens: 2048, + CanReason: true, + ReasoningLevels: []string{"low"}, + DefaultReasoningEffort: "low", + }, + }, models) +} diff --git a/internal/server/recover_test.go b/internal/server/recover_test.go index 2bbd61efd5..45b616677b 100644 --- a/internal/server/recover_test.go +++ b/internal/server/recover_test.go @@ -23,7 +23,7 @@ func TestRecoverHandler_PanicReturns500(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) h.ServeHTTP(rec, req) require.Equal(t, http.StatusInternalServerError, rec.Code) @@ -48,7 +48,7 @@ func TestRecoverHandler_NoPanicPassthrough(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) h.ServeHTTP(rec, req) require.Equal(t, http.StatusTeapot, rec.Code) @@ -71,7 +71,7 @@ func TestRecoverHandler_PanicAfterWriteHeader(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) require.NotPanics(t, func() { h.ServeHTTP(rec, req) }) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "partial", rec.Body.String()) @@ -89,6 +89,6 @@ func TestRecoverHandler_AbortHandlerPropagates(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) require.PanicsWithValue(t, http.ErrAbortHandler, func() { h.ServeHTTP(rec, req) }) } diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 76ec4dbbf2..2725157630 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -161,6 +161,19 @@ func (m *Models) ID() string { return ModelsID } +// RefreshProviders reloads provider models and repopulates the list. +func (m *Models) RefreshProviders() error { + var err error + m.providers, err = config.Providers(m.com.Config()) + if err != nil { + return fmt.Errorf("failed to get providers: %w", err) + } + if err := m.setProviderItems(); err != nil { + return fmt.Errorf("failed to set provider items: %w", err) + } + return nil +} + // HandleMsg implements Dialog. func (m *Models) HandleMsg(msg tea.Msg) Action { switch msg := msg.(type) { diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 9dfe722796..6ef9b4f30b 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -712,6 +712,15 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, m.handleFileEvent(msg.Payload)) case pubsub.Event[app.LSPEvent]: m.lspStates = app.GetLSPStates() + case pubsub.Event[config.ProvidersUpdatedEvent]: + if dia := m.dialog.Dialog(dialog.ModelsID); dia != nil { + models, ok := dia.(*dialog.Models) + if ok { + if err := models.RefreshProviders(); err != nil { + cmds = append(cmds, util.ReportError(err)) + } + } + } case pubsub.Event[skills.Event]: m.skillStates = msg.Payload.States case pubsub.Event[mcp.Event]: @@ -1471,29 +1480,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { } m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionToggleThinking: - cmds = append(cmds, func() tea.Msg { - cfg := m.com.Config() - if cfg == nil { - return util.ReportError(errors.New("configuration not found"))() - } - - agentCfg, ok := cfg.Agents[config.AgentCoder] - if !ok { - return util.ReportError(errors.New("agent configuration not found"))() - } - - currentModel := cfg.Models[agentCfg.Model] - currentModel.Think = !currentModel.Think - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { - return util.ReportError(err)() - } - m.com.Workspace.UpdateAgentModel(context.TODO()) - status := "disabled" - if currentModel.Think { - status = "enabled" - } - return util.NewInfoMsg("Thinking mode " + status) - }) + cmds = append(cmds, m.toggleThinking) m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionToggleTransparentBackground: cmds = append(cmds, func() tea.Msg { @@ -1542,29 +1529,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - cfg := m.com.Config() - if cfg == nil { - cmds = append(cmds, util.ReportError(errors.New("configuration not found"))) - break - } - - agentCfg, ok := cfg.Agents[config.AgentCoder] - if !ok { - cmds = append(cmds, util.ReportError(errors.New("agent configuration not found"))) - break - } - - currentModel := cfg.Models[agentCfg.Model] - currentModel.ReasoningEffort = msg.Effort - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { - cmds = append(cmds, util.ReportError(err)) - break + if cmd := m.selectReasoningEffort(msg.Effort); cmd != nil { + cmds = append(cmds, cmd) } - - cmds = append(cmds, func() tea.Msg { - m.com.Workspace.UpdateAgentModel(context.TODO()) - return util.NewInfoMsg("Reasoning effort set to " + msg.Effort) - }) m.dialog.CloseDialog(dialog.ReasoningID) case dialog.ActionPermissionResponse: m.dialog.CloseDialog(dialog.PermissionsID) @@ -1690,6 +1657,63 @@ func (m *UI) fetchHyperCredits() tea.Cmd { } } +// toggleThinking flips the thinking mode of the coder agent's current +// model and persists it at workspace scope so it isn't shadowed by +// workspace-scoped model selections. +func (m *UI) toggleThinking() tea.Msg { + cfg := m.com.Config() + if cfg == nil { + return util.ReportError(errors.New("configuration not found"))() + } + + agentCfg, ok := cfg.Agents[config.AgentCoder] + if !ok { + return util.ReportError(errors.New("agent configuration not found"))() + } + + currentModel := cfg.Models[agentCfg.Model] + currentModel.Think = !currentModel.Think + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, agentCfg.Model, currentModel); err != nil { + return util.ReportError(err)() + } + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { + return util.ReportError(err)() + } + status := "disabled" + if currentModel.Think { + status = "enabled" + } + return util.NewInfoMsg("Thinking mode " + status) +} + +// selectReasoningEffort sets the reasoning effort of the coder agent's +// current model and persists it at workspace scope so it isn't shadowed +// by workspace-scoped model selections. +func (m *UI) selectReasoningEffort(effort string) tea.Cmd { + cfg := m.com.Config() + if cfg == nil { + return util.ReportError(errors.New("configuration not found")) + } + + agentCfg, ok := cfg.Agents[config.AgentCoder] + if !ok { + return util.ReportError(errors.New("agent configuration not found")) + } + + currentModel := cfg.Models[agentCfg.Model] + currentModel.ReasoningEffort = effort + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, agentCfg.Model, currentModel); err != nil { + return util.ReportError(err) + } + + return func() tea.Msg { + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { + return util.ReportError(err)() + } + return util.NewInfoMsg("Reasoning effort set to " + effort) + } +} + // handleSelectModel performs the model selection after any provider // pre-checks (such as a silent Hyper OAuth refresh) have completed. func (m *UI) handleSelectModel(msg dialog.ActionSelectModel) tea.Cmd { @@ -1735,7 +1759,7 @@ func (m *UI) handleSelectModel(msg dialog.ActionSelectModel) tea.Cmd { return tea.Batch(cmds...) } - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else { if msg.ModelType == config.SelectedModelTypeLarge { @@ -1746,7 +1770,7 @@ func (m *UI) handleSelectModel(msg dialog.ActionSelectModel) tea.Cmd { if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. smallModel := m.com.Workspace.GetDefaultSmallModel(providerID) - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } diff --git a/internal/ui/model/ui_test.go b/internal/ui/model/ui_test.go index 4032c80a05..a0bc4d19db 100644 --- a/internal/ui/model/ui_test.go +++ b/internal/ui/model/ui_test.go @@ -1,12 +1,15 @@ package model import ( + "context" "testing" "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/ui/dialog" + "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/crush/internal/workspace" "github.com/stretchr/testify/require" ) @@ -74,6 +77,101 @@ func TestCurrentModelSupportsImages(t *testing.T) { }) } +func TestHandleSelectModelPersistsWorkspaceScope(t *testing.T) { + t.Parallel() + + providers := csync.NewMap[string, config.ProviderConfig]() + providers.Set("test-provider", config.ProviderConfig{ + ID: "test-provider", + APIKey: "test-key", + Models: []catwalk.Model{{ID: "large-model", Name: "Large Model"}}, + }) + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{}, + Providers: providers, + Options: &config.Options{TUI: &config.TUIOptions{}}, + } + ws := &testWorkspace{ + cfg: cfg, + defaultSmallModel: config.SelectedModel{ + Provider: "test-provider", + Model: "small-model", + }, + } + ui := New(&common.Common{Workspace: ws, Styles: ptr(styles.CharmtonePantera())}, "", false) + + ui.handleSelectModel(dialog.ActionSelectModel{ + Provider: catwalk.Provider{ID: "test-provider"}, + Model: config.SelectedModel{ + Provider: "test-provider", + Model: "large-model", + }, + ModelType: config.SelectedModelTypeLarge, + }) + + require.Len(t, ws.preferredModelUpdates, 2) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[0].scope) + require.Equal(t, config.SelectedModelTypeLarge, ws.preferredModelUpdates[0].modelType) + require.Equal(t, "large-model", ws.preferredModelUpdates[0].model.Model) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[1].scope) + require.Equal(t, config.SelectedModelTypeSmall, ws.preferredModelUpdates[1].modelType) + require.Equal(t, "small-model", ws.preferredModelUpdates[1].model.Model) +} + +func TestToggleThinkingPersistsWorkspaceScope(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Provider: "test-provider", + Model: "large-model", + }, + }, + Providers: csync.NewMap[string, config.ProviderConfig](), + Agents: map[string]config.Agent{ + config.AgentCoder: {Model: config.SelectedModelTypeLarge}, + }, + } + ws := &testWorkspace{cfg: cfg} + ui := newTestUIWithConfig(t, cfg) + ui.com.Workspace = ws + + ui.toggleThinking() + + require.Len(t, ws.preferredModelUpdates, 1) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[0].scope) + require.Equal(t, config.SelectedModelTypeLarge, ws.preferredModelUpdates[0].modelType) + require.True(t, ws.preferredModelUpdates[0].model.Think) +} + +func TestSelectReasoningEffortPersistsWorkspaceScope(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Provider: "test-provider", + Model: "large-model", + }, + }, + Providers: csync.NewMap[string, config.ProviderConfig](), + Agents: map[string]config.Agent{ + config.AgentCoder: {Model: config.SelectedModelTypeLarge}, + }, + } + ws := &testWorkspace{cfg: cfg} + ui := newTestUIWithConfig(t, cfg) + ui.com.Workspace = ws + + ui.selectReasoningEffort("high") + + require.Len(t, ws.preferredModelUpdates, 1) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[0].scope) + require.Equal(t, config.SelectedModelTypeLarge, ws.preferredModelUpdates[0].modelType) + require.Equal(t, "high", ws.preferredModelUpdates[0].model.ReasoningEffort) +} + func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI { t.Helper() @@ -87,9 +185,54 @@ func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI { // testWorkspace is a minimal [workspace.Workspace] stub for unit tests. type testWorkspace struct { workspace.Workspace - cfg *config.Config + cfg *config.Config + defaultSmallModel config.SelectedModel + preferredModelUpdates []preferredModelUpdate +} + +type preferredModelUpdate struct { + scope config.Scope + modelType config.SelectedModelType + model config.SelectedModel } func (w *testWorkspace) Config() *config.Config { return w.cfg } + +func (w *testWorkspace) PermissionSkipRequests() bool { + return false +} + +func (w *testWorkspace) ProjectNeedsInitialization() (bool, error) { + return false, nil +} + +func (w *testWorkspace) AgentIsReady() bool { + return false +} + +func (w *testWorkspace) AgentIsBusy() bool { + return false +} + +func (w *testWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + w.preferredModelUpdates = append(w.preferredModelUpdates, preferredModelUpdate{ + scope: scope, + modelType: modelType, + model: model, + }) + return nil +} + +func (w *testWorkspace) GetDefaultSmallModel(string) config.SelectedModel { + return w.defaultSmallModel +} + +func (w *testWorkspace) UpdateAgentModel(context.Context) error { + return nil +} + +func ptr[T any](v T) *T { + return &v +}