From 61f3e2e8159c7d37121a9f2abc00079f99c69bd7 Mon Sep 17 00:00:00 2001 From: AN Long Date: Fri, 5 Jun 2026 22:27:29 +0900 Subject: [PATCH] feat: set OpenAI prompt cache key per session --- internal/agent/coordinator.go | 16 ++++--- internal/agent/coordinator_test.go | 73 +++++++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 86ca09e3bf..031c9ec450 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -233,7 +233,7 @@ func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID st return nil, errModelProviderNotConfigured } - mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg) + mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg, sessionID) if err := c.refreshTokenIfExpired(ctx, providerCfg); err != nil { // NOTE(@andreynering): We don't return here because the event handling to ask the user to reauthenticate @@ -311,7 +311,7 @@ func (c *coordinator) run(ctx context.Context, accept *AcceptedRun, sessionID st return result, originalErr } -func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions { +func getProviderOptions(model Model, providerCfg config.ProviderConfig, sessionID string) fantasy.ProviderOptions { options := fantasy.ProviderOptions{} cfgOpts := []byte("{}") @@ -368,6 +368,9 @@ func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy. if !hasReasoningEffort && shouldSetEffort { mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort } + if _, hasCacheKey := mergedOptions["prompt_cache_key"]; !hasCacheKey && sessionID != "" { + mergedOptions["prompt_cache_key"] = sessionID + } if openai.IsResponsesModel(model.CatwalkCfg.ID) { if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) { mergedOptions["reasoning_summary"] = "auto" @@ -516,8 +519,8 @@ func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy. return options } -func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) { - modelOptions := getProviderOptions(model, cfg) +func mergeCallOptions(model Model, cfg config.ProviderConfig, sessionID string) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) { + modelOptions := getProviderOptions(model, cfg, sessionID) temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature) topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP) topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK) @@ -1118,7 +1121,8 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error { } summarize := func() error { - return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg)) + model := c.currentAgent.Model() + return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(model, providerCfg, sessionID)) } return c.runWithUnauthorizedRetry(ctx, providerCfg, summarize) @@ -1242,7 +1246,7 @@ func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (f SessionID: session.ID, Prompt: params.Prompt, MaxOutputTokens: maxTokens, - ProviderOptions: getProviderOptions(model, providerCfg), + ProviderOptions: getProviderOptions(model, providerCfg, session.ID), Temperature: model.ModelCfg.Temperature, TopP: model.ModelCfg.TopP, TopK: model.ModelCfg.TopK, diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go index c522ef5de1..f4583f95ca 100644 --- a/internal/agent/coordinator_test.go +++ b/internal/agent/coordinator_test.go @@ -9,6 +9,7 @@ import ( "charm.land/fantasy" "charm.land/fantasy/providers/anthropic" "charm.land/fantasy/providers/bedrock" + "charm.land/fantasy/providers/openai" "github.com/charmbracelet/crush/internal/config" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -415,7 +416,7 @@ func TestGetProviderOptionsReasoningEffort(t *testing.T) { } providerCfg := config.ProviderConfig{ID: "test", Type: tc.providerType} - opts := getProviderOptions(model, providerCfg) + opts := getProviderOptions(model, providerCfg, "") raw, ok := opts[anthropic.Name] require.True(t, ok, "options should be keyed under anthropic.Name for type %q", tc.providerType) @@ -426,3 +427,73 @@ func TestGetProviderOptionsReasoningEffort(t *testing.T) { }) } } + +func TestGetProviderOptionsPromptCacheKey(t *testing.T) { + t.Run("responses model uses session id", func(t *testing.T) { + model := Model{ + CatwalkCfg: catwalk.Model{ID: "gpt-5"}, + } + providerCfg := config.ProviderConfig{Type: catwalk.Type(openai.Name)} + + opts := getProviderOptions(model, providerCfg, "session-123") + + raw, ok := opts[openai.Name] + require.True(t, ok) + parsed, ok := raw.(*openai.ResponsesProviderOptions) + require.True(t, ok) + require.NotNil(t, parsed.PromptCacheKey) + assert.Equal(t, "session-123", *parsed.PromptCacheKey) + }) + + t.Run("chat completions model uses session id", func(t *testing.T) { + model := Model{ + CatwalkCfg: catwalk.Model{ID: "legacy-chat-model"}, + } + providerCfg := config.ProviderConfig{Type: catwalk.Type(openai.Name)} + + opts := getProviderOptions(model, providerCfg, "session-123") + + raw, ok := opts[openai.Name] + require.True(t, ok) + parsed, ok := raw.(*openai.ProviderOptions) + require.True(t, ok) + require.NotNil(t, parsed.PromptCacheKey) + assert.Equal(t, "session-123", *parsed.PromptCacheKey) + }) + + t.Run("preserves explicit prompt cache key", func(t *testing.T) { + model := Model{ + CatwalkCfg: catwalk.Model{ID: "gpt-5"}, + ModelCfg: config.SelectedModel{ + ProviderOptions: map[string]any{ + "prompt_cache_key": "configured-cache-key", + }, + }, + } + providerCfg := config.ProviderConfig{Type: catwalk.Type(openai.Name)} + + opts := getProviderOptions(model, providerCfg, "session-123") + + parsed, ok := opts[openai.Name].(*openai.ResponsesProviderOptions) + require.True(t, ok) + require.NotNil(t, parsed.PromptCacheKey) + assert.Equal(t, "configured-cache-key", *parsed.PromptCacheKey) + }) + + t.Run("skips when session id is empty", func(t *testing.T) { + model := Model{ + CatwalkCfg: catwalk.Model{ID: "gpt-5"}, + } + providerCfg := config.ProviderConfig{Type: catwalk.Type(openai.Name)} + + opts := getProviderOptions(model, providerCfg, "") + + raw, ok := opts[openai.Name] + if ok { + parsed, ok := raw.(*openai.ResponsesProviderOptions) + if ok { + assert.Nil(t, parsed.PromptCacheKey) + } + } + }) +}