From a35764873afb0d61f2829eef9badcc688022b86f Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Fri, 15 May 2026 09:16:18 +0000 Subject: [PATCH 1/4] feat: auto-download ONNX models from ModelScope --- Cargo.toml | 4 + README.md | 8 + docs/models.md | 64 +-- examples/auto_download.rs | 68 ++++ oar-ocr-core/Cargo.toml | 8 + oar-ocr-core/src/core/download/mod.rs | 442 +++++++++++++++++++++ oar-ocr-core/src/core/download/registry.rs | 149 +++++++ oar-ocr-core/src/core/mod.rs | 2 + src/lib.rs | 12 + src/oarocr/builder_utils.rs | 27 +- src/oarocr/ocr.rs | 21 +- src/oarocr/structure.rs | 34 +- 12 files changed, 797 insertions(+), 42 deletions(-) create mode 100644 examples/auto_download.rs create mode 100644 oar-ocr-core/src/core/download/mod.rs create mode 100644 oar-ocr-core/src/core/download/registry.rs diff --git a/Cargo.toml b/Cargo.toml index 26b4df7..4cf6c62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,10 @@ openvino = ["oar-ocr-core/openvino"] # Set ORT_LIB_LOCATION to use a custom installation (skips download). # Use --no-default-features for offline/enterprise environments. download-binaries = ["oar-ocr-core/download-binaries"] +# Auto-download OCR model files at runtime from ModelScope into `$OAR_HOME` +# (default `~/.oar`). When enabled, builder methods that take a model path +# accept either a filesystem path or a bare registered file name. +auto-download = ["oar-ocr-core/auto-download"] [dependencies] # Core library with all types and predictors diff --git a/README.md b/README.md index 79f7823..8947daa 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,14 @@ With GPU support: cargo add oar-ocr --features cuda ``` +With auto-download of model files from ModelScope: + +```bash +cargo add oar-ocr --features auto-download +``` + +Bare file names passed to the builders are then fetched from [`greatv/oar-ocr` on ModelScope](https://www.modelscope.cn/models/greatv/oar-ocr) into `$OAR_HOME` (default `~/.oar`) and verified against their expected SHA-256. See [docs/models.md](docs/models.md#auto-download-via-the-auto-download-feature) for the exact path resolution rules. + ### Basic Usage ```rust diff --git a/docs/models.md b/docs/models.md index ec70c67..e6eb28a 100644 --- a/docs/models.md +++ b/docs/models.md @@ -1,6 +1,6 @@ # Pre-trained Models -OAROCR provides pre-trained models for OCR and document understanding tasks. Download them from the [GitHub Releases](https://github.com/GreatV/oar-ocr/releases) page. +OAROCR provides pre-trained models for OCR and document understanding tasks. Download them manually from the [GitHub Releases](https://github.com/GreatV/oar-ocr/releases) page (linked in the tables below), or have the library fetch them on demand from ModelScope — see [§ Auto-download](#auto-download-via-the-auto-download-feature) at the bottom. ## Text Detection Models @@ -155,41 +155,45 @@ Models for document structure analysis with `OARStructureBuilder`: | Seal Detection (Mobile) | [`pp-ocrv4_mobile_seal_det.onnx`](https://github.com/GreatV/oar-ocr/releases/download/v0.3.0/pp-ocrv4_mobile_seal_det.onnx) | 4.6MB | Fast seal detection | | Seal Detection (Server) | [`pp-ocrv4_server_seal_det.onnx`](https://github.com/GreatV/oar-ocr/releases/download/v0.3.0/pp-ocrv4_server_seal_det.onnx) | 108.2MB | Accurate seal detection | -## Recommended Configurations - -### Fast Processing (Real-time) +## Auto-download (via the `auto-download` feature) +```bash +cargo add oar-ocr --features auto-download ``` -Detection: pp-ocrv5_mobile_det.onnx -Recognition: pp-ocrv5_mobile_rec.onnx -Dictionary: ppocrv5_dict.txt + +```rust,no_run +use oar_ocr::prelude::*; +let ocr = OAROCRBuilder::new( + "pp-ocrv5_mobile_det.onnx", // bare name → resolved via registry + "pp-ocrv5_mobile_rec.onnx", + "ppocrv5_dict.txt", +).build()?; +# Ok::<(), Box>(()) ``` -### High Accuracy +When the feature is enabled, registered file names are fetched from [`greatv/oar-ocr` on ModelScope](https://www.modelscope.cn/models/greatv/oar-ocr) into `$OAR_HOME` (default `~/.oar`) and verified against the expected SHA-256 before use. Subsequent runs reuse the cached copy. The bundled registry lives at [`oar_ocr::download::REGISTRY`](../oar-ocr-core/src/core/download/registry.rs). -``` -Detection: pp-ocrv5_server_det.onnx -Recognition: pp-ocrv5_server_rec.onnx -Dictionary: ppocrv5_dict.txt -``` +### Path resolution rules -### Document Processing +For each model path argument the builder applies these checks in order: -``` -Detection: pp-ocrv4_server_det.onnx -Recognition: pp-ocrv4_server_rec_doc.onnx -Dictionary: ppocrv4_doc_dict.txt -Orientation: pp-lcnet_x1_0_doc_ori.onnx -Rectification: uvdoc.onnx -``` +1. **Existing file wins.** If the path refers to a real file on disk it is used as-is — no registry lookup, no hash check, no network. A `./pp-ocrv5_mobile_det.onnx` next to the binary always shadows the registry. +2. **Only bare names or `$OAR_HOME`-rooted paths are eligible for auto-download.** A path is considered for registry resolution only when it has no parent component (e.g. `"pp-ocrv5_mobile_det.onnx"`) or when its parent equals the cache directory. Explicit paths like `./models/foo.onnx` or `/data/foo.onnx` are returned verbatim even if their file name is registered — the library never silently overrides an explicit path. +3. **Registry hit → cache or download.** If the file name appears in `REGISTRY`: + - `$OAR_HOME/` exists with matching size + SHA-256 → cached copy is used (no network). + - Cached copy is missing or its hash mismatches → download from ModelScope, verify SHA-256, atomically replace. +4. **Unregistered + missing → returned verbatim** so the builder produces its normal "model not found" error. -### Document Structure Analysis +| Input | On disk | Behaviour | +|---|---|---| +| `"pp-ocrv5_mobile_det.onnx"` | `./pp-ocrv5_mobile_det.onnx` exists | Use the local CWD file | +| `"pp-ocrv5_mobile_det.onnx"` | `$OAR_HOME/...` exists, hash OK | Use cached copy, no network | +| `"pp-ocrv5_mobile_det.onnx"` | absent or hash mismatch | Download to `$OAR_HOME`, verify, use | +| `"./models/det.onnx"` | absent | Returned as-is → "model not found" | +| `"~/.oar/pp-ocrv5_mobile_det.onnx"` | (any) | Treated as a `$OAR_HOME` cache path; same as bare name | -``` -Layout: pp-doclayout_plus-l.onnx -Table Classification: pp-lcnet_x1_0_table_cls.onnx -Table Structure (Wired): slanext_wired.onnx -Table Structure (Wireless): slanet_plus.onnx -Table Structure Dict: table_structure_dict_ch.txt -Formula: pp-formulanet_plus-l.onnx -``` +### Cache layout + +- Override the cache root with the `OAR_HOME` environment variable. Defaults to `~/.oar`. +- Files land at `$OAR_HOME/`, flat (no per-revision subdirectories). +- Downloads stream into `$OAR_HOME/..part` and are renamed atomically once the SHA-256 matches, so a crash mid-download won't poison the cache. diff --git a/examples/auto_download.rs b/examples/auto_download.rs new file mode 100644 index 0000000..57113b7 --- /dev/null +++ b/examples/auto_download.rs @@ -0,0 +1,68 @@ +//! Minimal OCR pipeline using the `auto-download` feature. +//! +//! Demonstrates how the high-level builder transparently fetches missing model +//! files from ModelScope. Any model path that isn't an on-disk file but +//! matches an entry in [`oar_ocr::download::REGISTRY`] is downloaded into +//! `$OAR_HOME` (default `~/.oar`) and verified against its SHA-256. +//! +//! # Build +//! +//! ```bash +//! cargo run --features auto-download --example auto_download -- +//! ``` +//! +//! Without the `auto-download` feature this example refuses to compile so the +//! intent is explicit. + +#[cfg(not(feature = "auto-download"))] +fn main() { + eprintln!( + "This example requires the `auto-download` feature.\n\ + Re-run with: cargo run --features auto-download --example auto_download -- " + ); + std::process::exit(2); +} + +#[cfg(feature = "auto-download")] +fn main() -> Result<(), Box> { + use oar_ocr::oarocr::OAROCRBuilder; + use oar_ocr::utils::load_image; + use std::path::PathBuf; + use tracing_subscriber::EnvFilter; + + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) + .init(); + + let image: PathBuf = std::env::args() + .nth(1) + .map(PathBuf::from) + .ok_or("usage: auto_download ")?; + + println!("OAR cache: {}", oar_ocr::download::cache_dir().display()); + + // Bare file names → resolved through the registry on `build()`. + // The first run downloads to ~/.oar (or $OAR_HOME); subsequent runs reuse + // the cached copies after verifying their SHA-256. + let ocr = OAROCRBuilder::new( + "pp-ocrv5_mobile_det.onnx", + "pp-ocrv5_mobile_rec.onnx", + "ppocrv5_dict.txt", + ) + .with_text_line_orientation_classification("pp-lcnet_x1_0_textline_ori.onnx") + .build()?; + + let img = load_image(&image)?; + let results = ocr.predict(vec![img])?; + for (i, page) in results.iter().enumerate() { + println!("--- page {i} ---"); + for region in &page.text_regions { + if let Some(ref text) = region.text { + println!("{text}"); + } + } + } + Ok(()) +} diff --git a/oar-ocr-core/Cargo.toml b/oar-ocr-core/Cargo.toml index f722ce7..ca7c52d 100644 --- a/oar-ocr-core/Cargo.toml +++ b/oar-ocr-core/Cargo.toml @@ -29,6 +29,11 @@ webgpu = ["ort/webgpu"] openvino = ["ort/openvino"] # Auto-download ONNX Runtime binaries during build (enabled by default). download-binaries = ["ort/download-binaries", "ort/tls-native"] +# Auto-download OCR model files at runtime from ModelScope into the on-disk +# cache (`$OAR_HOME` or `~/.oar`). When enabled, model paths handed to the +# OCR builders that do not exist on disk but match the bundled registry are +# resolved by fetching the file over HTTPS and verifying its SHA-256. +auto-download = ["dep:ureq", "dep:sha2", "dep:dirs"] [dependencies] oar-ocr-derive.workspace = true @@ -47,6 +52,9 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokenizers = { version = "0.23", default-features = false, features = ["progressbar", "onig"] } clipper2-rust = "1.0.3" +ureq = { version = "3.0", default-features = false, features = ["rustls", "platform-verifier"], optional = true } +sha2 = { version = "0.10", optional = true } +dirs = { version = "5.0", optional = true } [dev-dependencies] tempfile = "3.19" diff --git a/oar-ocr-core/src/core/download/mod.rs b/oar-ocr-core/src/core/download/mod.rs new file mode 100644 index 0000000..3ce2492 --- /dev/null +++ b/oar-ocr-core/src/core/download/mod.rs @@ -0,0 +1,442 @@ +//! Auto-download of OCR model files from ModelScope (feature `auto-download`). +//! +//! When this feature is enabled, callers can pass *just a filename* (such as +//! `"pp-ocrv5_mobile_det.onnx"`) anywhere the OCR builders accept a model +//! path. If the file is not already on disk it is fetched from +//! [`greatv/oar-ocr`](https://www.modelscope.cn/models/greatv/oar-ocr), +//! its SHA-256 verified against [`registry::REGISTRY`], and cached under +//! `$OAR_HOME` (default `~/.oar`). +//! +//! ## Lookup rules +//! +//! [`resolve_path`] is the single entry point used by the builders: +//! +//! 1. If the input path exists on disk and refers to a file, it is returned +//! as-is (no hash check) — users keep full control over local files. +//! 2. Otherwise, if the file *name* matches an entry in the registry, the +//! cached copy under `$OAR_HOME` is reused (verified) or downloaded and +//! verified. +//! 3. Otherwise, the path is returned unchanged so the caller produces its +//! usual "model not found" error. +//! +//! ## Cache location +//! +//! - `$OAR_HOME` environment variable (if set), else +//! - `/.oar` +//! +//! The directory is created lazily on first download. + +use std::env; +use std::fs::{self, File}; +use std::io::{self, Read, Write}; +use std::path::{Path, PathBuf}; +use std::time::Duration; + +use crate::core::errors::OCRError; + +pub mod registry; + +pub use registry::{DEFAULT_REVISION, MODELSCOPE_REPO, REGISTRY}; + +/// A registered file mirrored to ModelScope. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Entry { + /// File name as stored in the repo and in the cache directory. + pub name: &'static str, + /// Lowercase hex-encoded SHA-256 of the file contents. + pub sha256: &'static str, + /// Size in bytes. + pub size: u64, +} + +/// Environment variable used to override the cache directory. +pub const OAR_HOME_ENV: &str = "OAR_HOME"; + +/// Default cache directory name placed under the user's home dir. +const DEFAULT_CACHE_SUBDIR: &str = ".oar"; + +const DOWNLOAD_RETRIES: u32 = 3; +const READ_BUFFER_BYTES: usize = 64 * 1024; +/// Whole-request timeout for a single download attempt. Sized to allow the +/// largest registered file (~1.8 GB) to complete on slow links (~1 MB/s); +/// the retry loop applies this per attempt. +const REQUEST_TIMEOUT_SECS: u64 = 30 * 60; +const CONNECT_TIMEOUT_SECS: u64 = 30; + +/// Look up an entry by file name. Returns `None` if the file isn't registered. +pub fn find(name: &str) -> Option<&'static Entry> { + REGISTRY + .binary_search_by_key(&name, |e| e.name) + .ok() + .map(|idx| ®ISTRY[idx]) +} + +/// Returns the cache directory used to store auto-downloaded models. +/// +/// Resolution order: +/// 1. `$OAR_HOME`, if set and non-empty. +/// 2. `/.oar` where `` is the platform user home directory. +/// +/// Falls back to `./.oar` (relative to the current working directory) when +/// the user home cannot be determined — this matches what the repository's +/// existing examples already use. +pub fn cache_dir() -> PathBuf { + if let Some(dir) = env::var_os(OAR_HOME_ENV) { + let dir = PathBuf::from(dir); + if !dir.as_os_str().is_empty() { + return dir; + } + } + dirs::home_dir() + .map(|h| h.join(DEFAULT_CACHE_SUBDIR)) + .unwrap_or_else(|| PathBuf::from(DEFAULT_CACHE_SUBDIR)) +} + +/// Resolve a user-supplied path through the auto-download cache. +/// +/// See module docs for the lookup rules. Returns the resolved path (either +/// the original `input`, an existing cache entry, or a freshly downloaded +/// file). +pub fn resolve_path(input: impl AsRef) -> Result { + let input = input.as_ref(); + + // Rule 1: trust paths that already exist on disk. + if input.is_file() { + return Ok(input.to_path_buf()); + } + + // Rule 2: registry lookup keyed on the file name component. + let name = match input.file_name().and_then(|s| s.to_str()) { + Some(n) => n, + None => return Ok(input.to_path_buf()), + }; + + // Only match when the user gave just a filename (no parent component) or + // when the parent is the configured cache directory. This avoids quietly + // overriding a user's explicit path like `./models/pp-ocrv4_mobile_det.onnx`. + let parent_is_bare = input.parent().is_none_or(|p| p.as_os_str().is_empty()); + let cache = cache_dir(); + let parent_is_cache = input.parent() == Some(cache.as_path()); + + if !(parent_is_bare || parent_is_cache) { + return Ok(input.to_path_buf()); + } + + if let Some(entry) = find(name) { + return fetch_entry(entry); + } + + Ok(input.to_path_buf()) +} + +/// Fetch a registered file by name, returning its path in the cache. +/// +/// If the file is already present and its SHA-256 matches, no network access +/// occurs; otherwise it is downloaded from ModelScope and verified. +pub fn fetch(name: &str) -> Result { + let entry = find(name).ok_or_else(|| OCRError::ConfigError { + message: format!( + "model file `{}` is not registered for auto-download. \ + Pass an explicit path or add an entry to oar_ocr_core::core::download::registry::REGISTRY.", + name + ), + })?; + fetch_entry(entry) +} + +fn fetch_entry(entry: &Entry) -> Result { + let dir = cache_dir(); + if let Err(e) = fs::create_dir_all(&dir) { + return Err(OCRError::Io(io::Error::new( + e.kind(), + format!( + "failed to create OAR cache directory `{}`: {}", + dir.display(), + e + ), + ))); + } + let target = dir.join(entry.name); + + if cached_file_matches(&target, entry)? { + return Ok(target); + } + + download_and_verify(entry, &target)?; + Ok(target) +} + +fn cached_file_matches(path: &Path, entry: &Entry) -> Result { + let meta = match fs::metadata(path) { + Ok(m) => m, + Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(false), + Err(e) => return Err(io_with_context(e, format!("stat `{}`", path.display()))), + }; + if !meta.is_file() { + return Ok(false); + } + if meta.len() != entry.size { + tracing::warn!( + path = %path.display(), + expected_size = entry.size, + actual_size = meta.len(), + "cached model has wrong size; redownloading" + ); + return Ok(false); + } + match sha256_file(path) { + Ok(hash) if hash == entry.sha256 => Ok(true), + Ok(hash) => { + tracing::warn!( + path = %path.display(), + expected = entry.sha256, + actual = %hash, + "cached model sha256 mismatch; redownloading" + ); + Ok(false) + } + Err(e) => { + tracing::warn!( + path = %path.display(), + error = %e, + "failed to hash cached model; redownloading" + ); + Ok(false) + } + } +} + +fn download_and_verify(entry: &Entry, target: &Path) -> Result<(), OCRError> { + let url = format!( + "https://www.modelscope.cn/api/v1/models/{}/repo?Revision={}&FilePath={}", + MODELSCOPE_REPO, DEFAULT_REVISION, entry.name, + ); + + let mut last_err: Option = None; + for attempt in 1..=DOWNLOAD_RETRIES { + tracing::info!( + file = entry.name, + size = entry.size, + attempt, + "downloading from ModelScope" + ); + match download_attempt(&url, entry, target) { + Ok(()) => return Ok(()), + Err(e) => { + tracing::warn!( + file = entry.name, + attempt, + error = %e, + "download attempt failed", + ); + last_err = Some(e); + } + } + } + Err(last_err.unwrap_or_else(|| OCRError::ConfigError { + message: format!("download of `{}` failed after retries", entry.name), + })) +} + +fn download_attempt(url: &str, entry: &Entry, target: &Path) -> Result<(), OCRError> { + let agent = ureq::Agent::config_builder() + .timeout_global(Some(Duration::from_secs(REQUEST_TIMEOUT_SECS))) + .timeout_connect(Some(Duration::from_secs(CONNECT_TIMEOUT_SECS))) + .build() + .new_agent(); + + let response = agent + .get(url) + .call() + .map_err(|e| network_error(format!("GET {}", url), e))?; + + let mut body = response.into_body().into_reader(); + + let parent = target.parent().unwrap_or_else(|| Path::new(".")); + let tmp = target.with_file_name(format!(".{}.part", entry.name)); + let mut file = File::create(&tmp).map_err(|e| { + io_with_context( + e, + format!( + "create temp download `{}` in `{}`", + tmp.display(), + parent.display() + ), + ) + })?; + + let mut hasher = ::new(); + let mut buf = vec![0u8; READ_BUFFER_BYTES]; + let mut written: u64 = 0; + loop { + let n = body + .read(&mut buf) + .map_err(|e| io_with_context(e, format!("read body for `{}`", entry.name)))?; + if n == 0 { + break; + } + sha2::Digest::update(&mut hasher, &buf[..n]); + file.write_all(&buf[..n]) + .map_err(|e| io_with_context(e, format!("write `{}`", tmp.display())))?; + written += n as u64; + } + file.sync_all() + .map_err(|e| io_with_context(e, format!("sync `{}`", tmp.display())))?; + drop(file); + + if written != entry.size { + let _ = fs::remove_file(&tmp); + return Err(OCRError::ConfigError { + message: format!( + "downloaded `{}` is {} bytes but the registry expects {}", + entry.name, written, entry.size + ), + }); + } + + let actual_hash = encode_hex(&sha2::Digest::finalize(hasher)); + if actual_hash != entry.sha256 { + let _ = fs::remove_file(&tmp); + return Err(OCRError::ConfigError { + message: format!( + "sha256 mismatch for `{}`: expected {}, got {}", + entry.name, entry.sha256, actual_hash + ), + }); + } + + fs::rename(&tmp, target).map_err(|e| { + io_with_context( + e, + format!("move `{}` -> `{}`", tmp.display(), target.display()), + ) + })?; + Ok(()) +} + +fn sha256_file(path: &Path) -> io::Result { + let mut file = File::open(path)?; + let mut hasher = ::new(); + let mut buf = vec![0u8; READ_BUFFER_BYTES]; + loop { + let n = file.read(&mut buf)?; + if n == 0 { + break; + } + sha2::Digest::update(&mut hasher, &buf[..n]); + } + Ok(encode_hex(&sha2::Digest::finalize(hasher))) +} + +fn encode_hex(bytes: &[u8]) -> String { + const HEX: &[u8; 16] = b"0123456789abcdef"; + let mut out = String::with_capacity(bytes.len() * 2); + for &b in bytes { + out.push(HEX[(b >> 4) as usize] as char); + out.push(HEX[(b & 0xf) as usize] as char); + } + out +} + +fn io_with_context(e: io::Error, ctx: String) -> OCRError { + OCRError::Io(io::Error::new(e.kind(), format!("{ctx}: {e}"))) +} + +fn network_error(ctx: String, err: ureq::Error) -> OCRError { + OCRError::Io(io::Error::other(format!("{ctx}: {err}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn find_known_entry() { + let entry = find("ppocrv5_en_dict.txt").expect("registered"); + assert_eq!(entry.size, 1416); + } + + #[test] + fn find_unknown_entry_returns_none() { + assert!(find("does-not-exist.onnx").is_none()); + } + + #[test] + fn resolve_existing_file_returns_input() { + let dir = tempfile::tempdir().unwrap(); + let f = dir.path().join("local.onnx"); + std::fs::write(&f, b"hi").unwrap(); + let resolved = resolve_path(&f).unwrap(); + assert_eq!(resolved, f); + } + + #[test] + fn resolve_explicit_path_passthrough_for_unknown() { + // A nested path that doesn't exist and isn't registered must be + // returned verbatim so the caller's normal error fires. + let p = PathBuf::from("/nonexistent/dir/some_random_model.onnx"); + let resolved = resolve_path(&p).unwrap(); + assert_eq!(resolved, p); + } + + #[test] + fn resolve_bare_name_unknown_does_not_consult_network() { + // No registry hit, no existing file → returned as-is. + let p = PathBuf::from("not-in-registry.onnx"); + let resolved = resolve_path(&p).unwrap(); + assert_eq!(resolved, p); + } + + #[test] + fn cache_dir_honours_env_override() { + let dir = tempfile::tempdir().unwrap(); + unsafe { + std::env::set_var(OAR_HOME_ENV, dir.path()); + } + assert_eq!(cache_dir(), dir.path()); + unsafe { + std::env::remove_var(OAR_HOME_ENV); + } + } + + #[test] + fn cached_file_matches_detects_size_and_hash_mismatch() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("dummy.bin"); + + // Fake registry entry pinned to known SHA-256 of "hello\n" (6 bytes). + let entry = Entry { + name: "dummy.bin", + // sha256 of b"hello\n" + sha256: "5891b5b522d5df086d0ff0b110fbd9d21bb4fc7163af34d08286a2e846f6be03", + size: 6, + }; + + // Missing file → no match (no network needed). + assert!(!cached_file_matches(&path, &entry).unwrap()); + + // Correct contents → match. + std::fs::write(&path, b"hello\n").unwrap(); + assert!(cached_file_matches(&path, &entry).unwrap()); + + // Same hash but wrong size when we lie about expected size → no match. + let mismatched_size = Entry { size: 99, ..entry }; + assert!(!cached_file_matches(&path, &mismatched_size).unwrap()); + + // Wrong contents → no match (hash differs). + std::fs::write(&path, b"world!").unwrap(); + let same_len = Entry { size: 6, ..entry }; + assert!(!cached_file_matches(&path, &same_len).unwrap()); + } + + #[test] + fn fetch_unregistered_name_returns_config_error() { + let err = fetch("does-not-exist.onnx").unwrap_err(); + match err { + OCRError::ConfigError { message } => { + assert!(message.contains("not registered")); + } + other => panic!("expected ConfigError, got {other:?}"), + } + } +} diff --git a/oar-ocr-core/src/core/download/registry.rs b/oar-ocr-core/src/core/download/registry.rs new file mode 100644 index 0000000..c44a0c6 --- /dev/null +++ b/oar-ocr-core/src/core/download/registry.rs @@ -0,0 +1,149 @@ +//! Static registry of model files hosted on ModelScope under +//! [`greatv/oar-ocr`](https://www.modelscope.cn/models/greatv/oar-ocr). +//! +//! The entries are sorted by name to allow a binary search lookup. +//! This file is intended to be regenerated as new models are uploaded. + +use super::Entry; + +/// ModelScope owner/name of the repository hosting the registered files. +pub const MODELSCOPE_REPO: &str = "greatv/oar-ocr"; + +/// Default revision (branch/tag) to fetch from ModelScope. +pub const DEFAULT_REVISION: &str = "master"; + +/// Sorted-by-name registry of files mirrored on ModelScope. +/// +/// Adding an entry: upload the file with `modelscope upload greatv/oar-ocr`, +/// compute its SHA-256 (`sha256sum`) and byte size (`stat -c %s`), and insert +/// in name order. Keep the slice sorted so [`super::find`] can binary-search. +#[rustfmt::skip] +pub static REGISTRY: &[Entry] = &[ + Entry { name: "arabic_pp-ocrv3_mobile_rec.onnx", sha256: "7012a52bfe7ed9910bef1c74e295d8c3456175aa9b4e9015271892ced559687a", size: 8995821 }, + Entry { name: "arabic_pp-ocrv5_mobile_rec.onnx", sha256: "2768206d9a0ce48eba45b59619184e18161dde8f44115f029920ca17a9dc0384", size: 8026538 }, + Entry { name: "ch_repsvtr_rec.onnx", sha256: "6e7d40f0e3c4c16443c9efe7265bb9788ee3461fd6f26b41990524fbaae7ec5d", size: 25371710 }, + Entry { name: "ch_svtrv2_rec.onnx", sha256: "3fadaeecebd49d4df4f96155875be393e66161befc26258d3e62ee9968efd648", size: 84196641 }, + Entry { name: "chinese_cht_pp-ocrv3_mobile_rec.onnx", sha256: "535ec8a5d0207f34b4cc6cca28c0903991df0ec6ecb8bd7eea2f16f4468afc2b", size: 11143425 }, + Entry { name: "cyrillic_pp-ocrv3_mobile_rec.onnx", sha256: "6ab2b46cee27755f82cacd86a73706f00146f1938aa5c74549a4fb2d1f94ae9c", size: 8996341 }, + Entry { name: "cyrillic_pp-ocrv5_mobile_rec.onnx", sha256: "a18d96d7c8d73d90f2ed056549caa1de3a8e6cb744cccba16cd593ea8cd2d569", size: 8076390 }, + Entry { name: "devanagari_pp-ocrv3_mobile_rec.onnx", sha256: "97bc713646ae30442d536b47c3f0d65ad249dbc8006a7beff66825f7df691405", size: 8997381 }, + Entry { name: "devanagari_pp-ocrv5_mobile_rec.onnx", sha256: "b3d50774dfbec6ae02249ff79a925431a4381c8c6f86d342ff6e7b63e5fefa77", size: 7939902 }, + Entry { name: "el_pp-ocrv5_mobile_rec.onnx", sha256: "5a4a020e48e8783e035e1af135423c2161a363acab9ef16e48238c3d181f0f71", size: 7836326 }, + Entry { name: "en_pp-ocrv3_mobile_rec.onnx", sha256: "dcb188df82c426f251283fd4ef4ea57c039520ce6ff8355ba5aa9b2535073c33", size: 8978654 }, + Entry { name: "en_pp-ocrv4_mobile_rec.onnx", sha256: "40c07c5e431a4c59d7b5a1fefdba2fddb962c939d626c4dbf1d32965ab533431", size: 7710963 }, + Entry { name: "en_pp-ocrv5_mobile_rec.onnx", sha256: "8307465d3c9ef2ba4055c3bd0be55aafe11f518630212b7598b70ccb376028ac", size: 7876014 }, + Entry { name: "eslav_pp-ocrv5_mobile_rec.onnx", sha256: "36a66a68097e88b103e0f60f489e88c7239d3ea79d96fbac2d80ac9d134944cd", size: 7915218 }, + Entry { name: "japan_pp-ocrv3_mobile_rec.onnx", sha256: "da59d4dbb6786a92a3823c32b3f48179d18e29903d61462842aad9eb422a77a1", size: 10097703 }, + Entry { name: "ka_pp-ocrv3_mobile_rec.onnx", sha256: "e9592deee670c7ae3b21a7f6ad5d73f080628a30d8f00d26c668cfa3b493de04", size: 8993741 }, + Entry { name: "korean_pp-ocrv3_mobile_rec.onnx", sha256: "d30dbf20502044dc0e697f047564ac015e7083e4766fdc9d0fcd225aa2b0f20d", size: 9912841 }, + Entry { name: "korean_pp-ocrv5_mobile_rec.onnx", sha256: "2d7ed96308065a86103325d22af07a88c4d06afc009f21602a4882342c0cc054", size: 13446374 }, + Entry { name: "latex_ocr_rec.onnx", sha256: "b4714a09ab4b5049ef5d404b4f8212e17ee03d634756ad6ba3ad380e91745613", size: 102499156 }, + Entry { name: "latin_pp-ocrv3_mobile_rec.onnx", sha256: "e73a1fc3853b36fe99d7990858bbc9630346706db2e08738ffd352cf789de7ef", size: 9002061 }, + Entry { name: "latin_pp-ocrv5_mobile_rec.onnx", sha256: "e3a6bfeea1c8a01d6fccfd480a0bd363fd907f8c65931e228bb2736f5c3e142f", size: 8069614 }, + Entry { name: "p2o_pp-lcnet_x0_25_textline_ori.onnx", sha256: "44fdaeabcd95861fcdf8a31f8ecebf885f72ce50dac94989603a3bc60eacde54", size: 1000746 }, + Entry { name: "picodet-l_layout_17cls.onnx", sha256: "bf16cde3c9d0fe160ef74d7f9143f67f4c85fb0a6afe3923942ebe8e8854e734", size: 23481268 }, + Entry { name: "picodet-l_layout_3cls.onnx", sha256: "2112c55bb86aa59dc6bf4d71b75266e5b07fdb663eb295967086cd657d870001", size: 23445204 }, + Entry { name: "picodet-s_layout_17cls.onnx", sha256: "94da097a087d039a87575a3f1172b19265a74b40781aaf6727117a2205b4cb6d", size: 4905603 }, + Entry { name: "picodet-s_layout_3cls.onnx", sha256: "77895ad1aff11d7b4fd94d58454db0a848b90beabfac34a74feaad3956a277ca", size: 4883874 }, + Entry { name: "picodet_layout_1x.onnx", sha256: "5a95d6a17380cc5b146f515548679046708fd18d8caa1cafecd98a80b6252523", size: 7522719 }, + Entry { name: "picodet_layout_1x_table.onnx", sha256: "e62882aa0eedd7aa417e7a1ef0042f1a106a2492f32c2fed492e9b3145c2d1ba", size: 7514462 }, + Entry { name: "pp-docblocklayout.onnx", sha256: "4f2b7465a9ca1e8519848573544e1ee108ead67f8b35958b77a293b02eca44cd", size: 129331821 }, + Entry { name: "pp-doclayout-l.onnx", sha256: "094fef666d9785d001238d0b93a88c2c365b059d211d7e1e494ccd2419f3b3c1", size: 129377057 }, + Entry { name: "pp-doclayout-m.onnx", sha256: "8e458bfc919bbf7a35be9802485b5cd30151cb356364cfad09911d2ee1fc1f76", size: 23496727 }, + Entry { name: "pp-doclayout-s.onnx", sha256: "c2336493a0a13cd9b9b457ca68aea370b327c362a4a7da4917c2bba96029bceb", size: 4914918 }, + Entry { name: "pp-doclayout_plus-l.onnx", sha256: "b06cedc7ab3cca7da4ed66cf16024732149d2c29e6adcbfc69b9bb6ef94b4a48", size: 129714689 }, + Entry { name: "pp-doclayoutv2.onnx", sha256: "a325532df1c7530538ef4e8254695c091adc6afd3366c0851425491f0816d1d1", size: 213969379 }, + Entry { name: "pp-doclayoutv3.onnx", sha256: "1a7ec3812d239ad14debb87e38273012558fea27ac10b44b61260e3a88358e39", size: 129955811 }, + Entry { name: "pp-formulanet-l.onnx", sha256: "e408f1d4e6d67c694a2c8f75dfd17d2e2668b1191c3c68d029d445b133be4bb5", size: 730379948 }, + Entry { name: "pp-formulanet-s.onnx", sha256: "0ee32c7bfbd9e586364f89f71860476ccb5334e35674a61f3df5e0553d6a6dcc", size: 231878904 }, + Entry { name: "pp-formulanet-tokenizer.json", sha256: "2811d82701ec97c192fa256aa2b4516929373870ae660326cc5b1dc879b95ff2", size: 2140014 }, + Entry { name: "pp-formulanet_plus-l.onnx", sha256: "b4924d69c731365048de3d11a5d1829f3dfd8b98b4dbfd82437f934c2611934f", size: 733525676 }, + Entry { name: "pp-formulanet_plus-m.onnx", sha256: "9e3539c2b4eeed28f2d35e342fd5bb0bdaa7f6034a475fc7e890c92780910618", size: 592372919 }, + Entry { name: "pp-formulanet_plus-s.onnx", sha256: "449d205c8fb2fe0a9b134a5e4a0f2421c2e7812fd902ea67dfda4e9ef4588978", size: 231878904 }, + Entry { name: "pp-lcnet_x0_25_textline_ori.onnx", sha256: "fb402220f39b183d64a68cd48d4bd53267a21354b6fc39370c2a83fbdad85b10", size: 1018629 }, + Entry { name: "pp-lcnet_x1_0_doc_ori.onnx", sha256: "bbcd6c2b43ab15d2e605455aef2cd280ab87b570824522d24a69e9298875a1ac", size: 6787248 }, + Entry { name: "pp-lcnet_x1_0_table_cls.onnx", sha256: "61ed75151cadba903ec5182f1ffc59e961e52de501c61c5ffeb466346fc65040", size: 6776998 }, + Entry { name: "pp-lcnet_x1_0_textline_ori.onnx", sha256: "6b02efabbedd6be69e3de4c86b8dceed2d7329e75c12a796e6717bfb0d646950", size: 6776997 }, + Entry { name: "pp-ocrv3_mobile_rec.onnx", sha256: "8febeeba4792aed934be20ffffa6f050d717da059084cc294d4a88ee35130599", size: 10675943 }, + Entry { name: "pp-ocrv4_mobile_det.onnx", sha256: "ab2a50dcd2c340852f2d0fbfa547d5eec79a0d04a774eb0b622d96d0d9d2ceeb", size: 4826518 }, + Entry { name: "pp-ocrv4_mobile_rec.onnx", sha256: "5d54b59bac0f49d4561f0462630d8a6822b5b495db064f58096fd3d2392fbc4e", size: 10870526 }, + Entry { name: "pp-ocrv4_mobile_seal_det.onnx", sha256: "e6109a1022b5ebf0822fc00646ef2398a7ef387390ca5c978de79352b1314204", size: 4826518 }, + Entry { name: "pp-ocrv4_server_det.onnx", sha256: "5b676249ca4d1653675b249f134cb483ada721ac70d8013ddab09db7bcf26c1f", size: 113442336 }, + Entry { name: "pp-ocrv4_server_rec.onnx", sha256: "70939bcaabb8700dd9627ab7a38acbbbb8eea589cf27aef565f2343921a502c8", size: 90538610 }, + Entry { name: "pp-ocrv4_server_rec_doc.onnx", sha256: "1c64b0b01d5e03b931608ee366efeca868a6fd1b8015bd8f4f9a1fff43708ae3", size: 94897514 }, + Entry { name: "pp-ocrv4_server_seal_det.onnx", sha256: "8fc8b257e3841144c23b2d75b35cd95d82abefa343b6ac16f615f96c848e2357", size: 113442336 }, + Entry { name: "pp-ocrv5_mobile_det.onnx", sha256: "1eb7b4f7ab657ebd1c66d5f79bca7497f29768a2e3c15e52daecbba1a8e4a039", size: 4826518 }, + Entry { name: "pp-ocrv5_mobile_rec.onnx", sha256: "243a0f06d826761323e9045e9b113ab2c191c3aa50565585e628300b8eda0224", size: 16562373 }, + Entry { name: "pp-ocrv5_server_det.onnx", sha256: "9a910baffbefb807ff2f7bfaa72910e3e470bd17014d798386d87bb46f442839", size: 88116836 }, + Entry { name: "pp-ocrv5_server_rec.onnx", sha256: "4bfffad2c62eb1340250455856978fb9fb19cb4776b264ae3c2f91c35fbb40b4", size: 84502992 }, + Entry { name: "ppocr_keys_v1.txt", sha256: "a1c84d9bdb9ab29043c58896224d32941783eb821629618416dcb08f12886492", size: 26250 }, + Entry { name: "ppocrv4_doc_dict.txt", sha256: "a5bc3887c43c901e5a3f97b13ffadf1c5754ede7cc8c9f5abe22e875a7c48372", size: 62346 }, + Entry { name: "ppocrv5_arabic_dict.txt", sha256: "7f92f7dbb9b75a4787a83bfb4f6d14a8ab515525130c9d40a9036f61cf6999e9", size: 2369 }, + Entry { name: "ppocrv5_cyrillic_dict.txt", sha256: "db40aa52ceb112055be80c694afdf655d5d2c4f7873704524cc16a447ca913ba", size: 2781 }, + Entry { name: "ppocrv5_devanagari_dict.txt", sha256: "09c7440bfc5477e5c41052304b6b185aff8c4a5e8b2b4c23c1c706f6fe1ee9fc", size: 1943 }, + Entry { name: "ppocrv5_dict.txt", sha256: "d1979e9f794c464c0d2e0b70a7fe14dd978e9dc644c0e71f14158cdf8342af1b", size: 74012 }, + Entry { name: "ppocrv5_el_dict.txt", sha256: "31defc62c0c3ad3674a82da6192226a2ba98ef4ff014a7045cb88d59f9c3de31", size: 1103 }, + Entry { name: "ppocrv5_en_dict.txt", sha256: "e025a66d31f327ba0c232e03f407ae8d105e1e709e7ccb3f408aa778c24e70d6", size: 1416 }, + Entry { name: "ppocrv5_eslav_dict.txt", sha256: "3e95f1581557162870cacdba5af91a4c6be2890710d395b0c3c7578e7ee5e6eb", size: 1663 }, + Entry { name: "ppocrv5_korean_dict.txt", sha256: "a88071c68c01707489baa79ebe0405b7beb5cca229f4fc94cc3ef992328802d7", size: 47451 }, + Entry { name: "ppocrv5_latin_dict.txt", sha256: "ccbcc45730b3fbbd9050c5bc74db6a99067141ef1035e3d14889a84a6b9b1aff", size: 2616 }, + Entry { name: "ppocrv5_ta_dict.txt", sha256: "85b541352ae18dc6ba6d47152d8bf8adff6b0266e605d2eef2990c1bf466117b", size: 1723 }, + Entry { name: "ppocrv5_te_dict.txt", sha256: "42f83f5d3fdb50778e4fa5b66c58d99a59ab7792151c5e74f34b8ffd7b61c9d6", size: 1831 }, + Entry { name: "ppocrv5_th_dict.txt", sha256: "57f5406f94bb6688fb7077f7be65f08bbd71cecf48c01ea26c522cb5c4836b7a", size: 1767 }, + Entry { name: "rt-detr-h_layout_17cls.onnx", sha256: "079173c137540a2a56598d872e408646af12aa537140c1dc246592af4e7f9b95", size: 492056102 }, + Entry { name: "rt-detr-h_layout_3cls.onnx", sha256: "bce52ce49762f77213b2dd40ab5901c504e809880cc3e684f97e472f2a3303aa", size: 492027314 }, + Entry { name: "rt-detr-l_wired_table_cell_det.onnx", sha256: "238dfece5c48d926a3ebac07341eb197f35038f5f5ca79dc6f75fa9686853f6d", size: 129331821 }, + Entry { name: "rt-detr-l_wireless_table_cell_det.onnx", sha256: "3b373ba8467403956e2f043bfc00fba8a147fcb18c6988b898776d5cd523f520", size: 129331821 }, + Entry { name: "slanet.onnx", sha256: "ebb506f2af6ba26502bb857b6f82a06af12c5231a1c52146a473b2c90205df3b", size: 7782138 }, + Entry { name: "slanet_plus.onnx", sha256: "3a96a71719247c5d94992fca31266b598c54740388de371f0c75077e2a9e0b55", size: 7782138 }, + Entry { name: "slanet_plus_v2.onnx", sha256: "e0bff8da087f9b83629f1e1a6e0f8252fc2de85a7d80415b3510fc521338da3d", size: 7781255 }, + Entry { name: "slanext_wired.onnx", sha256: "0d1efd752685f42271326eeca93f321fc6ba6d6f75ff491f31f40556dcecc4af", size: 367743373 }, + Entry { name: "slanext_wireless.onnx", sha256: "9bc8f145da44766c11acef4e436a58da8fb192ef50dc7ffc2d7fcdf82ae66419", size: 367743373 }, + Entry { name: "ta_pp-ocrv3_mobile_rec.onnx", sha256: "de98658698cf72be6f299f04ed78032489ab502553dab31b13f49f22dab2d62f", size: 8987241 }, + Entry { name: "ta_pp-ocrv5_mobile_rec.onnx", sha256: "508d07ac0e1806a8b8857ebf20bd8837d68d962c3fdba030cf0022238ba819b8", size: 7913282 }, + Entry { name: "table_structure_dict_ch.txt", sha256: "68d344a84b726e043f390122240ff2b2ced2949b2a80ce9b61ae955054d190ef", size: 578 }, + Entry { name: "te_pp-ocrv3_mobile_rec.onnx", sha256: "5a8ba806f4fecac40bbaa468e1885530eaea483f9c63b01dbe93acebc342caa0", size: 8993221 }, + Entry { name: "te_pp-ocrv5_mobile_rec.onnx", sha256: "0957fc03425324d30afcd4342a5708687fd0c885a5c0987cae418e96dd761a60", size: 7926350 }, + Entry { name: "th_pp-ocrv5_mobile_rec.onnx", sha256: "5f6ee21242691681261fee01bc39867da9cc8ff9b889f2f048b3cb7f74380217", size: 7918606 }, + Entry { name: "unimernet.onnx", sha256: "1d64fafa0161f153dafe40823e97c4b05103030509dd6b286d7c8d4a11b068ab", size: 1842024100 }, + Entry { name: "unimernet_tokenizer.json", sha256: "2811d82701ec97c192fa256aa2b4516929373870ae660326cc5b1dc879b95ff2", size: 2140014 }, + Entry { name: "unimernet_tokenizer_config.json", sha256: "fd4d94f8b9dbb7deeb3a3ef084ca0e16c43d45774a180730b7e9cfd6359a074b", size: 4491 }, + Entry { name: "uvdoc.onnx", sha256: "1092557894d49644e7858b293df6cb9c873d53e51319b91a8614ca9c71686dc0", size: 31684150 }, +]; + +#[cfg(test)] +mod tests { + use super::REGISTRY; + + #[test] + fn registry_is_sorted_and_unique() { + for pair in REGISTRY.windows(2) { + assert!( + pair[0].name < pair[1].name, + "registry must be sorted and unique, but `{}` ≥ `{}`", + pair[0].name, + pair[1].name, + ); + } + } + + #[test] + fn registry_hashes_are_64_lowercase_hex() { + for entry in REGISTRY { + assert_eq!( + entry.sha256.len(), + 64, + "{} has non-64-char sha256", + entry.name + ); + assert!( + entry + .sha256 + .chars() + .all(|c| matches!(c, '0'..='9' | 'a'..='f')), + "{} sha256 must be lowercase hex", + entry.name + ); + } + } +} diff --git a/oar-ocr-core/src/core/mod.rs b/oar-ocr-core/src/core/mod.rs index 0db872e..92e35af 100644 --- a/oar-ocr-core/src/core/mod.rs +++ b/oar-ocr-core/src/core/mod.rs @@ -14,6 +14,8 @@ pub mod batch; pub mod config; pub mod constants; +#[cfg(feature = "auto-download")] +pub mod download; pub mod errors; pub mod image_reader; pub mod inference; diff --git a/src/lib.rs b/src/lib.rs index 1110be1..c646733 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,6 +104,18 @@ pub mod core { pub use oar_ocr_core::core::*; } +/// Auto-download of model files from ModelScope. +/// +/// Available only when the `auto-download` feature is enabled. See +/// [`oar_ocr_core::core::download`] for details. When the feature is on, +/// the high-level OCR builders accept either a filesystem path or a bare +/// registered file name (e.g. `"pp-ocrv5_mobile_det.onnx"`) for any model +/// path argument. +#[cfg(feature = "auto-download")] +pub mod download { + pub use oar_ocr_core::core::download::*; +} + pub mod domain { pub use oar_ocr_core::domain::*; } diff --git a/src/oarocr/builder_utils.rs b/src/oarocr/builder_utils.rs index 63b4c39..1473bd9 100644 --- a/src/oarocr/builder_utils.rs +++ b/src/oarocr/builder_utils.rs @@ -4,7 +4,29 @@ use oar_ocr_core::core::OCRError; use oar_ocr_core::core::config::OrtSessionConfig; use oar_ocr_core::core::traits::OrtConfigurable; use oar_ocr_core::core::traits::adapter::AdapterBuilder; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; + +/// Resolves a model path through the auto-download cache when the +/// `auto-download` feature is enabled. +/// +/// Behaviour: +/// +/// - With `auto-download` on, delegates to +/// [`oar_ocr_core::core::download::resolve_path`]. Bare file names that +/// match the registry are fetched from ModelScope and verified against +/// the expected SHA-256; on-disk paths are returned unchanged. +/// - Without the feature, returns the input verbatim. The caller's normal +/// error path produces the usual "model not found" message. +pub fn resolve_model_path(path: &Path) -> Result { + #[cfg(feature = "auto-download")] + { + oar_ocr_core::core::download::resolve_path(path) + } + #[cfg(not(feature = "auto-download"))] + { + Ok(path.to_path_buf()) + } +} /// Builds an optional adapter from a model path using a builder factory. /// @@ -37,11 +59,12 @@ where let Some(model) = model_path else { return Ok(None); }; + let resolved = resolve_model_path(model)?; let mut builder = create_builder(); if let Some(config) = ort_config { builder = builder.with_ort_config(config.clone()); } - Ok(Some(builder.build(model)?)) + Ok(Some(builder.build(&resolved)?)) } diff --git a/src/oarocr/ocr.rs b/src/oarocr/ocr.rs index 14b182f..08903cf 100644 --- a/src/oarocr/ocr.rs +++ b/src/oarocr/ocr.rs @@ -4,7 +4,7 @@ //! It simplifies the process of configuring text detection, recognition, and optional //! preprocessing components. -use super::builder_utils::build_optional_adapter; +use super::builder_utils::{build_optional_adapter, resolve_model_path}; use oar_ocr_core::core::config::OrtSessionConfig; use oar_ocr_core::core::constants::DEFAULT_REC_IMAGE_SHAPE; use oar_ocr_core::core::errors::OCRError; @@ -232,16 +232,21 @@ impl OAROCRBuilder { Self::validate_batch_size("region_batch_size", size)?; } + // Resolve required model paths through the auto-download cache when + // the feature is enabled. With the feature off these are no-ops. + let text_detection_model = resolve_model_path(&self.text_detection_model)?; + let text_recognition_model = resolve_model_path(&self.text_recognition_model)?; + let character_dict_path = resolve_model_path(&self.character_dict_path)?; + // Load character dictionary for text recognition - let char_dict = std::fs::read_to_string(&self.character_dict_path).map_err(|e| { - OCRError::InvalidInput { + let char_dict = + std::fs::read_to_string(&character_dict_path).map_err(|e| OCRError::InvalidInput { message: format!( "Failed to read character dictionary from '{}': {}", - self.character_dict_path.display(), + character_dict_path.display(), e ), - } - })?; + })?; // Build document rectification adapter if enabled let rectification_adapter = build_optional_adapter( @@ -325,7 +330,7 @@ impl OAROCRBuilder { detection_builder = detection_builder.text_type(text_type.clone()); } - let text_detection_adapter = detection_builder.build(&self.text_detection_model)?; + let text_detection_adapter = detection_builder.build(&text_detection_model)?; // Build text line orientation adapter if enabled let text_line_orientation_adapter = build_optional_adapter( @@ -350,7 +355,7 @@ impl OAROCRBuilder { recognition_builder = recognition_builder.with_config(rec_config.clone()); } - let text_recognition_adapter = recognition_builder.build(&self.text_recognition_model)?; + let text_recognition_adapter = recognition_builder.build(&text_recognition_model)?; let pipeline = OCRPipeline { rectification_adapter, diff --git a/src/oarocr/structure.rs b/src/oarocr/structure.rs index b8d1905..1988682 100644 --- a/src/oarocr/structure.rs +++ b/src/oarocr/structure.rs @@ -4,7 +4,7 @@ //! analysis pipelines that can detect layout elements, recognize tables, extract formulas, //! and optionally integrate OCR for text extraction. -use super::builder_utils::build_optional_adapter; +use super::builder_utils::{build_optional_adapter, resolve_model_path}; use oar_ocr_core::core::OCRError; use oar_ocr_core::core::config::OrtSessionConfig; use oar_ocr_core::core::traits::OrtConfigurable; @@ -646,7 +646,37 @@ impl OARStructureBuilder { /// Builds the structure analyzer runtime. /// /// This method instantiates all adapters and returns a ready-to-use structure analyzer. - pub fn build(self) -> Result { + pub fn build(mut self) -> Result { + // Resolve every model/dict/tokenizer path through the auto-download + // cache when the `auto-download` feature is enabled. With the feature + // off these calls are infallible no-ops. + self.layout_detection_model = resolve_model_path(&self.layout_detection_model)?; + fn ro(p: &mut Option) -> Result<(), OCRError> { + if let Some(path) = p { + *path = resolve_model_path(path)?; + } + Ok(()) + } + ro(&mut self.document_orientation_model)?; + ro(&mut self.document_rectification_model)?; + ro(&mut self.region_detection_model)?; + ro(&mut self.table_classification_model)?; + ro(&mut self.table_orientation_model)?; + ro(&mut self.table_cell_detection_model)?; + ro(&mut self.table_structure_recognition_model)?; + ro(&mut self.table_structure_dict_path)?; + ro(&mut self.wired_table_structure_model)?; + ro(&mut self.wireless_table_structure_model)?; + ro(&mut self.wired_table_cell_model)?; + ro(&mut self.wireless_table_cell_model)?; + ro(&mut self.formula_recognition_model)?; + ro(&mut self.formula_tokenizer_path)?; + ro(&mut self.seal_text_detection_model)?; + ro(&mut self.text_detection_model)?; + ro(&mut self.text_line_orientation_model)?; + ro(&mut self.text_recognition_model)?; + ro(&mut self.character_dict_path)?; + // Load character dictionary if OCR is enabled let char_dict = if let Some(ref dict_path) = self.character_dict_path { Some( From 0a62ea899861c1f321a36cd9fbd4d0653dfbdc9d Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Fri, 15 May 2026 11:41:29 +0000 Subject: [PATCH 2/4] fix(download): address PR review feedback --- docs/models.md | 9 +- oar-ocr-core/src/core/download/mod.rs | 139 ++++++++++++++++++++++++-- 2 files changed, 135 insertions(+), 13 deletions(-) diff --git a/docs/models.md b/docs/models.md index e6eb28a..fb8b05b 100644 --- a/docs/models.md +++ b/docs/models.md @@ -190,10 +190,13 @@ For each model path argument the builder applies these checks in order: | `"pp-ocrv5_mobile_det.onnx"` | `$OAR_HOME/...` exists, hash OK | Use cached copy, no network | | `"pp-ocrv5_mobile_det.onnx"` | absent or hash mismatch | Download to `$OAR_HOME`, verify, use | | `"./models/det.onnx"` | absent | Returned as-is → "model not found" | -| `"~/.oar/pp-ocrv5_mobile_det.onnx"` | (any) | Treated as a `$OAR_HOME` cache path; same as bare name | +| `"$OAR_HOME/pp-ocrv5_mobile_det.onnx"` (absolute) | (any) | Parent equals the cache dir → same as bare name | + +Note: the resolver compares paths verbatim — `~` is not expanded. Pass a bare filename, an absolute path under `$OAR_HOME`, or let your shell expand `~` for you. ### Cache layout -- Override the cache root with the `OAR_HOME` environment variable. Defaults to `~/.oar`. +- Override the cache root with the `OAR_HOME` environment variable. Defaults to `~/.oar` (resolved via the platform home directory; the literal `~` is not expanded by the library). - Files land at `$OAR_HOME/`, flat (no per-revision subdirectories). -- Downloads stream into `$OAR_HOME/..part` and are renamed atomically once the SHA-256 matches, so a crash mid-download won't poison the cache. +- Downloads stream into a unique `$OAR_HOME/....part` and are renamed atomically once the SHA-256 matches, so a crash mid-download won't poison the cache and concurrent processes don't clobber each other. +- After verification a `$OAR_HOME/..sha256` sidecar records the verified hash. Future loads with a matching cache file + sidecar skip the multi-second rehash; deleting the sidecar forces a fresh hash check. diff --git a/oar-ocr-core/src/core/download/mod.rs b/oar-ocr-core/src/core/download/mod.rs index 3ce2492..04dd078 100644 --- a/oar-ocr-core/src/core/download/mod.rs +++ b/oar-ocr-core/src/core/download/mod.rs @@ -30,6 +30,7 @@ use std::env; use std::fs::{self, File}; use std::io::{self, Read, Write}; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; use crate::core::errors::OCRError; @@ -184,8 +185,27 @@ fn cached_file_matches(path: &Path, entry: &Entry) -> Result { ); return Ok(false); } + + // Fast path: if a sidecar from a previous verified download/check exists + // and records the expected hash, skip rehashing the (potentially 1.8 GB) + // file. The sidecar is written under the same cache directory we control, + // so the threat model matches "trust the cache once we've vouched for it". + if sidecar_records_hash(path, entry.sha256) { + return Ok(true); + } + match sha256_file(path) { - Ok(hash) if hash == entry.sha256 => Ok(true), + Ok(hash) if hash == entry.sha256 => { + // Remember the verification so future loads skip the rehash. + if let Err(e) = write_sidecar(path, entry.sha256) { + tracing::debug!( + path = %path.display(), + error = %e, + "failed to write sha256 sidecar; cache will rehash next time" + ); + } + Ok(true) + } Ok(hash) => { tracing::warn!( path = %path.display(), @@ -206,12 +226,41 @@ fn cached_file_matches(path: &Path, entry: &Entry) -> Result { } } +fn sidecar_path(path: &Path) -> Option { + let name = path.file_name()?.to_str()?; + Some(path.with_file_name(format!(".{name}.sha256"))) +} + +fn sidecar_records_hash(path: &Path, expected: &str) -> bool { + let Some(sidecar) = sidecar_path(path) else { + return false; + }; + match fs::read_to_string(&sidecar) { + Ok(contents) => contents.trim().eq_ignore_ascii_case(expected), + Err(_) => false, + } +} + +fn write_sidecar(path: &Path, hash: &str) -> io::Result<()> { + let sidecar = sidecar_path(path) + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no filename for sidecar"))?; + fs::write(&sidecar, hash) +} + fn download_and_verify(entry: &Entry, target: &Path) -> Result<(), OCRError> { let url = format!( "https://www.modelscope.cn/api/v1/models/{}/repo?Revision={}&FilePath={}", MODELSCOPE_REPO, DEFAULT_REVISION, entry.name, ); + // Build the agent once so connection pooling and HTTPS handshakes survive + // across retry attempts. + let agent = ureq::Agent::config_builder() + .timeout_global(Some(Duration::from_secs(REQUEST_TIMEOUT_SECS))) + .timeout_connect(Some(Duration::from_secs(CONNECT_TIMEOUT_SECS))) + .build() + .new_agent(); + let mut last_err: Option = None; for attempt in 1..=DOWNLOAD_RETRIES { tracing::info!( @@ -220,7 +269,7 @@ fn download_and_verify(entry: &Entry, target: &Path) -> Result<(), OCRError> { attempt, "downloading from ModelScope" ); - match download_attempt(&url, entry, target) { + match download_attempt(&agent, &url, entry, target) { Ok(()) => return Ok(()), Err(e) => { tracing::warn!( @@ -238,13 +287,27 @@ fn download_and_verify(entry: &Entry, target: &Path) -> Result<(), OCRError> { })) } -fn download_attempt(url: &str, entry: &Entry, target: &Path) -> Result<(), OCRError> { - let agent = ureq::Agent::config_builder() - .timeout_global(Some(Duration::from_secs(REQUEST_TIMEOUT_SECS))) - .timeout_connect(Some(Duration::from_secs(CONNECT_TIMEOUT_SECS))) - .build() - .new_agent(); +/// Monotonic counter used to keep concurrent in-process downloads of the same +/// entry from sharing a temp path. Combined with the PID it gives a unique +/// suffix without pulling in a `rand` dependency. +static TMP_COUNTER: AtomicU64 = AtomicU64::new(0); + +fn unique_tmp_path(target: &Path, entry: &Entry) -> PathBuf { + let counter = TMP_COUNTER.fetch_add(1, Ordering::Relaxed); + target.with_file_name(format!( + ".{}.{}.{}.part", + entry.name, + std::process::id(), + counter + )) +} +fn download_attempt( + agent: &ureq::Agent, + url: &str, + entry: &Entry, + target: &Path, +) -> Result<(), OCRError> { let response = agent .get(url) .call() @@ -253,7 +316,7 @@ fn download_attempt(url: &str, entry: &Entry, target: &Path) -> Result<(), OCREr let mut body = response.into_body().into_reader(); let parent = target.parent().unwrap_or_else(|| Path::new(".")); - let tmp = target.with_file_name(format!(".{}.part", entry.name)); + let tmp = unique_tmp_path(target, entry); let mut file = File::create(&tmp).map_err(|e| { io_with_context( e, @@ -311,6 +374,16 @@ fn download_attempt(url: &str, entry: &Entry, target: &Path) -> Result<(), OCREr format!("move `{}` -> `{}`", tmp.display(), target.display()), ) })?; + + // Record the verified hash next to the file so subsequent loads can skip + // the expensive rehash. Best-effort: a failure here only costs a rehash. + if let Err(e) = write_sidecar(target, entry.sha256) { + tracing::debug!( + path = %target.display(), + error = %e, + "failed to write sha256 sidecar after download" + ); + } Ok(()) } @@ -423,12 +496,58 @@ mod tests { let mismatched_size = Entry { size: 99, ..entry }; assert!(!cached_file_matches(&path, &mismatched_size).unwrap()); - // Wrong contents → no match (hash differs). + // Wrong contents → no match (hash differs). Clear any sidecar left + // behind by the earlier matches so the rehash actually runs. + let sidecar = sidecar_path(&path).unwrap(); + let _ = std::fs::remove_file(&sidecar); std::fs::write(&path, b"world!").unwrap(); let same_len = Entry { size: 6, ..entry }; assert!(!cached_file_matches(&path, &same_len).unwrap()); } + #[test] + fn cached_file_matches_writes_sidecar_and_uses_it() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("dummy.bin"); + let entry = Entry { + name: "dummy.bin", + sha256: "5891b5b522d5df086d0ff0b110fbd9d21bb4fc7163af34d08286a2e846f6be03", + size: 6, + }; + std::fs::write(&path, b"hello\n").unwrap(); + + // First match triggers a real hash + writes the sidecar. + assert!(cached_file_matches(&path, &entry).unwrap()); + let sidecar = sidecar_path(&path).unwrap(); + assert_eq!(std::fs::read_to_string(&sidecar).unwrap(), entry.sha256); + + // Tampering with the file but keeping size: sidecar lets us trust the + // (now stale) cache. This is the deliberate tradeoff — see module docs. + std::fs::write(&path, b"world!").unwrap(); + assert!(cached_file_matches(&path, &entry).unwrap()); + + // If the sidecar disagrees with the expected hash, we fall back to + // rehashing and (here) reject the cache. + std::fs::write(&sidecar, "deadbeef").unwrap(); + assert!(!cached_file_matches(&path, &entry).unwrap()); + } + + #[test] + fn unique_tmp_path_never_repeats_in_process() { + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("model.onnx"); + let entry = Entry { + name: "model.onnx", + sha256: "00", + size: 0, + }; + let a = unique_tmp_path(&target, &entry); + let b = unique_tmp_path(&target, &entry); + assert_ne!(a, b); + let pid = std::process::id().to_string(); + assert!(a.to_string_lossy().contains(&pid)); + } + #[test] fn fetch_unregistered_name_returns_config_error() { let err = fetch("does-not-exist.onnx").unwrap_err(); From 264a44cf778ce83aa6d9d45e9bfa6f7e0c81fea3 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Fri, 15 May 2026 12:18:04 +0000 Subject: [PATCH 3/4] RAII guard --- oar-ocr-core/src/core/download/mod.rs | 74 ++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/oar-ocr-core/src/core/download/mod.rs b/oar-ocr-core/src/core/download/mod.rs index 04dd078..e82717b 100644 --- a/oar-ocr-core/src/core/download/mod.rs +++ b/oar-ocr-core/src/core/download/mod.rs @@ -302,6 +302,39 @@ fn unique_tmp_path(target: &Path, entry: &Entry) -> PathBuf { )) } +/// RAII guard that deletes a temp file on drop unless explicitly disarmed. +/// Keeps `$OAR_HOME` from accumulating stale `.part` files when a download +/// fails mid-stream (read error, write error, hash mismatch, …). +struct TempFileGuard { + path: Option, +} + +impl TempFileGuard { + fn new(path: PathBuf) -> Self { + Self { path: Some(path) } + } + + fn path(&self) -> &Path { + // Only `disarm` clears `path`, and it consumes `self`, so anywhere + // we still hold the guard the path is present. + self.path.as_deref().expect("guard already disarmed") + } + + /// Hand off ownership of the temp file to a successful rename; nothing + /// gets deleted on drop after this point. + fn disarm(mut self) { + self.path = None; + } +} + +impl Drop for TempFileGuard { + fn drop(&mut self) { + if let Some(path) = self.path.take() { + let _ = fs::remove_file(path); + } + } +} + fn download_attempt( agent: &ureq::Agent, url: &str, @@ -327,6 +360,8 @@ fn download_attempt( ), ) })?; + // From here on, any early return must not leak the temp file. + let guard = TempFileGuard::new(tmp); let mut hasher = ::new(); let mut buf = vec![0u8; READ_BUFFER_BYTES]; @@ -340,15 +375,14 @@ fn download_attempt( } sha2::Digest::update(&mut hasher, &buf[..n]); file.write_all(&buf[..n]) - .map_err(|e| io_with_context(e, format!("write `{}`", tmp.display())))?; + .map_err(|e| io_with_context(e, format!("write `{}`", guard.path().display())))?; written += n as u64; } file.sync_all() - .map_err(|e| io_with_context(e, format!("sync `{}`", tmp.display())))?; + .map_err(|e| io_with_context(e, format!("sync `{}`", guard.path().display())))?; drop(file); if written != entry.size { - let _ = fs::remove_file(&tmp); return Err(OCRError::ConfigError { message: format!( "downloaded `{}` is {} bytes but the registry expects {}", @@ -359,7 +393,6 @@ fn download_attempt( let actual_hash = encode_hex(&sha2::Digest::finalize(hasher)); if actual_hash != entry.sha256 { - let _ = fs::remove_file(&tmp); return Err(OCRError::ConfigError { message: format!( "sha256 mismatch for `{}`: expected {}, got {}", @@ -368,12 +401,19 @@ fn download_attempt( }); } - fs::rename(&tmp, target).map_err(|e| { + fs::rename(guard.path(), target).map_err(|e| { io_with_context( e, - format!("move `{}` -> `{}`", tmp.display(), target.display()), + format!( + "move `{}` -> `{}`", + guard.path().display(), + target.display() + ), ) })?; + // The temp path no longer exists (renamed onto `target`); disarm so + // Drop doesn't try to remove a now-missing file. + guard.disarm(); // Record the verified hash next to the file so subsequent loads can skip // the expensive rehash. Best-effort: a failure here only costs a rehash. @@ -532,6 +572,28 @@ mod tests { assert!(!cached_file_matches(&path, &entry).unwrap()); } + #[test] + fn temp_file_guard_removes_on_drop_and_keeps_when_disarmed() { + let dir = tempfile::tempdir().unwrap(); + + // Drop-without-disarm deletes the file. + let p = dir.path().join("dropme.part"); + std::fs::write(&p, b"x").unwrap(); + drop(TempFileGuard::new(p.clone())); + assert!(!p.exists(), "guard should remove the temp file on drop"); + + // Disarm keeps the file. + let p = dir.path().join("keepme.part"); + std::fs::write(&p, b"x").unwrap(); + let guard = TempFileGuard::new(p.clone()); + guard.disarm(); + assert!(p.exists(), "disarmed guard must not delete the file"); + + // Missing temp path on drop is silently ignored (no panic). + let p = dir.path().join("ghost.part"); + drop(TempFileGuard::new(p)); + } + #[test] fn unique_tmp_path_never_repeats_in_process() { let dir = tempfile::tempdir().unwrap(); From 21edcbc3ea91edab0cab9bb279c46c75cf989ac4 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Fri, 15 May 2026 12:49:15 +0000 Subject: [PATCH 4/4] fix(download): rename ro helper and serialize env-mutating tests --- oar-ocr-core/src/core/download/mod.rs | 15 ++++++++++ src/oarocr/structure.rs | 40 +++++++++++++-------------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/oar-ocr-core/src/core/download/mod.rs b/oar-ocr-core/src/core/download/mod.rs index e82717b..1b8b15e 100644 --- a/oar-ocr-core/src/core/download/mod.rs +++ b/oar-ocr-core/src/core/download/mod.rs @@ -462,6 +462,17 @@ fn network_error(ctx: String, err: ureq::Error) -> OCRError { #[cfg(test)] mod tests { use super::*; + use std::sync::Mutex; + + // Rust runs tests in parallel by default. Anything that reads or writes the + // process-global `OAR_HOME` env var (directly or via `cache_dir` / + // `resolve_path`) must hold this lock so the writer test does not race + // with concurrent readers and corrupt their observed cache directory. + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + fn lock_env() -> std::sync::MutexGuard<'static, ()> { + ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner()) + } #[test] fn find_known_entry() { @@ -476,6 +487,7 @@ mod tests { #[test] fn resolve_existing_file_returns_input() { + let _guard = lock_env(); let dir = tempfile::tempdir().unwrap(); let f = dir.path().join("local.onnx"); std::fs::write(&f, b"hi").unwrap(); @@ -485,6 +497,7 @@ mod tests { #[test] fn resolve_explicit_path_passthrough_for_unknown() { + let _guard = lock_env(); // A nested path that doesn't exist and isn't registered must be // returned verbatim so the caller's normal error fires. let p = PathBuf::from("/nonexistent/dir/some_random_model.onnx"); @@ -494,6 +507,7 @@ mod tests { #[test] fn resolve_bare_name_unknown_does_not_consult_network() { + let _guard = lock_env(); // No registry hit, no existing file → returned as-is. let p = PathBuf::from("not-in-registry.onnx"); let resolved = resolve_path(&p).unwrap(); @@ -502,6 +516,7 @@ mod tests { #[test] fn cache_dir_honours_env_override() { + let _guard = lock_env(); let dir = tempfile::tempdir().unwrap(); unsafe { std::env::set_var(OAR_HOME_ENV, dir.path()); diff --git a/src/oarocr/structure.rs b/src/oarocr/structure.rs index 1988682..3eaa9fb 100644 --- a/src/oarocr/structure.rs +++ b/src/oarocr/structure.rs @@ -651,31 +651,31 @@ impl OARStructureBuilder { // cache when the `auto-download` feature is enabled. With the feature // off these calls are infallible no-ops. self.layout_detection_model = resolve_model_path(&self.layout_detection_model)?; - fn ro(p: &mut Option) -> Result<(), OCRError> { + fn resolve_opt_path(p: &mut Option) -> Result<(), OCRError> { if let Some(path) = p { *path = resolve_model_path(path)?; } Ok(()) } - ro(&mut self.document_orientation_model)?; - ro(&mut self.document_rectification_model)?; - ro(&mut self.region_detection_model)?; - ro(&mut self.table_classification_model)?; - ro(&mut self.table_orientation_model)?; - ro(&mut self.table_cell_detection_model)?; - ro(&mut self.table_structure_recognition_model)?; - ro(&mut self.table_structure_dict_path)?; - ro(&mut self.wired_table_structure_model)?; - ro(&mut self.wireless_table_structure_model)?; - ro(&mut self.wired_table_cell_model)?; - ro(&mut self.wireless_table_cell_model)?; - ro(&mut self.formula_recognition_model)?; - ro(&mut self.formula_tokenizer_path)?; - ro(&mut self.seal_text_detection_model)?; - ro(&mut self.text_detection_model)?; - ro(&mut self.text_line_orientation_model)?; - ro(&mut self.text_recognition_model)?; - ro(&mut self.character_dict_path)?; + resolve_opt_path(&mut self.document_orientation_model)?; + resolve_opt_path(&mut self.document_rectification_model)?; + resolve_opt_path(&mut self.region_detection_model)?; + resolve_opt_path(&mut self.table_classification_model)?; + resolve_opt_path(&mut self.table_orientation_model)?; + resolve_opt_path(&mut self.table_cell_detection_model)?; + resolve_opt_path(&mut self.table_structure_recognition_model)?; + resolve_opt_path(&mut self.table_structure_dict_path)?; + resolve_opt_path(&mut self.wired_table_structure_model)?; + resolve_opt_path(&mut self.wireless_table_structure_model)?; + resolve_opt_path(&mut self.wired_table_cell_model)?; + resolve_opt_path(&mut self.wireless_table_cell_model)?; + resolve_opt_path(&mut self.formula_recognition_model)?; + resolve_opt_path(&mut self.formula_tokenizer_path)?; + resolve_opt_path(&mut self.seal_text_detection_model)?; + resolve_opt_path(&mut self.text_detection_model)?; + resolve_opt_path(&mut self.text_line_orientation_model)?; + resolve_opt_path(&mut self.text_recognition_model)?; + resolve_opt_path(&mut self.character_dict_path)?; // Load character dictionary if OCR is enabled let char_dict = if let Some(ref dict_path) = self.character_dict_path {