Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 44 additions & 36 deletions internal/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1637,15 +1637,23 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
return
}

// Ensure the session always gets a title even if every path below
// fails or the context is cancelled before we finish.
var titleSaved bool
defer func() {
if !titleSaved {
fallbackCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
if err := a.sessions.Rename(fallbackCtx, sessionID, DefaultSessionName); err != nil {
slog.Error("Failed to save fallback session title", "error", err)
}
}
}()

smallModel := a.smallModel.Get()
largeModel := a.largeModel.Get()
systemPromptPrefix := a.systemPromptPrefix.Get()

var maxOutputTokens int64 = 40
if smallModel.CatwalkCfg.CanReason {
maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
}

newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
return fantasy.NewAgent(
m,
Expand All @@ -1668,41 +1676,40 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
},
}

// Use the small model to generate the title.
model := smallModel
agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
resp, err := agent.Stream(ctx, streamCall)
if err == nil {
// We successfully generated a title with the small model.
slog.Debug("Generated title with small model")
} else {
// It didn't work. Let's try with the big model.
slog.Error("Error generating title with small model; trying big model", "err", err)
model = largeModel
agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
type modelAttempt struct {
name string
model Model
}
attempts := []modelAttempt{
{"small", smallModel},
{"large", largeModel},
}

var resp *fantasy.AgentResult
var err error
var model Model
var success bool
for _, attempt := range attempts {
tok := int64(40)
if attempt.model.CatwalkCfg.CanReason {
tok = attempt.model.CatwalkCfg.DefaultMaxTokens
}
agent := newAgent(attempt.model.Model, titlePrompt, tok)
resp, err = agent.Stream(ctx, streamCall)
if err == nil {
slog.Debug("Generated title with large model")
if err == nil && resp.Response.FinishReason != fantasy.FinishReasonLength {
model = attempt.model
slog.Debug("Generated title with " + attempt.name + " model")
success = true
break
}
if err != nil {
slog.Error("Error generating title with "+attempt.name+" model; trying next", "err", err)
} else {
// Welp, the large model didn't work either. Use the default
// session name and return.
slog.Error("Error generating title with large model", "err", err)
saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
if saveErr != nil {
slog.Error("Failed to save session title", "error", saveErr)
}
return
slog.Error("Title generation hit token limit with " + attempt.name + " model; trying next")
}
}

if resp == nil {
// Actually, we didn't get a response so we can't. Use the default
// session name and return.
slog.Error("Response is nil; can't generate title")
saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
if saveErr != nil {
slog.Error("Failed to save session title", "error", saveErr)
}
if !success {
// The deferred fallback will save the default session name.
return
}

Expand Down Expand Up @@ -1756,6 +1763,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, user
slog.Error("Failed to save session title and usage", "error", saveErr)
return
}
titleSaved = true
}

func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
Expand Down
8 changes: 4 additions & 4 deletions internal/server/recover_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestRecoverHandler_PanicReturns500(t *testing.T) {
}))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
h.ServeHTTP(rec, req)

require.Equal(t, http.StatusInternalServerError, rec.Code)
Expand All @@ -48,7 +48,7 @@ func TestRecoverHandler_NoPanicPassthrough(t *testing.T) {
}))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
h.ServeHTTP(rec, req)

require.Equal(t, http.StatusTeapot, rec.Code)
Expand All @@ -71,7 +71,7 @@ func TestRecoverHandler_PanicAfterWriteHeader(t *testing.T) {
}))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
require.NotPanics(t, func() { h.ServeHTTP(rec, req) })
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "partial", rec.Body.String())
Expand All @@ -89,6 +89,6 @@ func TestRecoverHandler_AbortHandlerPropagates(t *testing.T) {
}))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)
require.PanicsWithValue(t, http.ErrAbortHandler, func() { h.ServeHTTP(rec, req) })
}
Loading