diff --git a/Cargo.toml b/Cargo.toml index d912c66..da6ff79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } dialoguer = "0.12" console = "0.15" ctrlc = "3" +glob = "0.3.3" [dev-dependencies] tempfile = "3" diff --git a/src/cli/commands.rs b/src/cli/commands.rs index f128499..5a4dcdd 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -5,7 +5,9 @@ use std::path::{Path, PathBuf}; use super::interactive; use crate::config::{Config, Mode}; -use crate::core::{Database, EmbeddingEngine, Indexer, SearchEngine, ServerIndexer}; +use crate::core::{ + Database, EmbeddingEngine, FileFilter, Indexer, SearchEngine, ServerIndexer, +}; use crate::server::{self, Client}; use crate::ui::{self, SearchTui}; use crate::watcher::FileWatcher; @@ -95,6 +97,22 @@ enum Commands { /// Sync/index before searching #[arg(short = 's', long)] sync: bool, + + /// Filter by file extension (comma-separated, e.g., "rs,toml") + #[arg(long)] + ext: Option, + + /// Include only files matching these glob patterns (comma-separated) + #[arg(long)] + glob: Option, + + /// Exclude files matching these glob patterns (comma-separated) + #[arg(long)] + exclude: Option, + + /// Filter by file type category (e.g., "code", "config", "docs") + #[arg(long)] + r#type: Option, }, /// Watch directory for changes and auto-index @@ -268,6 +286,7 @@ impl Cli { show_content, false, false, + None, ); } @@ -289,11 +308,18 @@ impl Cli { content, interactive, sync, + ext, + glob, + exclude, + r#type, }) => { let max_results = max_results .or(self.max_results) .unwrap_or(config.max_results); let content = content || self.content || config.show_content; + + let filter = build_file_filter(ext, glob, exclude, r#type); + run_search_smart( &config, &query, @@ -302,6 +328,7 @@ impl Cli { content, interactive, sync, + filter.as_ref(), ) } Some(Commands::Watch { path, dry_run }) => { @@ -339,6 +366,66 @@ impl Cli { } } +fn build_file_filter( + ext: Option, + glob: Option, + exclude: Option, + type_category: Option, +) -> Option { + if ext.is_none() && glob.is_none() && exclude.is_none() && type_category.is_none() { + return None; + } + + let mut extensions = Vec::new(); + let mut include_globs = Vec::new(); + let mut exclude_globs = Vec::new(); + + if let Some(e) = ext { + extensions.extend(e.split(',').map(|s| s.trim().to_string())); + } + + if let Some(g) = glob { + include_globs.extend(g.split(',').map(|s| s.trim().to_string())); + } + + if let Some(e) = exclude { + exclude_globs.extend(e.split(',').map(|s| s.trim().to_string())); + } + + if let Some(t) = type_category { + match t.as_str() { + "code" => extensions.extend( + vec![ + "rs", "py", "js", "ts", "c", "cpp", "h", "hpp", "go", "java", "rb", "php", "sh", + ] + .into_iter() + .map(String::from), + ), + "config" => extensions.extend( + vec![ + "toml", "json", "yaml", "yml", "xml", "ini", "conf", "config", "properties", + ] + .into_iter() + .map(String::from), + ), + "docs" => extensions.extend( + vec!["md", "txt", "rst", "adoc", "pdf"] + .into_iter() + .map(String::from), + ), + _ => { + // Treat unknown types as single extensions for now, or just warn + // For now, we'll assume it might be a custom type which we don't support yet, + // so we just add it as an extension to be safe? No, that's confusing. + // Let's print a warning if possible, but we are in a helper function. + // We'll just ignore unknown types for now. + } + } + } + + Some(FileFilter::new(extensions, include_globs, exclude_globs)) +} + fn print_quick_help() { ui::print_banner(); @@ -547,6 +634,7 @@ fn run_search_smart( show_content: bool, interactive: bool, sync: bool, + filter: Option<&FileFilter>, ) -> Result<()> { if sync { run_index( @@ -576,13 +664,18 @@ fn run_search_smart( max_results, show_content, interactive, + filter, ); } + // Note: Filters not yet supported in server mode + if filter.is_some() { + ui::print_warning("Filters are currently only supported in local mode. Ignoring filters."); + } run_search_server(&client, query, path, max_results, show_content) } Mode::Local => { - run_search_local(config, query, path, max_results, show_content, interactive) + run_search_local(config, query, path, max_results, show_content, interactive, filter) } } } @@ -653,6 +746,7 @@ fn run_search_local( max_results: usize, show_content: bool, interactive: bool, + filter: Option<&FileFilter>, ) -> Result<()> { use std::time::Instant; @@ -684,7 +778,7 @@ fn run_search_local( println!(); let start = Instant::now(); - let results = search.search(query, path, max_results)?; + let results = search.search(query, path, filter, max_results)?; let elapsed = start.elapsed(); if results.is_empty() { diff --git a/src/core/db.rs b/src/core/db.rs index 080e586..c54fca6 100644 --- a/src/core/db.rs +++ b/src/core/db.rs @@ -3,6 +3,8 @@ use chrono::{DateTime, Utc}; use rusqlite::{params, Connection}; use std::path::{Path, PathBuf}; +use super::filter::FileFilter; + pub struct Database { conn: Connection, } @@ -156,6 +158,7 @@ impl Database { &self, query_embedding: &[f32], path_prefix: &Path, + filter: Option<&FileFilter>, limit: usize, ) -> Result> { let path_prefix_str = path_prefix.to_string_lossy(); @@ -185,6 +188,13 @@ impl Database { }) })? .filter_map(Result::ok) + .filter(|r| { + if let Some(f) = filter { + f.matches(&r.path) + } else { + true + } + }) .collect(); // Sort by similarity (highest first) diff --git a/src/core/filter.rs b/src/core/filter.rs new file mode 100644 index 0000000..2c21cb4 --- /dev/null +++ b/src/core/filter.rs @@ -0,0 +1,156 @@ +use glob::Pattern; +use std::path::Path; + +#[derive(Debug, Clone)] +pub struct FileFilter { + pub extensions: Vec, + pub include_globs: Vec, + pub exclude_globs: Vec, +} + +impl FileFilter { + pub fn new( + extensions: Vec, + include_globs: Vec, + exclude_globs: Vec, + ) -> Self { + let include_globs = include_globs + .into_iter() + .filter_map(|s| Pattern::new(&s).ok()) + .collect(); + let exclude_globs = exclude_globs + .into_iter() + .filter_map(|s| Pattern::new(&s).ok()) + .collect(); + + Self { + extensions, + include_globs, + exclude_globs, + } + } + + pub fn matches(&self, path: &Path) -> bool { + // 1. Check extension (whitelist) + if !self.extensions.is_empty() { + match path.extension().and_then(|e| e.to_str()) { + Some(ext) => { + if !self.extensions.iter().any(|e| e.eq_ignore_ascii_case(ext)) { + return false; + } + } + None => return false, + } + } + + // 2. Check exclude patterns (blacklist) + for pattern in &self.exclude_globs { + if pattern.matches_path(path) { + return false; + } + } + + // 3. Check include patterns (whitelist) + if !self.include_globs.is_empty() { + let mut matched = false; + for pattern in &self.include_globs { + if pattern.matches_path(path) { + matched = true; + break; + } + } + if !matched { + return false; + } + } + + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_filter_extension() { + let filter = FileFilter::new( + vec!["rs".to_string(), "toml".to_string()], + vec![], + vec![], + ); + + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + assert!(filter.matches(&PathBuf::from("Cargo.toml"))); + assert!(!filter.matches(&PathBuf::from("README.md"))); + assert!(!filter.matches(&PathBuf::from("src/main"))); + } + + #[test] + fn test_filter_include_glob() { + let filter = FileFilter::new( + vec![], + vec!["src/**/*.rs".to_string()], + vec![], + ); + + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + assert!(filter.matches(&PathBuf::from("src/core/mod.rs"))); + // Note: glob matching is relative to CWD usually, but Pattern matches absolute paths too if they match the string. + // glob::Pattern matches against the string representation. + // "src/**/*.rs" will match "src/main.rs" + assert!(!filter.matches(&PathBuf::from("tests/main.rs"))); + assert!(!filter.matches(&PathBuf::from("README.md"))); + } + + #[test] + fn test_filter_exclude_glob() { + let filter = FileFilter::new( + vec![], + vec![], + vec!["target/**".to_string(), "node_modules/**".to_string()], + ); + + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + assert!(!filter.matches(&PathBuf::from("target/debug/vgrep"))); + assert!(!filter.matches(&PathBuf::from("node_modules/react/index.js"))); + } + + #[test] + fn test_filter_combined() { + let filter = FileFilter::new( + vec!["rs".to_string()], + vec!["src/**".to_string()], + vec!["**/*_test.rs".to_string()], + ); + + // Must match extension AND include glob (if present) AND NOT exclude glob + // Wait, logic implementation: + // 1. Check extension (whitelist) - if present, MUST match + // 2. Check exclude (blacklist) - if matches, return FALSE + // 3. Check include (whitelist) - if present, MUST match + + // "src/main.rs" -> ext=rs (ok), exclude= (ok), include=src/** (ok) -> true + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + + // "src/main.py" -> ext=rs (fail) -> false + assert!(!filter.matches(&PathBuf::from("src/main.py"))); + + // "tests/main.rs" -> ext=rs (ok), exclude= (ok), include=src/** (fail) -> false + assert!(!filter.matches(&PathBuf::from("tests/main.rs"))); + + // "src/my_test.rs" -> ext=rs (ok), exclude=*_test.rs (fail) -> false + assert!(!filter.matches(&PathBuf::from("src/my_test.rs"))); + } +} + +impl Default for FileFilter { + fn default() -> Self { + Self { + extensions: Vec::new(), + include_globs: Vec::new(), + exclude_globs: Vec::new(), + } + } +} diff --git a/src/core/filter_test.rs b/src/core/filter_test.rs new file mode 100644 index 0000000..47301a5 --- /dev/null +++ b/src/core/filter_test.rs @@ -0,0 +1,70 @@ +#[cfg(test)] +mod tests { + use super::*; + use glob::Pattern; + use std::path::PathBuf; + + #[test] + fn test_filter_extension() { + let filter = FileFilter::new( + vec!["rs".to_string(), "toml".to_string()], + vec![], + vec![], + ); + + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + assert!(filter.matches(&PathBuf::from("Cargo.toml"))); + assert!(!filter.matches(&PathBuf::from("README.md"))); + assert!(!filter.matches(&PathBuf::from("src/main"))); + } + + #[test] + fn test_filter_include_glob() { + let filter = FileFilter::new( + vec![], + vec!["src/**/*.rs".to_string()], + vec![], + ); + + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + assert!(filter.matches(&PathBuf::from("src/core/mod.rs"))); + assert!(!filter.matches(&PathBuf::from("tests/main.rs"))); + assert!(!filter.matches(&PathBuf::from("README.md"))); + } + + #[test] + fn test_filter_exclude_glob() { + let filter = FileFilter::new( + vec![], + vec![], + vec!["target/**".to_string(), "node_modules/**".to_string()], + ); + + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + assert!(!filter.matches(&PathBuf::from("target/debug/vgrep"))); + assert!(!filter.matches(&PathBuf::from("node_modules/react/index.js"))); + } + + #[test] + fn test_filter_combined() { + let filter = FileFilter::new( + vec!["rs".to_string()], + vec!["src/**".to_string()], + vec!["**/*_test.rs".to_string()], + ); + + // Must match extension AND include glob AND NOT exclude glob + + // "src/main.rs" -> ext=rs (ok), include=src/** (ok), exclude= (ok) -> true + assert!(filter.matches(&PathBuf::from("src/main.rs"))); + + // "src/main.py" -> ext=rs (fail) -> false + assert!(!filter.matches(&PathBuf::from("src/main.py"))); + + // "tests/main.rs" -> ext=rs (ok), include=src/** (fail) -> false + assert!(!filter.matches(&PathBuf::from("tests/main.rs"))); + + // "src/unit_test.rs" -> ext=rs (ok), include=src/** (ok), exclude=*_test.rs (fail) -> false + assert!(!filter.matches(&PathBuf::from("src/unit_test.rs"))); + } +} diff --git a/src/core/mod.rs b/src/core/mod.rs index ba223a5..090c0ae 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -2,10 +2,12 @@ mod db; mod embeddings; +pub mod filter; mod indexer; mod search; pub use db::{ChunkEntry, Database, DatabaseStats, FileEntry, SearchResult as DbSearchResult}; pub use embeddings::EmbeddingEngine; +pub use filter::FileFilter; pub use indexer::{Indexer, ServerIndexer}; pub use search::{SearchEngine, SearchResult}; diff --git a/src/core/search.rs b/src/core/search.rs index 1690ba4..b36fb8f 100644 --- a/src/core/search.rs +++ b/src/core/search.rs @@ -4,6 +4,7 @@ use std::path::{Path, PathBuf}; use super::db::{Database, SearchResult as DbSearchResult}; use super::embeddings::EmbeddingEngine; +use super::filter::FileFilter; use crate::config::Config; pub struct SearchResult { @@ -38,6 +39,7 @@ impl SearchEngine { &self, query: &str, path: &Path, + filter: Option<&FileFilter>, max_results: usize, ) -> Result> { let abs_path = std::fs::canonicalize(path).unwrap_or_else(|_| path.to_path_buf()); @@ -48,7 +50,7 @@ impl SearchEngine { // Search for similar chunks let candidates = self .db - .search_similar(&query_embedding, &abs_path, max_results * 3)?; + .search_similar(&query_embedding, &abs_path, filter, max_results * 3)?; if candidates.is_empty() { return Ok(Vec::new()); @@ -92,7 +94,7 @@ impl SearchEngine { pub fn search_interactive(&self, query: &str, max_results: usize) -> Result> { let cwd = std::env::current_dir()?; - self.search(query, &cwd, max_results) + self.search(query, &cwd, None, max_results) } pub fn embed(&self, text: &str) -> Result> { diff --git a/src/server/api.rs b/src/server/api.rs index bd5cc8a..7efe624 100644 --- a/src/server/api.rs +++ b/src/server/api.rs @@ -241,7 +241,7 @@ async fn search( } }; - let candidates = match db.search_similar(&query_embedding, &abs_path, req.max_results * 3) { + let candidates = match db.search_similar(&query_embedding, &abs_path, None, req.max_results * 3) { Ok(c) => c, Err(e) => { return (