Skip to content

Commit b9d9d07

Browse files
authored
Merge pull request #8 from initializ/core/memory
Add long-term memory system with hybrid search
2 parents 1533a15 + dace691 commit b9d9d07

35 files changed

Lines changed: 4454 additions & 79 deletions

forge-cli/channels/router.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,15 @@ func (r *Router) Handler() channels.EventHandler {
3737
// forwardToA2A sends a tasks/send JSON-RPC request to the A2A server and
3838
// extracts the agent's response message from the returned task.
3939
func (r *Router) forwardToA2A(ctx context.Context, event *channels.ChannelEvent) (*a2a.Message, error) {
40-
taskID := fmt.Sprintf("%s-%s-%d", event.Channel, event.WorkspaceID, time.Now().UnixMilli())
40+
// Build a stable task ID so all messages in the same conversation share
41+
// one session. Use thread ID when available (threaded replies), otherwise
42+
// fall back to channel + workspace + user for DM-style conversations.
43+
var taskID string
44+
if event.ThreadID != "" {
45+
taskID = fmt.Sprintf("%s-%s-%s", event.Channel, event.WorkspaceID, event.ThreadID)
46+
} else {
47+
taskID = fmt.Sprintf("%s-%s-%s", event.Channel, event.WorkspaceID, event.UserID)
48+
}
4149

4250
params := a2a.SendTaskParams{
4351
ID: taskID,

forge-cli/runtime/runner.go

Lines changed: 225 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"path/filepath"
1010
"strings"
11+
"time"
1112

1213
"github.com/initializ/forge/forge-cli/server"
1314
cliskills "github.com/initializ/forge/forge-cli/skills"
@@ -17,6 +18,7 @@ import (
1718
"github.com/initializ/forge/forge-core/llm"
1819
"github.com/initializ/forge/forge-core/llm/oauth"
1920
"github.com/initializ/forge/forge-core/llm/providers"
21+
"github.com/initializ/forge/forge-core/memory"
2022
coreruntime "github.com/initializ/forge/forge-core/runtime"
2123
"github.com/initializ/forge/forge-core/tools"
2224
"github.com/initializ/forge/forge-core/tools/builtins"
@@ -165,12 +167,71 @@ func (r *Runner) Run(ctx context.Context) error {
165167
hooks := coreruntime.NewHookRegistry()
166168
r.registerLoggingHooks(hooks)
167169

168-
executor = coreruntime.NewLLMExecutor(coreruntime.LLMExecutorConfig{
170+
// Compute model-aware character budget.
171+
charBudget := r.cfg.Config.Memory.CharBudget
172+
if charBudget == 0 {
173+
charBudget = coreruntime.ContextBudgetForModel(mc.Client.Model)
174+
}
175+
176+
execCfg := coreruntime.LLMExecutorConfig{
169177
Client: llmClient,
170178
Tools: reg,
171179
Hooks: hooks,
172180
SystemPrompt: fmt.Sprintf("You are %s, an AI agent.", r.cfg.Config.AgentID),
173-
})
181+
Logger: r.logger,
182+
ModelName: mc.Client.Model,
183+
CharBudget: charBudget,
184+
}
185+
186+
// Initialize memory persistence (enabled by default).
187+
// Disable via FORGE_MEMORY_PERSISTENCE=false or memory.persistence: false in forge.yaml.
188+
memPersistence := true
189+
if r.cfg.Config.Memory.Persistence != nil {
190+
memPersistence = *r.cfg.Config.Memory.Persistence
191+
}
192+
if os.Getenv("FORGE_MEMORY_PERSISTENCE") == "false" {
193+
memPersistence = false
194+
}
195+
if memPersistence {
196+
sessDir := r.cfg.Config.Memory.SessionsDir
197+
if sessDir == "" {
198+
sessDir = filepath.Join(r.cfg.WorkDir, ".forge", "sessions")
199+
}
200+
memStore, storeErr := coreruntime.NewMemoryStore(sessDir)
201+
if storeErr != nil {
202+
r.logger.Warn("failed to create memory store, persistence disabled", map[string]any{
203+
"error": storeErr.Error(),
204+
})
205+
} else {
206+
// Clean up old sessions on startup (7-day TTL).
207+
deleted, _ := memStore.Cleanup(7 * 24 * time.Hour)
208+
if deleted > 0 {
209+
r.logger.Info("cleaned up old sessions", map[string]any{"deleted": deleted})
210+
}
211+
212+
compactor := coreruntime.NewCompactor(coreruntime.CompactorConfig{
213+
Client: llmClient,
214+
Store: memStore,
215+
Logger: r.logger,
216+
CharBudget: charBudget,
217+
TriggerRatio: r.cfg.Config.Memory.TriggerRatio,
218+
})
219+
220+
execCfg.Store = memStore
221+
execCfg.Compactor = compactor
222+
r.logger.Info("memory persistence enabled", map[string]any{
223+
"sessions_dir": sessDir,
224+
})
225+
}
226+
}
227+
228+
// Initialize long-term memory if enabled.
229+
memMgr := r.initLongTermMemory(ctx, mc, reg, execCfg.Compactor)
230+
if memMgr != nil {
231+
defer memMgr.Close() //nolint:errcheck
232+
}
233+
234+
executor = coreruntime.NewLLMExecutor(execCfg)
174235

175236
r.logger.Info("using LLM executor", map[string]any{
176237
"provider": mc.Provider,
@@ -248,11 +309,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
248309

249310
r.logger.Info("tasks/send", map[string]any{"task_id": params.ID})
250311

251-
// Create task in submitted state
252-
task := &a2a.Task{
253-
ID: params.ID,
254-
Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted},
312+
// Load existing task to preserve conversation history, or create new.
313+
task := store.Get(params.ID)
314+
if task == nil {
315+
task = &a2a.Task{ID: params.ID}
255316
}
317+
task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted}
256318
store.Put(task)
257319

258320
// Guardrail check inbound
@@ -268,9 +330,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
268330
return a2a.NewResponse(id, task)
269331
}
270332

333+
// Append inbound user message to task history.
334+
task.History = append(task.History, params.Message)
335+
271336
// Update to working
272-
store.UpdateStatus(params.ID, a2a.TaskStatus{State: a2a.TaskStateWorking})
273337
task.Status = a2a.TaskStatus{State: a2a.TaskStateWorking}
338+
store.Put(task)
274339

275340
// Execute via executor
276341
respMsg, err := executor.Execute(ctx, task, &params.Message)
@@ -302,6 +367,11 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
302367
}
303368
}
304369

370+
// Append agent response to task history.
371+
if respMsg != nil {
372+
task.History = append(task.History, *respMsg)
373+
}
374+
305375
// Build completed task
306376
task.Status = a2a.TaskStatus{
307377
State: a2a.TaskStateCompleted,
@@ -330,11 +400,12 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
330400

331401
r.logger.Info("tasks/sendSubscribe", map[string]any{"task_id": params.ID})
332402

333-
// Create task
334-
task := &a2a.Task{
335-
ID: params.ID,
336-
Status: a2a.TaskStatus{State: a2a.TaskStateSubmitted},
403+
// Load existing task to preserve conversation history, or create new.
404+
task := store.Get(params.ID)
405+
if task == nil {
406+
task = &a2a.Task{ID: params.ID}
337407
}
408+
task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted}
338409
store.Put(task)
339410
server.WriteSSEEvent(w, flusher, "status", task) //nolint:errcheck
340411

@@ -352,6 +423,9 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
352423
return
353424
}
354425

426+
// Append inbound user message to task history.
427+
task.History = append(task.History, params.Message)
428+
355429
// Update to working
356430
task.Status = a2a.TaskStatus{State: a2a.TaskStateWorking}
357431
store.Put(task)
@@ -387,6 +461,9 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent
387461
return
388462
}
389463

464+
// Append agent response to task history.
465+
task.History = append(task.History, *respMsg)
466+
390467
// Build completed result
391468
task.Status = a2a.TaskStatus{
392469
State: a2a.TaskStateCompleted,
@@ -721,6 +798,143 @@ func envFromOS() map[string]string {
721798
return env
722799
}
723800

801+
// initLongTermMemory sets up the long-term memory system if enabled.
802+
// It resolves the embedder, creates a memory.Manager, registers memory tools,
803+
// and starts background indexing. Returns the Manager (caller must Close) or nil.
804+
func (r *Runner) initLongTermMemory(ctx context.Context, mc *coreruntime.ModelConfig, reg *tools.Registry, compactor *coreruntime.Compactor) *memory.Manager {
805+
// Check if long-term memory is enabled.
806+
enabled := false
807+
if r.cfg.Config.Memory.LongTerm != nil {
808+
enabled = *r.cfg.Config.Memory.LongTerm
809+
}
810+
if os.Getenv("FORGE_MEMORY_LONG_TERM") == "true" {
811+
enabled = true
812+
}
813+
if !enabled {
814+
return nil
815+
}
816+
817+
memDir := r.cfg.Config.Memory.MemoryDir
818+
if memDir == "" {
819+
memDir = filepath.Join(r.cfg.WorkDir, ".forge", "memory")
820+
}
821+
822+
// Resolve embedder.
823+
embedder := r.resolveEmbedder(mc)
824+
825+
// Build search config from forge.yaml.
826+
searchCfg := memory.DefaultSearchConfig()
827+
if r.cfg.Config.Memory.VectorWeight > 0 {
828+
searchCfg.VectorWeight = r.cfg.Config.Memory.VectorWeight
829+
}
830+
if r.cfg.Config.Memory.KeywordWeight > 0 {
831+
searchCfg.KeywordWeight = r.cfg.Config.Memory.KeywordWeight
832+
}
833+
if r.cfg.Config.Memory.DecayHalfLifeDays > 0 {
834+
searchCfg.DecayHalfLife = time.Duration(r.cfg.Config.Memory.DecayHalfLifeDays) * 24 * time.Hour
835+
}
836+
837+
mgr, err := memory.NewManager(memory.ManagerConfig{
838+
MemoryDir: memDir,
839+
Embedder: embedder,
840+
Logger: r.logger,
841+
SearchConfig: searchCfg,
842+
})
843+
if err != nil {
844+
r.logger.Warn("failed to create memory manager, long-term memory disabled", map[string]any{
845+
"error": err.Error(),
846+
})
847+
return nil
848+
}
849+
850+
// Register memory tools.
851+
if regErr := reg.Register(builtins.NewMemorySearchTool(mgr)); regErr != nil {
852+
r.logger.Warn("failed to register memory_search tool", map[string]any{"error": regErr.Error()})
853+
}
854+
if regErr := reg.Register(builtins.NewMemoryGetTool(mgr)); regErr != nil {
855+
r.logger.Warn("failed to register memory_get tool", map[string]any{"error": regErr.Error()})
856+
}
857+
858+
// Wire memory flusher into compactor (if compactor exists).
859+
if compactor != nil {
860+
compactor.SetMemoryFlusher(mgr)
861+
}
862+
863+
// Index memory files at startup in background.
864+
go func() {
865+
if idxErr := mgr.IndexAll(ctx); idxErr != nil {
866+
r.logger.Warn("background memory indexing failed", map[string]any{"error": idxErr.Error()})
867+
}
868+
}()
869+
870+
mode := "keyword-only"
871+
if embedder != nil {
872+
mode = "vector+keyword"
873+
}
874+
r.logger.Info("long-term memory enabled", map[string]any{
875+
"memory_dir": memDir,
876+
"mode": mode,
877+
})
878+
879+
return mgr
880+
}
881+
882+
// resolveEmbedder creates an embedder from config or auto-detection.
883+
// Returns nil if no embedder can be created (keyword-only mode).
884+
func (r *Runner) resolveEmbedder(mc *coreruntime.ModelConfig) llm.Embedder {
885+
// Resolution order: config override → env → primary LLM provider.
886+
embProvider := r.cfg.Config.Memory.EmbeddingProvider
887+
if embProvider == "" {
888+
embProvider = os.Getenv("FORGE_EMBEDDING_PROVIDER")
889+
}
890+
if embProvider == "" {
891+
embProvider = mc.Provider
892+
}
893+
894+
// Anthropic has no embedding API — skip.
895+
if embProvider == "anthropic" {
896+
r.logger.Info("primary provider is anthropic (no embedding API), trying fallbacks for embeddings", nil)
897+
// Try fallback providers.
898+
for _, fb := range mc.Fallbacks {
899+
if fb.Provider != "anthropic" {
900+
embProvider = fb.Provider
901+
break
902+
}
903+
}
904+
if embProvider == "anthropic" {
905+
r.logger.Info("no embedding-capable provider found, using keyword-only search", nil)
906+
return nil
907+
}
908+
}
909+
910+
cfg := providers.OpenAIEmbedderConfig{
911+
APIKey: mc.Client.APIKey,
912+
Model: r.cfg.Config.Memory.EmbeddingModel,
913+
}
914+
915+
// Use the correct API key for the embedding provider if it differs from primary.
916+
if embProvider != mc.Provider {
917+
for _, fb := range mc.Fallbacks {
918+
if fb.Provider == embProvider {
919+
cfg.APIKey = fb.Client.APIKey
920+
cfg.BaseURL = fb.Client.BaseURL
921+
break
922+
}
923+
}
924+
}
925+
926+
embedder, err := providers.NewEmbedder(embProvider, cfg)
927+
if err != nil {
928+
r.logger.Warn("failed to create embedder, using keyword-only search", map[string]any{
929+
"provider": embProvider,
930+
"error": err.Error(),
931+
})
932+
return nil
933+
}
934+
935+
return embedder
936+
}
937+
724938
func defaultStr(s, def string) string {
725939
if s != "" {
726940
return s

forge-core/channels/plugin.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type ChannelEvent struct {
4646
WorkspaceID string `json:"workspace_id"`
4747
UserID string `json:"user_id"`
4848
ThreadID string `json:"thread_id,omitempty"`
49+
MessageID string `json:"message_id,omitempty"` // per-message ID for reply targeting
4950
Message string `json:"message"`
5051
Attachments []Attachment `json:"attachments,omitempty"`
5152
Raw json.RawMessage `json:"raw,omitempty"`

forge-core/llm/embedder.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package llm
2+
3+
import "context"
4+
5+
// EmbeddingRequest is a provider-agnostic request to generate embeddings.
6+
type EmbeddingRequest struct {
7+
Texts []string // texts to embed
8+
Model string // optional model override
9+
}
10+
11+
// EmbeddingResponse is a provider-agnostic embedding response.
12+
type EmbeddingResponse struct {
13+
Embeddings [][]float32
14+
Model string
15+
Usage UsageInfo
16+
}
17+
18+
// Embedder generates vector embeddings from text.
19+
type Embedder interface {
20+
// Embed produces embeddings for the given texts.
21+
Embed(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error)
22+
// Dimensions returns the dimensionality of the embedding vectors.
23+
Dimensions() int
24+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package providers
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/initializ/forge/forge-core/llm"
7+
)
8+
9+
// NewEmbedder creates an Embedder for the specified provider.
10+
// Supported providers: "openai", "gemini", "ollama".
11+
// Returns an error for "anthropic" (no embedding API).
12+
func NewEmbedder(provider string, cfg OpenAIEmbedderConfig) (llm.Embedder, error) {
13+
switch provider {
14+
case "openai":
15+
return NewOpenAIEmbedder(cfg), nil
16+
case "gemini":
17+
if cfg.BaseURL == "" {
18+
cfg.BaseURL = "https://generativelanguage.googleapis.com/v1beta/openai"
19+
}
20+
return NewOpenAIEmbedder(cfg), nil
21+
case "ollama":
22+
return NewOllamaEmbedder(cfg), nil
23+
case "anthropic":
24+
return nil, fmt.Errorf("anthropic does not provide an embedding API; configure an alternative embedding provider")
25+
default:
26+
return nil, fmt.Errorf("unknown embedding provider: %q", provider)
27+
}
28+
}

0 commit comments

Comments
 (0)