Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ internal/
### Key Patterns

- **Config is a Service**: accessed via `config.Service`, not global state.
- **Provider updates**: Catwalk is the default provider seed/cache. Venice and
Copilot models are live-overlaid only when authenticated and use the existing
provider cache/syncer with a 60s warm-cache TTL; `crush update-providers`
refreshes authenticated live providers best-effort, while `--source=venice`
and `--source=copilot` refresh them directly.
- **Tools are self-documenting**: each tool has a `.go` implementation and a
`.md` description file in `internal/agent/tools/`.
- **System prompts are Go templates**: `internal/agent/templates/*.md.tpl`
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,7 @@ command:

```bash
# Update providers remotely from Catwalk.
# This also refreshes Venice and Copilot models when you're authenticated.
crush update-providers

# Update providers from a custom Catwalk base URL.
Expand Down
1 change: 1 addition & 0 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ func (app *App) setupEvents() {
setupSubscriberMustDeliver(ctx, app.serviceEventsWG, "run-completions", app.runCompletions.Subscribe, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
setupSubscriber(ctx, app.serviceEventsWG, "providers", config.SubscribeProviderEvents, app.events)
if app.Skills != nil {
setupSubscriber(ctx, app.serviceEventsWG, "skills", app.Skills.SubscribeEvents, app.events)
}
Expand Down
69 changes: 66 additions & 3 deletions internal/cmd/update_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"log/slog"
"os"

"charm.land/lipgloss/v2"
"github.com/charmbracelet/crush/internal/config"
Expand All @@ -17,7 +18,7 @@ var updateProvidersCmd = &cobra.Command{
Short: "Update providers",
Long: `Update provider information from a specified local path or remote URL.`,
Example: `
# Update Catwalk providers remotely (default)
# Update Catwalk providers remotely (default), plus authenticated live providers
crush update-providers

# Update Catwalk providers from a custom URL
Expand All @@ -34,6 +35,12 @@ crush update-providers --source=hyper

# Update Hyper from a custom URL
crush update-providers --source=hyper https://hyper.example.com

# Update Venice provider models
crush update-providers --source=venice

# Update Copilot provider models
crush update-providers --source=copilot
`,
RunE: func(cmd *cobra.Command, args []string) error {
// NOTE(@andreynering): We want to skip logging output do stdout here.
Expand All @@ -48,10 +55,23 @@ crush update-providers --source=hyper https://hyper.example.com
switch updateProvidersSource {
case "catwalk":
err = config.UpdateProviders(pathOrURL)
if err == nil && pathOrURL == "" {
updateAuthenticatedLiveProviders(cmd)
}
case "hyper":
err = config.UpdateHyper(pathOrURL)
case "venice", "copilot":
cfg, loadErr := loadUpdateProvidersConfig(cmd)
if loadErr != nil {
return loadErr
}
if updateProvidersSource == "venice" {
err = config.UpdateVenice(pathOrURL, cfg)
} else {
err = config.UpdateCopilot(pathOrURL, cfg)
}
default:
return fmt.Errorf("invalid source %q, must be 'catwalk' or 'hyper'", updateProvidersSource)
return fmt.Errorf("invalid source %q, must be 'catwalk', 'hyper', 'venice', or 'copilot'", updateProvidersSource)
}

if err != nil {
Expand All @@ -77,6 +97,49 @@ crush update-providers --source=hyper https://hyper.example.com
},
}

func updateAuthenticatedLiveProviders(cmd *cobra.Command) {
cfg, err := loadUpdateProvidersConfig(cmd)
if err != nil {
fmt.Fprintf(os.Stderr, "Note: skipping Venice and Copilot updates: %v\n", err)
return
}
if err := config.UpdateVenice("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) {
fmt.Fprintf(os.Stderr, "Note: skipping Venice update: %v\n", err)
}
if err := config.UpdateCopilot("", cfg); err != nil && !config.IsMissingLiveProviderCredentials(err) {
fmt.Fprintf(os.Stderr, "Note: skipping Copilot update: %v\n", err)
}
}

func loadUpdateProvidersConfig(cmd *cobra.Command) (*config.Config, error) {
cwd, err := ResolveCwd(cmd)
if err != nil {
return nil, err
}
dataDir, _ := cmd.Flags().GetString("data-dir")
debug, _ := cmd.Flags().GetBool("debug")

// The config is loaded only to resolve provider credentials;
// updateLiveProvider fetches explicitly afterwards. Disable provider
// auto-update during Load so it doesn't fetch the same endpoints
// first (avoiding a double fetch per provider).
previous, hadPrevious := os.LookupEnv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE")
_ = os.Setenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE", "1")
defer func() {
if hadPrevious {
_ = os.Setenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE", previous)
} else {
_ = os.Unsetenv("CRUSH_DISABLE_PROVIDER_AUTO_UPDATE")
}
}()

store, err := config.Load(cwd, dataDir, debug)
if err != nil {
return nil, err
}
return store.Config(), nil
}

func init() {
updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk or hyper)")
updateProvidersCmd.Flags().StringVar(&updateProvidersSource, "source", "catwalk", "Provider source to update (catwalk, hyper, venice, or copilot)")
}
27 changes: 27 additions & 0 deletions internal/cmd/update_providers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package cmd

import (
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestUpdateProvidersCmd_SourceFlagIncludesLiveProviders(t *testing.T) {
t.Parallel()

flag := updateProvidersCmd.Flags().Lookup("source")
require.NotNil(t, flag)
require.Equal(t, "catwalk", flag.DefValue)
require.Contains(t, flag.Usage, "venice")
require.Contains(t, flag.Usage, "copilot")
}

func TestUpdateProvidersCmd_ExamplesIncludeLiveProviders(t *testing.T) {
t.Parallel()

example := strings.ToLower(updateProvidersCmd.Example)
require.Contains(t, example, "--source=venice")
require.Contains(t, example, "--source=copilot")
require.Contains(t, example, "authenticated live providers")
}
176 changes: 176 additions & 0 deletions internal/config/copilot_models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package config

import (
"context"
"encoding/json"
"fmt"
"net/http"
"slices"
"strings"
"time"

"charm.land/catwalk/pkg/catwalk"
"github.com/charmbracelet/crush/internal/oauth"
"github.com/charmbracelet/crush/internal/oauth/copilot"
)

var _ liveProviderClient = (*realCopilotModelsClient)(nil)

type copilotTokenRefresher func(context.Context, string) (*oauth.Token, error)

type realCopilotModelsClient struct {
baseURL string
apiKey string
oauthToken *oauth.Token
refreshToken copilotTokenRefresher
}

func (r *realCopilotModelsClient) Get(ctx context.Context) (catwalk.Provider, error) {
var result catwalk.Provider
baseURL := strings.TrimRight(r.baseURL, "/")
accessToken, err := r.accessToken(ctx)
if err != nil {
return result, err
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil)
if err != nil {
return result, fmt.Errorf("could not create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+accessToken)
for key, value := range copilot.Headers() {
req.Header.Set(key, value)
}

client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return result, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck

if resp.StatusCode != http.StatusOK {
return result, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var models copilotModelsResponse
if err := json.NewDecoder(resp.Body).Decode(&models); err != nil {
return result, fmt.Errorf("failed to decode response: %w", err)
}

result = catwalk.Provider{
ID: catwalk.InferenceProviderCopilot,
APIEndpoint: baseURL,
Type: catwalk.TypeOpenAICompat,
Models: copilotModelsToCatwalkModels(models),
}
return result, nil
}

func (r *realCopilotModelsClient) accessToken(ctx context.Context) (string, error) {
if r.oauthToken != nil {
if token := strings.TrimSpace(r.oauthToken.AccessToken); token != "" && !r.oauthToken.IsExpired() {
r.apiKey = token
return token, nil
}

refreshToken := strings.TrimSpace(r.oauthToken.RefreshToken)
if refreshToken != "" {
refresh := r.refreshToken
if refresh == nil {
refresh = copilot.RefreshToken
}
refreshedToken, err := refresh(ctx, refreshToken)
if err != nil {
return "", fmt.Errorf("failed to refresh Copilot token: %w", err)
}
r.oauthToken = refreshedToken
r.apiKey = strings.TrimSpace(refreshedToken.AccessToken)
if r.apiKey != "" {
return r.apiKey, nil
}
}
}

if token := strings.TrimSpace(r.apiKey); token != "" {
return token, nil
}
return "", fmt.Errorf("missing Copilot access token")
}

type copilotModelsResponse struct {
Data []copilotModel `json:"data"`
}

type copilotModel struct {
ID string `json:"id"`
Name string `json:"name"`
Version string `json:"version"`
Capabilities copilotModelCapabilities `json:"capabilities"`
}

type copilotModelCapabilities struct {
Type string `json:"type"`
Limits copilotModelLimits `json:"limits"`
Supports copilotModelSupports `json:"supports"`
}

type copilotModelLimits struct {
MaxContextWindowTokens int64 `json:"max_context_window_tokens"`
MaxOutputTokens int64 `json:"max_output_tokens"`
}

type copilotModelSupports struct {
Vision bool `json:"vision"`
ReasoningEffort []string `json:"reasoning_effort"`
AdaptiveThinking bool `json:"adaptive_thinking"`
}

func copilotModelsToCatwalkModels(response copilotModelsResponse) []catwalk.Model {
aliasedVersions := make(map[string]bool, len(response.Data))
for _, model := range response.Data {
if model.Version != "" && model.ID != model.Version {
aliasedVersions[model.Version] = true
}
}

seen := make(map[string]bool, len(response.Data))
models := make([]catwalk.Model, 0, len(response.Data))
for _, model := range response.Data {
if shouldSkipCopilotModel(model, aliasedVersions) || seen[model.ID] {
continue
}
seen[model.ID] = true
reasoningLevels := model.Capabilities.Supports.ReasoningEffort
models = append(models, catwalk.Model{
ID: model.ID,
Name: model.Name,
ContextWindow: model.Capabilities.Limits.MaxContextWindowTokens,
DefaultMaxTokens: model.Capabilities.Limits.MaxOutputTokens,
CanReason: model.Capabilities.Supports.AdaptiveThinking || len(reasoningLevels) > 0,
ReasoningLevels: reasoningLevels,
SupportsImages: model.Capabilities.Supports.Vision,
})
}

slices.SortStableFunc(models, func(a, b catwalk.Model) int {
return strings.Compare(a.ID, b.ID)
})
return models
}

func shouldSkipCopilotModel(model copilotModel, aliasedVersions map[string]bool) bool {
if capType := model.Capabilities.Type; capType != "" && capType != "chat" {
return true
}
return model.ID == "" ||
aliasedVersions[model.ID] ||
strings.Contains(model.ID, "embedding") ||
strings.HasPrefix(model.ID, "accounts/msft/routers") ||
strings.HasPrefix(model.ID, "oswe-vscode") ||
strings.HasPrefix(model.ID, "lark") ||
strings.HasPrefix(model.ID, "mai-code") ||
model.ID == "gpt-4-o-preview" ||
model.ID == "trajectory-compaction"
}
Loading
Loading