diff --git a/internal/db/duckdb.go b/internal/db/duckdb.go index 8b18eca..43220b1 100644 --- a/internal/db/duckdb.go +++ b/internal/db/duckdb.go @@ -81,6 +81,58 @@ func (s *Store) initialize() error { return fmt.Errorf("failed to execute schema: %w", err) } + // Create knowledge graph tables + graphSchema := ` + CREATE TABLE IF NOT EXISTS entities ( + id VARCHAR PRIMARY KEY, + canonical_name TEXT NOT NULL, + entity_type VARCHAR, + embedding FLOAT[768], + group_id VARCHAR DEFAULT 'default', + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + metadata JSON + ); + CREATE INDEX IF NOT EXISTS idx_entities_group_id ON entities (group_id); + CREATE INDEX IF NOT EXISTS idx_entities_canonical_name ON entities (canonical_name); + CREATE UNIQUE INDEX IF NOT EXISTS idx_entities_name_group ON entities (LOWER(canonical_name), group_id); + + CREATE TABLE IF NOT EXISTS knowledge ( + id VARCHAR PRIMARY KEY, + subject_entity_id VARCHAR NOT NULL, + predicate VARCHAR NOT NULL, + object_entity_id VARCHAR NOT NULL, + source_episode_id VARCHAR, + source VARCHAR NOT NULL, + group_id VARCHAR DEFAULT 'default', + embedding FLOAT[768], + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + expired_at TIMESTAMPTZ, + confidence FLOAT DEFAULT 1.0, + verified BOOLEAN DEFAULT FALSE, + metadata JSON + ); + CREATE INDEX IF NOT EXISTS idx_knowledge_subject ON knowledge (subject_entity_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_object ON knowledge (object_entity_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_source_episode ON knowledge (source_episode_id); + CREATE INDEX IF NOT EXISTS idx_knowledge_group_id ON knowledge (group_id); + + CREATE TABLE IF NOT EXISTS episode_links ( + id VARCHAR PRIMARY KEY, + source_episode_id VARCHAR NOT NULL, + target_episode_id VARCHAR NOT NULL, + relationship VARCHAR NOT NULL, + via_entity_id VARCHAR, + weight FLOAT DEFAULT 1.0, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + UNIQUE (source_episode_id, target_episode_id, relationship) + ); + CREATE INDEX IF NOT EXISTS idx_episode_links_source ON episode_links (source_episode_id); + CREATE INDEX IF NOT EXISTS idx_episode_links_target ON episode_links (target_episode_id); + ` + if _, err := s.db.Exec(graphSchema); err != nil { + return fmt.Errorf("failed to create graph schema: %w", err) + } + // Run migrations for existing databases if err := s.migrate(); err != nil { return fmt.Errorf("failed to run migrations: %w", err) @@ -739,6 +791,314 @@ func (s *Store) ensureFTSIndex() error { return nil } +// InsertEntity adds a new entity to the store, or returns the existing entity ID if a +// semantically matching entity already exists (cosine similarity >= threshold). +func (s *Store) InsertEntity(ctx context.Context, entity *models.Entity, similarityThreshold float64) (*models.Entity, error) { + if similarityThreshold <= 0 { + similarityThreshold = 0.88 + } + + groupFilter := "default" + if entity.GroupID != "" { + groupFilter = entity.GroupID + } + + // Entity resolution uses two string-based strategies: + // 1. Case-insensitive exact match: "Mike" = "mike" + // 2. Normalized match (strip non-alphanumeric, lowercase): "OscillateLabs" = "Oscillate Labs" + // + // Embedding-based resolution is NOT used for entity names. nomic-embed-text + // produces degenerate embeddings for short texts — "Mike" vs "DuckDB" scores + // 0.9594, same as "Mike" vs "Oscillate Labs". The embedding space is unusable + // for distinguishing entity names. Embeddings are still stored on entities for + // future use with longer-context models. + var existingID, existingName, existingType string + var existingCreatedAt time.Time + + // First pass: case-insensitive exact name match + err := s.db.QueryRowContext(ctx, ` + SELECT id, canonical_name, entity_type, created_at + FROM entities + WHERE LOWER(canonical_name) = LOWER(?) + AND group_id = ? + LIMIT 1 + `, entity.CanonicalName, groupFilter).Scan(&existingID, &existingName, &existingType, &existingCreatedAt) + + if err == nil { + return &models.Entity{ + ID: existingID, + CanonicalName: existingName, + EntityType: existingType, + GroupID: groupFilter, + CreatedAt: existingCreatedAt, + }, nil + } + if err != sql.ErrNoRows { + return nil, fmt.Errorf("failed to check entity name match: %w", err) + } + + // Second pass: normalized match — strip non-alphanumeric chars and compare. + // Catches "OscillateLabs" vs "Oscillate Labs" vs "Oscillate Labs LLC" + normalizedName := normalizeEntityName(entity.CanonicalName) + rows, err := s.db.QueryContext(ctx, ` + SELECT id, canonical_name, entity_type, created_at + FROM entities + WHERE group_id = ? + `, groupFilter) + if err != nil { + return nil, fmt.Errorf("failed to query entities for normalized match: %w", err) + } + defer rows.Close() + + for rows.Next() { + var candID, candName, candType string + var candCreatedAt time.Time + if err := rows.Scan(&candID, &candName, &candType, &candCreatedAt); err != nil { + continue + } + if normalizeEntityName(candName) == normalizedName { + return &models.Entity{ + ID: candID, + CanonicalName: candName, + EntityType: candType, + GroupID: groupFilter, + CreatedAt: candCreatedAt, + }, nil + } + } + + // No match found — insert new entity + if entity.ID == "" { + entity.ID = uuid.New().String() + } + if entity.CreatedAt.IsZero() { + entity.CreatedAt = time.Now() + } + if entity.GroupID == "" { + entity.GroupID = "default" + } + + var embeddingJSON interface{} + if len(entity.Embedding) > 0 { + data, _ := json.Marshal(entity.Embedding) + embeddingJSON = string(data) + } + + var metadataJSON interface{} + if entity.Metadata != "" { + metadataJSON = entity.Metadata + } + + _, err = s.db.ExecContext(ctx, ` + INSERT INTO entities (id, canonical_name, entity_type, embedding, group_id, created_at, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?) + `, entity.ID, entity.CanonicalName, entity.EntityType, embeddingJSON, entity.GroupID, entity.CreatedAt, metadataJSON) + if err != nil { + return nil, fmt.Errorf("failed to insert entity: %w", err) + } + + return entity, nil +} + +// GetEntity retrieves a single entity by ID +func (s *Store) GetEntity(ctx context.Context, id string) (*models.Entity, error) { + var entity models.Entity + var metadataRaw interface{} + err := s.db.QueryRowContext(ctx, ` + SELECT id, canonical_name, entity_type, group_id, created_at, metadata + FROM entities WHERE id = ? + `, id).Scan(&entity.ID, &entity.CanonicalName, &entity.EntityType, &entity.GroupID, &entity.CreatedAt, &metadataRaw) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("entity not found: %s", id) + } + if err != nil { + return nil, fmt.Errorf("failed to get entity: %w", err) + } + if metadataRaw != nil { + switch v := metadataRaw.(type) { + case map[string]interface{}: + if data, err := json.Marshal(v); err == nil { + entity.Metadata = string(data) + } + case string: + entity.Metadata = v + } + } + return &entity, nil +} + +// InsertKnowledgeTriple adds a knowledge triple to the store +func (s *Store) InsertKnowledgeTriple(ctx context.Context, triple *models.KnowledgeTriple) error { + if triple.ID == "" { + triple.ID = uuid.New().String() + } + if triple.CreatedAt.IsZero() { + triple.CreatedAt = time.Now() + } + if triple.GroupID == "" { + triple.GroupID = "default" + } + if triple.Confidence == 0 { + triple.Confidence = 1.0 + } + + var embeddingJSON interface{} + if len(triple.Embedding) > 0 { + data, _ := json.Marshal(triple.Embedding) + embeddingJSON = string(data) + } + + var metadataJSON interface{} + if triple.Metadata != "" { + metadataJSON = triple.Metadata + } + + _, err := s.db.ExecContext(ctx, ` + INSERT INTO knowledge ( + id, subject_entity_id, predicate, object_entity_id, + source_episode_id, source, group_id, embedding, + created_at, expired_at, confidence, verified, metadata + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + triple.ID, triple.SubjectEntityID, triple.Predicate, triple.ObjectEntityID, + triple.SourceEpisodeID, triple.Source, triple.GroupID, embeddingJSON, + triple.CreatedAt, triple.ExpiredAt, triple.Confidence, triple.Verified, metadataJSON, + ) + if err != nil { + return fmt.Errorf("failed to insert knowledge triple: %w", err) + } + return nil +} + +// SearchKnowledge finds knowledge triples matching the given query embedding +func (s *Store) SearchKnowledge(ctx context.Context, queryEmbedding []float32, groupID string, maxResults int, minSimilarity float64) ([]models.KnowledgeTriple, error) { + if maxResults <= 0 { + maxResults = 10 + } + if minSimilarity <= 0 { + minSimilarity = 0.35 + } + + embJSON, err := json.Marshal(queryEmbedding) + if err != nil { + return nil, fmt.Errorf("failed to marshal query embedding: %w", err) + } + + query := fmt.Sprintf(` + SELECT k.id, k.subject_entity_id, k.predicate, k.object_entity_id, + k.source_episode_id, k.source, k.group_id, k.created_at, + k.expired_at, k.confidence, k.verified, k.metadata, + se.canonical_name AS subject_name, oe.canonical_name AS object_name, + array_cosine_similarity(k.embedding, %s::FLOAT[768]) AS similarity + FROM knowledge k + JOIN entities se ON k.subject_entity_id = se.id + JOIN entities oe ON k.object_entity_id = oe.id + WHERE k.embedding IS NOT NULL + AND (k.expired_at IS NULL OR k.expired_at > CURRENT_TIMESTAMP) + AND array_cosine_similarity(k.embedding, %s::FLOAT[768]) >= %f + `, string(embJSON), string(embJSON), minSimilarity) + + if groupID != "" { + query += fmt.Sprintf(" AND k.group_id = '%s'", groupID) + } + + query += fmt.Sprintf(" ORDER BY similarity DESC LIMIT %d", maxResults) + + rows, err := s.db.QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("failed to search knowledge: %w", err) + } + defer rows.Close() + + var triples []models.KnowledgeTriple + for rows.Next() { + var t models.KnowledgeTriple + var metadataRaw interface{} + var similarity sql.NullFloat64 + err := rows.Scan( + &t.ID, &t.SubjectEntityID, &t.Predicate, &t.ObjectEntityID, + &t.SourceEpisodeID, &t.Source, &t.GroupID, &t.CreatedAt, + &t.ExpiredAt, &t.Confidence, &t.Verified, &metadataRaw, + &t.SubjectName, &t.ObjectName, &similarity, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan knowledge triple: %w", err) + } + if metadataRaw != nil { + switch v := metadataRaw.(type) { + case map[string]interface{}: + if data, err := json.Marshal(v); err == nil { + t.Metadata = string(data) + } + case string: + t.Metadata = v + } + } + triples = append(triples, t) + } + return triples, rows.Err() +} + +// InsertEpisodeLink adds a link between two episodes +func (s *Store) InsertEpisodeLink(ctx context.Context, link *models.EpisodeLink) error { + if link.ID == "" { + link.ID = uuid.New().String() + } + if link.CreatedAt.IsZero() { + link.CreatedAt = time.Now() + } + if link.Weight == 0 { + link.Weight = 1.0 + } + + var viaEntityID interface{} + if link.ViaEntityID != "" { + viaEntityID = link.ViaEntityID + } + + // UNIQUE constraint on (source_episode_id, target_episode_id, relationship) + // handles deduplication at the database level. INSERT OR IGNORE silently + // skips duplicates without a race-prone check-then-insert pattern. + _, err := s.db.ExecContext(ctx, ` + INSERT OR IGNORE INTO episode_links (id, source_episode_id, target_episode_id, relationship, via_entity_id, weight, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + `, link.ID, link.SourceEpisodeID, link.TargetEpisodeID, link.Relationship, viaEntityID, link.Weight, link.CreatedAt) + if err != nil { + return fmt.Errorf("failed to insert episode link: %w", err) + } + return nil +} + +// GetEpisodeLinks retrieves all links for a given episode (both directions) +func (s *Store) GetEpisodeLinks(ctx context.Context, episodeID string) ([]models.EpisodeLink, error) { + rows, err := s.db.QueryContext(ctx, ` + SELECT id, source_episode_id, target_episode_id, relationship, via_entity_id, weight, created_at + FROM episode_links + WHERE source_episode_id = ? OR target_episode_id = ? + ORDER BY created_at DESC + `, episodeID, episodeID) + if err != nil { + return nil, fmt.Errorf("failed to get episode links: %w", err) + } + defer rows.Close() + + var links []models.EpisodeLink + for rows.Next() { + var link models.EpisodeLink + var viaEntityID sql.NullString + err := rows.Scan(&link.ID, &link.SourceEpisodeID, &link.TargetEpisodeID, + &link.Relationship, &viaEntityID, &link.Weight, &link.CreatedAt) + if err != nil { + return nil, fmt.Errorf("failed to scan episode link: %w", err) + } + if viaEntityID.Valid { + link.ViaEntityID = viaEntityID.String + } + links = append(links, link) + } + return links, rows.Err() +} + // Close closes the database connection func (s *Store) Close() error { return s.db.Close() @@ -772,6 +1132,19 @@ func sanitizeFTSQuery(s string) string { return strings.ReplaceAll(s, "'", "''") } +// normalizeEntityName strips non-alphanumeric characters and lowercases for +// fuzzy entity matching. "OscillateLabs" and "Oscillate Labs" both normalize +// to "oscillatelabs". +func normalizeEntityName(name string) string { + var b strings.Builder + for _, r := range strings.ToLower(name) { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + b.WriteRune(r) + } + } + return b.String() +} + // Helper functions for scanning rows func (s *Store) scanEpisode(row *sql.Row) (*models.Episode, error) { diff --git a/internal/db/duckdb_test.go b/internal/db/duckdb_test.go index d02afad..2a6a09f 100644 --- a/internal/db/duckdb_test.go +++ b/internal/db/duckdb_test.go @@ -1475,6 +1475,412 @@ func containsWord(text, word string) bool { return strings.Contains(lower, strings.ToLower(word)) } +func TestInsertEntity(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + ctx := context.Background() + + t.Run("creates new entity", func(t *testing.T) { + entity, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "DuckDB", + EntityType: "tool", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert entity: %v", err) + } + if entity.ID == "" { + t.Error("Entity ID was not generated") + } + if entity.CanonicalName != "DuckDB" { + t.Errorf("Expected canonical name 'DuckDB', got %q", entity.CanonicalName) + } + }) + + t.Run("defaults GroupID to default", func(t *testing.T) { + entity, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "TestEntity", + EntityType: "concept", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert entity: %v", err) + } + if entity.GroupID != "default" { + t.Errorf("Expected group_id 'default', got %q", entity.GroupID) + } + }) + + t.Run("case-insensitive resolution", func(t *testing.T) { + // "duckdb" should resolve to existing "DuckDB" + resolved, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "duckdb", + EntityType: "tool", + }, 0.88) + if err != nil { + t.Fatalf("Failed to resolve entity: %v", err) + } + if resolved.CanonicalName != "DuckDB" { + t.Errorf("Expected resolution to 'DuckDB', got %q", resolved.CanonicalName) + } + }) + + t.Run("normalized resolution strips spaces", func(t *testing.T) { + // Create "OscillateLabs" first + original, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "OscillateLabs", + EntityType: "organization", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + // "Oscillate Labs" should resolve to "OscillateLabs" via normalized match + resolved, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "Oscillate Labs", + EntityType: "organization", + }, 0.88) + if err != nil { + t.Fatalf("Failed to resolve: %v", err) + } + if resolved.ID != original.ID { + t.Errorf("Expected same ID %s, got %s", original.ID, resolved.ID) + } + }) + + t.Run("normalized resolution handles hyphens and dots", func(t *testing.T) { + // Create "HomeAssistant" + original, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "HomeAssistant", + EntityType: "tool", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + // "Home-Assistant" should resolve to "HomeAssistant" (strip hyphen) + resolved, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "Home-Assistant", + EntityType: "tool", + }, 0.88) + if err != nil { + t.Fatalf("Failed to resolve: %v", err) + } + if resolved.ID != original.ID { + t.Errorf("Expected same ID %s, got %s", original.ID, resolved.ID) + } + }) + + t.Run("LLC suffix creates distinct entity", func(t *testing.T) { + // "Oscillate Labs LLC" normalizes to "oscillatelabsllc" while + // "OscillateLabs" normalizes to "oscillatelabs" — these are + // genuinely different normalized forms, so distinct entities. + entity, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "Oscillate Labs LLC", + EntityType: "organization", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + if entity.CanonicalName == "OscillateLabs" { + t.Error("LLC variant should be a distinct entity from base name") + } + }) + + t.Run("distinct entities stay separate", func(t *testing.T) { + python, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "Python", + EntityType: "tool", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + golang, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "Go", + EntityType: "tool", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + if python.ID == golang.ID { + t.Error("Python and Go should be distinct entities") + } + }) + + t.Run("different groups are independent", func(t *testing.T) { + e1, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "SharedName", + EntityType: "concept", + GroupID: "group-a", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + e2, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "SharedName", + EntityType: "concept", + GroupID: "group-b", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + if e1.ID == e2.ID { + t.Error("Same name in different groups should be distinct entities") + } + }) + + t.Run("GetEntity retrieves by ID", func(t *testing.T) { + entity, err := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "RetrieveMe", + EntityType: "person", + }, 0.88) + if err != nil { + t.Fatalf("Failed to insert: %v", err) + } + + retrieved, err := store.GetEntity(ctx, entity.ID) + if err != nil { + t.Fatalf("Failed to get entity: %v", err) + } + if retrieved.CanonicalName != "RetrieveMe" { + t.Errorf("Expected 'RetrieveMe', got %q", retrieved.CanonicalName) + } + }) +} + +func TestInsertKnowledgeTriple(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + ctx := context.Background() + + // Create entities first + subject, _ := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "Engram", + EntityType: "project", + }, 0.88) + object, _ := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "DuckDB", + EntityType: "tool", + }, 0.88) + + t.Run("inserts triple with defaults", func(t *testing.T) { + triple := &models.KnowledgeTriple{ + SubjectEntityID: subject.ID, + Predicate: "uses", + ObjectEntityID: object.ID, + Source: "test", + } + err := store.InsertKnowledgeTriple(ctx, triple) + if err != nil { + t.Fatalf("Failed to insert triple: %v", err) + } + if triple.ID == "" { + t.Error("Triple ID was not generated") + } + if triple.Confidence != 1.0 { + t.Errorf("Expected default confidence 1.0, got %f", triple.Confidence) + } + }) + + t.Run("inserts triple with custom confidence", func(t *testing.T) { + triple := &models.KnowledgeTriple{ + SubjectEntityID: subject.ID, + Predicate: "depends_on", + ObjectEntityID: object.ID, + Source: "dreamer/qwen3:8b", + Confidence: 0.72, + } + err := store.InsertKnowledgeTriple(ctx, triple) + if err != nil { + t.Fatalf("Failed to insert triple: %v", err) + } + if triple.Confidence != 0.72 { + t.Errorf("Expected confidence 0.72, got %f", triple.Confidence) + } + }) +} + +func TestInsertEpisodeLink(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + ctx := context.Background() + + // Create two episodes + ep1 := &models.Episode{Content: "Episode about DuckDB", Source: "test"} + ep2 := &models.Episode{Content: "Another episode about DuckDB", Source: "test"} + store.InsertEpisode(ctx, ep1) + store.InsertEpisode(ctx, ep2) + + t.Run("creates link between episodes", func(t *testing.T) { + link := &models.EpisodeLink{ + SourceEpisodeID: ep1.ID, + TargetEpisodeID: ep2.ID, + Relationship: "same_entity", + } + err := store.InsertEpisodeLink(ctx, link) + if err != nil { + t.Fatalf("Failed to insert link: %v", err) + } + if link.ID == "" { + t.Error("Link ID was not generated") + } + if link.Weight != 1.0 { + t.Errorf("Expected default weight 1.0, got %f", link.Weight) + } + }) + + t.Run("deduplicates existing links", func(t *testing.T) { + // Insert same link again — should not error + link := &models.EpisodeLink{ + SourceEpisodeID: ep1.ID, + TargetEpisodeID: ep2.ID, + Relationship: "same_entity", + } + err := store.InsertEpisodeLink(ctx, link) + if err != nil { + t.Fatalf("Duplicate link insert should not error: %v", err) + } + + // Verify only one link exists + links, err := store.GetEpisodeLinks(ctx, ep1.ID) + if err != nil { + t.Fatalf("Failed to get links: %v", err) + } + count := 0 + for _, l := range links { + if l.Relationship == "same_entity" { + count++ + } + } + if count != 1 { + t.Errorf("Expected 1 same_entity link, got %d", count) + } + }) + + t.Run("GetEpisodeLinks returns links in both directions", func(t *testing.T) { + // Query from target side + links, err := store.GetEpisodeLinks(ctx, ep2.ID) + if err != nil { + t.Fatalf("Failed to get links: %v", err) + } + if len(links) == 0 { + t.Error("Expected to find link from target side") + } + }) + + t.Run("link with via_entity_id", func(t *testing.T) { + entity, _ := store.InsertEntity(ctx, &models.Entity{ + CanonicalName: "DuckDB", + EntityType: "tool", + }, 0.88) + + link := &models.EpisodeLink{ + SourceEpisodeID: ep1.ID, + TargetEpisodeID: ep2.ID, + Relationship: "elaborates", + ViaEntityID: entity.ID, + } + err := store.InsertEpisodeLink(ctx, link) + if err != nil { + t.Fatalf("Failed to insert link with via_entity_id: %v", err) + } + + links, _ := store.GetEpisodeLinks(ctx, ep1.ID) + found := false + for _, l := range links { + if l.Relationship == "elaborates" && l.ViaEntityID == entity.ID { + found = true + } + } + if !found { + t.Error("Expected to find link with via_entity_id") + } + }) +} + +func TestNormalizeEntityName(t *testing.T) { + tests := []struct { + input, expected string + }{ + {"OscillateLabs", "oscillatelabs"}, + {"Oscillate Labs", "oscillatelabs"}, + {"Oscillate Labs, LLC", "oscillatelabsllc"}, + {"DuckDB", "duckdb"}, + {"duck-db", "duckdb"}, + {"Mike", "mike"}, + {"mike", "mike"}, + {"", ""}, + {"Hello World 123", "helloworld123"}, + } + for _, tt := range tests { + got := normalizeEntityName(tt.input) + if got != tt.expected { + t.Errorf("normalizeEntityName(%q) = %q, want %q", tt.input, got, tt.expected) + } + } +} + +func TestEpisodeLinkUniqueConstraint(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + ctx := context.Background() + + ep1 := &models.Episode{Content: "First", Source: "test"} + ep2 := &models.Episode{Content: "Second", Source: "test"} + store.InsertEpisode(ctx, ep1) + store.InsertEpisode(ctx, ep2) + + // First insert succeeds + err := store.InsertEpisodeLink(ctx, &models.EpisodeLink{ + SourceEpisodeID: ep1.ID, + TargetEpisodeID: ep2.ID, + Relationship: "same_entity", + }) + if err != nil { + t.Fatalf("First link insert failed: %v", err) + } + + // Duplicate should succeed silently (INSERT OR IGNORE) + err = store.InsertEpisodeLink(ctx, &models.EpisodeLink{ + SourceEpisodeID: ep1.ID, + TargetEpisodeID: ep2.ID, + Relationship: "same_entity", + }) + if err != nil { + t.Fatalf("Duplicate link insert should not error: %v", err) + } + + // Different relationship should create a new link + err = store.InsertEpisodeLink(ctx, &models.EpisodeLink{ + SourceEpisodeID: ep1.ID, + TargetEpisodeID: ep2.ID, + Relationship: "elaborates", + }) + if err != nil { + t.Fatalf("Different relationship link failed: %v", err) + } + + links, _ := store.GetEpisodeLinks(ctx, ep1.ID) + if len(links) != 2 { + t.Errorf("Expected 2 links (same_entity + elaborates), got %d", len(links)) + } +} + +func TestGraphTablesCreated(t *testing.T) { + store := setupTestStore(t) + defer store.Close() + + // Verify all three graph tables exist + tables := []string{"entities", "knowledge", "episode_links"} + for _, table := range tables { + var count int + err := store.db.QueryRow("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = ?", table).Scan(&count) + if err != nil { + t.Fatalf("Failed to check table %s: %v", table, err) + } + if count == 0 { + t.Errorf("Table %s was not created", table) + } + } +} + func setupTestStore(t *testing.T) *Store { t.Helper() tmpFile := t.TempDir() + "/test.duckdb" diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 8cf8019..6e0c607 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -221,6 +221,84 @@ func (s *Server) registerTools() { }, }, s.handleUpdateEpisode) + // add_knowledge tool + s.mcpServer.AddTool(mcp.Tool{ + Name: "add_knowledge", + Description: "Store a knowledge fact as a subject-predicate-object triple. Triples link back to their source episode for provenance. Entities are automatically resolved — if a semantically matching entity already exists, it will be reused rather than duplicated.\n\nAllowed predicates: owns, works_at, contributes_to, uses, prefers, builds, depends_on, located_in, related_to, part_of, instance_of, created_by, configured_with, deployed_on, communicates_via", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "subject": map[string]interface{}{ + "type": "string", + "description": "The subject entity (e.g., 'Mike', 'Engram project')", + }, + "predicate": map[string]interface{}{ + "type": "string", + "description": "The relationship", + "enum": []string{"owns", "works_at", "contributes_to", "uses", "prefers", "builds", "depends_on", "located_in", "related_to", "part_of", "instance_of", "created_by", "configured_with", "deployed_on", "communicates_via"}, + }, + "object": map[string]interface{}{ + "type": "string", + "description": "The object entity (e.g., 'DuckDB', 'OscillateLabs')", + }, + "subject_type": map[string]interface{}{ + "type": "string", + "description": "Type of the subject entity", + "enum": []string{"person", "project", "tool", "organization", "place", "concept"}, + }, + "object_type": map[string]interface{}{ + "type": "string", + "description": "Type of the object entity", + "enum": []string{"person", "project", "tool", "organization", "place", "concept"}, + }, + "source_episode_id": map[string]interface{}{ + "type": "string", + "description": "Episode ID this fact was derived from", + }, + "source": map[string]interface{}{ + "type": "string", + "description": "Source identifier (e.g., 'claude-code/opus-4.6')", + }, + "group_id": map[string]interface{}{ + "type": "string", + "description": "Group namespace", + }, + }, + Required: []string{"subject", "predicate", "object", "source"}, + }, + }, s.handleAddKnowledge) + + // link_episodes tool + s.mcpServer.AddTool(mcp.Tool{ + Name: "link_episodes", + Description: "Create a directional link between two episodes that share entities, topics, or narrative continuity. Duplicate links (same source, target, and relationship) are silently skipped.", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]interface{}{ + "source_episode_id": map[string]interface{}{ + "type": "string", + "description": "ID of the source episode", + }, + "target_episode_id": map[string]interface{}{ + "type": "string", + "description": "ID of the target episode", + }, + "relationship": map[string]interface{}{ + "type": "string", + "description": "Type of link between the episodes", + "enum": []string{"same_entity", "follows_up", "contradicts", "elaborates", "supersedes"}, + }, + "weight": map[string]interface{}{ + "type": "number", + "description": "Link strength 0.0-1.0 (default: 1.0)", + "minimum": 0.0, + "maximum": 1.0, + }, + }, + Required: []string{"source_episode_id", "target_episode_id", "relationship"}, + }, + }, s.handleLinkEpisodes) + // get_status tool s.mcpServer.AddTool(mcp.Tool{ Name: "get_status", @@ -495,6 +573,150 @@ func (s *Server) handleGetStatus(ctx context.Context, request mcp.CallToolReques return mcp.NewToolResultText(string(result)), nil } +func (s *Server) handleAddKnowledge(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var params struct { + Subject string `json:"subject"` + Predicate string `json:"predicate"` + Object string `json:"object"` + SubjectType string `json:"subject_type"` + ObjectType string `json:"object_type"` + SourceEpisodeID string `json:"source_episode_id"` + Source string `json:"source"` + GroupID string `json:"group_id"` + } + + if err := parseParams(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid parameters: %v", err)), nil + } + + // Validate predicate against controlled vocabulary + validPredicates := map[string]bool{ + "owns": true, "works_at": true, "contributes_to": true, "uses": true, + "prefers": true, "builds": true, "depends_on": true, "located_in": true, + "related_to": true, "part_of": true, "instance_of": true, "created_by": true, + "configured_with": true, "deployed_on": true, "communicates_via": true, + } + if !validPredicates[params.Predicate] { + return mcp.NewToolResultError(fmt.Sprintf("invalid predicate %q: must be one of: owns, works_at, contributes_to, uses, prefers, builds, depends_on, located_in, related_to, part_of, instance_of, created_by, configured_with, deployed_on, communicates_via", params.Predicate)), nil + } + + // Resolve subject entity (embed name, match or create) + embedCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + subjectEmb, err := s.embedder.Generate(embedCtx, params.Subject) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to generate subject embedding: %v\n", err) + } + + subjectEntity, err := s.store.InsertEntity(ctx, &models.Entity{ + CanonicalName: params.Subject, + EntityType: params.SubjectType, + Embedding: subjectEmb, + GroupID: params.GroupID, + }, 0.88) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to resolve subject entity: %v", err)), nil + } + + // Resolve object entity + embedCtx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel2() + + objectEmb, err := s.embedder.Generate(embedCtx2, params.Object) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to generate object embedding: %v\n", err) + } + + objectEntity, err := s.store.InsertEntity(ctx, &models.Entity{ + CanonicalName: params.Object, + EntityType: params.ObjectType, + Embedding: objectEmb, + GroupID: params.GroupID, + }, 0.88) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to resolve object entity: %v", err)), nil + } + + // Generate embedding for the triple itself + tripleText := fmt.Sprintf("%s %s %s", subjectEntity.CanonicalName, params.Predicate, objectEntity.CanonicalName) + embedCtx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel3() + + tripleEmb, err := s.embedder.Generate(embedCtx3, tripleText) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: Failed to generate triple embedding: %v\n", err) + } + + triple := &models.KnowledgeTriple{ + SubjectEntityID: subjectEntity.ID, + Predicate: params.Predicate, + ObjectEntityID: objectEntity.ID, + SourceEpisodeID: params.SourceEpisodeID, + Source: params.Source, + GroupID: params.GroupID, + Embedding: tripleEmb, + Confidence: 1.0, // Client-written triples get full confidence + } + + if err := s.store.InsertKnowledgeTriple(ctx, triple); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to store knowledge triple: %v", err)), nil + } + + result, _ := json.Marshal(map[string]interface{}{ + "success": true, + "triple_id": triple.ID, + "subject_entity_id": subjectEntity.ID, + "subject_name": subjectEntity.CanonicalName, + "object_entity_id": objectEntity.ID, + "object_name": objectEntity.CanonicalName, + "message": fmt.Sprintf("Stored: %s %s %s", subjectEntity.CanonicalName, params.Predicate, objectEntity.CanonicalName), + }) + + return mcp.NewToolResultText(string(result)), nil +} + +func (s *Server) handleLinkEpisodes(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var params struct { + SourceEpisodeID string `json:"source_episode_id"` + TargetEpisodeID string `json:"target_episode_id"` + Relationship string `json:"relationship"` + Weight float64 `json:"weight"` + } + + if err := parseParams(request.Params.Arguments, ¶ms); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("invalid parameters: %v", err)), nil + } + + // Validate relationship + validRelationships := map[string]bool{ + "same_entity": true, "follows_up": true, "contradicts": true, + "elaborates": true, "supersedes": true, + } + if !validRelationships[params.Relationship] { + return mcp.NewToolResultError(fmt.Sprintf("invalid relationship %q: must be one of: same_entity, follows_up, contradicts, elaborates, supersedes", params.Relationship)), nil + } + + link := &models.EpisodeLink{ + SourceEpisodeID: params.SourceEpisodeID, + TargetEpisodeID: params.TargetEpisodeID, + Relationship: params.Relationship, + Weight: params.Weight, + } + + if err := s.store.InsertEpisodeLink(ctx, link); err != nil { + return mcp.NewToolResultError(fmt.Sprintf("failed to link episodes: %v", err)), nil + } + + result, _ := json.Marshal(map[string]interface{}{ + "success": true, + "link_id": link.ID, + "message": fmt.Sprintf("Linked %s -[%s]-> %s", params.SourceEpisodeID, params.Relationship, params.TargetEpisodeID), + }) + + return mcp.NewToolResultText(string(result)), nil +} + // Serve starts the MCP server with stdio transport func (s *Server) Serve() error { return server.ServeStdio(s.mcpServer) diff --git a/internal/models/episode.go b/internal/models/episode.go index 892f4ce..865b49a 100644 --- a/internal/models/episode.go +++ b/internal/models/episode.go @@ -44,3 +44,46 @@ type UpdateParams struct { ExpiredAt *time.Time `json:"expired_at,omitempty"` Metadata *string `json:"metadata,omitempty"` } + +// Entity represents a canonical entity in the knowledge graph +type Entity struct { + ID string `json:"id"` + CanonicalName string `json:"canonical_name"` + EntityType string `json:"entity_type,omitempty"` // person, project, tool, org, concept + Embedding []float32 `json:"embedding,omitempty"` + GroupID string `json:"group_id"` + CreatedAt time.Time `json:"created_at"` + Metadata string `json:"metadata,omitempty"` // JSON string +} + +// KnowledgeTriple represents a subject-predicate-object fact in the knowledge graph +type KnowledgeTriple struct { + ID string `json:"id"` + SubjectEntityID string `json:"subject_entity_id"` + Predicate string `json:"predicate"` + ObjectEntityID string `json:"object_entity_id"` + SourceEpisodeID string `json:"source_episode_id,omitempty"` + Source string `json:"source"` + GroupID string `json:"group_id"` + Embedding []float32 `json:"embedding,omitempty"` + CreatedAt time.Time `json:"created_at"` + ExpiredAt *time.Time `json:"expired_at,omitempty"` + Confidence float64 `json:"confidence"` + Verified bool `json:"verified"` + Metadata string `json:"metadata,omitempty"` // JSON string + + // Denormalized fields populated during reads + SubjectName string `json:"subject_name,omitempty"` + ObjectName string `json:"object_name,omitempty"` +} + +// EpisodeLink represents a directional link between two episodes +type EpisodeLink struct { + ID string `json:"id"` + SourceEpisodeID string `json:"source_episode_id"` + TargetEpisodeID string `json:"target_episode_id"` + Relationship string `json:"relationship"` // same_entity, follows_up, contradicts, elaborates, supersedes + ViaEntityID string `json:"via_entity_id,omitempty"` + Weight float64 `json:"weight"` + CreatedAt time.Time `json:"created_at"` +}