diff --git a/src/cli/commands.rs b/src/cli/commands.rs index f128499..3e8ad7b 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -8,6 +8,7 @@ use crate::config::{Config, Mode}; use crate::core::{Database, EmbeddingEngine, Indexer, SearchEngine, ServerIndexer}; use crate::server::{self, Client}; use crate::ui::{self, SearchTui}; +use crate::utils::validation; use crate::watcher::FileWatcher; #[derive(Parser)] @@ -867,9 +868,21 @@ fn run_models(action: ModelsAction, config: &mut Config) -> Result<()> { .get("Qwen3-Embedding-0.6B-Q8_0.gguf")?; pb.finish_and_clear(); - ui::print_success(&format!("Downloaded: {}", embedding_path.display())); - - config.set_embedding_model(embedding_path.to_string_lossy().to_string())?; + + // Validate the downloaded file + println!(" Validating embedding model..."); + match validation::validate_model_file(&embedding_path, 100 * 1024 * 1024) { + Ok(_) => { + ui::print_success(&format!("Downloaded: {}", embedding_path.display())); + config.set_embedding_model(embedding_path.to_string_lossy().to_string())?; + }, + Err(e) => { + ui::print_error(&format!("Validation failed: {}", e)); + println!(" Deleting corrupted file..."); + let _ = std::fs::remove_file(&embedding_path); + return Err(e); + } + } } if !embedding_only { @@ -894,9 +907,21 @@ fn run_models(action: ModelsAction, config: &mut Config) -> Result<()> { .get("Qwen3-Reranker-0.6B-Q4_K_M.gguf")?; pb.finish_and_clear(); - ui::print_success(&format!("Downloaded: {}", reranker_path.display())); - - config.set_reranker_model(reranker_path.to_string_lossy().to_string())?; + + // Validate the downloaded file + println!(" Validating reranker model..."); + match validation::validate_model_file(&reranker_path, 100 * 1024 * 1024) { + Ok(_) => { + ui::print_success(&format!("Downloaded: {}", reranker_path.display())); + config.set_reranker_model(reranker_path.to_string_lossy().to_string())?; + }, + Err(e) => { + ui::print_error(&format!("Validation failed: {}", e)); + println!(" Deleting corrupted file..."); + let _ = std::fs::remove_file(&reranker_path); + return Err(e); + } + } } println!(); diff --git a/src/lib.rs b/src/lib.rs index 6f688f4..9df78e3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ pub mod config; pub mod core; pub mod server; pub mod ui; +pub mod utils; pub mod watcher; pub use config::Config; diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..8695201 --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod validation; diff --git a/src/utils/validation.rs b/src/utils/validation.rs new file mode 100644 index 0000000..10de938 --- /dev/null +++ b/src/utils/validation.rs @@ -0,0 +1,33 @@ +use anyhow::{Result, Context}; +use std::fs; +use std::path::Path; + +// Helper function to validate a model file +pub fn validate_model_file(path: &Path, expected_size_min: u64) -> Result<()> { + // 1. Check if file exists + if !path.exists() { + anyhow::bail!("Model file does not exist at {}", path.display()); + } + + // 2. Check file size + let metadata = fs::metadata(path)?; + if metadata.len() < expected_size_min { + anyhow::bail!( + "Model file size too small: {} bytes (expected at least {})", + metadata.len(), + expected_size_min + ); + } + + // 3. Try to verify GGUF header (first 4 bytes should be 'GGUF') + let mut file = fs::File::open(path)?; + use std::io::Read; + let mut buffer = [0u8; 4]; + file.read_exact(&mut buffer).context("Failed to read file header")?; + + if &buffer != b"GGUF" { + anyhow::bail!("Invalid file format: Header is not 'GGUF'"); + } + + Ok(()) +}