Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 373 additions & 0 deletions internal/db/duckdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,58 @@ func (s *Store) initialize() error {
return fmt.Errorf("failed to execute schema: %w", err)
}

// Create knowledge graph tables
graphSchema := `
CREATE TABLE IF NOT EXISTS entities (
id VARCHAR PRIMARY KEY,
canonical_name TEXT NOT NULL,
entity_type VARCHAR,
embedding FLOAT[768],
group_id VARCHAR DEFAULT 'default',
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
metadata JSON
);
CREATE INDEX IF NOT EXISTS idx_entities_group_id ON entities (group_id);
CREATE INDEX IF NOT EXISTS idx_entities_canonical_name ON entities (canonical_name);
CREATE UNIQUE INDEX IF NOT EXISTS idx_entities_name_group ON entities (LOWER(canonical_name), group_id);

CREATE TABLE IF NOT EXISTS knowledge (
id VARCHAR PRIMARY KEY,
subject_entity_id VARCHAR NOT NULL,
predicate VARCHAR NOT NULL,
object_entity_id VARCHAR NOT NULL,
source_episode_id VARCHAR,
source VARCHAR NOT NULL,
group_id VARCHAR DEFAULT 'default',
embedding FLOAT[768],
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
expired_at TIMESTAMPTZ,
confidence FLOAT DEFAULT 1.0,
verified BOOLEAN DEFAULT FALSE,
metadata JSON
);
CREATE INDEX IF NOT EXISTS idx_knowledge_subject ON knowledge (subject_entity_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_object ON knowledge (object_entity_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_source_episode ON knowledge (source_episode_id);
CREATE INDEX IF NOT EXISTS idx_knowledge_group_id ON knowledge (group_id);

CREATE TABLE IF NOT EXISTS episode_links (
id VARCHAR PRIMARY KEY,
source_episode_id VARCHAR NOT NULL,
target_episode_id VARCHAR NOT NULL,
relationship VARCHAR NOT NULL,
via_entity_id VARCHAR,
weight FLOAT DEFAULT 1.0,
created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
UNIQUE (source_episode_id, target_episode_id, relationship)
);
CREATE INDEX IF NOT EXISTS idx_episode_links_source ON episode_links (source_episode_id);
CREATE INDEX IF NOT EXISTS idx_episode_links_target ON episode_links (target_episode_id);
`
if _, err := s.db.Exec(graphSchema); err != nil {
return fmt.Errorf("failed to create graph schema: %w", err)
}

// Run migrations for existing databases
if err := s.migrate(); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
Expand Down Expand Up @@ -739,6 +791,314 @@ func (s *Store) ensureFTSIndex() error {
return nil
}

// InsertEntity adds a new entity to the store, or returns the existing entity ID if a
// semantically matching entity already exists (cosine similarity >= threshold).
func (s *Store) InsertEntity(ctx context.Context, entity *models.Entity, similarityThreshold float64) (*models.Entity, error) {
if similarityThreshold <= 0 {
similarityThreshold = 0.88
}

groupFilter := "default"
if entity.GroupID != "" {
groupFilter = entity.GroupID
}

// Entity resolution uses two string-based strategies:
// 1. Case-insensitive exact match: "Mike" = "mike"
// 2. Normalized match (strip non-alphanumeric, lowercase): "OscillateLabs" = "Oscillate Labs"
//
// Embedding-based resolution is NOT used for entity names. nomic-embed-text
// produces degenerate embeddings for short texts — "Mike" vs "DuckDB" scores
// 0.9594, same as "Mike" vs "Oscillate Labs". The embedding space is unusable
// for distinguishing entity names. Embeddings are still stored on entities for
// future use with longer-context models.
var existingID, existingName, existingType string
var existingCreatedAt time.Time

// First pass: case-insensitive exact name match
err := s.db.QueryRowContext(ctx, `
SELECT id, canonical_name, entity_type, created_at
FROM entities
WHERE LOWER(canonical_name) = LOWER(?)
AND group_id = ?
LIMIT 1
`, entity.CanonicalName, groupFilter).Scan(&existingID, &existingName, &existingType, &existingCreatedAt)

if err == nil {
return &models.Entity{
ID: existingID,
CanonicalName: existingName,
EntityType: existingType,
GroupID: groupFilter,
CreatedAt: existingCreatedAt,
}, nil
}
if err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to check entity name match: %w", err)
}

// Second pass: normalized match — strip non-alphanumeric chars and compare.
// Catches "OscillateLabs" vs "Oscillate Labs" vs "Oscillate Labs LLC"
normalizedName := normalizeEntityName(entity.CanonicalName)
rows, err := s.db.QueryContext(ctx, `
SELECT id, canonical_name, entity_type, created_at
FROM entities
WHERE group_id = ?
`, groupFilter)
if err != nil {
return nil, fmt.Errorf("failed to query entities for normalized match: %w", err)
}
defer rows.Close()

for rows.Next() {
var candID, candName, candType string
var candCreatedAt time.Time
if err := rows.Scan(&candID, &candName, &candType, &candCreatedAt); err != nil {
continue
}
if normalizeEntityName(candName) == normalizedName {
return &models.Entity{
ID: candID,
CanonicalName: candName,
EntityType: candType,
GroupID: groupFilter,
CreatedAt: candCreatedAt,
}, nil
}
}

// No match found — insert new entity
if entity.ID == "" {
entity.ID = uuid.New().String()
}
if entity.CreatedAt.IsZero() {
entity.CreatedAt = time.Now()
}
if entity.GroupID == "" {
entity.GroupID = "default"
}

var embeddingJSON interface{}
if len(entity.Embedding) > 0 {
data, _ := json.Marshal(entity.Embedding)
embeddingJSON = string(data)
}

var metadataJSON interface{}
if entity.Metadata != "" {
metadataJSON = entity.Metadata
}

_, err = s.db.ExecContext(ctx, `
INSERT INTO entities (id, canonical_name, entity_type, embedding, group_id, created_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, entity.ID, entity.CanonicalName, entity.EntityType, embeddingJSON, entity.GroupID, entity.CreatedAt, metadataJSON)
if err != nil {
return nil, fmt.Errorf("failed to insert entity: %w", err)
}

return entity, nil
}

// GetEntity retrieves a single entity by ID
func (s *Store) GetEntity(ctx context.Context, id string) (*models.Entity, error) {
var entity models.Entity
var metadataRaw interface{}
err := s.db.QueryRowContext(ctx, `
SELECT id, canonical_name, entity_type, group_id, created_at, metadata
FROM entities WHERE id = ?
`, id).Scan(&entity.ID, &entity.CanonicalName, &entity.EntityType, &entity.GroupID, &entity.CreatedAt, &metadataRaw)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("entity not found: %s", id)
}
if err != nil {
return nil, fmt.Errorf("failed to get entity: %w", err)
}
if metadataRaw != nil {
switch v := metadataRaw.(type) {
case map[string]interface{}:
if data, err := json.Marshal(v); err == nil {
entity.Metadata = string(data)
}
case string:
entity.Metadata = v
}
}
return &entity, nil
}

// InsertKnowledgeTriple adds a knowledge triple to the store
func (s *Store) InsertKnowledgeTriple(ctx context.Context, triple *models.KnowledgeTriple) error {
if triple.ID == "" {
triple.ID = uuid.New().String()
}
if triple.CreatedAt.IsZero() {
triple.CreatedAt = time.Now()
}
if triple.GroupID == "" {
triple.GroupID = "default"
}
if triple.Confidence == 0 {
triple.Confidence = 1.0
}

var embeddingJSON interface{}
if len(triple.Embedding) > 0 {
data, _ := json.Marshal(triple.Embedding)
embeddingJSON = string(data)
}

var metadataJSON interface{}
if triple.Metadata != "" {
metadataJSON = triple.Metadata
}

_, err := s.db.ExecContext(ctx, `
INSERT INTO knowledge (
id, subject_entity_id, predicate, object_entity_id,
source_episode_id, source, group_id, embedding,
created_at, expired_at, confidence, verified, metadata
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
triple.ID, triple.SubjectEntityID, triple.Predicate, triple.ObjectEntityID,
triple.SourceEpisodeID, triple.Source, triple.GroupID, embeddingJSON,
triple.CreatedAt, triple.ExpiredAt, triple.Confidence, triple.Verified, metadataJSON,
)
if err != nil {
return fmt.Errorf("failed to insert knowledge triple: %w", err)
}
return nil
}

// SearchKnowledge finds knowledge triples matching the given query embedding
func (s *Store) SearchKnowledge(ctx context.Context, queryEmbedding []float32, groupID string, maxResults int, minSimilarity float64) ([]models.KnowledgeTriple, error) {
if maxResults <= 0 {
maxResults = 10
}
if minSimilarity <= 0 {
minSimilarity = 0.35
}

embJSON, err := json.Marshal(queryEmbedding)
if err != nil {
return nil, fmt.Errorf("failed to marshal query embedding: %w", err)
}

query := fmt.Sprintf(`
SELECT k.id, k.subject_entity_id, k.predicate, k.object_entity_id,
k.source_episode_id, k.source, k.group_id, k.created_at,
k.expired_at, k.confidence, k.verified, k.metadata,
se.canonical_name AS subject_name, oe.canonical_name AS object_name,
array_cosine_similarity(k.embedding, %s::FLOAT[768]) AS similarity
FROM knowledge k
JOIN entities se ON k.subject_entity_id = se.id
JOIN entities oe ON k.object_entity_id = oe.id
WHERE k.embedding IS NOT NULL
AND (k.expired_at IS NULL OR k.expired_at > CURRENT_TIMESTAMP)
AND array_cosine_similarity(k.embedding, %s::FLOAT[768]) >= %f
`, string(embJSON), string(embJSON), minSimilarity)

if groupID != "" {
query += fmt.Sprintf(" AND k.group_id = '%s'", groupID)
}

query += fmt.Sprintf(" ORDER BY similarity DESC LIMIT %d", maxResults)

rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to search knowledge: %w", err)
}
defer rows.Close()

var triples []models.KnowledgeTriple
for rows.Next() {
var t models.KnowledgeTriple
var metadataRaw interface{}
var similarity sql.NullFloat64
err := rows.Scan(
&t.ID, &t.SubjectEntityID, &t.Predicate, &t.ObjectEntityID,
&t.SourceEpisodeID, &t.Source, &t.GroupID, &t.CreatedAt,
&t.ExpiredAt, &t.Confidence, &t.Verified, &metadataRaw,
&t.SubjectName, &t.ObjectName, &similarity,
)
if err != nil {
return nil, fmt.Errorf("failed to scan knowledge triple: %w", err)
}
if metadataRaw != nil {
switch v := metadataRaw.(type) {
case map[string]interface{}:
if data, err := json.Marshal(v); err == nil {
t.Metadata = string(data)
}
case string:
t.Metadata = v
}
}
triples = append(triples, t)
}
return triples, rows.Err()
}

// InsertEpisodeLink adds a link between two episodes
func (s *Store) InsertEpisodeLink(ctx context.Context, link *models.EpisodeLink) error {
if link.ID == "" {
link.ID = uuid.New().String()
}
if link.CreatedAt.IsZero() {
link.CreatedAt = time.Now()
}
if link.Weight == 0 {
link.Weight = 1.0
}

var viaEntityID interface{}
if link.ViaEntityID != "" {
viaEntityID = link.ViaEntityID
}

// UNIQUE constraint on (source_episode_id, target_episode_id, relationship)
// handles deduplication at the database level. INSERT OR IGNORE silently
// skips duplicates without a race-prone check-then-insert pattern.
_, err := s.db.ExecContext(ctx, `
INSERT OR IGNORE INTO episode_links (id, source_episode_id, target_episode_id, relationship, via_entity_id, weight, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
`, link.ID, link.SourceEpisodeID, link.TargetEpisodeID, link.Relationship, viaEntityID, link.Weight, link.CreatedAt)
if err != nil {
return fmt.Errorf("failed to insert episode link: %w", err)
}
return nil
}

// GetEpisodeLinks retrieves all links for a given episode (both directions)
func (s *Store) GetEpisodeLinks(ctx context.Context, episodeID string) ([]models.EpisodeLink, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, source_episode_id, target_episode_id, relationship, via_entity_id, weight, created_at
FROM episode_links
WHERE source_episode_id = ? OR target_episode_id = ?
ORDER BY created_at DESC
`, episodeID, episodeID)
if err != nil {
return nil, fmt.Errorf("failed to get episode links: %w", err)
}
defer rows.Close()

var links []models.EpisodeLink
for rows.Next() {
var link models.EpisodeLink
var viaEntityID sql.NullString
err := rows.Scan(&link.ID, &link.SourceEpisodeID, &link.TargetEpisodeID,
&link.Relationship, &viaEntityID, &link.Weight, &link.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan episode link: %w", err)
}
if viaEntityID.Valid {
link.ViaEntityID = viaEntityID.String
}
links = append(links, link)
}
return links, rows.Err()
}

// Close closes the database connection
func (s *Store) Close() error {
return s.db.Close()
Expand Down Expand Up @@ -772,6 +1132,19 @@ func sanitizeFTSQuery(s string) string {
return strings.ReplaceAll(s, "'", "''")
}

// normalizeEntityName strips non-alphanumeric characters and lowercases for
// fuzzy entity matching. "OscillateLabs" and "Oscillate Labs" both normalize
// to "oscillatelabs".
func normalizeEntityName(name string) string {
var b strings.Builder
for _, r := range strings.ToLower(name) {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') {
b.WriteRune(r)
}
}
return b.String()
}

// Helper functions for scanning rows

func (s *Store) scanEpisode(row *sql.Row) (*models.Episode, error) {
Expand Down
Loading
Loading