diff --git a/frontend/src-tauri/src/commands/canvas.rs b/frontend/src-tauri/src/commands/canvas.rs index 6fe79a4..48c55be 100644 --- a/frontend/src-tauri/src/commands/canvas.rs +++ b/frontend/src-tauri/src/commands/canvas.rs @@ -1,4 +1,4 @@ -use crate::commands::sync_chunk_index_for_note; +use crate::commands::{run_retrieval, sync_chunk_index_for_note}; use crate::models::canvas::{ AddModelsRequest, AvailableModel, CanvasSession, CanvasStreamEvent, CanvasViewport, ContextMode, Debate, DebateContinueRequest, DebateResponse, DebateRound, DebateStartRequest, @@ -2269,15 +2269,10 @@ async fn resolve_prompt_context( let pinned_ids = session.pinned_note_ids.clone(); // Quality gate: note-level retrieval to check if vault has relevant content - let retrieval_results = { - let search = state.search_service.read().await; - let graph = state.graph_index.read().await; - let priority = state.priority_service.read().await; - let retrieval = state.retrieval_service.read().await; - retrieval - .retrieve(&search, &graph, &priority, &request.prompt, 5, &pinned_ids) - .unwrap_or_default() - }; + let retrieval_results = + run_retrieval(state, &request.prompt, 5, &pinned_ids) + .await + .unwrap_or_default(); let retrieval_decision = should_use_retrieved_notes(&request.prompt, &retrieval_results); if retrieval_decision != RetrievalDecisionReason::UseRetrievedNotes { @@ -2413,15 +2408,10 @@ async fn resolve_twin_prompt_context( request: &PromptRequest, ) -> Result { let pinned_ids = session.pinned_note_ids.clone(); - let retrieval_results = { - let search = state.search_service.read().await; - let graph = state.graph_index.read().await; - let priority = state.priority_service.read().await; - let retrieval = state.retrieval_service.read().await; - retrieval - .retrieve(&search, &graph, &priority, &request.prompt, 5, &pinned_ids) - .unwrap_or_default() - }; + let retrieval_results = + run_retrieval(state, &request.prompt, 5, &pinned_ids) + .await + .unwrap_or_default(); let should_use_notes = should_use_retrieved_notes(&request.prompt, &retrieval_results) == RetrievalDecisionReason::UseRetrievedNotes; @@ -2475,20 +2465,17 @@ async fn resolve_twin_prompt_context( } if note_contexts.is_empty() { - let store = state.knowledge_store.read().await; - for result in &retrieval_results { - if let Ok(note) = store.get_note(&result.note.id) { - note_contexts.push(( - note.id.clone(), - note.title.clone(), - truncate_note_context_content(¬e.content, 1500), - )); + note_contexts = fetch_note_contexts(state, &retrieval_results).await; + let found_ids: HashSet<&str> = + note_contexts.iter().map(|(id, _, _)| id.as_str()).collect(); + for r in &retrieval_results { + if found_ids.contains(r.note.id.as_str()) { context_notes.push(TileContextNote { - id: result.note.id.clone(), - title: result.note.title.clone(), - snippet: result.snippet.clone(), - score: result.score, - pinned: pinned_ids.contains(&result.note.id), + id: r.note.id.clone(), + title: r.note.title.clone(), + snippet: r.snippet.clone(), + score: r.score, + pinned: pinned_ids.contains(&r.note.id), }); } } @@ -2539,6 +2526,24 @@ async fn resolve_twin_prompt_context( }) } +/// Fetch full note content for retrieval results, returning (id, title, truncated_content) tuples. +/// Skips notes that can't be read from the store. Called by both the semantic and twin note-level paths. +async fn fetch_note_contexts( + state: &AppState, + results: &[RetrievalResult], +) -> Vec<(String, String, String)> { + let store = state.knowledge_store.read().await; + results + .iter() + .filter_map(|r| { + store.get_note(&r.note.id).ok().map(|note| { + let truncated = truncate_note_context_content(¬e.content, 1500); + (note.id.clone(), note.title.clone(), truncated) + }) + }) + .collect() +} + /// Fall back to note-level context when chunk retrieval is disabled or returns nothing. async fn resolve_note_level_context( state: &AppState, @@ -2547,18 +2552,7 @@ async fn resolve_note_level_context( pinned_ids: &[String], user_system_prompt: &Option, ) -> Result { - let note_contexts: Vec<(String, String, String)> = { - let store = state.knowledge_store.read().await; - retrieval_results - .iter() - .filter_map(|r| { - store.get_note(&r.note.id).ok().map(|note| { - let truncated = truncate_note_context_content(¬e.content, 1500); - (note.id.clone(), note.title.clone(), truncated) - }) - }) - .collect() - }; + let note_contexts = fetch_note_contexts(state, retrieval_results).await; let context_notes: Vec = retrieval_results .iter() diff --git a/frontend/src-tauri/src/commands/memory.rs b/frontend/src-tauri/src/commands/memory.rs index c20cb40..0fa884a 100644 --- a/frontend/src-tauri/src/commands/memory.rs +++ b/frontend/src-tauri/src/commands/memory.rs @@ -1,3 +1,4 @@ +use crate::commands::run_retrieval; use crate::models::memory::{ Contradiction, ExtractRequest, ExtractedClaim, RecallRequest, RecallResult, }; @@ -10,21 +11,10 @@ pub async fn recall_relevant( request: RecallRequest, state: State<'_, AppState>, ) -> Result, String> { - let search = state.search_service.read().await; - let graph = state.graph_index.read().await; - let priority = state.priority_service.read().await; - let retrieval = state.retrieval_service.read().await; - - let results = retrieval.retrieve( - &search, - &graph, - &priority, - &request.query, - request.limit, - &request.context_note_ids, - )?; + let results = + run_retrieval(state.inner(), &request.query, request.limit, &request.context_note_ids) + .await?; - // Convert RetrievalResult → RecallResult Ok(results .into_iter() .map(|r| RecallResult { @@ -33,7 +23,7 @@ pub async fn recall_relevant( snippet: r.snippet, score: r.score, tags: r.note.tags, - graph_boost: 0.0, // now integrated into the composite score + graph_boost: 0.0, total_score: r.score, }) .collect()) diff --git a/frontend/src-tauri/src/commands/mod.rs b/frontend/src-tauri/src/commands/mod.rs index 6d02232..4d666e4 100644 --- a/frontend/src-tauri/src/commands/mod.rs +++ b/frontend/src-tauri/src/commands/mod.rs @@ -16,9 +16,27 @@ pub mod twin; pub mod zettelkasten; use crate::models::note::Note; +use crate::services::retrieval::RetrievalResult; use crate::AppState; use std::collections::HashSet; +/// Shared retrieval helper — acquires the 4 retrieval-pipeline read locks, calls +/// `retrieval.retrieve()`, and releases all locks before returning. Used by +/// `retrieve_relevant`, `recall_relevant`, and the canvas context resolvers to +/// avoid duplicating the same lock sequence in each caller. +pub(crate) async fn run_retrieval( + state: &AppState, + query: &str, + limit: usize, + context_ids: &[String], +) -> Result, String> { + let search = state.search_service.read().await; + let graph = state.graph_index.read().await; + let priority = state.priority_service.read().await; + let retrieval = state.retrieval_service.read().await; + retrieval.retrieve(&search, &graph, &priority, query, limit, context_ids) +} + pub(crate) async fn sync_chunk_index_for_note(state: &AppState, note: &Note) { sync_chunk_index_for_notes(state, std::slice::from_ref(note)).await; } diff --git a/frontend/src-tauri/src/commands/retrieval.rs b/frontend/src-tauri/src/commands/retrieval.rs index 59250e5..16c43f3 100644 --- a/frontend/src-tauri/src/commands/retrieval.rs +++ b/frontend/src-tauri/src/commands/retrieval.rs @@ -1,3 +1,4 @@ +use crate::commands::run_retrieval; use crate::services::retrieval::{RetrievalConfig, RetrievalConfigUpdate, RetrievalResult}; use crate::AppState; use tauri::State; @@ -12,13 +13,7 @@ pub async fn retrieve_relevant( ) -> Result, String> { let limit = limit.unwrap_or(10); let context_ids = context_note_ids.unwrap_or_default(); - - let search = state.search_service.read().await; - let graph = state.graph_index.read().await; - let priority = state.priority_service.read().await; - let retrieval = state.retrieval_service.read().await; - - retrieval.retrieve(&search, &graph, &priority, &query, limit, &context_ids) + run_retrieval(state.inner(), &query, limit, &context_ids).await } /// Get current retrieval configuration diff --git a/frontend/src-tauri/src/commands/zettelkasten.rs b/frontend/src-tauri/src/commands/zettelkasten.rs index f28b8c6..d16c3a0 100644 --- a/frontend/src-tauri/src/commands/zettelkasten.rs +++ b/frontend/src-tauri/src/commands/zettelkasten.rs @@ -6,7 +6,7 @@ use crate::models::note::{ ApplyLinksRequest, ApplyLinksResponse, CreateLinkResponse, DiscoverLinksResponse, NoteUpdate, RelationType, ZettelLinkCandidate, }; -use crate::services::link_discovery::discover_for_note; +use crate::services::link_discovery::{discover_for_note, DiscoverMode}; use crate::AppState; use std::collections::{HashMap, HashSet}; use tauri::State; @@ -160,37 +160,6 @@ fn deduplicate_links(links: Vec) -> Vec) -> Self { - match mode.unwrap_or("suggested").to_ascii_lowercase().as_str() { - "manual" => Self::Manual, - "algorithm" => Self::Algorithm, - "llm" | "suggested" => Self::Llm, - _ => Self::Llm, - } - } - - #[cfg(test)] - fn include_llm(self) -> bool { - matches!(self, Self::Llm) - } -} - -fn to_service_mode(mode: DiscoverMode) -> crate::services::link_discovery::DiscoverMode { - match mode { - DiscoverMode::Manual => crate::services::link_discovery::DiscoverMode::Manual, - DiscoverMode::Algorithm => crate::services::link_discovery::DiscoverMode::Algorithm, - DiscoverMode::Llm => crate::services::link_discovery::DiscoverMode::Llm, - } -} - // ── Tauri commands ─────────────────────────────────────────────────────── /// Discover potential links for a note using multiple strategies @@ -203,19 +172,12 @@ pub async fn discover_links( ) -> Result { let discover_mode = DiscoverMode::parse(mode.as_deref()); let max_links = maxLinks.unwrap_or(10); - discover_for_note( - state.inner(), - ¬eId, - to_service_mode(discover_mode), - max_links, - true, - ) - .await + discover_for_note(state.inner(), ¬eId, discover_mode, max_links, true).await } #[cfg(test)] mod tests { - use super::DiscoverMode; + use crate::services::link_discovery::DiscoverMode; #[test] fn parses_discover_modes() { diff --git a/frontend/src-tauri/src/services/link_discovery.rs b/frontend/src-tauri/src/services/link_discovery.rs index 25fd5d5..1753e72 100644 --- a/frontend/src-tauri/src/services/link_discovery.rs +++ b/frontend/src-tauri/src/services/link_discovery.rs @@ -6,7 +6,7 @@ use crate::models::note::{ }; use crate::models::settings::UserSettings; use crate::services::retrieval::RetrievalResult; -use crate::services::similarity::{SimilarityProvider, TfIdfProvider}; +use crate::services::similarity::{sparse_cosine, SimilarityProvider, TfIdfProvider}; use crate::services::yake::{self, YakeConfig, STOPWORDS}; use chrono::{DateTime, Utc}; use lazy_static::lazy_static; @@ -883,7 +883,7 @@ pub fn sample_exploratory_candidates( && !excluded_ids.contains(&profile.note_id) }) .filter_map(|profile| { - let term_cosine = cosine_similarity(&source_profile.term_vector, &profile.term_vector); + let term_cosine = sparse_cosine(&source_profile.term_vector, &profile.term_vector); let candidate_tags = profile .tags .iter() @@ -1069,29 +1069,6 @@ fn build_reference_index(notes: &[Note]) -> HashMap { refs } -pub fn cosine_similarity(a: &HashMap, b: &HashMap) -> f64 { - if a.is_empty() || b.is_empty() { - return 0.0; - } - - let mut dot_product = 0.0; - let mut norm_a = 0.0; - for (term, weight_a) in a { - norm_a += weight_a * weight_a; - if let Some(weight_b) = b.get(term) { - dot_product += weight_a * weight_b; - } - } - - let norm_b = b.values().map(|weight| weight * weight).sum::(); - let magnitude = (norm_a * norm_b).sqrt(); - if magnitude < 1e-10 { - 0.0 - } else { - dot_product / magnitude - } -} - pub fn title_token_overlap(left: &str, right: &str) -> usize { let left_tokens = tokenize_simple(left).into_iter().collect::>(); let right_tokens = tokenize_simple(right).into_iter().collect::>(); @@ -1507,7 +1484,7 @@ fn build_local_ranked_candidates( entry.tag_overlap = tag_overlap; entry.tag_overlap_ratio = tag_overlap as f64 / tag_union as f64; entry.key_term_cosine = - cosine_similarity(&source_profile.term_vector, &profile.term_vector); + sparse_cosine(&source_profile.term_vector, &profile.term_vector); if !source_link_set.is_empty() { let overlap_score = (shared_neighbors as f64 / source_link_set.len() as f64).clamp(0.0, 1.0); diff --git a/frontend/src-tauri/src/services/similarity.rs b/frontend/src-tauri/src/services/similarity.rs index 7e38841..101a53a 100644 --- a/frontend/src-tauri/src/services/similarity.rs +++ b/frontend/src-tauri/src/services/similarity.rs @@ -134,6 +134,34 @@ impl SimilarityProvider for TfIdfProvider { } } +// ── Free functions ─────────────────────────────────────────────────────── + +/// Cosine similarity between two sparse term-weight maps. +/// Used by link discovery and any other code that works with HashMap term vectors +/// rather than SparseTfVector structs. +pub fn sparse_cosine(a: &HashMap, b: &HashMap) -> f64 { + if a.is_empty() || b.is_empty() { + return 0.0; + } + + let mut dot_product = 0.0; + let mut norm_a = 0.0; + for (term, weight_a) in a { + norm_a += weight_a * weight_a; + if let Some(weight_b) = b.get(term) { + dot_product += weight_a * weight_b; + } + } + + let norm_b = b.values().map(|w| w * w).sum::(); + let magnitude = (norm_a * norm_b).sqrt(); + if magnitude < 1e-10 { + 0.0 + } else { + dot_product / magnitude + } +} + // ── Tests ──────────────────────────────────────────────────────────────── #[cfg(test)]