From f65bc5f1c0fb30282f9f55b9f1804098c48c9c8c Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 19:30:24 -0700 Subject: [PATCH 01/16] feat(config): add live provider syncer --- internal/config/live_provider.go | 169 +++++++++++++++ internal/config/live_provider_test.go | 290 ++++++++++++++++++++++++++ 2 files changed, 459 insertions(+) create mode 100644 internal/config/live_provider.go create mode 100644 internal/config/live_provider_test.go diff --git a/internal/config/live_provider.go b/internal/config/live_provider.go new file mode 100644 index 0000000000..ce9481380e --- /dev/null +++ b/internal/config/live_provider.go @@ -0,0 +1,169 @@ +package config + +import ( + "context" + "errors" + "log/slog" + "os" + "slices" + "sync" + "sync/atomic" + "time" + + "charm.land/catwalk/pkg/catwalk" +) + +const liveModelsTTL = time.Minute + +type liveProviderClient interface { + Get(context.Context, string) (catwalk.Provider, error) +} + +var _ syncer[catwalk.Provider] = (*liveProviderSync)(nil) + +type liveProviderSync struct { + once sync.Once + result catwalk.Provider + cache cache[catwalk.Provider] + client liveProviderClient + seed catwalk.Provider + autoupdate bool + credentialed bool + force bool + ttl time.Duration + init atomic.Bool +} + +func (s *liveProviderSync) Init(client liveProviderClient, path string, autoupdate bool, seed catwalk.Provider, credentialed bool) { + s.client = client + s.cache = newCache[catwalk.Provider](path) + s.autoupdate = autoupdate + s.seed = seed + s.credentialed = credentialed + s.ttl = liveModelsTTL + s.init.Store(true) +} + +func (s *liveProviderSync) Force() { + s.force = true +} + +func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { + if !s.init.Load() { + panic("called Get before Init") + } + + var throwErr error + s.once.Do(func() { + if !s.autoupdate { + slog.Info("Using provider seed", "provider", s.seed.ID) + s.result = s.seed + return + } + if !s.credentialed { + slog.Info("Skipping live provider sync without credentials", "provider", s.seed.ID) + s.result = s.seed + return + } + + cached, etag, cachedErr := s.cache.Get() + cachedAvailable := cachedErr == nil && len(cached.Models) > 0 + fallback := s.seed + if cachedAvailable { + fallback = cached + } + + if cachedAvailable && !s.force { + if age, ok := cacheAge(s.cache.path); ok && age < s.ttl { + slog.Info("Using cached live provider models", "provider", fallback.ID, "age", age) + s.result = cached + return + } + } + + slog.Info("Fetching live provider models", "provider", s.seed.ID) + result, err := s.client.Get(ctx, etag) + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("Live provider models not updated in time", "provider", s.seed.ID) + s.result = fallback + return + } + if errors.Is(err, catwalk.ErrNotModified) { + slog.Info("Live provider models not modified", "provider", s.seed.ID) + s.result = fallback + return + } + if err != nil { + slog.Warn("Live provider models not updated", "provider", s.seed.ID, "err", err) + s.result = fallback + return + } + if len(result.Models) == 0 { + slog.Warn("Live provider did not return any models", "provider", s.seed.ID) + s.result = fallback + return + } + + merged := mergeLiveProvider(s.seed, result) + s.result = merged + throwErr = s.cache.Store(merged) + }) + return s.result, throwErr +} + +func cacheAge(path string) (time.Duration, bool) { + info, err := os.Stat(path) + if err != nil || info.IsDir() { + return 0, false + } + + age := max(time.Since(info.ModTime()), 0) + return age, true +} + +func mergeLiveProvider(seed, live catwalk.Provider) catwalk.Provider { + merged := seed + if live.ID != "" { + merged.ID = live.ID + } + if live.Name != "" { + merged.Name = live.Name + } + if live.APIKey != "" { + merged.APIKey = live.APIKey + } + if live.APIEndpoint != "" { + merged.APIEndpoint = live.APIEndpoint + } + if live.Type != "" { + merged.Type = live.Type + } + if live.DefaultLargeModelID != "" { + merged.DefaultLargeModelID = live.DefaultLargeModelID + } + if live.DefaultSmallModelID != "" { + merged.DefaultSmallModelID = live.DefaultSmallModelID + } + if len(live.DefaultHeaders) > 0 { + merged.DefaultHeaders = live.DefaultHeaders + } + merged.Models = live.Models + + if len(merged.Models) == 0 { + return merged + } + + if merged.DefaultLargeModelID == "" || !modelExists(merged.Models, merged.DefaultLargeModelID) { + merged.DefaultLargeModelID = merged.Models[0].ID + } + if merged.DefaultSmallModelID == "" || !modelExists(merged.Models, merged.DefaultSmallModelID) { + merged.DefaultSmallModelID = merged.Models[0].ID + } + return merged +} + +func modelExists(models []catwalk.Model, id string) bool { + return slices.ContainsFunc(models, func(model catwalk.Model) bool { + return model.ID == id + }) +} diff --git a/internal/config/live_provider_test.go b/internal/config/live_provider_test.go new file mode 100644 index 0000000000..a47a9b3685 --- /dev/null +++ b/internal/config/live_provider_test.go @@ -0,0 +1,290 @@ +package config + +import ( + "context" + "encoding/json" + "errors" + "os" + "testing" + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +type mockLiveProviderClient struct { + provider catwalk.Provider + err error + callCount int + etags []string +} + +func (m *mockLiveProviderClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + m.callCount++ + m.etags = append(m.etags, etag) + return m.provider, m.err +} + +func TestLiveProviderSync_GetPanicIfNotInit(t *testing.T) { + t.Parallel() + + syncer := &liveProviderSync{} + require.Panics(t, func() { + _, _ = syncer.Get(t.Context()) + }) +} + +func TestLiveProviderSync_GetAutoUpdateDisabledReturnsSeed(t *testing.T) { + t.Parallel() + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", false, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 0, client.callCount) +} + +func TestLiveProviderSync_GetWithoutCredentialsReturnsSeed(t *testing.T) { + t.Parallel() + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, seed, false) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 0, client.callCount) +} + +func TestLiveProviderSync_GetWarmCacheReturnsCachedWithoutFetch(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 0, client.callCount) +} + +func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{ + Models: []catwalk.Model{ + {ID: "live-model", Name: "Live Model"}, + {ID: "other-live-model", Name: "Other Live Model"}, + }, + }, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, 1, client.callCount) + require.NotEmpty(t, client.etags[0]) + require.Equal(t, seed.ID, provider.ID) + require.Equal(t, seed.Name, provider.Name) + require.Equal(t, seed.APIEndpoint, provider.APIEndpoint) + require.Equal(t, seed.Type, provider.Type) + require.Equal(t, "live-model", provider.DefaultLargeModelID) + require.Equal(t, "live-model", provider.DefaultSmallModelID) + require.Equal(t, []catwalk.Model{ + {ID: "live-model", Name: "Live Model"}, + {ID: "other-live-model", Name: "Other Live Model"}, + }, provider.Models) + + stored, _, err := newCache[catwalk.Provider](path).Get() + require.NoError(t, err) + require.Equal(t, provider, stored) +} + +func TestLiveProviderSync_GetNotModifiedUsesCached(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + client := &mockLiveProviderClient{err: catwalk.ErrNotModified} + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 1, client.callCount) +} + +func TestLiveProviderSync_GetDeadlineExceededUsesCached(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + client := &mockLiveProviderClient{err: context.DeadlineExceeded} + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 1, client.callCount) +} + +func TestLiveProviderSync_GetFetchErrorUsesSeedWithoutCache(t *testing.T) { + t.Parallel() + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{err: errors.New("network error")} + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 1, client.callCount) +} + +func TestLiveProviderSync_GetEmptyModelsUsesFallback(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Name = "Cached Provider" + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + staleTime := time.Now().Add(-2 * liveModelsTTL) + require.NoError(t, os.Chtimes(path, staleTime, staleTime)) + + client := &mockLiveProviderClient{provider: catwalk.Provider{ID: "test-live"}} + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 1, client.callCount) +} + +func TestLiveProviderSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { + t.Parallel() + + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, testLiveSeedProvider(), true) + + provider1, err1 := syncer.Get(t.Context()) + require.NoError(t, err1) + provider2, err2 := syncer.Get(t.Context()) + require.NoError(t, err2) + require.Equal(t, provider1, provider2) + require.Equal(t, 1, client.callCount) +} + +func TestLiveProviderSync_ForceBypassesWarmCache(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + syncer.Force() + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, 1, client.callCount) + require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, provider.Models) +} + +func TestCacheAge(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + _, ok := cacheAge(path) + require.False(t, ok) + + writeLiveProviderCache(t, path, testLiveSeedProvider()) + age, ok := cacheAge(path) + require.True(t, ok) + require.Less(t, age, liveModelsTTL) + + future := time.Now().Add(liveModelsTTL) + require.NoError(t, os.Chtimes(path, future, future)) + age, ok = cacheAge(path) + require.True(t, ok) + require.Zero(t, age) +} + +func testLiveSeedProvider() catwalk.Provider { + return catwalk.Provider{ + ID: "test-live", + Name: "Test Live", + APIEndpoint: "https://example.com/v1", + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "seed-large-model", + DefaultSmallModelID: "seed-small-model", + DefaultHeaders: map[string]string{ + "X-Test": "seed", + }, + Models: []catwalk.Model{ + {ID: "seed-large-model", Name: "Seed Large Model"}, + {ID: "seed-small-model", Name: "Seed Small Model"}, + }, + } +} + +func writeLiveProviderCache(t *testing.T, path string, provider catwalk.Provider) { + t.Helper() + + data, err := json.Marshal(provider) + require.NoError(t, err) + require.NoError(t, os.WriteFile(path, data, 0o644)) +} From 848f092be93be5ae45c419181f12591559c258fe Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 19:37:00 -0700 Subject: [PATCH 02/16] feat(config): add Venice live models client --- internal/config/venice_models.go | 152 ++++++++++++++++++++++++++ internal/config/venice_models_test.go | 150 +++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 internal/config/venice_models.go create mode 100644 internal/config/venice_models_test.go diff --git a/internal/config/venice_models.go b/internal/config/venice_models.go new file mode 100644 index 0000000000..7832dc42b2 --- /dev/null +++ b/internal/config/venice_models.go @@ -0,0 +1,152 @@ +package config + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "charm.land/catwalk/pkg/catwalk" + xetag "github.com/charmbracelet/x/etag" +) + +var _ liveProviderClient = realVeniceModelsClient{} + +type realVeniceModelsClient struct { + baseURL string + apiKey string +} + +func (r realVeniceModelsClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + var result catwalk.Provider + baseURL := strings.TrimRight(r.baseURL, "/") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) + if err != nil { + return result, fmt.Errorf("could not create request: %w", err) + } + xetag.Request(req, etag) + if apiKey := strings.TrimSpace(r.apiKey); apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode == http.StatusNotModified { + return result, catwalk.ErrNotModified + } + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var models veniceModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&models); err != nil { + return result, fmt.Errorf("failed to decode response: %w", err) + } + + result = catwalk.Provider{ + ID: catwalk.InferenceProviderVenice, + APIEndpoint: baseURL, + Type: catwalk.TypeOpenAICompat, + Models: veniceModelsToCatwalkModels(models), + } + return result, nil +} + +type veniceModelsResponse struct { + Data []veniceModel `json:"data"` +} + +type veniceModel struct { + ID string `json:"id"` + ModelSpec veniceModelSpec `json:"model_spec"` + Type string `json:"type"` +} + +type veniceModelSpec struct { + AvailableContextTokens int64 `json:"availableContextTokens"` + MaxCompletionTokens int64 `json:"maxCompletionTokens"` + Capabilities veniceModelCapabilities `json:"capabilities"` + Name string `json:"name"` + Offline bool `json:"offline"` + Pricing veniceModelPricing `json:"pricing"` +} + +type veniceModelCapabilities struct { + SupportsReasoning bool `json:"supportsReasoning"` + ReasoningEffortOptions []string `json:"reasoningEffortOptions"` + DefaultReasoningEffort string `json:"defaultReasoningEffort"` + SupportsVision bool `json:"supportsVision"` +} + +type veniceModelPricing struct { + Input veniceModelPricingValue `json:"input"` + Output veniceModelPricingValue `json:"output"` + CacheInput veniceModelPricingValue `json:"cache_input"` +} + +type veniceModelPricingValue struct { + USD veniceUSD `json:"usd"` +} + +type veniceUSD float64 + +func (v *veniceUSD) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + *v = 0 + return nil + } + + var number float64 + if err := json.Unmarshal(data, &number); err == nil { + *v = veniceUSD(number) + return nil + } + + var value string + if err := json.Unmarshal(data, &value); err != nil { + return fmt.Errorf("failed to decode usd value: %w", err) + } + value = strings.TrimSpace(value) + if value == "" { + *v = 0 + return nil + } + parsed, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("failed to parse usd value: %w", err) + } + *v = veniceUSD(parsed) + return nil +} + +func veniceModelsToCatwalkModels(response veniceModelsResponse) []catwalk.Model { + models := make([]catwalk.Model, 0, len(response.Data)) + for _, model := range response.Data { + if !strings.EqualFold(model.Type, "text") || model.ModelSpec.Offline { + continue + } + + models = append(models, catwalk.Model{ + ID: model.ID, + Name: model.ModelSpec.Name, + CostPer1MIn: float64(model.ModelSpec.Pricing.Input.USD), + CostPer1MOut: float64(model.ModelSpec.Pricing.Output.USD), + CostPer1MInCached: float64(model.ModelSpec.Pricing.CacheInput.USD), + ContextWindow: model.ModelSpec.AvailableContextTokens, + DefaultMaxTokens: model.ModelSpec.MaxCompletionTokens, + CanReason: model.ModelSpec.Capabilities.SupportsReasoning, + ReasoningLevels: model.ModelSpec.Capabilities.ReasoningEffortOptions, + DefaultReasoningEffort: model.ModelSpec.Capabilities.DefaultReasoningEffort, + SupportsImages: model.ModelSpec.Capabilities.SupportsVision, + }) + } + return models +} diff --git a/internal/config/venice_models_test.go b/internal/config/venice_models_test.go new file mode 100644 index 0000000000..ad31e3cf7b --- /dev/null +++ b/internal/config/venice_models_test.go @@ -0,0 +1,150 @@ +package config + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "github.com/stretchr/testify/require" +) + +func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{ + "data": [ + { + "id": "venice-reasoning", + "type": "text", + "model_spec": { + "name": "Venice Reasoning", + "availableContextTokens": 128000, + "maxCompletionTokens": 4096, + "offline": false, + "pricing": { + "input": { "usd": "0.15" }, + "output": { "usd": 0.6 }, + "cache_input": { "usd": "0.03" } + }, + "capabilities": { + "supportsReasoning": true, + "reasoningEffortOptions": ["low", "medium", "high"], + "defaultReasoningEffort": "medium", + "supportsVision": true + } + } + }, + { + "id": "venice-image", + "type": "image", + "model_spec": { "name": "Image Model" } + }, + { + "id": "venice-offline", + "type": "text", + "model_spec": { "name": "Offline Model", "offline": true } + } + ] + }`)) + require.NoError(t, err) + })) + defer server.Close() + + client := realVeniceModelsClient{baseURL: server.URL + "/", apiKey: " test-key "} + provider, err := client.Get(t.Context(), "cached-etag") + + require.NoError(t, err) + require.Equal(t, catwalk.InferenceProviderVenice, provider.ID) + require.Equal(t, server.URL, provider.APIEndpoint) + require.Equal(t, catwalk.TypeOpenAICompat, provider.Type) + require.Equal(t, []catwalk.Model{ + { + ID: "venice-reasoning", + Name: "Venice Reasoning", + CostPer1MIn: 0.15, + CostPer1MOut: 0.6, + CostPer1MInCached: 0.03, + ContextWindow: 128000, + DefaultMaxTokens: 4096, + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, + DefaultReasoningEffort: "medium", + SupportsImages: true, + }, + }, provider.Models) +} + +func TestRealVeniceModelsClientGetNotModified(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) + w.WriteHeader(http.StatusNotModified) + })) + defer server.Close() + + client := realVeniceModelsClient{baseURL: server.URL} + provider, err := client.Get(t.Context(), "cached-etag") + + require.True(t, errors.Is(err, catwalk.ErrNotModified)) + require.Empty(t, provider.Models) +} + +func TestVeniceModelsToCatwalkModelsFiltersNonTextAndOfflineModels(t *testing.T) { + t.Parallel() + + models := veniceModelsToCatwalkModels(veniceModelsResponse{Data: []veniceModel{ + { + ID: "text-model", + Type: "TEXT", + ModelSpec: veniceModelSpec{ + Name: "Text Model", + AvailableContextTokens: 32000, + MaxCompletionTokens: 2048, + Pricing: veniceModelPricing{ + Input: veniceModelPricingValue{USD: veniceUSD(0.1)}, + Output: veniceModelPricingValue{USD: veniceUSD(0.2)}, + CacheInput: veniceModelPricingValue{USD: veniceUSD(0.01)}, + }, + Capabilities: veniceModelCapabilities{ + SupportsReasoning: true, + ReasoningEffortOptions: []string{"low"}, + DefaultReasoningEffort: "low", + }, + }, + }, + { + ID: "image-model", + Type: "image", + ModelSpec: veniceModelSpec{Name: "Image Model"}, + }, + { + ID: "offline-model", + Type: "text", + ModelSpec: veniceModelSpec{Name: "Offline Model", Offline: true}, + }, + }}) + + require.Equal(t, []catwalk.Model{ + { + ID: "text-model", + Name: "Text Model", + CostPer1MIn: 0.1, + CostPer1MOut: 0.2, + CostPer1MInCached: 0.01, + ContextWindow: 32000, + DefaultMaxTokens: 2048, + CanReason: true, + ReasoningLevels: []string{"low"}, + DefaultReasoningEffort: "low", + }, + }, models) +} From ba89e4d6abf405a6f5fb8e12526b10d13dc01d8a Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 19:43:41 -0700 Subject: [PATCH 03/16] feat(config): add Copilot live models client --- internal/config/copilot_models.go | 176 ++++++++++++++++++++++++ internal/config/copilot_models_test.go | 183 +++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 internal/config/copilot_models.go create mode 100644 internal/config/copilot_models_test.go diff --git a/internal/config/copilot_models.go b/internal/config/copilot_models.go new file mode 100644 index 0000000000..e6fef6e092 --- /dev/null +++ b/internal/config/copilot_models.go @@ -0,0 +1,176 @@ +package config + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "regexp" + "slices" + "strings" + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + xetag "github.com/charmbracelet/x/etag" +) + +var _ liveProviderClient = (*realCopilotModelsClient)(nil) + +type copilotTokenRefresher func(context.Context, string) (*oauth.Token, error) + +type realCopilotModelsClient struct { + baseURL string + apiKey string + oauthToken *oauth.Token + refreshToken copilotTokenRefresher +} + +func (r *realCopilotModelsClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { + var result catwalk.Provider + baseURL := strings.TrimRight(r.baseURL, "/") + accessToken, err := r.accessToken(ctx) + if err != nil { + return result, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) + if err != nil { + return result, fmt.Errorf("could not create request: %w", err) + } + xetag.Request(req, etag) + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+accessToken) + for key, value := range copilot.Headers() { + req.Header.Set(key, value) + } + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return result, fmt.Errorf("failed to make request: %w", err) + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode == http.StatusNotModified { + return result, catwalk.ErrNotModified + } + if resp.StatusCode != http.StatusOK { + return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var models copilotModelsResponse + if err := json.NewDecoder(resp.Body).Decode(&models); err != nil { + return result, fmt.Errorf("failed to decode response: %w", err) + } + + result = catwalk.Provider{ + ID: catwalk.InferenceProviderCopilot, + APIEndpoint: baseURL, + Type: catwalk.TypeOpenAICompat, + Models: copilotModelsToCatwalkModels(models), + } + return result, nil +} + +func (r *realCopilotModelsClient) accessToken(ctx context.Context) (string, error) { + if r.oauthToken != nil { + if token := strings.TrimSpace(r.oauthToken.AccessToken); token != "" && !r.oauthToken.IsExpired() { + r.apiKey = token + return token, nil + } + + refreshToken := strings.TrimSpace(r.oauthToken.RefreshToken) + if refreshToken != "" { + refresh := r.refreshToken + if refresh == nil { + refresh = copilot.RefreshToken + } + refreshedToken, err := refresh(ctx, refreshToken) + if err != nil { + return "", fmt.Errorf("failed to refresh Copilot token: %w", err) + } + r.oauthToken = refreshedToken + r.apiKey = strings.TrimSpace(refreshedToken.AccessToken) + if r.apiKey != "" { + return r.apiKey, nil + } + } + } + + if token := strings.TrimSpace(r.apiKey); token != "" { + return token, nil + } + return "", fmt.Errorf("missing Copilot access token") +} + +type copilotModelsResponse struct { + Data []copilotModel `json:"data"` +} + +type copilotModel struct { + ID string `json:"id"` + Name string `json:"name"` + Version string `json:"version"` + Capabilities copilotModelCapabilities `json:"capabilities"` +} + +type copilotModelCapabilities struct { + Limits copilotModelLimits `json:"limits"` + Supports copilotModelSupports `json:"supports"` +} + +type copilotModelLimits struct { + MaxContextWindowTokens int64 `json:"max_context_window_tokens"` + MaxOutputTokens int64 `json:"max_output_tokens"` +} + +type copilotModelSupports struct { + Vision bool `json:"vision"` +} + +var copilotVersionedModelRegexp = regexp.MustCompile(`-\d{4}-\d{2}-\d{2}$`) + +func copilotModelsToCatwalkModels(response copilotModelsResponse) []catwalk.Model { + aliasedVersions := make(map[string]bool, len(response.Data)) + for _, model := range response.Data { + if model.Version != "" && model.ID != model.Version { + aliasedVersions[model.Version] = true + } + } + + seen := make(map[string]bool, len(response.Data)) + models := make([]catwalk.Model, 0, len(response.Data)) + for _, model := range response.Data { + if shouldSkipCopilotModel(model, aliasedVersions) || seen[model.ID] { + continue + } + seen[model.ID] = true + models = append(models, catwalk.Model{ + ID: model.ID, + Name: model.Name, + ContextWindow: model.Capabilities.Limits.MaxContextWindowTokens, + DefaultMaxTokens: model.Capabilities.Limits.MaxOutputTokens, + SupportsImages: model.Capabilities.Supports.Vision, + }) + } + + slices.SortStableFunc(models, func(a, b catwalk.Model) int { + return strings.Compare(a.ID, b.ID) + }) + return models +} + +func shouldSkipCopilotModel(model copilotModel, aliasedVersions map[string]bool) bool { + return model.ID == "" || + aliasedVersions[model.ID] || + copilotVersionedModelRegexp.MatchString(model.ID) || + strings.Contains(model.ID, "embedding") || + strings.HasPrefix(model.ID, "accounts/msft/routers") || + strings.HasPrefix(model.ID, "oswe-vscode") || + strings.HasPrefix(model.ID, "lark") || + strings.HasPrefix(model.ID, "mai-code") || + model.ID == "gpt-4-o-preview" || + model.ID == "trajectory-compaction" +} diff --git a/internal/config/copilot_models_test.go b/internal/config/copilot_models_test.go new file mode 100644 index 0000000000..12b7633eb4 --- /dev/null +++ b/internal/config/copilot_models_test.go @@ -0,0 +1,183 @@ +package config + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/oauth/copilot" + "github.com/stretchr/testify/require" +) + +func TestRealCopilotModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) + require.Equal(t, "application/json", r.Header.Get("Accept")) + for key, value := range copilot.Headers() { + require.Equal(t, value, r.Header.Get(key)) + } + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{ + "data": [ + { + "id": "claude-sonnet-4.6", + "name": "Claude Sonnet 4.6", + "version": "claude-sonnet-4.6", + "capabilities": { + "limits": { + "max_context_window_tokens": 264000, + "max_output_tokens": 64000 + }, + "supports": { "vision": true } + } + } + ] + }`)) + require.NoError(t, err) + })) + defer server.Close() + + client := &realCopilotModelsClient{baseURL: server.URL + "/", apiKey: " test-token "} + provider, err := client.Get(t.Context(), "cached-etag") + + require.NoError(t, err) + require.Equal(t, catwalk.InferenceProviderCopilot, provider.ID) + require.Equal(t, server.URL, provider.APIEndpoint) + require.Equal(t, catwalk.TypeOpenAICompat, provider.Type) + require.Equal(t, []catwalk.Model{ + { + ID: "claude-sonnet-4.6", + Name: "Claude Sonnet 4.6", + ContextWindow: 264000, + DefaultMaxTokens: 64000, + SupportsImages: true, + }, + }, provider.Models) +} + +func TestRealCopilotModelsClientGetRefreshesExpiredOAuthToken(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Bearer refreshed-token", r.Header.Get("Authorization")) + _, err := w.Write([]byte(`{ + "data": [ + { + "id": "gpt-5.1", + "name": "GPT 5.1", + "capabilities": { + "limits": { + "max_context_window_tokens": 128000, + "max_output_tokens": 16000 + } + } + } + ] + }`)) + require.NoError(t, err) + })) + defer server.Close() + + client := &realCopilotModelsClient{ + baseURL: server.URL, + oauthToken: &oauth.Token{ + AccessToken: "expired-token", + RefreshToken: "github-token", + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(-time.Hour).Unix(), + }, + refreshToken: func(_ctx context.Context, githubToken string) (*oauth.Token, error) { + require.Equal(t, "github-token", githubToken) + return &oauth.Token{ + AccessToken: "refreshed-token", + RefreshToken: githubToken, + ExpiresIn: 3600, + ExpiresAt: time.Now().Add(time.Hour).Unix(), + }, nil + }, + } + + provider, err := client.Get(t.Context(), "") + + require.NoError(t, err) + require.Equal(t, []catwalk.Model{{ID: "gpt-5.1", Name: "GPT 5.1", ContextWindow: 128000, DefaultMaxTokens: 16000}}, provider.Models) + require.Equal(t, "refreshed-token", client.apiKey) +} + +func TestRealCopilotModelsClientGetNotModified(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) + w.WriteHeader(http.StatusNotModified) + })) + defer server.Close() + + client := &realCopilotModelsClient{baseURL: server.URL, apiKey: "test-token"} + provider, err := client.Get(t.Context(), "cached-etag") + + require.True(t, errors.Is(err, catwalk.ErrNotModified)) + require.Empty(t, provider.Models) +} + +func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { + t.Parallel() + + models := copilotModelsToCatwalkModels(copilotModelsResponse{Data: []copilotModel{ + { + ID: "gpt-5.1", + Name: "GPT 5.1", + Version: "gpt-5.1", + Capabilities: copilotModelCapabilities{ + Limits: copilotModelLimits{MaxContextWindowTokens: 128000, MaxOutputTokens: 16000}, + Supports: copilotModelSupports{Vision: true}, + }, + }, + { + ID: "gpt-5.1", + Name: "GPT 5.1 Duplicate", + Version: "gpt-5.1", + }, + { + ID: "aliased-model", + Name: "Aliased Model", + Version: "versioned-model-2026-01-01", + }, + { + ID: "versioned-model-2026-01-01", + Name: "Versioned Model", + Version: "versioned-model-2026-01-01", + }, + {ID: "text-embedding-3", Name: "Embedding"}, + {ID: "accounts/msft/routers/test", Name: "Router"}, + {ID: "oswe-vscode-test", Name: "OSWE"}, + {ID: "lark-test", Name: "Lark"}, + {ID: "mai-code-test", Name: "MAI"}, + {ID: "gpt-4-o-preview", Name: "Preview"}, + {ID: "trajectory-compaction", Name: "Compaction"}, + }}) + + require.Equal(t, []catwalk.Model{ + { + ID: "aliased-model", + Name: "Aliased Model", + }, + { + ID: "gpt-5.1", + Name: "GPT 5.1", + ContextWindow: 128000, + DefaultMaxTokens: 16000, + SupportsImages: true, + }, + }, models) +} From 3451b76bd99a3115939af1bf1b659fb40a099faa Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 20:06:44 -0700 Subject: [PATCH 04/16] feat(config): wire live provider updates --- internal/cmd/update_providers.go | 53 +++++- internal/cmd/update_providers_test.go | 27 +++ internal/config/provider.go | 236 +++++++++++++++++++++++++- internal/config/provider_test.go | 228 +++++++++++++++++++++++++ 4 files changed, 536 insertions(+), 8 deletions(-) create mode 100644 internal/cmd/update_providers_test.go diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 3b4b35b681..3839a5e346 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -17,7 +17,7 @@ var updateProvidersCmd = &cobra.Command{ Short: "Update providers", Long: `Update provider information from a specified local path or remote URL.`, Example: ` -# Update Catwalk providers remotely (default) +# Update Catwalk providers remotely (default), plus authenticated live providers crush update-providers # Update Catwalk providers from a custom URL @@ -34,6 +34,12 @@ crush update-providers --source=hyper # Update Hyper from a custom URL crush update-providers --source=hyper https://hyper.example.com + +# Update Venice provider models +crush update-providers --source=venice + +# Update Copilot provider models +crush update-providers --source=copilot `, RunE: func(cmd *cobra.Command, args []string) error { // NOTE(@andreynering): We want to skip logging output do stdout here. @@ -48,10 +54,23 @@ crush update-providers --source=hyper https://hyper.example.com switch updateProvidersSource { case "catwalk": err = config.UpdateProviders(pathOrURL) + if err == nil && pathOrURL == "" { + updateAuthenticatedLiveProviders(cmd) + } case "hyper": err = config.UpdateHyper(pathOrURL) + case "venice", "copilot": + cfg, loadErr := loadUpdateProvidersConfig(cmd) + if loadErr != nil { + return loadErr + } + if updateProvidersSource == "venice" { + err = config.UpdateVenice(pathOrURL, cfg) + } else { + err = config.UpdateCopilot(pathOrURL, cfg) + } default: - return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource) + return fmt.Errorf("invalid source %q, must be 'catwalk', 'hyper', 'venice', or 'copilot'", updateProvidersSource) } if err != nil { @@ -77,6 +96,34 @@ crush update-providers --source=hyper https://hyper.example.com }, } +func updateAuthenticatedLiveProviders(cmd *cobra.Command) { + cfg, err := loadUpdateProvidersConfig(cmd) + if err != nil { + slog.Debug("Skipping live provider updates", "error", err) + return + } + if err := config.UpdateVenice("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) { + slog.Debug("Skipping Venice provider update", "error", err) + } + if err := config.UpdateCopilot("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) { + slog.Debug("Skipping Copilot provider update", "error", err) + } +} + +func loadUpdateProvidersConfig(cmd *cobra.Command) (*config.Config, error) { + cwd, err := ResolveCwd(cmd) + if err != nil { + return nil, err + } + dataDir, _ := cmd.Flags().GetString("data-dir") + debug, _ := cmd.Flags().GetBool("debug") + store, err := config.Load(cwd, dataDir, debug) + if err != nil { + return nil, err + } + return store.Config(), nil +} + func init() { - updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)") + updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk, hyper, venice, or copilot)") } diff --git a/internal/cmd/update_providers_test.go b/internal/cmd/update_providers_test.go new file mode 100644 index 0000000000..d1d8d6dfd1 --- /dev/null +++ b/internal/cmd/update_providers_test.go @@ -0,0 +1,27 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUpdateProvidersCmd_SourceFlagIncludesLiveProviders(t *testing.T) { + t.Parallel() + + flag := updateProvidersCmd.Flags().Lookup("source") + require.NotNil(t, flag) + require.Equal(t, "catwalk", flag.DefValue) + require.Contains(t, flag.Usage, "venice") + require.Contains(t, flag.Usage, "copilot") +} + +func TestUpdateProvidersCmd_ExamplesIncludeLiveProviders(t *testing.T) { + t.Parallel() + + example := strings.ToLower(updateProvidersCmd.Example) + require.Contains(t, example, "--source=venice") + require.Contains(t, example, "--source=copilot") + require.Contains(t, example, "authenticated live providers") +} diff --git a/internal/config/provider.go b/internal/config/provider.go index 32a3894358..111fdf9c69 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -19,7 +19,9 @@ import ( "charm.land/catwalk/pkg/embedded" "github.com/charmbracelet/crush/internal/agent/hyper" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/x/etag" ) @@ -33,6 +35,14 @@ var ( providerErr error ) +var errMissingLiveProviderCredentials = errors.New("missing live provider credentials") + +// IsMissingLiveProviderCredentials reports whether err means a live provider +// cannot be updated because no credentials are configured. +func IsMissingLiveProviderCredentials(err error) bool { + return errors.Is(err, errMissingLiveProviderCredentials) +} + // file to cache provider data func cachePathFor(name string) string { xdgDataHome := os.Getenv("XDG_DATA_HOME") @@ -85,7 +95,7 @@ func UpdateProviders(pathOrURL string) error { return fmt.Errorf("failed to save providers to cache: %w", err) } - slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor) + slog.Info("Providers updated successfully", "count", len(providers), "from", pathOrURL, "to", cachePathFor("providers")) return nil } @@ -122,9 +132,23 @@ func UpdateHyper(pathOrURL string) error { return nil } +// UpdateVenice updates the cached Venice provider from a live source, embedded +// seed, or local provider file. +func UpdateVenice(pathOrURL string, cfg *Config) error { + return updateLiveProvider("Venice", "venice", catwalk.InferenceProviderVenice, pathOrURL, cfg, newVeniceLiveProviderClient) +} + +// UpdateCopilot updates the cached Copilot provider from a live source, +// embedded seed, or local provider file. +func UpdateCopilot(pathOrURL string, cfg *Config) error { + return updateLiveProvider("Copilot", "copilot", catwalk.InferenceProviderCopilot, pathOrURL, cfg, newCopilotLiveProviderClient) +} + var ( catwalkSyncer = &catwalkSync{} hyperSyncer = &hyperSync{} + veniceSyncer = &liveProviderSync{} + copilotSyncer = &liveProviderSync{} ) // Providers returns the list of providers, taking into account cached results @@ -141,8 +165,12 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { var wg sync.WaitGroup var errs []error providers := csync.NewSlice[catwalk.Provider]() - autoupdate := !cfg.Options.DisableProviderAutoUpdate - customProvidersOnly := cfg.Options.DisableDefaultProviders + options := &Options{} + if cfg != nil && cfg.Options != nil { + options = cfg.Options + } + autoupdate := !options.DisableProviderAutoUpdate + customProvidersOnly := options.DisableDefaultProviders ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) defer cancel() @@ -186,16 +214,214 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { wg.Wait() + items := slices.Collect(providers.Seq()) + if !customProvidersOnly { + var liveErrs []error + items, liveErrs = overlayLiveProviderModels(ctx, cfg, items, autoupdate) + errs = append(errs, liveErrs...) + } + if hyperFound { - providerList = append([]catwalk.Provider{hyperProvider}, slices.Collect(providers.Seq())...) + providerList = append([]catwalk.Provider{hyperProvider}, items...) } else { - providerList = slices.Collect(providers.Seq()) + providerList = items } providerErr = errors.Join(errs...) }) return providerList, providerErr } +func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []catwalk.Provider, autoupdate bool) ([]catwalk.Provider, []error) { + if cfg == nil || len(providers) == 0 { + return providers, nil + } + + environment := env.New() + resolver := NewShellVariableResolver(environment) + errs := make([]error, 0, 2) + + syncProvider := func(providerID catwalk.InferenceProvider, cacheName string, syncer *liveProviderSync, newClient liveProviderClientFunc) { + index := slices.IndexFunc(providers, func(provider catwalk.Provider) bool { + return provider.ID == providerID + }) + if index < 0 { + return + } + + seed := providers[index] + client, credentialed, err := newClient(seed, cfg, resolver, "") + if err != nil { + slog.Warn("Skipping live provider sync", "provider", providerID, "error", err) + return + } + + syncer.Init(client, cachePathFor(cacheName), autoupdate, seed, credentialed) + provider, err := syncer.Get(ctx) + if err != nil { + errs = append(errs, fmt.Errorf("Crush was unable to cache updated models from %s: %w", seed.Name, err)) + return + } + providers[index] = provider + } + + syncProvider(catwalk.InferenceProviderVenice, "venice", veniceSyncer, newVeniceLiveProviderClient) + syncProvider(catwalk.InferenceProviderCopilot, "copilot", copilotSyncer, newCopilotLiveProviderClient) + + return providers, errs +} + +type liveProviderClientFunc func(catwalk.Provider, *Config, VariableResolver, string) (liveProviderClient, bool, error) + +func newVeniceLiveProviderClient(seed catwalk.Provider, cfg *Config, resolver VariableResolver, baseURLOverride string) (liveProviderClient, bool, error) { + var providerConfig ProviderConfig + configExists := cfg != nil && cfg.Providers != nil + if configExists { + providerConfig, configExists = cfg.Providers.Get(string(catwalk.InferenceProviderVenice)) + } + if configExists && providerConfig.Disable { + return realVeniceModelsClient{baseURL: seed.APIEndpoint}, false, nil + } + + baseURL := cmp.Or(baseURLOverride, seed.APIEndpoint) + if configExists && providerConfig.BaseURL != "" && baseURLOverride == "" { + resolved, err := resolver.ResolveValue(providerConfig.BaseURL) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Venice base URL: %w", err) + } + baseURL = resolved + } + + apiKey := "" + if configExists && providerConfig.APIKey != "" { + resolved, err := resolver.ResolveValue(providerConfig.APIKey) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Venice API key: %w", err) + } + apiKey = strings.TrimSpace(resolved) + } + if apiKey == "" { + apiKey = strings.TrimSpace(os.Getenv("VENICE_API_KEY")) + } + + return realVeniceModelsClient{baseURL: baseURL, apiKey: apiKey}, apiKey != "", nil +} + +func newCopilotLiveProviderClient(seed catwalk.Provider, cfg *Config, resolver VariableResolver, baseURLOverride string) (liveProviderClient, bool, error) { + var providerConfig ProviderConfig + configExists := cfg != nil && cfg.Providers != nil + if configExists { + providerConfig, configExists = cfg.Providers.Get(string(catwalk.InferenceProviderCopilot)) + } + if configExists && providerConfig.Disable { + return &realCopilotModelsClient{baseURL: seed.APIEndpoint}, false, nil + } + + baseURL := cmp.Or(baseURLOverride, seed.APIEndpoint) + if configExists && providerConfig.BaseURL != "" && baseURLOverride == "" { + resolved, err := resolver.ResolveValue(providerConfig.BaseURL) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Copilot base URL: %w", err) + } + baseURL = resolved + } + + apiKey := "" + if configExists && providerConfig.APIKey != "" { + resolved, err := resolver.ResolveValue(providerConfig.APIKey) + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Copilot API key: %w", err) + } + apiKey = strings.TrimSpace(resolved) + } + oauthToken := providerConfig.OAuthToken + credentialed := apiKey != "" || usableOAuthToken(oauthToken) + + return &realCopilotModelsClient{baseURL: baseURL, apiKey: apiKey, oauthToken: oauthToken}, credentialed, nil +} + +func usableOAuthToken(token *oauth.Token) bool { + if token == nil { + return false + } + return strings.TrimSpace(token.AccessToken) != "" || strings.TrimSpace(token.RefreshToken) != "" +} + +func updateLiveProvider(name, cacheName string, providerID catwalk.InferenceProvider, pathOrURL string, cfg *Config, newClient liveProviderClientFunc) error { + seed, err := liveProviderSeed(providerID) + if err != nil { + return err + } + + var provider catwalk.Provider + switch { + case pathOrURL == "embedded": + provider = seed + case pathOrURL != "" && !strings.HasPrefix(pathOrURL, "http://") && !strings.HasPrefix(pathOrURL, "https://"): + content, err := os.ReadFile(pathOrURL) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + if err := json.Unmarshal(content, &provider); err != nil { + return fmt.Errorf("failed to unmarshal provider data: %w", err) + } + default: + if cfg == nil { + return fmt.Errorf("failed to fetch provider from %s: %w", name, errMissingLiveProviderCredentials) + } + environment := env.New() + resolver := NewShellVariableResolver(environment) + client, credentialed, err := newClient(seed, cfg, resolver, pathOrURL) + if err != nil { + return fmt.Errorf("failed to prepare %s provider update: %w", name, err) + } + if !credentialed { + return fmt.Errorf("failed to fetch provider from %s: %w", name, errMissingLiveProviderCredentials) + } + + ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) + defer cancel() + + provider, err = client.Get(ctx, "") + if err != nil { + return fmt.Errorf("failed to fetch provider from %s: %w", name, err) + } + if len(provider.Models) == 0 { + return fmt.Errorf("failed to fetch provider from %s: no models returned", name) + } + provider = mergeLiveProvider(seed, provider) + } + + if err := newCache[catwalk.Provider](cachePathFor(cacheName)).Store(provider); err != nil { + return fmt.Errorf("failed to save %s provider to cache: %w", name, err) + } + + slog.Info(name+" provider updated successfully", "from", cmp.Or(pathOrURL, "live"), "to", cachePathFor(cacheName)) + return nil +} + +func liveProviderSeed(providerID catwalk.InferenceProvider) (catwalk.Provider, error) { + providers, _, err := newCache[[]catwalk.Provider](cachePathFor("providers")).Get() + if err != nil || len(providers) == 0 { + providers = embedded.GetAll() + } + if provider, ok := findProvider(providers, providerID); ok { + return provider, nil + } + if provider, ok := findProvider(embedded.GetAll(), providerID); ok { + return provider, nil + } + return catwalk.Provider{}, fmt.Errorf("provider %s not found", providerID) +} + +func findProvider(providers []catwalk.Provider, providerID catwalk.InferenceProvider) (catwalk.Provider, bool) { + for _, provider := range providers { + if provider.ID == providerID { + return provider, true + } + } + return catwalk.Provider{}, false +} + type cache[T any] struct { path string } diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 283c18c8ab..2d28aa13ea 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -2,12 +2,15 @@ package config import ( "encoding/json" + "net/http" + "net/http/httptest" "os" "path/filepath" "sync" "testing" "charm.land/catwalk/pkg/catwalk" + "github.com/charmbracelet/crush/internal/csync" "github.com/stretchr/testify/require" ) @@ -17,6 +20,8 @@ func resetProviderState() { providerErr = nil catwalkSyncer = &catwalkSync{} hyperSyncer = &hyperSync{} + veniceSyncer = &liveProviderSync{} + copilotSyncer = &liveProviderSync{} } func TestProviders_Integration_AutoUpdateDisabled(t *testing.T) { @@ -217,6 +222,229 @@ func TestProviders_Integration_BothFail(t *testing.T) { require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models. } +func TestProviders_Integration_LiveOverlayFetchesWithCredentials(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + + switch r.Header.Get("Authorization") { + case "Bearer venice-token": + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "venice-live", + "type": "text", + "model_spec": { + "name": "Venice Live", + "availableContextTokens": 4096, + "maxCompletionTokens": 1024, + "pricing": {"input": {"usd": "0.1"}, "output": {"usd": "0.2"}}, + "capabilities": {"supportsVision": true} + } + } + ] + }`)) + case "Bearer copilot-token": + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "copilot-live", + "name": "Copilot Live", + "capabilities": { + "limits": {"max_context_window_tokens": 8192, "max_output_tokens": 2048}, + "supports": {"vision": true} + } + } + ] + }`)) + default: + w.WriteHeader(http.StatusUnauthorized) + } + })) + defer server.Close() + + providers := []catwalk.Provider{ + { + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "venice-live", + DefaultSmallModelID: "venice-live", + Models: []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, + }, + { + Name: "Copilot", + ID: catwalk.InferenceProviderCopilot, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "copilot-live", + DefaultSmallModelID: "copilot-live", + Models: []catwalk.Model{{ID: "copilot-seed", Name: "Copilot Seed"}}, + }, + } + cfg := &Config{ + Options: &Options{}, + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + string(catwalk.InferenceProviderVenice): {APIKey: "venice-token"}, + string(catwalk.InferenceProviderCopilot): {APIKey: "copilot-token"}, + }), + } + + result, errs := overlayLiveProviderModels(t.Context(), cfg, providers, true) + require.Empty(t, errs) + + venice, ok := findProvider(result, catwalk.InferenceProviderVenice) + require.True(t, ok) + require.Equal(t, "Venice", venice.Name) + require.Equal(t, "venice-live", venice.DefaultLargeModelID) + require.Equal(t, []catwalk.Model{{ + ID: "venice-live", + Name: "Venice Live", + CostPer1MIn: 0.1, + CostPer1MOut: 0.2, + ContextWindow: 4096, + DefaultMaxTokens: 1024, + SupportsImages: true, + }}, venice.Models) + + copilot, ok := findProvider(result, catwalk.InferenceProviderCopilot) + require.True(t, ok) + require.Equal(t, "Copilot", copilot.Name) + require.Equal(t, "copilot-live", copilot.DefaultSmallModelID) + require.Equal(t, []catwalk.Model{{ + ID: "copilot-live", + Name: "Copilot Live", + ContextWindow: 8192, + DefaultMaxTokens: 2048, + SupportsImages: true, + }}, copilot.Models) +} + +func TestProviders_Integration_LiveOverlaySkipsWithoutCredentials(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + seed := catwalk.Provider{ + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: "http://127.0.0.1:1", + Models: []catwalk.Model{{ID: "seed-model", Name: "Seed Model"}}, + } + cached := seed + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + require.NoError(t, newCache[catwalk.Provider](cachePathFor("venice")).Store(cached)) + + cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} + result, errs := overlayLiveProviderModels(t.Context(), cfg, []catwalk.Provider{seed}, true) + require.Empty(t, errs) + require.Equal(t, []catwalk.Provider{seed}, result) +} + +func TestUpdateVenice_LiveSourceStoresCache(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer venice-token", r.Header.Get("Authorization")) + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "venice-live", + "type": "text", + "model_spec": { + "name": "Venice Live", + "availableContextTokens": 4096, + "maxCompletionTokens": 1024, + "pricing": {"input": {"usd": "0.1"}, "output": {"usd": "0.2"}}, + "capabilities": {"supportsVision": true} + } + } + ] + }`)) + })) + defer server.Close() + + seed := catwalk.Provider{ + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "venice-live", + DefaultSmallModelID: "venice-live", + } + require.NoError(t, newCache[[]catwalk.Provider](cachePathFor("providers")).Store([]catwalk.Provider{seed})) + cfg := &Config{ + Options: &Options{}, + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + string(catwalk.InferenceProviderVenice): {APIKey: "venice-token"}, + }), + } + + require.NoError(t, UpdateVenice(server.URL, cfg)) + + cached, _, err := newCache[catwalk.Provider](cachePathFor("venice")).Get() + require.NoError(t, err) + require.Equal(t, "Venice", cached.Name) + require.Equal(t, "venice-live", cached.DefaultLargeModelID) + require.Len(t, cached.Models, 1) + require.Equal(t, "venice-live", cached.Models[0].ID) +} + +func TestUpdateVenice_LiveSourceRequiresCredentials(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + seed := catwalk.Provider{ID: catwalk.InferenceProviderVenice, Name: "Venice"} + require.NoError(t, newCache[[]catwalk.Provider](cachePathFor("providers")).Store([]catwalk.Provider{seed})) + cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} + + err := UpdateVenice("", cfg) + require.Error(t, err) + require.True(t, IsMissingLiveProviderCredentials(err)) +} + +func TestUpdateCopilot_FileSourceStoresCache(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + provider := catwalk.Provider{ + Name: "Copilot", + ID: catwalk.InferenceProviderCopilot, + Models: []catwalk.Model{ + {ID: "copilot-file", Name: "Copilot File"}, + }, + } + data, err := json.Marshal(provider) + require.NoError(t, err) + path := filepath.Join(tmpDir, "copilot.json") + require.NoError(t, os.WriteFile(path, data, 0o644)) + + require.NoError(t, UpdateCopilot(path, nil)) + + cached, _, err := newCache[catwalk.Provider](cachePathFor("copilot")).Get() + require.NoError(t, err) + require.Equal(t, provider, cached) +} + func TestCache_StoreAndGet(t *testing.T) { t.Parallel() From f62f9f3ccd7191d9accb358c0606266a2a1d1447 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 20:19:40 -0700 Subject: [PATCH 05/16] docs(config): document live provider refresh --- AGENTS.md | 5 +++++ README.md | 1 + internal/config/provider.go | 2 +- internal/server/recover_test.go | 8 ++++---- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f25de08542..e3fe3d3cdd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -64,6 +64,11 @@ internal/ ### Key Patterns - **Config is a Service**: accessed via `config.Service`, not global state. +- **Provider updates**: Catwalk is the default provider seed/cache. Venice and + Copilot models are live-overlaid only when authenticated and use the existing + provider cache/syncer with a 60s warm-cache TTL; `crush update-providers` + refreshes authenticated live providers best-effort, while `--source=venice` + and `--source=copilot` refresh them directly. - **Tools are self-documenting**: each tool has a `.go` implementation and a `.md` description file in `internal/agent/tools/`. - **System prompts are Go templates**: `internal/agent/templates/*.md.tpl` diff --git a/README.md b/README.md index f3ddaf7afa..2ab556614d 100644 --- a/README.md +++ b/README.md @@ -859,6 +859,7 @@ command: ```bash # Update providers remotely from Catwalk. +# This also refreshes Venice and Copilot models when you're authenticated. crush update-providers # Update providers from a custom Catwalk base URL. diff --git a/internal/config/provider.go b/internal/config/provider.go index 111fdf9c69..3a79e108ec 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -258,7 +258,7 @@ func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []cat syncer.Init(client, cachePathFor(cacheName), autoupdate, seed, credentialed) provider, err := syncer.Get(ctx) if err != nil { - errs = append(errs, fmt.Errorf("Crush was unable to cache updated models from %s: %w", seed.Name, err)) + errs = append(errs, fmt.Errorf("crush was unable to cache updated models from %s: %w", seed.Name, err)) return } providers[index] = provider diff --git a/internal/server/recover_test.go b/internal/server/recover_test.go index 2bbd61efd5..45b616677b 100644 --- a/internal/server/recover_test.go +++ b/internal/server/recover_test.go @@ -23,7 +23,7 @@ func TestRecoverHandler_PanicReturns500(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) h.ServeHTTP(rec, req) require.Equal(t, http.StatusInternalServerError, rec.Code) @@ -48,7 +48,7 @@ func TestRecoverHandler_NoPanicPassthrough(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) h.ServeHTTP(rec, req) require.Equal(t, http.StatusTeapot, rec.Code) @@ -71,7 +71,7 @@ func TestRecoverHandler_PanicAfterWriteHeader(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) require.NotPanics(t, func() { h.ServeHTTP(rec, req) }) require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, "partial", rec.Body.String()) @@ -89,6 +89,6 @@ func TestRecoverHandler_AbortHandlerPropagates(t *testing.T) { })) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/test", nil) + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil) require.PanicsWithValue(t, http.ErrAbortHandler, func() { h.ServeHTTP(rec, req) }) } From 4bd9d365d8c572797a2ee77f5d96e60083a74715 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 20:36:31 -0700 Subject: [PATCH 06/16] fix(config): preserve live model metadata --- internal/config/copilot_models.go | 7 ++++++- internal/config/copilot_models_test.go | 17 ++++++++++++++--- internal/config/venice_models.go | 14 ++++++++++---- internal/config/venice_models_test.go | 16 ++++++++++++++++ 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/internal/config/copilot_models.go b/internal/config/copilot_models.go index e6fef6e092..ea9ff37b18 100644 --- a/internal/config/copilot_models.go +++ b/internal/config/copilot_models.go @@ -127,7 +127,9 @@ type copilotModelLimits struct { } type copilotModelSupports struct { - Vision bool `json:"vision"` + Vision bool `json:"vision"` + ReasoningEffort []string `json:"reasoning_effort"` + AdaptiveThinking bool `json:"adaptive_thinking"` } var copilotVersionedModelRegexp = regexp.MustCompile(`-\d{4}-\d{2}-\d{2}$`) @@ -147,11 +149,14 @@ func copilotModelsToCatwalkModels(response copilotModelsResponse) []catwalk.Mode continue } seen[model.ID] = true + reasoningLevels := model.Capabilities.Supports.ReasoningEffort models = append(models, catwalk.Model{ ID: model.ID, Name: model.Name, ContextWindow: model.Capabilities.Limits.MaxContextWindowTokens, DefaultMaxTokens: model.Capabilities.Limits.MaxOutputTokens, + CanReason: model.Capabilities.Supports.AdaptiveThinking || len(reasoningLevels) > 0, + ReasoningLevels: reasoningLevels, SupportsImages: model.Capabilities.Supports.Vision, }) } diff --git a/internal/config/copilot_models_test.go b/internal/config/copilot_models_test.go index 12b7633eb4..11143f7c2f 100644 --- a/internal/config/copilot_models_test.go +++ b/internal/config/copilot_models_test.go @@ -38,7 +38,10 @@ func TestRealCopilotModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { "max_context_window_tokens": 264000, "max_output_tokens": 64000 }, - "supports": { "vision": true } + "supports": { + "vision": true, + "reasoning_effort": ["low", "medium", "high"] + } } } ] @@ -60,6 +63,8 @@ func TestRealCopilotModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { Name: "Claude Sonnet 4.6", ContextWindow: 264000, DefaultMaxTokens: 64000, + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, SupportsImages: true, }, }, provider.Models) @@ -139,8 +144,12 @@ func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { Name: "GPT 5.1", Version: "gpt-5.1", Capabilities: copilotModelCapabilities{ - Limits: copilotModelLimits{MaxContextWindowTokens: 128000, MaxOutputTokens: 16000}, - Supports: copilotModelSupports{Vision: true}, + Limits: copilotModelLimits{MaxContextWindowTokens: 128000, MaxOutputTokens: 16000}, + Supports: copilotModelSupports{ + Vision: true, + ReasoningEffort: []string{"low", "medium", "high"}, + AdaptiveThinking: true, + }, }, }, { @@ -177,6 +186,8 @@ func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { Name: "GPT 5.1", ContextWindow: 128000, DefaultMaxTokens: 16000, + CanReason: true, + ReasoningLevels: []string{"low", "medium", "high"}, SupportsImages: true, }, }, models) diff --git a/internal/config/venice_models.go b/internal/config/venice_models.go index 7832dc42b2..de99fe2b98 100644 --- a/internal/config/venice_models.go +++ b/internal/config/venice_models.go @@ -65,9 +65,10 @@ type veniceModelsResponse struct { } type veniceModel struct { - ID string `json:"id"` - ModelSpec veniceModelSpec `json:"model_spec"` - Type string `json:"type"` + ID string `json:"id"` + ContextLength int64 `json:"context_length"` + ModelSpec veniceModelSpec `json:"model_spec"` + Type string `json:"type"` } type veniceModelSpec struct { @@ -134,13 +135,18 @@ func veniceModelsToCatwalkModels(response veniceModelsResponse) []catwalk.Model continue } + contextWindow := model.ModelSpec.AvailableContextTokens + if contextWindow == 0 { + contextWindow = model.ContextLength + } + models = append(models, catwalk.Model{ ID: model.ID, Name: model.ModelSpec.Name, CostPer1MIn: float64(model.ModelSpec.Pricing.Input.USD), CostPer1MOut: float64(model.ModelSpec.Pricing.Output.USD), CostPer1MInCached: float64(model.ModelSpec.Pricing.CacheInput.USD), - ContextWindow: model.ModelSpec.AvailableContextTokens, + ContextWindow: contextWindow, DefaultMaxTokens: model.ModelSpec.MaxCompletionTokens, CanReason: model.ModelSpec.Capabilities.SupportsReasoning, ReasoningLevels: model.ModelSpec.Capabilities.ReasoningEffortOptions, diff --git a/internal/config/venice_models_test.go b/internal/config/venice_models_test.go index ad31e3cf7b..c2fca83f95 100644 --- a/internal/config/venice_models_test.go +++ b/internal/config/venice_models_test.go @@ -42,6 +42,16 @@ func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { } } }, + { + "id": "venice-fallback", + "type": "text", + "context_length": 64000, + "model_spec": { + "name": "Venice Fallback", + "maxCompletionTokens": 2048, + "offline": false + } + }, { "id": "venice-image", "type": "image", @@ -79,6 +89,12 @@ func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { DefaultReasoningEffort: "medium", SupportsImages: true, }, + { + ID: "venice-fallback", + Name: "Venice Fallback", + ContextWindow: 64000, + DefaultMaxTokens: 2048, + }, }, provider.Models) } From 9f6c85b8d87d46ad1b9a67855012d7cc02aab60a Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Tue, 9 Jun 2026 21:31:06 -0700 Subject: [PATCH 07/16] fix(ui): persist model selections to workspace scope --- internal/ui/model/ui.go | 4 +- internal/ui/model/ui_test.go | 91 +++++++++++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 3 deletions(-) diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 9dfe722796..436c105269 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1735,7 +1735,7 @@ func (m *UI) handleSelectModel(msg dialog.ActionSelectModel) tea.Cmd { return tea.Batch(cmds...) } - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, msg.ModelType, msg.Model); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, msg.ModelType, msg.Model); err != nil { cmds = append(cmds, util.ReportError(err)) } else { if msg.ModelType == config.SelectedModelTypeLarge { @@ -1746,7 +1746,7 @@ func (m *UI) handleSelectModel(msg dialog.ActionSelectModel) tea.Cmd { if _, ok := cfg.Models[config.SelectedModelTypeSmall]; !ok { // Ensure small model is set is unset. smallModel := m.com.Workspace.GetDefaultSmallModel(providerID) - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, config.SelectedModelTypeSmall, smallModel); err != nil { + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, config.SelectedModelTypeSmall, smallModel); err != nil { cmds = append(cmds, util.ReportError(err)) } } diff --git a/internal/ui/model/ui_test.go b/internal/ui/model/ui_test.go index 4032c80a05..378b422df5 100644 --- a/internal/ui/model/ui_test.go +++ b/internal/ui/model/ui_test.go @@ -1,12 +1,15 @@ package model import ( + "context" "testing" "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/ui/common" + "github.com/charmbracelet/crush/internal/ui/dialog" + "github.com/charmbracelet/crush/internal/ui/styles" "github.com/charmbracelet/crush/internal/workspace" "github.com/stretchr/testify/require" ) @@ -74,6 +77,47 @@ func TestCurrentModelSupportsImages(t *testing.T) { }) } +func TestHandleSelectModelPersistsWorkspaceScope(t *testing.T) { + t.Parallel() + + providers := csync.NewMap[string, config.ProviderConfig]() + providers.Set("test-provider", config.ProviderConfig{ + ID: "test-provider", + APIKey: "test-key", + Models: []catwalk.Model{{ID: "large-model", Name: "Large Model"}}, + }) + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{}, + Providers: providers, + Options: &config.Options{TUI: &config.TUIOptions{}}, + } + ws := &testWorkspace{ + cfg: cfg, + defaultSmallModel: config.SelectedModel{ + Provider: "test-provider", + Model: "small-model", + }, + } + ui := New(&common.Common{Workspace: ws, Styles: ptr(styles.CharmtonePantera())}, "", false) + + ui.handleSelectModel(dialog.ActionSelectModel{ + Provider: catwalk.Provider{ID: "test-provider"}, + Model: config.SelectedModel{ + Provider: "test-provider", + Model: "large-model", + }, + ModelType: config.SelectedModelTypeLarge, + }) + + require.Len(t, ws.preferredModelUpdates, 2) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[0].scope) + require.Equal(t, config.SelectedModelTypeLarge, ws.preferredModelUpdates[0].modelType) + require.Equal(t, "large-model", ws.preferredModelUpdates[0].model.Model) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[1].scope) + require.Equal(t, config.SelectedModelTypeSmall, ws.preferredModelUpdates[1].modelType) + require.Equal(t, "small-model", ws.preferredModelUpdates[1].model.Model) +} + func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI { t.Helper() @@ -87,9 +131,54 @@ func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI { // testWorkspace is a minimal [workspace.Workspace] stub for unit tests. type testWorkspace struct { workspace.Workspace - cfg *config.Config + cfg *config.Config + defaultSmallModel config.SelectedModel + preferredModelUpdates []preferredModelUpdate +} + +type preferredModelUpdate struct { + scope config.Scope + modelType config.SelectedModelType + model config.SelectedModel } func (w *testWorkspace) Config() *config.Config { return w.cfg } + +func (w *testWorkspace) PermissionSkipRequests() bool { + return false +} + +func (w *testWorkspace) ProjectNeedsInitialization() (bool, error) { + return false, nil +} + +func (w *testWorkspace) AgentIsReady() bool { + return false +} + +func (w *testWorkspace) AgentIsBusy() bool { + return false +} + +func (w *testWorkspace) UpdatePreferredModel(scope config.Scope, modelType config.SelectedModelType, model config.SelectedModel) error { + w.preferredModelUpdates = append(w.preferredModelUpdates, preferredModelUpdate{ + scope: scope, + modelType: modelType, + model: model, + }) + return nil +} + +func (w *testWorkspace) GetDefaultSmallModel(string) config.SelectedModel { + return w.defaultSmallModel +} + +func (w *testWorkspace) UpdateAgentModel(context.Context) error { + return nil +} + +func ptr[T any](v T) *T { + return &v +} From 00d4e3184371c853b0f9685554adf09f5d8c1e6d Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:38:38 -0700 Subject: [PATCH 08/16] fix(ui): persist thinking and reasoning-effort toggles to workspace scope --- internal/ui/model/ui.go | 101 +++++++++++++++++++---------------- internal/ui/model/ui_test.go | 54 +++++++++++++++++++ 2 files changed, 110 insertions(+), 45 deletions(-) diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 436c105269..9ed1e62c09 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1471,29 +1471,7 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { } m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionToggleThinking: - cmds = append(cmds, func() tea.Msg { - cfg := m.com.Config() - if cfg == nil { - return util.ReportError(errors.New("configuration not found"))() - } - - agentCfg, ok := cfg.Agents[config.AgentCoder] - if !ok { - return util.ReportError(errors.New("agent configuration not found"))() - } - - currentModel := cfg.Models[agentCfg.Model] - currentModel.Think = !currentModel.Think - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { - return util.ReportError(err)() - } - m.com.Workspace.UpdateAgentModel(context.TODO()) - status := "disabled" - if currentModel.Think { - status = "enabled" - } - return util.NewInfoMsg("Thinking mode " + status) - }) + cmds = append(cmds, m.toggleThinking) m.dialog.CloseDialog(dialog.CommandsID) case dialog.ActionToggleTransparentBackground: cmds = append(cmds, func() tea.Msg { @@ -1542,29 +1520,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd { break } - cfg := m.com.Config() - if cfg == nil { - cmds = append(cmds, util.ReportError(errors.New("configuration not found"))) - break - } - - agentCfg, ok := cfg.Agents[config.AgentCoder] - if !ok { - cmds = append(cmds, util.ReportError(errors.New("agent configuration not found"))) - break - } - - currentModel := cfg.Models[agentCfg.Model] - currentModel.ReasoningEffort = msg.Effort - if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, agentCfg.Model, currentModel); err != nil { - cmds = append(cmds, util.ReportError(err)) - break + if cmd := m.selectReasoningEffort(msg.Effort); cmd != nil { + cmds = append(cmds, cmd) } - - cmds = append(cmds, func() tea.Msg { - m.com.Workspace.UpdateAgentModel(context.TODO()) - return util.NewInfoMsg("Reasoning effort set to " + msg.Effort) - }) m.dialog.CloseDialog(dialog.ReasoningID) case dialog.ActionPermissionResponse: m.dialog.CloseDialog(dialog.PermissionsID) @@ -1690,6 +1648,59 @@ func (m *UI) fetchHyperCredits() tea.Cmd { } } +// toggleThinking flips the thinking mode of the coder agent's current +// model and persists it at workspace scope so it isn't shadowed by +// workspace-scoped model selections. +func (m *UI) toggleThinking() tea.Msg { + cfg := m.com.Config() + if cfg == nil { + return util.ReportError(errors.New("configuration not found"))() + } + + agentCfg, ok := cfg.Agents[config.AgentCoder] + if !ok { + return util.ReportError(errors.New("agent configuration not found"))() + } + + currentModel := cfg.Models[agentCfg.Model] + currentModel.Think = !currentModel.Think + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, agentCfg.Model, currentModel); err != nil { + return util.ReportError(err)() + } + m.com.Workspace.UpdateAgentModel(context.TODO()) + status := "disabled" + if currentModel.Think { + status = "enabled" + } + return util.NewInfoMsg("Thinking mode " + status) +} + +// selectReasoningEffort sets the reasoning effort of the coder agent's +// current model and persists it at workspace scope so it isn't shadowed +// by workspace-scoped model selections. +func (m *UI) selectReasoningEffort(effort string) tea.Cmd { + cfg := m.com.Config() + if cfg == nil { + return util.ReportError(errors.New("configuration not found")) + } + + agentCfg, ok := cfg.Agents[config.AgentCoder] + if !ok { + return util.ReportError(errors.New("agent configuration not found")) + } + + currentModel := cfg.Models[agentCfg.Model] + currentModel.ReasoningEffort = effort + if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, agentCfg.Model, currentModel); err != nil { + return util.ReportError(err) + } + + return func() tea.Msg { + m.com.Workspace.UpdateAgentModel(context.TODO()) + return util.NewInfoMsg("Reasoning effort set to " + effort) + } +} + // handleSelectModel performs the model selection after any provider // pre-checks (such as a silent Hyper OAuth refresh) have completed. func (m *UI) handleSelectModel(msg dialog.ActionSelectModel) tea.Cmd { diff --git a/internal/ui/model/ui_test.go b/internal/ui/model/ui_test.go index 378b422df5..a0bc4d19db 100644 --- a/internal/ui/model/ui_test.go +++ b/internal/ui/model/ui_test.go @@ -118,6 +118,60 @@ func TestHandleSelectModelPersistsWorkspaceScope(t *testing.T) { require.Equal(t, "small-model", ws.preferredModelUpdates[1].model.Model) } +func TestToggleThinkingPersistsWorkspaceScope(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Provider: "test-provider", + Model: "large-model", + }, + }, + Providers: csync.NewMap[string, config.ProviderConfig](), + Agents: map[string]config.Agent{ + config.AgentCoder: {Model: config.SelectedModelTypeLarge}, + }, + } + ws := &testWorkspace{cfg: cfg} + ui := newTestUIWithConfig(t, cfg) + ui.com.Workspace = ws + + ui.toggleThinking() + + require.Len(t, ws.preferredModelUpdates, 1) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[0].scope) + require.Equal(t, config.SelectedModelTypeLarge, ws.preferredModelUpdates[0].modelType) + require.True(t, ws.preferredModelUpdates[0].model.Think) +} + +func TestSelectReasoningEffortPersistsWorkspaceScope(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Models: map[config.SelectedModelType]config.SelectedModel{ + config.SelectedModelTypeLarge: { + Provider: "test-provider", + Model: "large-model", + }, + }, + Providers: csync.NewMap[string, config.ProviderConfig](), + Agents: map[string]config.Agent{ + config.AgentCoder: {Model: config.SelectedModelTypeLarge}, + }, + } + ws := &testWorkspace{cfg: cfg} + ui := newTestUIWithConfig(t, cfg) + ui.com.Workspace = ws + + ui.selectReasoningEffort("high") + + require.Len(t, ws.preferredModelUpdates, 1) + require.Equal(t, config.ScopeWorkspace, ws.preferredModelUpdates[0].scope) + require.Equal(t, config.SelectedModelTypeLarge, ws.preferredModelUpdates[0].modelType) + require.Equal(t, "high", ws.preferredModelUpdates[0].model.ReasoningEffort) +} + func newTestUIWithConfig(t *testing.T, cfg *config.Config) *UI { t.Helper() From 009b0f7d6121db36822c5e81898f5e28d18ad455 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:40:47 -0700 Subject: [PATCH 09/16] fix(config): filter Copilot models by capability type instead of dated-ID regex --- internal/config/copilot_models.go | 8 ++++---- internal/config/copilot_models_test.go | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/internal/config/copilot_models.go b/internal/config/copilot_models.go index ea9ff37b18..2425b3fc1f 100644 --- a/internal/config/copilot_models.go +++ b/internal/config/copilot_models.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "net/http" - "regexp" "slices" "strings" "time" @@ -117,6 +116,7 @@ type copilotModel struct { } type copilotModelCapabilities struct { + Type string `json:"type"` Limits copilotModelLimits `json:"limits"` Supports copilotModelSupports `json:"supports"` } @@ -132,8 +132,6 @@ type copilotModelSupports struct { AdaptiveThinking bool `json:"adaptive_thinking"` } -var copilotVersionedModelRegexp = regexp.MustCompile(`-\d{4}-\d{2}-\d{2}$`) - func copilotModelsToCatwalkModels(response copilotModelsResponse) []catwalk.Model { aliasedVersions := make(map[string]bool, len(response.Data)) for _, model := range response.Data { @@ -168,9 +166,11 @@ func copilotModelsToCatwalkModels(response copilotModelsResponse) []catwalk.Mode } func shouldSkipCopilotModel(model copilotModel, aliasedVersions map[string]bool) bool { + if capType := model.Capabilities.Type; capType != "" && capType != "chat" { + return true + } return model.ID == "" || aliasedVersions[model.ID] || - copilotVersionedModelRegexp.MatchString(model.ID) || strings.Contains(model.ID, "embedding") || strings.HasPrefix(model.ID, "accounts/msft/routers") || strings.HasPrefix(model.ID, "oswe-vscode") || diff --git a/internal/config/copilot_models_test.go b/internal/config/copilot_models_test.go index 11143f7c2f..617ecddfc4 100644 --- a/internal/config/copilot_models_test.go +++ b/internal/config/copilot_models_test.go @@ -167,6 +167,21 @@ func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { Name: "Versioned Model", Version: "versioned-model-2026-01-01", }, + { + ID: "unaliased-model-2026-01-01", + Name: "Unaliased Dated Model", + Version: "unaliased-model-2026-01-01", + Capabilities: copilotModelCapabilities{ + Type: "chat", + }, + }, + { + ID: "some-embeddings-model", + Name: "Embeddings via Type", + Capabilities: copilotModelCapabilities{ + Type: "embeddings", + }, + }, {ID: "text-embedding-3", Name: "Embedding"}, {ID: "accounts/msft/routers/test", Name: "Router"}, {ID: "oswe-vscode-test", Name: "OSWE"}, @@ -190,5 +205,9 @@ func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { ReasoningLevels: []string{"low", "medium", "high"}, SupportsImages: true, }, + { + ID: "unaliased-model-2026-01-01", + Name: "Unaliased Dated Model", + }, }, models) } From b77cfdc51acd14104466e5a8e38e85fe2c0692ae Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:44:24 -0700 Subject: [PATCH 10/16] fix(config): make live provider cache write failures non-fatal --- internal/config/live_provider.go | 7 ++++--- internal/config/live_provider_test.go | 23 +++++++++++++++++++++++ internal/config/provider.go | 18 ++++++++---------- internal/config/provider_test.go | 6 ++---- 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/internal/config/live_provider.go b/internal/config/live_provider.go index ce9481380e..0418009630 100644 --- a/internal/config/live_provider.go +++ b/internal/config/live_provider.go @@ -53,7 +53,6 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { panic("called Get before Init") } - var throwErr error s.once.Do(func() { if !s.autoupdate { slog.Info("Using provider seed", "provider", s.seed.ID) @@ -106,9 +105,11 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { merged := mergeLiveProvider(s.seed, result) s.result = merged - throwErr = s.cache.Store(merged) + if err := s.cache.Store(merged); err != nil { + slog.Warn("Failed to store live provider cache", "provider", s.seed.ID, "err", err) + } }) - return s.result, throwErr + return s.result, nil } func cacheAge(path string) (time.Duration, bool) { diff --git a/internal/config/live_provider_test.go b/internal/config/live_provider_test.go index a47a9b3685..fccad01ccc 100644 --- a/internal/config/live_provider_test.go +++ b/internal/config/live_provider_test.go @@ -129,6 +129,29 @@ func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { require.Equal(t, provider, stored) } +func TestLiveProviderSync_GetStoreFailureStillReturnsMerged(t *testing.T) { + t.Parallel() + + // Point the cache at a path whose parent is a regular file so that + // Store fails; the merged provider must still be returned without + // error. + blocker := t.TempDir() + "/blocker" + require.NoError(t, os.WriteFile(blocker, []byte("not a dir"), 0o644)) + path := blocker + "/provider.json" + + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, seed, true) + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, 1, client.callCount) + require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, provider.Models) +} + func TestLiveProviderSync_GetNotModifiedUsesCached(t *testing.T) { t.Parallel() diff --git a/internal/config/provider.go b/internal/config/provider.go index 3a79e108ec..3c3aa61237 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -216,9 +216,7 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { items := slices.Collect(providers.Seq()) if !customProvidersOnly { - var liveErrs []error - items, liveErrs = overlayLiveProviderModels(ctx, cfg, items, autoupdate) - errs = append(errs, liveErrs...) + items = overlayLiveProviderModels(ctx, cfg, items, autoupdate) } if hyperFound { @@ -231,14 +229,13 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { return providerList, providerErr } -func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []catwalk.Provider, autoupdate bool) ([]catwalk.Provider, []error) { +func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []catwalk.Provider, autoupdate bool) []catwalk.Provider { if cfg == nil || len(providers) == 0 { - return providers, nil + return providers } environment := env.New() resolver := NewShellVariableResolver(environment) - errs := make([]error, 0, 2) syncProvider := func(providerID catwalk.InferenceProvider, cacheName string, syncer *liveProviderSync, newClient liveProviderClientFunc) { index := slices.IndexFunc(providers, func(provider catwalk.Provider) bool { @@ -258,16 +255,17 @@ func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []cat syncer.Init(client, cachePathFor(cacheName), autoupdate, seed, credentialed) provider, err := syncer.Get(ctx) if err != nil { - errs = append(errs, fmt.Errorf("crush was unable to cache updated models from %s: %w", seed.Name, err)) - return + slog.Warn("Live provider sync failed", "provider", providerID, "error", err) + } + if len(provider.Models) > 0 { + providers[index] = provider } - providers[index] = provider } syncProvider(catwalk.InferenceProviderVenice, "venice", veniceSyncer, newVeniceLiveProviderClient) syncProvider(catwalk.InferenceProviderCopilot, "copilot", copilotSyncer, newCopilotLiveProviderClient) - return providers, errs + return providers } type liveProviderClientFunc func(catwalk.Provider, *Config, VariableResolver, string) (liveProviderClient, bool, error) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 2d28aa13ea..43d9f06993 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -296,8 +296,7 @@ func TestProviders_Integration_LiveOverlayFetchesWithCredentials(t *testing.T) { }), } - result, errs := overlayLiveProviderModels(t.Context(), cfg, providers, true) - require.Empty(t, errs) + result := overlayLiveProviderModels(t.Context(), cfg, providers, true) venice, ok := findProvider(result, catwalk.InferenceProviderVenice) require.True(t, ok) @@ -344,8 +343,7 @@ func TestProviders_Integration_LiveOverlaySkipsWithoutCredentials(t *testing.T) require.NoError(t, newCache[catwalk.Provider](cachePathFor("venice")).Store(cached)) cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} - result, errs := overlayLiveProviderModels(t.Context(), cfg, []catwalk.Provider{seed}, true) - require.Empty(t, errs) + result := overlayLiveProviderModels(t.Context(), cfg, []catwalk.Provider{seed}, true) require.Equal(t, []catwalk.Provider{seed}, result) } From f0c80bf02c3f55196634f7db56d97f91d16d3ec7 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:45:40 -0700 Subject: [PATCH 11/16] refactor(config): remove unused liveProviderSync.Force API --- internal/config/live_provider.go | 7 +------ internal/config/live_provider_test.go | 21 --------------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/internal/config/live_provider.go b/internal/config/live_provider.go index 0418009630..e9ab127018 100644 --- a/internal/config/live_provider.go +++ b/internal/config/live_provider.go @@ -29,7 +29,6 @@ type liveProviderSync struct { seed catwalk.Provider autoupdate bool credentialed bool - force bool ttl time.Duration init atomic.Bool } @@ -44,10 +43,6 @@ func (s *liveProviderSync) Init(client liveProviderClient, path string, autoupda s.init.Store(true) } -func (s *liveProviderSync) Force() { - s.force = true -} - func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { if !s.init.Load() { panic("called Get before Init") @@ -72,7 +67,7 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { fallback = cached } - if cachedAvailable && !s.force { + if cachedAvailable { if age, ok := cacheAge(s.cache.path); ok && age < s.ttl { slog.Info("Using cached live provider models", "provider", fallback.ID, "age", age) s.result = cached diff --git a/internal/config/live_provider_test.go b/internal/config/live_provider_test.go index fccad01ccc..9d699438fe 100644 --- a/internal/config/live_provider_test.go +++ b/internal/config/live_provider_test.go @@ -246,27 +246,6 @@ func TestLiveProviderSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { require.Equal(t, 1, client.callCount) } -func TestLiveProviderSync_ForceBypassesWarmCache(t *testing.T) { - t.Parallel() - - path := t.TempDir() + "/provider.json" - cached := testLiveSeedProvider() - cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} - writeLiveProviderCache(t, path, cached) - - client := &mockLiveProviderClient{ - provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, - } - syncer := &liveProviderSync{} - syncer.Init(client, path, true, testLiveSeedProvider(), true) - syncer.Force() - - provider, err := syncer.Get(t.Context()) - require.NoError(t, err) - require.Equal(t, 1, client.callCount) - require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, provider.Models) -} - func TestCacheAge(t *testing.T) { t.Parallel() From 4845b6af5ef647f87b4b5d27d51563a8f4f51978 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:49:37 -0700 Subject: [PATCH 12/16] refactor(config): drop unreachable etag/304 plumbing for Venice and Copilot clients --- internal/config/copilot_models.go | 7 +------ internal/config/copilot_models_test.go | 22 ++-------------------- internal/config/live_provider.go | 11 +++-------- internal/config/live_provider_test.go | 9 +++------ internal/config/provider.go | 2 +- internal/config/venice_models.go | 7 +------ internal/config/venice_models_test.go | 20 +------------------- 7 files changed, 12 insertions(+), 66 deletions(-) diff --git a/internal/config/copilot_models.go b/internal/config/copilot_models.go index 2425b3fc1f..5416d4889c 100644 --- a/internal/config/copilot_models.go +++ b/internal/config/copilot_models.go @@ -12,7 +12,6 @@ import ( "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/oauth/copilot" - xetag "github.com/charmbracelet/x/etag" ) var _ liveProviderClient = (*realCopilotModelsClient)(nil) @@ -26,7 +25,7 @@ type realCopilotModelsClient struct { refreshToken copilotTokenRefresher } -func (r *realCopilotModelsClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { +func (r *realCopilotModelsClient) Get(ctx context.Context) (catwalk.Provider, error) { var result catwalk.Provider baseURL := strings.TrimRight(r.baseURL, "/") accessToken, err := r.accessToken(ctx) @@ -38,7 +37,6 @@ func (r *realCopilotModelsClient) Get(ctx context.Context, etag string) (catwalk if err != nil { return result, fmt.Errorf("could not create request: %w", err) } - xetag.Request(req, etag) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) for key, value := range copilot.Headers() { @@ -52,9 +50,6 @@ func (r *realCopilotModelsClient) Get(ctx context.Context, etag string) (catwalk } defer resp.Body.Close() //nolint:errcheck - if resp.StatusCode == http.StatusNotModified { - return result, catwalk.ErrNotModified - } if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } diff --git a/internal/config/copilot_models_test.go b/internal/config/copilot_models_test.go index 617ecddfc4..2837290302 100644 --- a/internal/config/copilot_models_test.go +++ b/internal/config/copilot_models_test.go @@ -2,7 +2,6 @@ package config import ( "context" - "errors" "net/http" "net/http/httptest" "testing" @@ -20,7 +19,6 @@ func TestRealCopilotModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "/models", r.URL.Path) require.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) - require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) require.Equal(t, "application/json", r.Header.Get("Accept")) for key, value := range copilot.Headers() { require.Equal(t, value, r.Header.Get(key)) @@ -51,7 +49,7 @@ func TestRealCopilotModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { defer server.Close() client := &realCopilotModelsClient{baseURL: server.URL + "/", apiKey: " test-token "} - provider, err := client.Get(t.Context(), "cached-etag") + provider, err := client.Get(t.Context()) require.NoError(t, err) require.Equal(t, catwalk.InferenceProviderCopilot, provider.ID) @@ -112,29 +110,13 @@ func TestRealCopilotModelsClientGetRefreshesExpiredOAuthToken(t *testing.T) { }, } - provider, err := client.Get(t.Context(), "") + provider, err := client.Get(t.Context()) require.NoError(t, err) require.Equal(t, []catwalk.Model{{ID: "gpt-5.1", Name: "GPT 5.1", ContextWindow: 128000, DefaultMaxTokens: 16000}}, provider.Models) require.Equal(t, "refreshed-token", client.apiKey) } -func TestRealCopilotModelsClientGetNotModified(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) - w.WriteHeader(http.StatusNotModified) - })) - defer server.Close() - - client := &realCopilotModelsClient{baseURL: server.URL, apiKey: "test-token"} - provider, err := client.Get(t.Context(), "cached-etag") - - require.True(t, errors.Is(err, catwalk.ErrNotModified)) - require.Empty(t, provider.Models) -} - func TestCopilotModelsToCatwalkModelsFiltersAndDeduplicates(t *testing.T) { t.Parallel() diff --git a/internal/config/live_provider.go b/internal/config/live_provider.go index e9ab127018..897a428025 100644 --- a/internal/config/live_provider.go +++ b/internal/config/live_provider.go @@ -16,7 +16,7 @@ import ( const liveModelsTTL = time.Minute type liveProviderClient interface { - Get(context.Context, string) (catwalk.Provider, error) + Get(context.Context) (catwalk.Provider, error) } var _ syncer[catwalk.Provider] = (*liveProviderSync)(nil) @@ -60,7 +60,7 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { return } - cached, etag, cachedErr := s.cache.Get() + cached, _, cachedErr := s.cache.Get() cachedAvailable := cachedErr == nil && len(cached.Models) > 0 fallback := s.seed if cachedAvailable { @@ -76,17 +76,12 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { } slog.Info("Fetching live provider models", "provider", s.seed.ID) - result, err := s.client.Get(ctx, etag) + result, err := s.client.Get(ctx) if errors.Is(err, context.DeadlineExceeded) { slog.Warn("Live provider models not updated in time", "provider", s.seed.ID) s.result = fallback return } - if errors.Is(err, catwalk.ErrNotModified) { - slog.Info("Live provider models not modified", "provider", s.seed.ID) - s.result = fallback - return - } if err != nil { slog.Warn("Live provider models not updated", "provider", s.seed.ID, "err", err) s.result = fallback diff --git a/internal/config/live_provider_test.go b/internal/config/live_provider_test.go index 9d699438fe..2d8488a396 100644 --- a/internal/config/live_provider_test.go +++ b/internal/config/live_provider_test.go @@ -16,12 +16,10 @@ type mockLiveProviderClient struct { provider catwalk.Provider err error callCount int - etags []string } -func (m *mockLiveProviderClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { +func (m *mockLiveProviderClient) Get(ctx context.Context) (catwalk.Provider, error) { m.callCount++ - m.etags = append(m.etags, etag) return m.provider, m.err } @@ -112,7 +110,6 @@ func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, 1, client.callCount) - require.NotEmpty(t, client.etags[0]) require.Equal(t, seed.ID, provider.ID) require.Equal(t, seed.Name, provider.Name) require.Equal(t, seed.APIEndpoint, provider.APIEndpoint) @@ -152,7 +149,7 @@ func TestLiveProviderSync_GetStoreFailureStillReturnsMerged(t *testing.T) { require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, provider.Models) } -func TestLiveProviderSync_GetNotModifiedUsesCached(t *testing.T) { +func TestLiveProviderSync_GetFetchErrorUsesCached(t *testing.T) { t.Parallel() path := t.TempDir() + "/provider.json" @@ -163,7 +160,7 @@ func TestLiveProviderSync_GetNotModifiedUsesCached(t *testing.T) { staleTime := time.Now().Add(-2 * liveModelsTTL) require.NoError(t, os.Chtimes(path, staleTime, staleTime)) - client := &mockLiveProviderClient{err: catwalk.ErrNotModified} + client := &mockLiveProviderClient{err: errors.New("network error")} syncer := &liveProviderSync{} syncer.Init(client, path, true, testLiveSeedProvider(), true) diff --git a/internal/config/provider.go b/internal/config/provider.go index 3c3aa61237..c07a0726ae 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -379,7 +379,7 @@ func updateLiveProvider(name, cacheName string, providerID catwalk.InferenceProv ctx, cancel := context.WithTimeout(context.Background(), 45*time.Second) defer cancel() - provider, err = client.Get(ctx, "") + provider, err = client.Get(ctx) if err != nil { return fmt.Errorf("failed to fetch provider from %s: %w", name, err) } diff --git a/internal/config/venice_models.go b/internal/config/venice_models.go index de99fe2b98..9c76cf7fc8 100644 --- a/internal/config/venice_models.go +++ b/internal/config/venice_models.go @@ -10,7 +10,6 @@ import ( "time" "charm.land/catwalk/pkg/catwalk" - xetag "github.com/charmbracelet/x/etag" ) var _ liveProviderClient = realVeniceModelsClient{} @@ -20,14 +19,13 @@ type realVeniceModelsClient struct { apiKey string } -func (r realVeniceModelsClient) Get(ctx context.Context, etag string) (catwalk.Provider, error) { +func (r realVeniceModelsClient) Get(ctx context.Context) (catwalk.Provider, error) { var result catwalk.Provider baseURL := strings.TrimRight(r.baseURL, "/") req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil) if err != nil { return result, fmt.Errorf("could not create request: %w", err) } - xetag.Request(req, etag) if apiKey := strings.TrimSpace(r.apiKey); apiKey != "" { req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -39,9 +37,6 @@ func (r realVeniceModelsClient) Get(ctx context.Context, etag string) (catwalk.P } defer resp.Body.Close() //nolint:errcheck - if resp.StatusCode == http.StatusNotModified { - return result, catwalk.ErrNotModified - } if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode) } diff --git a/internal/config/venice_models_test.go b/internal/config/venice_models_test.go index c2fca83f95..9129f66341 100644 --- a/internal/config/venice_models_test.go +++ b/internal/config/venice_models_test.go @@ -1,7 +1,6 @@ package config import ( - "errors" "net/http" "net/http/httptest" "testing" @@ -16,7 +15,6 @@ func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "/models", r.URL.Path) require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) - require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) w.Header().Set("Content-Type", "application/json") _, err := w.Write([]byte(`{ @@ -69,7 +67,7 @@ func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { defer server.Close() client := realVeniceModelsClient{baseURL: server.URL + "/", apiKey: " test-key "} - provider, err := client.Get(t.Context(), "cached-etag") + provider, err := client.Get(t.Context()) require.NoError(t, err) require.Equal(t, catwalk.InferenceProviderVenice, provider.ID) @@ -98,22 +96,6 @@ func TestRealVeniceModelsClientGetMapsModelsAndSendsHeaders(t *testing.T) { }, provider.Models) } -func TestRealVeniceModelsClientGetNotModified(t *testing.T) { - t.Parallel() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, `"cached-etag"`, r.Header.Get("If-None-Match")) - w.WriteHeader(http.StatusNotModified) - })) - defer server.Close() - - client := realVeniceModelsClient{baseURL: server.URL} - provider, err := client.Get(t.Context(), "cached-etag") - - require.True(t, errors.Is(err, catwalk.ErrNotModified)) - require.Empty(t, provider.Models) -} - func TestVeniceModelsToCatwalkModelsFiltersNonTextAndOfflineModels(t *testing.T) { t.Parallel() From 79e53158c558a2bdc73f5145d25c5f7a309cb8da Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:52:13 -0700 Subject: [PATCH 13/16] fix(cmd): avoid double-fetching live providers and surface skipped updates in update-providers --- internal/cmd/update_providers.go | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/internal/cmd/update_providers.go b/internal/cmd/update_providers.go index 3839a5e346..ae9c017c45 100644 --- a/internal/cmd/update_providers.go +++ b/internal/cmd/update_providers.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" "log/slog" + "os" "charm.land/lipgloss/v2" "github.com/charmbracelet/crush/internal/config" @@ -99,14 +100,14 @@ crush update-providers --source=copilot func updateAuthenticatedLiveProviders(cmd *cobra.Command) { cfg, err := loadUpdateProvidersConfig(cmd) if err != nil { - slog.Debug("Skipping live provider updates", "error", err) + fmt.Fprintf(os.Stderr, "Note: skipping Venice and Copilot updates: %v\n", err) return } if err := config.UpdateVenice("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) { - slog.Debug("Skipping Venice provider update", "error", err) + fmt.Fprintf(os.Stderr, "Note: skipping Venice update: %v\n", err) } if err := config.UpdateCopilot("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) { - slog.Debug("Skipping Copilot provider update", "error", err) + fmt.Fprintf(os.Stderr, "Note: skipping Copilot update: %v\n", err) } } @@ -117,6 +118,21 @@ func loadUpdateProvidersConfig(cmd *cobra.Command) (*config.Config, error) { } dataDir, _ := cmd.Flags().GetString("data-dir") debug, _ := cmd.Flags().GetBool("debug") + + // The config is loaded only to resolve provider credentials; + // updateLiveProvider fetches explicitly afterwards. Disable provider + // auto-update during Load so it doesn't fetch the same endpoints + // first (avoiding a double fetch per provider). + previous, hadPrevious := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE") + _ = os.Setenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE", "1") + defer func() { + if hadPrevious { + _ = os.Setenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE", previous) + } else { + _ = os.Unsetenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE") + } + }() + store, err := config.Load(cwd, dataDir, debug) if err != nil { return nil, err From bdb62622702b8fb4288e45d07629b5ddf453273e Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 08:53:58 -0700 Subject: [PATCH 14/16] refactor(config): resolve VENICE_API_KEY fallback through the variable resolver --- internal/config/provider.go | 6 +++++- internal/config/provider_test.go | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/internal/config/provider.go b/internal/config/provider.go index c07a0726ae..09e371fd29 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -298,7 +298,11 @@ func newVeniceLiveProviderClient(seed catwalk.Provider, cfg *Config, resolver Va apiKey = strings.TrimSpace(resolved) } if apiKey == "" { - apiKey = strings.TrimSpace(os.Getenv("VENICE_API_KEY")) + resolved, err := resolver.ResolveValue("$VENICE_API_KEY") + if err != nil { + return nil, false, fmt.Errorf("failed to resolve Venice API key from environment: %w", err) + } + apiKey = strings.TrimSpace(resolved) } return realVeniceModelsClient{baseURL: baseURL, apiKey: apiKey}, apiKey != "", nil diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 43d9f06993..461b90b2e3 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -11,6 +11,7 @@ import ( "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/env" "github.com/stretchr/testify/require" ) @@ -347,6 +348,26 @@ func TestProviders_Integration_LiveOverlaySkipsWithoutCredentials(t *testing.T) require.Equal(t, []catwalk.Provider{seed}, result) } +func TestNewVeniceLiveProviderClient_EnvFallbackUsesResolver(t *testing.T) { + t.Parallel() + + seed := catwalk.Provider{ + ID: catwalk.InferenceProviderVenice, + APIEndpoint: "https://api.venice.ai/api/v1", + } + cfg := &Config{Options: &Options{}, Providers: csync.NewMap[string, ProviderConfig]()} + resolver := NewShellVariableResolver(env.NewFromMap(map[string]string{ + "VENICE_API_KEY": " env-venice-key ", + })) + + client, credentialed, err := newVeniceLiveProviderClient(seed, cfg, resolver, "") + require.NoError(t, err) + require.True(t, credentialed) + veniceClient, ok := client.(realVeniceModelsClient) + require.True(t, ok) + require.Equal(t, "env-venice-key", veniceClient.apiKey) +} + func TestUpdateVenice_LiveSourceStoresCache(t *testing.T) { resetProviderState() defer resetProviderState() From a956eb0b646f6cca23f5f64e5dcac556436a3c15 Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Wed, 10 Jun 2026 09:22:54 -0700 Subject: [PATCH 15/16] fix(ui): report UpdateAgentModel failures in thinking and reasoning-effort toggles Assisted-by: Claude Opus 4.6 (via Crush) --- internal/ui/model/ui.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 9ed1e62c09..ca30d20fbd 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -1667,7 +1667,9 @@ func (m *UI) toggleThinking() tea.Msg { if err := m.com.Workspace.UpdatePreferredModel(config.ScopeWorkspace, agentCfg.Model, currentModel); err != nil { return util.ReportError(err)() } - m.com.Workspace.UpdateAgentModel(context.TODO()) + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { + return util.ReportError(err)() + } status := "disabled" if currentModel.Think { status = "enabled" @@ -1696,7 +1698,9 @@ func (m *UI) selectReasoningEffort(effort string) tea.Cmd { } return func() tea.Msg { - m.com.Workspace.UpdateAgentModel(context.TODO()) + if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil { + return util.ReportError(err)() + } return util.NewInfoMsg("Reasoning effort set to " + effort) } } From b24ee6d985304529f1f2ad55d4a80e59de940b1e Mon Sep 17 00:00:00 2001 From: Greg Slepak Date: Sat, 13 Jun 2026 20:28:15 -0700 Subject: [PATCH 16/16] feat(config): refresh Copilot/Venice live models in the background Make live provider model refresh non-blocking on startup and propagate completed refreshes into the running session. The TUI now starts with the catwalk seed or warm cache immediately while stale/missing live caches refresh in the background. A syncer completion callback updates a lock-guarded provider list and publishes a ProvidersUpdatedEvent over a config pubsub broker, which the UI uses to repopulate an open model dialog without restarting Crush. --- internal/app/app.go | 1 + internal/config/live_provider.go | 50 ++++++-- internal/config/live_provider_test.go | 177 +++++++++++++++++++++++--- internal/config/provider.go | 100 ++++++++++++++- internal/config/provider_test.go | 133 ++++++++++++++++--- internal/ui/dialog/models.go | 13 ++ internal/ui/model/ui.go | 9 ++ 7 files changed, 429 insertions(+), 54 deletions(-) diff --git a/internal/app/app.go b/internal/app/app.go index d8a3abc63b..a5e6666cde 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -504,6 +504,7 @@ func (app *App) setupEvents() { setupSubscriberMustDeliver(ctx, app.serviceEventsWG, "run-completions", app.runCompletions.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events) setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events) + setupSubscriber(ctx, app.serviceEventsWG, "providers", config.SubscribeProviderEvents, app.events) if app.Skills != nil { setupSubscriber(ctx, app.serviceEventsWG, "skills", app.Skills.SubscribeEvents, app.events) } diff --git a/internal/config/live_provider.go b/internal/config/live_provider.go index 897a428025..e04562a742 100644 --- a/internal/config/live_provider.go +++ b/internal/config/live_provider.go @@ -13,7 +13,10 @@ import ( "charm.land/catwalk/pkg/catwalk" ) -const liveModelsTTL = time.Minute +const ( + liveModelsTTL = time.Minute + liveProviderFetchTimeout = 45 * time.Second +) type liveProviderClient interface { Get(context.Context) (catwalk.Provider, error) @@ -23,9 +26,11 @@ var _ syncer[catwalk.Provider] = (*liveProviderSync)(nil) type liveProviderSync struct { once sync.Once + resultMu sync.RWMutex result catwalk.Provider cache cache[catwalk.Provider] client liveProviderClient + onRefresh func(catwalk.Provider) seed catwalk.Provider autoupdate bool credentialed bool @@ -51,12 +56,12 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { s.once.Do(func() { if !s.autoupdate { slog.Info("Using provider seed", "provider", s.seed.ID) - s.result = s.seed + s.setResult(s.seed) return } if !s.credentialed { slog.Info("Skipping live provider sync without credentials", "provider", s.seed.ID) - s.result = s.seed + s.setResult(s.seed) return } @@ -70,36 +75,59 @@ func (s *liveProviderSync) Get(ctx context.Context) (catwalk.Provider, error) { if cachedAvailable { if age, ok := cacheAge(s.cache.path); ok && age < s.ttl { slog.Info("Using cached live provider models", "provider", fallback.ID, "age", age) - s.result = cached + s.setResult(cached) return } } - slog.Info("Fetching live provider models", "provider", s.seed.ID) + s.setResult(fallback) + s.refreshInBackground() + }) + return s.getResult(), nil +} + +func (s *liveProviderSync) refreshInBackground() { + slog.Info("Refreshing live provider models in background", "provider", s.seed.ID) + go func() { + ctx, cancel := context.WithTimeout(context.Background(), liveProviderFetchTimeout) + defer cancel() + result, err := s.client.Get(ctx) if errors.Is(err, context.DeadlineExceeded) { slog.Warn("Live provider models not updated in time", "provider", s.seed.ID) - s.result = fallback return } if err != nil { slog.Warn("Live provider models not updated", "provider", s.seed.ID, "err", err) - s.result = fallback return } if len(result.Models) == 0 { slog.Warn("Live provider did not return any models", "provider", s.seed.ID) - s.result = fallback return } merged := mergeLiveProvider(s.seed, result) - s.result = merged + s.setResult(merged) if err := s.cache.Store(merged); err != nil { slog.Warn("Failed to store live provider cache", "provider", s.seed.ID, "err", err) + return } - }) - return s.result, nil + if s.onRefresh != nil { + s.onRefresh(merged) + } + }() +} + +func (s *liveProviderSync) setResult(provider catwalk.Provider) { + s.resultMu.Lock() + defer s.resultMu.Unlock() + s.result = provider +} + +func (s *liveProviderSync) getResult() catwalk.Provider { + s.resultMu.RLock() + defer s.resultMu.RUnlock() + return s.result } func cacheAge(path string) (time.Duration, bool) { diff --git a/internal/config/live_provider_test.go b/internal/config/live_provider_test.go index 2d8488a396..a95a6f5ead 100644 --- a/internal/config/live_provider_test.go +++ b/internal/config/live_provider_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "os" + "sync" "testing" "time" @@ -13,16 +14,40 @@ import ( ) type mockLiveProviderClient struct { - provider catwalk.Provider - err error + provider catwalk.Provider + err error + started chan struct{} + release chan struct{} + + mu sync.Mutex callCount int + startOnce sync.Once } func (m *mockLiveProviderClient) Get(ctx context.Context) (catwalk.Provider, error) { + m.mu.Lock() m.callCount++ + m.mu.Unlock() + + if m.started != nil { + m.startOnce.Do(func() { close(m.started) }) + } + if m.release != nil { + select { + case <-m.release: + case <-ctx.Done(): + return catwalk.Provider{}, ctx.Err() + } + } return m.provider, m.err } +func (m *mockLiveProviderClient) calls() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.callCount +} + func TestLiveProviderSync_GetPanicIfNotInit(t *testing.T) { t.Parallel() @@ -45,7 +70,7 @@ func TestLiveProviderSync_GetAutoUpdateDisabledReturnsSeed(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, seed, provider) - require.Equal(t, 0, client.callCount) + require.Equal(t, 0, client.calls()) } func TestLiveProviderSync_GetWithoutCredentialsReturnsSeed(t *testing.T) { @@ -61,7 +86,7 @@ func TestLiveProviderSync_GetWithoutCredentialsReturnsSeed(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, seed, provider) - require.Equal(t, 0, client.callCount) + require.Equal(t, 0, client.calls()) } func TestLiveProviderSync_GetWarmCacheReturnsCachedWithoutFetch(t *testing.T) { @@ -82,10 +107,10 @@ func TestLiveProviderSync_GetWarmCacheReturnsCachedWithoutFetch(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, cached, provider) - require.Equal(t, 0, client.callCount) + require.Equal(t, 0, client.calls()) } -func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { +func TestLiveProviderSync_GetStaleCacheReturnsCachedAndRefreshesInBackground(t *testing.T) { t.Parallel() path := t.TempDir() + "/provider.json" @@ -96,6 +121,8 @@ func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { require.NoError(t, os.Chtimes(path, staleTime, staleTime)) seed := testLiveSeedProvider() + started := make(chan struct{}) + release := make(chan struct{}) client := &mockLiveProviderClient{ provider: catwalk.Provider{ Models: []catwalk.Model{ @@ -103,13 +130,26 @@ func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { {ID: "other-live-model", Name: "Other Live Model"}, }, }, + started: started, + release: release, } syncer := &liveProviderSync{} syncer.Init(client, path, true, seed, true) provider, err := syncer.Get(t.Context()) require.NoError(t, err) - require.Equal(t, 1, client.callCount) + require.Equal(t, cached, provider) + <-started + require.Equal(t, 1, client.calls()) + + close(release) + require.Eventually(t, func() bool { + stored, _, err := newCache[catwalk.Provider](path).Get() + return err == nil && len(stored.Models) == 2 && stored.Models[0].ID == "live-model" + }, time.Second, 10*time.Millisecond) + + provider, err = syncer.Get(t.Context()) + require.NoError(t, err) require.Equal(t, seed.ID, provider.ID) require.Equal(t, seed.Name, provider.Name) require.Equal(t, seed.APIEndpoint, provider.APIEndpoint) @@ -120,33 +160,122 @@ func TestLiveProviderSync_GetStaleCacheFetchesMergesAndStores(t *testing.T) { {ID: "live-model", Name: "Live Model"}, {ID: "other-live-model", Name: "Other Live Model"}, }, provider.Models) +} + +func TestLiveProviderSync_GetBackgroundRefreshInvokesCallbackAfterStore(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + seed := testLiveSeedProvider() + started := make(chan struct{}) + release := make(chan struct{}) + callbackCh := make(chan catwalk.Provider, 1) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + started: started, + release: release, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, seed, true) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } - stored, _, err := newCache[catwalk.Provider](path).Get() + provider, err := syncer.Get(t.Context()) require.NoError(t, err) - require.Equal(t, provider, stored) + require.Equal(t, seed, provider) + <-started + close(release) + + var callbackProvider catwalk.Provider + require.Eventually(t, func() bool { + select { + case callbackProvider = <-callbackCh: + return true + default: + return false + } + }, time.Second, 10*time.Millisecond) + require.Equal(t, seed.ID, callbackProvider.ID) + require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, callbackProvider.Models) + require.Empty(t, callbackCh) } -func TestLiveProviderSync_GetStoreFailureStillReturnsMerged(t *testing.T) { +func TestLiveProviderSync_GetWarmCacheDoesNotInvokeCallback(t *testing.T) { + t.Parallel() + + path := t.TempDir() + "/provider.json" + cached := testLiveSeedProvider() + cached.Models = []catwalk.Model{{ID: "cached-model", Name: "Cached Model"}} + writeLiveProviderCache(t, path, cached) + callbackCh := make(chan catwalk.Provider, 1) + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, path, true, testLiveSeedProvider(), true) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, cached, provider) + require.Equal(t, 0, client.calls()) + require.Empty(t, callbackCh) +} + +func TestLiveProviderSync_GetWithoutCredentialsDoesNotInvokeCallback(t *testing.T) { + t.Parallel() + + callbackCh := make(chan catwalk.Provider, 1) + seed := testLiveSeedProvider() + client := &mockLiveProviderClient{ + provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model"}}}, + } + syncer := &liveProviderSync{} + syncer.Init(client, t.TempDir()+"/provider.json", true, seed, false) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } + + provider, err := syncer.Get(t.Context()) + require.NoError(t, err) + require.Equal(t, seed, provider) + require.Equal(t, 0, client.calls()) + require.Empty(t, callbackCh) +} + +func TestLiveProviderSync_GetStoreFailureStillUsesMergedResult(t *testing.T) { t.Parallel() - // Point the cache at a path whose parent is a regular file so that - // Store fails; the merged provider must still be returned without - // error. blocker := t.TempDir() + "/blocker" require.NoError(t, os.WriteFile(blocker, []byte("not a dir"), 0o644)) path := blocker + "/provider.json" seed := testLiveSeedProvider() + started := make(chan struct{}) + callbackCh := make(chan catwalk.Provider, 1) client := &mockLiveProviderClient{ provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + started: started, } syncer := &liveProviderSync{} syncer.Init(client, path, true, seed, true) + syncer.onRefresh = func(provider catwalk.Provider) { + callbackCh <- provider + } provider, err := syncer.Get(t.Context()) require.NoError(t, err) - require.Equal(t, 1, client.callCount) - require.Equal(t, []catwalk.Model{{ID: "live-model", Name: "Live Model"}}, provider.Models) + require.Equal(t, seed, provider) + <-started + require.Eventually(t, func() bool { + provider, err := syncer.Get(t.Context()) + return err == nil && len(provider.Models) == 1 && provider.Models[0].ID == "live-model" + }, time.Second, 10*time.Millisecond) + require.Equal(t, 1, client.calls()) + require.Empty(t, callbackCh) } func TestLiveProviderSync_GetFetchErrorUsesCached(t *testing.T) { @@ -167,7 +296,7 @@ func TestLiveProviderSync_GetFetchErrorUsesCached(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, cached, provider) - require.Equal(t, 1, client.callCount) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) } func TestLiveProviderSync_GetDeadlineExceededUsesCached(t *testing.T) { @@ -188,7 +317,7 @@ func TestLiveProviderSync_GetDeadlineExceededUsesCached(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, cached, provider) - require.Equal(t, 1, client.callCount) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) } func TestLiveProviderSync_GetFetchErrorUsesSeedWithoutCache(t *testing.T) { @@ -202,7 +331,7 @@ func TestLiveProviderSync_GetFetchErrorUsesSeedWithoutCache(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, seed, provider) - require.Equal(t, 1, client.callCount) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) } func TestLiveProviderSync_GetEmptyModelsUsesFallback(t *testing.T) { @@ -223,14 +352,18 @@ func TestLiveProviderSync_GetEmptyModelsUsesFallback(t *testing.T) { provider, err := syncer.Get(t.Context()) require.NoError(t, err) require.Equal(t, cached, provider) - require.Equal(t, 1, client.callCount) + require.Eventually(t, func() bool { return client.calls() == 1 }, time.Second, 10*time.Millisecond) } -func TestLiveProviderSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { +func TestLiveProviderSync_GetCalledMultipleTimesSchedulesOneRefresh(t *testing.T) { t.Parallel() + started := make(chan struct{}) + release := make(chan struct{}) client := &mockLiveProviderClient{ provider: catwalk.Provider{Models: []catwalk.Model{{ID: "live-model", Name: "Live Model"}}}, + started: started, + release: release, } syncer := &liveProviderSync{} syncer.Init(client, t.TempDir()+"/provider.json", true, testLiveSeedProvider(), true) @@ -240,7 +373,9 @@ func TestLiveProviderSync_GetCalledMultipleTimesUsesOnce(t *testing.T) { provider2, err2 := syncer.Get(t.Context()) require.NoError(t, err2) require.Equal(t, provider1, provider2) - require.Equal(t, 1, client.callCount) + <-started + require.Equal(t, 1, client.calls()) + close(release) } func TestCacheAge(t *testing.T) { diff --git a/internal/config/provider.go b/internal/config/provider.go index 09e371fd29..82ac581fe7 100644 --- a/internal/config/provider.go +++ b/internal/config/provider.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "os" "path/filepath" "runtime" @@ -22,6 +23,7 @@ import ( "github.com/charmbracelet/crush/internal/env" "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/oauth" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/x/etag" ) @@ -30,11 +32,25 @@ type syncer[T any] interface { } var ( - providerOnce sync.Once - providerList []catwalk.Provider - providerErr error + providerOnce sync.Once + providerMu sync.RWMutex + providerList []catwalk.Provider + providerErr error + pendingProvider map[catwalk.InferenceProvider]catwalk.Provider + + providerEvents = pubsub.NewBroker[ProvidersUpdatedEvent]() ) +// ProvidersUpdatedEvent reports that a provider's model list changed. +type ProvidersUpdatedEvent struct { + ProviderID string +} + +// SubscribeProviderEvents returns provider update events. +func SubscribeProviderEvents(ctx context.Context) <-chan pubsub.Event[ProvidersUpdatedEvent] { + return providerEvents.Subscribe(ctx) +} + var errMissingLiveProviderCredentials = errors.New("missing live provider credentials") // IsMissingLiveProviderCredentials reports whether err means a live provider @@ -220,13 +236,84 @@ func Providers(cfg *Config) ([]catwalk.Provider, error) { } if hyperFound { - providerList = append([]catwalk.Provider{hyperProvider}, items...) + setProviderList(append([]catwalk.Provider{hyperProvider}, items...)) } else { - providerList = items + setProviderList(items) } providerErr = errors.Join(errs...) }) - return providerList, providerErr + return providerSnapshot(), providerErr +} + +func setProviderList(providers []catwalk.Provider) { + providerMu.Lock() + defer providerMu.Unlock() + + providers = cloneProviders(providers) + for providerID, provider := range pendingProvider { + index := slices.IndexFunc(providers, func(item catwalk.Provider) bool { + return item.ID == providerID + }) + if index >= 0 { + providers[index] = provider + } + } + pendingProvider = nil + providerList = providers +} + +func providerSnapshot() []catwalk.Provider { + providerMu.RLock() + defer providerMu.RUnlock() + return cloneProviders(providerList) +} + +func updateProviderList(provider catwalk.Provider) { + providerMu.Lock() + defer providerMu.Unlock() + + provider = cloneProvider(provider) + index := slices.IndexFunc(providerList, func(item catwalk.Provider) bool { + return item.ID == provider.ID + }) + if index < 0 { + if pendingProvider == nil { + pendingProvider = make(map[catwalk.InferenceProvider]catwalk.Provider) + } + pendingProvider[provider.ID] = provider + return + } + providerList = slices.Clone(providerList) + providerList[index] = provider +} + +func publishProviderUpdated(provider catwalk.Provider) { + updateProviderList(provider) + providerEvents.Publish(pubsub.UpdatedEvent, ProvidersUpdatedEvent{ProviderID: string(provider.ID)}) +} + +func cloneProviders(providers []catwalk.Provider) []catwalk.Provider { + if providers == nil { + return nil + } + cloned := make([]catwalk.Provider, len(providers)) + for i, provider := range providers { + cloned[i] = cloneProvider(provider) + } + return cloned +} + +func cloneProvider(provider catwalk.Provider) catwalk.Provider { + provider.Models = slices.Clone(provider.Models) + provider.DefaultHeaders = cloneStringMap(provider.DefaultHeaders) + return provider +} + +func cloneStringMap(values map[string]string) map[string]string { + if values == nil { + return nil + } + return maps.Clone(values) } func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []catwalk.Provider, autoupdate bool) []catwalk.Provider { @@ -253,6 +340,7 @@ func overlayLiveProviderModels(ctx context.Context, cfg *Config, providers []cat } syncer.Init(client, cachePathFor(cacheName), autoupdate, seed, credentialed) + syncer.onRefresh = publishProviderUpdated provider, err := syncer.Get(ctx) if err != nil { slog.Warn("Live provider sync failed", "provider", providerID, "error", err) diff --git a/internal/config/provider_test.go b/internal/config/provider_test.go index 461b90b2e3..fb4cdab98c 100644 --- a/internal/config/provider_test.go +++ b/internal/config/provider_test.go @@ -1,24 +1,31 @@ package config import ( + "context" "encoding/json" "net/http" "net/http/httptest" "os" "path/filepath" "sync" + "sync/atomic" "testing" + "time" "charm.land/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/env" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/stretchr/testify/require" ) func resetProviderState() { providerOnce = sync.Once{} + providerMu.Lock() providerList = nil providerErr = nil + pendingProvider = nil + providerMu.Unlock() catwalkSyncer = &catwalkSync{} hyperSyncer = &hyperSync{} veniceSyncer = &liveProviderSync{} @@ -223,15 +230,21 @@ func TestProviders_Integration_BothFail(t *testing.T) { require.Equal(t, "Charm Hyper", hyperResult.Name) // Falls back to embedded when no models. } -func TestProviders_Integration_LiveOverlayFetchesWithCredentials(t *testing.T) { +func TestProviders_Integration_LiveOverlayRefreshesWithCredentialsInBackground(t *testing.T) { resetProviderState() defer resetProviderState() tmpDir := t.TempDir() t.Setenv("XDG_DATA_HOME", tmpDir) + started := make(chan struct{}, 2) + release := make(chan struct{}) + var requestCount atomic.Int32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, "/models", r.URL.Path) + requestCount.Add(1) + started <- struct{}{} + <-release switch r.Header.Get("Authorization") { case "Bearer venice-token": @@ -297,12 +310,112 @@ func TestProviders_Integration_LiveOverlayFetchesWithCredentials(t *testing.T) { }), } - result := overlayLiveProviderModels(t.Context(), cfg, providers, true) + resultCh := make(chan []catwalk.Provider, 1) + go func() { + resultCh <- overlayLiveProviderModels(t.Context(), cfg, providers, true) + }() + + var result []catwalk.Provider + select { + case result = <-resultCh: + case <-time.After(100 * time.Millisecond): + close(release) + require.FailNow(t, "Live overlay blocked on background refresh") + } venice, ok := findProvider(result, catwalk.InferenceProviderVenice) require.True(t, ok) - require.Equal(t, "Venice", venice.Name) - require.Equal(t, "venice-live", venice.DefaultLargeModelID) + require.Equal(t, []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, venice.Models) + + copilot, ok := findProvider(result, catwalk.InferenceProviderCopilot) + require.True(t, ok) + require.Equal(t, []catwalk.Model{{ID: "copilot-seed", Name: "Copilot Seed"}}, copilot.Models) + + require.Eventually(t, func() bool { return requestCount.Load() == 2 }, time.Second, 10*time.Millisecond) + <-started + <-started + close(release) + + require.Eventually(t, func() bool { + venice, _, err := newCache[catwalk.Provider](cachePathFor("venice")).Get() + return err == nil && len(venice.Models) == 1 && venice.Models[0].ID == "venice-live" + }, time.Second, 10*time.Millisecond) + require.Eventually(t, func() bool { + copilot, _, err := newCache[catwalk.Provider](cachePathFor("copilot")).Get() + return err == nil && len(copilot.Models) == 1 && copilot.Models[0].ID == "copilot-live" + }, time.Second, 10*time.Millisecond) +} + +func TestProviders_Integration_LiveOverlayUpdatesProviderListAndPublishesEvent(t *testing.T) { + resetProviderState() + defer resetProviderState() + + tmpDir := t.TempDir() + t.Setenv("XDG_DATA_HOME", tmpDir) + + started := make(chan struct{}, 1) + release := make(chan struct{}) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + started <- struct{}{} + <-release + require.Equal(t, "Bearer venice-token", r.Header.Get("Authorization")) + _, _ = w.Write([]byte(`{ + "data": [ + { + "id": "venice-live", + "type": "text", + "model_spec": { + "name": "Venice Live", + "availableContextTokens": 4096, + "maxCompletionTokens": 1024, + "pricing": {"input": {"usd": "0.1"}, "output": {"usd": "0.2"}}, + "capabilities": {"supportsVision": true} + } + } + ] + }`)) + })) + defer server.Close() + + seed := catwalk.Provider{ + Name: "Venice", + ID: catwalk.InferenceProviderVenice, + APIEndpoint: server.URL, + Type: catwalk.TypeOpenAICompat, + DefaultLargeModelID: "venice-seed", + DefaultSmallModelID: "venice-seed", + Models: []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, + } + cfg := &Config{ + Options: &Options{}, + Providers: csync.NewMapFrom(map[string]ProviderConfig{ + string(catwalk.InferenceProviderVenice): {APIKey: "venice-token"}, + }), + } + providerOnce.Do(func() {}) + setProviderList([]catwalk.Provider{seed}) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + events := SubscribeProviderEvents(ctx) + + result := overlayLiveProviderModels(t.Context(), cfg, []catwalk.Provider{seed}, true) + require.Equal(t, []catwalk.Model{{ID: "venice-seed", Name: "Venice Seed"}}, result[0].Models) + <-started + close(release) + + select { + case event := <-events: + require.Equal(t, pubsub.UpdatedEvent, event.Type) + require.Equal(t, string(catwalk.InferenceProviderVenice), event.Payload.ProviderID) + case <-time.After(time.Second): + require.FailNow(t, "Timed out waiting for providers updated event") + } + + providers, err := Providers(cfg) + require.NoError(t, err) + venice, ok := findProvider(providers, catwalk.InferenceProviderVenice) + require.True(t, ok) require.Equal(t, []catwalk.Model{{ ID: "venice-live", Name: "Venice Live", @@ -312,18 +425,6 @@ func TestProviders_Integration_LiveOverlayFetchesWithCredentials(t *testing.T) { DefaultMaxTokens: 1024, SupportsImages: true, }}, venice.Models) - - copilot, ok := findProvider(result, catwalk.InferenceProviderCopilot) - require.True(t, ok) - require.Equal(t, "Copilot", copilot.Name) - require.Equal(t, "copilot-live", copilot.DefaultSmallModelID) - require.Equal(t, []catwalk.Model{{ - ID: "copilot-live", - Name: "Copilot Live", - ContextWindow: 8192, - DefaultMaxTokens: 2048, - SupportsImages: true, - }}, copilot.Models) } func TestProviders_Integration_LiveOverlaySkipsWithoutCredentials(t *testing.T) { diff --git a/internal/ui/dialog/models.go b/internal/ui/dialog/models.go index 76ec4dbbf2..2725157630 100644 --- a/internal/ui/dialog/models.go +++ b/internal/ui/dialog/models.go @@ -161,6 +161,19 @@ func (m *Models) ID() string { return ModelsID } +// RefreshProviders reloads provider models and repopulates the list. +func (m *Models) RefreshProviders() error { + var err error + m.providers, err = config.Providers(m.com.Config()) + if err != nil { + return fmt.Errorf("failed to get providers: %w", err) + } + if err := m.setProviderItems(); err != nil { + return fmt.Errorf("failed to set provider items: %w", err) + } + return nil +} + // HandleMsg implements Dialog. func (m *Models) HandleMsg(msg tea.Msg) Action { switch msg := msg.(type) { diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index ca30d20fbd..6ef9b4f30b 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -712,6 +712,15 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) { cmds = append(cmds, m.handleFileEvent(msg.Payload)) case pubsub.Event[app.LSPEvent]: m.lspStates = app.GetLSPStates() + case pubsub.Event[config.ProvidersUpdatedEvent]: + if dia := m.dialog.Dialog(dialog.ModelsID); dia != nil { + models, ok := dia.(*dialog.Models) + if ok { + if err := models.RefreshProviders(); err != nil { + cmds = append(cmds, util.ReportError(err)) + } + } + } case pubsub.Event[skills.Event]: m.skillStates = msg.Payload.States case pubsub.Event[mcp.Event]: