Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 38 additions & 44 deletions frontend/src-tauri/src/commands/canvas.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::commands::sync_chunk_index_for_note;
use crate::commands::{run_retrieval, sync_chunk_index_for_note};
use crate::models::canvas::{
AddModelsRequest, AvailableModel, CanvasSession, CanvasStreamEvent, CanvasViewport,
ContextMode, Debate, DebateContinueRequest, DebateResponse, DebateRound, DebateStartRequest,
Expand Down Expand Up @@ -2269,15 +2269,10 @@ async fn resolve_prompt_context(
let pinned_ids = session.pinned_note_ids.clone();

// Quality gate: note-level retrieval to check if vault has relevant content
let retrieval_results = {
let search = state.search_service.read().await;
let graph = state.graph_index.read().await;
let priority = state.priority_service.read().await;
let retrieval = state.retrieval_service.read().await;
retrieval
.retrieve(&search, &graph, &priority, &request.prompt, 5, &pinned_ids)
.unwrap_or_default()
};
let retrieval_results =
run_retrieval(state, &request.prompt, 5, &pinned_ids)
.await
.unwrap_or_default();

let retrieval_decision = should_use_retrieved_notes(&request.prompt, &retrieval_results);
if retrieval_decision != RetrievalDecisionReason::UseRetrievedNotes {
Expand Down Expand Up @@ -2413,15 +2408,10 @@ async fn resolve_twin_prompt_context(
request: &PromptRequest,
) -> Result<ResolvedPromptContext, String> {
let pinned_ids = session.pinned_note_ids.clone();
let retrieval_results = {
let search = state.search_service.read().await;
let graph = state.graph_index.read().await;
let priority = state.priority_service.read().await;
let retrieval = state.retrieval_service.read().await;
retrieval
.retrieve(&search, &graph, &priority, &request.prompt, 5, &pinned_ids)
.unwrap_or_default()
};
let retrieval_results =
run_retrieval(state, &request.prompt, 5, &pinned_ids)
.await
.unwrap_or_default();

let should_use_notes = should_use_retrieved_notes(&request.prompt, &retrieval_results)
== RetrievalDecisionReason::UseRetrievedNotes;
Expand Down Expand Up @@ -2475,20 +2465,17 @@ async fn resolve_twin_prompt_context(
}

if note_contexts.is_empty() {
let store = state.knowledge_store.read().await;
for result in &retrieval_results {
if let Ok(note) = store.get_note(&result.note.id) {
note_contexts.push((
note.id.clone(),
note.title.clone(),
truncate_note_context_content(&note.content, 1500),
));
note_contexts = fetch_note_contexts(state, &retrieval_results).await;
let found_ids: HashSet<&str> =
note_contexts.iter().map(|(id, _, _)| id.as_str()).collect();
for r in &retrieval_results {
if found_ids.contains(r.note.id.as_str()) {
context_notes.push(TileContextNote {
id: result.note.id.clone(),
title: result.note.title.clone(),
snippet: result.snippet.clone(),
score: result.score,
pinned: pinned_ids.contains(&result.note.id),
id: r.note.id.clone(),
title: r.note.title.clone(),
snippet: r.snippet.clone(),
score: r.score,
pinned: pinned_ids.contains(&r.note.id),
});
}
}
Expand Down Expand Up @@ -2539,6 +2526,24 @@ async fn resolve_twin_prompt_context(
})
}

/// Fetch full note content for retrieval results, returning (id, title, truncated_content) tuples.
/// Skips notes that can't be read from the store. Called by both the semantic and twin note-level paths.
async fn fetch_note_contexts(
state: &AppState,
results: &[RetrievalResult],
) -> Vec<(String, String, String)> {
let store = state.knowledge_store.read().await;
results
.iter()
.filter_map(|r| {
store.get_note(&r.note.id).ok().map(|note| {
let truncated = truncate_note_context_content(&note.content, 1500);
(note.id.clone(), note.title.clone(), truncated)
})
})
.collect()
}

/// Fall back to note-level context when chunk retrieval is disabled or returns nothing.
async fn resolve_note_level_context(
state: &AppState,
Expand All @@ -2547,18 +2552,7 @@ async fn resolve_note_level_context(
pinned_ids: &[String],
user_system_prompt: &Option<String>,
) -> Result<ResolvedPromptContext, String> {
let note_contexts: Vec<(String, String, String)> = {
let store = state.knowledge_store.read().await;
retrieval_results
.iter()
.filter_map(|r| {
store.get_note(&r.note.id).ok().map(|note| {
let truncated = truncate_note_context_content(&note.content, 1500);
(note.id.clone(), note.title.clone(), truncated)
})
})
.collect()
};
let note_contexts = fetch_note_contexts(state, retrieval_results).await;

let context_notes: Vec<TileContextNote> = retrieval_results
.iter()
Expand Down
20 changes: 5 additions & 15 deletions frontend/src-tauri/src/commands/memory.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::commands::run_retrieval;
use crate::models::memory::{
Contradiction, ExtractRequest, ExtractedClaim, RecallRequest, RecallResult,
};
Expand All @@ -10,21 +11,10 @@ pub async fn recall_relevant(
request: RecallRequest,
state: State<'_, AppState>,
) -> Result<Vec<RecallResult>, String> {
let search = state.search_service.read().await;
let graph = state.graph_index.read().await;
let priority = state.priority_service.read().await;
let retrieval = state.retrieval_service.read().await;

let results = retrieval.retrieve(
&search,
&graph,
&priority,
&request.query,
request.limit,
&request.context_note_ids,
)?;
let results =
run_retrieval(state.inner(), &request.query, request.limit, &request.context_note_ids)
.await?;

// Convert RetrievalResult → RecallResult
Ok(results
.into_iter()
.map(|r| RecallResult {
Expand All @@ -33,7 +23,7 @@ pub async fn recall_relevant(
snippet: r.snippet,
score: r.score,
tags: r.note.tags,
graph_boost: 0.0, // now integrated into the composite score
graph_boost: 0.0,
total_score: r.score,
})
.collect())
Expand Down
18 changes: 18 additions & 0 deletions frontend/src-tauri/src/commands/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,27 @@ pub mod twin;
pub mod zettelkasten;

use crate::models::note::Note;
use crate::services::retrieval::RetrievalResult;
use crate::AppState;
use std::collections::HashSet;

/// Shared retrieval helper — acquires the 4 retrieval-pipeline read locks, calls
/// `retrieval.retrieve()`, and releases all locks before returning. Used by
/// `retrieve_relevant`, `recall_relevant`, and the canvas context resolvers to
/// avoid duplicating the same lock sequence in each caller.
pub(crate) async fn run_retrieval(
state: &AppState,
query: &str,
limit: usize,
context_ids: &[String],
) -> Result<Vec<RetrievalResult>, String> {
let search = state.search_service.read().await;
let graph = state.graph_index.read().await;
let priority = state.priority_service.read().await;
let retrieval = state.retrieval_service.read().await;
retrieval.retrieve(&search, &graph, &priority, query, limit, context_ids)
}

pub(crate) async fn sync_chunk_index_for_note(state: &AppState, note: &Note) {
sync_chunk_index_for_notes(state, std::slice::from_ref(note)).await;
}
Expand Down
9 changes: 2 additions & 7 deletions frontend/src-tauri/src/commands/retrieval.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::commands::run_retrieval;
use crate::services::retrieval::{RetrievalConfig, RetrievalConfigUpdate, RetrievalResult};
use crate::AppState;
use tauri::State;
Expand All @@ -12,13 +13,7 @@ pub async fn retrieve_relevant(
) -> Result<Vec<RetrievalResult>, String> {
let limit = limit.unwrap_or(10);
let context_ids = context_note_ids.unwrap_or_default();

let search = state.search_service.read().await;
let graph = state.graph_index.read().await;
let priority = state.priority_service.read().await;
let retrieval = state.retrieval_service.read().await;

retrieval.retrieve(&search, &graph, &priority, &query, limit, &context_ids)
run_retrieval(state.inner(), &query, limit, &context_ids).await
}

/// Get current retrieval configuration
Expand Down
44 changes: 3 additions & 41 deletions frontend/src-tauri/src/commands/zettelkasten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::models::note::{
ApplyLinksRequest, ApplyLinksResponse, CreateLinkResponse, DiscoverLinksResponse, NoteUpdate,
RelationType, ZettelLinkCandidate,
};
use crate::services::link_discovery::discover_for_note;
use crate::services::link_discovery::{discover_for_note, DiscoverMode};
use crate::AppState;
use std::collections::{HashMap, HashSet};
use tauri::State;
Expand Down Expand Up @@ -160,37 +160,6 @@ fn deduplicate_links(links: Vec<ZettelLinkCandidate>) -> Vec<ZettelLinkCandidate
seen.into_values().collect()
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum DiscoverMode {
Manual,
Algorithm,
Llm,
}

impl DiscoverMode {
fn parse(mode: Option<&str>) -> Self {
match mode.unwrap_or("suggested").to_ascii_lowercase().as_str() {
"manual" => Self::Manual,
"algorithm" => Self::Algorithm,
"llm" | "suggested" => Self::Llm,
_ => Self::Llm,
}
}

#[cfg(test)]
fn include_llm(self) -> bool {
matches!(self, Self::Llm)
}
}

fn to_service_mode(mode: DiscoverMode) -> crate::services::link_discovery::DiscoverMode {
match mode {
DiscoverMode::Manual => crate::services::link_discovery::DiscoverMode::Manual,
DiscoverMode::Algorithm => crate::services::link_discovery::DiscoverMode::Algorithm,
DiscoverMode::Llm => crate::services::link_discovery::DiscoverMode::Llm,
}
}

// ── Tauri commands ───────────────────────────────────────────────────────

/// Discover potential links for a note using multiple strategies
Expand All @@ -203,19 +172,12 @@ pub async fn discover_links(
) -> Result<DiscoverLinksResponse, String> {
let discover_mode = DiscoverMode::parse(mode.as_deref());
let max_links = maxLinks.unwrap_or(10);
discover_for_note(
state.inner(),
&noteId,
to_service_mode(discover_mode),
max_links,
true,
)
.await
discover_for_note(state.inner(), &noteId, discover_mode, max_links, true).await
}

#[cfg(test)]
mod tests {
use super::DiscoverMode;
use crate::services::link_discovery::DiscoverMode;

#[test]
fn parses_discover_modes() {
Expand Down
29 changes: 3 additions & 26 deletions frontend/src-tauri/src/services/link_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::models::note::{
};
use crate::models::settings::UserSettings;
use crate::services::retrieval::RetrievalResult;
use crate::services::similarity::{SimilarityProvider, TfIdfProvider};
use crate::services::similarity::{sparse_cosine, SimilarityProvider, TfIdfProvider};
use crate::services::yake::{self, YakeConfig, STOPWORDS};
use chrono::{DateTime, Utc};
use lazy_static::lazy_static;
Expand Down Expand Up @@ -883,7 +883,7 @@ pub fn sample_exploratory_candidates(
&& !excluded_ids.contains(&profile.note_id)
})
.filter_map(|profile| {
let term_cosine = cosine_similarity(&source_profile.term_vector, &profile.term_vector);
let term_cosine = sparse_cosine(&source_profile.term_vector, &profile.term_vector);
let candidate_tags = profile
.tags
.iter()
Expand Down Expand Up @@ -1069,29 +1069,6 @@ fn build_reference_index(notes: &[Note]) -> HashMap<String, String> {
refs
}

pub fn cosine_similarity(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
if a.is_empty() || b.is_empty() {
return 0.0;
}

let mut dot_product = 0.0;
let mut norm_a = 0.0;
for (term, weight_a) in a {
norm_a += weight_a * weight_a;
if let Some(weight_b) = b.get(term) {
dot_product += weight_a * weight_b;
}
}

let norm_b = b.values().map(|weight| weight * weight).sum::<f64>();
let magnitude = (norm_a * norm_b).sqrt();
if magnitude < 1e-10 {
0.0
} else {
dot_product / magnitude
}
}

pub fn title_token_overlap(left: &str, right: &str) -> usize {
let left_tokens = tokenize_simple(left).into_iter().collect::<HashSet<_>>();
let right_tokens = tokenize_simple(right).into_iter().collect::<HashSet<_>>();
Expand Down Expand Up @@ -1507,7 +1484,7 @@ fn build_local_ranked_candidates(
entry.tag_overlap = tag_overlap;
entry.tag_overlap_ratio = tag_overlap as f64 / tag_union as f64;
entry.key_term_cosine =
cosine_similarity(&source_profile.term_vector, &profile.term_vector);
sparse_cosine(&source_profile.term_vector, &profile.term_vector);
if !source_link_set.is_empty() {
let overlap_score =
(shared_neighbors as f64 / source_link_set.len() as f64).clamp(0.0, 1.0);
Expand Down
28 changes: 28 additions & 0 deletions frontend/src-tauri/src/services/similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,34 @@ impl SimilarityProvider for TfIdfProvider {
}
}

// ── Free functions ───────────────────────────────────────────────────────

/// Cosine similarity between two sparse term-weight maps.
/// Used by link discovery and any other code that works with HashMap term vectors
/// rather than SparseTfVector structs.
pub fn sparse_cosine(a: &HashMap<String, f64>, b: &HashMap<String, f64>) -> f64 {
if a.is_empty() || b.is_empty() {
return 0.0;
}

let mut dot_product = 0.0;
let mut norm_a = 0.0;
for (term, weight_a) in a {
norm_a += weight_a * weight_a;
if let Some(weight_b) = b.get(term) {
dot_product += weight_a * weight_b;
}
}

let norm_b = b.values().map(|w| w * w).sum::<f64>();
let magnitude = (norm_a * norm_b).sqrt();
if magnitude < 1e-10 {
0.0
} else {
dot_product / magnitude
}
}

// ── Tests ────────────────────────────────────────────────────────────────

#[cfg(test)]
Expand Down
Loading