Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/agent/auto_review.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ func ClassifyTool(toolName string, args any) (RiskTier, string) {

case "todo", "task_create", "task_update", "task_list", "task_get",
"check_background", "schedule_create", "schedule_list", "schedule_delete",
"subagent_run",
// Teams
"spawn_teammate", "list_teammates", "send_message", "read_inbox", "broadcast",
// Protocols
Expand Down
172 changes: 172 additions & 0 deletions pkg/agent/compaction_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package agent

import (
"context"
"fmt"
"iter"
"strings"
"testing"

"google.golang.org/adk/model"
"google.golang.org/genai"
)

type mockLLM struct {
generateFunc func(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error]
}

func (m *mockLLM) Name() string {
return "mock-llm"
}

func (m *mockLLM) GenerateContent(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error] {
if m.generateFunc != nil {
return m.generateFunc(ctx, req, stream)
}
return func(yield func(*model.LLMResponse, error) bool) {
yield(&model.LLMResponse{
Content: &genai.Content{
Role: "model",
Parts: []*genai.Part{
{Text: "LLM-generated summary"},
},
},
TurnComplete: true,
}, nil)
}
}

func TestCompaction_LLMAndCircuitBreaker(t *testing.T) {
// Reset circuit breaker state before testing
compactionCircuitBreaker.mu.Lock()
compactionCircuitBreaker.failures = 0
compactionCircuitBreaker.open = false
compactionCircuitBreaker.mu.Unlock()

// 1. Create a history of 14 rounds (>12)
contents := make([]*genai.Content, 14)
for i := 0; i < 14; i++ {
role := "user"
if i%2 == 1 {
role = "model"
}
contents[i] = &genai.Content{
Role: role,
Parts: []*genai.Part{
{Text: fmt.Sprintf("Round message %d", i)},
},
}
}

// 2. Test successful LLM summarization
mock := &mockLLM{}
compacted := CompactContents(contents, "session-test-llm", mock)
if len(compacted) != 6 {
t.Fatalf("Expected compacted length 6, got %d", len(compacted))
}
// Verify that the second item is the system prompt with LLM summary
sysText := compacted[1].Parts[0].Text
if !strings.Contains(sysText, "LLM-generated summary") {
t.Errorf("Expected summary to contain LLM output, got: %q", sysText)
}

// Verify circuit breaker did not trip
compactionCircuitBreaker.mu.Lock()
failures := compactionCircuitBreaker.failures
isOpen := compactionCircuitBreaker.open
compactionCircuitBreaker.mu.Unlock()
if failures != 0 || isOpen {
t.Errorf("Expected 0 failures and closed circuit breaker, got failures=%d, open=%v", failures, isOpen)
}

// 3. Test failing LLM calls to trip circuit breaker
failMock := &mockLLM{
generateFunc: func(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error] {
panic("simulated LLM panic")
},
}

// First failure
_ = CompactContents(contents, "session-test-fail-1", failMock)
compactionCircuitBreaker.mu.Lock()
failures = compactionCircuitBreaker.failures
isOpen = compactionCircuitBreaker.open
compactionCircuitBreaker.mu.Unlock()
if failures != 1 || isOpen {
t.Errorf("Expected 1 failure, got %d, open=%v", failures, isOpen)
}

// Second failure
_ = CompactContents(contents, "session-test-fail-2", failMock)
compactionCircuitBreaker.mu.Lock()
failures = compactionCircuitBreaker.failures
isOpen = compactionCircuitBreaker.open
compactionCircuitBreaker.mu.Unlock()
if failures != 2 || isOpen {
t.Errorf("Expected 2 failures, got %d, open=%v", failures, isOpen)
}

// Third failure -> should trip circuit breaker
_ = CompactContents(contents, "session-test-fail-3", failMock)
compactionCircuitBreaker.mu.Lock()
failures = compactionCircuitBreaker.failures
isOpen = compactionCircuitBreaker.open
compactionCircuitBreaker.mu.Unlock()
if failures != 3 || !isOpen {
t.Errorf("Expected 3 failures and open circuit breaker, got failures=%d, open=%v", failures, isOpen)
}

// 4. Verification that subsequent compaction bypasses LLM and falls back to truncation-only summary
compactedFallback := CompactContents(contents, "session-test-fallback", mock)
sysTextFallback := compactedFallback[1].Parts[0].Text
if !strings.Contains(sysTextFallback, "truncation-only mode") {
t.Errorf("Expected truncation-only fallback message, got: %q", sysTextFallback)
}
}

func TestCompaction_StickyLatch(t *testing.T) {
// Reset circuit breaker state
compactionCircuitBreaker.mu.Lock()
compactionCircuitBreaker.failures = 0
compactionCircuitBreaker.open = false
compactionCircuitBreaker.mu.Unlock()

// 1. Create a history of 14 rounds with one sticky block
contents := make([]*genai.Content, 14)
for i := 0; i < 14; i++ {
role := "user"
if i%2 == 1 {
role = "model"
}
text := fmt.Sprintf("Round message %d", i)
if i == 5 {
text = "This is a [STICKY] instruction that must persist"
}
contents[i] = &genai.Content{
Role: role,
Parts: []*genai.Part{
{Text: text},
},
}
}

// 2. Perform compaction
compacted := CompactContents(contents, "session-test-sticky")
// Expected: preserved prompt (1), summary (1), sticky (1), and last 4 rounds (4) = 7
if len(compacted) != 7 {
t.Fatalf("Expected compacted length of 7, got %d", len(compacted))
}

// Check that the sticky block exists in the output
foundSticky := false
for _, c := range compacted {
for _, p := range c.Parts {
if strings.Contains(p.Text, "[STICKY]") {
foundSticky = true
}
}
}
if !foundSticky {
t.Error("Expected sticky content to be preserved, but it was not found")
}
}
36 changes: 23 additions & 13 deletions pkg/agent/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/firebase/genkit/go/genkit"
"google.golang.org/adk/agent"
"google.golang.org/adk/agent/llmagent"
"google.golang.org/adk/model"
"google.golang.org/adk/runner"
"google.golang.org/adk/session"
"google.golang.org/adk/tool"
Expand Down Expand Up @@ -91,15 +92,23 @@ func TypePromptPrefix(typeName string) string {
}

func (ap *AgentPool) ExecuteMessage(teammate *Teammate, msg TeamMessage) (string, error) {
// Determine worktree directory
worktreePath := filepath.Join(GlobalWorktreeManager.worktreesDir, teammate.Name)
if _, err := os.Stat(worktreePath); os.IsNotExist(err) {
worktreePath, _ = os.Getwd()
}
return ap.ExecuteMessageInDir(teammate, msg, worktreePath)
}

func (ap *AgentPool) ExecuteMessageInDir(teammate *Teammate, msg TeamMessage, dir string) (string, error) {
ap.mu.Lock()
subRunner, exists := ap.runners[teammate.Name]
ap.mu.Unlock()

if !exists {
// 1. Ensure worktree directory exists dynamically if it doesn't yet
worktreePath := filepath.Join(GlobalWorktreeManager.worktreesDir, teammate.Name)
if _, err := os.Stat(worktreePath); os.IsNotExist(err) {
_, _ = GlobalWorktreeManager.Create(teammate.Name, "")
// 1. Ensure directory exists dynamically if it doesn't yet
if dir != "" {
_ = os.MkdirAll(dir, 0755)
}

// 2. Setup subagent ADK Runner
Expand Down Expand Up @@ -127,8 +136,14 @@ func (ap *AgentPool) ExecuteMessage(teammate *Teammate, msg TeamMessage) (string
genkitReg := ap.GenkitRegistry
ap.mu.RUnlock()

// Retrieve or build subagent adapter
subAdapter, err := llm.NewAdapter(genkitReg, prov, mod, key, base, systemPrompt, fmtFormat, runnerHooks{})
// Retrieve or build subagent adapter with dynamic resolve hooks
var subAdapter model.LLM
hooks := runnerHooks{
modelGetter: func() model.LLM {
return subAdapter
},
}
subAdapter, err = llm.NewAdapter(genkitReg, prov, mod, key, base, systemPrompt, fmtFormat, hooks)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -159,16 +174,11 @@ func (ap *AgentPool) ExecuteMessage(teammate *Teammate, msg TeamMessage) (string
ap.mu.Unlock()
}

// Determine worktree directory
worktreePath := filepath.Join(GlobalWorktreeManager.worktreesDir, teammate.Name)
if _, err := os.Stat(worktreePath); os.IsNotExist(err) {
worktreePath, _ = os.Getwd()
}

// 3. Execute prompt on the subRunner
// Setup context with subagent name and workdir path
ctx := context.WithValue(context.Background(), WorkdirKey, worktreePath)
ctx := context.WithValue(context.Background(), WorkdirKey, dir)
ctx = context.WithValue(ctx, AgentNameKey, teammate.Name)
ctx = context.WithValue(ctx, "session_id", teammate.Name+"-session") // Inject session id for history compaction

userMsg := &genai.Content{
Role: "user",
Expand Down
83 changes: 79 additions & 4 deletions pkg/agent/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import (
)

// runnerHooks implements llm.AdapterHooks using the agent package's global managers.
type runnerHooks struct{}
type runnerHooks struct {
modelGetter func() model.LLM
}

func (runnerHooks) NagReminder() string {
if GlobalTodoManager.RoundsSinceUpdate() >= 3 {
Expand All @@ -38,6 +40,18 @@ func (runnerHooks) NoteRound() {
GlobalTodoManager.NoteRoundWithoutUpdate()
}

func (r runnerHooks) CompactHistory(contents []*genai.Content) []*genai.Content {
sessID := GlobalLogger.sessionID
if sessID == "" {
sessID = "session-default"
}
var m model.LLM
if r.modelGetter != nil {
m = r.modelGetter()
}
return CompactContents(contents, sessID, m)
}

func buildSystemPrompt() string {
builder := NewSystemPromptBuilder()
return builder.Build()
Expand Down Expand Up @@ -171,7 +185,14 @@ func NewCustomRunner(provider llm.ProviderType, modelName string, apiKey string,

// 2. Create our abstract model adapter
systemPrompt := buildSystemPrompt()
modelAdapter, err := llm.NewAdapter(g, provider, modelName, apiKey, baseURL, systemPrompt, apiFormat, runnerHooks{})
var modelAdapter model.LLM
hooks := runnerHooks{
modelGetter: func() model.LLM {
return modelAdapter
},
}
var err error
modelAdapter, err = llm.NewAdapter(g, provider, modelName, apiKey, baseURL, systemPrompt, apiFormat, hooks)
if err != nil {
return nil, fmt.Errorf("failed to create model adapter: %w", err)
}
Expand Down Expand Up @@ -278,6 +299,10 @@ func (cr *CustomRunner) ModelName() string {
return cr.llmModel.Name()
}

func (cr *CustomRunner) GetModel() model.LLM {
return cr.llmModel
}

func (cr *CustomRunner) GetTokenUsage() int {
if cr.llmModel == nil {
return 0
Expand Down Expand Up @@ -383,8 +408,9 @@ func (cr *CustomRunner) Execute(ctx context.Context, userID, sessionID, prompt s
},
}

runCtx := context.WithValue(ctx, "session_id", sessionID)
runConfig := runner.WithStateDelta(nil)
events := cr.adkRunner.Run(ctx, userID, sessionID, userMsg, agent.RunConfig{
events := cr.adkRunner.Run(runCtx, userID, sessionID, userMsg, agent.RunConfig{
StreamingMode: agent.StreamingModeSSE,
}, runConfig)

Expand Down Expand Up @@ -1018,9 +1044,58 @@ func rollbackPendingEdits() {
pendingEditSnapshots.snapshots = make(map[string]string)
}

// commitPendingEdits clears all snapshots after a successful turn.
// UndoGroup tracks pre-edit file states for a single conversational turn.
type UndoGroup struct {
Snapshots map[string]string
}

// UndoHistoryManager maintains a session-wide history of UndoGroups.
type UndoHistoryManager struct {
mu sync.Mutex
history []UndoGroup
}

func (u *UndoHistoryManager) Push(group UndoGroup) {
u.mu.Lock()
defer u.mu.Unlock()
u.history = append(u.history, group)
}

func (u *UndoHistoryManager) PopAndUndo() (int, error) {
u.mu.Lock()
defer u.mu.Unlock()

if len(u.history) == 0 {
return 0, fmt.Errorf("no changes to undo")
}

last := u.history[len(u.history)-1]
u.history = u.history[:len(u.history)-1]

count := 0
for path, content := range last.Snapshots {
if content == "" {
_ = os.Remove(path)
} else {
_ = os.WriteFile(path, []byte(content), 0644)
}
count++
}

return count, nil
}

var GlobalUndoManager = &UndoHistoryManager{}

// commitPendingEdits pushes the current turn's snapshots onto the GlobalUndoManager and clears active pending snapshots.
func commitPendingEdits() {
pendingEditSnapshots.mu.Lock()
defer pendingEditSnapshots.mu.Unlock()

if len(pendingEditSnapshots.snapshots) > 0 {
GlobalUndoManager.Push(UndoGroup{
Snapshots: pendingEditSnapshots.snapshots,
})
}
pendingEditSnapshots.snapshots = make(map[string]string)
}
Loading