diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 90c46c9..a815ea0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,6 +33,34 @@ jobs: cargo clippy --all-targets --all -- -D warnings cargo clippy --all-targets -p oar-ocr-vl -- -D warnings + - name: Check rustdoc warnings + env: + RUSTDOCFLAGS: -D warnings + run: cargo doc --workspace --no-deps + + feature-matrix: + name: Feature matrix (${{ matrix.package }}) + needs: preflight + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - package: oar-ocr + - package: oar-ocr-core + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache Cargo dependencies + uses: Swatinem/rust-cache@v2 + + - name: Check all feature combinations + run: cargo check -p ${{ matrix.package }} --all-features + test: name: Test (${{ matrix.os }}) needs: preflight diff --git a/Cargo.toml b/Cargo.toml index d2c67fe..26b4df7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = [".", "oar-ocr-derive", "oar-ocr-core", "oar-ocr-vl"] resolver = "2" [workspace.package] -version = "0.6.3" +version = "0.7.0" edition = "2024" rust-version = "1.95" license = "Apache-2.0" @@ -12,8 +12,8 @@ repository = "https://github.com/greatv/oar-ocr" homepage = "https://github.com/greatv/oar-ocr" [workspace.dependencies] -oar-ocr-core = { version = "0.6.3", path = "oar-ocr-core", default-features = false } -oar-ocr-derive = { version = "0.6.3", path = "oar-ocr-derive", default-features = false } +oar-ocr-core = { version = "0.7.0", path = "oar-ocr-core", default-features = false } +oar-ocr-derive = { version = "0.7.0", path = "oar-ocr-derive", default-features = false } [package] name = "oar-ocr" @@ -68,5 +68,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } clap = { version = "4.5.42", features = ["derive"] } tempfile = "3.19" ab_glyph = "0.2" +fontdb = "0.23" hayro = "0.5" regex = "1" diff --git a/README.md b/README.md index 7161fe7..79f7823 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # OAR-OCR -![Crates.io Version](https://img.shields.io/crates/v/oar-ocr) +[![Crates.io Version](https://img.shields.io/crates/v/oar-ocr)](https://crates.io/crates/oar-ocr) ![Crates.io Downloads (recent)](https://img.shields.io/crates/dr/oar-ocr) [![dependency status](https://deps.rs/repo/github/GreatV/oar-ocr/status.svg)](https://deps.rs/repo/github/GreatV/oar-ocr) ![GitHub License](https://img.shields.io/github/license/GreatV/oar-ocr) @@ -84,12 +84,17 @@ fn main() -> Result<(), Box> { ## Vision-Language Models (VLM) -For advanced document understanding using Vision-Language Models (like PaddleOCR-VL, **PaddleOCR-VL-1.5**, UniRec, and MinerU2.5), check out the [`oar-ocr-vl`](oar-ocr-vl/README.md) crate. +For advanced document understanding using Vision-Language Models (like PaddleOCR-VL, **PaddleOCR-VL-1.5**, GLM-OCR, HunyuanOCR, and MinerU2.5), check out the [`oar-ocr-vl`](oar-ocr-vl/README.md) crate. + +### Hierarchical Speculative Decoding (HSD) + +`oar-ocr-vl` ships a training-free CUDA acceleration scheme for the VLM backbones above. A cheap pipeline drafter (layout + OCR) proposes text candidates and the target VLM verifies them in batches via tree-attention, typically delivering several-fold wall-time speedups on document-heavy pages at `τ = 0.75`. Build with `--features hsd` (implies `cuda`); see [`docs/hsd.md`](docs/hsd.md) for the algorithm overview, config knobs, supported backbones, and AAL guidance. ## Documentation - [**Usage Guide**](docs/usage.md) - Detailed API usage, builder patterns, GPU configuration - [**Pre-trained Models**](docs/models.md) - Model download links and recommended configurations +- [**HSD**](docs/hsd.md) - Hierarchical Speculative Decoding for VLM inference ## Examples @@ -117,6 +122,4 @@ This project builds upon the excellent work of several open-source projects: - **[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)**: Baidu's awesome multilingual OCR toolkits based on PaddlePaddle. This project utilizes PaddleOCR's pre-trained models, which provide excellent accuracy and performance for text detection and recognition across multiple languages. -- **[OpenOCR](https://github.com/Topdu/OpenOCR)**: An open-source toolkit for general OCR research and applications by the FVL Laboratory at Fudan University. We use the UniRec model for unified text, formula, and table recognition. - - **[Candle](https://github.com/huggingface/candle)**: A minimalist ML framework for Rust by Hugging Face. We use Candle to implement Vision-Language model inference. diff --git a/docs/hsd.md b/docs/hsd.md new file mode 100644 index 0000000..e11fa8f --- /dev/null +++ b/docs/hsd.md @@ -0,0 +1,97 @@ +# Hierarchical Speculative Decoding (HSD) + +HSD is an optional CUDA acceleration path for document VLM decoding. It leaves the target model unchanged. A cheaper document pipeline — the paper uses PP-StructureV3 (layout analysis + element recognition) — prepares draft text once per page. The VLM then verifies those drafts with tree-batched speculative decoding and commits only accepted tokens. + +Reference: Liao et al., *"HSD: Training-Free Acceleration for Document Parsing Vision-Language Model with Hierarchical Speculative Decoding"* (arXiv:2602.12957). Section references below cite that paper. + +## When to use it + +HSD helps when the draft is close to what the VLM would generate on its own. That is common on text-heavy pages, tables with regular structure, and repeated document boilerplate. A good draft lets one verify pass accept several tokens. + +It is not a general CPU speedup. The implementation is intended for CUDA, where a wider tree-attention verify pass is cheap compared with repeated single-token decoding. On CPU or Metal, the verify work is effectively serialized and the benefit usually disappears. + +The paper defines the acceptance threshold on the open interval $\tau \in (0, 1)$ (§3.2). Lower values accept more near-tie tokens, which can improve speed but may change the output. This implementation also accepts `tau = 1.0` as a degenerate boundary: at $\tau = 1.0$ the acceptance test collapses to "child equals the unrestricted argmax", so HSD follows the target model's greedy path. That extension is not part of the paper. + +## Document flow + +The document-level path has two stages (§3.1): + +- **Stage 1: region-level local verification.** For each region $r_i \in \mathcal{R}$, the target VLM verifies the region draft set $\tilde{\mathcal{Y}}^{(i)}$ on the cropped image $z_i = x|_{r_i}$: + $$\hat{y}^{(i)} = \mathrm{SpecDecode}(p_\theta, z_i, \tilde{\mathcal{Y}}^{(i)}).$$ +- **Stage 2: page-level global verification.** Stage 1 outputs are aggregated into an unordered page-level draft set + $$\tilde{\mathcal{Y}}^{\mathrm{pg}} = \{\hat{y}^{(i)} \mid r_i \in \mathcal{R}\},$$ + which the target VLM then verifies in a single full-page pass: $\hat{y}^{\mathrm{pg}} = \mathrm{SpecDecode}(p_\theta, x, \tilde{\mathcal{Y}}^{\mathrm{pg}})$. Because the matcher scans each $\hat{y}^{(i)}$ independently, draft order has no semantic effect; the target model resolves the final reading order during verification. + +Backends that implement the full document path can turn either stage off through `HsdConfig`. PaddleOCR-VL is not evaluated in the paper; in this implementation it stays element-oriented by model design and uses only the region path. + +## One SpecDecode step + +For the accepted prefix $\hat{y}_{1:t}$ and a draft set $\tilde{\mathcal{Y}}$ (§3.2): + +1. **Draft-target matching.** Let the reference window be the most recent $n$ accepted tokens, $w = \hat{y}_{t-n+1:t}$. For each draft $\tilde{y} \in \tilde{\mathcal{Y}}$, record every start index $j$ with $\tilde{y}_{j:j+n-1} = w$. Collect the suffixes that follow each match: + $$\mathcal{C} = \big\{\, \tilde{y}_{j+n:|\tilde{y}|} \,\big|\, \tilde{y} \in \tilde{\mathcal{Y}},\; j \in \mathcal{J}(\tilde{y}),\; j + n \le |\tilde{y}|\,\big\}.$$ +2. **Prefix-tree batching.** Merge $\mathcal{C}$ into a prefix tree $\mathcal{T}$ whose root represents the empty prefix and whose every root-to-leaf path is one element of $\mathcal{C}$. For a node $v$, $\pi(v)$ is the token sequence on the path root → $v$, and $\mathrm{Next}(v) = \{c_{|\pi(v)|+1} \mid c \in \mathcal{C},\; c_{1:|\pi(v)|} = \pi(v)\}$ is the set of distinct next tokens shared by candidates that pass through $v$. +3. **One tree-batched forward.** Linearize $\mathcal{T}$ into a packed sequence and run the target VLM under a tree-ancestry attention mask: a token at node $v$ attends only to $\hat{y}_{1:t}$ and to the tokens on $v$'s ancestor path. This produces $p_\theta(\cdot \mid z, \hat{y}_{1:t} \oplus \pi(v))$ at every node in one pass. +4. **Greedy traversal with the $\tau$-test.** Start at the root $s$. At each step, select the best child token in the tree's local candidate set and compare it with the unrestricted argmax over the full vocabulary $\mathcal{V}$: + $$u^\star = \arg\max_{u \in \mathrm{Next}(s)} p_\theta(u \mid z, \hat{y}_{1:t} \oplus \pi(s)), \qquad \hat{u} = \arg\max_{u \in \mathcal{V}} p_\theta(u \mid z, \hat{y}_{1:t} \oplus \pi(s)).$$ + Accept $u^\star$ and descend to its child node iff + $$\log p_\theta(u^\star \mid z, \hat{y}_{1:t} \oplus \pi(s)) - \log p_\theta(\hat{u} \mid z, \hat{y}_{1:t} \oplus \pi(s)) \ge \log \tau.$$ + Stop when the test fails, when $\mathrm{Next}(s) = \emptyset$, or when $s$ is a leaf. +5. **Bonus target token.** At the terminal node $s$, append the unrestricted argmax $\hat{u}$ to extend the accepted sequence by one extra target token: + $$\hat{y}_{1:t_\mathrm{new}} = \hat{y}_{1:t} \oplus \pi(s) \oplus \hat{u}.$$ +6. **Commit KV state.** Gather the KV cache so it keeps only the accepted prefix and the path through $s$, then continue decoding from $\hat{u}$. + +If $\mathcal{C} = \emptyset$ (no draft matches the current window), $\mathcal{T}$ contains only the root, $\mathrm{Next}(\mathrm{root}) = \emptyset$, the traversal stops immediately, and step 5 falls back to a single greedy token — the paper's algorithm with no special case. + +## Correctness at `tau = 1.0` + +The paper proves training-free, near-lossless acceleration over its stated domain $\tau \in (0, 1)$. This implementation also exposes $\tau = 1.0$ as a degenerate boundary: with $\log \tau = 0$, the acceptance test in step 4 reduces to $u^\star = \hat{u}$, so a child token is accepted only when it coincides with the unrestricted argmax. The committed sequence is then independent of the drafter and identical to the target model's greedy decode. + +By default this is enforced via a strict replay path (`strict_at_tau_one = true`, see Configuration). That replay path is an OAR-side correctness oracle, not part of the paper. Set `strict_at_tau_one = false` to keep $\tau = 1.0$ on the same tree-batched verify path the paper describes. + +The demo harness runs this oracle check and compares HSD output with baseline output byte-for-byte. + +## Reading AAL + +The main debug metric is **Average Acceptance Length (AAL)** (§4.2). At verification step $k$, let $\alpha_k$ be the number of consecutive draft tokens accepted before the first mismatch ($\alpha_k = 0$ on a full rejection). Over $N$ verification steps: + +$$\mathrm{AAL} = \frac{1}{N} \sum_{k=1}^{N} \alpha_k.$$ + +The bonus target token appended at step 5 is not counted. Larger AAL means more decoding steps are saved per verify pass; the realized end-to-end speedup also depends on per-step verify overhead and parallel efficiency. + +For reference, the paper reports overall AAL on OmniDocBench v1.5 (Tab. 1) in the **2.5 to 4.6** range across the evaluated backbones (HunyuanOCR 4.55, dots.ocr 3.98, Qwen3-VL-8B 3.98, Qwen3-VL-2B 3.33, Qwen2.5-VL-7B 3.56, Qwen2.5-VL-3B 2.52). The ranges below are operational rules of thumb observed on this implementation, not paper numbers; use AAL as a draft-quality signal, not as a correctness metric: + +- `AAL` around `1`: the draft is doing little work. Check tokenization, window length, reading order, and whether the drafter output resembles the target output. +- `AAL` from `3` to `6`: a normal range for many text-heavy pages with OCR drafts. +- `AAL` from `8` to `15`: strong alignment, often from tables, lists, or repeated layout. +- `AAL > 20`: usually a long exact span. Inspect output quality as well as speed. + +## Configuration + +`HsdConfig` controls the two-stage document path. Its `dsv` field contains the per-step speculative decoding knobs. The first two fields (`window_len`, `tau`) follow the paper's defaults (§4.3); the rest are OAR-side engineering knobs not present in the paper. + +| Field | Default | Source | Notes | +|-------|---------|--------|-------| +| `window_len` | `3` | paper §4.3 | Reference-window length $n$. Longer windows reduce false matches on repetitive text but also reduce draft hits. | +| `tau` | `0.75` | paper §4.3 | Acceptance threshold. Paper domain is $\tau \in (0, 1)$; lower accepts more borderline tokens. `1.0` is a boundary extension that recovers greedy decoding. | +| `max_candidates_per_step` | `32` | OAR addition | Bounds the number of candidate suffixes used to build each tree. The paper's ablations use uncapped trees. | +| `max_suffix_len` | `256` | OAR addition | Bounds candidate depth so long drafts do not create oversized trees. | +| `cold_start_full_draft` | `true` | OAR addition | Seeds the first step from draft prefixes before any accepted window exists. The paper's matcher has no cold-start fallback. | +| `strict_at_tau_one` | `true` | OAR addition | When `true` and $\tau \ge 1.0$, route through a strict replay oracle. Set `false` to keep $\tau = 1.0$ on the paper's tree-batched verify path. | + +The candidate caps are guardrails for long tables, formulas, and repeated boilerplate. To reproduce a paper-faithful matcher, set both caps to `usize::MAX`, `cold_start_full_draft = false`, and `strict_at_tau_one = false`. + +## Running it + +Build the VLM crate with the `hsd` feature. It enables CUDA transitively: + +```bash +cargo run -p oar-ocr-vl --release --features hsd,download-binaries \ + --example hsd_demo -- \ + --backend hunyuanocr \ + --model-dir models/HunyuanOCR \ + --device cuda \ + --image document.jpg +``` + +The supported backbones expose `generate_hsd*` methods next to their baseline `generate` methods: `PaddleOcrVl`, `HunyuanOcr`, `GlmOcr`, and `MinerU`. diff --git a/docs/usage.md b/docs/usage.md index d5e3972..40fcf1d 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -138,6 +138,8 @@ let structure = OARStructureBuilder::new("picodet-l_layout_17cls.onnx") | `.with_table_structure_recognition(path, type)` | Add table structure recognition | | `.table_structure_dict_path(path)` | Set table structure dictionary | | `.with_formula_recognition(model, tokenizer, type)` | Add formula recognition | +| `.formula_recognition_config(config)` | Set formula score threshold, max length, and batch size | +| `.formula_ort_session(config)` | Apply ONNX Runtime configuration only to formula recognition | | `.with_ocr(det, rec, dict)` | Add integrated OCR pipeline | | `.with_seal_detection(path)` | Add seal/stamp text detection | | `.image_batch_size(n)` | Set batch size for image processing | @@ -241,14 +243,14 @@ Add the VL crate to your `Cargo.toml`: ```toml [dependencies] -oar-ocr-vl = "0.6" +oar-ocr-vl = "0.7" ``` For GPU acceleration, enable CUDA: ```toml [dependencies] -oar-ocr-vl = { version = "0.6", features = ["cuda"] } +oar-ocr-vl = { version = "0.7", features = ["cuda"] } ``` ### Downloading the Model @@ -281,8 +283,12 @@ fn main() -> Result<(), Box> { let device = parse_device("cpu")?; // or "cuda", "cuda:0" let vl = PaddleOcrVl::from_dir("PaddleOCR-VL", device)?; - // Element-level OCR - let result = vl.generate(image, PaddleOcrVlTask::Ocr, 256)?; + // Element-level OCR. The API is batch-oriented, so pass one task per image. + let result = vl + .generate(&[image], &[PaddleOcrVlTask::Ocr], 256) + .into_iter() + .next() + .expect("one result")?; println!("{result}"); Ok(()) @@ -302,7 +308,11 @@ fn main() -> Result<(), Box> { let device = parse_device("cpu")?; let vl = PaddleOcrVl::from_dir("PaddleOCR-VL-1.5", device)?; - let result = vl.generate(image, PaddleOcrVlTask::Seal, 256)?; + let result = vl + .generate(&[image], &[PaddleOcrVlTask::Seal], 256) + .into_iter() + .next() + .expect("one result")?; println!("{result}"); Ok(()) @@ -330,34 +340,42 @@ cargo run -p oar-ocr-vl --features cuda --example paddleocr_vl -- \ | `PaddleOcrVlTask::Spotting` | Text spotting (localization + recognition) | Structured text | | `PaddleOcrVlTask::Seal` | Seal recognition | Plain text | -## UniRec +## HunyuanOCR + +[HunyuanOCR](https://huggingface.co/tencent/HunyuanOCR) is a 1B parameter OCR expert VLM. It's available in the `oar-ocr-vl` crate and supports prompt-driven image-to-text OCR. -[UniRec](https://github.com/Topdu/OpenOCR) is a unified recognition model with only **0.1B parameters**, developed by the FVL Laboratory at Fudan University as part of the OpenOCR project. It is designed for high-accuracy and efficient recognition of plain text (words, lines, paragraphs), mathematical formulas (single-line, multi-line), and mixed content in both Chinese and English. Despite its small size, it achieves performance comparable to or better than much larger vision-language models. It's also available in the `oar-ocr-vl` crate. +Note: inputs are automatically resized to satisfy the model's image/token limits (e.g., max side length 2048). ### Downloading the Model ```bash -hf download Topdu/UniRec-0.1B --local-dir models/unirec-0.1b +git lfs install +git clone https://huggingface.co/tencent/HunyuanOCR + +# Or using hf +hf download tencent/HunyuanOCR --local-dir HunyuanOCR ``` ### Basic Usage -```rust +```rust,no_run use oar_ocr_core::utils::load_image; -use oar_ocr_vl::UniRec; +use oar_ocr_vl::HunyuanOcr; use oar_ocr_vl::utils::parse_device; -use std::path::Path; fn main() -> Result<(), Box> { - let image = load_image(Path::new("formula.png"))?; - let device = parse_device("cpu")?; // or "cuda", "cuda:0" + let image = load_image("document.jpg")?; + let device = parse_device("cpu")?; // or "cuda", "cuda:0" - // Load UniRec model - let model = UniRec::from_dir("models/unirec-0.1b", device)?; + let model = HunyuanOcr::from_dir("HunyuanOCR", device)?; - // Generate recognition result - let result = model.generate(image, 512)?; - println!("{result}"); + let prompt = "Detect and recognize text in the image, and output the text coordinates in a formatted manner."; + let text = model + .generate(&[image], &[prompt], 1024) + .into_iter() + .next() + .expect("one result")?; + println!("{text}"); Ok(()) } @@ -366,41 +384,56 @@ fn main() -> Result<(), Box> { ### Running the Example ```bash -cargo run -p oar-ocr-vl --features cuda --example unirec -- \ - -m models/unirec-0.1b --device cuda formula.jpg +cargo run -p oar-ocr-vl --features cuda --example hunyuanocr -- \ + --model-dir HunyuanOCR \ + --device cuda \ + --prompt "Detect and recognize text in the image, and output the text coordinates in a formatted manner." \ + document.jpg ``` -## HunyuanOCR +### Application-oriented Prompts -[HunyuanOCR](https://huggingface.co/tencent/HunyuanOCR) is a 1B parameter OCR expert VLM powered by Hunyuan's multimodal architecture. It's available in the `oar-ocr-vl` crate and supports prompt-driven image-to-text OCR. +Prompts from the upstream HunyuanOCR README: -Note: inputs are automatically resized to satisfy the model's image/token limits (e.g., max side length 2048). +| Task | English | Chinese | +|------|---------|---------| +| **Spotting** | Detect and recognize text in the image, and output the text coordinates in a formatted manner. | 检测并识别图片中的文字,将文本坐标格式化输出。 | +| **Parsing** | • Identify the formula in the image and represent it using LaTeX format.

• Parse the table in the image into HTML.

• Parse the chart in the image; use Mermaid format for flowcharts and Markdown for other charts.

• Extract all information from the main body of the document image and represent it in markdown format, ignoring headers and footers. Tables should be expressed in HTML format, formulas in the document should be represented using LaTeX format, and the parsing should be organized according to the reading order. | • 识别图片中的公式,用 LaTeX 格式表示。

• 把图中的表格解析为 HTML。

• 解析图中的图表,对于流程图使用 Mermaid 格式表示,其他图表使用 Markdown 格式表示。

• 提取文档图片中正文的所有信息用 markdown 格式表示,其中页眉、页脚部分忽略,表格用 html 格式表达,文档中公式用 latex 格式表示,按照阅读顺序组织进行解析。 | +| **Information Extraction** | • Output the value of Key.

• Extract the content of the fields: ['key1','key2', ...] from the image and return it in JSON format.

• Extract the subtitles from the image. | • 输出 Key 的值。

• 提取图片中的: ['key1','key2', ...] 的字段内容,并按照 JSON 格式返回。

• 提取图片中的字幕。 | +| **Translation** | First extract the text, then translate the text content into English. If it is a document, ignore the header and footer. Formulas should be represented in LaTeX format, and tables should be represented in HTML format. | 先提取文字,再将文字内容翻译为英文。若是文档,则其中页眉、页脚忽略。公式用latex格式表示,表格用html格式表示。 | + +## GLM-OCR + +[GLM-OCR](https://huggingface.co/zai-org/GLM-OCR) is an OCR expert VLM in the `oar-ocr-vl` crate. It uses prompt-driven image-to-text generation and can be used directly or as a `DocParser` backend. ### Downloading the Model ```bash git lfs install -git clone https://huggingface.co/tencent/HunyuanOCR +git clone https://huggingface.co/zai-org/GLM-OCR # Or using hf -hf download tencent/HunyuanOCR --local-dir HunyuanOCR +hf download zai-org/GLM-OCR --local-dir GLM-OCR ``` ### Basic Usage ```rust,no_run use oar_ocr_core::utils::load_image; -use oar_ocr_vl::HunyuanOcr; +use oar_ocr_vl::GlmOcr; use oar_ocr_vl::utils::parse_device; fn main() -> Result<(), Box> { let image = load_image("document.jpg")?; let device = parse_device("cpu")?; // or "cuda", "cuda:0" - let model = HunyuanOcr::from_dir("HunyuanOCR", device)?; - - let prompt = "Detect and recognize text in the image, and output the text coordinates in a formatted manner."; - let text = model.generate(image, prompt, 1024)?; + let model = GlmOcr::from_dir("GLM-OCR", device)?; + let prompt = "Text Recognition:"; + let text = model + .generate(&[image], &[prompt], 1024) + .into_iter() + .next() + .expect("one result")?; println!("{text}"); Ok(()) @@ -410,34 +443,72 @@ fn main() -> Result<(), Box> { ### Running the Example ```bash -cargo run -p oar-ocr-vl --features cuda --example hunyuanocr -- \ - --model-dir HunyuanOCR \ +cargo run -p oar-ocr-vl --features cuda --example glmocr -- \ + --model-dir GLM-OCR \ --device cuda \ - --prompt "Detect and recognize text in the image, and output the text coordinates in a formatted manner." \ + --prompt "Text Recognition:" \ document.jpg ``` -### Application-oriented Prompts +## MinerU2.5 -Prompts from the upstream HunyuanOCR README: +[MinerU2.5](https://huggingface.co/opendatalab/MinerU2.5-2509-1.2B) is a document parsing VLM supported by `oar-ocr-vl`. For full-page documents, use its model-native two-step pipeline rather than forcing it through `DocParser`. -| Task | English | Chinese | -|------|---------|---------| -| **Spotting** | Detect and recognize text in the image, and output the text coordinates in a formatted manner. | 检测并识别图片中的文字,将文本坐标格式化输出。 | -| **Parsing** | • Identify the formula in the image and represent it using LaTeX format.

• Parse the table in the image into HTML.

• Parse the chart in the image; use Mermaid format for flowcharts and Markdown for other charts.

• Extract all information from the main body of the document image and represent it in markdown format, ignoring headers and footers. Tables should be expressed in HTML format, formulas in the document should be represented using LaTeX format, and the parsing should be organized according to the reading order. | • 识别图片中的公式,用 LaTeX 格式表示。

• 把图中的表格解析为 HTML。

• 解析图中的图表,对于流程图使用 Mermaid 格式表示,其他图表使用 Markdown 格式表示。

• 提取文档图片中正文的所有信息用 markdown 格式表示,其中页眉、页脚部分忽略,表格用 html 格式表达,文档中公式用 latex 格式表示,按照阅读顺序组织进行解析。 | -| **Information Extraction** | • Output the value of Key.

• Extract the content of the fields: ['key1','key2', ...] from the image and return it in JSON format.

• Extract the subtitles from the image. | • 输出 Key 的值。

• 提取图片中的: ['key1','key2', ...] 的字段内容,并按照 JSON 格式返回。

• 提取图片中的字幕。 | -| **Translation** | First extract the text, then translate the text content into English. If it is a document, ignore the header and footer. Formulas should be represented in LaTeX format, and tables should be represented in HTML format. | 先提取文字,再将文字内容翻译为英文。若是文档,则其中页眉、页脚忽略。公式用latex格式表示,表格用html格式表示。 | +### Downloading the Model + +```bash +git lfs install +git clone https://huggingface.co/opendatalab/MinerU2.5-2509-1.2B + +# Or using hf +hf download opendatalab/MinerU2.5-2509-1.2B --local-dir MinerU2.5-2509-1.2B +``` + +### Basic Usage + +```rust,no_run +use oar_ocr_core::utils::load_image; +use oar_ocr_vl::MinerU; +use oar_ocr_vl::utils::parse_device; + +fn main() -> Result<(), Box> { + let image = load_image("document.jpg")?; + let device = parse_device("cpu")?; // or "cuda", "cuda:0" + + let model = MinerU::from_dir("MinerU2.5-2509-1.2B", device)?; + let prompt = "\nText Recognition:"; + let text = model + .generate(&[image], &[prompt], 1024) + .into_iter() + .next() + .expect("one result")?; + println!("{text}"); + + Ok(()) +} +``` + +### Running the Example + +```bash +cargo run -p oar-ocr-vl --features cuda --example mineru -- \ + --model-dir MinerU2.5-2509-1.2B \ + --device cuda \ + document.jpg +``` ## DocParser -DocParser provides a unified API for two-stage document parsing that combines layout detection with VL-based recognition. It supports UniRec and PaddleOCR-VL (including PaddleOCR-VL-1.5) as recognition backends. +DocParser provides a unified API for external layout-first document parsing with VL-based recognition. It supports PaddleOCR-VL, PaddleOCR-VL-1.5, and GLM-OCR as recognition backends. + +Use `parse(&layout, image)` with an ONNX layout detector. HunyuanOCR and MinerU2.5 are not exposed by the `doc_parser` example because their reference-quality paths are prompt-driven full-page parsing and model-native two-step extraction, respectively. ### Basic Usage ```rust use oar_ocr_core::utils::load_image; use oar_ocr_core::predictors::LayoutDetectionPredictor; -use oar_ocr_vl::{DocParser, DocParserConfig, UniRec, PaddleOcrVl}; +use oar_ocr_vl::{DocParser, GlmOcr, PaddleOcrVl}; use oar_ocr_vl::utils::parse_device; use std::path::Path; @@ -452,21 +523,21 @@ fn main() -> Result<(), Box> { // Load document image let image = load_image(Path::new("document.jpg"))?; - // Option 1: Using UniRec (lighter, faster) - let unirec = UniRec::from_dir("models/unirec-0.1b", device.clone())?; - let parser = DocParser::with_config(&unirec, DocParserConfig::default()); + // Option 1: Using PaddleOCR-VL + let paddleocr_vl = PaddleOcrVl::from_dir("PaddleOCR-VL", device.clone())?; + let parser = DocParser::new(&paddleocr_vl); let result = parser.parse(&layout, image.clone())?; println!("{}", result.to_markdown()); - // Option 2: Using PaddleOCR-VL (heavier, more accurate) - let paddleocr_vl = PaddleOcrVl::from_dir("PaddleOCR-VL", device)?; - let parser = DocParser::new(&paddleocr_vl); + // Option 2: Using PaddleOCR-VL-1.5 (next-gen, more accurate) + let paddleocr_vl_15 = PaddleOcrVl::from_dir("PaddleOCR-VL-1.5", device.clone())?; + let parser = DocParser::new(&paddleocr_vl_15); let result = parser.parse(&layout, image.clone())?; println!("{}", result.to_markdown()); - // Option 3: Using PaddleOCR-VL-1.5 (next-gen, more accurate) - let paddleocr_vl_15 = PaddleOcrVl::from_dir("PaddleOCR-VL-1.5", device)?; - let parser = DocParser::new(&paddleocr_vl_15); + // Option 3: Using GLM-OCR with external layout + let glmocr = GlmOcr::from_dir("GLM-OCR", device)?; + let parser = DocParser::new(&glmocr); let result = parser.parse(&layout, image)?; println!("{}", result.to_markdown()); @@ -477,15 +548,7 @@ fn main() -> Result<(), Box> { ### Running the Example ```bash -# Using UniRec (default, lighter) -cargo run -p oar-ocr-vl --features cuda --example doc_parser -- \ - --model-name unirec \ - --model-dir models/unirec-0.1b \ - --layout-model models/pp-doclayoutv3.onnx \ - --device cuda \ - document.jpg - -# Using PaddleOCR-VL (heavier, more accurate) +# Using PaddleOCR-VL cargo run -p oar-ocr-vl --features cuda --example doc_parser -- \ --model-name paddleocr-vl \ --model-dir PaddleOCR-VL \ @@ -500,8 +563,49 @@ cargo run -p oar-ocr-vl --features cuda --example doc_parser -- \ --layout-model models/pp-doclayoutv3.onnx \ --device cuda \ document.jpg + +# Using GLM-OCR with layout +cargo run -p oar-ocr-vl --features cuda --example doc_parser -- \ + --model-name glmocr \ + --model-dir GLM-OCR \ + --layout-model models/pp-doclayoutv3.onnx \ + --device cuda \ + document.jpg + +``` + +## Hierarchical Speculative Decoding (HSD) + +HSD is a CUDA-only acceleration path available on every VLM backbone (`PaddleOcrVl`, `HunyuanOcr`, `GlmOcr`, `MinerU`). Enable it by building with the `hsd` feature; that pulls in the per-backbone `generate_hsd*` methods and transitively turns on `cuda`. + +Each backbone exposes a `generate_hsd*` entry point taking an `HsdConfig`. A typical call site: + +```rust +use oar_ocr_vl::hsd::types::{DsvConfig, HsdConfig}; + +let cfg = HsdConfig { + dsv: DsvConfig::default(), + enable_stage1: true, + enable_stage2: true, + max_page_tokens: 16_384, + max_region_tokens: 4_096, +}; +let (text, stats) = model.generate_hsd(&image, instruction, &drafts, &cfg)?; ``` +Run the demo example end-to-end: + +```bash +cargo run -p oar-ocr-vl --release --features hsd,download-binaries \ + --example hsd_demo -- \ + --backend hunyuanocr \ + --model-dir models/HunyuanOCR \ + --device cuda \ + --image document.jpg +``` + +See [`docs/hsd.md`](hsd.md) for the algorithm. + ## Configuration Options ### OrtSessionConfig diff --git a/examples/structure.rs b/examples/structure.rs index 2d15f08..e2ac01d 100644 --- a/examples/structure.rs +++ b/examples/structure.rs @@ -334,6 +334,10 @@ struct Args { #[arg(long, default_value_t = 1536)] formula_max_length: usize, + /// Preferred formula recognition batch size + #[arg(long, default_value_t = 8)] + formula_batch_size: usize, + /// Text detection score threshold (DB thresh, default: 0.3) #[arg(long, default_value = "0.3")] det_score_thresh: f32, @@ -551,6 +555,7 @@ fn main() -> Result<(), Box> { let formula_config = FormulaRecognitionConfig { score_threshold: args.formula_score_thresh, max_length: args.formula_max_length, + batch_size: args.formula_batch_size, }; let text_det_config = TextDetectionConfig { diff --git a/examples/table_structure_recognition.rs b/examples/table_structure_recognition.rs index e4a7f63..28893fd 100644 --- a/examples/table_structure_recognition.rs +++ b/examples/table_structure_recognition.rs @@ -40,7 +40,7 @@ //! ```bash //! cargo run --example table_structure_recognition -- \ //! --model-path path/to/model.onnx \ -//! --dict-path /path/to/table_structure_dict_ch.txt \ +//! --dict-path path/to/table_structure_dict_ch.txt \ //! --image-path path/to/image.jpg //! ``` //! @@ -49,7 +49,7 @@ //! ```bash //! cargo run --example table_structure_recognition -- \ //! --model-path path/to/model.onnx \ -//! --dict-path /path/to/table_structure_dict.txt \ +//! --dict-path path/to/table_structure_dict.txt \ //! --model-name SLANet_plus \ //! --image-path path/to/image.jpg //! ``` diff --git a/examples/utils/visualization.rs b/examples/utils/visualization.rs index c56421e..0469d3f 100644 --- a/examples/utils/visualization.rs +++ b/examples/utils/visualization.rs @@ -26,22 +26,25 @@ use tracing::{debug, info, warn}; /// Load a system font for text rendering. pub fn load_system_font() -> Option { - let font_paths = [ - "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", - "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", - "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", - "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", - "/System/Library/Fonts/Arial.ttf", - "C:\\Windows\\Fonts\\arial.ttf", - ]; - - for path in &font_paths { - if let Ok(font_data) = std::fs::read(path) - && let Ok(font) = FontVec::try_from_vec(font_data) - { - debug!("Loaded font from {}", path); - return Some(font); - } + let mut font_db = fontdb::Database::new(); + font_db.load_system_fonts(); + + let query = fontdb::Query { + families: &[ + fontdb::Family::SansSerif, + fontdb::Family::Serif, + fontdb::Family::Monospace, + ], + ..Default::default() + }; + + if let Some(font_id) = font_db.query(&query) + && let Some((font_data, face_index)) = + font_db.with_face_data(font_id, |data, index| (data.to_vec(), index)) + && let Ok(font) = FontVec::try_from_vec_and_index(font_data, face_index) + { + debug!("Loaded system font from font database"); + return Some(font); } debug!("No system font found"); diff --git a/oar-ocr-core/Cargo.toml b/oar-ocr-core/Cargo.toml index 401e9b8..d5ab9d6 100644 --- a/oar-ocr-core/Cargo.toml +++ b/oar-ocr-core/Cargo.toml @@ -38,7 +38,6 @@ itertools = "0.14" regex = "1.11.1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" -toml = "1.0" ort = { version = "2.0.0-rc.12", default-features = false, features = [ "std", "ndarray", "tracing", "copy-dylibs" ] } ndarray = "0.17" nalgebra = "0.34" @@ -48,7 +47,6 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokenizers = { version = "0.23", default-features = false, features = ["progressbar", "onig"] } clipper2 = "0.5.3" -html-escape = "0.2" [dev-dependencies] tempfile = "3.19" diff --git a/oar-ocr-core/src/core/errors/mod.rs b/oar-ocr-core/src/core/errors/mod.rs index 3277a56..12dc7ca 100644 --- a/oar-ocr-core/src/core/errors/mod.rs +++ b/oar-ocr-core/src/core/errors/mod.rs @@ -8,8 +8,8 @@ //! //! The error system is organized into several modules: //! -//! - [`types`] - Core error types (OCRError, ProcessingStage) -//! - [`constructors`] - Helper methods for creating errors with context +//! - `types` - Core error types (OCRError, ProcessingStage) +//! - `constructors` - Helper methods for creating errors with context //! //! # Main Error Types //! diff --git a/oar-ocr-core/src/core/inference/mod.rs b/oar-ocr-core/src/core/inference/mod.rs index 6391e75..ae59781 100644 --- a/oar-ocr-core/src/core/inference/mod.rs +++ b/oar-ocr-core/src/core/inference/mod.rs @@ -84,6 +84,30 @@ impl OrtInfer { _ => None, } } + + /// Returns the declared output names and tensor shapes from the first session. + /// + /// This is intended for model adapters that need to choose among multiple + /// ONNX outputs before interpreting tensors semantically. + pub fn output_shapes(&self) -> Vec<(String, Vec)> { + let Some(session_mutex) = self.sessions.first() else { + return Vec::new(); + }; + let Ok(session_guard) = session_mutex.lock() else { + return Vec::new(); + }; + session_guard + .outputs() + .iter() + .filter_map(|output| match output.dtype() { + ValueType::Tensor { shape, .. } => Some(( + output.name().to_string(), + shape.iter().copied().collect::>(), + )), + _ => None, + }) + .collect() + } } #[cfg(test)] diff --git a/oar-ocr-core/src/core/inference/tensor_output.rs b/oar-ocr-core/src/core/inference/tensor_output.rs index aa8caab..a3b3328 100644 --- a/oar-ocr-core/src/core/inference/tensor_output.rs +++ b/oar-ocr-core/src/core/inference/tensor_output.rs @@ -29,6 +29,14 @@ impl TensorOutput { } } + /// Returns a compact name for the tensor element type. + pub fn dtype_name(&self) -> &'static str { + match self { + TensorOutput::F32 { .. } => "f32", + TensorOutput::I64 { .. } => "i64", + } + } + /// Returns the number of dimensions. pub fn ndim(&self) -> usize { self.shape().len() diff --git a/oar-ocr-core/src/core/macros.rs b/oar-ocr-core/src/core/macros.rs index 90d6873..405aaf9 100644 --- a/oar-ocr-core/src/core/macros.rs +++ b/oar-ocr-core/src/core/macros.rs @@ -147,7 +147,7 @@ macro_rules! impl_task_type_enum { /// Macro to handle optional nested config initialization in builders. /// /// This macro eliminates the repeated pattern of: -/// ```rust,no_run +/// ```text /// // if self.config.field.is_none() { /// // self.config.field = Some(Type::new()); /// // } @@ -155,7 +155,7 @@ macro_rules! impl_task_type_enum { /// /// # Usage /// -/// ```rust,no_run +/// ```text /// // Instead of: /// // if self.config.orientation.is_none() { /// // self.config.orientation = Some(DocOrientationClassifierConfig::new()); @@ -187,7 +187,7 @@ macro_rules! with_nested { /// /// # Usage /// -/// ```rust,no_run +/// ```text /// // Instead of: /// // StageMetrics::new(success_count, failure_count) /// // .with_processing_time(start_time.elapsed()) @@ -235,7 +235,7 @@ macro_rules! metrics { /// /// # Usage /// -/// ```rust,no_run +/// ```text /// // impl_complete_builder! { /// // builder: MyBuilder, /// // config_field: config, @@ -847,7 +847,7 @@ macro_rules! impl_adapter_builder { /// Macro to conditionally apply OrtSessionConfig to any builder that has `with_ort_config`. /// /// This macro eliminates the repeated pattern: -/// ```rust,no_run +/// ```text /// // let mut builder = SomeBuilder::new(); /// // if let Some(ort_config) = ort_config { /// // builder = builder.with_ort_config(ort_config); @@ -855,13 +855,13 @@ macro_rules! impl_adapter_builder { /// ``` /// /// Instead, use: -/// ```rust,no_run +/// ```text /// // let builder = apply_ort_config!(SomeBuilder::new(), ort_config); /// ``` /// /// # Usage /// -/// ```rust,no_run +/// ```text /// // Works with any builder that has a `with_ort_config` method: /// // let builder = apply_ort_config!( /// // DBModelBuilder::new() diff --git a/oar-ocr-core/src/domain/adapters/formula_recognition_adapter.rs b/oar-ocr-core/src/domain/adapters/formula_recognition_adapter.rs index a63de03..77c6887 100644 --- a/oar-ocr-core/src/domain/adapters/formula_recognition_adapter.rs +++ b/oar-ocr-core/src/domain/adapters/formula_recognition_adapter.rs @@ -14,6 +14,7 @@ use crate::models::recognition::{ }; use crate::processors::normalize_latex; use std::path::{Path, PathBuf}; +use std::time::Instant; use tokenizers::Tokenizer; /// Special token IDs extracted from a tokenizer. @@ -107,12 +108,14 @@ impl FormulaModel { token_ids: &ndarray::Array2, sos_token_id: i64, eos_token_id: i64, + vocab_size: i64, ) -> Vec> { match self { FormulaModel::PPFormulaNet(_) => { let config = PPFormulaNetPostprocessConfig { sos_token_id, eos_token_id, + vocab_size, }; PPFormulaNetModel::filter_tokens(token_ids, &config) } @@ -120,6 +123,7 @@ impl FormulaModel { let config = UniMERNetPostprocessConfig { sos_token_id, eos_token_id, + vocab_size, }; UniMERNetModel::filter_tokens(token_ids, &config) } @@ -172,6 +176,7 @@ impl ModelAdapter for FormulaRecognitionAdapter { let batch_len = input.images.len(); // Preprocess and infer + let t_preprocess = Instant::now(); let batch_tensor = self .model .preprocess(input.into_owned_images()) @@ -182,6 +187,9 @@ impl ModelAdapter for FormulaRecognitionAdapter { e, ) })?; + let preprocess_dur = t_preprocess.elapsed(); + let batch_shape = batch_tensor.shape().to_vec(); + let t_infer = Instant::now(); let token_ids = self.model.infer(&batch_tensor).map_err(|e| { OCRError::adapter_execution_error( "FormulaRecognitionAdapter", @@ -189,12 +197,15 @@ impl ModelAdapter for FormulaRecognitionAdapter { e, ) })?; + let infer_dur = t_infer.elapsed(); // Filter tokens and decode + let t_decode = Instant::now(); let filtered_tokens = self.model.filter_tokens( &token_ids, self.model_config.sos_token_id, self.model_config.eos_token_id, + self.tokenizer.get_vocab_size(true) as i64, ); let mut formulas = Vec::new(); @@ -221,11 +232,14 @@ impl ModelAdapter for FormulaRecognitionAdapter { { tracing::warn!( "Token id(s) exceed tokenizer vocab (max_id={} >= vocab_size={}). \ - This usually means model/tokenizer mismatch. If you're using external models, \ - please supply the matching tokenizer via --tokenizer-path.", + Skipping this formula to avoid emitting corrupt LaTeX. \ + This usually means model/tokenizer mismatch or an unsupported model output type.", max_id, vocab_size ); + formulas.push(String::new()); + scores.push(None); + continue; } let latex = match self.tokenizer.decode(tokens_to_decode, true) { @@ -265,6 +279,28 @@ impl ModelAdapter for FormulaRecognitionAdapter { formulas.push(latex); scores.push(None); } + let decode_dur = t_decode.elapsed(); + // Per-batch diagnostics are noisy under bulk structure parsing; emit at + // debug so production logs stay clean. Enable with + // `RUST_LOG=oar_ocr_core::domain::adapters::formula_recognition_adapter=debug`. + if tracing::enabled!(tracing::Level::DEBUG) { + let token_lens: Vec = filtered_tokens.iter().map(Vec::len).collect(); + let raw_token_prefixes: Vec> = token_ids + .axis_iter(ndarray::Axis(0)) + .map(|row| row.iter().copied().take(12).collect()) + .collect(); + tracing::debug!( + "formula adapter: batch={}, tensor_shape={:?}, output_shape={:?}, token_lens={:?}, raw_prefixes={:?}, preprocess={:.1} ms, infer={:.1} ms, decode={:.1} ms", + batch_len, + batch_shape, + token_ids.shape(), + token_lens, + raw_token_prefixes, + preprocess_dur.as_secs_f64() * 1000.0, + infer_dur.as_secs_f64() * 1000.0, + decode_dur.as_secs_f64() * 1000.0 + ); + } Ok(FormulaRecognitionOutput { formulas, scores }) } @@ -274,7 +310,7 @@ impl ModelAdapter for FormulaRecognitionAdapter { } fn recommended_batch_size(&self) -> usize { - 8 + self.config.batch_size } } @@ -348,6 +384,12 @@ impl_adapter_builder! { self } + /// Sets the preferred formula recognition batch size. + pub fn batch_size(mut self, size: usize) -> Self { + self.config.task_config.batch_size = size; + self + } + /// Sets the task configuration (alias for with_config). pub fn task_config(mut self, config: FormulaRecognitionConfig) -> Self { self.config = self.config.with_task_config(config); @@ -445,6 +487,12 @@ impl_adapter_builder! { self } + /// Sets the preferred formula recognition batch size. + pub fn batch_size(mut self, size: usize) -> Self { + self.config.task_config.batch_size = size; + self + } + /// Sets the task configuration (alias for with_config). pub fn task_config(mut self, config: FormulaRecognitionConfig) -> Self { self.config = self.config.with_task_config(config); @@ -525,11 +573,13 @@ mod tests { let config = FormulaRecognitionConfig { score_threshold: 0.8, max_length: 512, + batch_size: 4, }; let builder = PPFormulaNetAdapterBuilder::new().with_config(config); assert_eq!(builder.config.task_config().score_threshold, 0.8); assert_eq!(builder.config.task_config().max_length, 512); + assert_eq!(builder.config.task_config().batch_size, 4); } #[test] @@ -537,10 +587,12 @@ mod tests { let builder = PPFormulaNetAdapterBuilder::new() .score_threshold(0.9) .max_length(1024) + .batch_size(2) .target_size(640, 640); assert_eq!(builder.config.task_config().score_threshold, 0.9); assert_eq!(builder.config.task_config().max_length, 1024); + assert_eq!(builder.config.task_config().batch_size, 2); assert_eq!(builder.target_size, Some((640, 640))); } @@ -551,6 +603,7 @@ mod tests { // Default config values assert_eq!(builder.config.task_config().score_threshold, 0.0); assert_eq!(builder.config.task_config().max_length, 1536); + assert_eq!(builder.config.task_config().batch_size, 8); } #[test] @@ -564,11 +617,13 @@ mod tests { let config = FormulaRecognitionConfig { score_threshold: 0.7, max_length: 2048, + batch_size: 4, }; let builder = UniMERNetAdapterBuilder::new().with_config(config); assert_eq!(builder.config.task_config().score_threshold, 0.7); assert_eq!(builder.config.task_config().max_length, 2048); + assert_eq!(builder.config.task_config().batch_size, 4); } #[test] @@ -576,10 +631,12 @@ mod tests { let builder = UniMERNetAdapterBuilder::new() .score_threshold(0.85) .max_length(768) + .batch_size(2) .target_size(512, 512); assert_eq!(builder.config.task_config().score_threshold, 0.85); assert_eq!(builder.config.task_config().max_length, 768); + assert_eq!(builder.config.task_config().batch_size, 2); assert_eq!(builder.target_size, Some((512, 512))); } @@ -590,6 +647,7 @@ mod tests { // Default config values assert_eq!(builder.config.task_config().score_threshold, 0.0); assert_eq!(builder.config.task_config().max_length, 1536); + assert_eq!(builder.config.task_config().batch_size, 8); } #[test] diff --git a/oar-ocr-core/src/domain/adapters/preprocessing.rs b/oar-ocr-core/src/domain/adapters/preprocessing.rs index c968360..d1f95a5 100644 --- a/oar-ocr-core/src/domain/adapters/preprocessing.rs +++ b/oar-ocr-core/src/domain/adapters/preprocessing.rs @@ -104,7 +104,7 @@ pub fn db_preprocess_for_text_type(text_type: Option<&str>) -> DBPreprocessConfi /// /// # Example /// -/// ```rust,no_run +/// ```text /// // let rgb_images: Vec = load_images(); /// // let dynamic_images = rgb_to_dynamic(rgb_images); /// ``` @@ -129,7 +129,7 @@ pub fn rgb_to_dynamic(images: Vec) -> Vec { /// /// # Example /// -/// ```rust,no_run +/// ```text /// // let tensor = resize_and_normalize( /// // images, /// // &self.resizer, @@ -165,7 +165,7 @@ where /// /// # Example /// -/// ```rust,no_run +/// ```text /// // let (tensor, scales) = detection_resize_and_normalize( /// // images, /// // &self.resizer, @@ -215,7 +215,7 @@ pub trait DetectionResizeOperation { /// /// # Example /// -/// ```rust,no_run +/// ```text /// // let tensor = PreprocessPipelineBuilder::new() /// // .rgb_images(images) /// // .resize(&resizer) diff --git a/oar-ocr-core/src/domain/tasks/formula_recognition.rs b/oar-ocr-core/src/domain/tasks/formula_recognition.rs index bdc5c13..c5dea78 100644 --- a/oar-ocr-core/src/domain/tasks/formula_recognition.rs +++ b/oar-ocr-core/src/domain/tasks/formula_recognition.rs @@ -20,6 +20,9 @@ pub struct FormulaRecognitionConfig { /// Maximum formula length in tokens (default: 1536) #[validate(min = 1)] pub max_length: usize, + /// Preferred batch size for formula recognition. + #[validate(min = 1)] + pub batch_size: usize, } impl Default for FormulaRecognitionConfig { @@ -27,6 +30,7 @@ impl Default for FormulaRecognitionConfig { Self { score_threshold: 0.0, max_length: 1536, + batch_size: 8, } } } diff --git a/oar-ocr-core/src/domain/tasks/layout_detection.rs b/oar-ocr-core/src/domain/tasks/layout_detection.rs index 8eea082..58bc8a6 100644 --- a/oar-ocr-core/src/domain/tasks/layout_detection.rs +++ b/oar-ocr-core/src/domain/tasks/layout_detection.rs @@ -45,7 +45,9 @@ impl Default for UnclipRatio { /// Configuration for layout detection task. #[derive(Debug, Clone, Serialize, Deserialize, ConfigValidator)] pub struct LayoutDetectionConfig { - /// Default score threshold for detection (default: 0.5) + /// Default score threshold for detection (default: 0.5, matches PaddleX's + /// `draw_threshold: 0.5` post-NMS visibility threshold from the published + /// `inference.yml` for the PP-DocLayout / PicoDet layout families). #[validate(range(min = 0.0, max = 1.0))] pub score_threshold: f32, /// Maximum number of layout elements (default: 100) @@ -128,8 +130,8 @@ impl LayoutDetectionConfig { /// Creates a config with PP-DocLayoutV2 default thresholds and merge modes. /// - /// These defaults are aligned with OpenOCR/OpenDoc's pipeline config: - /// `OpenOCR/configs/rec/unirec/opendoc_pipeline.yml`. + /// These defaults follow the per-class thresholds and merge-mode settings + /// used by upstream PP-DocLayoutV2 deployments. /// /// Notes: /// - The postprocessor applies `score_threshold` before per-class thresholds, so we set it diff --git a/oar-ocr-core/src/models/recognition/pp_formulanet.rs b/oar-ocr-core/src/models/recognition/pp_formulanet.rs index e44aaf9..6ef88a7 100644 --- a/oar-ocr-core/src/models/recognition/pp_formulanet.rs +++ b/oar-ocr-core/src/models/recognition/pp_formulanet.rs @@ -4,7 +4,7 @@ //! The model is independent of any specific task and can be reused in different contexts. use crate::core::OCRError; -use crate::core::inference::{OrtInfer, TensorInput}; +use crate::core::inference::{OrtInfer, TensorInput, TensorOutput}; use crate::processors::{FormulaPreprocessParams, FormulaPreprocessor}; use image::RgbImage; use ndarray::{ArrayBase, Axis, Data, Ix2}; @@ -43,6 +43,9 @@ pub struct PPFormulaNetPostprocessConfig { pub sos_token_id: i64, /// End-of-sequence token id pub eos_token_id: i64, + /// Tokenizer vocabulary size. Non-negative IDs at or above this value are + /// treated as padding/sentinel values emitted by exported ONNX models. + pub vocab_size: i64, } impl Default for PPFormulaNetPostprocessConfig { @@ -50,6 +53,7 @@ impl Default for PPFormulaNetPostprocessConfig { Self { sos_token_id: 0, eos_token_id: 2, + vocab_size: i64::MAX, } } } @@ -129,19 +133,54 @@ impl PPFormulaNetModel { source: Box::new(e), })?; - let output = outputs + tracing::debug!( + "PP-FormulaNet declared output shapes: {:?}", + self.inference.output_shapes() + ); + + // Some exported PP-FormulaNet ONNX models emit multiple tensors + // (e.g. token IDs + per-step scores). The token-ID tensor is the + // unique 2-D i64 output; pick it explicitly rather than trusting + // graph output order, which has bitten us before when exporters + // reordered metadata vs ids. + let i64_2d_count = outputs + .iter() + .filter(|(_, t)| matches!(t, TensorOutput::I64 { shape, .. } if shape.len() == 2)) + .count(); + if i64_2d_count != 1 { + // Defer the (name, dtype, shape) walk to the error path: on the + // happy path we don't pay for a `Vec` we'd immediately drop. + let candidates: Vec<(String, &'static str, Vec)> = outputs + .iter() + .map(|(name, t)| (name.clone(), t.dtype_name(), t.shape().to_vec())) + .collect(); + return Err(OCRError::Inference { + model_name: "PP-FormulaNet".to_string(), + context: format!( + "expected exactly one 2-D i64 output (token ids); found {} candidate(s) among outputs {:?}", + i64_2d_count, candidates + ), + source: Box::new(OCRError::InvalidInput { + message: "PP-FormulaNet: ambiguous or missing token-id output".to_string(), + }), + }); + } + let (name, tensor) = outputs .into_iter() - .next() - .ok_or_else(|| OCRError::InvalidInput { - message: "PP-FormulaNet: no output returned from inference".to_string(), - })?; - - output - .1 + .find(|(_, t)| matches!(t, TensorOutput::I64 { shape, .. } if shape.len() == 2)) + .expect("i64_2d_count == 1 checked above"); + tracing::debug!( + "PP-FormulaNet selected output '{}' dtype={} runtime_shape={:?}", + name, + tensor.dtype_name(), + tensor.shape() + ); + + tensor .try_into_array2_i64() .map_err(|e| OCRError::Inference { model_name: "PP-FormulaNet".to_string(), - context: "failed to convert output to 2D i64 array".to_string(), + context: format!("failed to convert output '{name}' to 2-D i64 array"), source: Box::new(e), }) } @@ -190,6 +229,7 @@ impl PPFormulaNetModel { .iter() .copied() .take_while(|&id| id != config.eos_token_id) + .take_while(|&id| id < 0 || id < config.vocab_size) .filter(|&id| id >= 0 && id != config.sos_token_id) .map(|id| id as u32) .collect(); @@ -277,3 +317,37 @@ impl PPFormulaNetModelBuilder { PPFormulaNetModel::new(inference, preprocess_config) } } + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::arr2; + + #[test] + fn filter_tokens_stops_at_vocab_sentinel() { + let token_ids = arr2(&[[0, 42, 49_999, 4_096_990_134i64, 77, 2]]); + let config = PPFormulaNetPostprocessConfig { + sos_token_id: 0, + eos_token_id: 2, + vocab_size: 50_000, + }; + + let filtered = PPFormulaNetModel::filter_tokens(&token_ids, &config); + + assert_eq!(filtered, vec![vec![42, 49_999]]); + } + + #[test] + fn filter_tokens_still_stops_at_eos() { + let token_ids = arr2(&[[0, 42, 2, 43]]); + let config = PPFormulaNetPostprocessConfig { + sos_token_id: 0, + eos_token_id: 2, + vocab_size: 50_000, + }; + + let filtered = PPFormulaNetModel::filter_tokens(&token_ids, &config); + + assert_eq!(filtered, vec![vec![42]]); + } +} diff --git a/oar-ocr-core/src/models/recognition/unimernet.rs b/oar-ocr-core/src/models/recognition/unimernet.rs index 89516f6..06b9409 100644 --- a/oar-ocr-core/src/models/recognition/unimernet.rs +++ b/oar-ocr-core/src/models/recognition/unimernet.rs @@ -4,6 +4,7 @@ //! The model is independent of any specific task and can be reused in different contexts. use crate::core::OCRError; +use crate::core::config::{OrtExecutionProvider, OrtGraphOptimizationLevel, OrtSessionConfig}; use crate::core::inference::{OrtInfer, TensorInput}; use crate::processors::{UniMERNetPreprocessParams, UniMERNetPreprocessor}; use image::RgbImage; @@ -43,6 +44,9 @@ pub struct UniMERNetPostprocessConfig { pub sos_token_id: i64, /// End-of-sequence token id pub eos_token_id: i64, + /// Tokenizer vocabulary size. Non-negative IDs at or above this value are + /// treated as padding/sentinel values emitted by exported ONNX models. + pub vocab_size: i64, } impl Default for UniMERNetPostprocessConfig { @@ -50,6 +54,7 @@ impl Default for UniMERNetPostprocessConfig { Self { sos_token_id: 0, eos_token_id: 2, + vocab_size: i64::MAX, } } } @@ -190,6 +195,7 @@ impl UniMERNetModel { .iter() .copied() .take_while(|&id| id != config.eos_token_id) + .take_while(|&id| id < 0 || id < config.vocab_size) .filter(|&id| id >= 0 && id != config.sos_token_id) .map(|id| id as u32) .collect(); @@ -248,10 +254,12 @@ impl UniMERNetModelBuilder { /// Builds the UniMERNet model. pub fn build(self, model_path: &std::path::Path) -> Result { // Create ONNX inference engine - let inference = if self.ort_config.is_some() { + let ort_config = self.ort_config.map(Self::configure_unimernet_ort_for_cuda); + + let inference = if ort_config.is_some() { use crate::core::config::ModelInferenceConfig; let common_config = ModelInferenceConfig { - ort_session: self.ort_config, + ort_session: ort_config, ..Default::default() }; OrtInfer::from_config(&common_config, model_path, None)? @@ -276,4 +284,122 @@ impl UniMERNetModelBuilder { UniMERNetModel::new(inference, preprocess_config) } + + fn configure_unimernet_ort_for_cuda(mut config: OrtSessionConfig) -> OrtSessionConfig { + if !Self::uses_cuda(&config) { + return config; + } + + config.optimization_level = Some(OrtGraphOptimizationLevel::Level1); + if config.enable_mem_pattern.is_none() { + config.enable_mem_pattern = Some(false); + } + + let entries = config + .session_config_entries + .get_or_insert_with(Default::default); + let disabled_optimizers = entries + .entry("optimization.disable_specified_optimizers".to_string()) + .or_default(); + if !disabled_optimizers + .split(',') + .any(|name| name.trim() == "ConstantFolding") + { + if !disabled_optimizers.trim().is_empty() { + disabled_optimizers.push(','); + } + disabled_optimizers.push_str("ConstantFolding"); + } + + config + } + + fn uses_cuda(config: &OrtSessionConfig) -> bool { + config + .execution_providers + .as_ref() + .is_some_and(|providers| { + providers + .iter() + .any(|provider| matches!(provider, OrtExecutionProvider::CUDA { .. })) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::arr2; + + #[test] + fn filter_tokens_stops_at_vocab_sentinel() { + let token_ids = arr2(&[[0, 42, 49_999, 4_096_990_134i64, 77, 2]]); + let config = UniMERNetPostprocessConfig { + sos_token_id: 0, + eos_token_id: 2, + vocab_size: 50_000, + }; + + let filtered = UniMERNetModel::filter_tokens(&token_ids, &config); + + assert_eq!(filtered, vec![vec![42, 49_999]]); + } + + #[test] + fn filter_tokens_still_stops_at_eos() { + let token_ids = arr2(&[[0, 42, 2, 43]]); + let config = UniMERNetPostprocessConfig { + sos_token_id: 0, + eos_token_id: 2, + vocab_size: 50_000, + }; + + let filtered = UniMERNetModel::filter_tokens(&token_ids, &config); + + assert_eq!(filtered, vec![vec![42]]); + } + + #[test] + fn cuda_config_disables_constant_folding_for_unimernet() { + let config = OrtSessionConfig::new().with_execution_providers(vec![ + OrtExecutionProvider::CUDA { + device_id: Some(0), + gpu_mem_limit: None, + arena_extend_strategy: None, + cudnn_conv_algo_search: None, + cudnn_conv_use_max_workspace: None, + }, + OrtExecutionProvider::CPU, + ]); + + let configured = UniMERNetModelBuilder::configure_unimernet_ort_for_cuda(config); + + assert!(matches!( + configured.optimization_level, + Some(OrtGraphOptimizationLevel::Level1) + )); + assert_eq!(configured.enable_mem_pattern, Some(false)); + assert_eq!( + configured + .session_config_entries + .as_ref() + .and_then(|entries| entries.get("optimization.disable_specified_optimizers")) + .map(String::as_str), + Some("ConstantFolding") + ); + } + + #[test] + fn cpu_config_keeps_unimernet_ort_config_unchanged() { + let config = + OrtSessionConfig::new().with_optimization_level(OrtGraphOptimizationLevel::All); + + let configured = UniMERNetModelBuilder::configure_unimernet_ort_for_cuda(config); + + assert!(matches!( + configured.optimization_level, + Some(OrtGraphOptimizationLevel::All) + )); + assert!(configured.session_config_entries.is_none()); + } } diff --git a/oar-ocr-core/src/predictors/formula_recognition.rs b/oar-ocr-core/src/predictors/formula_recognition.rs index fa39002..9bf74ac 100644 --- a/oar-ocr-core/src/predictors/formula_recognition.rs +++ b/oar-ocr-core/src/predictors/formula_recognition.rs @@ -105,6 +105,7 @@ impl FormulaRecognitionPredictorBuilder { state: PredictorBuilderState::new(FormulaRecognitionConfig { score_threshold: 0.0, max_length: 1536, + batch_size: 8, }), model_name: "FormulaRecognition".to_string(), tokenizer_path: None, diff --git a/oar-ocr-core/src/predictors/table_structure_recognition.rs b/oar-ocr-core/src/predictors/table_structure_recognition.rs index f4b24c6..cf1233f 100644 --- a/oar-ocr-core/src/predictors/table_structure_recognition.rs +++ b/oar-ocr-core/src/predictors/table_structure_recognition.rs @@ -34,19 +34,25 @@ enum TableStructureModelFamily { } impl TableStructureModelFamily { + /// `Wired` carries the 512×512 default that matches SLANeXt ONNX exports; + /// `Wireless` carries the 488×488 default that matches PaddleX's + /// `ResizeTableImage(max_len=488)+PaddingTableImage(488,488)` pipeline used + /// by both SLANet and SLANet_plus. fn from_model_name(model_name: &str) -> Option { match model_name { - "SLANet" | "SLANeXt_wired" | "SLANeXt_wireless" => Some(Self::Wired), - "SLANet_plus" => Some(Self::Wireless), + "SLANeXt_wired" | "SLANeXt_wireless" => Some(Self::Wired), + "SLANet" | "SLANet_plus" => Some(Self::Wireless), _ => None, } } fn detect_from_path(path: &Path) -> Option { let stem = path.file_stem()?.to_str()?.to_ascii_lowercase(); - if stem.contains("slanet_plus") { + if stem.contains("slanext") { + Some(Self::Wired) + } else if stem.contains("slanet_plus") || stem.contains("slanet") { Some(Self::Wireless) - } else if stem.contains("wired") || stem.contains("slanet") || stem.contains("slanext") { + } else if stem.contains("wired") { Some(Self::Wired) } else { None diff --git a/oar-ocr-core/src/processors/decode.rs b/oar-ocr-core/src/processors/decode.rs index fcc676d..ccd527b 100644 --- a/oar-ocr-core/src/processors/decode.rs +++ b/oar-ocr-core/src/processors/decode.rs @@ -19,8 +19,7 @@ pub type PositionedDecodeResult = ( ); static ALPHANUMERIC_REGEX: LazyLock = LazyLock::new(|| { - Regex::new(r"[a-zA-Z0-9 :*./%+-]") - .unwrap_or_else(|e| panic!("Failed to compile regex pattern: {e}")) + Regex::new(r"[a-zA-Z0-9 :*./%+-]").expect("static regex: alphanumeric decoder pattern") }); /// A base decoder for text recognition that handles character mapping and basic decoding operations. diff --git a/oar-ocr-core/src/processors/formula_preprocess.rs b/oar-ocr-core/src/processors/formula_preprocess.rs index 29bdeff..23ba69a 100644 --- a/oar-ocr-core/src/processors/formula_preprocess.rs +++ b/oar-ocr-core/src/processors/formula_preprocess.rs @@ -13,17 +13,16 @@ use std::sync::LazyLock; // Static regex patterns for LaTeX normalization static CHINESE_TEXT_PATTERN: LazyLock = LazyLock::new(|| { Regex::new(r"\\text\s*\{([^{}]*[\u{4e00}-\u{9fff}]+[^{}]*)\}") - .unwrap_or_else(|e| panic!("Failed to compile Chinese text regex pattern: {e}")) + .expect("static regex: Chinese text pattern") }); static TEXT_COMMAND_PATTERN: LazyLock = LazyLock::new(|| { Regex::new(r"(\\(operatorname|mathrm|text|mathbf)\s?\*?\s*\{.*?\})") - .unwrap_or_else(|e| panic!("Failed to compile text command regex pattern: {e}")) + .expect("static regex: text command pattern") }); static LETTER_TO_NONLETTER_PATTERN: LazyLock = LazyLock::new(|| { - Regex::new(r"([a-zA-Z])\s+([^a-zA-Z])") - .unwrap_or_else(|e| panic!("Failed to compile letter to nonletter regex pattern: {e}")) + Regex::new(r"([a-zA-Z])\s+([^a-zA-Z])").expect("static regex: letter to nonletter pattern") }); /// Configuration parameters for formula preprocessing pipeline. @@ -326,22 +325,25 @@ pub fn normalize_latex(latex: &str) -> String { prev_result = result.clone(); // Python pattern 1: r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter) - // This removes spaces between two non-letters unless preceded by backslash-space - // We need to be careful not to remove spaces after \\ + // In Python's raw regex, `\\ ` matches a literal `\` followed by space — + // i.e. the LaTeX thin-space token `\ `. The negative lookahead therefore + // refuses to match `(noletter)` when it would start with `\ `, leaving + // LaTeX thin spaces untouched. The earlier Rust port checked for `\\ ` + // (two literal backslashes + space, i.e. LaTeX line break) which is the + // wrong token; that bug let `\ \ ` collapse to `\\` and dropped both + // thin spaces. let mut temp = String::new(); let chars: Vec = result.chars().collect(); let mut i = 0; while i < chars.len() { - if i + 2 < chars.len() - && chars[i] == '\\' - && chars[i + 1] == '\\' - && chars[i + 2] == ' ' - { - // Keep "\\ " as is + if i + 1 < chars.len() && chars[i] == '\\' && chars[i + 1] == ' ' { + // Python's lookahead refuses to *start* a match at this `\`, + // but the following space is still a normal-space candidate for + // matches that begin at the next character. Mirror that by only + // emitting the backslash here and letting the next iteration + // decide what to do with the space. temp.push(chars[i]); - temp.push(chars[i + 1]); - temp.push(chars[i + 2]); - i += 3; + i += 1; } else if i + 1 < chars.len() && chars[i + 1].is_whitespace() { // Check if current char is noletter let is_noletter_current = !chars[i].is_ascii_alphabetic(); diff --git a/oar-ocr-core/src/processors/layout_sorting.rs b/oar-ocr-core/src/processors/layout_sorting.rs index 43b8fb3..c182f19 100644 --- a/oar-ocr-core/src/processors/layout_sorting.rs +++ b/oar-ocr-core/src/processors/layout_sorting.rs @@ -933,3 +933,102 @@ fn calculate_projection_overlap_ratio( 0.0 } } + +#[cfg(test)] +mod tests { + use super::*; + + fn elem( + x1: f32, + y1: f32, + x2: f32, + y2: f32, + element_type: LayoutElementType, + ) -> SortableElement { + SortableElement { + bbox: BoundingBox::from_coords(x1, y1, x2, y2), + element_type, + num_lines: Some(2), + } + } + + fn sort(elements: Vec) -> Vec { + sort_layout_enhanced(&elements, 400.0, 600.0) + } + + #[test] + fn sort_layout_enhanced_empty_input_returns_empty_order() { + assert!(sort_layout_enhanced(&[], 400.0, 600.0).is_empty()); + } + + #[test] + fn sort_layout_enhanced_places_headers_first_and_footers_last() { + let elements = vec![ + elem(20.0, 110.0, 380.0, 135.0, LayoutElementType::Text), + elem(20.0, 560.0, 380.0, 585.0, LayoutElementType::Footer), + elem(20.0, 25.0, 380.0, 45.0, LayoutElementType::Header), + elem(20.0, 5.0, 380.0, 20.0, LayoutElementType::Header), + elem(20.0, 145.0, 380.0, 170.0, LayoutElementType::Text), + ]; + + assert_eq!(sort(elements), vec![3, 2, 0, 4, 1]); + } + + #[test] + fn sort_layout_enhanced_inserts_document_title_before_body_text() { + let elements = vec![ + elem(20.0, 90.0, 380.0, 120.0, LayoutElementType::Text), + elem(20.0, 55.0, 380.0, 80.0, LayoutElementType::DocTitle), + elem(20.0, 130.0, 380.0, 160.0, LayoutElementType::Text), + ]; + + assert_eq!(sort(elements), vec![1, 0, 2]); + } + + #[test] + fn sort_layout_enhanced_orders_two_column_text_by_rows() { + let elements = vec![ + elem(215.0, 120.0, 380.0, 150.0, LayoutElementType::Text), + elem(20.0, 40.0, 185.0, 70.0, LayoutElementType::Text), + elem(215.0, 40.0, 380.0, 70.0, LayoutElementType::Text), + elem(20.0, 120.0, 185.0, 150.0, LayoutElementType::Text), + ]; + + assert_eq!(sort(elements), vec![1, 2, 3, 0]); + } + + #[test] + fn associate_child_blocks_keeps_near_vision_title_next_to_vision() { + let mut blocks = vec![ + SortableBlock::new( + BoundingBox::from_coords(20.0, 20.0, 380.0, 45.0), + 0, + LayoutElementType::Text, + Some(1), + ), + SortableBlock::new( + BoundingBox::from_coords(20.0, 90.0, 220.0, 190.0), + 1, + LayoutElementType::Image, + Some(5), + ), + SortableBlock::new( + BoundingBox::from_coords(20.0, 192.0, 220.0, 210.0), + 2, + LayoutElementType::FigureTitle, + Some(1), + ), + SortableBlock::new( + BoundingBox::from_coords(20.0, 230.0, 380.0, 255.0), + 3, + LayoutElementType::Text, + Some(1), + ), + ]; + + associate_child_blocks(&mut blocks); + + let order: Vec = blocks.iter().map(|b| b.original_index).collect(); + assert_eq!(order, vec![0, 1, 2, 3]); + } +} diff --git a/oar-ocr-core/src/utils/dict.rs b/oar-ocr-core/src/utils/dict.rs index e5185ae..5501da5 100644 --- a/oar-ocr-core/src/utils/dict.rs +++ b/oar-ocr-core/src/utils/dict.rs @@ -89,7 +89,7 @@ pub fn read_dict_content(path: &Path) -> Result { /// use oar_ocr_core::utils::require_path; /// use std::path::PathBuf; /// -/// let path: Option = Some(PathBuf::from("/path/to/dict.txt")); +/// let path: Option = Some(PathBuf::from("dict.txt")); /// let validated = require_path(path, "text_recognition", "character dictionary path")?; /// # Ok::<(), oar_ocr_core::core::OCRError>(()) /// ``` @@ -110,7 +110,7 @@ pub fn require_path + Clone>( mod tests { use super::*; use std::io::Write; - use tempfile::NamedTempFile; + use tempfile::{NamedTempFile, tempdir}; #[test] fn test_read_character_dict() -> Result<(), Box> { @@ -135,14 +135,17 @@ mod tests { } #[test] - fn test_read_nonexistent_file() { - let result = read_character_dict(Path::new("/nonexistent/path/dict.txt")); + fn test_read_nonexistent_file() -> Result<(), Box> { + let dir = tempdir()?; + let missing_path = dir.path().join("missing-dict.txt"); + let result = read_character_dict(&missing_path); assert!(result.is_err()); + Ok(()) } #[test] fn test_require_path_some() { - let path = Some(std::path::PathBuf::from("/some/path")); + let path = Some(std::path::PathBuf::from("some/path")); let result = require_path(path, "test", "test path"); assert!(result.is_ok()); } diff --git a/oar-ocr-vl/Cargo.toml b/oar-ocr-vl/Cargo.toml index 4276dce..af907b5 100644 --- a/oar-ocr-vl/Cargo.toml +++ b/oar-ocr-vl/Cargo.toml @@ -29,11 +29,18 @@ cuda = [ "oar-ocr-core/cuda", ] # When enabled, turns on Candle's Metal backend for GPU acceleration on Apple devices. +# Cargo features cannot be target-scoped directly, so Linux `--all-features` +# is intentionally not used for this crate; CI checks Metal on macOS instead. metal = [ "candle-core/metal", "candle-nn/metal", "candle-transformers/metal", ] +# When enabled, compiles the Hierarchical Speculative Decoding (HSD) module +# and the per-model `generate_hsd*` paths. HSD relies on Candle's CUDA backend +# (custom KV-cache gather, tree-attention forwards) and is not supported on +# CPU/Metal, so it transitively requires the `cuda` feature. +hsd = ["cuda"] [dependencies] oar-ocr-core.workspace = true @@ -56,3 +63,4 @@ clap = { version = "4.5.42", features = ["derive"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } criterion = { version = "0.8", features = ["html_reports"] } hayro = "0.6" +oar-ocr = { path = "..", default-features = false } diff --git a/oar-ocr-vl/README.md b/oar-ocr-vl/README.md index d212c05..5a059ad 100644 --- a/oar-ocr-vl/README.md +++ b/oar-ocr-vl/README.md @@ -2,7 +2,7 @@ Vision-Language models for document understanding in Rust. -This crate provides native Rust inference for document VLMs using [Candle](https://github.com/huggingface/candle), along with a two-stage document parsing pipeline. +This crate provides native Rust inference for document VLMs using [Candle](https://github.com/huggingface/candle), along with a document parsing pipeline for backends that work well with external layout detection. ## Supported Models @@ -10,20 +10,18 @@ This crate provides native Rust inference for document VLMs using [Candle](https |-------|------------|-------------| | [PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL) | 0.9B | SOTA document parsing VLM supporting 109 languages, text, tables, formulas, and 11 chart types | | [PaddleOCR-VL-1.5](https://huggingface.co/PaddlePaddle/PaddleOCR-VL-1.5) | 0.9B | Next-gen PaddleOCR-VL with 94.5% on OmniDocBench v1.5, adds text spotting and seal recognition | -| [UniRec](https://huggingface.co/topdu/unirec-0.1b) | 0.1B | Ultra-lightweight unified recognition for text, formulas, and tables (Chinese/English) | | [HunyuanOCR](https://huggingface.co/tencent/HunyuanOCR) | 1B | End-to-end OCR VLM for multilingual document parsing, text spotting, and information extraction | | [GLM-OCR](https://huggingface.co/zai-org/GLM-OCR) | 0.9B | #1 on OmniDocBench v1.5 (94.62), optimized for real-world scenarios with MTP loss and RL training | -| [LightOnOCR-2](https://huggingface.co/lightonai/LightOnOCR-2-1B) | 1B | SOTA on OlmOCR-Bench, 9x smaller than competitors, processes 5.7 pages/s on H100 | | [MinerU2.5](https://huggingface.co/opendatalab/MinerU2.5-2509-1.2B) | 1.2B | Decoupled document parsing VLM with strong text, formula, and table recognition | ## Document Parsing Pipeline -**DocParser** is a two-stage document parsing API that combines: +**DocParser** is a unified document parsing API for layout-first backends. It combines: 1. **Layout detection** (ONNX models like PP-DocLayoutV3) to identify document regions -2. **VL-based recognition** (any supported model above) to extract content from each region +2. **VL-based recognition** to extract content from each region -This approach provides structured output with reading order preservation. +Use DocParser with PaddleOCR-VL, PaddleOCR-VL-1.5, and GLM-OCR. HunyuanOCR should be used with its model-native full-page prompts, and MinerU2.5 should use its model-native two-step extraction example. ## Installation @@ -33,8 +31,7 @@ Add `oar-ocr-vl` to your project: cargo add oar-ocr-vl ``` -If you use ONNX-based helpers from `oar-ocr-core` and want ORT binaries to be fetched -automatically during build, enable `download-binaries` explicitly: +If you use ONNX-based helpers from `oar-ocr-core` and want ORT binaries to be fetched automatically during build, enable `download-binaries` explicitly: ```bash cargo add oar-ocr-vl --features download-binaries @@ -46,6 +43,16 @@ To enable GPU acceleration (CUDA), add the feature flag: cargo add oar-ocr-vl --features cuda ``` +### Hierarchical Speculative Decoding (HSD) + +A training-free CUDA acceleration scheme for the VLMs listed above. A cheap pipeline drafter (layout + OCR) proposes per-region text candidates and the target VLM verifies them in batches via tree-attention. Each backbone exposes `generate_hsd*` methods alongside its baseline `generate`. Build with `--features hsd` (transitively enables `cuda`): + +```bash +cargo add oar-ocr-vl --features hsd,download-binaries +``` + +See [`docs/hsd.md`](../docs/hsd.md) at the workspace root for the algorithm overview, `DsvConfig` / `HsdConfig` knobs, supported backbones, and Average-Acceptance-Length (AAL) guidance. End-to-end runnable examples live under `examples/hsd_*.rs`. + ## Usage ### PaddleOCR-VL @@ -62,8 +69,12 @@ let device = candle_core::Device::Cpu; // Or Device::new_cuda(0)? // Initialize model let model = PaddleOcrVl::from_dir("PaddleOCR-VL", device)?; -// Perform OCR -let result = model.generate(image, PaddleOcrVlTask::Ocr, 256)?; +// Perform OCR. The API is batch-oriented, so pass one task per image. +let result = model + .generate(&[image], &[PaddleOcrVlTask::Ocr], 256) + .into_iter() + .next() + .expect("one result")?; println!("Result: {}", result); ``` @@ -76,37 +87,22 @@ use oar_ocr_vl::{PaddleOcrVl, PaddleOcrVlTask}; let image = load_image("seal.png")?; let device = candle_core::Device::Cpu; let model = PaddleOcrVl::from_dir("PaddleOCR-VL-1.5", device)?; -let result = model.generate(image, PaddleOcrVlTask::Seal, 256)?; -println!("Result: {}", result); -``` - -### UniRec - -UniRec is a unified model that handles text, mathematical formulas, and table structures in a single pass without needing task-specific prompts. - -```rust -use oar_ocr_core::utils::load_image; -use oar_ocr_vl::UniRec; - -let image = load_image("mixed_content.png")?; -let device = candle_core::Device::Cpu; - -// Initialize model -let model = UniRec::from_dir("models/unirec-0.1b", device)?; - -// Generate content (automatically handles text, formulas, etc.) -let result = model.generate(image, 512)?; +let result = model + .generate(&[image], &[PaddleOcrVlTask::Seal], 256) + .into_iter() + .next() + .expect("one result")?; println!("Result: {}", result); ``` ### DocParser -Combine layout detection with a VLM backend to parse an entire page into Markdown. +Parse an entire page into Markdown with a layout predictor. This path is intended for external layout-first backends such as PaddleOCR-VL, PaddleOCR-VL-1.5, and GLM-OCR. ```rust use oar_ocr_core::utils::load_image; use oar_ocr_core::predictors::LayoutDetectionPredictor; -use oar_ocr_vl::{DocParser, UniRec}; +use oar_ocr_vl::{DocParser, PaddleOcrVl}; let device = candle_core::Device::Cpu; @@ -115,9 +111,9 @@ let layout_predictor = LayoutDetectionPredictor::builder() .model_name("pp-doclayoutv3") .build("pp-doclayoutv3.onnx")?; -// 2. Setup Recognition Backend (UniRec or PaddleOCR-VL) -let unirec = UniRec::from_dir("models/unirec-0.1b", device)?; -let parser = DocParser::new(&unirec); +// 2. Setup a layout-first recognition backend +let vl = PaddleOcrVl::from_dir("models/PaddleOCR-VL-1.5", device.clone())?; +let parser = DocParser::new(&vl); // 3. Parse Document let image = load_image("page.jpg")?; @@ -135,38 +131,35 @@ use oar_ocr_vl::MinerU; let image = load_image("document.png")?; let device = candle_core::Device::Cpu; -let model = MinerU::from_dir("/path/to/MinerU2.5-2509-1.2B", device)?; -let result = model.generate(&[image], &["\nDocument Parsing:"], 4096); -println!("Result: {}", result[0].as_ref()?); +let model = MinerU::from_dir("models/MinerU2.5-2509-1.2B", device)?; +// For full documents, prefer the `mineru` example, which follows the +// model-native two-step pipeline: layout detection, then crop recognition. +let result = model + .generate(&[image], &["\nText Recognition:"], 4096) + .into_iter() + .next() + .expect("one result")?; +println!("Result: {}", result); ``` ## Running Examples The `oar-ocr-vl` crate includes several examples demonstrating its capabilities. -### DocParser (Two-Stage Pipeline) +### DocParser -This example combines layout detection (ONNX) with a VLM for recognition. +This example combines layout detection (ONNX) with a VLM for recognition. It supports PaddleOCR-VL, PaddleOCR-VL-1.5, and GLM-OCR. ```bash cargo run --release --features cuda --example doc_parser -- \ - --model-name unirec \ - --model-dir models/unirec-0.1b \ + --model-name paddleocr-vl-1.5 \ + --model-dir models/PaddleOCR-VL-1.5 \ --layout-model models/pp-doclayoutv3.onnx \ --device cuda \ document.jpg ``` -### UniRec (Direct Inference) - -Run the UniRec model directly on an image. - -```bash -cargo run --release --features cuda --example unirec -- \ - --model-dir models/unirec-0.1b \ - --device cuda \ - formula.png -``` +HunyuanOCR and MinerU2.5 are intentionally not exposed by this example because their reference-quality paths are prompt-driven full-page parsing and model-native two-step extraction, respectively. ### PaddleOCR-VL (Direct Inference) @@ -224,11 +217,30 @@ cargo run --release --features cuda --example glmocr -- \ ### MinerU2.5 (Direct Inference) -Two-step document extraction (layout detection + content extraction): +Model-native two-step document extraction (layout prompt + content extraction): ```bash cargo run --release --features cuda --example mineru -- \ - --model-dir /path/to/MinerU2.5-2509-1.2B \ + --model-dir models/MinerU2.5-2509-1.2B \ --device cuda:0 \ document.jpg ``` + +### HSD (Hierarchical Speculative Decoding) + +The shared `hsd_demo` example runs baseline and HSD back-to-back so you can +compare wall time and outputs. Select the target VLM via `--backend`: + +```bash +# Single-page smoke test (HunyuanOCR backbone). +cargo run --release --features hsd,download-binaries --example hsd_demo -- \ + --backend hunyuanocr \ + --model-dir models/HunyuanOCR \ + --device cuda \ + --image document.jpg + +# Quality + perf matrix over OmniDocBench-style inputs. +cargo run --release --features hsd,download-binaries --example hsd_omnidocbench -- --help +``` + +See [`docs/hsd.md`](../docs/hsd.md) for the full set of HSD knobs and the backbone-by-backbone capability matrix. diff --git a/oar-ocr-vl/build.rs b/oar-ocr-vl/build.rs new file mode 100644 index 0000000..c6dd8ed --- /dev/null +++ b/oar-ocr-vl/build.rs @@ -0,0 +1,8 @@ +fn main() { + let metal_enabled = std::env::var_os("CARGO_FEATURE_METAL").is_some(); + let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + + if metal_enabled && target_os != "macos" { + panic!("oar-ocr-vl feature `metal` is only supported on macOS targets"); + } +} diff --git a/oar-ocr-vl/examples/doc_parser.rs b/oar-ocr-vl/examples/doc_parser.rs index 589d7d9..3a19ce3 100644 --- a/oar-ocr-vl/examples/doc_parser.rs +++ b/oar-ocr-vl/examples/doc_parser.rs @@ -1,19 +1,16 @@ //! Unified Document Parser Example //! -//! This example demonstrates the unified DocParser API that supports multiple -//! recognition models for two-stage document parsing (layout detection + recognition). +//! This example demonstrates the unified DocParser API for external +//! layout-first document parsing (layout detection + region recognition). +//! +//! HunyuanOCR and MinerU2.5 are intentionally not exposed here: +//! their reference-quality usage is full-page prompt-driven parsing or a +//! model-native two-step pipeline, not forced external layout crops. //! //! # Usage //! //! ```bash -//! # Using UniRec model (default, lighter) -//! cargo run -p oar-ocr-vl --example doc_parser -- \ -//! --model-name unirec \ -//! --model-dir models/unirec-0.1b \ -//! --layout-model models/pp-doclayoutv3.onnx \ -//! document.jpg -//! -//! # Using PaddleOCR-VL model (heavier, more accurate) +//! # Using PaddleOCR-VL model //! cargo run -p oar-ocr-vl --example doc_parser -- \ //! --model-name paddleocr-vl \ //! --model-dir PaddleOCR-VL \ @@ -27,23 +24,10 @@ //! --layout-model models/pp-doclayoutv3.onnx \ //! document.jpg //! -//! # Using LightOnOCR model (end-to-end OCR) -//! cargo run -p oar-ocr-vl --example doc_parser -- \ -//! --model-name lightonocr \ -//! --model-dir LightOnOCR-2-1B \ -//! document.jpg -//! //! # Using GLM-OCR model //! cargo run -p oar-ocr-vl --example doc_parser -- \ //! --model-name glmocr \ -//! --model-dir /path/to/GLM-OCR \ -//! --layout-model models/pp-doclayoutv3.onnx \ -//! document.jpg -//! -//! # Using MinerU2.5 model -//! cargo run -p oar-ocr-vl --example doc_parser -- \ -//! --model-name mineru \ -//! --model-dir /path/to/MinerU2.5-2509-1.2B \ +//! --model-dir models/GLM-OCR \ //! --layout-model models/pp-doclayoutv3.onnx \ //! document.jpg //! ``` @@ -65,44 +49,33 @@ use oar_ocr_vl::{DocParser, DocParserConfig}; /// Recognition model type #[derive(Debug, Clone, Copy, ValueEnum)] enum ModelName { - /// UniRec: Lightweight unified recognition - Unirec, - /// PaddleOCR-VL: Large VLM with task prompts + /// PaddleOCR-VL: VLM with task prompts #[value(name = "paddleocr-vl")] PaddleOcrVl, /// PaddleOCR-VL-1.5: Next-gen VLM with spotting and seal recognition #[value(name = "paddleocr-vl-1.5")] PaddleOcrVl15, - /// HunyuanOCR: OCR expert VLM (HunYuanVL) - #[value(name = "hunyuanocr")] - HunyuanOcr, /// GLM-OCR: OCR expert VLM (GLM-V) #[value(name = "glmocr")] GlmOcr, - /// LightOnOCR: End-to-end OCR VLM - #[value(name = "lightonocr")] - LightOnOcr, - /// MinerU2.5: Qwen2-VL document parsing model - #[value(name = "mineru")] - MinerU, } /// Command-line arguments #[derive(Parser)] #[command(name = "doc_parser")] #[command( - about = "Unified Document Parser - supports UniRec, PaddleOCR-VL, PaddleOCR-VL-1.5, HunyuanOCR, GLM-OCR, LightOnOCR, and MinerU2.5 models" + about = "Unified external-layout DocParser - supports PaddleOCR-VL, PaddleOCR-VL-1.5, and GLM-OCR" )] struct Args { /// Recognition model to use - #[arg(short = 'n', long, value_enum, default_value = "unirec")] + #[arg(short = 'n', long, value_enum, default_value = "paddleocr-vl-1.5")] model_name: ModelName, /// Path to the model directory #[arg(short, long)] model_dir: PathBuf, - /// Path to the PP-DocLayout ONNX model file (v2/v3, required unless using lightonocr) + /// Path to the PP-DocLayout ONNX model file (v2/v3, required) #[arg(short, long)] layout_model: Option, @@ -132,7 +105,7 @@ struct Args { } fn main() -> Result<(), Box> { - use oar_ocr_vl::{GlmOcr, HunyuanOcr, LightOnOcr, MinerU, PaddleOcrVl, UniRec}; + use oar_ocr_vl::{GlmOcr, PaddleOcrVl}; utils::init_tracing(); let args = Args::parse(); @@ -145,23 +118,17 @@ fn main() -> Result<(), Box> { error!("Model directory not found: {}", args.model_dir.display()); return Err("Model directory not found".into()); } - let needs_layout = !matches!(args.model_name, ModelName::LightOnOcr); - let layout_model_path = if needs_layout { - let path = args.layout_model.as_ref().ok_or_else(|| { - error!( - "Layout model is required for {:?} (not needed for LightOnOcr)", - args.model_name - ); - "Layout model not provided" - })?; - if !path.exists() { - error!("Layout model not found: {}", path.display()); - return Err("Layout model not found".into()); - } - Some(path) - } else { - None - }; + let layout_model_path = args.layout_model.as_ref().ok_or_else(|| { + error!( + "Layout model is required for {:?}. Use the hunyuanocr/mineru examples for model-native full-page parsing.", + args.model_name + ); + "Layout model not provided" + })?; + if !layout_model_path.exists() { + error!("Layout model not found: {}", layout_model_path.display()); + return Err("Layout model not found".into()); + } // Filter valid images let existing_images: Vec = @@ -179,32 +146,28 @@ fn main() -> Result<(), Box> { let device = parse_device(&args.device)?; info!("Device: {:?}", device); - let layout_predictor = if let Some(layout_path) = layout_model_path { - info!("Loading layout model..."); - let normalized_layout_name = args.layout_model_name.to_lowercase().replace('-', "_"); - let layout_config = match normalized_layout_name.as_str() { - "pp_doclayoutv2" | "pp_doclayout_v2" => { - Some(LayoutDetectionConfig::with_pp_doclayoutv2_defaults()) - } - "pp_doclayoutv3" | "pp_doclayout_v3" => { - Some(LayoutDetectionConfig::with_pp_doclayoutv3_defaults()) - } - "pp_structurev3" | "pp_structure_v3" => { - Some(LayoutDetectionConfig::with_pp_structurev3_defaults()) - } - _ => None, - }; - - let mut layout_builder = - LayoutDetectionPredictor::builder().model_name(&args.layout_model_name); - if let Some(config) = layout_config { - layout_builder = layout_builder.with_config(config); + info!("Loading layout model..."); + let normalized_layout_name = args.layout_model_name.to_lowercase().replace('-', "_"); + let layout_config = match normalized_layout_name.as_str() { + "pp_doclayoutv2" | "pp_doclayout_v2" => { + Some(LayoutDetectionConfig::with_pp_doclayoutv2_defaults()) } - Some(layout_builder.build(layout_path)?) - } else { - None + "pp_doclayoutv3" | "pp_doclayout_v3" => { + Some(LayoutDetectionConfig::with_pp_doclayoutv3_defaults()) + } + "pp_structurev3" | "pp_structure_v3" => { + Some(LayoutDetectionConfig::with_pp_structurev3_defaults()) + } + _ => None, }; + let mut layout_builder = + LayoutDetectionPredictor::builder().model_name(&args.layout_model_name); + if let Some(config) = layout_config { + layout_builder = layout_builder.with_config(config); + } + let layout_predictor = layout_builder.build(layout_model_path)?; + // Create config let config = DocParserConfig { max_tokens: args.max_tokens, @@ -213,18 +176,6 @@ fn main() -> Result<(), Box> { // Process images with the selected model match args.model_name { - ModelName::Unirec => { - info!("Loading UniRec model..."); - let load_start = Instant::now(); - let unirec = UniRec::from_dir(&args.model_dir, device)?; - info!( - "UniRec loaded in {:.2}ms", - load_start.elapsed().as_secs_f64() * 1000.0 - ); - - let parser = DocParser::with_config(&unirec, config); - process_images(&parser, layout_predictor.as_ref(), &existing_images, &args)?; - } ModelName::PaddleOcrVl | ModelName::PaddleOcrVl15 => { info!("Loading PaddleOCR-VL model..."); let load_start = Instant::now(); @@ -235,19 +186,7 @@ fn main() -> Result<(), Box> { ); let parser = DocParser::with_config(&vl, config); - process_images(&parser, layout_predictor.as_ref(), &existing_images, &args)?; - } - ModelName::HunyuanOcr => { - info!("Loading HunyuanOCR model..."); - let load_start = Instant::now(); - let model = HunyuanOcr::from_dir(&args.model_dir, device)?; - info!( - "HunyuanOCR loaded in {:.2}ms", - load_start.elapsed().as_secs_f64() * 1000.0 - ); - - let parser = DocParser::with_config(&model, config); - process_images(&parser, layout_predictor.as_ref(), &existing_images, &args)?; + process_images(&parser, &layout_predictor, &existing_images, &args)?; } ModelName::GlmOcr => { info!("Loading GLM-OCR model..."); @@ -259,31 +198,7 @@ fn main() -> Result<(), Box> { ); let parser = DocParser::with_config(&model, config); - process_images(&parser, layout_predictor.as_ref(), &existing_images, &args)?; - } - ModelName::LightOnOcr => { - info!("Loading LightOnOCR model..."); - let load_start = Instant::now(); - let model = LightOnOcr::from_dir(&args.model_dir, device)?; - info!( - "LightOnOCR loaded in {:.2}ms", - load_start.elapsed().as_secs_f64() * 1000.0 - ); - - let parser = DocParser::with_config(&model, config); - process_images(&parser, layout_predictor.as_ref(), &existing_images, &args)?; - } - ModelName::MinerU => { - info!("Loading MinerU2.5 model..."); - let load_start = Instant::now(); - let model = MinerU::from_dir(&args.model_dir, device)?; - info!( - "MinerU2.5 loaded in {:.2}ms", - load_start.elapsed().as_secs_f64() * 1000.0 - ); - - let parser = DocParser::with_config(&model, config); - process_images(&parser, layout_predictor.as_ref(), &existing_images, &args)?; + process_images(&parser, &layout_predictor, &existing_images, &args)?; } } Ok(()) @@ -291,7 +206,7 @@ fn main() -> Result<(), Box> { fn process_images( parser: &DocParser, - layout_predictor: Option<&LayoutDetectionPredictor>, + layout_predictor: &LayoutDetectionPredictor, images: &[PathBuf], args: &Args, ) -> Result<(), Box> { @@ -315,10 +230,7 @@ fn process_images( }; let start = Instant::now(); - let result = match layout_predictor { - Some(predictor) => parser.parse(predictor, rgb_img), - None => parser.parse_without_layout(rgb_img), - }; + let result = parser.parse(layout_predictor, rgb_img); match result { Ok(result) => { info!(" Parsed in {:.2}s", start.elapsed().as_secs_f64()); diff --git a/oar-ocr-vl/examples/glmocr.rs b/oar-ocr-vl/examples/glmocr.rs index c9e2db2..a507bf7 100644 --- a/oar-ocr-vl/examples/glmocr.rs +++ b/oar-ocr-vl/examples/glmocr.rs @@ -12,7 +12,7 @@ //! //! ```bash //! cargo run -p oar-ocr-vl --example glmocr -- \ -//! --model-dir /path/to/GLM-OCR \ +//! --model-dir models/GLM-OCR \ //! --prompt "Text Recognition:" \ //! document.jpg //! ``` diff --git a/oar-ocr-vl/examples/hsd_demo.rs b/oar-ocr-vl/examples/hsd_demo.rs new file mode 100644 index 0000000..1c8d6e0 --- /dev/null +++ b/oar-ocr-vl/examples/hsd_demo.rs @@ -0,0 +1,637 @@ +//! Hierarchical Speculative Decoding demo / harness, shared across backends. +//! +//! Two passes per run: +//! +//! 1. **Correctness check** (skip with `--skip-check`). +//! Use the baseline `generate(...)` output as a perfect draft and run HSD +//! with τ=1.0. The output must match the baseline exactly. Any divergence +//! indicates a bug somewhere in the HSD pipeline. +//! +//! 2. **Performance measurement.** +//! Run HSD with the user-supplied draft (or the baseline output if no +//! `--draft-text` / `--draft-file` is given) at τ=0.75 and report SR_e2e +//! along with Average Acceptance Length, fallback steps, and the per-stage +//! breakdown. +//! +//! # Usage +//! +//! ```bash +//! cargo run -p oar-ocr-vl --release --features hsd,download-binaries --example hsd_demo -- \ +//! --backend hunyuanocr \ +//! --model-dir models/HunyuanOCR \ +//! --device cuda:0 \ +//! --image document.jpg \ +//! --max-tokens 4096 +//! ``` +//! +//! Pass `--draft-text "..."` or `--draft-file path` to use a real drafter's +//! output as the speculative draft instead of the baseline (which would +//! otherwise produce an artificially high AAL). + +mod utils; + +#[cfg(not(feature = "hsd"))] +fn main() { + eprintln!("This example requires the `hsd` feature. Re-run with `--features hsd`."); + std::process::exit(1); +} + +#[cfg(feature = "hsd")] +fn main() -> Result<(), Box> { + imp::run() +} + +#[cfg(feature = "hsd")] +mod imp { + + use super::utils; + + use clap::{Parser, ValueEnum}; + use std::fs; + use std::path::PathBuf; + use std::time::Instant; + + use image::RgbImage; + use oar_ocr_core::utils::load_image; + use oar_ocr_core::{ + domain::structure::{LayoutElement, LayoutElementType}, + processors::BoundingBox, + }; + use oar_ocr_vl::hsd::types::{Draft, HsdStats}; + use oar_ocr_vl::utils::parse_device; + use oar_ocr_vl::{GlmOcr, HunyuanOcr, MinerU, PaddleOcrVl, PaddleOcrVlTask}; + use utils::{ + DEMO_DEFAULT_MAX_CANDIDATES, DEMO_DEFAULT_MAX_SUFFIX_LEN, auto_tune_hsd_oracle, + make_hsd_cfg, print_diff, print_hsd_stats, print_preview, + }; + + #[derive(Copy, Clone, Debug, ValueEnum)] + enum Backend { + #[value(name = "hunyuanocr")] + HunyuanOcr, + #[value(name = "paddleocr_vl")] + PaddleOcrVl, + #[value(name = "mineru")] + MinerU, + #[value(name = "glmocr")] + GlmOcr, + } + + #[derive(Copy, Clone, Debug, ValueEnum)] + enum Task { + Ocr, + Table, + Chart, + Formula, + Spotting, + Seal, + } + + impl Task { + fn to_native(self) -> PaddleOcrVlTask { + match self { + Task::Ocr => PaddleOcrVlTask::Ocr, + Task::Table => PaddleOcrVlTask::Table, + Task::Chart => PaddleOcrVlTask::Chart, + Task::Formula => PaddleOcrVlTask::Formula, + Task::Spotting => PaddleOcrVlTask::Spotting, + Task::Seal => PaddleOcrVlTask::Seal, + } + } + } + + #[derive(Parser)] + #[command(name = "hsd_demo")] + #[command(about = "HSD correctness check + performance demo across VL backends")] + struct Args { + /// Target backend. + #[arg(long, value_enum, default_value_t = Backend::HunyuanOcr)] + backend: Backend, + + /// Path to the backend model directory. + #[arg(long)] + model_dir: PathBuf, + + /// Path to a single page / region image. + #[arg(long)] + image: PathBuf, + + /// Device: cpu, cuda, cuda:N, or metal. + #[arg(long, default_value = "cuda:0")] + device: String, + + /// Maximum tokens to generate. + #[arg(long, default_value_t = 4096)] + max_tokens: usize, + + /// Instruction prompt. Defaults depend on `--backend`: + /// - hunyuanocr: full-page spotting prompt + /// - glmocr: "Read the text in this image." + /// - mineru: "Read the text in this image:" + /// + /// The default for paddleocr_vl is encoded by `--task`, not by this flag. + #[arg(long)] + instruction: Option, + + /// PaddleOCR-VL task. Ignored for other backends. + #[arg(long, value_enum, default_value_t = Task::Ocr)] + task: Task, + + /// Optional pre-computed draft text from a real drafter pipeline. If + /// omitted, the baseline output is reused as an oracle draft (upper-bound + /// AAL, still useful for wall-clock validation). + #[arg(long)] + draft_text: Option, + + /// File whose contents are used as the draft. Mutually exclusive with + /// `--draft-text`. + #[arg(long)] + draft_file: Option, + + /// Skip the τ=1.0 correctness check (saves one HSD run). + #[arg(long)] + skip_check: bool, + + /// Acceptance threshold for the perf pass. + #[arg(long, default_value_t = 0.75)] + tau: f32, + + /// HunyuanOCR-only: exercise `generate_hsd_full` with a single full-page + /// region draft. Minimal real Stage-1 path; dataset region benchmarks + /// should use `hsd_omnidocbench`. + #[arg(long)] + stage1_full: bool, + + /// Reference window length n (paper §3.2). + #[arg(long, default_value_t = 3)] + window_len: usize, + + /// Maximum prefix-tree candidate count per verification step. + #[arg(long, default_value_t = DEMO_DEFAULT_MAX_CANDIDATES)] + max_candidates: usize, + + /// Maximum candidate suffix length. + #[arg(long, default_value_t = DEMO_DEFAULT_MAX_SUFFIX_LEN)] + max_suffix_len: usize, + } + + enum DraftSource { + Oracle, + Cli(String), + File(String), + } + + impl DraftSource { + fn label(&self) -> &'static str { + match self { + DraftSource::Oracle => "oracle (baseline output)", + DraftSource::Cli(_) => "--draft-text", + DraftSource::File(_) => "--draft-file", + } + } + + fn text(&self) -> Option<&str> { + match self { + DraftSource::Oracle => None, + DraftSource::Cli(s) | DraftSource::File(s) => Some(s.as_str()), + } + } + } + + fn default_instruction(backend: Backend) -> &'static str { + match backend { + Backend::HunyuanOcr => { + "Detect and recognize text in the image, and output the text coordinates in a formatted manner." + } + Backend::GlmOcr => "Read the text in this image.", + Backend::MinerU => "Read the text in this image:", + // PaddleOCR-VL builds its prompt from --task; instruction is unused. + Backend::PaddleOcrVl => "", + } + } + + pub fn run() -> Result<(), Box> { + utils::init_tracing(); + let args = Args::parse(); + let device = parse_device(&args.device)?; + let image = load_image(&args.image)?; + let instruction = args + .instruction + .clone() + .unwrap_or_else(|| default_instruction(args.backend).to_string()); + + let draft_source = match (&args.draft_text, &args.draft_file) { + (Some(t), None) => DraftSource::Cli(t.clone()), + (None, Some(p)) => DraftSource::File(fs::read_to_string(p)?), + (None, None) => DraftSource::Oracle, + (Some(_), Some(_)) => { + return Err("--draft-text and --draft-file are mutually exclusive".into()); + } + }; + + if args.stage1_full && !matches!(args.backend, Backend::HunyuanOcr) { + return Err("--stage1-full is only supported with --backend hunyuanocr".into()); + } + + match args.backend { + Backend::HunyuanOcr => { + run_hunyuanocr(&args, device, &image, &instruction, &draft_source) + } + Backend::GlmOcr => run_glmocr(&args, device, &image, &instruction, &draft_source), + Backend::MinerU => run_mineru(&args, device, &image, &instruction, &draft_source), + Backend::PaddleOcrVl => run_paddleocr_vl(&args, device, &image, &draft_source), + } + } + + fn run_hunyuanocr( + args: &Args, + device: candle_core::Device, + image: &RgbImage, + instruction: &str, + draft_source: &DraftSource, + ) -> Result<(), Box> { + println!("Loading HunyuanOCR from {}", args.model_dir.display()); + let t_load = Instant::now(); + let model = HunyuanOcr::from_dir(&args.model_dir, device)?; + println!("Model load: {:?}", t_load.elapsed()); + println!("Loaded image: {}x{}", image.width(), image.height()); + + let (baseline_tokens, baseline_text, baseline_dur) = baseline_pass("HunyuanOCR", || { + let t = Instant::now(); + let res = + model.generate_tokens(std::slice::from_ref(image), &[instruction], args.max_tokens); + let dur = t.elapsed(); + let tokens = first_result(res, "HunyuanOCR baseline")?; + let text = model.decode_tokens(&tokens)?; + Ok::<_, Box>((tokens, text, dur)) + })?; + print_preview("BASELINE", &baseline_text); + + if !args.skip_check { + println!("\n[2/3] HSD τ=1.0 correctness check (oracle draft = baseline)..."); + let cfg = make_hsd_cfg( + args.max_tokens, + 1.0, + args.window_len, + args.max_candidates, + args.max_suffix_len, + false, + ); + let t = Instant::now(); + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + let (hsd_tokens, _stats) = model.generate_hsd_tokens_with_token_drafts( + image, + instruction, + &token_drafts, + &cfg, + )?; + let dur = t.elapsed(); + if hsd_tokens == baseline_tokens { + println!(" ✓ τ=1.0 HSD output matches baseline ({:?})", dur); + } else { + let hsd_text = model.decode_tokens(&hsd_tokens)?.trim().to_string(); + eprintln!(" ✗ τ=1.0 HSD output diverges from baseline."); + print_diff(&baseline_text, &hsd_text); + return Err("HSD τ=1.0 mismatch".into()); + } + } else { + println!("\n[2/3] correctness check skipped"); + } + + println!( + "\n[3/3] HSD τ={:.2} performance pass (draft source: {})...", + args.tau, + draft_source.label() + ); + let (eff_max_candidates, eff_max_suffix_len, note) = auto_tune_hsd_oracle( + matches!(draft_source, DraftSource::Oracle), + args.max_candidates, + args.max_suffix_len, + args.max_tokens, + ); + if let Some(n) = note { + println!("{n}"); + } + let cfg = make_hsd_cfg( + args.max_tokens, + args.tau, + args.window_len, + eff_max_candidates, + eff_max_suffix_len, + args.stage1_full, + ); + let t_hsd = Instant::now(); + let (hsd_text, stats): (String, HsdStats) = if args.stage1_full { + let text = match draft_source { + DraftSource::Oracle => baseline_text.as_str(), + DraftSource::Cli(s) | DraftSource::File(s) => s.as_str(), + }; + let element = full_page_text_element(image, text); + model.generate_hsd_full( + image, + oar_ocr_vl::HunyuanHsdPrompts { + page: instruction, + region: instruction, + }, + std::slice::from_ref(&element), + &[], + |elem| elem.text.iter().cloned().collect(), + &cfg, + )? + } else { + match draft_source { + DraftSource::Oracle => { + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + model.generate_hsd_with_token_drafts(image, instruction, &token_drafts, &cfg)? + } + DraftSource::Cli(s) | DraftSource::File(s) => { + model.generate_hsd(image, instruction, std::slice::from_ref(s), &cfg)? + } + } + }; + let hsd_dur = t_hsd.elapsed(); + print_preview("HSD OUTPUT", &hsd_text); + print_hsd_stats(baseline_dur, hsd_dur, &stats, args.stage1_full); + oracle_note(draft_source); + Ok(()) + } + + fn run_glmocr( + args: &Args, + device: candle_core::Device, + image: &RgbImage, + instruction: &str, + draft_source: &DraftSource, + ) -> Result<(), Box> { + println!("Loading GLM-OCR from {}", args.model_dir.display()); + let model = GlmOcr::from_dir(&args.model_dir, device)?; + + let (baseline_tokens, baseline_text, baseline_dur) = baseline_pass("GLM-OCR", || { + let t = Instant::now(); + let res = + model.generate_tokens(std::slice::from_ref(image), &[instruction], args.max_tokens); + let dur = t.elapsed(); + let tokens = first_result(res, "GLM-OCR baseline")?; + let text = model.decode_tokens(&tokens)?; + Ok::<_, Box>((tokens, text, dur)) + })?; + + if !args.skip_check { + println!("\n[2/3] HSD τ=1.0 correctness check (oracle draft = baseline)..."); + let cfg = make_hsd_cfg( + args.max_tokens, + 1.0, + args.window_len, + args.max_candidates, + args.max_suffix_len, + false, + ); + let t = Instant::now(); + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + let (hsd_text, _) = + model.generate_hsd_with_token_drafts(image, instruction, &token_drafts, &cfg)?; + let dur = t.elapsed(); + if hsd_text == baseline_text { + println!(" ✓ matches baseline ({:?})", dur); + } else { + eprintln!(" ✗ diverges from baseline."); + print_diff(&baseline_text, &hsd_text); + return Err("τ=1.0 mismatch".into()); + } + } else { + println!("\n[2/3] correctness check skipped"); + } + + println!("\n[3/3] HSD τ={:.2} performance pass...", args.tau); + let (eff_max_candidates, eff_max_suffix_len, note) = auto_tune_hsd_oracle( + matches!(draft_source, DraftSource::Oracle), + args.max_candidates, + args.max_suffix_len, + args.max_tokens, + ); + if let Some(n) = note { + println!("{n}"); + } + let cfg = make_hsd_cfg( + args.max_tokens, + args.tau, + args.window_len, + eff_max_candidates, + eff_max_suffix_len, + false, + ); + let t = Instant::now(); + let (hsd_text, stats) = match draft_source { + DraftSource::Oracle => { + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + model.generate_hsd_with_token_drafts(image, instruction, &token_drafts, &cfg)? + } + DraftSource::Cli(s) | DraftSource::File(s) => { + model.generate_hsd(image, instruction, std::slice::from_ref(s), &cfg)? + } + }; + let hsd_dur = t.elapsed(); + print_preview("HSD OUTPUT", &hsd_text); + print_hsd_stats(baseline_dur, hsd_dur, &stats, false); + oracle_note(draft_source); + Ok(()) + } + + fn run_mineru( + args: &Args, + device: candle_core::Device, + image: &RgbImage, + instruction: &str, + draft_source: &DraftSource, + ) -> Result<(), Box> { + println!("Loading MinerU2.5 from {}", args.model_dir.display()); + let model = MinerU::from_dir(&args.model_dir, device)?; + + let (baseline_tokens, baseline_text, baseline_dur) = baseline_pass("MinerU2.5", || { + let t = Instant::now(); + let res = + model.generate_tokens(std::slice::from_ref(image), &[instruction], args.max_tokens); + let dur = t.elapsed(); + let tokens = first_result(res, "MinerU2.5 baseline")?; + let text = model.decode_tokens(&tokens)?; + Ok::<_, Box>((tokens, text, dur)) + })?; + + if !args.skip_check { + println!("\n[2/3] HSD τ=1.0 correctness check (oracle draft = baseline)..."); + let cfg = make_hsd_cfg( + args.max_tokens, + 1.0, + args.window_len, + args.max_candidates, + args.max_suffix_len, + false, + ); + let t = Instant::now(); + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + let (hsd_text, _) = + model.generate_hsd_with_token_drafts(image, instruction, &token_drafts, &cfg)?; + let dur = t.elapsed(); + if hsd_text == baseline_text { + println!(" ✓ matches baseline ({:?})", dur); + } else { + eprintln!(" ✗ diverges from baseline."); + print_diff(&baseline_text, &hsd_text); + return Err("τ=1.0 mismatch".into()); + } + } else { + println!("\n[2/3] correctness check skipped"); + } + + println!("\n[3/3] HSD τ={:.2} performance pass...", args.tau); + let (eff_max_candidates, eff_max_suffix_len, note) = auto_tune_hsd_oracle( + matches!(draft_source, DraftSource::Oracle), + args.max_candidates, + args.max_suffix_len, + args.max_tokens, + ); + if let Some(n) = note { + println!("{n}"); + } + let cfg = make_hsd_cfg( + args.max_tokens, + args.tau, + args.window_len, + eff_max_candidates, + eff_max_suffix_len, + false, + ); + let t = Instant::now(); + let (hsd_text, stats) = match draft_source { + DraftSource::Oracle => { + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + model.generate_hsd_with_token_drafts(image, instruction, &token_drafts, &cfg)? + } + DraftSource::Cli(s) | DraftSource::File(s) => { + model.generate_hsd(image, instruction, std::slice::from_ref(s), &cfg)? + } + }; + let hsd_dur = t.elapsed(); + print_preview("HSD OUTPUT", &hsd_text); + print_hsd_stats(baseline_dur, hsd_dur, &stats, false); + oracle_note(draft_source); + Ok(()) + } + + fn run_paddleocr_vl( + args: &Args, + device: candle_core::Device, + image: &RgbImage, + draft_source: &DraftSource, + ) -> Result<(), Box> { + let task = args.task.to_native(); + println!("Loading PaddleOCR-VL from {}", args.model_dir.display()); + let model = PaddleOcrVl::from_dir(&args.model_dir, device)?; + + let (baseline_tokens, baseline_text, baseline_dur) = baseline_pass("PaddleOCR-VL", || { + let t = Instant::now(); + let res = model.generate_tokens(std::slice::from_ref(image), &[task], args.max_tokens); + let dur = t.elapsed(); + let tokens = first_result(res, "PaddleOCR-VL baseline")?; + let (_, pp) = model.decode_tokens(&tokens, task)?; + Ok::<_, Box>((tokens, pp, dur)) + })?; + + if !args.skip_check { + println!("\n[2/3] HSD τ=1.0 correctness check (oracle draft = baseline)..."); + let cfg = make_hsd_cfg( + args.max_tokens, + 1.0, + args.window_len, + args.max_candidates, + args.max_suffix_len, + false, + ); + let t = Instant::now(); + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + let (hsd_text, _) = + model.generate_hsd_with_token_drafts(image, task, &token_drafts, &cfg)?; + let dur = t.elapsed(); + if hsd_text == baseline_text { + println!(" ✓ matches baseline ({:?})", dur); + } else { + eprintln!(" ✗ diverges from baseline."); + print_diff(&baseline_text, &hsd_text); + return Err("τ=1.0 mismatch".into()); + } + } else { + println!("\n[2/3] correctness check skipped"); + } + + println!("\n[3/3] HSD τ={:.2} performance pass...", args.tau); + let (eff_max_candidates, eff_max_suffix_len, note) = auto_tune_hsd_oracle( + matches!(draft_source, DraftSource::Oracle), + args.max_candidates, + args.max_suffix_len, + args.max_tokens, + ); + if let Some(n) = note { + println!("{n}"); + } + let cfg = make_hsd_cfg( + args.max_tokens, + args.tau, + args.window_len, + eff_max_candidates, + eff_max_suffix_len, + false, + ); + let t = Instant::now(); + let (hsd_text, stats) = match draft_source { + DraftSource::Oracle => { + let token_drafts = vec![Draft::new(baseline_tokens.clone())]; + model.generate_hsd_with_token_drafts(image, task, &token_drafts, &cfg)? + } + DraftSource::Cli(s) | DraftSource::File(s) => { + model.generate_hsd(image, task, std::slice::from_ref(s), &cfg)? + } + }; + let hsd_dur = t.elapsed(); + print_preview("HSD OUTPUT", &hsd_text); + print_hsd_stats(baseline_dur, hsd_dur, &stats, false); + oracle_note(draft_source); + Ok(()) + } + + fn baseline_pass(label: &str, f: F) -> Result> + where + F: FnOnce() -> Result>, + { + println!("\n[1/3] {label} baseline generate..."); + f() + } + + fn first_result( + results: Vec>, + label: &str, + ) -> Result> { + match results.into_iter().next() { + Some(Ok(t)) => Ok(t), + Some(Err(e)) => Err(format!("{label} failed: {e}").into()), + None => Err(format!("{label} returned no results").into()), + } + } + + fn full_page_text_element(image: &RgbImage, text: &str) -> LayoutElement { + LayoutElement::new( + BoundingBox::from_coords(0.0, 0.0, image.width() as f32, image.height() as f32), + LayoutElementType::Text, + 1.0, + ) + .with_text(text.to_string()) + } + + fn oracle_note(draft_source: &DraftSource) { + if draft_source.text().is_none() { + println!( + "(Oracle draft = baseline output — AAL is an upper bound; \ + realistic drafters give lower numbers.)" + ); + } + } +} // mod imp diff --git a/oar-ocr-vl/examples/hsd_omnidocbench.rs b/oar-ocr-vl/examples/hsd_omnidocbench.rs new file mode 100644 index 0000000..6f25c0d --- /dev/null +++ b/oar-ocr-vl/examples/hsd_omnidocbench.rs @@ -0,0 +1,3136 @@ +//! Run HSD on OmniDocBench v1.5 with a real per-page draft. +//! +//! For each page (up to `--max-pages`): +//! 1. Load the page image from `/images/`. +//! 2. Build a GT draft from `layout_dets`: markdown/plain text for document +//! parsers, or `text<|LOC_x|><|LOC_y|>...` spotting streams for +//! PaddleOCR-VL spotting. +//! 3. Run baseline generation and time it. +//! 4. Run backend `generate_hsd(...)` with the draft, capture stats. +//! 5. Aggregate `SR_decode`, `SR_e2e`, AAL, fallback ratio across pages. +//! +//! ```bash +//! cargo run -p oar-ocr-vl --release --features hsd,download-binaries \ +//! --example hsd_omnidocbench -- \ +//! --bench-dir data/omnidocbench_v1.5 \ +//! --model-dir models/HunyuanOCR \ +//! --max-pages 20 +//! +//! cargo run -p oar-ocr-vl --release --features hsd,download-binaries \ +//! --example hsd_omnidocbench -- \ +//! --backend paddleocr_vl --task spotting \ +//! --bench-dir data/omnidocbench_v1.5 \ +//! --model-dir models/PaddleOCR-VL-1.5 \ +//! --max-pages 20 +//! ``` + +mod utils; + +#[cfg(not(feature = "hsd"))] +fn main() { + eprintln!("This example requires the `hsd` feature. Re-run with `--features hsd`."); + std::process::exit(1); +} + +#[cfg(feature = "hsd")] +fn main() -> Result<(), Box> { + imp::run() +} + +#[cfg(feature = "hsd")] +mod imp { + + use super::utils; + + use clap::{Parser, ValueEnum}; + use image::imageops::FilterType; + use serde::Deserialize; + use std::collections::{BTreeMap, HashMap}; + use std::path::PathBuf; + use std::time::{Duration, Instant}; + use tokenizers::Tokenizer; + + use oar_ocr::prelude::{OARStructure, OARStructureBuilder}; + use oar_ocr_core::core::{OCRError, config::OrtSessionConfig}; + use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, StructureResult}; + use oar_ocr_core::domain::tasks::FormulaRecognitionConfig; + use oar_ocr_core::predictors::TextRecognitionPredictor; + use oar_ocr_core::processors::{BoundingBox, Point}; + use oar_ocr_core::utils::{BBoxCrop, load_image}; + use oar_ocr_vl::hsd::drafting::{ + TargetDraftAdapter, bbox_xyxy, page_markdown_for, region_markdowns_for, + }; + use oar_ocr_vl::hsd::types::{ + Draft, DsvConfig, HsdConfig, HsdStats, RegionKind, SpecDecodeStats, + }; + use oar_ocr_vl::utils::parse_device; + use oar_ocr_vl::{GlmOcr, HunyuanOcr, MinerU, PaddleOcrVl, PaddleOcrVlTask}; + + use utils::structure_match::{MatchThresholds, match_region}; + + const HUNYUAN_CHINESE_PARSING_PROMPT: &str = "提取文档图片中正文的所有信息用 markdown 格式表示,其中页眉、页脚部分忽略,表格用 html 格式表达,文档中公式用 latex 格式表示,按照阅读顺序组织进行解析。"; + const HUNYUAN_REGION_PROMPT: &str = "Extract all information from the document region image and represent it in markdown format. Tables should be expressed in HTML format, and formulas should be represented using LaTeX format."; + const HUNYUAN_CHINESE_REGION_PROMPT: &str = "提取文档区域图片中的所有信息用 markdown 格式表示,表格用 html 格式表达,公式用 latex 格式表示。"; + const GLMOCR_TEXT_RECOGNITION_PROMPT: &str = "Text Recognition:"; + const MINERU_TEXT_RECOGNITION_PROMPT: &str = "\nText Recognition:"; + + #[derive(Copy, Clone, Debug, ValueEnum)] + enum Backend { + #[value(name = "hunyuanocr")] + HunyuanOcr, + #[value(name = "paddleocr_vl")] + PaddleOcrVl, + #[value(name = "mineru")] + MinerU, + #[value(name = "glmocr")] + GlmOcr, + } + + impl Backend { + fn as_str(self) -> &'static str { + match self { + Self::HunyuanOcr => "hunyuanocr", + Self::PaddleOcrVl => "paddleocr_vl", + Self::MinerU => "mineru", + Self::GlmOcr => "glmocr", + } + } + } + + #[derive(Copy, Clone, Debug, ValueEnum)] + enum Task { + Ocr, + Table, + Chart, + Formula, + Spotting, + Seal, + } + + impl Task { + fn to_native(self) -> PaddleOcrVlTask { + match self { + Task::Ocr => PaddleOcrVlTask::Ocr, + Task::Table => PaddleOcrVlTask::Table, + Task::Chart => PaddleOcrVlTask::Chart, + Task::Formula => PaddleOcrVlTask::Formula, + Task::Spotting => PaddleOcrVlTask::Spotting, + Task::Seal => PaddleOcrVlTask::Seal, + } + } + } + + #[derive(Copy, Clone, Debug, ValueEnum)] + enum Mode { + /// Full-page generation / verification. + Page, + /// PaddleOCR-VL region-level verification using OmniDocBench layout crops. + Region, + } + + enum BackendModel { + HunyuanOcr(HunyuanOcr), + PaddleOcrVl(PaddleOcrVl), + MinerU(MinerU), + GlmOcr(GlmOcr), + } + + struct RegionDrafts { + joined: String, + per_element: Vec>, + per_element_tokens: Vec>>, + } + + #[derive(Parser)] + #[command(name = "hsd_omnidocbench")] + struct Args { + /// Root of the unzipped OmniDocBench dataset (must contain + /// `OmniDocBench.json` and `images/`). + #[arg(long)] + bench_dir: PathBuf, + /// Backend weights directory. + #[arg(long)] + model_dir: PathBuf, + /// Target backend to benchmark. + #[arg(long, value_enum, default_value_t = Backend::HunyuanOcr)] + backend: Backend, + /// PaddleOCR-VL task. Ignored by HunyuanOCR. + #[arg(long, value_enum, default_value_t = Task::Spotting)] + task: Task, + /// Benchmark mode. `region` currently supports HunyuanOCR and PaddleOCR-VL. + #[arg(long, value_enum, default_value_t = Mode::Page)] + mode: Mode, + #[arg(long, default_value = "cuda:0")] + device: String, + /// Device for ONNX-based PP-OCRv5 / PP-StructureV3 drafter models. + /// Defaults to `--device`. + #[arg(long)] + drafter_device: Option, + #[arg(long, default_value_t = 4096)] + max_tokens: usize, + /// Number of pages to evaluate. Default 5 for a quick smoke run. + #[arg(long, default_value_t = 5)] + max_pages: usize, + /// Index of the first entry to evaluate (skip this many before counting). + #[arg(long, default_value_t = 0)] + start_idx: usize, + /// Optional substring filter — only run entries whose `image_path` contains this. + #[arg(long)] + filter: Option, + /// Restrict to a specific OmniDocBench subset (e.g. `v1.5`, `equation_hard`). + #[arg(long)] + subset: Option, + /// Restrict to a specific page language (e.g. `english`, `chinese`). + #[arg(long)] + language: Option, + /// Use the matching official Chinese-language Parsing prompt for pages + /// whose language is `chinese`. (English pages still use --instruction.) + #[arg(long, default_value_t = true)] + auto_prompt_lang: bool, + /// Skip pages whose image fails to load (vs. aborting). + #[arg(long, default_value_t = true)] + skip_missing: bool, + /// HunyuanOCR instruction prompt. Defaults to HunyuanOCR's official + /// "Parsing" task prompt, which elicits markdown output matching the + /// OmniDocBench GT format. (Source: HunyuanOCR README under Quick Start + /// → Tasks.) + #[arg( + long, + default_value = "Extract all information from the main body of the document image and represent it in markdown format, ignoring headers and footers. Tables should be expressed in HTML format, formulas in the document should be represented using LaTeX format, and the parsing should be organized according to the reading order." + )] + instruction: String, + /// HunyuanOCR region-level instruction used for Stage 1 crop verification. + #[arg(long, default_value = HUNYUAN_REGION_PROMPT)] + hunyuanocr_region_instruction: String, + /// GLM-OCR instruction prompt. Defaults to the model-native OCR expert + /// prompt documented in the local README. + #[arg(long, default_value = GLMOCR_TEXT_RECOGNITION_PROMPT)] + glmocr_instruction: String, + /// GLM-OCR region-level instruction used for Stage 1 crop verification. + #[arg(long, default_value = GLMOCR_TEXT_RECOGNITION_PROMPT)] + glmocr_region_instruction: String, + /// MinerU2.5 instruction prompt. Defaults to the model-native content + /// extraction prompt used by the MinerU two-step example. + #[arg(long, default_value = MINERU_TEXT_RECOGNITION_PROMPT)] + mineru_instruction: String, + /// MinerU2.5 region-level instruction used for Stage 1 crop verification. + /// Pass an empty string to opt into two-step mode where each layout + /// element gets its native MinerU prompt (`\nText Recognition:`, + /// `\nTable Recognition:`, `\nFormula Recognition:`, `\nImage Analysis:`), + /// matching MinerU's official `two_step_extract` flow. + #[arg(long, default_value = MINERU_TEXT_RECOGNITION_PROMPT)] + mineru_region_instruction: String, + /// Convenience flag for `--mineru-region-instruction ""` — forces MinerU + /// Stage 1 into per-element prompt dispatch (`two_step_extract`-style). + /// Overrides any explicit `--mineru-region-instruction` value. + #[arg(long, default_value_t = false)] + mineru_two_step: bool, + /// Path to a JSON file containing pre-postprocess raw drafts from a + /// *different* VLM. Activates `--draft-source cross-vlm-file`. Schema: + /// `{"source_backend": "", "pages": {"": [{"bbox": [...], "raw_text": "..."}]}}`. + /// Use the source backend's `decode_tokens_raw` to populate `raw_text`. + /// The target adapter handles any per-target surface conversion + /// (HTML↔OTSL, formula wrapping, etc.) — no explicit source hint needed. + #[arg(long)] + cross_vlm_draft_file: Option, + /// IoU floor when matching cross-VLM regions onto layout elements. + /// Defaults to 0.5, matching the structure/formula IoU thresholds used + /// elsewhere in the bench. + #[arg(long, default_value_t = 0.5)] + cross_vlm_iou_threshold: f32, + #[arg(long, default_value_t = 0.75)] + tau: f32, + /// Override `DsvConfig::window_len` (paper §3.2 `n`). Default 0 = honour + /// the preset's value (3 for all presets). + /// + /// **Use with care.** Empirical 2026-05-13 result on HunyuanOCR page+gt: a + /// 3-page smoke with `--dsv-window-len 2` *regressed* SR_e2e from 0.45× + /// to 0.17× because the matcher then found many more candidates per step + /// but most were stale matches — leading to bigger trees that the + /// verifier had to forward through anyway, then reject. When the matrix + /// reports high `dsv empty tree calls / page`, the right answer is + /// usually NOT a smaller window — it's a better drafter (closer + /// byte-level alignment with the target VLM's natural output). + #[arg(long, default_value_t = 0)] + dsv_window_len: usize, + /// Override `DsvConfig::max_candidates_per_step`. Default 0 = honour the + /// preset's value (32 for default, 128 for omnibench). Lower this when + /// `dsv avg tree nodes / page` blows up — each verify_tree forward + /// processes the whole packed tree even if all paths get rejected, so + /// big trees on divergent drafts pay full compute for no acceptance. + #[arg(long, default_value_t = 0)] + dsv_max_candidates: usize, + /// Override `DsvConfig::max_suffix_len`. Default 0 = honour the preset's + /// value (256 for all presets). + #[arg(long, default_value_t = 0)] + dsv_max_suffix_len: usize, + /// Print the first N chars of baseline + draft for each page (debug). + #[arg(long, default_value_t = 0)] + preview: usize, + /// Write baseline-vs-draft token alignment diagnostics to this file. + #[arg(long)] + token_diff_output: Option, + /// Number of leading tokens to include in token diagnostics. + #[arg(long, default_value_t = 200)] + token_diff_limit: usize, + /// N-gram/window length used for token-match diagnostics. + #[arg(long, default_value_t = 3)] + token_diff_window_len: usize, + /// Stop after writing token diagnostics, before running HSD. + #[arg(long, default_value_t = false)] + token_diff_only: bool, + /// In page mode, run a full Stage-1 + Stage-2 HSD path for backends that + /// support it when region elements are available from the drafter. Use + /// `--page-dual-stage=false` to keep the page-level-only ablation. + #[arg(long, default_value_t = true, action = clap::ArgAction::Set)] + page_dual_stage: bool, + /// Count pages with AAL at or below this threshold as low-AAL outliers. + #[arg(long, default_value_t = 0.5)] + outlier_aal_threshold: f32, + /// Count pages with SR_e2e below this threshold as speed-regression + /// outliers. + #[arg(long, default_value_t = 1.0)] + outlier_sr_e2e_threshold: f64, + /// Draft source: `gt` (build from OmniDocBench layout_dets), `baseline` + /// (re-use baseline output as an oracle draft), `ppocr-rec` (region-mode + /// PP-OCRv5 recognition model output), or `structure` (OARStructureBuilder + /// / PP-StructureV3-style markdown in page mode; IoU-matched + /// text/table/formula drafts in region mode). + #[arg(long, default_value = "gt")] + draft_source: String, + /// Preserve legacy HunyuanOCR GT draft markdown heuristics (`# title` and + /// `$$\nformula\n$$`). The default is tuned for token overlap with + /// HunyuanOCR baseline output. + #[arg(long, default_value_t = false)] + hunyuanocr_legacy_gt_format: bool, + /// Apply lightweight PaddleOCR-VL region GT surface normalization. Only + /// affects `--backend paddleocr_vl --mode region --draft-source gt`. + #[arg(long, default_value_t = false)] + normalize_draft: bool, + /// Pre-resize each page so its longer side fits this many pixels. + /// HunyuanOCR's vit encoder fails on very large pages even though the + /// preprocessor's clamp logic claims to handle them; 1280 is a safe + /// default that keeps text readable. Set to 0 to disable. + #[arg(long, default_value_t = 1280)] + resize_max: u32, + /// PP-OCR recognition model for `--draft-source ppocr-rec`. Default + /// resolves to `pp-ocrv5_mobile_rec.onnx` in CWD; pass a full path to + /// override. + #[arg(long, default_value = "pp-ocrv5_mobile_rec.onnx")] + ppocr_rec_model: PathBuf, + /// PP-OCR recognition dictionary for `--draft-source ppocr-rec`. Default + /// resolves to `ppocrv5_dict.txt` in CWD; pass a full path to override. + #[arg(long, default_value = "ppocrv5_dict.txt")] + ppocr_dict_path: PathBuf, + /// PP-OCR recognition score threshold. + #[arg(long, default_value_t = 0.0)] + ppocr_score_thresh: f32, + /// PP-OCR recognition max text length. + #[arg(long, default_value_t = 200)] + ppocr_max_text_length: usize, + /// Layout model for `--draft-source structure`. Default resolves to + /// `pp-doclayout_plus-l.onnx` in CWD; pass a full path to override. + #[arg(long, default_value = "pp-doclayout_plus-l.onnx")] + structure_layout_model: PathBuf, + /// Layout model preset for `--draft-source structure`. + #[arg(long, default_value = "pp-doclayout_plus-l")] + structure_layout_model_name: String, + /// PP-DocBlockLayout model for structure reading order. **Required for + /// paper-equivalent multi-column AAL** — without it, reading order + /// degrades to bbox `(y, x)` sort and Stage-2 acceptance collapses on + /// multi-column pages. Default resolves to `pp-docblocklayout.onnx` in + /// the current working directory; pass a full path (e.g. + /// `/some/dir/pp-docblocklayout.onnx`) to use a model placed elsewhere, + /// or pass an empty string (`--structure-region-model ""`) to explicitly + /// disable. + #[arg(long, default_value = "pp-docblocklayout.onnx")] + structure_region_model: PathBuf, + /// PP-OCR detection model for structure OCR. Default resolves to + /// `pp-ocrv5_mobile_det.onnx` in CWD; pass a full path to override. + #[arg(long, default_value = "pp-ocrv5_mobile_det.onnx")] + structure_ocr_det_model: PathBuf, + /// PP-OCR recognition model for structure OCR. Default resolves to + /// `pp-ocrv5_mobile_rec.onnx` in CWD; pass a full path to override. + #[arg(long, default_value = "pp-ocrv5_mobile_rec.onnx")] + structure_ocr_rec_model: PathBuf, + /// PP-OCR dictionary for structure OCR. Default resolves to + /// `ppocrv5_dict.txt` in CWD; pass a full path to override. + #[arg(long, default_value = "ppocrv5_dict.txt")] + structure_ocr_dict_path: PathBuf, + /// Table classifier for structure table routing. + /// + /// Default resolves to `pp-lcnet_x1_0_table_cls.onnx` in CWD; pass a full + /// path to load from elsewhere. Pass an empty string to opt out — but + /// doing so forces every detected table region to a single structure + /// model (wired *or* wireless) and produces no draft for the unused + /// branch. Without table coverage the matrix Stage-1 region kind table + /// shows 0/N drafts for `table`, leaving acceptance to drop to 0% on + /// table-heavy pages. + #[arg(long, default_value = "pp-lcnet_x1_0_table_cls.onnx")] + structure_table_cls_model: PathBuf, + /// Wired table structure model (SLANeXt) for structure table HTML. + /// Default: `slanext_wired.onnx` in CWD; pass empty string to skip. + #[arg(long, default_value = "slanext_wired.onnx")] + structure_wired_table_model: PathBuf, + /// Wireless table structure model (SLANet+) for structure table HTML. + /// Default: `slanet_plus.onnx` in CWD; pass empty string to skip. + #[arg(long, default_value = "slanet_plus.onnx")] + structure_wireless_table_model: PathBuf, + /// Table structure dictionary for structure table HTML. Default: + /// `table_structure_dict_ch.txt` (PaddleX-compatible bilingual dict) in + /// CWD; pass empty string to skip. + #[arg(long, default_value = "table_structure_dict_ch.txt")] + structure_table_dict_path: PathBuf, + /// Wired table cell detection model (RT-DETR-L). Default: + /// `rt-detr-l_wired_table_cell_det.onnx` in CWD; pass empty string to + /// skip cell detection (table structure still works without it, but + /// reduced fidelity on complex tables). + #[arg(long, default_value = "rt-detr-l_wired_table_cell_det.onnx")] + structure_wired_cell_model: PathBuf, + /// Wireless table cell detection model (RT-DETR-L). Default: + /// `rt-detr-l_wireless_table_cell_det.onnx` in CWD; pass empty string to + /// skip. + #[arg(long, default_value = "rt-detr-l_wireless_table_cell_det.onnx")] + structure_wireless_cell_model: PathBuf, + /// Formula model for structure LaTeX drafts. + /// + /// Default: `pp-formulanet_plus-l.onnx` in CWD (732MB) for accuracy on the + /// quality matrix — Plus-S has noticeably worse argmax behavior on + /// OmniDocBench academic pages (e.g. `\breve` vs `\check`, dropped + /// subscripts) and is recommended only for smoke / perf runs. Pass + /// `pp-formulanet_plus-s.onnx` (232MB) explicitly if disk/RAM is tight. + /// Empty string skips formula drafting and drops formula AAL to 0 on + /// academic pages. + #[arg(long, default_value = "pp-formulanet_plus-l.onnx")] + structure_formula_model: PathBuf, + /// Formula tokenizer for structure LaTeX drafts. Default: + /// `pp-formulanet-tokenizer.json` in CWD; pass empty string to skip + /// (must match the choice of `--structure-formula-model`). + #[arg(long, default_value = "pp-formulanet-tokenizer.json")] + structure_formula_tokenizer: PathBuf, + /// Formula model type for structure LaTeX drafts. + #[arg(long, default_value = "pp_formulanet")] + structure_formula_type: String, + /// Device for structure formula recognition. Defaults to the drafter + /// device through the global structure ORT session; pass `cpu`, `cuda`, + /// or `cuda:N` to override just formula recognition. + #[arg(long)] + structure_formula_device: Option, + /// Preferred formula recognition batch size for structure drafts. + #[arg(long, default_value_t = 8)] + structure_formula_batch_size: usize, + /// Maximum decoded formula length for structure drafts. + #[arg(long, default_value_t = 1536)] + structure_formula_max_length: usize, + /// Strict IoU floor for cross-category structure → region matches. The + /// previous "max IoU wins regardless of type" policy is preserved at this + /// floor as a safety net for cases where the structure pipeline assigns + /// an unexpected type to the matching region. + #[arg(long, default_value_t = 0.8)] + structure_iou_threshold: f32, + /// Relaxed IoU floor for same-`semantic_category` structure → region + /// matches. Since the type pre-filter bounds poisoning risk, a lower + /// floor here can improve coverage on partially-overlapping regions where + /// the structure pipeline and OmniDocBench layout disagree on the exact + /// bbox, but the 2026-05-06 30-page run regressed AAL/SR at 0.5. The + /// conservative default therefore matches `--structure-iou-threshold`. + #[arg(long, default_value_t = 0.8)] + structure_same_category_iou: f32, + /// Allow table/formula/chart regions to fall back to generic layout OCR + /// text when specialized structure output is missing. + #[arg(long, default_value_t = false)] + structure_allow_generic_fallback: bool, + /// Batch size for structure region OCR. + #[arg(long, default_value_t = 8)] + structure_region_batch_size: usize, + /// Where to write a per-page CSV log. Default: `/hsd_results.csv`. + #[arg(long)] + output_csv: Option, + /// Where to write an aggregate markdown summary. Default: `.md`. + #[arg(long)] + output_summary: Option, + } + + #[derive(Debug, Deserialize)] + struct OmniEntry { + layout_dets: Vec, + page_info: PageInfo, + #[serde(default)] + #[allow(dead_code)] + extra: serde_json::Value, + } + + #[derive(Debug, Deserialize)] + struct LayoutDet { + #[serde(default)] + category_type: String, + #[serde(default)] + ignore: bool, + /// Reading-order index. May be `null` for skipped/abandoned regions. + order: Option, + /// Recognised text. Often `""` for non-text regions (figure/table). + #[serde(default)] + text: String, + /// OmniDocBench quadrilateral `[x1,y1,x2,y2,x3,y3,x4,y4]`. + #[serde(default)] + poly: Vec, + } + + #[derive(Debug, Deserialize)] + struct PageInfo { + image_path: String, + #[allow(dead_code)] + page_no: Option, + #[allow(dead_code)] + height: Option, + #[allow(dead_code)] + width: Option, + #[serde(default)] + page_attribute: serde_json::Value, + } + + /// JSON format consumed by `--cross-vlm-draft-file`. Lets the bench exercise + /// the [`crate::hsd::drafting::convert_raw_to_target_adapter`] un-postprocess + /// path without loading a second VLM in-process. + /// + /// The producer is expected to be a separate run of another VLM (e.g. + /// PaddleOCR-VL) that called `decode_tokens_raw` per region and serialized + /// the pre-postprocess raw text. The bench then assigns those texts onto the + /// target backend's layout elements (matched by bbox IoU), and the target's + /// `TargetDraftAdapter` does the rest of the surface conversion (e.g. + /// HTML↔OTSL, `$$ ... $$` wrapping). + /// + /// Minimal schema: + /// ```json + /// { + /// "source_backend": "paddleocr_vl", + /// "pages": { + /// "page-001.png": [ + /// {"bbox": [10.0, 20.0, 200.0, 50.0], "raw_text": "$$x = 1$$"} + /// ] + /// } + /// } + /// ``` + /// `source_backend` is informational only — the target adapter handles + /// per-element form so the source hint is not required for correctness. + #[derive(Debug, Clone, Deserialize)] + struct CrossVlmDraftFile { + #[allow(dead_code)] + #[serde(default)] + source_backend: Option, + pages: HashMap>, + } + + #[derive(Debug, Clone, Deserialize)] + struct CrossVlmRegion { + /// `[x_min, y_min, x_max, y_max]` in original image pixel coordinates. + bbox: [f32; 4], + /// Pre-postprocess decoded string from the source backend (use that + /// backend's `decode_tokens_raw`, not `decode_tokens`). + raw_text: String, + } + + impl CrossVlmDraftFile { + fn load(path: &std::path::Path) -> Result> { + let bytes = std::fs::read(path).map_err(|e| { + format!( + "failed to read --cross-vlm-draft-file {}: {e}", + path.display() + ) + })?; + let parsed: Self = serde_json::from_slice(&bytes).map_err(|e| { + format!( + "failed to parse --cross-vlm-draft-file {}: {e}", + path.display() + ) + })?; + Ok(parsed) + } + + /// Look up the per-page region list. Tries the full image path first, + /// then the basename (so callers can use either convention). + fn lookup_page(&self, image_path: &str) -> Option<&[CrossVlmRegion]> { + if let Some(regions) = self.pages.get(image_path) { + return Some(regions.as_slice()); + } + let basename = std::path::Path::new(image_path) + .file_name() + .and_then(|n| n.to_str())?; + self.pages.get(basename).map(Vec::as_slice) + } + } + + /// Axis-aligned IoU between two `[x_min, y_min, x_max, y_max]` rectangles. + /// Returns 0.0 when either box has zero area. + fn axis_aligned_iou(a: &[f32; 4], b: &[f32; 4]) -> f32 { + let (ax0, ay0, ax1, ay1) = (a[0], a[1], a[2], a[3]); + let (bx0, by0, bx1, by1) = (b[0], b[1], b[2], b[3]); + let area_a = ((ax1 - ax0).max(0.0)) * ((ay1 - ay0).max(0.0)); + let area_b = ((bx1 - bx0).max(0.0)) * ((by1 - by0).max(0.0)); + if area_a <= 0.0 || area_b <= 0.0 { + return 0.0; + } + let ix0 = ax0.max(bx0); + let iy0 = ay0.max(by0); + let ix1 = ax1.min(bx1); + let iy1 = ay1.min(by1); + let iw = (ix1 - ix0).max(0.0); + let ih = (iy1 - iy0).max(0.0); + let inter = iw * ih; + let union = area_a + area_b - inter; + if union <= 0.0 { 0.0 } else { inter / union } + } + + /// Find the cross-VLM region whose bbox best matches `elem_bbox` (max IoU + /// above `iou_threshold`). Returns `None` if no candidate clears the bar. + fn match_cross_vlm_region<'a>( + elem_bbox: &[f32; 4], + candidates: &'a [CrossVlmRegion], + iou_threshold: f32, + ) -> Option<&'a CrossVlmRegion> { + let mut best: Option<(&CrossVlmRegion, f32)> = None; + for cand in candidates { + let iou = axis_aligned_iou(elem_bbox, &cand.bbox); + if iou < iou_threshold { + continue; + } + match best { + Some((_, best_iou)) if best_iou >= iou => {} + _ => best = Some((cand, iou)), + } + } + best.map(|(c, _)| c) + } + + #[derive(Clone, Copy)] + struct DraftFormat { + heading_prefix: bool, + wrap_formulas: bool, + formula_newlines: bool, + space_after_sec_dot: bool, + separator_after_page_number: bool, + } + + impl DraftFormat { + fn markdown() -> Self { + Self { + heading_prefix: true, + wrap_formulas: true, + formula_newlines: true, + space_after_sec_dot: false, + separator_after_page_number: false, + } + } + + fn hunyuanocr_aligned() -> Self { + Self { + heading_prefix: false, + wrap_formulas: false, + formula_newlines: false, + space_after_sec_dot: true, + separator_after_page_number: true, + } + } + } + + fn align_hunyuan_heading(text: &str, fmt: DraftFormat) -> String { + if !fmt.space_after_sec_dot { + return text.to_string(); + } + if let Some(rest) = text.strip_prefix("SEC.") + && rest.chars().next().is_some_and(|ch| ch.is_ascii_digit()) + { + return format!("SEC. {rest}"); + } + text.to_string() + } + + /// Heuristic markdown-style serialisation for one page's draft. + fn build_draft(entry: &OmniEntry, fmt: DraftFormat) -> String { + let mut dets: Vec<&LayoutDet> = entry + .layout_dets + .iter() + .filter(|d| d.order.is_some()) + .filter(|d| !matches!(d.category_type.as_str(), "abandon" | "text_mask")) + .collect(); + dets.sort_by_key(|d| d.order.unwrap()); + + let mut out = String::new(); + for d in dets { + let text = d.text.trim(); + if text.is_empty() { + continue; + } + let category = d.category_type.as_str(); + let formatted = match category { + "title" | "header" if fmt.heading_prefix => format!("# {text}"), + "title" | "header" => align_hunyuan_heading(text, fmt), + "equation_isolated" | "equation_semantic" => { + if !fmt.wrap_formulas || text.starts_with("$$") || text.starts_with("\\[") { + text.to_string() + } else if fmt.formula_newlines { + format!("$$\n{text}\n$$") + } else { + format!("$${text}$$") + } + } + _ => text.to_string(), + }; + if !out.is_empty() { + out.push_str("\n\n"); + } + out.push_str(&formatted); + if fmt.separator_after_page_number && category == "page_number" { + out.push_str("\n\n---"); + } + } + out + } + + fn build_plain_draft(entry: &OmniEntry) -> String { + let mut dets: Vec<&LayoutDet> = entry + .layout_dets + .iter() + .filter(|d| d.order.is_some()) + .filter(|d| !d.ignore) + .filter(|d| !is_mask_or_abandon(d.category_type.as_str())) + .collect(); + dets.sort_by_key(|d| d.order.unwrap()); + + dets.into_iter() + .map(|d| d.text.trim()) + .filter(|s| !s.is_empty()) + .collect::>() + .join("\n") + } + + fn build_spotting_draft(entry: &OmniEntry, image_width: u32, image_height: u32) -> String { + let mut dets: Vec<&LayoutDet> = entry + .layout_dets + .iter() + .filter(|d| d.order.is_some()) + .filter(|d| !d.ignore) + .filter(|d| !is_mask_or_abandon(d.category_type.as_str())) + .collect(); + dets.sort_by_key(|d| d.order.unwrap()); + + let mut out = String::new(); + for d in dets { + let text = d.text.trim(); + if text.is_empty() { + continue; + } + let Some(loc_tokens) = poly_loc_tokens(&d.poly, image_width, image_height) else { + continue; + }; + if !out.is_empty() { + out.push('\n'); + } + out.push_str(text); + out.push_str(&loc_tokens); + } + out + } + + fn build_gt_draft( + entry: &OmniEntry, + backend: Backend, + task: Task, + image_size: (u32, u32), + hunyuanocr_legacy_gt_format: bool, + ) -> String { + match backend { + Backend::HunyuanOcr => build_draft( + entry, + if hunyuanocr_legacy_gt_format { + DraftFormat::markdown() + } else { + DraftFormat::hunyuanocr_aligned() + }, + ), + Backend::MinerU | Backend::GlmOcr => build_plain_draft(entry), + Backend::PaddleOcrVl => match task { + Task::Spotting => build_spotting_draft(entry, image_size.0, image_size.1), + Task::Ocr | Task::Seal => build_plain_draft(entry), + Task::Table | Task::Chart | Task::Formula => { + build_draft(entry, DraftFormat::markdown()) + } + }, + } + } + + fn target_draft_adapter(backend: Backend, task: Task) -> TargetDraftAdapter { + // Pick the target VLM's natural draft surface so structure / cross-VLM + // drafts get auto-normalized via the adapter (HTML↔OTSL for tables, + // formula wrapper handling, heading shell, etc.). PaddleOCR-VL is + // element-only and uses the same `PaddleOcrVl` adapter regardless of + // task — the adapter dispatches on element kind, not task. + let _ = task; + match backend { + Backend::HunyuanOcr => TargetDraftAdapter::HunyuanOcr, + Backend::MinerU => TargetDraftAdapter::MinerU, + Backend::GlmOcr => TargetDraftAdapter::GlmOcr, + Backend::PaddleOcrVl => TargetDraftAdapter::PaddleOcrVl, + } + } + + fn tokenize_draft( + tokenizer: &Tokenizer, + draft: &str, + ) -> Result, Box> { + tokenizer + .encode(draft, false) + .map(|enc| enc.get_ids().to_vec()) + .map_err(|e| format!("tokenizer encode failed: {e}").into()) + } + + fn token_piece(tokenizer: &Tokenizer, token: u32) -> String { + tokenizer + .decode(&[token], false) + .unwrap_or_else(|_| format!("")) + .replace('\n', "\\n") + } + + fn count_window_hits(reference: &[u32], target: &[u32], window_len: usize) -> (usize, usize) { + if window_len == 0 || reference.len() < window_len { + return (0, 0); + } + let total = reference.len() - window_len + 1; + if target.len() < window_len { + return (0, total); + } + let hits = reference + .windows(window_len) + .filter(|w| target.windows(window_len).any(|dw| dw == *w)) + .count(); + (hits, total) + } + + fn best_per_draft_window_hits( + reference: &[u32], + drafts: &[Vec], + window_len: usize, + ) -> (usize, usize) { + drafts + .iter() + .map(|draft| count_window_hits(reference, draft, window_len)) + .max_by_key(|(hits, _)| *hits) + .unwrap_or(( + 0, + reference.len().saturating_sub(window_len).saturating_add(1), + )) + } + + struct TokenDiffInput<'a> { + run_row_idx: usize, + candidate_idx: usize, + image_path: &'a str, + backend: Backend, + mode: Mode, + draft_source: &'a str, + baseline_text: &'a str, + draft_text: &'a str, + baseline_tokens: &'a [u32], + draft_tokens: &'a [u32], + structure_elements: Option<&'a [LayoutElement]>, + hsd_page_draft_count: usize, + region_draft_count: usize, + per_draft_max_hits: Option<(usize, usize)>, + limit: usize, + window_len: usize, + } + + fn preview_text(input: &str, limit: usize) -> String { + input + .chars() + .take(limit) + .collect::() + .replace('\n', "\\n") + } + + fn append_token_diff_report( + out: &mut String, + tokenizer: &Tokenizer, + input: TokenDiffInput<'_>, + ) { + let common = input + .baseline_tokens + .iter() + .zip(input.draft_tokens.iter()) + .take_while(|(a, b)| a == b) + .count(); + let (hits, total) = + count_window_hits(input.baseline_tokens, input.draft_tokens, input.window_len); + let hit_rate = if total > 0 { + hits as f64 / total as f64 + } else { + 0.0 + }; + out.push_str(&format!( + "# HSD Token Diff\n\n\ + | field | value |\n\ + |---|---|\n\ + | run row idx | {} |\n\ + | candidate idx | {} |\n\ + | image | {} |\n\ + | backend | {} |\n\ + | mode | {:?} |\n\ + | draft source | {} |\n\ + | baseline chars | {} |\n\ + | draft chars | {} |\n\ + | baseline tokens | {} |\n\ + | draft tokens | {} |\n\ + | common token prefix | {} |\n\ + | HSD page draft count | {} |\n\ + | diagnostic region draft count | {} |\n\ + | baseline {}-gram hits in concatenated/page draft | {}/{} ({:.3}) |\n", + input.run_row_idx, + input.candidate_idx, + input.image_path, + input.backend.as_str(), + input.mode, + input.draft_source, + input.baseline_text.chars().count(), + input.draft_text.chars().count(), + input.baseline_tokens.len(), + input.draft_tokens.len(), + common, + input.hsd_page_draft_count, + input.region_draft_count, + input.window_len, + hits, + total, + hit_rate + )); + if let Some((per_hits, per_total)) = input.per_draft_max_hits { + let per_rate = if per_total > 0 { + per_hits as f64 / per_total as f64 + } else { + 0.0 + }; + out.push_str(&format!( + "| best single-region {}-gram hits | {}/{} ({:.3}) |\n", + input.window_len, per_hits, per_total, per_rate + )); + } + out.push('\n'); + + if let Some(elements) = input.structure_elements { + out.push_str("## Structure Element Order\n\n"); + out.push_str("| idx | type | bbox | text preview |\n"); + out.push_str("|---:|---|---|---|\n"); + for (idx, elem) in elements.iter().take(40).enumerate() { + let text = elem + .text + .as_deref() + .map(|s| preview_text(s, 120)) + .unwrap_or_default(); + out.push_str(&format!( + "| {} | {:?} | [{:.0},{:.0},{:.0},{:.0}] | `{}` |\n", + idx, + elem.element_type, + elem.bbox.x_min(), + elem.bbox.y_min(), + elem.bbox.x_max(), + elem.bbox.y_max(), + text.replace('`', "\\`") + )); + } + out.push('\n'); + } + + out.push_str("## First Differing Tokens\n\n"); + out.push_str("| idx | baseline id | baseline piece | draft id | draft piece |\n"); + out.push_str("|---:|---:|---|---:|---|\n"); + let n = input + .limit + .min(input.baseline_tokens.len().max(input.draft_tokens.len())); + for i in common.saturating_sub(5)..n { + let b = input.baseline_tokens.get(i).copied(); + let d = input.draft_tokens.get(i).copied(); + let bp = b + .map(|t| token_piece(tokenizer, t)) + .unwrap_or_else(|| "".to_string()); + let dp = d + .map(|t| token_piece(tokenizer, t)) + .unwrap_or_else(|| "".to_string()); + out.push_str(&format!( + "| {} | {} | `{}` | {} | `{}` |\n", + i, + b.map(|t| t.to_string()).unwrap_or_default(), + bp.replace('`', "\\`"), + d.map(|t| t.to_string()).unwrap_or_default(), + dp.replace('`', "\\`") + )); + } + + let bp: String = input.baseline_text.chars().take(1000).collect(); + let dp: String = input.draft_text.chars().take(1000).collect(); + out.push_str("\n## Baseline Text Preview\n\n```text\n"); + out.push_str(&bp); + out.push_str("\n```\n\n## Draft Text Preview\n\n```text\n"); + out.push_str(&dp); + out.push_str("\n```\n"); + } + + fn page_attr<'a>(entry: &'a OmniEntry, key: &str) -> &'a str { + entry + .page_info + .page_attribute + .as_object() + .and_then(|m| m.get(key)) + .and_then(|v| v.as_str()) + .unwrap_or("") + } + + fn prompt_for_entry<'a>(entry: &OmniEntry, args: &'a Args) -> (&'a str, &'static str) { + if args.auto_prompt_lang && page_attr(entry, "language").eq_ignore_ascii_case("chinese") { + (HUNYUAN_CHINESE_PARSING_PROMPT, "hunyuanocr_parsing_zh") + } else { + (args.instruction.as_str(), "hunyuanocr_parsing_en") + } + } + + fn hunyuanocr_region_prompt<'a>(entry: &OmniEntry, args: &'a Args) -> (&'a str, &'static str) { + if args.auto_prompt_lang && page_attr(entry, "language").eq_ignore_ascii_case("chinese") { + (HUNYUAN_CHINESE_REGION_PROMPT, "hunyuanocr_region_zh") + } else { + ( + args.hunyuanocr_region_instruction.as_str(), + "hunyuanocr_region_en", + ) + } + } + + fn glmocr_prompt(args: &Args) -> (&str, &'static str) { + (args.glmocr_instruction.as_str(), "glmocr_text_recognition") + } + + fn glmocr_region_prompt(args: &Args) -> (&str, &'static str) { + ( + args.glmocr_region_instruction.as_str(), + "glmocr_region_text_recognition", + ) + } + + fn mineru_prompt(args: &Args) -> (&str, &'static str) { + (args.mineru_instruction.as_str(), "mineru_text_recognition") + } + + fn mineru_region_prompt(args: &Args) -> (&str, &'static str) { + // `--mineru-two-step` forces empty region prompt, which MinerU's + // `generate_hsd_full` interprets as "dispatch per-element via + // `MinerUTaskPrompt::for_layout`" (matches the official `two_step_extract` + // flow). + if args.mineru_two_step { + ("", "mineru_two_step_per_element") + } else { + ( + args.mineru_region_instruction.as_str(), + "mineru_region_text_recognition", + ) + } + } + + fn csv_escape(value: impl AsRef) -> String { + let value = value.as_ref(); + if value.contains([',', '"', '\n', '\r']) { + format!("\"{}\"", value.replace('"', "\"\"")) + } else { + value.to_string() + } + } + + fn csv_row(fields: &[String]) -> String { + let mut row = fields.iter().map(csv_escape).collect::>().join(","); + row.push('\n'); + row + } + + fn require_hsd_elements<'a>( + elements: Option<&'a [LayoutElement]>, + backend: &'static str, + ) -> Result<&'a [LayoutElement], OCRError> { + let Some(elements) = elements.filter(|elements| !elements.is_empty()) else { + return Err(OCRError::InvalidInput { + message: format!( + "{backend} page dual-stage requires non-empty HSD layout elements" + ), + }); + }; + Ok(elements) + } + + fn layout_kind_bucket(t: LayoutElementType) -> &'static str { + match t { + LayoutElementType::Table => "table", + LayoutElementType::Formula | LayoutElementType::FormulaNumber => "formula", + LayoutElementType::Image + | LayoutElementType::Chart + | LayoutElementType::Seal + | LayoutElementType::HeaderImage + | LayoutElementType::FooterImage => "visual", + LayoutElementType::DocTitle + | LayoutElementType::ParagraphTitle + | LayoutElementType::FigureTitle + | LayoutElementType::TableTitle + | LayoutElementType::ChartTitle + | LayoutElementType::FigureTableChartTitle => "title", + LayoutElementType::Header | LayoutElementType::Footer | LayoutElementType::Number => { + "page_artifact" + } + LayoutElementType::List => "list", + LayoutElementType::Text + | LayoutElementType::Content + | LayoutElementType::Abstract + | LayoutElementType::AsideText + | LayoutElementType::Reference + | LayoutElementType::ReferenceContent + | LayoutElementType::Footnote => "text", + _ => "other", + } + } + + fn region_kind_buckets(elements: &[LayoutElement]) -> String { + let mut counts: BTreeMap<&'static str, (usize, usize)> = BTreeMap::new(); + for elem in elements { + let entry = counts + .entry(layout_kind_bucket(elem.element_type)) + .or_insert((0, 0)); + entry.0 += 1; + if elem + .text + .as_deref() + .is_some_and(|text| !text.trim().is_empty()) + { + entry.1 += 1; + } + } + counts + .into_iter() + .map(|(kind, (total, drafted))| format!("{kind}:{drafted}/{total}")) + .collect::>() + .join(";") + } + + fn region_kind_name(kind: RegionKind) -> &'static str { + match kind { + RegionKind::Text => "text", + RegionKind::Title => "title", + RegionKind::List => "list", + RegionKind::Table => "table", + RegionKind::Formula => "formula", + RegionKind::Figure => "visual", + RegionKind::Header | RegionKind::Footer => "page_artifact", + RegionKind::Other => "other", + } + } + + fn stage1_region_kind_stats(stats: &HsdStats) -> String { + let mut by_kind: BTreeMap<&'static str, (u32, u32, u32, u32)> = BTreeMap::new(); + for region in &stats.stage1_regions { + let entry = by_kind.entry(region_kind_name(region.kind)).or_default(); + entry.0 += 1; + entry.1 += region.stats.accept.num_steps; + entry.2 += region.stats.accept.num_fallbacks; + entry.3 += region + .stats + .accept + .per_step_accepted + .iter() + .copied() + .sum::(); + } + by_kind + .into_iter() + .map(|(kind, (regions, steps, fallbacks, accepted_sum))| { + let aal = if steps == 0 { + 0.0 + } else { + accepted_sum as f32 / steps as f32 + }; + let fallback_rate = if steps == 0 { + 0.0 + } else { + fallbacks as f32 / steps as f32 + }; + format!("{kind}:regions={regions},aal={aal:.2},fallback={fallback_rate:.3}") + }) + .collect::>() + .join(";") + } + + fn is_mask_or_abandon(category: &str) -> bool { + matches!(category, "abandon" | "text_mask") || category.ends_with("_mask") + } + + fn poly_loc_tokens(poly: &[f32], image_width: u32, image_height: u32) -> Option { + if poly.len() < 8 { + return None; + } + if image_width == 0 || image_height == 0 { + return None; + } + let mut out = String::new(); + for i in 0..4 { + let x = loc_index(poly[2 * i], image_width); + let y = loc_index(poly[2 * i + 1], image_height); + out.push_str(&format!("<|LOC_{x}|><|LOC_{y}|>")); + } + Some(out) + } + + fn loc_index(coord: f32, extent: u32) -> i32 { + ((coord * 1000.0 / extent as f32).round() as i32).clamp(0, 1000) + } + + fn build_layout_elements( + entry: &OmniEntry, + x_scale: f32, + y_scale: f32, + normalize_text: bool, + require_text: bool, + ) -> Vec { + let mut dets: Vec<&LayoutDet> = entry + .layout_dets + .iter() + .filter(|d| d.order.is_some()) + .filter(|d| !d.ignore) + .filter(|d| !is_mask_or_abandon(d.category_type.as_str())) + .filter(|d| !require_text || !d.text.trim().is_empty()) + .filter(|d| d.poly.len() >= 8) + .filter(|d| valid_scaled_poly(&d.poly, x_scale, y_scale)) + .collect(); + dets.sort_by_key(|d| d.order.unwrap()); + + dets.into_iter() + .map(|d| { + let bbox = BoundingBox::new( + (0..4) + .map(|i| Point::new(d.poly[2 * i], d.poly[2 * i + 1])) + .map(|p| Point::new(p.x * x_scale, p.y * y_scale)) + .collect(), + ); + let element_type = layout_type_from_omni_category(d.category_type.as_str()); + let mut elem = LayoutElement::new(bbox, element_type, 1.0); + elem.label = Some(d.category_type.clone()); + let text = d.text.trim(); + if !text.is_empty() { + elem.text = Some(if normalize_text { + normalize_paddleocr_vl_region_draft(text) + } else { + text.to_string() + }); + } + elem.order_index = d.order.map(|x| x as u32); + elem + }) + .collect() + } + + fn normalize_paddleocr_vl_region_draft(input: &str) -> String { + let mut s = input.to_string(); + let replacements = [ + ("\u{00a0}", " "), + ("\t", " "), + ("\r\n", "\n"), + ("\r", "\n"), + ("\u{2010}", "-"), + ("\u{2011}", "-"), + ("\u{2012}", "-"), + ("\u{2013}", "-"), + ("\u{2014}", "-"), + ("\u{2212}", "-"), + ("\u{2026}", "..."), + ("\u{2018}", "'"), + ("\u{2019}", "'"), + ("\u{201c}", "\""), + ("\u{201d}", "\""), + ("\u{2217}", "*"), + ("\u{00d7}", "x"), + ("\u{2022}", "-"), + ("\u{25cf}", "-"), + ("\u{25aa}", "-"), + ]; + for (from, to) in replacements { + s = s.replace(from, to); + } + + s = collapse_horizontal_space(&s); + s = normalize_math_operator_spaces(&s); + s = normalize_dash_between_alnums(&s); + s = normalize_latex_inline_wrappers(&s); + s.trim().to_string() + } + + fn collapse_horizontal_space(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let mut last_was_space = false; + for ch in input.chars() { + if ch == ' ' { + if !last_was_space { + out.push(ch); + } + last_was_space = true; + } else { + out.push(ch); + last_was_space = false; + } + } + out + } + + fn normalize_math_operator_spaces(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let chars: Vec = input.chars().collect(); + let mut i = 0; + while i < chars.len() { + if chars[i] == ' ' && should_remove_space_around_operator(&chars, i) { + i += 1; + continue; + } + out.push(chars[i]); + i += 1; + } + out + } + + fn should_remove_space_around_operator(chars: &[char], i: usize) -> bool { + let prev = previous_non_space(chars, i); + let next = next_non_space(chars, i + 1); + match (prev, next) { + (Some(a), Some(b)) => { + is_math_operator(a) || is_math_operator(b) || (is_numericish(a) && is_numericish(b)) + } + _ => false, + } + } + + fn previous_non_space(chars: &[char], mut i: usize) -> Option { + while i > 0 { + i -= 1; + if chars[i] != ' ' { + return Some(chars[i]); + } + } + None + } + + fn next_non_space(chars: &[char], mut i: usize) -> Option { + while i < chars.len() { + if chars[i] != ' ' { + return Some(chars[i]); + } + i += 1; + } + None + } + + fn is_math_operator(ch: char) -> bool { + matches!( + ch, + '<' | '>' | '=' | '+' | '-' | '±' | '≤' | '≥' | '×' | '÷' | '/' | '*' + ) + } + + fn is_numericish(ch: char) -> bool { + ch.is_ascii_digit() || matches!(ch, '.' | ',' | '%' | '<' | '>' | '=' | '+' | '-') + } + + fn normalize_dash_between_alnums(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + let chars: Vec = input.chars().collect(); + for (i, ch) in chars.iter().copied().enumerate() { + if ch == '-' + && i > 0 + && i + 1 < chars.len() + && chars[i - 1].is_ascii_alphanumeric() + && chars[i + 1].is_ascii_alphanumeric() + { + out.push(' '); + } else { + out.push(ch); + } + } + out + } + + fn normalize_latex_inline_wrappers(input: &str) -> String { + let mut s = input.trim().to_string(); + loop { + let trimmed = s.trim(); + let unwrapped = if trimmed.starts_with("$") + && trimmed.ends_with("$") + && trimmed.len() > 2 + { + Some(trimmed[1..trimmed.len() - 1].trim()) + } else if trimmed.starts_with("\\(") && trimmed.ends_with("\\)") && trimmed.len() > 4 { + Some(trimmed[2..trimmed.len() - 2].trim()) + } else { + None + }; + match unwrapped { + Some(inner) => s = inner.to_string(), + None => break, + } + } + s + } + + fn valid_scaled_poly(poly: &[f32], x_scale: f32, y_scale: f32) -> bool { + if poly.len() < 8 { + return false; + } + let xs = [ + poly[0] * x_scale, + poly[2] * x_scale, + poly[4] * x_scale, + poly[6] * x_scale, + ]; + let ys = [ + poly[1] * y_scale, + poly[3] * y_scale, + poly[5] * y_scale, + poly[7] * y_scale, + ]; + let w = xs.iter().copied().fold(f32::NEG_INFINITY, f32::max) + - xs.iter().copied().fold(f32::INFINITY, f32::min); + let h = ys.iter().copied().fold(f32::NEG_INFINITY, f32::max) + - ys.iter().copied().fold(f32::INFINITY, f32::min); + if w < 2.0 || h < 2.0 { + return false; + } + let ratio = (w / h).max(h / w); + ratio <= 200.0 + } + + fn layout_type_from_omni_category(category: &str) -> LayoutElementType { + match category { + "title" => LayoutElementType::DocTitle, + "header" => LayoutElementType::Header, + "footer" => LayoutElementType::Footer, + "page_number" => LayoutElementType::Number, + "text_block" => LayoutElementType::Text, + "list_group" => LayoutElementType::List, + "table" => LayoutElementType::Table, + "table_caption" => LayoutElementType::TableTitle, + "table_footnote" => LayoutElementType::Footnote, + "figure" => LayoutElementType::Image, + "figure_caption" | "figure_footnote" => LayoutElementType::FigureTitle, + "chart" => LayoutElementType::Chart, + "equation_isolated" | "equation_semantic" => LayoutElementType::Formula, + "equation_caption" | "equation_explanation" => LayoutElementType::Text, + "reference" => LayoutElementType::Reference, + "code_txt" | "code_txt_caption" => LayoutElementType::Text, + _ => LayoutElementType::Other, + } + } + + fn paddleocr_vl_task_for_layout_type(t: LayoutElementType) -> Option { + match t { + LayoutElementType::Table => Some(PaddleOcrVlTask::Table), + LayoutElementType::Chart => Some(PaddleOcrVlTask::Chart), + LayoutElementType::Formula => Some(PaddleOcrVlTask::Formula), + LayoutElementType::Image + | LayoutElementType::HeaderImage + | LayoutElementType::FooterImage + | LayoutElementType::Seal => None, + _ => Some(PaddleOcrVlTask::Ocr), + } + } + + fn run_paddleocr_vl_region_baseline( + model: &PaddleOcrVl, + image: &image::RgbImage, + elements: &[LayoutElement], + max_tokens: usize, + ) -> Result> { + let mut outputs = Vec::new(); + let mut per_element = vec![None; elements.len()]; + let mut per_element_tokens = vec![None; elements.len()]; + for (idx, elem) in elements.iter().enumerate() { + let Some(task) = paddleocr_vl_task_for_layout_type(elem.element_type) else { + continue; + }; + let crop = BBoxCrop::crop_bounding_box(image, &elem.bbox)?; + let tokens = model + .generate_tokens(&[crop], &[task], max_tokens) + .into_iter() + .next() + .ok_or("PaddleOCR-VL region baseline returned no result")??; + let (_, result) = model.decode_tokens(&tokens, task)?; + if !result.trim().is_empty() { + let trimmed = result.trim().to_string(); + per_element[idx] = Some(trimmed.clone()); + per_element_tokens[idx] = Some(tokens); + outputs.push(trimmed); + } + } + Ok(RegionDrafts { + joined: outputs.join("\n\n"), + per_element, + per_element_tokens, + }) + } + + fn run_ppocr_rec_drafter( + predictor: &TextRecognitionPredictor, + image: &image::RgbImage, + elements: &[LayoutElement], + ) -> Result> { + let mut crops = Vec::new(); + let mut indices = Vec::new(); + for (idx, elem) in elements.iter().enumerate() { + if elem + .text + .as_ref() + .map(|s| s.trim().is_empty()) + .unwrap_or(true) + { + continue; + } + let crop = BBoxCrop::crop_rotated_bounding_box(image, &elem.bbox) + .or_else(|_| BBoxCrop::crop_bounding_box(image, &elem.bbox))?; + crops.push(crop); + indices.push(idx); + } + + let mut per_element = vec![None; elements.len()]; + if crops.is_empty() { + return Ok(RegionDrafts { + joined: String::new(), + per_element, + per_element_tokens: vec![None; elements.len()], + }); + } + + let output = predictor.predict(crops)?; + let mut joined = Vec::new(); + for ((idx, text), score) in indices.into_iter().zip(output.texts).zip(output.scores) { + let text = text.trim().to_string(); + if text.is_empty() || score <= 0.0 { + continue; + } + per_element[idx] = Some(text.clone()); + joined.push(text); + } + Ok(RegionDrafts { + joined: joined.join("\n\n"), + per_element, + per_element_tokens: vec![None; elements.len()], + }) + } + + fn build_structure_drafter(args: &Args) -> Result> { + let mut builder = OARStructureBuilder::new(&args.structure_layout_model) + .layout_model_name(&args.structure_layout_model_name) + .with_ocr( + &args.structure_ocr_det_model, + &args.structure_ocr_rec_model, + &args.structure_ocr_dict_path, + ) + .region_batch_size(args.structure_region_batch_size); + + if let Some(ort_cfg) = parse_ort_device(drafter_device(args))? { + builder = builder.ort_session(ort_cfg); + } + // PP-DocBlockLayout reading-order model. Required for paper-equivalent + // multi-column AAL. Empty path = explicit user opt-out; missing file = + // warn-and-fallback to bbox (y, x) sort. + let region_path = &args.structure_region_model; + if region_path.as_os_str().is_empty() { + eprintln!( + "[WARN] --structure-region-model is empty: reading-order falls back to bbox (y,x) sort. \ + Multi-column pages may miss Stage-2 acceptance; pass the PP-DocBlockLayout ONNX path to recover paper AAL." + ); + } else if !region_path.exists() { + eprintln!( + "[WARN] --structure-region-model '{}' not found on disk: reading-order falls back to bbox (y,x) sort. \ + Multi-column pages may miss Stage-2 acceptance.", + region_path.display() + ); + } else { + builder = builder + .with_region_detection(region_path) + .region_model_name("pp-docblocklayout"); + } + // Sub-model resolver: empty path = explicit user opt-out, non-empty + + // missing file = warn-and-skip. Returns None when the path should be + // dropped so the call-site can branch. + fn resolve_submodel<'a>(label: &str, path: &'a PathBuf) -> Option<&'a PathBuf> { + if path.as_os_str().is_empty() { + return None; + } + if !path.exists() { + eprintln!( + "[WARN] {label} '{}' not found on disk: structure drafter will skip this sub-model. \ + Affected regions produce 0 drafts and target VLM autoregresses through them.", + path.display() + ); + return None; + } + Some(path) + } + + if let Some(path) = resolve_submodel( + "--structure-table-cls-model", + &args.structure_table_cls_model, + ) { + builder = builder.with_table_classification(path); + } + // Table structure dict applies to both wired and wireless models. Resolve + // once so the wired/wireless branches can borrow it without re-checking. + let table_dict = resolve_submodel( + "--structure-table-dict-path", + &args.structure_table_dict_path, + ); + if let Some(path) = resolve_submodel( + "--structure-wired-table-model", + &args.structure_wired_table_model, + ) { + let Some(dict) = table_dict else { + return Err( + "--structure-wired-table-model requires --structure-table-dict-path (or pass empty string to skip wired tables)".into(), + ); + }; + builder = builder + .with_wired_table_structure(path) + .wired_table_structure_model_name("slanext_wired") + .table_structure_dict_path(dict); + } + if let Some(path) = resolve_submodel( + "--structure-wireless-table-model", + &args.structure_wireless_table_model, + ) { + let Some(dict) = table_dict else { + return Err( + "--structure-wireless-table-model requires --structure-table-dict-path (or pass empty string to skip wireless tables)".into(), + ); + }; + builder = builder + .with_wireless_table_structure(path) + .wireless_table_structure_model_name("slanet_plus") + .table_structure_dict_path(dict); + } + if let Some(path) = resolve_submodel( + "--structure-wired-cell-model", + &args.structure_wired_cell_model, + ) { + builder = builder + .with_wired_table_cell_detection(path) + .wired_table_cell_model_name("rtdetr-l_wired_table_cell_det"); + } + if let Some(path) = resolve_submodel( + "--structure-wireless-cell-model", + &args.structure_wireless_cell_model, + ) { + builder = builder + .with_wireless_table_cell_detection(path) + .wireless_table_cell_model_name("rtdetr-l_wireless_table_cell_det"); + } + // Formula recognition requires both the ONNX model and its tokenizer. + // Treat them as a single unit: both present → enable; either missing → skip. + let formula_model = + resolve_submodel("--structure-formula-model", &args.structure_formula_model); + let formula_tokenizer = resolve_submodel( + "--structure-formula-tokenizer", + &args.structure_formula_tokenizer, + ); + match (formula_model, formula_tokenizer) { + (Some(path), Some(tokenizer)) => { + builder = builder + .with_formula_recognition(path, tokenizer, &args.structure_formula_type) + .formula_recognition_config(FormulaRecognitionConfig { + score_threshold: 0.0, + max_length: args.structure_formula_max_length, + batch_size: args.structure_formula_batch_size, + }); + if let Some(formula_device) = &args.structure_formula_device { + builder = + builder.formula_ort_session(parse_required_ort_device(formula_device)?); + } + } + (Some(_), None) => { + return Err( + "--structure-formula-model requires --structure-formula-tokenizer (or pass both as empty strings to skip formula drafting)".into(), + ); + } + (None, Some(_)) => { + // Tokenizer without model is a no-op; emit a hint but don't fail. + eprintln!( + "[WARN] --structure-formula-tokenizer set but --structure-formula-model is empty/missing: \ + skipping formula drafting. Formula regions will produce 0 drafts." + ); + } + (None, None) => {} // both opt-out + } + + Ok(builder.build()?) + } + + fn run_structure_drafter( + structure: &OARStructure, + image: &image::RgbImage, + elements: &[LayoutElement], + th: MatchThresholds, + ) -> Result> { + let result = structure.predict_image(image.clone())?; + Ok(match_structure_to_regions(&result, elements, th)) + } + + fn run_structure_page_drafter( + structure: &OARStructure, + image: &image::RgbImage, + ) -> Result> { + Ok(structure.predict_image(image.clone())?) + } + + // Thin wrapper kept for readability at call sites; the real implementation + // lives in `oar_ocr_vl::hsd::drafting::structure_result_to_layout_elements` + // so other consumers can share the same OAR-structure → HSD-element bridge. + fn structure_result_hsd_elements(result: &StructureResult) -> Vec { + oar_ocr_vl::hsd::drafting::structure_result_to_layout_elements(result) + } + + fn match_structure_to_regions( + result: &StructureResult, + elements: &[LayoutElement], + th: MatchThresholds, + ) -> RegionDrafts { + let mut per_element = vec![None; elements.len()]; + let mut joined = Vec::new(); + + for (idx, elem) in elements.iter().enumerate() { + let draft = match_region(result, elem, th) + .map(|m| m.text.trim().to_string()) + .filter(|s| !s.is_empty()); + if let Some(text) = draft { + per_element[idx] = Some(text.clone()); + joined.push(text); + } + } + + RegionDrafts { + joined: joined.join("\n\n"), + per_element, + per_element_tokens: vec![None; elements.len()], + } + } + + fn parse_ort_device( + device: &str, + ) -> Result, Box> { + let device_lower = device.to_lowercase(); + if device_lower == "cpu" { + return Ok(None); + } + + #[cfg(feature = "cuda")] + { + use oar_ocr_core::core::config::OrtExecutionProvider; + if device_lower.starts_with("cuda") { + let device_id = if device_lower == "cuda" { + 0 + } else if let Some(id_str) = device_lower.strip_prefix("cuda:") { + id_str.parse::()? + } else { + return Err(format!("invalid device: {device}").into()); + }; + return Ok(Some(OrtSessionConfig::new().with_execution_providers( + vec![ + OrtExecutionProvider::CUDA { + device_id: Some(device_id), + gpu_mem_limit: None, + arena_extend_strategy: None, + cudnn_conv_algo_search: None, + cudnn_conv_use_max_workspace: None, + }, + OrtExecutionProvider::CPU, + ], + ))); + } + } + + #[cfg(not(feature = "cuda"))] + { + if device_lower.starts_with("cuda") { + return Err("CUDA requested for PP-OCR but cuda feature is not enabled".into()); + } + } + + Err(format!("unsupported device for PP-OCR drafter: {device}").into()) + } + + fn parse_required_ort_device( + device: &str, + ) -> Result> { + let device_lower = device.to_lowercase(); + if device_lower == "cpu" { + use oar_ocr_core::core::config::OrtExecutionProvider; + return Ok( + OrtSessionConfig::new().with_execution_providers(vec![OrtExecutionProvider::CPU]) + ); + } + + parse_ort_device(device)? + .ok_or_else(|| format!("unsupported explicit ONNX Runtime device: {device}").into()) + } + + fn drafter_device(args: &Args) -> &str { + args.drafter_device + .as_deref() + .unwrap_or(args.device.as_str()) + } + + pub fn run() -> Result<(), Box> { + utils::init_tracing(); + let args = Args::parse(); + + // Parse the bench JSON. + let json_path = args.bench_dir.join("OmniDocBench.json"); + println!("Reading {}", json_path.display()); + let bytes = std::fs::read(&json_path)?; + let entries: Vec = serde_json::from_slice(&bytes)?; + println!("Loaded {} entries", entries.len()); + + // Load backend. + let device = parse_device(&args.device)?; + println!( + "Loading {} from {}", + args.backend.as_str(), + args.model_dir.display() + ); + let model = match args.backend { + Backend::HunyuanOcr => { + BackendModel::HunyuanOcr(HunyuanOcr::from_dir(&args.model_dir, device)?) + } + Backend::PaddleOcrVl => { + BackendModel::PaddleOcrVl(PaddleOcrVl::from_dir(&args.model_dir, device)?) + } + Backend::MinerU => BackendModel::MinerU(MinerU::from_dir(&args.model_dir, device)?), + Backend::GlmOcr => BackendModel::GlmOcr(GlmOcr::from_dir(&args.model_dir, device)?), + }; + let paddleocr_vl_task = args.task.to_native(); + let ppocr_rec = if args.draft_source == "ppocr-rec" { + if !matches!(args.mode, Mode::Region) || !matches!(args.backend, Backend::PaddleOcrVl) { + return Err( + "--draft-source ppocr-rec requires --backend paddleocr_vl --mode region".into(), + ); + } + println!( + "Loading PP-OCR rec drafter from {}", + args.ppocr_rec_model.display() + ); + let mut builder = TextRecognitionPredictor::builder() + .score_threshold(args.ppocr_score_thresh) + .max_text_length(args.ppocr_max_text_length) + .dict_path(&args.ppocr_dict_path); + if let Some(ort_cfg) = parse_ort_device(drafter_device(&args))? { + builder = builder.with_ort_config(ort_cfg); + } + Some(builder.build(&args.ppocr_rec_model)?) + } else { + None + }; + let structure_drafter = if args.draft_source == "structure" { + if matches!(args.mode, Mode::Region) + && !matches!(args.backend, Backend::HunyuanOcr | Backend::PaddleOcrVl) + { + return Err( + "--draft-source structure requires --backend hunyuanocr or paddleocr_vl with --mode region".into(), + ); + } + println!( + "Loading structure drafter with layout model {}", + args.structure_layout_model.display() + ); + Some(build_structure_drafter(&args)?) + } else { + None + }; + + let cross_vlm_drafts = if args.draft_source == "cross-vlm-file" { + let path = args + .cross_vlm_draft_file + .as_ref() + .ok_or("--draft-source cross-vlm-file requires --cross-vlm-draft-file ")?; + println!("Loading cross-VLM drafts from {}", path.display()); + let parsed = CrossVlmDraftFile::load(path)?; + if let Some(name) = parsed.source_backend.as_deref() { + println!(" source_backend = {name}, pages = {}", parsed.pages.len()); + } else { + println!(" pages = {}", parsed.pages.len()); + } + Some(parsed) + } else { + if args.cross_vlm_draft_file.is_some() { + eprintln!( + "[warn] --cross-vlm-draft-file is set but --draft-source is not cross-vlm-file; ignoring." + ); + } + None + }; + + // Per-page log. + let csv_path = args + .output_csv + .clone() + .unwrap_or_else(|| args.bench_dir.join("hsd_results.csv")); + let summary_path = args.output_summary.clone().unwrap_or_else(|| { + let mut p = csv_path.clone().into_os_string(); + p.push(".md"); + PathBuf::from(p) + }); + let mut csv = String::from( + "page_idx,image,subset,language,backend,mode,task,device,drafter_device,draft_source,page_dual_stage,hsd_entry,regions,draft_regions,draft_coverage,region_kind_buckets,stage1_region_kind_stats,draft_3gram_hit_rate,tau,max_tokens,resize_max,start_idx,prompt_kind,prompt,region_prompt_kind,region_prompt,baseline_ms,hsd_ms,drafter_ms,decode_ms,prefill_ms,stage1_decode_ms,stage1_prefill_ms,stage1_verify_steps,stage1_fallback_steps,stage1_aal,stage2_decode_ms,stage2_prefill_ms,stage2_verify_steps,stage2_fallback_steps,stage2_aal,dsv_candidate_ms,dsv_verify_ms,dsv_traverse_ms,dsv_commit_ms,dsv_step_one_ms,dsv_fallback_argmax_ms,dsv_verify_calls,dsv_step_one_calls,dsv_fallback_argmax_calls,dsv_avg_candidates,dsv_max_candidates,dsv_empty_tree_calls,dsv_rejected_tree_calls,dsv_accepted_tree_calls,dsv_avg_tree_nodes,dsv_max_tree_nodes,emitted_tokens,verify_steps,fallback_steps,aal,sr_decode,sr_e2e\n", + ); + + let mut dsv = DsvConfig { + tau: args.tau, + ..Default::default() + }; + // Per-knob CLI overrides. Each defaults to 0 = "honour the preset"; any + // non-zero value wins so a user can do e.g. `--dsv-window-len 2` on top of + // `--config-preset omnibench` to relax matching on divergent drafts + // without having to copy the rest of the preset. + if args.dsv_window_len > 0 { + dsv.window_len = args.dsv_window_len; + } + if args.dsv_max_candidates > 0 { + dsv.max_candidates_per_step = args.dsv_max_candidates; + } + if args.dsv_max_suffix_len > 0 { + dsv.max_suffix_len = args.dsv_max_suffix_len; + } + let cfg = HsdConfig { + dsv, + enable_stage1: matches!(args.mode, Mode::Region) + || (matches!(args.mode, Mode::Page) && args.page_dual_stage), + enable_stage2: true, + max_page_tokens: args.max_tokens, + max_region_tokens: args.max_tokens, + }; + if matches!(args.mode, Mode::Region) + && !matches!(args.backend, Backend::HunyuanOcr | Backend::PaddleOcrVl) + { + return Err( + "--mode region currently supports --backend hunyuanocr or paddleocr_vl".into(), + ); + } + + // Build the candidate pool with optional substring + subset filter. + let candidates: Vec<&OmniEntry> = entries + .iter() + .filter(|e| match &args.filter { + Some(s) => e.page_info.image_path.contains(s), + None => true, + }) + .filter(|e| match &args.subset { + Some(want) => { + let got = e + .page_info + .page_attribute + .as_object() + .and_then(|m| m.get("subset")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + got == want + } + None => true, + }) + .filter(|e| match &args.language { + Some(want) => { + let got = e + .page_info + .page_attribute + .as_object() + .and_then(|m| m.get("language")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + got == want + } + None => true, + }) + .skip(args.start_idx) + .collect(); + let n_pages = args.max_pages.min(candidates.len()); + println!( + "Running HSD on {} of {} candidates (backend = {}, mode = {:?}, task = {:?}, draft = {}, filter = {:?}, start = {}, τ = {}, max_tokens = {})", + n_pages, + candidates.len(), + args.backend.as_str(), + args.mode, + args.task, + args.draft_source, + args.filter, + args.start_idx, + args.tau, + args.max_tokens + ); + + let mut sum_baseline_ms: f64 = 0.0; + let mut sum_hsd_ms: f64 = 0.0; + let mut sum_decode_ms: f64 = 0.0; + let mut sum_prefill_ms: f64 = 0.0; + let mut sum_drafter_ms: f64 = 0.0; + let mut sum_aal: f64 = 0.0; + let mut sum_sr_decode: f64 = 0.0; + let mut sum_sr_e2e: f64 = 0.0; + let mut sum_emitted: u64 = 0; + let mut sum_steps: u64 = 0; + let mut sum_fallbacks: u64 = 0; + let mut sum_dsv = SpecDecodeStats::default(); + let mut low_aal_outliers = 0u32; + let mut sr_e2e_outliers = 0u32; + let mut worst_aal = f32::INFINITY; + let mut worst_sr_e2e = f64::INFINITY; + let mut worst_aal_page = String::new(); + let mut worst_sr_e2e_page = String::new(); + let mut hsd_entry_counts: BTreeMap<&'static str, u32> = BTreeMap::new(); + let mut counted = 0u32; + let mut skipped = 0u32; + + println!( + "\n{:>4} {:>9} {:>9} {:>5} {:>5} {:>5} | {:<60}", + "idx", "base_ms", "hsd_ms", "AAL", "ndec", "fb", "page" + ); + println!("{}", "-".repeat(110)); + + for (i, entry) in candidates.iter().take(n_pages).enumerate() { + let candidate_idx = args.start_idx + i; + let img_path = args + .bench_dir + .join("images") + .join(&entry.page_info.image_path); + if !img_path.exists() { + if args.skip_missing { + skipped += 1; + continue; + } else { + return Err(format!("missing image: {}", img_path.display()).into()); + } + } + let mut image = match load_image(&img_path) { + Ok(img) => img, + Err(e) => { + if args.skip_missing { + eprintln!("[skip] {} ({e})", entry.page_info.image_path); + skipped += 1; + continue; + } else { + return Err(e.into()); + } + } + }; + let gt_image_size = image.dimensions(); + let mut x_scale = 1.0f32; + let mut y_scale = 1.0f32; + if args.resize_max > 0 { + let long = image.width().max(image.height()); + if long > args.resize_max { + let scale = args.resize_max as f32 / long as f32; + let nw = (image.width() as f32 * scale).round() as u32; + let nh = (image.height() as f32 * scale).round() as u32; + x_scale = nw as f32 / image.width() as f32; + y_scale = nh as f32 / image.height() as f32; + // Match upstream Python's pre-resize filter (PIL.Image.LANCZOS) + // — using CatmullRom here was the dominant source of pixel + // drift vs the Python pipeline (per-patch cos 0.998 → + // 1.000 after this change), which compounded into ~18 % + // divergence at the prefill's last-token logits. + image = image::imageops::resize(&image, nw, nh, FilterType::Lanczos3); + } + } + let need_layout_elements = matches!(args.mode, Mode::Region) + || (matches!(args.mode, Mode::Page) + && args.page_dual_stage + && args.draft_source == "gt"); + let elements = if need_layout_elements { + let normalize_text = args.normalize_draft && args.draft_source == "gt"; + let require_text = matches!(args.draft_source.as_str(), "gt" | "ppocr-rec"); + build_layout_elements(entry, x_scale, y_scale, normalize_text, require_text) + } else { + Vec::new() + }; + let draft = match args.mode { + Mode::Page => build_gt_draft( + entry, + args.backend, + args.task, + gt_image_size, + args.hunyuanocr_legacy_gt_format, + ), + Mode::Region => elements + .iter() + .filter_map(|e| e.text.as_ref()) + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect::>() + .join("\n\n"), + }; + if draft.trim().is_empty() { + let draft_filled_later = matches!( + args.draft_source.as_str(), + "baseline" | "structure" | "cross-vlm-file" + ) && (matches!(args.mode, Mode::Page) + || matches!(args.mode, Mode::Region)); + if !draft_filled_later { + // No usable draft text — skip rather than running with empty draft + // (which would just be a baseline run with extra overhead). + skipped += 1; + continue; + } + } + let (hunyuanocr_prompt, hunyuanocr_prompt_kind) = prompt_for_entry(entry, &args); + let (hunyuanocr_region_prompt, hunyuanocr_region_prompt_kind) = + hunyuanocr_region_prompt(entry, &args); + let (mineru_prompt, mineru_prompt_kind) = mineru_prompt(&args); + let (mineru_region_prompt, mineru_region_prompt_kind) = mineru_region_prompt(&args); + let (glmocr_prompt, glmocr_prompt_kind) = glmocr_prompt(&args); + let (glmocr_region_prompt, glmocr_region_prompt_kind) = glmocr_region_prompt(&args); + let (prompt_text, prompt_kind) = match args.backend { + Backend::HunyuanOcr => (hunyuanocr_prompt, hunyuanocr_prompt_kind), + Backend::PaddleOcrVl => (paddleocr_vl_task.prompt(), "paddleocr_vl_task"), + Backend::MinerU => (mineru_prompt, mineru_prompt_kind), + Backend::GlmOcr => (glmocr_prompt, glmocr_prompt_kind), + }; + let (region_prompt_text, region_prompt_kind) = match args.backend { + Backend::HunyuanOcr => (hunyuanocr_region_prompt, hunyuanocr_region_prompt_kind), + Backend::PaddleOcrVl => (paddleocr_vl_task.prompt(), "paddleocr_vl_region_task"), + Backend::MinerU => (mineru_region_prompt, mineru_region_prompt_kind), + Backend::GlmOcr => (glmocr_region_prompt, glmocr_region_prompt_kind), + }; + + // Baseline. + let t0 = Instant::now(); + let mut region_baseline_drafts: Option>> = None; + let mut region_baseline_token_drafts: Option>>> = None; + let mut baseline_tokens: Option> = None; + let baseline_result: Result> = + match (&model, args.mode) { + (BackendModel::HunyuanOcr(model), Mode::Page) => { + let toks_result = model + .generate_tokens( + &[image.clone()], + &[hunyuanocr_prompt], + args.max_tokens, + ) + .into_iter() + .next() + .ok_or("baseline returned no results")?; + let toks = toks_result?; + baseline_tokens = Some(toks.clone()); + model.decode_tokens(&toks).map_err(|e| e.into()) + } + (BackendModel::PaddleOcrVl(model), Mode::Page) => { + let toks_result = model + .generate_tokens( + &[image.clone()], + &[paddleocr_vl_task], + args.max_tokens, + ) + .into_iter() + .next() + .ok_or("baseline returned no results")?; + let toks = toks_result?; + baseline_tokens = Some(toks.clone()); + model + .decode_tokens(&toks, paddleocr_vl_task) + .map(|(_, processed)| processed) + .map_err(|e| e.into()) + } + (BackendModel::MinerU(model), Mode::Page) => { + let toks_result = model + .generate_tokens(&[image.clone()], &[mineru_prompt], args.max_tokens) + .into_iter() + .next() + .ok_or("baseline returned no results")?; + let toks = toks_result?; + baseline_tokens = Some(toks.clone()); + model.decode_tokens(&toks).map_err(|e| e.into()) + } + (BackendModel::GlmOcr(model), Mode::Page) => { + let toks_result = model + .generate_tokens(&[image.clone()], &[glmocr_prompt], args.max_tokens) + .into_iter() + .next() + .ok_or("baseline returned no results")?; + let toks = toks_result?; + baseline_tokens = Some(toks.clone()); + model.decode_tokens(&toks).map_err(|e| e.into()) + } + (BackendModel::PaddleOcrVl(model), Mode::Region) => { + let drafts = run_paddleocr_vl_region_baseline( + model, + &image, + &elements, + args.max_tokens, + )?; + region_baseline_drafts = Some(drafts.per_element); + region_baseline_token_drafts = Some(drafts.per_element_tokens); + Ok(drafts.joined) + } + (BackendModel::HunyuanOcr(model), Mode::Region) => { + let toks_result = model + .generate_tokens( + &[image.clone()], + &[hunyuanocr_prompt], + args.max_tokens, + ) + .into_iter() + .next() + .ok_or("baseline returned no results")?; + let toks = toks_result?; + baseline_tokens = Some(toks.clone()); + model.decode_tokens(&toks).map_err(|e| e.into()) + } + (_, Mode::Region) => Err( + "--mode region currently supports --backend hunyuanocr or paddleocr_vl" + .into(), + ), + }; + let baseline_dur = t0.elapsed(); + let baseline_text = match baseline_result { + Ok(s) => s, + Err(e) => { + eprintln!( + "[skip] baseline failed for {}: {}", + entry.page_info.image_path, e + ); + skipped += 1; + continue; + } + }; + if args.preview > 0 { + let bp: String = baseline_text.chars().take(args.preview).collect(); + let dp: String = draft.chars().take(args.preview).collect(); + println!( + "--- BASELINE ({} chars total) ---\n{bp}\n", + baseline_text.len() + ); + println!("--- DRAFT ({} chars total) ---\n{dp}\n", draft.len()); + } + + // Pick draft per --draft-source. + let mut hsd_elements: Option> = None; + let mut external_drafter_dur = Duration::ZERO; + let mut diagnostic_region_drafts: Vec = Vec::new(); + let actual_draft = match args.draft_source.as_str() { + "gt" => { + if matches!(args.mode, Mode::Page) && args.page_dual_stage { + hsd_elements = Some(elements.clone()); + } + draft.clone() + } + "baseline" if matches!(args.mode, Mode::Region) => { + let per_element = region_baseline_drafts + .as_ref() + .ok_or("missing region baseline drafts")?; + let mut oracle_elements = elements.clone(); + for (elem, baseline) in oracle_elements.iter_mut().zip(per_element.iter()) { + if let Some(text) = baseline { + elem.text = Some(text.clone()); + } + } + hsd_elements = Some(oracle_elements); + baseline_text.clone() + } + "ppocr-rec" if matches!(args.mode, Mode::Region) => { + let predictor = ppocr_rec + .as_ref() + .ok_or("missing PP-OCR recognition drafter")?; + let t_drafter = Instant::now(); + let ppocr_drafts = run_ppocr_rec_drafter(predictor, &image, &elements)?; + external_drafter_dur += t_drafter.elapsed(); + let mut drafter_elements = elements.clone(); + for (elem, draft) in drafter_elements + .iter_mut() + .zip(ppocr_drafts.per_element.iter()) + { + if let Some(text) = draft { + elem.text = Some(text.clone()); + } + } + hsd_elements = Some(drafter_elements); + ppocr_drafts.joined + } + "structure" if matches!(args.mode, Mode::Region) => { + let structure = structure_drafter + .as_ref() + .ok_or("missing structure drafter")?; + let t_drafter = Instant::now(); + let structure_drafts = run_structure_drafter( + structure, + &image, + &elements, + MatchThresholds::new( + args.structure_same_category_iou, + args.structure_iou_threshold, + args.structure_allow_generic_fallback, + ), + )?; + external_drafter_dur += t_drafter.elapsed(); + let mut drafter_elements = elements.clone(); + for (elem, draft) in drafter_elements + .iter_mut() + .zip(structure_drafts.per_element.iter()) + { + elem.text = draft.clone(); + } + hsd_elements = Some(drafter_elements); + structure_drafts.joined + } + "structure" if matches!(args.mode, Mode::Page) => { + let structure = structure_drafter + .as_ref() + .ok_or("missing structure drafter")?; + let t_drafter = Instant::now(); + let result = run_structure_page_drafter(structure, &image)?; + external_drafter_dur += t_drafter.elapsed(); + let elems = structure_result_hsd_elements(&result); + let adapter = target_draft_adapter(args.backend, args.task); + diagnostic_region_drafts = region_markdowns_for(&elems, &[], adapter); + hsd_elements = Some(elems); + page_markdown_for(hsd_elements.as_deref().unwrap_or(&[]), &[], adapter) + } + "baseline" => baseline_text.clone(), + "cross-vlm-file" => { + // Per-page raw drafts from another VLM. The target backend's + // adapter handles surface conversion (HTML↔OTSL, formula + // wrapping); we just match elements by bbox IoU and stash the + // raw text on `elem.text` so downstream `generate_hsd_full` + // picks it up via `region_markdown_for`. + let cross = cross_vlm_drafts + .as_ref() + .ok_or("missing cross-VLM drafts (loaded earlier?)")?; + let regions = cross.lookup_page(&entry.page_info.image_path); + let mut drafter_elements = elements.clone(); + let mut matched = 0usize; + if let Some(regions) = regions { + for elem in drafter_elements.iter_mut() { + let elem_bbox = bbox_xyxy(&elem.bbox); + if let Some(region) = match_cross_vlm_region( + &elem_bbox, + regions, + args.cross_vlm_iou_threshold, + ) { + elem.text = Some(region.raw_text.clone()); + matched += 1; + } else { + // No matching cross-VLM region — drop the text so + // the element won't contribute a stale draft to + // Stage 1 / Stage 2. + elem.text = None; + } + } + } else { + // No page entry — clear all element texts so the bench + // reports `draft_regions=0` rather than silently reusing + // the OmniDocBench GT text on this page. + for elem in drafter_elements.iter_mut() { + elem.text = None; + } + } + if regions.is_none() { + eprintln!( + "[warn] cross-vlm-file has no entry for {}", + entry.page_info.image_path + ); + } + let adapter = target_draft_adapter(args.backend, args.task); + diagnostic_region_drafts = + region_markdowns_for(&drafter_elements, &[], adapter); + let joined = page_markdown_for(&drafter_elements, &[], adapter); + hsd_elements = Some(drafter_elements); + println!( + " cross-vlm-file matched {matched}/{} regions for {}", + elements.len(), + entry.page_info.image_path, + ); + joined + } + other => return Err(format!("unknown --draft-source: {other}").into()), + }; + if actual_draft.trim().is_empty() { + skipped += 1; + continue; + } + + // Build the Stage-2 draft set `Ỹ^pg` per paper Eq. 3. Most draft + // sources are inherently single-document (gt = ground-truth page, + // baseline = the VLM's own page output), so they form a 1-element + // set. The structure+page route splits per layout element so the + // matcher can scan each region draft independently (Eqs. 1+2), + // preserving per-region n-gram locality even when the drafter's + // page-level format diverges from the target VLM's. + let actual_drafts: Vec = if matches!(args.mode, Mode::Page) + && args.draft_source == "structure" + && !diagnostic_region_drafts.is_empty() + { + diagnostic_region_drafts.clone() + } else { + vec![actual_draft.clone()] + }; + + let page_draft_tokens = if matches!(args.mode, Mode::Page) { + Some(match &model { + BackendModel::HunyuanOcr(model) => { + tokenize_draft(model.tokenizer(), &actual_draft)? + } + BackendModel::PaddleOcrVl(model) => { + tokenize_draft(model.tokenizer(), &actual_draft)? + } + BackendModel::MinerU(model) => { + tokenize_draft(model.tokenizer(), &actual_draft)? + } + BackendModel::GlmOcr(model) => { + tokenize_draft(model.tokenizer(), &actual_draft)? + } + }) + } else { + None + }; + let draft_3gram_hit_rate = if let (Some(baseline_tokens), Some(draft_tokens)) = + (baseline_tokens.as_ref(), page_draft_tokens.as_ref()) + { + let (hits, total) = + count_window_hits(baseline_tokens, draft_tokens, args.token_diff_window_len); + if total > 0 { + hits as f64 / total as f64 + } else { + 0.0 + } + } else { + 0.0 + }; + let diagnostic_region_draft_tokens = + if matches!(args.mode, Mode::Page) && !diagnostic_region_drafts.is_empty() { + let tokenizer = match &model { + BackendModel::HunyuanOcr(model) => model.tokenizer(), + BackendModel::PaddleOcrVl(model) => model.tokenizer(), + BackendModel::MinerU(model) => model.tokenizer(), + BackendModel::GlmOcr(model) => model.tokenizer(), + }; + Some( + diagnostic_region_drafts + .iter() + .map(|d| tokenize_draft(tokenizer, d)) + .collect::, _>>()?, + ) + } else { + None + }; + let per_draft_max_hits = if let (Some(baseline_tokens), Some(region_drafts)) = ( + baseline_tokens.as_ref(), + diagnostic_region_draft_tokens.as_ref(), + ) { + Some(best_per_draft_window_hits( + baseline_tokens, + region_drafts, + args.token_diff_window_len, + )) + } else { + None + }; + + if let Some(path) = &args.token_diff_output { + if !matches!(args.mode, Mode::Page) { + return Err("--token-diff-output currently supports --mode page only".into()); + } + let baseline_tokens = baseline_tokens + .as_ref() + .ok_or("--token-diff-output requires page-mode baseline tokens")?; + let draft_tokens = page_draft_tokens + .as_ref() + .ok_or("--token-diff-output requires page-mode draft tokens")?; + let tokenizer = match &model { + BackendModel::HunyuanOcr(model) => model.tokenizer(), + BackendModel::PaddleOcrVl(model) => model.tokenizer(), + BackendModel::MinerU(model) => model.tokenizer(), + BackendModel::GlmOcr(model) => model.tokenizer(), + }; + let mut report = String::new(); + append_token_diff_report( + &mut report, + tokenizer, + TokenDiffInput { + run_row_idx: i, + candidate_idx, + image_path: &entry.page_info.image_path, + backend: args.backend, + mode: args.mode, + draft_source: &args.draft_source, + baseline_text: &baseline_text, + draft_text: &actual_draft, + baseline_tokens, + draft_tokens, + structure_elements: hsd_elements.as_deref(), + hsd_page_draft_count: actual_drafts.len(), + region_draft_count: diagnostic_region_drafts.len(), + per_draft_max_hits, + limit: args.token_diff_limit, + window_len: args.token_diff_window_len, + }, + ); + std::fs::write(path, report)?; + println!("Token diff report -> {}", path.display()); + if args.token_diff_only { + return Ok(()); + } + } + + // For page mode, `draft_region_count` is the size of `Ỹ^pg` (the + // Stage-2 draft set per paper Eq. 3) — `actual_drafts.len()` after + // the multi-draft refactor, NOT a 0/1 indicator of "is the draft + // non-empty?". For region mode it's the count of elements with + // non-empty draft text, unchanged. + let hsd_element_count = hsd_elements.as_ref().map_or(elements.len(), Vec::len); + let draft_region_count = if matches!(args.mode, Mode::Page) { + if let Some(elems) = hsd_elements.as_ref().filter(|elems| !elems.is_empty()) { + elems + .iter() + .filter(|e| e.text.as_deref().is_some_and(|s| !s.trim().is_empty())) + .count() + } else { + actual_drafts + .iter() + .filter(|d| !d.trim().is_empty()) + .count() + } + } else { + hsd_elements + .as_ref() + .unwrap_or(&elements) + .iter() + .filter(|e| e.text.as_deref().is_some_and(|s| !s.trim().is_empty())) + .count() + }; + let draft_coverage = if hsd_element_count == 0 { + if matches!(args.mode, Mode::Page) && draft_region_count > 0 { + 1.0 + } else { + 0.0 + } + } else { + draft_region_count as f64 / hsd_element_count as f64 + }; + let region_kind_buckets = hsd_elements + .as_deref() + .map(region_kind_buckets) + .unwrap_or_else(|| region_kind_buckets(&elements)); + + // HSD with the draft. + let t1 = Instant::now(); + let oracle_draft = (args.draft_source == "baseline") + .then(|| { + baseline_tokens + .as_ref() + .map(|t| vec![Draft::new(t.clone())]) + }) + .flatten(); + let (hsd_entry, hsd) = match (&model, args.mode, oracle_draft.as_deref()) { + (BackendModel::HunyuanOcr(model), Mode::Page, Some(token_drafts)) => ( + "hunyuanocr.generate_hsd_with_token_drafts", + model.generate_hsd_with_token_drafts( + &image, + hunyuanocr_prompt, + token_drafts, + &cfg, + ), + ), + (BackendModel::HunyuanOcr(model), Mode::Page, None) => { + if args.page_dual_stage { + let elems = require_hsd_elements(hsd_elements.as_deref(), "hunyuanocr")?; + ( + "hunyuanocr.generate_hsd_full", + model.generate_hsd_full( + &image, + oar_ocr_vl::HunyuanHsdPrompts { + page: hunyuanocr_prompt, + region: hunyuanocr_region_prompt, + }, + elems, + &[], + |elem| elem.text.iter().cloned().collect(), + &cfg, + ), + ) + } else { + ( + "hunyuanocr.generate_hsd", + model.generate_hsd(&image, hunyuanocr_prompt, &actual_drafts, &cfg), + ) + } + } + (BackendModel::PaddleOcrVl(model), Mode::Page, Some(token_drafts)) => ( + "paddleocr_vl.generate_hsd_with_token_drafts", + model.generate_hsd_with_token_drafts( + &image, + paddleocr_vl_task, + token_drafts, + &cfg, + ), + ), + (BackendModel::PaddleOcrVl(model), Mode::Page, None) => ( + "paddleocr_vl.generate_hsd", + model.generate_hsd(&image, paddleocr_vl_task, &actual_drafts, &cfg), + ), + (BackendModel::MinerU(model), Mode::Page, Some(token_drafts)) => ( + "mineru.generate_hsd_with_token_drafts", + model.generate_hsd_with_token_drafts(&image, mineru_prompt, token_drafts, &cfg), + ), + (BackendModel::MinerU(model), Mode::Page, None) => { + if args.page_dual_stage { + let elems = require_hsd_elements(hsd_elements.as_deref(), "mineru")?; + ( + "mineru.generate_hsd_full", + model.generate_hsd_full( + &image, + elems, + &[], + mineru_prompt, + mineru_region_prompt, + &cfg, + ), + ) + } else { + ( + "mineru.generate_hsd", + model.generate_hsd(&image, mineru_prompt, &actual_drafts, &cfg), + ) + } + } + (BackendModel::GlmOcr(model), Mode::Page, Some(token_drafts)) => ( + "glmocr.generate_hsd_with_token_drafts", + model.generate_hsd_with_token_drafts(&image, glmocr_prompt, token_drafts, &cfg), + ), + (BackendModel::GlmOcr(model), Mode::Page, None) => { + if args.page_dual_stage { + let elems = require_hsd_elements(hsd_elements.as_deref(), "glmocr")?; + ( + "glmocr.generate_hsd_full", + model.generate_hsd_full( + &image, + elems, + &[], + glmocr_prompt, + glmocr_region_prompt, + &cfg, + ), + ) + } else { + ( + "glmocr.generate_hsd", + model.generate_hsd(&image, glmocr_prompt, &actual_drafts, &cfg), + ) + } + } + (BackendModel::PaddleOcrVl(model), Mode::Region, _) => { + let elems = hsd_elements.as_deref().unwrap_or(elements.as_slice()); + if args.draft_source == "baseline" { + let token_drafts = + region_baseline_token_drafts.as_ref().ok_or_else(|| { + oar_ocr_core::core::OCRError::InvalidInput { + message: "missing region baseline token drafts".to_string(), + } + })?; + ( + "paddleocr_vl.generate_hsd_full_with_token_drafts", + model.generate_hsd_full_with_token_drafts( + &image, + elems, + &[], + token_drafts, + &cfg, + ), + ) + } else { + ( + "paddleocr_vl.generate_hsd_full", + model.generate_hsd_full(&image, elems, &[], &cfg), + ) + } + } + (BackendModel::HunyuanOcr(model), Mode::Region, _) => { + let elems = hsd_elements.as_deref().unwrap_or(elements.as_slice()); + ( + "hunyuanocr.generate_hsd_full", + model.generate_hsd_full( + &image, + oar_ocr_vl::HunyuanHsdPrompts { + page: hunyuanocr_prompt, + region: hunyuanocr_region_prompt, + }, + elems, + &[], + |elem| elem.text.iter().cloned().collect(), + &cfg, + ), + ) + } + (_, Mode::Region, _) => { + unreachable!( + "--mode region is rejected for unsupported backends before the loop" + ) + } + }; + let hsd = match hsd { + Ok(v) => v, + Err(e) => { + eprintln!("[skip] HSD failed for {}: {e}", entry.page_info.image_path); + let mut cur: Option<&dyn std::error::Error> = std::error::Error::source(&e); + while let Some(s) = cur { + eprintln!(" caused by: {s}"); + cur = s.source(); + } + skipped += 1; + continue; + } + }; + let hsd_dur = t1.elapsed() + external_drafter_dur; + let (_text, mut stats) = hsd; + stats.drafter += external_drafter_dur; + *hsd_entry_counts.entry(hsd_entry).or_insert(0) += 1; + + let baseline_ms = baseline_dur.as_secs_f64() * 1000.0; + let hsd_ms = hsd_dur.as_secs_f64() * 1000.0; + let drafter_ms = stats.drafter.as_secs_f64() * 1000.0; + let stage1_decode_ms = stats.stage1.decode.as_secs_f64() * 1000.0; + let stage1_prefill_ms = stats.stage1.vision_prefill.as_secs_f64() * 1000.0; + let stage1_aal = stats.stage1.accept.aal(); + let stage2_decode_ms = stats.stage2.decode.as_secs_f64() * 1000.0; + let stage2_prefill_ms = stats.stage2.vision_prefill.as_secs_f64() * 1000.0; + let stage2_aal = stats.stage2.accept.aal(); + let stage1_region_kind_stats = stage1_region_kind_stats(&stats); + let stage = match args.mode { + Mode::Page => &stats.stage2, + Mode::Region => &stats.stage1, + }; + let decode_ms = stage.decode.as_secs_f64() * 1000.0; + let prefill_ms = stage.vision_prefill.as_secs_f64() * 1000.0; + let baseline_decode_estimate_ms = (baseline_ms - prefill_ms).max(0.0); + let sr_decode = if decode_ms > 0.0 { + baseline_decode_estimate_ms / decode_ms + } else { + 0.0 + }; + let sr_e2e = if hsd_ms > 0.0 { + baseline_ms / hsd_ms + } else { + 0.0 + }; + let aal = stage.accept.aal(); + + sum_baseline_ms += baseline_ms; + sum_hsd_ms += hsd_ms; + sum_decode_ms += decode_ms; + sum_prefill_ms += prefill_ms; + sum_drafter_ms += drafter_ms; + sum_aal += aal as f64; + sum_sr_decode += sr_decode; + sum_sr_e2e += sr_e2e; + sum_emitted += stage.emitted_tokens as u64; + sum_steps += stage.accept.num_steps as u64; + sum_fallbacks += stage.accept.num_fallbacks as u64; + sum_dsv.add_assign(&stage.dsv); + if aal <= args.outlier_aal_threshold { + low_aal_outliers += 1; + } + if sr_e2e < args.outlier_sr_e2e_threshold { + sr_e2e_outliers += 1; + } + if aal < worst_aal { + worst_aal = aal; + worst_aal_page = entry.page_info.image_path.clone(); + } + if sr_e2e < worst_sr_e2e { + worst_sr_e2e = sr_e2e; + worst_sr_e2e_page = entry.page_info.image_path.clone(); + } + counted += 1; + + let subset = page_attr(entry, "subset"); + let language = page_attr(entry, "language"); + let img_short: String = entry.page_info.image_path.chars().take(60).collect(); + println!( + "{:>4} {:>9.0} {:>9.0} {:>5.1} {:>5} {:>5} | {:<60}", + i, + baseline_ms, + hsd_ms, + aal, + stage.accept.num_steps, + stage.accept.num_fallbacks, + img_short + ); + + csv.push_str(&csv_row(&[ + i.to_string(), + entry.page_info.image_path.clone(), + subset.to_string(), + language.to_string(), + args.backend.as_str().to_string(), + format!("{:?}", args.mode).to_lowercase(), + format!("{:?}", args.task).to_lowercase(), + args.device.clone(), + drafter_device(&args).to_string(), + args.draft_source.clone(), + args.page_dual_stage.to_string(), + hsd_entry.to_string(), + hsd_element_count.to_string(), + draft_region_count.to_string(), + format!("{draft_coverage:.3}"), + region_kind_buckets, + stage1_region_kind_stats, + format!("{draft_3gram_hit_rate:.3}"), + format!("{:.3}", args.tau), + args.max_tokens.to_string(), + args.resize_max.to_string(), + args.start_idx.to_string(), + prompt_kind.to_string(), + prompt_text.to_string(), + region_prompt_kind.to_string(), + region_prompt_text.to_string(), + format!("{baseline_ms:.1}"), + format!("{hsd_ms:.1}"), + format!("{drafter_ms:.1}"), + format!("{decode_ms:.1}"), + format!("{prefill_ms:.1}"), + format!("{stage1_decode_ms:.1}"), + format!("{stage1_prefill_ms:.1}"), + stats.stage1.accept.num_steps.to_string(), + stats.stage1.accept.num_fallbacks.to_string(), + format!("{stage1_aal:.2}"), + format!("{stage2_decode_ms:.1}"), + format!("{stage2_prefill_ms:.1}"), + stats.stage2.accept.num_steps.to_string(), + stats.stage2.accept.num_fallbacks.to_string(), + format!("{stage2_aal:.2}"), + format_duration_ms(stage.dsv.candidate_build), + format_duration_ms(stage.dsv.verify_tree), + format_duration_ms(stage.dsv.traverse), + format_duration_ms(stage.dsv.commit), + format_duration_ms(stage.dsv.step_one), + format_duration_ms(stage.dsv.fallback_argmax), + stage.dsv.verify_tree_calls.to_string(), + stage.dsv.step_one_calls.to_string(), + stage.dsv.fallback_argmax_calls.to_string(), + format!("{:.1}", stage.dsv.avg_candidates()), + stage.dsv.candidates_max.to_string(), + stage.dsv.empty_tree_calls.to_string(), + stage.dsv.rejected_tree_calls.to_string(), + stage.dsv.accepted_tree_calls.to_string(), + format!("{:.1}", stage.dsv.avg_tree_nodes()), + stage.dsv.tree_nodes_max.to_string(), + stage.emitted_tokens.to_string(), + stage.accept.num_steps.to_string(), + stage.accept.num_fallbacks.to_string(), + format!("{aal:.2}"), + format!("{sr_decode:.3}"), + format!("{sr_e2e:.3}"), + ])); + std::fs::write(&csv_path, &csv)?; + } + + if counted == 0 { + eprintln!("No pages produced valid measurements (skipped: {skipped})."); + return Err("nothing measured".into()); + } + let n = counted as f64; + + println!("{}", "-".repeat(110)); + println!( + "\n=== AGGREGATE ({} pages, {} skipped) ===", + counted, skipped + ); + println!("baseline e2e (mean): {:.1} ms", sum_baseline_ms / n); + println!("HSD e2e (mean): {:.1} ms", sum_hsd_ms / n); + println!("drafter (mean): {:.1} ms", sum_drafter_ms / n); + println!("HSD decode (mean): {:.1} ms", sum_decode_ms / n); + println!("HSD prefill (mean): {:.1} ms", sum_prefill_ms / n); + println!("emitted tokens (mean): {}", sum_emitted / counted as u64); + println!("verify steps (mean): {}", sum_steps / counted as u64); + println!("fallback steps (mean): {}", sum_fallbacks / counted as u64); + println!("AAL (mean): {:.2}", sum_aal / n); + println!( + "low-AAL outliers: {} (AAL <= {:.2})", + low_aal_outliers, args.outlier_aal_threshold + ); + println!( + "SR_e2e outliers: {} (SR_e2e < {:.2})", + sr_e2e_outliers, args.outlier_sr_e2e_threshold + ); + println!( + "DSV candidate (mean): {:.1} ms", + sum_dsv.candidate_build.as_secs_f64() * 1000.0 / n + ); + println!( + "DSV verify_tree mean: {:.1} ms (calls/page {:.1}, avg nodes {:.1}, max nodes {})", + sum_dsv.verify_tree.as_secs_f64() * 1000.0 / n, + sum_dsv.verify_tree_calls as f64 / n, + sum_dsv.avg_tree_nodes(), + sum_dsv.tree_nodes_max + ); + println!( + "DSV candidates: avg {:.1}, max {}, empty/reject/accept {} / {} / {}", + sum_dsv.avg_candidates(), + sum_dsv.candidates_max, + sum_dsv.empty_tree_calls, + sum_dsv.rejected_tree_calls, + sum_dsv.accepted_tree_calls + ); + println!( + "DSV traverse/commit: {:.1} / {:.1} ms", + sum_dsv.traverse.as_secs_f64() * 1000.0 / n, + sum_dsv.commit.as_secs_f64() * 1000.0 / n + ); + println!( + "DSV step_one mean: {:.1} ms (calls/page {:.1})", + sum_dsv.step_one.as_secs_f64() * 1000.0 / n, + sum_dsv.step_one_calls as f64 / n + ); + println!(); + println!("SR_decode (mean): {:.2}×", sum_sr_decode / n); + println!("SR_e2e (mean): {:.2}×", sum_sr_e2e / n); + // Throughput-style aggregate (sum baseline / sum HSD). + let sr_e2e_total = sum_baseline_ms / sum_hsd_ms; + let sr_decode_total = (sum_baseline_ms - sum_prefill_ms).max(0.0) / sum_decode_ms; + println!("SR_e2e (total time): {:.2}×", sr_e2e_total); + println!("SR_decode (total): {:.2}×", sr_decode_total); + + let fallback_rate = if sum_steps > 0 { + sum_fallbacks as f64 / sum_steps as f64 + } else { + 0.0 + }; + let hsd_entries = hsd_entry_counts + .iter() + .map(|(entry, count)| format!("{entry}={count}")) + .collect::>() + .join(", "); + let summary = format!( + "# HSD OmniDocBench Summary\n\n\ + ## Run\n\n\ + | field | value |\n\ + |---|---|\n\ + | backend | {backend} |\n\ + | mode | {mode:?} |\n\ + | task | {task:?} |\n\ + | device | {device} |\n\ + | drafter device | {drafter_device} |\n\ + | draft source | {draft_source} |\n\ + | page dual stage | {page_dual_stage} |\n\ + | HSD entries | {hsd_entries} |\n\ + | region prompt | {region_prompt_kind} |\n\ + | tau | {tau:.3} |\n\ + | max tokens | {max_tokens} |\n\ + | resize max | {resize_max} |\n\ + | start idx | {start_idx} |\n\ + | max pages | {max_pages} |\n\ + | subset filter | {subset_filter} |\n\ + | language filter | {language_filter} |\n\ + | outlier AAL threshold | {outlier_aal_threshold:.2} |\n\ + | outlier SR_e2e threshold | {outlier_sr_e2e_threshold:.2} |\n\ + | CSV | {csv_path} |\n\n\ + ## Aggregate\n\n\ + | metric | value |\n\ + |---|---:|\n\ + | measured pages | {counted} |\n\ + | skipped pages | {skipped} |\n\ + | baseline e2e mean ms | {baseline_mean:.1} |\n\ + | HSD e2e mean ms | {hsd_mean:.1} |\n\ + | drafter mean ms | {drafter_mean:.1} |\n\ + | HSD decode mean ms | {decode_mean:.1} |\n\ + | HSD prefill mean ms | {prefill_mean:.1} |\n\ + | DSV candidate mean ms | {dsv_candidate_mean:.1} |\n\ + | DSV verify_tree mean ms | {dsv_verify_mean:.1} |\n\ + | DSV traverse mean ms | {dsv_traverse_mean:.1} |\n\ + | DSV commit mean ms | {dsv_commit_mean:.1} |\n\ + | DSV step_one mean ms | {dsv_step_one_mean:.1} |\n\ + | DSV verify calls/page | {dsv_verify_calls_mean:.1} |\n\ + | DSV step_one calls/page | {dsv_step_one_calls_mean:.1} |\n\ + | DSV avg candidates | {dsv_avg_candidates:.1} |\n\ + | DSV max candidates | {dsv_max_candidates} |\n\ + | DSV empty tree calls | {dsv_empty_tree_calls} |\n\ + | DSV rejected tree calls | {dsv_rejected_tree_calls} |\n\ + | DSV accepted tree calls | {dsv_accepted_tree_calls} |\n\ + | DSV avg tree nodes | {dsv_avg_tree_nodes:.1} |\n\ + | DSV max tree nodes | {dsv_max_tree_nodes} |\n\ + | emitted tokens mean | {emitted_mean} |\n\ + | verify steps mean | {steps_mean} |\n\ + | fallback steps mean | {fallback_mean} |\n\ + | fallback total | {sum_fallbacks} |\n\ + | fallback rate | {fallback_rate:.3} |\n\ + | AAL mean | {aal_mean:.2} |\n\ + | low-AAL outliers | {low_aal_outliers} |\n\ + | low-AAL outlier rate | {low_aal_outlier_rate:.3} |\n\ + | worst AAL | {worst_aal:.2} |\n\ + | worst AAL page | {worst_aal_page} |\n\ + | SR_e2e outliers | {sr_e2e_outliers} |\n\ + | SR_e2e outlier rate | {sr_e2e_outlier_rate:.3} |\n\ + | worst SR_e2e | {worst_sr_e2e:.2}x |\n\ + | worst SR_e2e page | {worst_sr_e2e_page} |\n\ + | SR_decode mean | {sr_decode_mean:.2}x |\n\ + | SR_e2e mean | {sr_e2e_mean:.2}x |\n\ + | SR_decode total | {sr_decode_total:.2}x |\n\ + | SR_e2e total time | {sr_e2e_total:.2}x |\n", + backend = args.backend.as_str(), + mode = args.mode, + task = args.task, + device = args.device, + drafter_device = drafter_device(&args), + draft_source = args.draft_source, + page_dual_stage = args.page_dual_stage, + hsd_entries = hsd_entries, + region_prompt_kind = match args.backend { + Backend::HunyuanOcr => "hunyuanocr_region", + Backend::PaddleOcrVl => "paddleocr_vl_region_task", + Backend::MinerU => "mineru_region_text_recognition", + Backend::GlmOcr => "glmocr_region_text_recognition", + }, + tau = args.tau, + max_tokens = args.max_tokens, + resize_max = args.resize_max, + start_idx = args.start_idx, + max_pages = args.max_pages, + subset_filter = args.subset.as_deref().unwrap_or(""), + language_filter = args.language.as_deref().unwrap_or(""), + outlier_aal_threshold = args.outlier_aal_threshold, + outlier_sr_e2e_threshold = args.outlier_sr_e2e_threshold, + csv_path = csv_path.display(), + baseline_mean = sum_baseline_ms / n, + hsd_mean = sum_hsd_ms / n, + drafter_mean = sum_drafter_ms / n, + decode_mean = sum_decode_ms / n, + prefill_mean = sum_prefill_ms / n, + dsv_candidate_mean = sum_dsv.candidate_build.as_secs_f64() * 1000.0 / n, + dsv_verify_mean = sum_dsv.verify_tree.as_secs_f64() * 1000.0 / n, + dsv_traverse_mean = sum_dsv.traverse.as_secs_f64() * 1000.0 / n, + dsv_commit_mean = sum_dsv.commit.as_secs_f64() * 1000.0 / n, + dsv_step_one_mean = sum_dsv.step_one.as_secs_f64() * 1000.0 / n, + dsv_verify_calls_mean = sum_dsv.verify_tree_calls as f64 / n, + dsv_step_one_calls_mean = sum_dsv.step_one_calls as f64 / n, + dsv_avg_candidates = sum_dsv.avg_candidates(), + dsv_max_candidates = sum_dsv.candidates_max, + dsv_empty_tree_calls = sum_dsv.empty_tree_calls, + dsv_rejected_tree_calls = sum_dsv.rejected_tree_calls, + dsv_accepted_tree_calls = sum_dsv.accepted_tree_calls, + dsv_avg_tree_nodes = sum_dsv.avg_tree_nodes(), + dsv_max_tree_nodes = sum_dsv.tree_nodes_max, + emitted_mean = sum_emitted / counted as u64, + steps_mean = sum_steps / counted as u64, + fallback_mean = sum_fallbacks / counted as u64, + aal_mean = sum_aal / n, + low_aal_outliers = low_aal_outliers, + low_aal_outlier_rate = low_aal_outliers as f64 / n, + worst_aal = worst_aal, + worst_aal_page = worst_aal_page, + sr_e2e_outliers = sr_e2e_outliers, + sr_e2e_outlier_rate = sr_e2e_outliers as f64 / n, + worst_sr_e2e = worst_sr_e2e, + worst_sr_e2e_page = worst_sr_e2e_page, + sr_decode_mean = sum_sr_decode / n, + sr_e2e_mean = sum_sr_e2e / n, + ); + + std::fs::write(&csv_path, csv)?; + std::fs::write(&summary_path, summary)?; + println!("\nPer-page CSV → {}", csv_path.display()); + println!("Markdown summary → {}", summary_path.display()); + Ok(()) + } + + fn format_duration_ms(duration: std::time::Duration) -> String { + format!("{:.3}", duration.as_secs_f64() * 1000.0) + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn axis_aligned_iou_self_is_one() { + let b = [0.0, 0.0, 10.0, 10.0]; + assert!((axis_aligned_iou(&b, &b) - 1.0).abs() < 1e-6); + } + + #[test] + fn axis_aligned_iou_disjoint_is_zero() { + let a = [0.0, 0.0, 5.0, 5.0]; + let b = [10.0, 10.0, 20.0, 20.0]; + assert_eq!(axis_aligned_iou(&a, &b), 0.0); + } + + #[test] + fn axis_aligned_iou_half_overlap() { + // a = 10x10 at origin; b = 10x10 shifted right by 5 → 5x10 overlap. + // intersection = 50, union = 100 + 100 - 50 = 150, IoU = 1/3. + let a = [0.0, 0.0, 10.0, 10.0]; + let b = [5.0, 0.0, 15.0, 10.0]; + let iou = axis_aligned_iou(&a, &b); + assert!((iou - (1.0 / 3.0)).abs() < 1e-6); + } + + #[test] + fn axis_aligned_iou_zero_area_returns_zero() { + let degenerate = [5.0, 5.0, 5.0, 10.0]; // zero width + let b = [0.0, 0.0, 10.0, 10.0]; + assert_eq!(axis_aligned_iou(°enerate, &b), 0.0); + assert_eq!(axis_aligned_iou(&b, °enerate), 0.0); + } + + fn region(bbox: [f32; 4], raw: &str) -> CrossVlmRegion { + CrossVlmRegion { + bbox, + raw_text: raw.to_string(), + } + } + + #[test] + fn match_cross_vlm_picks_best_iou_above_threshold() { + let elem = [0.0, 0.0, 10.0, 10.0]; + let regions = vec![ + region([100.0, 100.0, 200.0, 200.0], "far"), + region([0.0, 0.0, 9.0, 10.0], "best"), // IoU = 0.9 + region([2.0, 2.0, 12.0, 12.0], "ok"), // IoU ≈ 0.471 + ]; + let m = match_cross_vlm_region(&elem, ®ions, 0.5).expect("match"); + assert_eq!(m.raw_text, "best"); + } + + #[test] + fn match_cross_vlm_returns_none_below_threshold() { + let elem = [0.0, 0.0, 10.0, 10.0]; + let regions = vec![region([100.0, 100.0, 200.0, 200.0], "far")]; + assert!(match_cross_vlm_region(&elem, ®ions, 0.5).is_none()); + } + + #[test] + fn cross_vlm_draft_file_parses_minimal_json() { + let json = r#"{ + "source_backend": "paddleocr_vl", + "pages": { + "page-001.png": [ + {"bbox": [10.0, 20.0, 200.0, 50.0], "raw_text": "$$x = 1$$"} + ] + } + }"#; + let parsed: CrossVlmDraftFile = serde_json::from_str(json).expect("parse"); + assert_eq!(parsed.source_backend.as_deref(), Some("paddleocr_vl")); + let regions = parsed.lookup_page("page-001.png").expect("page"); + assert_eq!(regions.len(), 1); + assert_eq!(regions[0].raw_text, "$$x = 1$$"); + assert_eq!(regions[0].bbox, [10.0, 20.0, 200.0, 50.0]); + } + + #[test] + fn cross_vlm_draft_file_falls_back_to_basename() { + let json = r#"{"pages": {"page-001.png": [{"bbox": [0,0,1,1], "raw_text": "x"}]}}"#; + let parsed: CrossVlmDraftFile = serde_json::from_str(json).expect("parse"); + // Caller may pass a nested path; lookup should fall back to basename. + let regions = parsed + .lookup_page("images/subset/page-001.png") + .expect("page by basename"); + assert_eq!(regions.len(), 1); + } + } +} // mod imp diff --git a/oar-ocr-vl/examples/hunyuanocr.rs b/oar-ocr-vl/examples/hunyuanocr.rs index 38f6869..38d7189 100644 --- a/oar-ocr-vl/examples/hunyuanocr.rs +++ b/oar-ocr-vl/examples/hunyuanocr.rs @@ -12,7 +12,7 @@ //! //! ```bash //! cargo run -p oar-ocr-vl --example hunyuanocr -- \\ -//! --model-dir /path/to/HunyuanOCR \\ +//! --model-dir models/HunyuanOCR \\ //! --prompt "Detect and recognize text in the image, and output the text coordinates in a formatted manner." \\ //! document.jpg //! ``` diff --git a/oar-ocr-vl/examples/mineru.rs b/oar-ocr-vl/examples/mineru.rs index 3fbad75..dcff374 100644 --- a/oar-ocr-vl/examples/mineru.rs +++ b/oar-ocr-vl/examples/mineru.rs @@ -13,7 +13,7 @@ //! //! ```bash //! cargo run -p oar-ocr-vl --example mineru -- \ -//! --model-dir /path/to/MinerU2.5-2509-1.2B \ +//! --model-dir models/MinerU2.5-2509-1.2B \ //! --device cuda:0 \ //! document.jpg //! ``` diff --git a/oar-ocr-vl/examples/unirec.rs b/oar-ocr-vl/examples/unirec.rs deleted file mode 100644 index ca9ba32..0000000 --- a/oar-ocr-vl/examples/unirec.rs +++ /dev/null @@ -1,187 +0,0 @@ -//! UniRec Unified Recognition Example (Candle-based) -//! -//! This example demonstrates how to use the UniRec model for unified recognition -//! of text, formulas, tables, and more using the Candle ML framework for native Rust inference. -//! -//! # Usage -//! -//! ```bash -//! cargo run -p oar-ocr-vl --example unirec -- [OPTIONS] ... -//! ``` -//! -//! # Arguments -//! -//! * `-m, --model-dir` - Path to the UniRec model directory (containing model.safetensors, config.json, tokenizer.json) -//! * `-d, --device` - Device to run on: cpu, cuda, cuda:N, or metal (default: cpu) -//! * `--max-tokens` - Maximum number of tokens to generate (default: 512) -//! * `-v, --verbose` - Enable verbose output -//! * `...` - Paths to input images to process -//! -//! # Examples -//! -//! ```bash -//! # Run on CPU -//! cargo run -p oar-ocr-vl --example unirec -- \ -//! -m models/unirec-0.1b \ -//! formula.jpg text.jpg -//! -//! # Run on CUDA GPU -//! cargo run -p oar-ocr-vl --features cuda --example unirec -- \ -//! -m models/unirec-0.1b -d cuda \ -//! formula.jpg text.jpg -//! ``` - -mod utils; - -use clap::Parser; -use std::path::PathBuf; -use std::time::Instant; - -use tracing::{error, info}; - -use oar_ocr_core::utils::load_image; -use oar_ocr_vl::UniRec; -use oar_ocr_vl::utils::parse_device; - -/// Command-line arguments for the UniRec example -#[derive(Parser)] -#[command(name = "unirec")] -#[command( - about = "UniRec Unified Recognition Example - recognizes text, formulas, tables using Candle for native Rust inference" -)] -struct Args { - /// Path to the UniRec model directory (containing model.safetensors, config.json, tokenizer.json) - #[arg(short, long)] - model_dir: PathBuf, - - /// Paths to input images to process - #[arg(required = true)] - images: Vec, - - /// Device to run on: cpu, cuda, cuda:N, or metal (default: cpu) - #[arg(short, long, default_value = "cpu")] - device: String, - - /// Maximum number of tokens to generate (default: 512) - #[arg(long, default_value = "512")] - max_tokens: usize, - - /// Enable verbose output - #[arg(short, long)] - verbose: bool, -} - -fn main() -> Result<(), Box> { - // Initialize tracing for logging - utils::init_tracing(); - - // Parse command-line arguments - let args = Args::parse(); - - info!("UniRec Unified Recognition Example (Candle-based)"); - - // Verify that the model directory exists - if !args.model_dir.exists() { - error!("Model directory not found: {}", args.model_dir.display()); - return Err("Model directory not found".into()); - } - - // Check for required files - let model_file = args.model_dir.join("model.safetensors"); - let config_file = args.model_dir.join("config.json"); - let tokenizer_file = args.model_dir.join("tokenizer.json"); - - if !model_file.exists() { - error!("Model file not found: {}", model_file.display()); - return Err("model.safetensors not found in model directory".into()); - } - if !config_file.exists() { - error!("Config file not found: {}", config_file.display()); - return Err("config.json not found in model directory".into()); - } - if !tokenizer_file.exists() { - error!("Tokenizer file not found: {}", tokenizer_file.display()); - return Err("tokenizer.json not found in model directory".into()); - } - - // Filter out non-existent image files - let existing_images: Vec = args - .images - .iter() - .filter(|path| { - let exists = path.exists(); - if !exists { - error!("Image file not found: {}", path.display()); - } - exists - }) - .cloned() - .collect(); - - if existing_images.is_empty() { - error!("No valid image files found"); - return Err("No valid image files found".into()); - } - - // Determine device - let device = parse_device(&args.device)?; - info!("Using device: {:?}", device); - - // Load the UniRec model - info!("Loading UniRec model from: {}", args.model_dir.display()); - let load_start = Instant::now(); - let model = UniRec::from_dir(&args.model_dir, device)?; - let load_duration = load_start.elapsed(); - info!( - "Model loaded in {:.2}ms", - load_duration.as_secs_f64() * 1000.0 - ); - - if args.verbose { - let cfg = model.config(); - info!("Model configuration:"); - info!(" d_model: {}", cfg.d_model); - info!(" vocab_size: {}", cfg.vocab_size); - info!(" decoder_layers: {}", cfg.decoder_layers); - info!(" decoder_attention_heads: {}", cfg.decoder_attention_heads); - info!(" input_size: {}x{}", cfg.input_width, cfg.input_height); - } - - // Process each image - info!("\n=== Processing {} images ===", existing_images.len()); - - for image_path in &existing_images { - info!("\nProcessing: {}", image_path.display()); - - // Load image - let rgb_img = match load_image(image_path) { - Ok(img) => { - if args.verbose { - info!(" Loaded image: {}x{}", img.width(), img.height()); - } - img - } - Err(e) => { - error!(" Failed to load image: {}", e); - continue; - } - }; - - // Run inference - let infer_start = Instant::now(); - match model.generate(&[rgb_img], args.max_tokens).pop().unwrap() { - Ok(result) => { - let infer_duration = infer_start.elapsed(); - info!( - " Inference time: {:.2}ms", - infer_duration.as_secs_f64() * 1000.0 - ); - info!(" Result: {}", result); - } - Err(e) => { - error!(" Inference failed: {}", e); - } - } - } - Ok(()) -} diff --git a/oar-ocr-vl/examples/utils/mod.rs b/oar-ocr-vl/examples/utils/mod.rs index 937b16d..e6cee3c 100644 --- a/oar-ocr-vl/examples/utils/mod.rs +++ b/oar-ocr-vl/examples/utils/mod.rs @@ -1,6 +1,16 @@ //! Common utilities for oar-ocr-vl examples. +#[allow(dead_code)] +pub mod structure_match; + +#[cfg(feature = "hsd")] +use std::time::Duration; + +#[cfg(feature = "hsd")] +use oar_ocr_vl::hsd::types::{DsvConfig, HsdConfig, HsdStats, SpecDecodeStats}; + /// Initializes the tracing subscriber for logging in examples. +#[allow(dead_code)] pub fn init_tracing() { use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -12,3 +22,186 @@ pub fn init_tracing() { .with(tracing_subscriber::fmt::layer()) .init(); } + +#[cfg(feature = "hsd")] +#[allow(dead_code)] +pub fn make_hsd_cfg( + max_tokens: usize, + tau: f32, + window_len: usize, + max_candidates_per_step: usize, + max_suffix_len: usize, + enable_stage1: bool, +) -> HsdConfig { + HsdConfig { + dsv: DsvConfig { + tau, + window_len, + max_candidates_per_step, + max_suffix_len, + ..Default::default() + }, + enable_stage1, + enable_stage2: true, + max_page_tokens: max_tokens, + max_region_tokens: max_tokens, + } +} + +/// Default `max_candidates_per_step` for the smoke examples' real-drafter path. +/// Kept in a constant so per-example clap defaults and the oracle auto-tune +/// detector ("did the user pass a value different from the default?") agree. +#[cfg(feature = "hsd")] +#[allow(dead_code)] +pub const DEMO_DEFAULT_MAX_CANDIDATES: usize = 32; + +/// Default `max_suffix_len` for the smoke examples' real-drafter path. See +/// [`DEMO_DEFAULT_MAX_CANDIDATES`] for why this is a constant rather than +/// inlined per-CLI. +#[cfg(feature = "hsd")] +#[allow(dead_code)] +pub const DEMO_DEFAULT_MAX_SUFFIX_LEN: usize = 64; + +/// Resolve `(max_candidates_per_step, max_suffix_len)` for an HSD example's +/// perf pass, transparently auto-tuning to the optimal config when the +/// caller is on the oracle path (draft = baseline output) and hasn't +/// overridden the defaults. +/// +/// Why this exists: the smoke `hsd_*` examples default to oracle (no +/// `--draft-text` / `--draft-file`) for correctness verification. With the +/// real-drafter defaults `max_suffix_len = 64, max_candidates = 32`, a 448+ +/// token oracle baseline gets chopped into 64-token chunks; the matcher then +/// finds many candidates per window on repetition-heavy outputs (table HTML, +/// formula LaTeX) and explodes the prefix tree to ~2000 nodes per verify +/// step. The resulting verify_tree forward approaches baseline forward cost +/// and end-to-end speedup regresses below 1.0×. +/// +/// When the oracle path is detected: +/// - `max_suffix_len → max_tokens` so cold-start emits the entire baseline +/// as a single candidate (its prefix is guaranteed to match target token +/// for token). +/// - `max_candidates → 4` because there is only one viable candidate; the +/// smaller width avoids speculative tree branching that just costs +/// verify forward time. +/// +/// If the user explicitly passed `--max-suffix-len` or `--max-candidates`, +/// their values pass through unchanged. +/// +/// Returns `Some((eff_max_candidates, eff_max_suffix_len, message))` if +/// auto-tune fired, `None` if the original defaults remain in effect. +#[cfg(feature = "hsd")] +#[allow(dead_code)] +pub fn auto_tune_hsd_oracle( + is_oracle: bool, + cli_max_candidates: usize, + cli_max_suffix_len: usize, + max_tokens: usize, +) -> (usize, usize, Option) { + if is_oracle + && cli_max_candidates == DEMO_DEFAULT_MAX_CANDIDATES + && cli_max_suffix_len == DEMO_DEFAULT_MAX_SUFFIX_LEN + { + let note = format!( + " (oracle path detected: auto-tuning max_suffix_len={max_tokens} and \ + max_candidates=4 — pass --max-suffix-len / --max-candidates to override.)" + ); + (4, max_tokens, Some(note)) + } else { + (cli_max_candidates, cli_max_suffix_len, None) + } +} + +#[cfg(feature = "hsd")] +#[allow(dead_code)] +pub fn print_hsd_stats( + baseline_dur: Duration, + hsd_dur: Duration, + stats: &HsdStats, + include_stage1: bool, +) { + let baseline_decode_estimate = baseline_dur.saturating_sub(stats.stage2.vision_prefill); + let sr_decode = if stats.stage2.decode.as_secs_f64() > 0.0 { + baseline_decode_estimate.as_secs_f64() / stats.stage2.decode.as_secs_f64() + } else { + 0.0 + }; + let sr_e2e = if hsd_dur.as_secs_f64() > 0.0 { + baseline_dur.as_secs_f64() / hsd_dur.as_secs_f64() + } else { + 0.0 + }; + + println!("\n=== STATS ==="); + println!("baseline e2e: {:?}", baseline_dur); + println!("HSD e2e: {:?}", hsd_dur); + println!(" drafter prep: {:?}", stats.drafter); + if include_stage1 { + println!(" stage1 prep: {:?}", stats.stage1.draft_prep); + println!(" stage1 vision+prefill:{:?}", stats.stage1.vision_prefill); + println!(" stage1 decode: {:?}", stats.stage1.decode); + println!(" stage1 forward passes:{}", stats.stage1.forward_passes); + } + println!(" vision+prefill: {:?}", stats.stage2.vision_prefill); + println!(" decode: {:?}", stats.stage2.decode); + println!(" forward passes: {}", stats.stage2.forward_passes); + println!(" emitted tokens: {}", stats.stage2.emitted_tokens); + println!(" verify steps: {}", stats.stage2.accept.num_steps); + println!( + " fallback steps: {}", + stats.stage2.accept.num_fallbacks + ); + println!(" AAL: {:.2}", stats.stage2.accept.aal()); + print_dsv_stats(&stats.stage2.dsv); + println!("\nSR_decode (estimated): {:.2}x", sr_decode); + println!("SR_e2e: {:.2}x", sr_e2e); +} + +#[cfg(feature = "hsd")] +#[allow(dead_code)] +pub fn print_dsv_stats(dsv: &SpecDecodeStats) { + println!(" DSV candidate build: {:?}", dsv.candidate_build); + println!( + " DSV verify_tree: {:?} (calls={}, avg_nodes={:.1}, max_nodes={})", + dsv.verify_tree, + dsv.verify_tree_calls, + dsv.avg_tree_nodes(), + dsv.tree_nodes_max + ); + println!(" DSV traverse: {:?}", dsv.traverse); + println!(" DSV commit: {:?}", dsv.commit); + println!( + " DSV step_one: {:?} (calls={})", + dsv.step_one, dsv.step_one_calls + ); +} + +#[allow(dead_code)] +pub fn print_preview(tag: &str, text: &str) { + let preview: String = text.chars().take(400).collect(); + let preview_chars = preview.chars().count(); + let total_chars = text.chars().count(); + println!("--- {tag} (first {preview_chars} chars) ---"); + println!("{preview}"); + if total_chars > preview_chars { + println!("... [{} more chars]", total_chars - preview_chars); + } +} + +#[allow(dead_code)] +pub fn print_diff(baseline: &str, hsd: &str) { + let common = baseline + .chars() + .zip(hsd.chars()) + .take_while(|(a, b)| a == b) + .count(); + eprintln!( + "Lengths: baseline={}, hsd={}, common prefix={}", + baseline.len(), + hsd.len(), + common + ); + let snippet = + |s: &str, start: usize| -> String { s.chars().skip(start).take(80).collect::() }; + eprintln!("baseline[{common}..]: {:?}", snippet(baseline, common)); + eprintln!("hsd[{common}..]: {:?}", snippet(hsd, common)); +} diff --git a/oar-ocr-vl/examples/utils/structure_match.rs b/oar-ocr-vl/examples/utils/structure_match.rs new file mode 100644 index 0000000..2386dcb --- /dev/null +++ b/oar-ocr-vl/examples/utils/structure_match.rs @@ -0,0 +1,367 @@ +//! Source-aware matching from `StructureResult` candidates to OmniDocBench +//! target regions. +//! +//! Two-pass policy: +//! +//! 1. **Same-category pass**: only candidates whose `LayoutElementType` +//! shares the target's `semantic_category()` are eligible, using a +//! relaxed IoU floor (`same_category_iou`). The category pre-filter +//! bounds poisoning risk so a lower IoU is safe. +//! 2. **Cross-category fallback**: any candidate at the strict IoU floor +//! (`cross_category_iou`). Preserves the previous "max IoU wins" safety +//! net for cases where the structure pipeline assigns an unexpected +//! type to the matching region. +//! +//! Tables and formulas are pre-typed by the structure pipeline (they live +//! on `StructureResult::tables` / `::formulas`), so they always use the +//! same-category threshold. They optionally fall back to generic layout +//! text if `allow_generic_fallback` is set. +//! +//! For target types whose `semantic_category()` is `"region"` or +//! `"other"`, the same-category pass is skipped (the category carries no +//! useful signal) and we go straight to the cross-category fallback. + +use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, StructureResult}; +use oar_ocr_core::processors::BoundingBox; + +#[derive(Debug, Clone, Copy)] +pub struct MatchThresholds { + pub same_category_iou: f32, + pub cross_category_iou: f32, + pub allow_generic_fallback: bool, +} + +impl MatchThresholds { + pub fn new( + same_category_iou: f32, + cross_category_iou: f32, + allow_generic_fallback: bool, + ) -> Self { + Self { + same_category_iou, + cross_category_iou, + allow_generic_fallback, + } + } +} + +#[derive(Debug, Clone)] +pub struct StructureMatch { + pub source: &'static str, + pub text: String, + pub iou: f32, + pub same_category: bool, +} + +pub fn match_region( + result: &StructureResult, + elem: &LayoutElement, + th: MatchThresholds, +) -> Option { + match elem.element_type { + LayoutElementType::Table => best_table(result, &elem.bbox, th), + LayoutElementType::Chart => None, + LayoutElementType::Formula => best_formula(result, &elem.bbox, th), + LayoutElementType::Image + | LayoutElementType::HeaderImage + | LayoutElementType::FooterImage => None, + other => best_layout(result, &elem.bbox, other, th), + } +} + +fn best_layout( + result: &StructureResult, + target: &BoundingBox, + target_type: LayoutElementType, + th: MatchThresholds, +) -> Option { + let target_cat = target_type.semantic_category(); + let same_cat_useful = !matches!(target_cat, "region" | "other"); + + if same_cat_useful { + let same = result + .layout_elements + .iter() + .filter_map(|c| { + let text = c.text.as_ref()?.trim(); + if text.is_empty() { + return None; + } + if c.element_type.semantic_category() != target_cat { + return None; + } + let iou = target.iou(&c.bbox); + (iou >= th.same_category_iou).then(|| (iou, text.to_string())) + }) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); + if let Some((iou, text)) = same { + return Some(StructureMatch { + source: "layout", + text, + iou, + same_category: true, + }); + } + } + + result + .layout_elements + .iter() + .filter_map(|c| { + let text = c.text.as_ref()?.trim(); + if text.is_empty() { + return None; + } + let iou = target.iou(&c.bbox); + (iou >= th.cross_category_iou).then(|| (iou, text.to_string())) + }) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(iou, text)| StructureMatch { + source: "layout", + text, + iou, + same_category: false, + }) +} + +fn best_table( + result: &StructureResult, + target: &BoundingBox, + th: MatchThresholds, +) -> Option { + let direct = result + .tables + .iter() + .filter_map(|table| { + let html = table.html_structure.as_ref()?.trim(); + if html.is_empty() { + return None; + } + let iou = target.iou(&table.bbox); + (iou >= th.same_category_iou).then(|| (iou, html.to_string())) + }) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(iou, text)| StructureMatch { + source: "table", + text, + iou, + same_category: true, + }); + + direct.or_else(|| { + if !th.allow_generic_fallback { + return None; + } + best_layout(result, target, LayoutElementType::Table, th) + }) +} + +fn best_formula( + result: &StructureResult, + target: &BoundingBox, + th: MatchThresholds, +) -> Option { + let direct = result + .formulas + .iter() + .filter_map(|formula| { + let latex = formula.latex.trim(); + if latex.is_empty() { + return None; + } + let iou = target.iou(&formula.bbox); + (iou >= th.same_category_iou).then(|| (iou, latex.to_string())) + }) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(iou, text)| StructureMatch { + source: "formula", + text, + iou, + same_category: true, + }); + + direct.or_else(|| { + if !th.allow_generic_fallback { + return None; + } + best_layout(result, target, LayoutElementType::Formula, th) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use oar_ocr_core::domain::structure::{ + FormulaResult, LayoutElement, LayoutElementType, StructureResult, TableResult, TableType, + }; + use oar_ocr_core::processors::BoundingBox; + + fn bb(x1: f32, y1: f32, x2: f32, y2: f32) -> BoundingBox { + BoundingBox::from_coords(x1, y1, x2, y2) + } + + fn cand(t: LayoutElementType, b: BoundingBox, text: &str) -> LayoutElement { + LayoutElement::new(b, t, 1.0).with_text(text) + } + + fn target(t: LayoutElementType, b: BoundingBox) -> LayoutElement { + LayoutElement::new(b, t, 1.0) + } + + fn empty_result() -> StructureResult { + StructureResult::new("test.jpg", 0) + } + + fn th_default() -> MatchThresholds { + MatchThresholds::new(0.5, 0.8, false) + } + + /// Same-category at relaxed IoU wins over cross-category that would + /// otherwise be ineligible (below the strict floor) — exactly the + /// poisoning case the policy is designed to avoid. + #[test] + fn same_category_beats_lower_iou_cross_category() { + let mut r = empty_result(); + // Cross-category candidate at IoU below strict 0.8 (would poison + // under a pure max-IoU policy if relaxed): + r.layout_elements.push(cand( + LayoutElementType::Text, + bb(0.0, 30.0, 100.0, 100.0), + "BODY TEXT (wrong type)", + )); + // Same-category candidate at IoU ~0.66 (passes 0.5 floor): + r.layout_elements.push(cand( + LayoutElementType::DocTitle, + bb(10.0, 10.0, 100.0, 60.0), + "TITLE TEXT", + )); + + let t = target(LayoutElementType::DocTitle, bb(0.0, 0.0, 100.0, 50.0)); + let m = match_region(&r, &t, th_default()).unwrap(); + assert_eq!(m.text, "TITLE TEXT"); + assert!(m.same_category); + } + + /// Cross-category falls through at strict 0.8 floor when no same-cat + /// candidate is eligible. + #[test] + fn cross_category_only_at_strict_threshold() { + let mut r = empty_result(); + // No same-category candidate; only a Text candidate at IoU = 1.0. + r.layout_elements.push(cand( + LayoutElementType::Text, + bb(0.0, 0.0, 100.0, 50.0), + "FALLBACK BODY", + )); + + let t = target(LayoutElementType::DocTitle, bb(0.0, 0.0, 100.0, 50.0)); + let m = match_region(&r, &t, th_default()).unwrap(); + assert_eq!(m.text, "FALLBACK BODY"); + assert!(!m.same_category); + } + + /// Cross-category candidate below strict threshold yields no match — + /// don't poison with a partial overlap of the wrong type. + #[test] + fn cross_category_below_strict_returns_none() { + let mut r = empty_result(); + r.layout_elements.push(cand( + LayoutElementType::Text, + bb(40.0, 0.0, 100.0, 50.0), + "PARTIAL OVERLAP", + )); + let t = target(LayoutElementType::DocTitle, bb(0.0, 0.0, 100.0, 50.0)); + assert!(match_region(&r, &t, th_default()).is_none()); + } + + /// "region" / "other" semantic categories skip the same-cat pass and go + /// straight to the cross-category fallback. + #[test] + fn region_target_skips_same_category_pass() { + let mut r = empty_result(); + r.layout_elements.push(cand( + LayoutElementType::Text, + bb(0.0, 0.0, 100.0, 50.0), + "ANY TEXT", + )); + let t = target(LayoutElementType::Region, bb(0.0, 0.0, 100.0, 50.0)); + let m = match_region(&r, &t, th_default()).unwrap(); + assert_eq!(m.text, "ANY TEXT"); + assert!(!m.same_category); + } + + /// Table → table candidate at relaxed threshold (the candidate set + /// is already type-restricted, so a low IoU floor is safe). + #[test] + fn table_target_uses_relaxed_threshold() { + let mut r = empty_result(); + r.tables.push( + TableResult::new(bb(0.0, 0.0, 100.0, 60.0), TableType::Wired) + .with_html_structure("x
"), + ); + let t = target(LayoutElementType::Table, bb(0.0, 0.0, 100.0, 50.0)); + let m = match_region(&r, &t, th_default()).unwrap(); + assert_eq!(m.source, "table"); + assert!(m.same_category); + } + + /// Formula → formula candidate at relaxed threshold. + #[test] + fn formula_target_uses_relaxed_threshold() { + let mut r = empty_result(); + r.formulas.push(FormulaResult::new( + bb(0.0, 0.0, 100.0, 60.0), + r"\sum x", + 1.0, + )); + let t = target(LayoutElementType::Formula, bb(0.0, 0.0, 100.0, 50.0)); + let m = match_region(&r, &t, th_default()).unwrap(); + assert_eq!(m.source, "formula"); + assert!(m.same_category); + } + + /// Without `allow_generic_fallback`, a missing table candidate yields + /// None even if a generic layout candidate would have matched. + #[test] + fn table_no_generic_fallback_by_default() { + let mut r = empty_result(); + r.layout_elements.push(cand( + LayoutElementType::Table, + bb(0.0, 0.0, 100.0, 50.0), + "table-as-text", + )); + let t = target(LayoutElementType::Table, bb(0.0, 0.0, 100.0, 50.0)); + assert!(match_region(&r, &t, th_default()).is_none()); + } + + /// With `allow_generic_fallback`, the same case finds a draft via + /// the generic layout pass. + #[test] + fn table_generic_fallback_when_enabled() { + let mut r = empty_result(); + r.layout_elements.push(cand( + LayoutElementType::Table, + bb(0.0, 0.0, 100.0, 50.0), + "table-as-text", + )); + let t = target(LayoutElementType::Table, bb(0.0, 0.0, 100.0, 50.0)); + let th = MatchThresholds::new(0.5, 0.8, true); + let m = match_region(&r, &t, th).unwrap(); + assert_eq!(m.source, "layout"); + } + + /// Image / Chart targets are intentionally non-drafted. + #[test] + fn image_and_chart_targets_return_none() { + let mut r = empty_result(); + r.layout_elements.push(cand( + LayoutElementType::Image, + bb(0.0, 0.0, 100.0, 50.0), + "alt text", + )); + let img = target(LayoutElementType::Image, bb(0.0, 0.0, 100.0, 50.0)); + let chart = target(LayoutElementType::Chart, bb(0.0, 0.0, 100.0, 50.0)); + assert!(match_region(&r, &img, th_default()).is_none()); + assert!(match_region(&r, &chart, th_default()).is_none()); + } +} diff --git a/oar-ocr-vl/src/attention.rs b/oar-ocr-vl/src/attention.rs index cdb1347..22b1728 100644 --- a/oar-ocr-vl/src/attention.rs +++ b/oar-ocr-vl/src/attention.rs @@ -1,14 +1,15 @@ //! Unified attention implementation for all VLM models. //! //! This module provides shared attention and rotary embedding implementations -//! to ensure consistent behavior across LightOnOCR, PaddleOCR-VL, HunyuanOCR, and UniRec models. +//! to ensure consistent behavior across PaddleOCR-VL, HunyuanOCR, GLM-OCR, +//! and MinerU2.5 models. //! //! ## Benefits //! //! - Single place for attention and RoPE optimizations //! - Consistent mask handling across models //! - Shared KV cache logic -//! - Support for multiple RoPE variants (standard, MRoPE, XDRoPE) +//! - Support for multi-axis RoPE variants (MRoPE, XDRoPE) //! - Easier testing and maintenance //! //! ## Usage @@ -22,10 +23,6 @@ //! // Create causal mask for autoregressive decoding //! let mask = create_causal_mask(seq_len, kv_len, dtype, device)?; //! -//! // Standard RoPE (for LightOnOCR) -//! let rope = RotaryEmbedding::new(base, head_dim, max_pos, device)?; -//! let (cos, sin) = rope.get_cos_sin(seq_len, dtype)?; -//! //! // Multi-axis RoPE (for PaddleOCR-VL, HunyuanOCR) //! let rope = RotaryEmbedding::new_multi_axis(head_dim, rope_theta, num_dims, device)?; //! let (cos, sin) = rope.forward_multi_axis(&position_ids, dtype)?; @@ -218,6 +215,103 @@ pub fn combine_masks(causal_mask: &Tensor, padding_mask: &Tensor) -> Result], + prefix_kv_len: usize, + dtype: DType, + device: &Device, +) -> Result { + let n = parents.len(); + let total_kv = prefix_kv_len + n; + + if n == 0 { + return on_compute_device(device, |compute_device| { + Tensor::zeros((1, 1, 0, total_kv), dtype, compute_device) + }); + } + + // Host-side mask buffer. Initialised to -inf, then we punch holes for + // (a) the accepted prefix (whole left block) and (b) each node's ancestor + // chain (sparse hits in the right block). + // + // Earlier revisions materialised an O(N²) `Vec>` ancestor matrix + // up front; that's redundant because we only consume each ancestor set + // once. Walking the parent chain inline during the buffer fill removes: + // - the N² bool allocation (up to 32² = 1 KiB, irrelevant for big-O but + // one extra heap call per HSD step), and + // - the per-row O(N) "is this an ancestor?" check (now a tight + // parent-pointer walk of length = node depth, typically ≤ 8 for HSD + // trees of width 32). + // + // The accepted-prefix block uses `slice::fill` (memset under the hood) + // instead of a per-cell loop, which is the dominant cost on long pages — + // for a 16k-token prefix and 32 candidates that's ~512K float writes per + // verify step, and memset is 5-10× faster than the per-element write. + let mut buf = vec![f32::NEG_INFINITY; n * total_kv]; + for i in 0..n { + let row_off = i * total_kv; + // (a) Allow attending to the entire accepted prefix — single memset. + if prefix_kv_len > 0 { + buf[row_off..row_off + prefix_kv_len].fill(0.0); + } + // (b) Allow attending to this node + its ancestors. Walk parent + // pointers in place; no auxiliary bitset. + let mut cur = Some(i); + while let Some(j) = cur { + buf[row_off + prefix_kv_len + j] = 0.0; + cur = parents[j]; + } + } + + on_compute_device(device, move |compute_device| { + Tensor::from_vec(buf, (1, 1, n, total_kv), compute_device)?.to_dtype(dtype) + }) +} + /// Create a left-padding mask for batched sequences. /// /// Left-padding aligns sequences at the right edge, which is standard for @@ -235,7 +329,7 @@ pub fn combine_masks(causal_mask: &Tensor, padding_mask: &Tensor) -> Result Result { - let last_dim = x.dim(D::Minus1)?; - let x1 = x.narrow(D::Minus1, 0, last_dim / 2)?; - let x2 = x.narrow(D::Minus1, last_dim / 2, last_dim / 2)?; - Tensor::cat(&[&x2.neg()?, &x1], D::Minus1) -} - // ============================================================================ // Rotary Positional Embedding (RoPE) // ============================================================================ @@ -296,20 +383,17 @@ fn rotate_half(x: &Tensor) -> Result { /// Unified Rotary Positional Embedding implementation. /// /// Supports multiple RoPE variants: -/// - **Standard RoPE**: Precomputed cos/sin for efficient lookup (LightOnOCR) -/// - **Dynamic RoPE**: On-the-fly computation from position IDs +/// - **Dynamic RoPE**: On-the-fly computation from position IDs (single-axis) /// - **Multi-axis RoPE (MRoPE)**: 3-axis encoding for text/height/width (PaddleOCR-VL) /// - **Extended Dimension RoPE (XDRoPE)**: Configurable num_dims (HunyuanOCR) /// /// ## Architecture /// -/// The implementation uses an enum to support different RoPE modes while maintaining -/// a unified interface. This eliminates code duplication across models. +/// All variants share a single `Dynamic` representation parameterized by +/// `num_dims`; the constructor functions pick the right `num_dims` for each +/// model family. #[derive(Debug, Clone)] pub enum RotaryEmbedding { - /// Precomputed cos/sin for standard RoPE (used by LightOnOCR). - /// Efficient for inference with fixed max sequence length. - Precomputed { cos: Tensor, sin: Tensor }, /// Dynamic computation from inverse frequencies (used by PaddleOCR-VL, HunyuanOCR). /// Supports multi-axis position encoding. Dynamic { @@ -320,48 +404,6 @@ pub enum RotaryEmbedding { } impl RotaryEmbedding { - /// Create a standard RoPE with precomputed cos/sin (LightOnOCR style). - /// - /// This precomputes embeddings for all positions up to max_position_embeddings, - /// enabling efficient lookup during inference. - /// - /// # Arguments - /// * `base` - RoPE base frequency (typically 10000.0) - /// * `head_dim` - Dimension of each attention head - /// * `max_position_embeddings` - Maximum sequence length to precompute - /// * `device` - Device for tensor allocation - /// * `dtype` - Data type for cos/sin tensors - /// - /// # Returns - /// RotaryEmbedding in Precomputed mode - pub fn new_precomputed( - base: f32, - head_dim: usize, - max_position_embeddings: usize, - device: &Device, - dtype: DType, - ) -> Result { - let inv_freq: Vec<_> = (0..head_dim) - .step_by(2) - .map(|i| 1f32 / base.powf(i as f32 / head_dim as f32)) - .collect(); - let inv_freq_len = inv_freq.len(); - - // Use on_compute_device to handle Metal's lack of support for arange - let freqs = on_compute_device(device, |compute_device| { - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), compute_device)?; - let t = Tensor::arange(0u32, max_position_embeddings as u32, compute_device)? - .to_dtype(DType::F32)? - .reshape((max_position_embeddings, 1))?; - t.matmul(&inv_freq) - })?; - - let sin = freqs.sin()?.to_dtype(dtype)?; - let cos = freqs.cos()?.to_dtype(dtype)?; - - Ok(Self::Precomputed { cos, sin }) - } - /// Create a dynamic RoPE with on-the-fly computation (standard single-axis). /// /// This computes embeddings dynamically from position IDs, suitable for @@ -423,35 +465,7 @@ impl RotaryEmbedding { Ok(Self::Dynamic { inv_freq, num_dims }) } - /// Get precomputed cos/sin for standard RoPE (Precomputed mode only). - /// - /// Used by LightOnOCR for efficient embedding lookup. - /// - /// # Arguments - /// * `seq_len` - Sequence length to retrieve embeddings for - /// * `dtype` - Target data type - /// - /// # Returns - /// Tuple of (cos, sin) tensors, shape: (seq_len, head_dim/2) - /// - /// # Panics - /// Panics if called on Dynamic mode - pub fn get_cos_sin(&self, seq_len: usize, dtype: DType) -> Result<(Tensor, Tensor)> { - match self { - Self::Precomputed { cos, sin } => { - let cos = cos.narrow(0, 0, seq_len)?.to_dtype(dtype)?; - let sin = sin.narrow(0, 0, seq_len)?.to_dtype(dtype)?; - Ok((cos, sin)) - } - Self::Dynamic { .. } => { - panic!( - "get_cos_sin() called on Dynamic RoPE mode. Use forward_multi_axis() instead." - ) - } - } - } - - /// Forward pass for multi-axis RoPE (Dynamic mode only). + /// Forward pass for multi-axis RoPE. /// /// Computes cos/sin from position IDs dynamically. Supports multi-dimensional /// position encoding. @@ -462,9 +476,6 @@ impl RotaryEmbedding { /// /// # Returns /// Tuple of (cos, sin) tensors, shape: (num_dims, batch, seq, head_dim) - /// - /// # Panics - /// Panics if called on Precomputed mode pub fn forward_multi_axis( &self, position_ids: &Tensor, @@ -570,95 +581,6 @@ impl RotaryEmbedding { })?; Ok((cos, sin)) } - Self::Precomputed { .. } => { - panic!( - "forward_multi_axis() called on Precomputed RoPE mode. Use get_cos_sin() instead." - ) - } - } - } - - /// Apply rotary embeddings to query and key tensors (for Precomputed mode). - /// - /// This is a convenience method for LightOnOCR-style usage with precomputed embeddings. - /// - /// # Arguments - /// * `q` - Query tensor: (batch, heads, seq, head_dim) - /// * `k` - Key tensor: (batch, heads, seq, head_dim) - /// * `seqlen_offsets` - Sequence length offsets for each batch item - /// - /// # Returns - /// Tuple of (rotated_q, rotated_k) - pub fn apply_rotary_emb( - &self, - q: &Tensor, - k: &Tensor, - seqlen_offsets: &[usize], - ) -> Result<(Tensor, Tensor)> { - match self { - Self::Precomputed { cos, sin } => { - let (b_sz, _qh, seq_len, _n_embd) = q.dims4()?; - - // Helper to apply RoPE with full broadcasting support - let apply_rope = |x: &Tensor, cos: &Tensor, sin: &Tensor| -> Result { - let x_rot = rotate_half(x)?; - // x * cos + x_rot * sin - // cos/sin shapes must be broadcastable to x shape - let a = x.broadcast_mul(cos)?; - let b = x_rot.broadcast_mul(sin)?; - a.add(&b) - }; - - // Fast path: if all offsets are equal (e.g. 0 during prefill) - if seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]) { - let offset = seqlen_offsets[0]; - let cos_seq = cos.narrow(0, offset, seq_len)?; - let sin_seq = sin.narrow(0, offset, seq_len)?; - - // Repeat to match head_dim: (S, D/2) -> (S, D) - let cos_seq = Tensor::cat(&[&cos_seq, &cos_seq], D::Minus1)?; - let sin_seq = Tensor::cat(&[&sin_seq, &sin_seq], D::Minus1)?; - - // (S, D) -> (1, 1, S, D) for broadcasting - let cos_b = cos_seq.unsqueeze(0)?.unsqueeze(0)?; - let sin_b = sin_seq.unsqueeze(0)?.unsqueeze(0)?; - - let q_out = apply_rope(q, &cos_b, &sin_b)?; - let k_out = apply_rope(k, &cos_b, &sin_b)?; - return Ok((q_out, k_out)); - } - - // General path: gather embeddings for each batch item - let device = q.device(); - let mut indices = Vec::with_capacity(b_sz * seq_len); - for &offset in seqlen_offsets { - for i in 0..seq_len { - indices.push((offset + i) as u32); - } - } - let indices = Tensor::from_vec(indices, (b_sz * seq_len,), device)?; - let cos_gather = cos.index_select(&indices, 0)?; // (B*S, D/2) - let sin_gather = sin.index_select(&indices, 0)?; - - // Repeat to match head_dim: (B*S, D/2) -> (B*S, D) - let cos_gather = Tensor::cat(&[&cos_gather, &cos_gather], D::Minus1)?; - let sin_gather = Tensor::cat(&[&sin_gather, &sin_gather], D::Minus1)?; - - // Reshape to (B, 1, S, D) for broadcasting against (B, H, S, D) - // Note: cos.dim(1) is D/2, so after cat it is D. - let shape = (b_sz, 1, seq_len, cos_gather.dim(1)?); - let cos_b = cos_gather.reshape(shape)?; - let sin_b = sin_gather.reshape(shape)?; - - let q_out = apply_rope(q, &cos_b, &sin_b)?; - let k_out = apply_rope(k, &cos_b, &sin_b)?; - Ok((q_out, k_out)) - } - Self::Dynamic { .. } => { - panic!( - "apply_rotary_emb() requires Precomputed mode. Use forward_multi_axis() for Dynamic mode." - ) - } } } } @@ -851,6 +773,223 @@ mod tests { Ok(()) } + #[test] + fn test_tree_attention_mask_shape() -> Result<()> { + let device = Device::Cpu; + // chain: 0 -> 1 -> 2 (linear), prefix_kv_len = 2 + let parents = vec![None, Some(0), Some(1)]; + let mask = create_tree_attention_mask(&parents, 2, DType::F32, &device)?; + assert_eq!(mask.dims(), &[1, 1, 3, 5]); + let m: Vec = mask.flatten_all()?.to_vec1()?; + // Row 0: prefix [0,1] = 0; cand cols [0]=self -> 0; [1,2]=-inf + assert_eq!(m[0], 0.0); + assert_eq!(m[1], 0.0); + assert_eq!(m[2], 0.0); + assert!(m[3].is_infinite()); + assert!(m[4].is_infinite()); + // Row 1: prefix=0; cand cols [0,1]=ancestors -> 0; [2]=-inf + assert_eq!(m[5], 0.0); + assert_eq!(m[5 + 1], 0.0); + assert_eq!(m[5 + 2], 0.0); + assert_eq!(m[5 + 3], 0.0); + assert!(m[5 + 4].is_infinite()); + // Row 2: prefix=0; cand cols [0,1,2]=full path -> all 0 + for col in 0..5 { + assert_eq!(m[10 + col], 0.0); + } + Ok(()) + } + + #[test] + fn test_tree_attention_mask_branching() -> Result<()> { + let device = Device::Cpu; + // tree: + // root + // | + // 0 (tok) + // / \ + // 1 2 + let parents = vec![None, Some(0), Some(0)]; + let mask = create_tree_attention_mask(&parents, 0, DType::F32, &device)?; + assert_eq!(mask.dims(), &[1, 1, 3, 3]); + let m: Vec = mask.flatten_all()?.to_vec1()?; + // Node 1 should NOT attend to node 2 (its sibling). + // Layout: row*3 + col + // Row 1 (node 1): col 0 = parent (allowed), col 1 = self (allowed), col 2 = sibling (FORBIDDEN) + assert_eq!(m[3], 0.0); + assert_eq!(m[3 + 1], 0.0); + assert!(m[3 + 2].is_infinite()); + // Row 2 (node 2): col 0 = parent, col 1 = sibling, col 2 = self + assert_eq!(m[6], 0.0); + assert!(m[6 + 1].is_infinite()); + assert_eq!(m[6 + 2], 0.0); + Ok(()) + } + + #[test] + fn test_tree_attention_mask_changes_sibling_logits() -> Result<()> { + let device = Device::Cpu; + let parents = vec![None, Some(0), Some(0)]; + let tree_mask = create_tree_attention_mask(&parents, 0, DType::F32, &device)?; + let no_op_mask = Tensor::zeros((1, 1, 3, 3), DType::F32, &device)?; + + let q = Tensor::from_vec( + vec![ + 0.0f32, 0.0, // node 0 + 1.0, 0.0, // node 1 strongly matches sibling key + 0.0, 0.0, // node 2 + ], + (1, 1, 3, 2), + &device, + )?; + let k = Tensor::from_vec( + vec![ + 0.0f32, 0.0, // node 0 + 0.0, 0.0, // node 1 + 10.0, 0.0, // node 2 sibling + ], + (1, 1, 3, 2), + &device, + )?; + let v = Tensor::from_vec( + vec![ + 0.0f32, 0.0, // node 0 + 0.0, 0.0, // node 1 + 100.0, 0.0, // node 2 sibling + ], + (1, 1, 3, 2), + &device, + )?; + + let masked = scaled_dot_product_attention(&q, &k, &v, Some(&tree_mask), 1.0, false)?; + let unmasked = scaled_dot_product_attention(&q, &k, &v, Some(&no_op_mask), 1.0, false)?; + + let masked_node_1: Vec = masked.i((0, 0, 1, ..))?.to_vec1()?; + let unmasked_node_1: Vec = unmasked.i((0, 0, 1, ..))?.to_vec1()?; + assert!( + unmasked_node_1[0] - masked_node_1[0] > 90.0, + "sibling unmask should materially change node logits: masked={masked_node_1:?}, unmasked={unmasked_node_1:?}" + ); + + Ok(()) + } + + #[test] + fn test_tree_attention_mask_empty() -> Result<()> { + let device = Device::Cpu; + let mask = create_tree_attention_mask(&[], 4, DType::F32, &device)?; + assert_eq!(mask.dims(), &[1, 1, 0, 4]); + Ok(()) + } + + /// Reference implementation kept from the pre-optimization revision + /// (the `Vec>` ancestor-bitset version). Lets the parity test + /// compare the optimized parent-walk path against the original + /// O(N²)-bitset path across a battery of randomly shaped trees. + fn create_tree_attention_mask_reference( + parents: &[Option], + prefix_kv_len: usize, + dtype: DType, + device: &Device, + ) -> Result { + let n = parents.len(); + let total_kv = prefix_kv_len + n; + if n == 0 { + return Tensor::zeros((1, 1, 0, total_kv), dtype, device); + } + let mut ancestors: Vec> = vec![vec![false; n]; n]; + for (i, row) in ancestors.iter_mut().enumerate() { + let mut cur = Some(i); + while let Some(j) = cur { + row[j] = true; + cur = parents[j]; + } + } + let mut buf = vec![f32::NEG_INFINITY; n * total_kv]; + for (row_buf, ancestor_row) in buf.chunks_mut(total_kv).zip(ancestors.iter()) { + row_buf[..prefix_kv_len].fill(0.0); + for (j, &is_anc) in ancestor_row.iter().enumerate() { + if is_anc { + row_buf[prefix_kv_len + j] = 0.0; + } + } + } + Tensor::from_vec(buf, (1, 1, n, total_kv), device)?.to_dtype(dtype) + } + + #[test] + fn test_tree_attention_mask_matches_reference_on_varied_shapes() -> Result<()> { + let device = Device::Cpu; + + // Deterministic pseudo-random tree shapes (we hand-pick parents so the + // test is reproducible without a seeded RNG dependency). + let shapes: Vec<(Vec>, usize)> = vec![ + // Linear chain, no prefix. + (vec![None, Some(0), Some(1), Some(2)], 0), + // Linear chain, with prefix. + (vec![None, Some(0), Some(1)], 7), + // Branching tree, with prefix. + ( + vec![ + None, + Some(0), + Some(0), + Some(1), + Some(1), + Some(2), + Some(2), + Some(2), + ], + 12, + ), + // Multi-root forest, no prefix. + (vec![None, None, Some(0), Some(1), Some(2)], 0), + // Single root, deep+wide, with prefix matching a realistic HSD + // size: ~16 candidates, 4096-token prefix. + ( + { + let mut p = vec![None]; + for i in 1..16 { + // alternate root-children and chains, like an + // expanded prefix-tree from many drafts. + p.push(if i % 4 == 0 { None } else { Some(i - 1) }); + } + p + }, + 4096, + ), + ]; + + for (parents, prefix_kv_len) in shapes { + let opt = create_tree_attention_mask(&parents, prefix_kv_len, DType::F32, &device)?; + let reference = + create_tree_attention_mask_reference(&parents, prefix_kv_len, DType::F32, &device)?; + assert_eq!( + opt.dims(), + reference.dims(), + "shape mismatch for parents={:?}", + parents + ); + let opt_vals: Vec = opt.flatten_all()?.to_vec1()?; + let ref_vals: Vec = reference.flatten_all()?.to_vec1()?; + assert_eq!(opt_vals.len(), ref_vals.len()); + for (i, (&a, &b)) in opt_vals.iter().zip(ref_vals.iter()).enumerate() { + let same = match (a.is_infinite(), b.is_infinite()) { + (true, true) => a.is_sign_negative() == b.is_sign_negative(), + (false, false) => (a - b).abs() < 1e-6, + _ => false, + }; + assert!( + same, + "mismatch at idx {i} for parents={:?} prefix={}: opt={a} ref={b}", + parents, prefix_kv_len + ); + } + } + + Ok(()) + } + #[test] fn test_repeat_kv() -> Result<()> { let device = Device::Cpu; @@ -934,20 +1073,6 @@ mod tests { // RoPE Tests // ======================================================================== - #[test] - fn test_rotary_embedding_precomputed() -> Result<()> { - let device = Device::Cpu; - let rope = RotaryEmbedding::new_precomputed(10000.0, 64, 512, &device, DType::F32) - .expect("Failed to create RoPE"); - - // Test getting cos/sin for a sequence - let (cos, sin) = rope.get_cos_sin(128, DType::F32)?; - assert_eq!(cos.dims(), &[128, 32]); // head_dim/2 - assert_eq!(sin.dims(), &[128, 32]); - - Ok(()) - } - #[test] fn test_rotary_embedding_dynamic_single_axis() -> std::result::Result<(), OCRError> { let device = Device::Cpu; @@ -1037,22 +1162,4 @@ mod tests { Ok(()) } - - #[test] - fn test_rotary_embedding_apply() -> Result<()> { - let device = Device::Cpu; - let rope = RotaryEmbedding::new_precomputed(10000.0, 64, 512, &device, DType::F32) - .expect("Failed to create RoPE"); - - let q = Tensor::randn(0f32, 1., (2, 4, 8, 64), &device)?; // (batch, heads, seq, dim) - let k = Tensor::randn(0f32, 1., (2, 4, 8, 64), &device)?; - - let seqlen_offsets = vec![0, 0]; - let (q_rot, k_rot) = rope.apply_rotary_emb(&q, &k, &seqlen_offsets)?; - - assert_eq!(q_rot.dims(), q.dims()); - assert_eq!(k_rot.dims(), k.dims()); - - Ok(()) - } } diff --git a/oar-ocr-vl/src/doc_parser.rs b/oar-ocr-vl/src/doc_parser.rs index 423aa0a..a1ddc75 100644 --- a/oar-ocr-vl/src/doc_parser.rs +++ b/oar-ocr-vl/src/doc_parser.rs @@ -7,11 +7,9 @@ //! 4. Returns structured document results //! //! Supported backends: -//! - `UniRec` - Lightweight unified recognition -//! - `PaddleOcrVl` - Larger VLM with task-specific prompts +//! - `PaddleOcrVl` - VLM with task-specific prompts //! - `HunyuanOcr` - OCR expert VLM (HunYuanVL) //! - `GlmOcr` - GLM-OCR OCR expert VLM -//! - `LightOnOcr` - End-to-end OCR VLM //! - `MinerU` - MinerU2.5 document parsing VLM (Qwen2-VL backbone) use super::utils::{ @@ -59,7 +57,7 @@ pub trait RecognitionBackend { ) -> Result; /// Whether this backend requires post-processing for table output. - /// UniRec outputs OTSL format that needs conversion; PaddleOCR-VL outputs HTML directly. + /// Some backends emit OTSL and need conversion; PaddleOCR-VL outputs HTML directly. fn needs_table_postprocess(&self) -> bool { false } @@ -150,7 +148,7 @@ impl<'a, B: RecognitionBackend> DocParser<'a, B> { /// Parse a document image without layout detection (single full-image OCR). /// - /// Use this for end-to-end models (e.g. LightOnOCR) that handle layout internally. + /// Use this for end-to-end models that handle layout internally. /// For models requiring separate layout detection, use [`parse`](Self::parse) instead. pub fn parse_without_layout(&self, image: RgbImage) -> Result { self.recognize_full_image("".into(), 0, image) @@ -437,42 +435,8 @@ impl<'a, B: RecognitionBackend> DocParser<'a, B> { } } -use super::unirec::UniRec; - -impl RecognitionBackend for UniRec { - fn recognize( - &self, - image: RgbImage, - _task: RecognitionTask, - max_tokens: usize, - ) -> Result { - // UniRec doesn't use task-specific prompts; it's a unified model - self.generate(&[image], max_tokens) - .into_iter() - .next() - .unwrap_or_else(|| { - Err(OCRError::InvalidInput { - message: "UniRec: no result returned".to_string(), - }) - }) - } - - fn needs_table_postprocess(&self) -> bool { - true // UniRec outputs OTSL format - } - - fn needs_formula_preprocess(&self) -> bool { - true // Benefit from margin cropping - } - - fn needs_repetition_truncation(&self) -> bool { - true // May produce repetitive output - } -} - use super::glmocr::GlmOcr; use super::hunyuanocr::HunyuanOcr; -use super::lightonocr::LightOnOcr; use super::mineru::MinerU; use super::paddleocr_vl::{PaddleOcrVl, PaddleOcrVlTask}; @@ -599,50 +563,6 @@ impl RecognitionBackend for GlmOcr { } } -impl RecognitionBackend for LightOnOcr { - fn recognize( - &self, - image: RgbImage, - task: RecognitionTask, - max_tokens: usize, - ) -> Result { - let prompt = match task { - RecognitionTask::Ocr => "", - RecognitionTask::Table => "Parse the table in the image into HTML.", - RecognitionTask::Formula => { - "Identify the formula in the image and represent it using LaTeX format." - } - RecognitionTask::Chart => { - "Parse the chart in the image; use Mermaid format for flowcharts and Markdown for other charts." - } - }; - let out = self - .generate(&[image], &[prompt], max_tokens) - .into_iter() - .next() - .unwrap_or_else(|| { - Err(OCRError::InvalidInput { - message: "LightOnOCR: no result returned".to_string(), - }) - })?; - Ok(truncate_repetitive_content(&out, 10, 10, 10) - .trim() - .to_string()) - } - - fn needs_table_postprocess(&self) -> bool { - false - } - - fn needs_formula_preprocess(&self) -> bool { - false - } - - fn needs_repetition_truncation(&self) -> bool { - false // handled inside `recognize()` - } -} - impl RecognitionBackend for MinerU { fn recognize( &self, @@ -765,9 +685,6 @@ fn should_have_order_index(element_type: LayoutElementType) -> bool { ) } -/// Document parser using UniRec backend. -pub type UniRecDocParser<'a> = DocParser<'a, UniRec>; - /// Document parser using PaddleOCR-VL backend. pub type PaddleOcrVlDocParser<'a> = DocParser<'a, PaddleOcrVl>; diff --git a/oar-ocr-vl/src/glmocr/model.rs b/oar-ocr-vl/src/glmocr/model.rs index 75e1791..4343dae 100644 --- a/oar-ocr-vl/src/glmocr/model.rs +++ b/oar-ocr-vl/src/glmocr/model.rs @@ -2,14 +2,37 @@ use super::config::{EosTokenId, GlmOcrConfig, GlmOcrImageProcessorConfig}; use super::processing::{GlmOcrImageInputs, preprocess_image}; use super::text::GlmOcrTextModel; use super::vision::GlmOcrVisionModel; +#[cfg(feature = "hsd")] +use crate::attention::create_tree_attention_mask; +#[cfg(feature = "hsd")] +use crate::hsd::backend_util::{commit_keep_indices, step_pos_ids, tree_pos_ids}; +#[cfg(feature = "hsd")] +use crate::hsd::drafting::{ + TargetDraftAdapter, bbox_xyxy, crop_region_image, format_verified_region, map_layout_kind, + region_markdown_for, region_markdowns_for, structure_result_to_layout_elements, +}; +#[cfg(feature = "hsd")] +use crate::hsd::prefix_tree::PrefixTree; +#[cfg(feature = "hsd")] +use crate::hsd::types::{AcceptStats, Draft, HsdConfig, HsdStats, RegionStageStats}; +#[cfg(feature = "hsd")] +use crate::hsd::verify::{SpecBackend, spec_decode}; use crate::utils::{ candle_to_ocr_inference, candle_to_ocr_processing, truncate_repetitive_content, }; +#[cfg(feature = "hsd")] +use candle_core::Result as CandleResult; use candle_core::{D, DType, Device, IndexOp, Tensor}; +#[cfg(feature = "hsd")] +use candle_nn::ops as cnn_ops; use candle_nn::{Linear, Module, VarBuilder, linear_no_bias}; use image::RgbImage; use oar_ocr_core::core::OCRError; +#[cfg(feature = "hsd")] +use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, StructureResult}; use std::path::Path; +#[cfg(feature = "hsd")] +use std::time::{Duration, Instant}; use tokenizers::Tokenizer; pub struct GlmOcr { @@ -122,7 +145,48 @@ impl GlmOcr { })]; } - match self.generate_internal(images, instructions, max_new_tokens) { + match self.generate_tokens_internal(images, instructions, max_new_tokens) { + Ok(results) => results + .into_iter() + .map(|tokens| self.decode_generated_tokens(&tokens)) + .collect(), + Err(e) => { + let msg = format!("generation failed: {e}"); + (0..images.len()) + .map(|_| { + Err(OCRError::InvalidInput { + message: msg.clone(), + }) + }) + .collect() + } + } + } + + /// Generate raw baseline tokens for oracle-draft / tokenizer round-trip + /// experiments. Tokens are exactly the ids emitted by the decode loop, + /// excluding stop tokens, before tokenizer decoding or repetition + /// truncation. + pub fn generate_tokens( + &self, + images: &[RgbImage], + instructions: &[impl AsRef], + max_new_tokens: usize, + ) -> Vec, OCRError>> { + if images.is_empty() { + return Vec::new(); + } + if images.len() != instructions.len() { + return vec![Err(OCRError::InvalidInput { + message: format!( + "GLM-OCR: images count ({}) != instructions count ({})", + images.len(), + instructions.len() + ), + })]; + } + + match self.generate_tokens_internal(images, instructions, max_new_tokens) { Ok(results) => results.into_iter().map(Ok).collect(), Err(e) => { let msg = format!("generation failed: {e}"); @@ -137,12 +201,12 @@ impl GlmOcr { } } - fn generate_internal( + fn generate_tokens_internal( &self, images: &[RgbImage], instructions: &[impl AsRef], max_new_tokens: usize, - ) -> Result, OCRError> { + ) -> Result>, OCRError> { let mut results = Vec::with_capacity(images.len()); for (image, instruction) in images.iter().zip(instructions.iter()) { @@ -243,19 +307,41 @@ impl GlmOcr { logits = self.logits_from_hidden(&last)?; } - let decoded = - self.tokenizer - .decode(&generated, true) - .map_err(|e| OCRError::InvalidInput { - message: format!("GLM-OCR: tokenizer decode failed: {e}"), - })?; - let decoded = truncate_repetitive_content(&decoded, 10, 10, 10); - results.push(decoded.trim().to_string()); + results.push(generated); } Ok(results) } + pub fn decode_tokens(&self, tokens: &[u32]) -> Result { + self.decode_generated_tokens(tokens) + } + + /// Decode tokens **without** applying GLM-OCR's repetition-collapse + /// post-process. Use this when feeding GLM-OCR output as a draft to + /// another target VLM — DSV matches at token granularity, and any + /// repetition collapse on the source side will byte-mismatch the target's + /// natural output, destroying acceptance length. + pub fn decode_tokens_raw(&self, tokens: &[u32]) -> Result { + let decoded = self + .tokenizer + .decode(tokens, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("GLM-OCR: tokenizer decode failed: {e}"), + })?; + Ok(decoded.trim().to_string()) + } + + pub fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + fn decode_generated_tokens(&self, tokens: &[u32]) -> Result { + let raw = self.decode_tokens_raw(tokens)?; + let truncated = truncate_repetitive_content(&raw, 10, 10, 10); + Ok(truncated.trim().to_string()) + } + fn prepare_inputs( &self, input_ids: &[u32], @@ -388,6 +474,332 @@ impl GlmOcr { Ok(embeds) } + /// Hierarchical Speculative Decoding entry for a single image / region. + /// + /// Use `generate_hsd_full` for the two-stage document flow. + #[cfg(feature = "hsd")] + pub fn generate_hsd( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let t_drafter = Instant::now(); + let tokenized = self.tokenize_drafts(drafts)?; + self.generate_hsd_tokenized( + image, + instruction, + &tokenized, + hsd_cfg, + hsd_cfg.max_region_tokens, + t_drafter.elapsed(), + ) + } + + /// HSD entry that consumes already-tokenized drafts. This is the oracle + /// path used by benchmarks to avoid `decode -> encode` tokenizer + /// round-trips when the draft comes from this backend's own baseline. + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_token_drafts( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[Draft], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + self.generate_hsd_tokenized( + image, + instruction, + drafts, + hsd_cfg, + hsd_cfg.max_region_tokens, + Duration::ZERO, + ) + } + + #[cfg(feature = "hsd")] + fn tokenize_drafts(&self, drafts: &[String]) -> Result, OCRError> { + let mut tokenized: Vec = Vec::with_capacity(drafts.len()); + for d in drafts { + if d.trim().is_empty() { + continue; + } + let enc = + self.tokenizer + .encode(d.as_str(), false) + .map_err(|e| OCRError::InvalidInput { + message: format!("GLM-OCR HSD: tokenizer encode failed: {e}"), + })?; + let tokens = enc.get_ids().to_vec(); + if !tokens.is_empty() { + tokenized.push(Draft::new(tokens)); + } + } + Ok(tokenized) + } + + #[cfg(feature = "hsd")] + fn generate_hsd_tokenized( + &self, + image: &RgbImage, + instruction: &str, + tokenized: &[Draft], + hsd_cfg: &HsdConfig, + max_new_tokens: usize, + drafter_elapsed: Duration, + ) -> Result<(String, HsdStats), OCRError> { + if !self.device.is_cuda() { + return Err(OCRError::ConfigError { + message: "HSD requires CUDA device".to_string(), + }); + } + + let mut stats = HsdStats { + drafter: drafter_elapsed, + ..Default::default() + }; + let t_pre = Instant::now(); + let (initial_lp, rope_delta) = self.hsd_prefill_single(image, instruction)?; + stats.stage2.vision_prefill = t_pre.elapsed(); + stats.stage2.forward_passes = 1; + + let t_dec = Instant::now(); + let mut backend = GlmOcrSpecBackend::new(self, rope_delta); + let mut accept = AcceptStats::default(); + let mut dsv = Default::default(); + let generated = spec_decode( + &mut backend, + tokenized, + initial_lp, + max_new_tokens, + &hsd_cfg.dsv, + &mut accept, + &mut dsv, + ) + .map_err(|e| candle_to_ocr_inference("GLM-OCR", "spec_decode", e))?; + stats.stage2.decode = t_dec.elapsed(); + stats.stage2.emitted_tokens = generated.len() as u32; + stats.stage2.accept = accept; + stats.stage2.dsv = dsv; + stats.stage2.forward_passes += backend.forward_passes; + + // Strip the first stop token and anything after it before decoding. + let stop_pos = generated + .iter() + .position(|t| self.eos_token_ids.contains(t)) + .unwrap_or(generated.len()); + let trimmed = &generated[..stop_pos]; + + let decoded = self + .tokenizer + .decode(trimmed, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("GLM-OCR HSD: tokenizer decode failed: {e}"), + })?; + let decoded = truncate_repetitive_content(&decoded, 10, 10, 10); + Ok((decoded.trim().to_string(), stats)) + } + + /// Run the full two-stage HSD: Stage 1 verifies each layout-detected + /// region against the layout drafter's text, then Stage 2 (gated by + /// `hsd_cfg.enable_stage2`) verifies the Stage-1-aggregated markdown on + /// the full image with `hsd_cfg.max_page_tokens` budget. + /// + /// - `enable_stage1 = false`: skip per-region verification; build the + /// Stage 2 draft set directly from the layout drafter's per-element + /// markdowns (`region_markdowns`). Mirrors the paper's Table 8 + /// "Page-level Spec. Decoding only" ablation. + /// - `enable_stage2 = false`: return the Stage-1-only aggregation (lossy + /// ablation matching paper Table 8). + /// + /// `region_instruction` is used only for Stage 1 crop verification; + /// `page_instruction` is used for Stage 2 full-page verification. + #[cfg(feature = "hsd")] + pub fn generate_hsd_full( + &self, + image: &RgbImage, + elements: &[LayoutElement], + ignore_labels: &[String], + page_instruction: &str, + region_instruction: &str, + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let mut stats = HsdStats::default(); + let mut region_md: Vec<(usize, String)> = Vec::with_capacity(elements.len()); + + if hsd_cfg.enable_stage1 { + for (idx, elem) in elements.iter().enumerate() { + if let Some(label) = &elem.label + && ignore_labels.iter().any(|l| l == label) + { + continue; + } + if matches!( + elem.element_type, + LayoutElementType::Image + | LayoutElementType::HeaderImage + | LayoutElementType::FooterImage + | LayoutElementType::Seal + ) { + continue; + } + let draft = region_markdown_for(elem, TargetDraftAdapter::GlmOcr); + if draft.trim().is_empty() { + continue; + } + + let bbox = bbox_xyxy(&elem.bbox); + let crop = crop_region_image(image, &bbox)?; + let drafts = vec![draft]; + let (region_text, region_stats) = + self.generate_hsd(&crop, region_instruction, &drafts, hsd_cfg)?; + stats.drafter += region_stats.drafter; + + let kind = map_layout_kind(elem.element_type); + stats.stage1_regions.push(RegionStageStats { + kind, + stats: region_stats.stage2.clone(), + }); + stats.stage1.add_assign(region_stats.stage2); + let order = elem.order_index.map(|x| x as usize).unwrap_or(idx); + region_md.push((order, format_verified_region(®ion_text, kind))); + } + } + + region_md.sort_by_key(|(order, _)| *order); + let region_md: Vec = region_md + .into_iter() + .map(|(_, text)| text) + .filter(|s| !s.trim().is_empty()) + .collect(); + + // Stage 2 — page-level global verification on the full image. Per + // paper Eq. 3 the page draft is the *unordered set* `Ỹ^pg = {ŷ^(i)}`, + // one draft per region. We pass the Vec straight to `spec_decode` + // instead of pre-joining: `collect_candidates` scans each draft + // independently (Eqs. 1+2), so per-region n-gram locality is + // preserved even when full-page transitions don't appear naturally + // in the target VLM's output. Budget = `max_page_tokens`. + if hsd_cfg.enable_stage2 { + let t_drafter = Instant::now(); + let page_drafts: Vec = if !region_md.is_empty() { + region_md.clone() + } else { + region_markdowns_for(elements, ignore_labels, TargetDraftAdapter::GlmOcr) + }; + if !page_drafts.is_empty() { + let tokenized = self.tokenize_drafts(&page_drafts)?; + let (text, s2_stats) = self.generate_hsd_tokenized( + image, + page_instruction, + &tokenized, + hsd_cfg, + hsd_cfg.max_page_tokens, + t_drafter.elapsed(), + )?; + stats.stage2 = s2_stats.stage2; + stats.drafter += s2_stats.drafter; + return Ok((text, stats)); + } + } + + // Stage 2 disabled or no draft to verify — return Stage-1-only join + // as a human-readable fallback. The `\n\n` separator here is for the + // *output* (caller-facing), not for any further HSD input. + Ok((region_md.join("\n\n"), stats)) + } + + /// One-call HSD entry that consumes a `StructureResult` (the output of + /// the OARStructure / PP-StructureV3 pipeline) directly. + /// + /// Backfills table HTML / formula LaTeX via + /// [`structure_result_to_layout_elements`] then delegates to + /// [`Self::generate_hsd_full`]. See the HunyuanOCR sibling for the + /// full design discussion — GLM-OCR keeps single-draft-per-region semantics + /// (its public Recognition prompts emit one canonical output). + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_structure( + &self, + image: &RgbImage, + page_instruction: &str, + region_instruction: &str, + structure: &StructureResult, + ignore_labels: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let elements = structure_result_to_layout_elements(structure); + self.generate_hsd_full( + image, + &elements, + ignore_labels, + page_instruction, + region_instruction, + hsd_cfg, + ) + } + + /// Run a single-image prefill with the supplied instruction. Returns + /// the F32 last-position log-probabilities and the MRoPE delta. + #[cfg(feature = "hsd")] + fn hsd_prefill_single( + &self, + image: &RgbImage, + instruction: &str, + ) -> Result<(Tensor, i64), OCRError> { + let image_inputs = preprocess_image( + image, + &self.image_cfg, + &self.cfg.vision_config, + &self.device, + self.dtype, + )?; + + let prompt = build_prompt(instruction); + let prompt = expand_image_tokens(&prompt, image_inputs.num_image_tokens)?; + let enc = self + .tokenizer + .encode(prompt, false) + .map_err(|e| OCRError::InvalidInput { + message: format!("GLM-OCR HSD: tokenizer encode failed: {e}"), + })?; + let input_ids = enc.get_ids().to_vec(); + if input_ids.is_empty() { + return Err(OCRError::InvalidInput { + message: "GLM-OCR HSD: empty prompt after tokenization".to_string(), + }); + } + let seq_len = input_ids.len(); + + let inputs_embeds = self.prepare_inputs(&input_ids, &image_inputs)?; + let (position_ids, max_pos) = build_position_ids( + &input_ids, + image_inputs.grid_thw, + self.cfg.vision_config.spatial_merge_size, + self.image_token_id, + &self.device, + )?; + let rope_delta = max_pos + 1 - seq_len as i64; + + // GLM-OCR's existing prefill calls `text.forward(..., None)` — the + // implementation builds its own internal causal mask via the rotary + // path. Match that behaviour here. + self.text.clear_kv_cache(); + let hidden = self.text.forward(&inputs_embeds, &position_ids, None)?; + let last = hidden + .i((0, seq_len - 1, ..)) + .map_err(|e| candle_to_ocr_inference("GLM-OCR", "get last hidden", e))?; + let logits = self.logits_from_hidden(&last)?; + let lp = cnn_ops::log_softmax( + &logits + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("GLM-OCR", "logits to f32", e))?, + D::Minus1, + ) + .map_err(|e| candle_to_ocr_inference("GLM-OCR", "log_softmax prefill", e))?; + Ok((lp, rope_delta)) + } + fn logits_from_hidden(&self, hidden: &Tensor) -> Result { let hidden = hidden.unsqueeze(0).map_err(|e| { candle_to_ocr_processing( @@ -410,6 +822,110 @@ impl GlmOcr { } } +/// HSD adapter for GLM-OCR. 3-axis MRoPE, independent lm_head, rope_delta +/// captured at prefill (same shape as MinerU / PaddleOCR-VL). +#[cfg(feature = "hsd")] +struct GlmOcrSpecBackend<'a> { + model: &'a GlmOcr, + rope_delta: i64, + pre_verify_kv: usize, + forward_passes: u32, +} + +#[cfg(feature = "hsd")] +impl<'a> GlmOcrSpecBackend<'a> { + fn new(model: &'a GlmOcr, rope_delta: i64) -> Self { + Self { + model, + rope_delta, + pre_verify_kv: 0, + forward_passes: 0, + } + } + + fn project_logprobs_2d(&self, hidden_2d: &Tensor) -> CandleResult { + // (N, hidden) → (N, vocab) → log-softmax F32. + // GLM-OCR's lm_head expects shape (..., hidden); 2D works directly. + let logits = self.model.lm_head.forward(hidden_2d)?; + cnn_ops::log_softmax(&logits.to_dtype(DType::F32)?, D::Minus1) + } + + fn project_logprobs_1d(&self, hidden_1d: &Tensor) -> CandleResult { + // lm_head is a Linear that requires ≥ 2-D input. + let logits = self + .model + .lm_head + .forward(&hidden_1d.unsqueeze(0)?)? + .squeeze(0)?; + cnn_ops::log_softmax(&logits.to_dtype(DType::F32)?, D::Minus1) + } +} + +#[cfg(feature = "hsd")] +impl<'a> SpecBackend for GlmOcrSpecBackend<'a> { + fn step_one(&mut self, token: u32) -> CandleResult { + let model = self.model; + let device = &model.device; + + let tok_t = Tensor::new(vec![token], device)?.reshape((1usize, 1usize))?; + let embeds = model + .text + .embed(&tok_t) + .map_err(|e| candle_core::Error::Msg(format!("GLM-OCR HSD step_one embed: {e}")))?; + + let pos_ids = step_pos_ids(3, model.text.current_kv_len(), self.rope_delta, device)?; + + let hidden = model + .text + .forward(&embeds, &pos_ids, None) + .map_err(|e| candle_core::Error::Msg(format!("GLM-OCR HSD step_one forward: {e}")))?; + self.forward_passes += 1; + let last = hidden.i((0, 0, ..))?; + self.project_logprobs_1d(&last) + } + + fn verify_tree(&mut self, tree: &PrefixTree) -> CandleResult { + let n = tree.num_nodes(); + let model = self.model; + let device = &model.device; + let dtype = model.dtype; + + let prefix_kv = model.text.current_kv_len(); + self.pre_verify_kv = prefix_kv; + + let tok_t = Tensor::new(tree.tokens.clone(), device)?.reshape((1usize, n))?; + let embeds = model + .text + .embed(&tok_t) + .map_err(|e| candle_core::Error::Msg(format!("GLM-OCR HSD verify_tree embed: {e}")))?; + + let pos_ids = tree_pos_ids(3, prefix_kv, self.rope_delta, tree, device)?; + let mask = create_tree_attention_mask(&tree.parents, prefix_kv, dtype, device)?; + + let hidden = model + .text + .forward(&embeds, &pos_ids, Some(&mask)) + .map_err(|e| { + candle_core::Error::Msg(format!("GLM-OCR HSD verify_tree forward: {e}")) + })?; + self.forward_passes += 1; + let h2 = hidden.squeeze(0)?; + self.project_logprobs_2d(&h2) + } + + fn commit_verify(&mut self, accepted_path: &[usize]) -> CandleResult<()> { + let indices = commit_keep_indices(self.pre_verify_kv, accepted_path); + self.model + .text + .keep_kv_indices(&indices) + .map_err(|e| candle_core::Error::Msg(format!("GLM-OCR HSD commit_verify: {e}"))) + } + + fn is_eos(&self, tok: u32) -> bool { + self.model.eos_token_ids.contains(&tok) + } +} + fn build_prompt(instruction: &str) -> String { format!( "[gMASK]<|user|>\n<|begin_of_image|><|image|><|end_of_image|>{instruction}<|assistant|>\n" diff --git a/oar-ocr-vl/src/glmocr/text.rs b/oar-ocr-vl/src/glmocr/text.rs index ebc228f..5d90f9c 100644 --- a/oar-ocr-vl/src/glmocr/text.rs +++ b/oar-ocr-vl/src/glmocr/text.rs @@ -1,10 +1,13 @@ use super::config::GlmOcrTextConfig; use crate::attention::{repeat_kv, scaled_dot_product_attention}; +#[cfg(feature = "hsd")] +use crate::hsd::TrimmableKvCache; +#[cfg(not(feature = "hsd"))] +use crate::kv_trim::TrimmableKvCache; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing}; use candle_core::{D, DType, Device, IndexOp, Tensor}; use candle_nn::{ - Embedding, Linear, Module, RmsNorm, VarBuilder, embedding, kv_cache::KvCache, linear_no_bias, - rms_norm, + Embedding, Linear, Module, RmsNorm, VarBuilder, embedding, linear_no_bias, rms_norm, }; use oar_ocr_core::core::OCRError; use std::cell::RefCell; @@ -511,7 +514,7 @@ struct GlmOcrTextAttention { num_kv_groups: usize, head_dim: usize, scaling: f64, - kv_cache: RefCell, + kv_cache: RefCell, } impl GlmOcrTextAttention { @@ -554,7 +557,8 @@ impl GlmOcrTextAttention { .map_err(|e| candle_to_ocr_inference("GLM-OCR", "text o_proj", e))?; let cache_cap = cfg.max_position_embeddings.min(16384); - let kv_cache = KvCache::new(2, cache_cap); + // Trim/gather-capable KV cache (HSD verification path). + let kv_cache = TrimmableKvCache::new(2, cache_cap); Ok(Self { q_proj, @@ -741,6 +745,19 @@ impl GlmOcrTextAttention { fn clear_kv_cache(&self) { self.kv_cache.borrow_mut().reset(); } + + #[cfg(feature = "hsd")] + fn current_kv_len(&self) -> usize { + self.kv_cache.borrow().current_seq_len() + } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.kv_cache + .borrow_mut() + .keep_indices(indices) + .map_err(|e| candle_to_ocr_inference("GLM-OCR", "keep_kv_indices", e)) + } } #[derive(Debug, Clone)] @@ -837,6 +854,11 @@ impl GlmOcrTextDecoderLayer { fn clear_kv_cache(&self) { self.self_attn.clear_kv_cache(); } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.self_attn.keep_kv_indices(indices) + } } #[derive(Debug, Clone)] @@ -900,4 +922,23 @@ impl GlmOcrTextModel { layer.clear_kv_cache(); } } + + /// Current sequence length held in the KV cache (read from layer 0). + #[cfg(feature = "hsd")] + pub fn current_kv_len(&self) -> usize { + self.layers + .first() + .map(|l| l.self_attn.current_kv_len()) + .unwrap_or(0) + } + + /// Gather every layer's KV cache to keep only the supplied positions + /// (in order). Used by HSD after tree-attention verification. + #[cfg(feature = "hsd")] + pub fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + for layer in &self.layers { + layer.keep_kv_indices(indices)?; + } + Ok(()) + } } diff --git a/oar-ocr-vl/src/hsd/backend_util.rs b/oar-ocr-vl/src/hsd/backend_util.rs new file mode 100644 index 0000000..84b27df --- /dev/null +++ b/oar-ocr-vl/src/hsd/backend_util.rs @@ -0,0 +1,128 @@ +//! Mechanical helpers shared by every `SpecBackend` implementation. +//! +//! Each VLM backbone (GLM-OCR, MinerU, PaddleOCR-VL, HunyuanOCR) implements +//! [`super::verify::SpecBackend`] over its own text model and uses its own +//! lm_head / rep-penalty conventions. Those parts are intentionally per-backend. +//! What is *not* per-backend is the index/position-id arithmetic — that lives +//! here so the backends only have to spell out the truly model-specific work. + +use candle_core::{Device, Result as CandleResult, Tensor}; + +use super::prefix_tree::PrefixTree; + +/// Build the 3D position-id tensor for a single-token decode step. +/// +/// Returns a tensor of shape `(axes, 1, 1)` filled with `kv_len + rope_delta`. +/// `axes` is the number of MRoPE axes the backbone uses (3 for GLM-OCR / +/// MinerU / PaddleOCR-VL, 4 for HunyuanOCR). +pub fn step_pos_ids( + axes: usize, + kv_len: usize, + rope_delta: i64, + device: &Device, +) -> CandleResult { + let pos = kv_len as i64 + rope_delta; + let data = vec![pos; axes]; + Tensor::from_vec(data, (axes, 1usize, 1usize), device) +} + +/// Build the 3D position-id tensor for a tree-verification forward pass. +/// +/// Returns a tensor of shape `(axes, 1, num_nodes)` where node `i` is placed +/// at logical position `prefix_kv + rope_delta + tree.depths[i] - 1` along +/// every MRoPE axis (depth-1 = first newly-generated token). +pub fn tree_pos_ids( + axes: usize, + prefix_kv: usize, + rope_delta: i64, + tree: &PrefixTree, + device: &Device, +) -> CandleResult { + let n = tree.num_nodes(); + let mut pos_data: Vec = Vec::with_capacity(axes * n); + for _axis in 0..axes { + for d in &tree.depths { + pos_data.push(prefix_kv as i64 + rope_delta + (*d as i64) - 1); + } + } + Tensor::from_vec(pos_data, (axes, 1usize, n), device) +} + +/// Build the KV-cache `keep_indices` vector for `commit_verify`. +/// +/// The cache keeps `[0, prefix_kv)` (the accepted history) followed by the +/// path-node positions `prefix_kv + p` for each `p` in `accepted_path`. +pub fn commit_keep_indices(prefix_kv: usize, accepted_path: &[usize]) -> Vec { + let mut indices: Vec = Vec::with_capacity(prefix_kv + accepted_path.len()); + for i in 0..prefix_kv { + indices.push(i as u32); + } + for &p in accepted_path { + indices.push((prefix_kv + p) as u32); + } + indices +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hsd::matching::Candidate; + use crate::hsd::prefix_tree::build_prefix_tree; + + fn candidate(tokens: Vec) -> Candidate { + Candidate { + draft_idx: 0, + suffix_start: 0, + tokens, + } + } + + #[test] + fn step_pos_ids_three_axis() -> CandleResult<()> { + let t = step_pos_ids(3, 10, 4, &Device::Cpu)?; + assert_eq!(t.dims(), &[3, 1, 1]); + let v: Vec = t.flatten_all()?.to_vec1()?; + assert_eq!(v, vec![14, 14, 14]); + Ok(()) + } + + #[test] + fn step_pos_ids_four_axis_zero_delta() -> CandleResult<()> { + let t = step_pos_ids(4, 7, 0, &Device::Cpu)?; + assert_eq!(t.dims(), &[4, 1, 1]); + let v: Vec = t.flatten_all()?.to_vec1()?; + assert_eq!(v, vec![7, 7, 7, 7]); + Ok(()) + } + + #[test] + fn tree_pos_ids_shape_and_values() -> CandleResult<()> { + // Build a small tree: root → 10 → 11; root → 20. + let cands = vec![candidate(vec![10u32, 11u32]), candidate(vec![20u32])]; + let tree = build_prefix_tree(&cands); + let t = tree_pos_ids(3, 5, 2, &tree, &Device::Cpu)?; + let n = tree.num_nodes(); + assert_eq!(t.dims(), &[3, 1, n]); + let v: Vec = t.flatten_all()?.to_vec1()?; + // First axis should equal prefix_kv + rope_delta + depth - 1. + for axis in 0..3 { + for (i, &d) in tree.depths.iter().enumerate() { + let expected = 5 + 2 + (d as i64) - 1; + assert_eq!(v[axis * n + i], expected); + } + } + Ok(()) + } + + #[test] + fn commit_indices_layout() { + let indices = commit_keep_indices(3, &[0, 2, 5]); + assert_eq!(indices, vec![0, 1, 2, 3, 5, 8]); + } + + #[test] + fn commit_indices_empty_path() { + let indices = commit_keep_indices(4, &[]); + assert_eq!(indices, vec![0, 1, 2, 3]); + } +} diff --git a/oar-ocr-vl/src/hsd/drafting.rs b/oar-ocr-vl/src/hsd/drafting.rs new file mode 100644 index 0000000..1e4a7f8 --- /dev/null +++ b/oar-ocr-vl/src/hsd/drafting.rs @@ -0,0 +1,1499 @@ +//! Bridge between the layout drafter pipeline and HSD's region/page drafts. +//! +//! The HSD algorithm itself ([`super::matching`], [`super::prefix_tree`], +//! [`super::verify`]) is intentionally tokenizer-agnostic; this module is the +//! one place that knows about [`LayoutElement`] and accepts an injected +//! tokenizer closure. Backends call into here to turn the drafter's +//! recognition output into the [`RegionDraft`] / [`Draft`] values consumed by +//! [`super::verify::spec_decode`]. +//! +//! ## Tokenizer requirement +//! +//! The closure passed to [`build_region_drafts`] / [`build_page_draft`] / +//! [`page_draft_from_region_outputs`] **must be the target VLM's tokenizer**. +//! HSD matches drafts against the verifier's accepted-token tail at token +//! granularity; using a different tokenizer (even one that's "close enough") +//! will quietly destroy the acceptance length. +//! +//! ### Tokenizer parity with the paper's HF Transformers stack +//! +//! Both this crate and HF Transformers ultimately call the same +//! [`tokenizers`](https://docs.rs/tokenizers) Rust crate, so there is **no +//! algorithmic divergence** between the paper's tokenization and this stack's +//! tokenization as long as the same `tokenizer.json` is loaded. All four VLM +//! HSD paths consistently use `tokenizer.encode(text, false)` (i.e. +//! `add_special_tokens = False`), which matches HF Transformers +//! `tokenizer(text, add_special_tokens=False)`. +//! +//! Any remaining byte-level AAL loss is therefore **adapter-level**, not +//! tokenizer-level — i.e. the drafter's serialized string doesn't exactly +//! match the target VLM's natural output convention. That is what +//! [`TargetDraftAdapter`] exists to address; new long-tail divergences +//! (heading prefix style, HTML attribute order, math-wrapper spacing) belong +//! in a new adapter branch with a unit test, not in tokenizer wrapping. + +use image::RgbImage; +use oar_ocr_core::core::OCRError; +use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, StructureResult}; +use oar_ocr_core::processors::BoundingBox; +use oar_ocr_core::utils::BBoxCrop; + +use super::types::{Draft, RegionDraft, RegionKind}; +use crate::utils::table::{convert_html_to_otsl, convert_otsl_to_html, looks_like_table_tokens}; +use crate::utils::to_markdown; + +/// Target-side text surface used to serialize drafter regions before tokenizing +/// them for DSV. Keeping this explicit prevents benchmark and backend code from +/// silently feeding one model another model's natural output convention. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TargetDraftAdapter { + /// Generic markdown conversion used by the original structure pipeline. + Markdown, + /// HunyuanOCR parsing style: headings are plain text, isolated formulas are + /// inline `$$ ... $$`, and page numbers keep the separator HunyuanOCR tends to emit. + HunyuanOcr, + /// Plain region text with no markdown shell. This matches the current + /// GLM-OCR / MinerU page benchmark prompt better than HunyuanOCR markdown. + PlainText, + /// PaddleOCR-VL element-level raw output form (pre-postprocess): + /// - Formula: `$$ ... $$` wrapped (post-process strips the wrapper). + /// - Table: OTSL tokens pass through; HTML is converted to OTSL when the + /// table parser recognizes the structure, otherwise it passes through. + /// - Headings / text / list / page-number: plain text, no `#` shell. + PaddleOcrVl, + /// GLM-OCR Recognition-prompt style: model emits LaTeX without `$$` + /// wrappers, HTML tables pass through, headings are plain text. JSON + /// outputs (information extraction prompts) are not handled here — they + /// require a separate JSON-shaped draft path. + GlmOcr, + /// MinerU2.5 two-step-extract per-element style: + /// - Formula: `$$\n ... \n$$` block form. + /// - Table: HTML pass through. + /// - Headings: plain text (no `#` — the two-step extractor emits the + /// layout label separately and per-element text is bare). + /// - Page numbers: plain text without the HunyuanOCR-style `---` separator. + MinerU, +} + +/// Map a layout element type to the coarse HSD region kind. +pub fn map_layout_kind(t: LayoutElementType) -> RegionKind { + use LayoutElementType::*; + match t { + DocTitle + | ParagraphTitle + | FigureTitle + | TableTitle + | ChartTitle + | FigureTableChartTitle => RegionKind::Title, + Text | Content | Abstract | AsideText | Reference | ReferenceContent | Footnote + | Number => RegionKind::Text, + List => RegionKind::List, + Table => RegionKind::Table, + Formula | FormulaNumber => RegionKind::Formula, + Image | Chart | Seal | HeaderImage | FooterImage => RegionKind::Figure, + Header => RegionKind::Header, + Footer => RegionKind::Footer, + Algorithm | Region | Other => RegionKind::Other, + } +} + +/// Extract axis-aligned `[x_min, y_min, x_max, y_max]` from a bounding box. +pub fn bbox_xyxy(bbox: &BoundingBox) -> [f32; 4] { + [bbox.x_min(), bbox.y_min(), bbox.x_max(), bbox.y_max()] +} + +/// Crop an image to an HSD `[x_min, y_min, x_max, y_max]` bounding box. +pub fn crop_region_image(image: &RgbImage, bbox: &[f32; 4]) -> Result { + let bb = BoundingBox::from_coords(bbox[0], bbox[1], bbox[2], bbox[3]); + BBoxCrop::crop_bounding_box(image, &bb) +} + +/// Serialize a single layout element to the markdown the target VLM would +/// emit for that region in isolation. Falls back to plain text on unknown +/// element types. +pub fn region_markdown(elem: &LayoutElement) -> String { + region_markdown_for(elem, TargetDraftAdapter::Markdown) +} + +fn raw_region_text(elem: &LayoutElement) -> Option<&str> { + elem.text + .as_deref() + .map(str::trim) + .filter(|s| !s.is_empty()) +} + +fn is_visual_only(elem: &LayoutElement) -> bool { + matches!( + elem.element_type, + LayoutElementType::Image + | LayoutElementType::HeaderImage + | LayoutElementType::FooterImage + | LayoutElementType::Seal + ) +} + +fn align_hunyuanocr_heading(text: &str) -> String { + if let Some(rest) = text.strip_prefix("SEC.") + && rest.chars().next().is_some_and(|ch| ch.is_ascii_digit()) + { + return format!("SEC. {rest}"); + } + text.to_string() +} + +/// Normalize a table draft into the HTML form HunyuanOCR / GLM-OCR / MinerU +/// naturally emit. OTSL token streams (` ... `) are converted via +/// `convert_otsl_to_html`; HTML pass-through is preserved; other text is +/// returned unchanged so the per-target adapter caller can decide. +fn table_text_as_html(text: &str) -> String { + if looks_like_table_tokens(text) { + convert_otsl_to_html(text) + } else { + text.to_string() + } +} + +fn hunyuanocr_region_markdown(elem: &LayoutElement) -> String { + if is_visual_only(elem) { + return String::new(); + } + let Some(text) = raw_region_text(elem) else { + return String::new(); + }; + match elem.element_type { + LayoutElementType::Header + | LayoutElementType::DocTitle + | LayoutElementType::ParagraphTitle => align_hunyuanocr_heading(text), + LayoutElementType::Formula | LayoutElementType::FormulaNumber => { + if text.starts_with("$$") || text.starts_with("\\[") { + text.to_string() + } else { + format!("$$ {text} $$") + } + } + LayoutElementType::Number => format!("{text}\n\n---"), + // HunyuanOCR emits HTML tables; auto-convert OTSL drafts so a + // PaddleOCR-VL drafter's OTSL output can be matched. + LayoutElementType::Table => table_text_as_html(text), + _ => text.to_string(), + } +} + +fn plain_region_markdown(elem: &LayoutElement) -> String { + if is_visual_only(elem) { + return String::new(); + } + raw_region_text(elem).unwrap_or("").to_string() +} + +fn paddleocr_vl_region_markdown(elem: &LayoutElement) -> String { + if is_visual_only(elem) { + return String::new(); + } + let Some(text) = raw_region_text(elem) else { + return String::new(); + }; + match elem.element_type { + LayoutElementType::Formula | LayoutElementType::FormulaNumber => { + if text.starts_with("$$") || text.starts_with("\\[") || text.starts_with('$') { + text.to_string() + } else { + format!("$${text}$$") + } + } + // PaddleOCR-VL emits raw OTSL tokens for tables. If the drafter + // supplied OTSL, pass it through. If the drafter supplied HTML, + // attempt the inverse `convert_html_to_otsl` so the draft matches + // what PaddleOCR-VL actually emits before its post-process step. + // Fall back to the original text only when the input is neither + // recognizable form. + LayoutElementType::Table => { + if looks_like_table_tokens(text) { + text.to_string() + } else if text.contains(" text.to_string(), + } +} + +fn glmocr_region_markdown(elem: &LayoutElement) -> String { + if is_visual_only(elem) { + return String::new(); + } + let Some(text) = raw_region_text(elem) else { + return String::new(); + }; + match elem.element_type { + LayoutElementType::Formula | LayoutElementType::FormulaNumber => { + // GLM-OCR's "Formula Recognition:" prompt emits bare LaTeX + // without `$$` delimiters. Strip them if the drafter wrapped. + strip_math_wrappers_str(text).to_string() + } + // GLM-OCR's "Table Recognition:" prompt emits HTML — convert OTSL + // drafts to HTML so source-PaddleOCR-VL → target-GLM-OCR matches. + LayoutElementType::Table => table_text_as_html(text), + // Headings / page numbers / lists: plain text. GLM-OCR's + // Recognition prompts never emit `#` headings or the `---` separator. + _ => text.to_string(), + } +} + +fn mineru_region_markdown(elem: &LayoutElement) -> String { + if is_visual_only(elem) { + return String::new(); + } + let Some(text) = raw_region_text(elem) else { + return String::new(); + }; + match elem.element_type { + LayoutElementType::Formula | LayoutElementType::FormulaNumber => { + // MinerU2.5 two_step_extract emits display formulas in block form + // `$$\n ... \n$$`. Strip any existing wrapper before re-wrapping + // so callers can pass either form. + let core = strip_math_wrappers_str(text); + format!("$$\n{core}\n$$") + } + // MinerU's per-element table prompt emits HTML; convert OTSL drafts + // (e.g. from a PaddleOCR-VL drafter) so the byte form matches. + LayoutElementType::Table => table_text_as_html(text), + // Other elements stay plain — MinerU's per-element step emits + // content without surrounding markdown shell; the layout step adds + // heading levels separately. + _ => text.to_string(), + } +} + +/// Strip `$$ ... $$` or `$ ... $` wrappers from a math snippet. Mirrors +/// `oar_ocr_vl::utils::text::strip_math_wrappers` but kept here to avoid a +/// crate-internal dependency cycle when running the HSD unit tests in +/// isolation. +fn strip_math_wrappers_str(input: &str) -> &str { + let mut trimmed = input.trim(); + trimmed = trimmed + .strip_prefix("$$") + .and_then(|s| s.strip_suffix("$$")) + .unwrap_or(trimmed); + trimmed = trimmed + .strip_prefix('$') + .and_then(|s| s.strip_suffix('$')) + .unwrap_or(trimmed); + trimmed.trim() +} + +/// Serialize a single layout element using the target VLM's draft adapter. +pub fn region_markdown_for(elem: &LayoutElement, adapter: TargetDraftAdapter) -> String { + match adapter { + TargetDraftAdapter::Markdown => { + // Reuse the existing converter on a singleton slice; it trims + // trailing `\n\n` so the output is suitable for direct tokenization. + to_markdown(std::slice::from_ref(elem), &[]) + } + TargetDraftAdapter::HunyuanOcr => hunyuanocr_region_markdown(elem), + TargetDraftAdapter::PlainText => plain_region_markdown(elem), + TargetDraftAdapter::PaddleOcrVl => paddleocr_vl_region_markdown(elem), + TargetDraftAdapter::GlmOcr => glmocr_region_markdown(elem), + TargetDraftAdapter::MinerU => mineru_region_markdown(elem), + } +} + +/// Cross-VLM un-postprocess + re-adapt for a single piece of raw region text. +/// +/// DSV matches drafts against the target VLM's natural output at token +/// granularity, so when a different VLM (source) supplies the draft we must: +/// +/// 1. Obtain the source's **pre-postprocess** decoded string (use the +/// source backend's `decode_tokens_raw` rather than `decode_tokens`). +/// 2. Re-shape that raw string into the target's natural surface via the +/// target's [`TargetDraftAdapter`]. +/// +/// Step 2 is what this helper does. Step 1 must be performed by the caller +/// using the source backend's API — the source's post-process is per-backend +/// and not encoded here. +/// +/// ## Table surface handling +/// +/// PaddleOCR-VL emits tables as raw OTSL tokens; HunyuanOCR / GLM-OCR / MinerU +/// emit tables as HTML. The adapters now bridge the two with +/// [`crate::utils::table::convert_otsl_to_html`] / +/// [`crate::utils::table::convert_html_to_otsl`] so cross-VLM table drafts +/// land on the target's natural surface (PaddleOCR-VL -> HTML-emitting backends +/// and vice versa). Cells with structure that neither parser recognizes still +/// fall back to pass-through. +/// +/// `element_type` and `bbox` are used to drive the adapter's per-kind +/// dispatch; only `text` is actually consumed beyond that. +pub fn convert_raw_to_target_adapter( + raw_text: &str, + element_type: LayoutElementType, + target_adapter: TargetDraftAdapter, +) -> String { + let stub_bbox = BoundingBox::from_coords(0.0, 0.0, 1.0, 1.0); + let mut elem = LayoutElement::new(stub_bbox, element_type, 1.0); + elem.text = Some(raw_text.to_string()); + region_markdown_for(&elem, target_adapter) +} + +/// Wrap verified region text in the coarse markdown shell used when Stage-1 +/// outputs are reassembled into a Stage-2 page draft. +pub fn format_verified_region(text: &str, kind: RegionKind) -> String { + let trimmed = text.trim(); + if trimmed.is_empty() { + return String::new(); + } + match kind { + RegionKind::Title => format!("# {trimmed}"), + RegionKind::Formula => { + if trimmed.starts_with("$$") { + trimmed.to_string() + } else { + format!("$$\n{trimmed}\n$$") + } + } + _ => trimmed.to_string(), + } +} + +/// Aggregate elements into the full-page markdown draft (reading order +/// already applied by the drafter pipeline). +/// +/// Use [`region_markdowns`] instead when feeding Stage 2: the paper's Eq. 3 +/// formulates `Ỹ^pg` as an unordered *set* of per-region drafts, not a single +/// concatenated string. Concatenating breaks the sliding-window matcher when +/// inter-region transitions don't appear in the target VLM's natural output. +pub fn page_markdown(elements: &[LayoutElement], ignore_labels: &[String]) -> String { + to_markdown(elements, ignore_labels) +} + +/// Aggregate target-adapted region drafts into a human-readable page draft. +/// DSV Stage 2 should still prefer [`region_markdowns_for`] so each region +/// remains an independent draft candidate. +pub fn page_markdown_for( + elements: &[LayoutElement], + ignore_labels: &[String], + adapter: TargetDraftAdapter, +) -> String { + region_markdowns_for(elements, ignore_labels, adapter).join("\n\n") +} + +/// One markdown draft per layout element, in input (reading) order. +/// +/// This is the Stage-2 draft set when Stage 1 is disabled and the layout +/// drafter pipeline is the sole draft source (paper §3.1 Eq. 3 with the +/// pipeline outputs in place of `ŷ^(i)`). Each element's markdown becomes a +/// separate entry in `Ỹ^pg`, which the DSV matcher scans independently — +/// this preserves per-region n-gram locality even when the target VLM's +/// full-page output format differs significantly from the drafter's +/// markdown style. +/// +/// Elements are skipped when their `label` matches one of `ignore_labels`, +/// when their serialized markdown is empty / whitespace-only, or when they +/// represent purely visual regions (`Image` / `HeaderImage` / `FooterImage` +/// / `Seal`) that have no recognized text. +pub fn region_markdowns(elements: &[LayoutElement], ignore_labels: &[String]) -> Vec { + region_markdowns_for(elements, ignore_labels, TargetDraftAdapter::Markdown) +} + +/// One target-adapted draft per layout element, in input (reading) order. +pub fn region_markdowns_for( + elements: &[LayoutElement], + ignore_labels: &[String], + adapter: TargetDraftAdapter, +) -> Vec { + let mut out: Vec = Vec::with_capacity(elements.len()); + for elem in elements { + if let Some(label) = &elem.label + && ignore_labels.iter().any(|l| l == label) + { + continue; + } + let md = region_markdown_for(elem, adapter); + let trimmed = md.trim(); + if !trimmed.is_empty() { + out.push(trimmed.to_string()); + } + } + out +} + +/// Serialize multiple text candidates per layout element through a target +/// adapter. The returned Vec is flat because Stage 2 consumes an unordered +/// draft set; Stage 1 grouping is handled by +/// [`build_region_draft_candidates_with_adapter`]. +pub fn region_markdown_candidates_for( + elements: &[LayoutElement], + ignore_labels: &[String], + adapter: TargetDraftAdapter, + text_candidates: C, +) -> Vec +where + C: Fn(&LayoutElement) -> Vec, +{ + let mut out: Vec = Vec::new(); + for elem in elements { + if let Some(label) = &elem.label + && ignore_labels.iter().any(|l| l == label) + { + continue; + } + for candidate in text_candidates(elem) { + let mut candidate_elem = elem.clone(); + candidate_elem.text = Some(candidate); + let md = region_markdown_for(&candidate_elem, adapter); + let trimmed = md.trim(); + if trimmed.is_empty() || out.iter().any(|prev| prev == trimmed) { + continue; + } + out.push(trimmed.to_string()); + } + } + out +} + +/// Build per-region drafts using the supplied target-VLM tokenizer. +/// +/// Regions are skipped when: +/// - their `label` matches one of `ignore_labels`, +/// - their text is empty / whitespace-only after markdown serialization, or +/// - the tokenizer yields zero tokens. +pub fn build_region_drafts( + elements: &[LayoutElement], + ignore_labels: &[String], + tokenize: F, +) -> Vec +where + F: Fn(&str) -> Vec, +{ + build_region_drafts_with_adapter( + elements, + ignore_labels, + TargetDraftAdapter::Markdown, + tokenize, + ) +} + +/// Build per-region drafts using the supplied target-VLM tokenizer and text adapter. +pub fn build_region_drafts_with_adapter( + elements: &[LayoutElement], + ignore_labels: &[String], + adapter: TargetDraftAdapter, + tokenize: F, +) -> Vec +where + F: Fn(&str) -> Vec, +{ + build_region_draft_candidates_with_adapter( + elements, + ignore_labels, + adapter, + |elem| elem.text.iter().cloned().collect(), + tokenize, + ) +} + +/// Build per-region drafts from multiple raw text candidates per layout element. +/// +/// This is the Stage-1 multi-draft plumbing point: OCR top-k candidates or +/// outputs from multiple drafters can be supplied through `text_candidates`. +/// Each candidate is serialized through the target adapter, tokenized, and +/// deduplicated before being packed into one [`RegionDraft`]. +pub fn build_region_draft_candidates_with_adapter( + elements: &[LayoutElement], + ignore_labels: &[String], + adapter: TargetDraftAdapter, + text_candidates: C, + tokenize: T, +) -> Vec +where + C: Fn(&LayoutElement) -> Vec, + T: Fn(&str) -> Vec, +{ + let mut out: Vec = Vec::with_capacity(elements.len()); + for elem in elements { + if let Some(label) = &elem.label + && ignore_labels.iter().any(|l| l == label) + { + continue; + } + let mut drafts: Vec = Vec::new(); + for candidate in text_candidates(elem) { + let mut candidate_elem = elem.clone(); + candidate_elem.text = Some(candidate); + let md = region_markdown_for(&candidate_elem, adapter); + if md.trim().is_empty() { + continue; + } + let toks = tokenize(&md); + if toks.is_empty() || drafts.iter().any(|draft| draft.tokens == toks) { + continue; + } + drafts.push(Draft::new(toks)); + } + if drafts.is_empty() { + continue; + } + out.push(RegionDraft { + bbox: bbox_xyxy(&elem.bbox), + drafts, + reading_order: elem.order_index.map(|x| x as usize), + kind: map_layout_kind(elem.element_type), + }); + } + out +} + +/// Build the Stage-2 page-level draft by tokenizing the full-page markdown +/// (as derived by [`page_markdown`]). +pub fn build_page_draft( + elements: &[LayoutElement], + ignore_labels: &[String], + tokenize: F, +) -> Draft +where + F: Fn(&str) -> Vec, +{ + let md = page_markdown(elements, ignore_labels); + Draft::new(tokenize(&md)) +} + +/// Build the Stage-2 page-level draft by joining already-verified Stage-1 +/// region outputs in reading order. Regions are separated with a blank line so +/// they look like distinct paragraphs / blocks to the target VLM. +pub fn page_draft_from_region_outputs(region_outputs_in_order: &[String], tokenize: F) -> Draft +where + F: Fn(&str) -> Vec, +{ + let mut joined = String::new(); + for m in region_outputs_in_order + .iter() + .map(|m| m.trim()) + .filter(|m| !m.is_empty()) + { + if !joined.is_empty() { + joined.push_str("\n\n"); + } + joined.push_str(m); + } + Draft::new(tokenize(&joined)) +} + +/// Adapt an [`OARStructure`](https://docs.rs/oar-ocr)-style [`StructureResult`] +/// into the `Vec` shape every model's `generate_hsd_full` +/// expects. +/// +/// The OAR structure pipeline (PP-DocLayout + PP-OCRv5 + table / formula / +/// seal predictors) populates per-element fields slightly differently than +/// the HSD drafting code expects: +/// +/// - Text-bearing regions (paragraphs, titles, captions, etc.) already have +/// their recognized text in `LayoutElement.text`. +/// - **Tables** keep their HTML in a separate [`TableResult`] keyed by bbox; +/// `LayoutElement.text` on the `Table` element is usually empty. +/// - **Formulas** keep their LaTeX in a separate [`FormulaResult`] keyed by +/// bbox; `LayoutElement.text` on the `Formula` element is usually empty. +/// +/// This helper backfills those two cases by IoU-matching (>0.5) the table / +/// formula side records onto their corresponding layout elements, then +/// dedups elements whose `(element_type, bbox)` IoU > 0.98 (the structure +/// pipeline occasionally double-emits a region as both layout and table / +/// formula). The result is ready to feed to +/// `HunyuanOcr::generate_hsd_full` / `GlmOcr::generate_hsd_full` / +/// `MinerU::generate_hsd_full` / `PaddleOcrVl::generate_hsd_full` without +/// further glue. +/// +/// [`OARStructure`]: https://docs.rs/oar-ocr/latest/oar_ocr/oarocr/structure/struct.OARStructure.html +/// [`TableResult`]: oar_ocr_core::domain::structure::TableResult +/// [`FormulaResult`]: oar_ocr_core::domain::structure::FormulaResult +pub fn structure_result_to_layout_elements(result: &StructureResult) -> Vec { + let mut elements = result.layout_elements.clone(); + for elem in &mut elements { + match elem.element_type { + LayoutElementType::Table => { + if elem.text.as_deref().is_none_or(|s| s.trim().is_empty()) + && let Some(html) = result + .tables + .iter() + .filter_map(|table| { + let html = table.html_structure.as_deref()?.trim(); + (!html.is_empty()).then(|| (table.bbox.iou(&elem.bbox), html)) + }) + .filter(|(iou, _)| *iou > 0.5) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(_, html)| html.to_string()) + { + elem.text = Some(html); + } + } + LayoutElementType::Formula => { + if elem.text.as_deref().is_none_or(|s| s.trim().is_empty()) + && let Some(latex) = result + .formulas + .iter() + .filter_map(|formula| { + let latex = formula.latex.trim(); + (!latex.is_empty()).then(|| (formula.bbox.iou(&elem.bbox), latex)) + }) + .filter(|(iou, _)| *iou > 0.5) + .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)) + .map(|(_, latex)| latex.to_string()) + { + elem.text = Some(latex); + } + } + _ => {} + } + } + let mut unique: Vec = Vec::with_capacity(elements.len()); + for elem in elements { + let duplicate = unique + .iter() + .any(|prev| prev.element_type == elem.element_type && prev.bbox.iou(&elem.bbox) > 0.98); + if !duplicate { + unique.push(elem); + } + } + unique +} + +#[cfg(test)] +mod tests { + use super::*; + use oar_ocr_core::processors::Point; + + /// Deterministic byte-tokenizer for tests: each UTF-8 byte → one token id. + fn byte_tok(s: &str) -> Vec { + s.bytes().map(|b| b as u32).collect() + } + + fn elem( + ty: LayoutElementType, + text: &str, + x1: f32, + y1: f32, + x2: f32, + y2: f32, + order: Option, + ) -> LayoutElement { + let bbox = BoundingBox::new(vec![ + Point::new(x1, y1), + Point::new(x2, y1), + Point::new(x2, y2), + Point::new(x1, y2), + ]); + let mut e = LayoutElement::new(bbox, ty, 0.99); + e.text = Some(text.to_string()); + e.order_index = order; + e + } + + #[test] + fn region_markdowns_skips_visual_only_and_ignored_and_empty() { + let elements = vec![ + elem( + LayoutElementType::Text, + "hello world", + 0.0, + 0.0, + 10.0, + 10.0, + Some(0), + ), + // Visual-only — must be skipped (no text to verify). + elem(LayoutElementType::Image, "", 10.0, 0.0, 20.0, 10.0, Some(1)), + // Empty text — must be skipped. + elem( + LayoutElementType::Text, + " ", + 0.0, + 10.0, + 10.0, + 20.0, + Some(2), + ), + elem( + LayoutElementType::Formula, + "x = 1", + 10.0, + 10.0, + 20.0, + 20.0, + Some(3), + ), + // Ignored by label. + { + let mut e = elem( + LayoutElementType::Text, + "skip me", + 0.0, + 20.0, + 10.0, + 30.0, + Some(4), + ); + e.label = Some("footer".to_string()); + e + }, + ]; + let got = region_markdowns(&elements, &["footer".to_string()]); + // Two surviving drafts: the text element and the formula element. + // Order matches input order (reading order), not sorted by anything. + assert_eq!(got.len(), 2); + assert!(got[0].contains("hello world")); + assert!(got[1].contains("x = 1")); + } + + #[test] + fn region_markdowns_per_region_independence() { + // Two text regions: each becomes a separate draft. The output is NOT + // a joined string — that's the paper Eq. 3 alignment. + let elements = vec![ + elem( + LayoutElementType::Text, + "first region", + 0.0, + 0.0, + 10.0, + 10.0, + Some(0), + ), + elem( + LayoutElementType::Text, + "second region", + 0.0, + 10.0, + 10.0, + 20.0, + Some(1), + ), + ]; + let got = region_markdowns(&elements, &[]); + assert_eq!(got.len(), 2); + // Neither draft contains the *other* draft's content — there's no + // pre-joining via "\n\n" inside the helper. + assert!(!got[0].contains("second region")); + assert!(!got[1].contains("first region")); + } + + #[test] + fn map_layout_kind_covers_main_types() { + assert_eq!( + map_layout_kind(LayoutElementType::DocTitle), + RegionKind::Title + ); + assert_eq!(map_layout_kind(LayoutElementType::Text), RegionKind::Text); + assert_eq!(map_layout_kind(LayoutElementType::Table), RegionKind::Table); + assert_eq!( + map_layout_kind(LayoutElementType::Formula), + RegionKind::Formula + ); + assert_eq!( + map_layout_kind(LayoutElementType::Image), + RegionKind::Figure + ); + assert_eq!( + map_layout_kind(LayoutElementType::Header), + RegionKind::Header + ); + assert_eq!(map_layout_kind(LayoutElementType::List), RegionKind::List); + } + + #[test] + fn bbox_xyxy_returns_axis_aligned() { + let b = BoundingBox::new(vec![ + Point::new(10.0, 20.0), + Point::new(110.0, 25.0), + Point::new(115.0, 80.0), + Point::new(15.0, 75.0), + ]); + let xy = bbox_xyxy(&b); + assert!((xy[0] - 10.0).abs() < 1e-3); + assert!((xy[1] - 20.0).abs() < 1e-3); + assert!((xy[2] - 115.0).abs() < 1e-3); + assert!((xy[3] - 80.0).abs() < 1e-3); + } + + #[test] + fn region_markdown_title_gets_heading_prefix() { + let e = elem( + LayoutElementType::DocTitle, + "Hello world", + 0.0, + 0.0, + 100.0, + 20.0, + Some(1), + ); + let md = region_markdown(&e); + assert!(md.starts_with("# "), "expected H1 prefix, got: {md:?}"); + assert!(md.contains("Hello world")); + } + + #[test] + fn region_markdown_formula_wrapped_in_dollars() { + let e = elem( + LayoutElementType::Formula, + "x^2 + y^2 = 1", + 0.0, + 0.0, + 100.0, + 20.0, + None, + ); + let md = region_markdown(&e); + assert!(md.contains("$$"), "expected $$ wrapping, got: {md:?}"); + } + + #[test] + fn region_markdown_text_passes_through() { + let e = elem( + LayoutElementType::Text, + "plain paragraph", + 0.0, + 0.0, + 100.0, + 20.0, + None, + ); + let md = region_markdown(&e); + assert!(md.contains("plain paragraph")); + } + + #[test] + fn hunyuanocr_adapter_uses_hunyuanocr_surface() { + let title = elem( + LayoutElementType::DocTitle, + "SEC.1 Introduction", + 0.0, + 0.0, + 100.0, + 20.0, + Some(0), + ); + let formula = elem( + LayoutElementType::Formula, + "x^2 + y^2 = 1", + 0.0, + 20.0, + 100.0, + 40.0, + Some(1), + ); + let page_number = elem( + LayoutElementType::Number, + "12", + 0.0, + 40.0, + 100.0, + 60.0, + Some(2), + ); + + assert_eq!( + region_markdown_for(&title, TargetDraftAdapter::HunyuanOcr), + "SEC. 1 Introduction" + ); + assert_eq!( + region_markdown_for(&formula, TargetDraftAdapter::HunyuanOcr), + "$$ x^2 + y^2 = 1 $$" + ); + assert_eq!( + region_markdown_for(&page_number, TargetDraftAdapter::HunyuanOcr), + "12\n\n---" + ); + } + + #[test] + fn paddleocr_vl_adapter_uses_raw_form() { + let title = elem( + LayoutElementType::DocTitle, + "A heading", + 0.0, + 0.0, + 100.0, + 20.0, + Some(0), + ); + let formula = elem( + LayoutElementType::Formula, + "x = 1", + 0.0, + 20.0, + 100.0, + 40.0, + Some(1), + ); + let formula_wrapped = elem( + LayoutElementType::Formula, + "$$y = 2$$", + 0.0, + 40.0, + 100.0, + 60.0, + Some(2), + ); + let table_otsl = elem( + LayoutElementType::Table, + "ab", + 0.0, + 60.0, + 100.0, + 80.0, + Some(3), + ); + let page_number = elem( + LayoutElementType::Number, + "12", + 0.0, + 80.0, + 100.0, + 100.0, + Some(4), + ); + + // Headings stay plain — PaddleOCR-VL has no `# ` shell. + assert_eq!( + region_markdown_for(&title, TargetDraftAdapter::PaddleOcrVl), + "A heading" + ); + // Bare LaTeX gets wrapped into `$$..$$` raw form (post-process strips it). + assert_eq!( + region_markdown_for(&formula, TargetDraftAdapter::PaddleOcrVl), + "$$x = 1$$" + ); + // Already wrapped formulas pass through unchanged. + assert_eq!( + region_markdown_for(&formula_wrapped, TargetDraftAdapter::PaddleOcrVl), + "$$y = 2$$" + ); + // OTSL passes through as-is (PaddleOCR-VL emits OTSL natively). + assert_eq!( + region_markdown_for(&table_otsl, TargetDraftAdapter::PaddleOcrVl), + "ab" + ); + // Page numbers stay plain — no HunyuanOCR-style `---` separator. + assert_eq!( + region_markdown_for(&page_number, TargetDraftAdapter::PaddleOcrVl), + "12" + ); + } + + #[test] + fn glmocr_adapter_strips_formula_wrappers() { + let title = elem( + LayoutElementType::DocTitle, + "A heading", + 0.0, + 0.0, + 100.0, + 20.0, + Some(0), + ); + let formula_wrapped = elem( + LayoutElementType::Formula, + "$$x = 1$$", + 0.0, + 20.0, + 100.0, + 40.0, + Some(1), + ); + let formula_bare = elem( + LayoutElementType::Formula, + "y = 2", + 0.0, + 40.0, + 100.0, + 60.0, + Some(2), + ); + let table = elem( + LayoutElementType::Table, + "
a
", + 0.0, + 60.0, + 100.0, + 80.0, + Some(3), + ); + + assert_eq!( + region_markdown_for(&title, TargetDraftAdapter::GlmOcr), + "A heading" + ); + // `$$..$$` stripped to bare LaTeX (GLM-OCR emits unwrapped LaTeX). + assert_eq!( + region_markdown_for(&formula_wrapped, TargetDraftAdapter::GlmOcr), + "x = 1" + ); + // Already bare LaTeX passes through. + assert_eq!( + region_markdown_for(&formula_bare, TargetDraftAdapter::GlmOcr), + "y = 2" + ); + // Tables pass through as-is. + assert_eq!( + region_markdown_for(&table, TargetDraftAdapter::GlmOcr), + "
a
" + ); + } + + #[test] + fn convert_raw_to_target_adapter_pipes_paddleocr_vl_raw_into_hunyuanocr() { + // Scenario: PaddleOCR-VL's `decode_tokens_raw` returns `$$x = 1$$` for + // a formula region. Feeding that as a draft for HunyuanOCR should pass + // through unchanged (HunyuanOCR also emits `$$ ... $$`). + let out = convert_raw_to_target_adapter( + "$$x = 1$$", + LayoutElementType::Formula, + TargetDraftAdapter::HunyuanOcr, + ); + assert_eq!(out, "$$x = 1$$"); + } + + #[test] + fn convert_raw_to_target_adapter_strips_wrapper_for_glmocr() { + // PaddleOCR-VL raw `$$x = 1$$` → GLM-OCR adapter strips wrapper to + // bare LaTeX since GLM-OCR's Formula Recognition prompt emits bare. + let out = convert_raw_to_target_adapter( + "$$x = 1$$", + LayoutElementType::Formula, + TargetDraftAdapter::GlmOcr, + ); + assert_eq!(out, "x = 1"); + } + + #[test] + fn convert_raw_to_target_adapter_rewraps_for_mineru() { + // PaddleOCR-VL raw `$$x = 1$$` → MinerU adapter rewraps to block form. + let out = convert_raw_to_target_adapter( + "$$x = 1$$", + LayoutElementType::Formula, + TargetDraftAdapter::MinerU, + ); + assert_eq!(out, "$$\nx = 1\n$$"); + } + + #[test] + fn convert_raw_to_target_adapter_table_passthrough_for_non_paddleocr_vl() { + // HunyuanOCR / GLM-OCR / MinerU adapters keep HTML tables as-is — these + // targets emit HTML natively. + let html = "
a
"; + assert_eq!( + convert_raw_to_target_adapter( + html, + LayoutElementType::Table, + TargetDraftAdapter::HunyuanOcr, + ), + html, + ); + assert_eq!( + convert_raw_to_target_adapter( + html, + LayoutElementType::Table, + TargetDraftAdapter::MinerU, + ), + html, + ); + } + + #[test] + fn paddleocr_vl_adapter_converts_html_table_to_otsl() { + // PaddleOCR-VL emits OTSL natively; feeding an HTML table draft now + // gets converted to OTSL via convert_html_to_otsl, closing the major + // byte-mismatch source that previously tanked AAL on table regions. + let html = "
ab
"; + assert_eq!( + convert_raw_to_target_adapter( + html, + LayoutElementType::Table, + TargetDraftAdapter::PaddleOcrVl, + ), + "ab", + ); + } + + #[test] + fn hunyuanocr_adapter_converts_otsl_to_html() { + // PaddleOCR-VL → HunyuanOCR: source emits OTSL, target wants HTML. + // The adapter must run convert_otsl_to_html so the table draft + // matches HunyuanOCR's natural output form. + let out = convert_raw_to_target_adapter( + "ab", + LayoutElementType::Table, + TargetDraftAdapter::HunyuanOcr, + ); + assert!(out.contains(""), "expected HTML, got: {out:?}"); + assert!(out.contains("")); + assert!(out.contains("")); + } + + #[test] + fn glmocr_adapter_converts_otsl_to_html() { + let out = convert_raw_to_target_adapter( + "ab", + LayoutElementType::Table, + TargetDraftAdapter::GlmOcr, + ); + assert!(out.contains("
ab
")); + assert!(out.contains("")); + } + + #[test] + fn mineru_adapter_converts_otsl_to_html() { + let out = convert_raw_to_target_adapter( + "ab", + LayoutElementType::Table, + TargetDraftAdapter::MinerU, + ); + assert!(out.contains("
a
")); + assert!(out.contains("")); + } + + #[test] + fn html_target_adapters_passthrough_existing_html() { + // Already-HTML drafts pass through untouched for HTML targets. + let html = "
a
x
"; + for adapter in [ + TargetDraftAdapter::HunyuanOcr, + TargetDraftAdapter::GlmOcr, + TargetDraftAdapter::MinerU, + ] { + assert_eq!( + convert_raw_to_target_adapter(html, LayoutElementType::Table, adapter), + html, + ); + } + } + + #[test] + fn paddleocr_vl_adapter_passes_through_otsl_unchanged() { + // Already-OTSL drafts pass through (no double conversion). + let otsl = "a"; + assert_eq!( + convert_raw_to_target_adapter( + otsl, + LayoutElementType::Table, + TargetDraftAdapter::PaddleOcrVl, + ), + otsl, + ); + } + + #[test] + fn mineru_adapter_uses_block_formula_and_plain_headings() { + let title = elem( + LayoutElementType::DocTitle, + "A heading", + 0.0, + 0.0, + 100.0, + 20.0, + Some(0), + ); + let formula = elem( + LayoutElementType::Formula, + "x = 1", + 0.0, + 20.0, + 100.0, + 40.0, + Some(1), + ); + let formula_wrapped = elem( + LayoutElementType::Formula, + "$$y = 2$$", + 0.0, + 40.0, + 100.0, + 60.0, + Some(2), + ); + let page_number = elem( + LayoutElementType::Number, + "12", + 0.0, + 60.0, + 100.0, + 80.0, + Some(3), + ); + + // Headings plain — MinerU's per-element step emits text only. + assert_eq!( + region_markdown_for(&title, TargetDraftAdapter::MinerU), + "A heading" + ); + // Block-form display formula `$$\n...\n$$`. + assert_eq!( + region_markdown_for(&formula, TargetDraftAdapter::MinerU), + "$$\nx = 1\n$$" + ); + // Re-wraps already-wrapped formulas to canonical block form. + assert_eq!( + region_markdown_for(&formula_wrapped, TargetDraftAdapter::MinerU), + "$$\ny = 2\n$$" + ); + // No `---` separator for page numbers. + assert_eq!( + region_markdown_for(&page_number, TargetDraftAdapter::MinerU), + "12" + ); + } + + #[test] + fn all_vlm_adapters_skip_visual_only_elements_regardless_of_text() { + // Image / HeaderImage / FooterImage / Seal carry no text the verifier + // can match against the target VLM's natural output. Every per-VLM + // adapter must return the empty string for these element types even + // when the drafter populated `text` (e.g. a caption transcribed under + // the image). Otherwise a stray draft like "*Figure 1: ...*" leaks + // into the matcher and tanks AAL. + let visual_types = [ + LayoutElementType::Image, + LayoutElementType::HeaderImage, + LayoutElementType::FooterImage, + LayoutElementType::Seal, + ]; + let vlm_adapters = [ + TargetDraftAdapter::HunyuanOcr, + TargetDraftAdapter::GlmOcr, + TargetDraftAdapter::MinerU, + TargetDraftAdapter::PaddleOcrVl, + ]; + for ty in visual_types { + let e = elem(ty, "nonempty caption", 0.0, 0.0, 100.0, 50.0, Some(0)); + for adapter in vlm_adapters { + let out = region_markdown_for(&e, adapter); + assert!( + out.is_empty(), + "{:?} adapter must skip visual-only {:?}, got: {:?}", + adapter, + ty, + out + ); + } + } + } + + #[test] + fn plain_adapter_does_not_add_markdown_shell() { + let title = elem( + LayoutElementType::DocTitle, + "A heading", + 0.0, + 0.0, + 100.0, + 20.0, + Some(0), + ); + let formula = elem( + LayoutElementType::Formula, + "x = 1", + 0.0, + 20.0, + 100.0, + 40.0, + Some(1), + ); + + assert_eq!( + region_markdown_for(&title, TargetDraftAdapter::PlainText), + "A heading" + ); + assert_eq!( + region_markdown_for(&formula, TargetDraftAdapter::PlainText), + "x = 1" + ); + } + + #[test] + fn build_region_drafts_skips_empty_and_filters_labels() { + let mut e1 = elem( + LayoutElementType::Text, + "first", + 0.0, + 0.0, + 10.0, + 10.0, + Some(1), + ); + let mut e2 = elem( + LayoutElementType::Text, + " ", + 0.0, + 0.0, + 10.0, + 10.0, + Some(2), + ); + let mut e3 = elem( + LayoutElementType::Header, + "skip-me", + 0.0, + 0.0, + 10.0, + 10.0, + Some(3), + ); + let mut e4 = elem( + LayoutElementType::Text, + "second", + 0.0, + 0.0, + 10.0, + 10.0, + Some(4), + ); + e1.label = Some("Text".into()); + e2.label = Some("Text".into()); + e3.label = Some("Header".into()); + e4.label = Some("Text".into()); + + let drafts = build_region_drafts(&[e1, e2, e3, e4], &["Header".to_string()], byte_tok); + assert_eq!(drafts.len(), 2); + assert_eq!(drafts[0].kind, RegionKind::Text); + assert_eq!(drafts[0].reading_order, Some(1)); + assert_eq!(drafts[1].reading_order, Some(4)); + // Tokens are non-empty. + assert_eq!(drafts[0].drafts.len(), 1); + assert_eq!(drafts[1].drafts.len(), 1); + assert!(!drafts[0].drafts[0].is_empty()); + assert!(!drafts[1].drafts[0].is_empty()); + } + + #[test] + fn build_region_drafts_keeps_kind_and_bbox() { + let e = elem( + LayoutElementType::Table, + "
a
", + 1.0, + 2.0, + 3.0, + 4.0, + Some(7), + ); + let drafts = build_region_drafts(&[e], &[], byte_tok); + assert_eq!(drafts.len(), 1); + assert_eq!(drafts[0].kind, RegionKind::Table); + assert_eq!(drafts[0].reading_order, Some(7)); + assert_eq!(drafts[0].bbox, [1.0, 2.0, 3.0, 4.0]); + assert_eq!(drafts[0].drafts.len(), 1); + } + + #[test] + fn build_region_draft_candidates_packs_and_dedups_multi_drafts() { + let e = elem( + LayoutElementType::Text, + "unused", + 1.0, + 2.0, + 3.0, + 4.0, + Some(7), + ); + let drafts = build_region_draft_candidates_with_adapter( + &[e], + &[], + TargetDraftAdapter::PlainText, + |_| vec!["alpha".to_string(), "alpha".to_string(), "beta".to_string()], + byte_tok, + ); + assert_eq!(drafts.len(), 1); + assert_eq!(drafts[0].drafts.len(), 2); + assert_eq!( + drafts[0].drafts[0].tokens, + b"alpha".iter().map(|&b| b as u32).collect::>() + ); + assert_eq!( + drafts[0].drafts[1].tokens, + b"beta".iter().map(|&b| b as u32).collect::>() + ); + assert_eq!(drafts[0].kind, RegionKind::Text); + assert_eq!(drafts[0].reading_order, Some(7)); + } + + #[test] + fn region_markdown_candidates_serializes_and_dedups_multi_drafts() { + let e = elem( + LayoutElementType::Formula, + "unused", + 1.0, + 2.0, + 3.0, + 4.0, + Some(7), + ); + let drafts = + region_markdown_candidates_for(&[e], &[], TargetDraftAdapter::PlainText, |_| { + vec!["x+y".to_string(), "x+y".to_string(), "a=b".to_string()] + }); + assert_eq!(drafts, vec!["x+y".to_string(), "a=b".to_string()]); + } + + #[test] + fn build_page_draft_concatenates() { + let e1 = elem( + LayoutElementType::DocTitle, + "T", + 0.0, + 0.0, + 10.0, + 10.0, + Some(1), + ); + let e2 = elem( + LayoutElementType::Text, + "body", + 0.0, + 0.0, + 10.0, + 10.0, + Some(2), + ); + let pd = build_page_draft(&[e1, e2], &[], byte_tok); + assert!(!pd.is_empty()); + // Page draft should be at least as long as title alone (sanity). + let title_only = byte_tok("# T"); + assert!(pd.tokens.len() >= title_only.len()); + } + + #[test] + fn page_draft_from_region_outputs_uses_blank_line_separator() { + let outputs = vec!["alpha".to_string(), "beta".to_string(), "gamma".to_string()]; + let d = page_draft_from_region_outputs(&outputs, byte_tok); + let s = String::from_utf8(d.tokens.iter().map(|&t| t as u8).collect()).unwrap(); + assert_eq!(s, "alpha\n\nbeta\n\ngamma"); + } + + #[test] + fn page_draft_from_region_outputs_handles_empty_input() { + let d = page_draft_from_region_outputs(&[], byte_tok); + assert!(d.is_empty()); + } + + #[test] + fn page_draft_from_region_outputs_trims_per_region_padding() { + let outputs = vec![" spaced ".to_string(), "next ".to_string()]; + let d = page_draft_from_region_outputs(&outputs, byte_tok); + let s = String::from_utf8(d.tokens.iter().map(|&t| t as u8).collect()).unwrap(); + assert_eq!(s, "spaced\n\nnext"); + } + + #[test] + fn page_draft_from_region_outputs_skips_empty_regions() { + let outputs = vec![ + "alpha".to_string(), + " ".to_string(), + String::new(), + "beta".to_string(), + ]; + let d = page_draft_from_region_outputs(&outputs, byte_tok); + let s = String::from_utf8(d.tokens.iter().map(|&t| t as u8).collect()).unwrap(); + assert_eq!(s, "alpha\n\nbeta"); + } +} diff --git a/oar-ocr-vl/src/hsd/kv_trim.rs b/oar-ocr-vl/src/hsd/kv_trim.rs new file mode 100644 index 0000000..a8d7ee1 --- /dev/null +++ b/oar-ocr-vl/src/hsd/kv_trim.rs @@ -0,0 +1,329 @@ +//! KV-cache wrapper supporting head-only trimming and gather. +//! +//! `candle_nn::kv_cache::KvCache` exposes `append` / `reset` but no way to +//! roll back the cache to an earlier sequence length, nor to keep an +//! arbitrary subset of positions. HSD's verifier needs both: after a +//! tree-attention forward pass we may accept fewer tokens than were +//! appended, and the unaccepted suffix must be discarded so the next forward +//! pass sees a clean prefix. +//! +//! ## Implementation note +//! +//! Append uses `Tensor::cat`, matching the public-API behaviour of +//! `candle_nn::KvCache` before its preallocation rewrite. We tried a +//! preallocation + `slice_set` strategy (see git history) for parity with +//! `candle_nn::kv_cache::Cache`, but the resulting K/V values caused HSD +//! acceptance to collapse on the same workloads where the cat-based +//! implementation was correct, despite kv-only unit tests passing. The cause +//! turned out to be subtle and we reverted; rare per-page slowdowns observed +//! on long benchmarks remain an open issue tracked separately. + +use candle_core::{Result, Tensor}; + +/// Append-and-trim KV cache for use during HSD verification. +/// +/// `Clone` mirrors `candle_nn::kv_cache::KvCache::Clone`: it produces a +/// shallow copy that shares the same underlying `Tensor` storage. Cheap; only +/// useful for structures that need to derive `Clone` (e.g. GLM-OCR's text +/// model, which is held by value in multiple places). +#[derive(Debug, Clone)] +pub struct TrimmableKvCache { + /// Concatenation axis (typically `2` for the seq dim of `(B, H, T, D)` tensors). + cat_dim: usize, + /// Cached `(keys, values)`, each shape `(B, H, cur_len, D)`. `None` until + /// first `append`. K and V are stored together so the "both populated or + /// both empty" invariant is enforced by the type system — earlier + /// revisions used parallel `Option` fields and had to assert this + /// invariant manually. + kv: Option<(Tensor, Tensor)>, + cur_len: usize, + /// Configured sequence-length capacity retained for parity with + /// `candle_nn::kv_cache::KvCache::new`. This wrapper does not enforce it. + /// Only read by HSD's `max_seq_len()` accessor. + #[cfg_attr(not(feature = "hsd"), allow(dead_code))] + max_len: usize, +} + +// `TrimmableKvCache` lives at the crate root so every model's attention path +// can store one. Most of its methods are only consumed by the HSD verify +// driver (`trim_to`, `keep_indices`, `current_seq_len`, `max_seq_len`, `k`, +// `v`) — silence dead-code warnings for those when `hsd` is off without +// gating the methods themselves (they remain available for external +// callers / future re-enabling). +#[cfg_attr(not(feature = "hsd"), allow(dead_code))] +impl TrimmableKvCache { + pub fn new(cat_dim: usize, max_len: usize) -> Self { + Self { + cat_dim, + kv: None, + cur_len: 0, + max_len, + } + } + + /// Append `(k_new, v_new)` to the cache and return the concatenated + /// `(K_all, V_all)` tensors that the attention path will consume. + /// + /// Cat-based growth (not preallocated). We tried a `slice_set` / + /// preallocated-buffer rewrite to match `candle_nn::kv_cache::Cache` — + /// nsys profiling had pointed to per-step cat as a candidate bottleneck + /// on long-output pages. The rewrite passed unit tests, didn't change + /// per-page wall time on either fast or slow images, and *regressed* + /// HSD acceptance length (AAL collapsed from ~22 → ~15 on 30 v1.5 + /// pages). Reverted: the wall-time outliers we were chasing are + /// dominated by candle's per-op CPU dispatch on long decode loops, not + /// by KV-cache copy overhead. + pub fn append(&mut self, k_new: &Tensor, v_new: &Tensor) -> Result<(Tensor, Tensor)> { + let new_len = k_new.dim(self.cat_dim)?; + let (k_all, v_all) = match self.kv.as_ref() { + None => (k_new.clone(), v_new.clone()), + Some((k_old, v_old)) => { + let k = Tensor::cat(&[k_old, k_new], self.cat_dim)?.contiguous()?; + let v = Tensor::cat(&[v_old, v_new], self.cat_dim)?.contiguous()?; + (k, v) + } + }; + self.kv = Some((k_all.clone(), v_all.clone())); + self.cur_len += new_len; + Ok((k_all, v_all)) + } + + /// Drop everything at sequence indices `>= len`. No-op if `len >= cur_len`. + pub fn trim_to(&mut self, len: usize) -> Result<()> { + if len >= self.cur_len { + return Ok(()); + } + if len == 0 { + self.reset(); + return Ok(()); + } + // `cur_len > 0` implies `kv.is_some()` by the invariant maintained in + // `append` / `reset`. + let Some((k_old, v_old)) = self.kv.as_ref() else { + return Err(candle_core::Error::Msg( + "TrimmableKvCache::trim_to: cache empty but cur_len > 0".into(), + )); + }; + let k = k_old.narrow(self.cat_dim, 0, len)?.contiguous()?; + let v = v_old.narrow(self.cat_dim, 0, len)?.contiguous()?; + self.kv = Some((k, v)); + self.cur_len = len; + Ok(()) + } + + /// Gather the cache to keep only the supplied positions, in the supplied + /// order. This is the operation HSD performs after a tree-attention + /// verification pass: keep `[0..prefix_kv_len)` (the accepted history) + /// then append the path-node positions. + /// + /// Each index must be `< current_seq_len()`. Indices may be repeated, but + /// in normal HSD use they are distinct. + pub fn keep_indices(&mut self, indices: &[u32]) -> Result<()> { + if indices.is_empty() { + self.reset(); + return Ok(()); + } + for &i in indices { + if (i as usize) >= self.cur_len { + return Err(candle_core::Error::Msg(format!( + "TrimmableKvCache::keep_indices: index {} out of bounds (cur_len={})", + i, self.cur_len + ))); + } + } + let Some((k, v)) = self.kv.as_ref() else { + return Err(candle_core::Error::Msg( + "TrimmableKvCache::keep_indices on empty cache".into(), + )); + }; + if indices.iter().enumerate().all(|(i, &x)| x as usize == i) { + return self.trim_to(indices.len()); + } + let device = k.device(); + // `Tensor::new(&[u32], device)` lands the slice directly via candle's + // `NdArray for &[S]` impl — no `indices.to_vec()` allocation needed. + let idx_t = Tensor::new(indices, device)?; + let new_k = k.index_select(&idx_t, self.cat_dim)?.contiguous()?; + let new_v = v.index_select(&idx_t, self.cat_dim)?.contiguous()?; + self.kv = Some((new_k, new_v)); + self.cur_len = indices.len(); + Ok(()) + } + + pub fn current_seq_len(&self) -> usize { + self.cur_len + } + + pub fn max_seq_len(&self) -> usize { + self.max_len + } + + pub fn reset(&mut self) { + self.kv = None; + self.cur_len = 0; + } + + /// Borrow the current K cache, if any. + pub fn k(&self) -> Option<&Tensor> { + self.kv.as_ref().map(|(k, _)| k) + } + + /// Borrow the current V cache, if any. + pub fn v(&self) -> Option<&Tensor> { + self.kv.as_ref().map(|(_, v)| v) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::{DType, Device}; + + fn dev() -> Device { + Device::Cpu + } + + #[test] + fn append_grows_and_returns_full_cache() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + let a = Tensor::zeros((1, 2, 3, 4), DType::F32, &dev())?; + let b = Tensor::ones((1, 2, 5, 4), DType::F32, &dev())?; + let (k1, _) = c.append(&a, &a)?; + assert_eq!(k1.dims(), &[1, 2, 3, 4]); + assert_eq!(c.current_seq_len(), 3); + let (k2, _) = c.append(&b, &b)?; + assert_eq!(k2.dims(), &[1, 2, 8, 4]); + assert_eq!(c.current_seq_len(), 8); + Ok(()) + } + + #[test] + fn trim_to_shorter_drops_tail() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + let t = Tensor::zeros((1, 2, 6, 4), DType::F32, &dev())?; + c.append(&t, &t)?; + c.trim_to(4)?; + assert_eq!(c.current_seq_len(), 4); + assert_eq!(c.k().unwrap().dims(), &[1, 2, 4, 4]); + Ok(()) + } + + #[test] + fn trim_to_zero_resets() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + let t = Tensor::zeros((1, 2, 6, 4), DType::F32, &dev())?; + c.append(&t, &t)?; + c.trim_to(0)?; + assert_eq!(c.current_seq_len(), 0); + assert!(c.k().is_none()); + assert!(c.v().is_none()); + Ok(()) + } + + #[test] + fn trim_to_longer_is_noop() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + let t = Tensor::zeros((1, 2, 3, 4), DType::F32, &dev())?; + c.append(&t, &t)?; + c.trim_to(10)?; + assert_eq!(c.current_seq_len(), 3); + Ok(()) + } + + #[test] + fn reset_then_append_works() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + let t = Tensor::zeros((1, 2, 3, 4), DType::F32, &dev())?; + c.append(&t, &t)?; + c.reset(); + let s = Tensor::ones((1, 2, 2, 4), DType::F32, &dev())?; + let (k, _) = c.append(&s, &s)?; + assert_eq!(k.dims(), &[1, 2, 2, 4]); + assert_eq!(c.current_seq_len(), 2); + Ok(()) + } + + #[test] + fn trim_then_append_concats_correctly() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + let a = Tensor::zeros((1, 2, 6, 4), DType::F32, &dev())?; + let b = Tensor::ones((1, 2, 3, 4), DType::F32, &dev())?; + c.append(&a, &a)?; + c.trim_to(4)?; + let (k, _) = c.append(&b, &b)?; + assert_eq!(k.dims(), &[1, 2, 7, 4]); + assert_eq!(c.current_seq_len(), 7); + Ok(()) + } + + #[test] + fn empty_then_trim_is_noop() -> Result<()> { + let mut c = TrimmableKvCache::new(2, 64); + c.trim_to(5)?; + assert_eq!(c.current_seq_len(), 0); + Ok(()) + } + + /// Build a deterministic cache where K[..., t, 0] == t (so we can verify + /// the gathered ordering after `keep_indices`). + fn build_indexed_cache(len: usize) -> Result { + let mut c = TrimmableKvCache::new(2, 128); + for t in 0..len { + let k = Tensor::from_vec(vec![t as f32, 0.0, 0.0, 0.0], (1, 1, 1, 4), &dev())?; + c.append(&k, &k)?; + } + Ok(c) + } + + #[test] + fn keep_indices_gathers_in_order() -> Result<()> { + let mut c = build_indexed_cache(8)?; + c.keep_indices(&[0, 1, 3, 5])?; + assert_eq!(c.current_seq_len(), 4); + let k = c.k().unwrap(); + let raw: Vec = k.flatten_all()?.to_vec1()?; + assert_eq!(raw[0], 0.0); + assert_eq!(raw[4], 1.0); + assert_eq!(raw[8], 3.0); + assert_eq!(raw[12], 5.0); + Ok(()) + } + + #[test] + fn keep_indices_prefix_uses_trim_fast_path() -> Result<()> { + let mut c = build_indexed_cache(6)?; + c.keep_indices(&[0, 1, 2])?; + assert_eq!(c.current_seq_len(), 3); + Ok(()) + } + + #[test] + fn keep_indices_empty_resets() -> Result<()> { + let mut c = build_indexed_cache(3)?; + c.keep_indices(&[])?; + assert_eq!(c.current_seq_len(), 0); + assert!(c.k().is_none()); + Ok(()) + } + + #[test] + fn keep_indices_out_of_bounds_errors() { + let mut c = build_indexed_cache(3).unwrap(); + let err = c.keep_indices(&[0, 5]).unwrap_err().to_string(); + assert!(err.contains("out of bounds"), "unexpected error: {err}"); + } + + #[test] + fn keep_indices_then_append_works() -> Result<()> { + let mut c = build_indexed_cache(5)?; + c.keep_indices(&[1, 3])?; + let extra = Tensor::from_vec(vec![99.0f32, 0.0, 0.0, 0.0], (1, 1, 1, 4), &dev())?; + c.append(&extra, &extra)?; + assert_eq!(c.current_seq_len(), 3); + let raw: Vec = c.k().unwrap().flatten_all()?.to_vec1()?; + assert_eq!(raw[0], 1.0); + assert_eq!(raw[4], 3.0); + assert_eq!(raw[8], 99.0); + Ok(()) + } +} diff --git a/oar-ocr-vl/src/hsd/matching.rs b/oar-ocr-vl/src/hsd/matching.rs new file mode 100644 index 0000000..ad0cf2b --- /dev/null +++ b/oar-ocr-vl/src/hsd/matching.rs @@ -0,0 +1,369 @@ +//! Draft-target matching with a sliding reference window (paper §3.2). +//! +//! Given the target VLM's accepted token sequence `ŷ_{1:t}` and a set of fixed +//! drafts `Ỹ`, this module finds every position in every draft where the last +//! `n` accepted tokens reappear, and extracts the suffix that strictly follows +//! each such match. The collected suffixes form the candidate set `C` consumed +//! by the [`super::prefix_tree`] builder. + +use super::types::{Draft, DsvConfig}; + +/// A candidate suffix extracted from one draft after a successful window match. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Candidate { + /// Index of the source draft inside the input slice (debug / introspection). + pub draft_idx: usize, + /// Position in the source draft where the *suffix* starts (i.e. `j + n` in + /// the paper's notation). Mostly useful for diagnostics. + pub suffix_start: usize, + /// Token sequence following the matched window. + pub tokens: Vec, +} + +/// Slide a length-`n` reference window over each draft and collect every suffix +/// that follows a match. +/// +/// The effective window length is `min(accepted_tail.len(), cfg.window_len)`, +/// so this also works during the first few generation steps when the accepted +/// history is shorter than `n`. If `accepted_tail` is empty, every draft's +/// leading prefix becomes a candidate (interpreted as a zero-length window +/// matching the empty string at position 0). +/// +/// Each candidate is truncated to at most `cfg.max_suffix_len` tokens. The +/// total number of candidates returned is capped at +/// `cfg.max_candidates_per_step`. When the cap is hit, half of the budget keeps +/// early scan-order candidates and the rest keeps the longest remaining +/// suffixes. This preserves some draft/position diversity while still favoring +/// paths that provide more parallel verification headroom. +pub fn collect_candidates( + accepted_tail: &[u32], + accepted_len: usize, + drafts: &[Draft], + cfg: &DsvConfig, +) -> Vec { + let mut out: Vec = Vec::new(); + + let win_len = accepted_tail.len().min(cfg.window_len); + let window = &accepted_tail[accepted_tail.len() - win_len..]; + + for (di, draft) in drafts.iter().enumerate() { + if draft.tokens.len() <= win_len { + // Even with a perfect match there'd be no suffix to extract. + continue; + } + + if win_len == 0 { + // No accepted history. Paper's formulation has no match with an + // empty window → no candidates at step 0 (driver falls back to + // step_one for the first token). `cold_start_full_draft = true` + // (default) softens this by emitting each draft's leading prefix + // as a single candidate so the very first step still has tree + // material to verify. + if cfg.cold_start_full_draft { + let take = cfg.max_suffix_len.min(draft.tokens.len()); + if take > 0 { + out.push(Candidate { + draft_idx: di, + suffix_start: 0, + tokens: draft.tokens[..take].to_vec(), + }); + } + } + continue; + } + + // Naive O(|draft| * n) scan. The drafts here are page-level Markdown, + // typically a few thousand tokens, so this is comfortably fast. If a + // profile ever shows it as a hotspot, switch to a rolling-hash search. + let last_start = draft.tokens.len() - win_len; + let mut j = 0; + while j <= last_start { + if &draft.tokens[j..j + win_len] == window { + let suf_start = j + win_len; + if suf_start < draft.tokens.len() { + let take = cfg.max_suffix_len.min(draft.tokens.len() - suf_start); + if take > 0 { + out.push(Candidate { + draft_idx: di, + suffix_start: suf_start, + tokens: draft.tokens[suf_start..suf_start + take].to_vec(), + }); + } + } + } + j += 1; + } + } + + cap_candidates(out, accepted_len, cfg.max_candidates_per_step) +} + +fn cap_candidates( + out: Vec, + accepted_len: usize, + max_candidates: usize, +) -> Vec { + if out.len() <= max_candidates { + return out; + } + if max_candidates == 0 { + return Vec::new(); + } + + // Position-aware cap: prefer candidates whose match position is closest to + // the **current decode point** (`accepted_len`). For an oracle / aligned + // draft, `suffix_start == accepted_len` exactly — and that's the only + // candidate whose suffix is the correct continuation. Both the previous + // "earliest by scan order" and the (briefly-shipped) "latest by scan + // order" policies were wrong in opposite ways: earliest dropped the + // current-position match in favor of stale repeats from before; latest + // dropped it in favor of stale repeats from after. Distance to + // `accepted_len` is the only signal that works across patterns. + // + // Reserve a smaller "long-suffix diversity" slot in case the ideal-position + // match happens to have a short or empty suffix (e.g. the 3-gram appears + // near the end of the draft and there's not much past it). With a 128-cap, + // that's 96 by position + 32 by suffix length. + let position_quota = max_candidates.saturating_sub(max_candidates / 4).max(1); + let long_quota = max_candidates - position_quota; + let mut selected = vec![false; out.len()]; + let mut capped: Vec = Vec::with_capacity(max_candidates); + + let mut by_distance: Vec = (0..out.len()).collect(); + by_distance.sort_by(|&a, &b| { + let da = out[a].suffix_start.abs_diff(accepted_len); + let db = out[b].suffix_start.abs_diff(accepted_len); + da.cmp(&db) + .then_with(|| out[b].tokens.len().cmp(&out[a].tokens.len())) + .then_with(|| a.cmp(&b)) + }); + for &idx in by_distance.iter().take(position_quota) { + selected[idx] = true; + capped.push(out[idx].clone()); + } + + let mut remaining: Vec = (0..out.len()).filter(|&idx| !selected[idx]).collect(); + remaining.sort_by(|&a, &b| { + out[b] + .tokens + .len() + .cmp(&out[a].tokens.len()) + .then_with(|| a.cmp(&b)) + }); + + for idx in remaining.into_iter().take(long_quota) { + capped.push(out[idx].clone()); + } + + capped +} + +#[cfg(test)] +mod tests { + use super::*; + + fn d(tokens: &[u32]) -> Draft { + Draft::new(tokens.to_vec()) + } + + fn cfg() -> DsvConfig { + DsvConfig { + window_len: 3, + tau: 0.75, + max_candidates_per_step: 32, + max_suffix_len: 256, + ..Default::default() + } + } + + #[test] + fn empty_inputs() { + assert!(collect_candidates(&[], 0, &[], &cfg()).is_empty()); + assert!(collect_candidates(&[1, 2, 3], 3, &[], &cfg()).is_empty()); + } + + #[test] + fn empty_draft_skipped() { + let drafts = vec![d(&[])]; + assert!(collect_candidates(&[1, 2, 3], 3, &drafts, &cfg()).is_empty()); + } + + #[test] + fn single_match_extracts_suffix() { + let drafts = vec![d(&[10, 20, 30, 1, 2, 3, 40, 50, 60])]; + let got = collect_candidates(&[1, 2, 3], 3, &drafts, &cfg()); + assert_eq!(got.len(), 1); + assert_eq!(got[0].draft_idx, 0); + assert_eq!(got[0].suffix_start, 6); + assert_eq!(got[0].tokens, vec![40, 50, 60]); + } + + #[test] + fn no_match_no_candidates() { + let drafts = vec![d(&[10, 20, 30, 40, 50])]; + assert!(collect_candidates(&[1, 2, 3], 3, &drafts, &cfg()).is_empty()); + } + + #[test] + fn match_at_end_yields_no_suffix() { + let drafts = vec![d(&[10, 1, 2, 3])]; + // window matches at position 1, but there's no token after it. + assert!(collect_candidates(&[1, 2, 3], 3, &drafts, &cfg()).is_empty()); + } + + #[test] + fn multiple_matches_same_draft() { + let drafts = vec![d(&[1, 2, 3, 7, 1, 2, 3, 8, 9])]; + let got = collect_candidates(&[1, 2, 3], 3, &drafts, &cfg()); + // Two matches at positions 0 and 4, two distinct suffixes. + let suffixes: Vec<_> = got.iter().map(|c| c.tokens.clone()).collect(); + assert!(suffixes.contains(&vec![7, 1, 2, 3, 8, 9])); + assert!(suffixes.contains(&vec![8, 9])); + } + + #[test] + fn matches_across_multiple_drafts() { + let drafts = vec![d(&[1, 2, 3, 4, 5]), d(&[9, 1, 2, 3, 6])]; + let got = collect_candidates(&[1, 2, 3], 3, &drafts, &cfg()); + assert_eq!(got.len(), 2); + let by_draft: Vec<_> = got + .iter() + .map(|c| (c.draft_idx, c.tokens.clone())) + .collect(); + assert!(by_draft.contains(&(0, vec![4, 5]))); + assert!(by_draft.contains(&(1, vec![6]))); + } + + #[test] + fn shorter_window_when_history_short() { + // accepted_tail shorter than cfg.window_len: window shrinks gracefully. + let drafts = vec![d(&[7, 1, 2, 3, 4])]; + let got = collect_candidates(&[1], 1, &drafts, &cfg()); + assert_eq!(got.len(), 1); + assert_eq!(got[0].tokens, vec![2, 3, 4]); + } + + #[test] + fn empty_history_uses_each_draft_prefix() { + let cfg = DsvConfig { + max_suffix_len: 3, + ..cfg() + }; + let drafts = vec![d(&[10, 11, 12, 13]), d(&[20, 21])]; + let got = collect_candidates(&[], 0, &drafts, &cfg); + assert_eq!(got.len(), 2); + assert_eq!(got[0].tokens, vec![10, 11, 12]); // truncated by max_suffix_len + assert_eq!(got[1].tokens, vec![20, 21]); + } + + #[test] + fn suffix_length_capped() { + let cfg = DsvConfig { + max_suffix_len: 2, + ..cfg() + }; + let drafts = vec![d(&[1, 2, 3, 7, 8, 9, 10, 11])]; + let got = collect_candidates(&[1, 2, 3], 3, &drafts, &cfg); + assert_eq!(got.len(), 1); + assert_eq!(got[0].tokens, vec![7, 8]); + } + + #[test] + fn cap_prefers_match_at_current_decode_position() { + // Position-aware cap: with `accepted_len = 100`, the candidate whose + // `suffix_start` is closest to 100 wins the position-quota slot, even + // if it's not the longest or first/last in scan order. This is the + // 2026-05-14 fix that replaced the earlier "head" and "latest" cap + // heuristics — both were wrong for opposite reasons (each dropped + // the current-position match in favor of stale repeats on the wrong + // side). For oracle / well-aligned drafts, the current decode point + // is the only signal that picks the correct continuation across + // patterns. + // + // Setup: a single 200-token draft where the 3-gram `[1, 2, 3]` + // appears at positions 0, 50, 100, 150 — each followed by a unique + // marker (10, 50, 100, 150). With cap=1 and accepted_len=103 + // (suffix would start at draft position 103), the match at + // suffix_start=103 (closest to 103) must win. + let cfg = DsvConfig { + max_candidates_per_step: 1, + max_suffix_len: 10, + ..cfg() + }; + let mut tokens = vec![99u32; 200]; + // place markers at positions: window starts at 0, 50, 100, 150 + // so suffix_starts are 3, 53, 103, 153. + for (i, &start) in [0usize, 50, 100, 150].iter().enumerate() { + tokens[start] = 1; + tokens[start + 1] = 2; + tokens[start + 2] = 3; + // distinctive marker as the first suffix token at start+3. + tokens[start + 3] = 1000 + (i as u32); + } + let drafts = vec![d(&tokens)]; + // accepted_len = 103 means the current decode point is draft pos 103. + // Closest match suffix_start = 103 → marker 1002. + let got = collect_candidates(&[1, 2, 3], 103, &drafts, &cfg); + assert_eq!(got.len(), 1, "cap=1 should keep exactly one candidate"); + assert_eq!( + got[0].tokens.first(), + Some(&1002), + "must keep match closest to accepted_len=103, got tokens={:?}", + got[0].tokens + ); + } + + #[test] + fn cap_at_repeated_3gram_keeps_current_position() { + // Regression test (refined 2026-05-14). Pattern `[X, Y, Z]` repeats + // 20 times in a row, each followed by a unique marker. With + // `accepted_len` set to the position WHERE WE ARE in the draft (here + // simulating "we just accepted up to and including marker 9, so we + // expect marker 10 next"), the position-aware cap must keep the + // marker-10 match, not the earliest (marker 0 — original head bug) + // or the latest (marker 19 — the intermediate "latest" bug). + let cfg = DsvConfig { + max_candidates_per_step: 1, + ..cfg() + }; + let mut tokens: Vec = Vec::new(); + for i in 0..20u32 { + tokens.extend_from_slice(&[100, 101, 102, 200 + i]); + } + let drafts = vec![d(&tokens)]; + // Each unit is 4 tokens. The 3-gram at unit i starts at position + // 4*i; its suffix begins at position 4*i + 3 (marker 200 + i). + // Simulating "we have accepted up to position 4*10 + 3 = 43, so the + // next correct token is marker 210 at suffix_start = 43." + let got = collect_candidates(&[100, 101, 102], 43, &drafts, &cfg); + assert_eq!(got.len(), 1); + assert_eq!( + got[0].tokens.first(), + Some(&210), + "must keep current-position match (marker 210), got {:?}", + got[0].tokens + ); + } + + #[test] + fn candidate_count_zero_drops_all_candidates() { + let cfg = DsvConfig { + max_candidates_per_step: 0, + ..cfg() + }; + let drafts = vec![d(&[1, 2, 3, 9])]; + assert!(collect_candidates(&[1, 2, 3], 3, &drafts, &cfg).is_empty()); + } + + #[test] + fn cold_start_disabled_yields_no_candidates() { + // With cold_start_full_draft = false and an empty accepted tail the + // matcher must produce *no* candidates. + let cfg = DsvConfig { + cold_start_full_draft: false, + ..cfg() + }; + let drafts = vec![d(&[10, 11, 12, 13])]; + assert!(collect_candidates(&[], 0, &drafts, &cfg).is_empty()); + } +} diff --git a/oar-ocr-vl/src/hsd/mod.rs b/oar-ocr-vl/src/hsd/mod.rs new file mode 100644 index 0000000..0a588d4 --- /dev/null +++ b/oar-ocr-vl/src/hsd/mod.rs @@ -0,0 +1,45 @@ +//! Hierarchical Speculative Decoding (HSD) for VLM-based document parsers. +//! +//! Reference: Liao et al., "HSD: Training-Free Acceleration for Document Parsing +//! Vision-Language Model with Hierarchical Speculative Decoding" (arXiv:2602.12957). +//! +//! ## High-level flow +//! +//! 1. A lightweight pipeline drafter (e.g., PP-DocLayout + PP-OCRv5) emits +//! a region partition `R = {r_i}` and one or more coarse text candidates +//! for each region. +//! 2. Each region candidate is tokenized with the **target VLM's** tokenizer, yielding +//! [`RegionDraft`]. Stage 2 keeps page-level drafts as an unordered collection +//! of [`Draft`] values, one entry per region/output. +//! 3. Stage 1 (region-level): for each `r_i`, the target VLM runs `SpecDecode` on the +//! cropped region image and the corresponding candidate set, in parallel across regions. +//! 4. Stage 2 (page-level): the Stage-1 outputs are passed as the paper's unordered +//! set `Y^pg = {y^(i)}`. Each output remains a separate [`Draft`], so the +//! sliding-window matcher scans each draft independently before the full-page +//! verification restores global coherence. +//! +//! ## Module layout +//! +//! - [`types`] — shared data structures (drafts, configs, stats). +//! - [`drafting`] — layout-output to HSD draft conversion helpers. +//! - [`kv_trim`] — append/trim/gather KV-cache wrapper used by backends. +//! - [`matching`] — draft-target matching with a sliding reference window. +//! - [`prefix_tree`] — prefix-tree construction over candidate suffixes. +//! - [`verify`] — DSV `SpecDecode` operator (model-side hooks live here). +//! - [`backend_util`] — mechanical helpers (pos-id / keep-index construction) +//! shared by every `SpecBackend` impl. + +pub mod backend_util; +pub mod drafting; +pub mod matching; +pub mod prefix_tree; +pub mod types; +pub mod verify; + +// `kv_trim` is declared at crate root (see `lib.rs`) so the cache type stays +// available without the `hsd` feature; re-export the type here so +// HSD-internal callers keep using `crate::hsd::TrimmableKvCache`. +pub use crate::kv_trim::TrimmableKvCache; +pub use types::{ + AcceptStats, Draft, DsvConfig, HsdConfig, HsdStats, RegionDraft, RegionStageStats, StageStats, +}; diff --git a/oar-ocr-vl/src/hsd/prefix_tree.rs b/oar-ocr-vl/src/hsd/prefix_tree.rs new file mode 100644 index 0000000..632ea97 --- /dev/null +++ b/oar-ocr-vl/src/hsd/prefix_tree.rs @@ -0,0 +1,198 @@ +//! Prefix tree for parallel verification of multiple candidate suffixes +//! (paper §3.2 / Fig. 2b). +//! +//! Each candidate suffix produced by [`super::matching::collect_candidates`] is +//! inserted into a tree that shares common prefixes. Common prefixes are merged +//! so a single packed forward pass through the model can verify all candidates +//! at once (under a tree-ancestry attention mask, built in [`super::verify`]). +//! +//! ## Indexing convention +//! +//! - The root is implicit: it carries no token, has no numeric node id, and is +//! represented as parent `None`. +//! - All `Vec`s in [`PrefixTree`] are indexed by non-root node id starting at +//! 0 — i.e. `tokens[0]` is the first inserted child of the root. +//! - `parents[i] = None` means the parent is the root. + +use super::matching::Candidate; +use std::collections::HashMap; + +/// A flattened prefix tree, ready to be turned into a packed token sequence +/// plus a tree-ancestry attention mask. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct PrefixTree { + /// Token id at each non-root node. + pub tokens: Vec, + /// Parent index for each non-root node. `None` ⇒ child of root. + pub parents: Vec>, + /// Distance from root, counted in tokens (root-children have depth 1). + pub depths: Vec, + /// `leaf_for[i] = Some((cand_idx, depth))` if a candidate ends exactly at + /// node `i`. If multiple candidates end here, the deepest candidate wins; + /// equal-depth ties keep the first inserted candidate. + pub leaf_for: Vec>, +} + +impl PrefixTree { + pub fn num_nodes(&self) -> usize { + self.tokens.len() + } + + pub fn is_empty(&self) -> bool { + self.tokens.is_empty() + } + + /// Walk from `node` up to the root, collecting tokens in root → node order. + /// Useful for diagnostics and for reconstructing accepted segments after + /// greedy traversal. + pub fn path_tokens(&self, node: usize) -> Vec { + let mut path = Vec::with_capacity(self.depths[node] as usize); + let mut cur = Some(node); + while let Some(i) = cur { + path.push(self.tokens[i]); + cur = self.parents[i]; + } + path.reverse(); + path + } + + /// Indices of `node`'s direct children. + /// + /// This is a linear scan over the flattened tree. Tree sizes are bounded by + /// `DsvConfig::{max_candidates_per_step,max_suffix_len}`; if larger trees + /// become common, cache adjacency lists during construction. + pub fn children_of(&self, node: Option) -> Vec { + self.parents + .iter() + .enumerate() + .filter_map(|(i, p)| if *p == node { Some(i) } else { None }) + .collect() + } +} + +/// Build a [`PrefixTree`] from candidate suffixes. +/// +/// Candidates are inserted in the order given. Duplicate paths collapse onto +/// the same set of nodes. A candidate that is a prefix of another marks its +/// terminal node as a leaf without breaking the longer candidate's path. +pub fn build_prefix_tree(candidates: &[Candidate]) -> PrefixTree { + let mut tree = PrefixTree::default(); + // (parent_node, token) -> child_node, where parent_node == None means root. + let mut child_map: HashMap<(Option, u32), usize> = HashMap::new(); + + for (cand_idx, cand) in candidates.iter().enumerate() { + let mut parent: Option = None; + let mut depth: u32 = 0; + for &tok in &cand.tokens { + depth += 1; + let node = match child_map.get(&(parent, tok)) { + Some(&existing) => existing, + None => { + let new_idx = tree.tokens.len(); + tree.tokens.push(tok); + tree.parents.push(parent); + tree.depths.push(depth); + tree.leaf_for.push(None); + child_map.insert((parent, tok), new_idx); + new_idx + } + }; + parent = Some(node); + } + // Record this candidate's terminal node — deepest wins, equal-depth + // duplicates keep the first inserted candidate. + if let Some(end) = parent { + let prev = tree.leaf_for[end]; + let new_depth = tree.depths[end]; + tree.leaf_for[end] = match prev { + Some((_, d)) if d >= new_depth => prev, + _ => Some((cand_idx, new_depth)), + }; + } + } + + tree +} + +#[cfg(test)] +mod tests { + use super::*; + + fn c(idx: usize, toks: &[u32]) -> Candidate { + Candidate { + draft_idx: idx, + suffix_start: 0, + tokens: toks.to_vec(), + } + } + + #[test] + fn empty_input() { + let t = build_prefix_tree(&[]); + assert!(t.is_empty()); + assert_eq!(t.num_nodes(), 0); + } + + #[test] + fn single_candidate_is_a_chain() { + let t = build_prefix_tree(&[c(0, &[7, 8, 9])]); + assert_eq!(t.tokens, vec![7, 8, 9]); + assert_eq!(t.parents, vec![None, Some(0), Some(1)]); + assert_eq!(t.depths, vec![1, 2, 3]); + assert_eq!(t.leaf_for, vec![None, None, Some((0, 3))]); + } + + #[test] + fn shared_prefix_merges() { + let t = build_prefix_tree(&[c(0, &[1, 2, 3]), c(1, &[1, 2, 4])]); + // Nodes (in insertion order): 1 (root child), 2, 3 (leaf for cand 0), 4 (leaf for cand 1) + assert_eq!(t.tokens, vec![1, 2, 3, 4]); + assert_eq!(t.parents, vec![None, Some(0), Some(1), Some(1)]); + assert_eq!(t.depths, vec![1, 2, 3, 3]); + assert_eq!(t.leaf_for, vec![None, None, Some((0, 3)), Some((1, 3))]); + } + + #[test] + fn candidate_that_is_prefix_of_another() { + let t = build_prefix_tree(&[c(0, &[1, 2]), c(1, &[1, 2, 3])]); + // Node 1 (token 2) is a leaf for cand 0; node 2 (token 3) is a leaf for cand 1. + assert_eq!(t.tokens, vec![1, 2, 3]); + assert_eq!(t.parents, vec![None, Some(0), Some(1)]); + assert_eq!(t.leaf_for[1], Some((0, 2))); + assert_eq!(t.leaf_for[2], Some((1, 3))); + } + + #[test] + fn duplicate_candidate_collapses() { + let t = build_prefix_tree(&[c(0, &[5, 6]), c(1, &[5, 6])]); + assert_eq!(t.tokens, vec![5, 6]); + // Longest-leaf rule keeps cand 0 (first inserted, equal depth). + assert_eq!(t.leaf_for[1], Some((0, 2))); + } + + #[test] + fn path_tokens_reconstructs() { + let t = build_prefix_tree(&[c(0, &[1, 2, 3])]); + assert_eq!(t.path_tokens(2), vec![1, 2, 3]); + assert_eq!(t.path_tokens(0), vec![1]); + } + + #[test] + fn children_of_root_and_internal() { + let t = build_prefix_tree(&[c(0, &[1, 2]), c(1, &[3, 4])]); + let mut roots = t.children_of(None); + roots.sort(); + assert_eq!(roots, vec![0, 2]); // first child of each chain + let mut c1 = t.children_of(Some(0)); + c1.sort(); + assert_eq!(c1, vec![1]); + } + + #[test] + fn duplicate_path_keeps_first_leaf() { + // Same path with two candidates keeps the first id at the shared + // terminal because the terminal depth is equal. + let t = build_prefix_tree(&[c(0, &[1, 2]), c(1, &[1, 2])]); + assert_eq!(t.leaf_for[1], Some((0, 2))); + } +} diff --git a/oar-ocr-vl/src/hsd/types.rs b/oar-ocr-vl/src/hsd/types.rs new file mode 100644 index 0000000..7086392 --- /dev/null +++ b/oar-ocr-vl/src/hsd/types.rs @@ -0,0 +1,368 @@ +//! Shared data types for Hierarchical Speculative Decoding. + +use std::time::Duration; + +/// A single draft sequence: the tokenized output of the pipeline drafter for one +/// region (Stage 1) or one member of the page-level draft set (Stage 2). +/// +/// Tokens are encoded with the **target VLM's** tokenizer so the verification path +/// can match prefixes at token granularity without re-tokenization. +#[derive(Debug, Clone)] +pub struct Draft { + /// Token id sequence as produced by the target VLM's tokenizer. + pub tokens: Vec, +} + +impl Draft { + pub fn new(tokens: Vec) -> Self { + Self { tokens } + } + + pub fn len(&self) -> usize { + self.tokens.len() + } + + pub fn is_empty(&self) -> bool { + self.tokens.is_empty() + } +} + +/// Stage-1 input: draft candidates plus the region geometry needed to crop the page image. +/// +/// The crop itself is materialised by the caller (we keep this struct image-agnostic +/// so the algorithm layer stays free of `image` / `candle` dependencies). +#[derive(Debug, Clone)] +pub struct RegionDraft { + /// Region bounding box `[x0, y0, x1, y1]` in original image pixel coordinates. + pub bbox: [f32; 4], + /// Tokenized draft candidates for this region. Stage 1 verifies all + /// candidates together through the same prefix-tree batching path used by + /// page-level DSV. + pub drafts: Vec, + /// Optional reading-order index assigned by the drafter. + pub reading_order: Option, + /// Drafter-side category label (paragraph / table / formula / figure / ...). + pub kind: RegionKind, +} + +/// Coarse semantic kind reported by the drafter. Mirrors the layout categories used +/// by `oar_ocr_core::domain::structure::LayoutElementType` but stays decoupled so the +/// HSD module compiles without pulling layout dependencies. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub enum RegionKind { + Text, + Title, + List, + Table, + Formula, + Figure, + Header, + Footer, + #[default] + Other, +} + +/// Configuration for Decoupled Speculative Verification. +/// +/// Defaults follow the paper's experimental setup (Section 4.3): n=3, τ=0.75, +/// **but** with three engineering deviations enabled by default: +/// +/// - `max_candidates_per_step = 32`, `max_suffix_len = 256` cap candidate-tree +/// width / depth (paper uses *all* matching suffixes — see [`paper_mode`]). +/// - `cold_start_full_draft = true`: on the very first step (empty accepted +/// history) the matcher emits each draft's leading prefix as a single +/// candidate. The paper's window-only formulation would emit no candidates +/// at step 0 and fall back to a `step_one` argmax for the first token. +/// - `strict_at_tau_one = true`: when `τ ≥ 1.0` the driver redirects to +/// [`crate::hsd::verify::spec_decode_strict`] (per-token replay) for a +/// cheap oracle/correctness path. Setting this to `false` keeps τ=1.0 on +/// the tree-batched verify path — that matches the paper's "τ as a +/// tolerance threshold" formulation and produces equivalent outputs at +/// higher per-step verify cost. +/// +/// [`paper_mode`]: Self::paper_mode +#[derive(Debug, Clone, Copy)] +pub struct DsvConfig { + /// Reference-window length `n`. Token sequences are matched by sliding a window + /// of this size over each draft. + pub window_len: usize, + /// Acceptance threshold `τ ∈ (0, 1]`. At each tree node, the best child + /// token is accepted iff its log-probability is within `log τ` of the + /// model's unrestricted argmax token. + pub tau: f32, + /// Cap on the number of candidate paths kept per verification step. Protects + /// against pathological prefix trees on very long drafts. Set to + /// `usize::MAX` to match the paper's unbounded candidate set. + pub max_candidates_per_step: usize, + /// Cap on the depth of any candidate suffix considered. Drafts longer than this + /// are truncated in the suffix extraction step. Set to `usize::MAX` to + /// match the paper's unbounded suffix length. + pub max_suffix_len: usize, + /// When `true` (default) and the accepted history is empty, the matcher + /// emits each draft's leading prefix as a candidate so the very first + /// verification step has something to verify. The paper's window-only + /// formulation has no such fallback and would fall back to `step_one`. + pub cold_start_full_draft: bool, + /// When `true` (default) and `τ ≥ 1.0`, the driver replays drafts one + /// token at a time via [`crate::hsd::verify::spec_decode_strict`]. Set to + /// `false` to keep τ=1.0 on the tree-batched verify path (paper-style). + pub strict_at_tau_one: bool, +} + +impl Default for DsvConfig { + fn default() -> Self { + Self { + window_len: 3, + tau: 0.75, + max_candidates_per_step: 32, + max_suffix_len: 256, + cold_start_full_draft: true, + strict_at_tau_one: true, + } + } +} + +/// Top-level HSD configuration covering both stages. +#[derive(Debug, Clone)] +pub struct HsdConfig { + pub dsv: DsvConfig, + /// If false, Stage 1 is skipped and only the page-level pass runs. + /// + /// HunyuanOCR, GLM-OCR, and MinerU honor both stage gates in their + /// `generate_hsd_full` / `generate_hsd_with_structure` paths. PaddleOCR-VL + /// remains element-level by model design and therefore uses Stage 1 only. + pub enable_stage1: bool, + /// If false, Stage 2 is skipped in backends that implement the full + /// two-stage path. + pub enable_stage2: bool, + /// Hard cap on `max_new_tokens` for the page-level pass. + pub max_page_tokens: usize, + /// Hard cap on `max_new_tokens` for any single region pass. + pub max_region_tokens: usize, +} + +impl Default for HsdConfig { + fn default() -> Self { + Self { + dsv: DsvConfig::default(), + enable_stage1: true, + enable_stage2: true, + max_page_tokens: 16384, + max_region_tokens: 4096, + } + } +} + +/// Per-step acceptance bookkeeping. Used to compute Average Acceptance Length (AAL) +/// in the spirit of Leviathan et al. 2023: the number of *draft* tokens accepted at +/// each verification step (excludes the bonus token sampled by the target). +#[derive(Debug, Clone, Default)] +pub struct AcceptStats { + /// Per-step accepted draft-token counts (`α_k` in the paper's notation). + pub per_step_accepted: Vec, + /// Number of verification steps (`N`). + pub num_steps: u32, + /// Number of fallback steps where the prefix tree was empty / fully rejected. + pub num_fallbacks: u32, +} + +impl AcceptStats { + /// Average Acceptance Length over recorded steps. + pub fn aal(&self) -> f32 { + if self.num_steps == 0 { + 0.0 + } else { + let sum: u32 = self.per_step_accepted.iter().sum(); + sum as f32 / self.num_steps as f32 + } + } + + pub fn record(&mut self, accepted: u32) { + self.per_step_accepted.push(accepted); + self.num_steps += 1; + } + + pub fn record_fallback(&mut self) { + self.per_step_accepted.push(0); + self.num_steps += 1; + self.num_fallbacks += 1; + } + + pub fn add_assign(&mut self, other: Self) { + self.per_step_accepted.extend(other.per_step_accepted); + self.num_steps += other.num_steps; + self.num_fallbacks += other.num_fallbacks; + } +} + +/// Internal timing/counter breakdown for the shared speculative decoder. +#[derive(Debug, Clone, Default)] +pub struct SpecDecodeStats { + /// Sliding-window candidate collection plus prefix-tree construction. + pub candidate_build: Duration, + /// Packed target-model verify-tree forward calls. + pub verify_tree: Duration, + /// Host-side greedy tree traversal and acceptance test. + pub traverse: Duration, + /// KV-cache trim/gather after verification. + pub commit: Duration, + /// Single-token decode steps after bonus/fallback tokens or strict replay. + pub step_one: Duration, + /// Device argmax used by empty-tree fallback. + pub fallback_argmax: Duration, + pub verify_tree_calls: u32, + pub step_one_calls: u32, + pub fallback_argmax_calls: u32, + pub candidate_steps: u32, + pub candidates_total: u64, + pub candidates_max: u32, + pub empty_tree_calls: u32, + pub rejected_tree_calls: u32, + pub accepted_tree_calls: u32, + pub tree_nodes_total: u64, + pub tree_nodes_max: u32, +} + +impl SpecDecodeStats { + pub fn add_assign(&mut self, other: &Self) { + self.candidate_build += other.candidate_build; + self.verify_tree += other.verify_tree; + self.traverse += other.traverse; + self.commit += other.commit; + self.step_one += other.step_one; + self.fallback_argmax += other.fallback_argmax; + self.verify_tree_calls += other.verify_tree_calls; + self.step_one_calls += other.step_one_calls; + self.fallback_argmax_calls += other.fallback_argmax_calls; + self.candidate_steps += other.candidate_steps; + self.candidates_total += other.candidates_total; + self.candidates_max = self.candidates_max.max(other.candidates_max); + self.empty_tree_calls += other.empty_tree_calls; + self.rejected_tree_calls += other.rejected_tree_calls; + self.accepted_tree_calls += other.accepted_tree_calls; + self.tree_nodes_total += other.tree_nodes_total; + self.tree_nodes_max = self.tree_nodes_max.max(other.tree_nodes_max); + } + + pub fn avg_candidates(&self) -> f32 { + if self.candidate_steps == 0 { + 0.0 + } else { + self.candidates_total as f32 / self.candidate_steps as f32 + } + } + + pub fn avg_tree_nodes(&self) -> f32 { + if self.verify_tree_calls == 0 { + 0.0 + } else { + self.tree_nodes_total as f32 / self.verify_tree_calls as f32 + } + } +} + +/// Wall-clock and counter stats for one HSD stage (Stage 1 or Stage 2). +#[derive(Debug, Clone, Default)] +pub struct StageStats { + pub vision_prefill: Duration, + pub draft_prep: Duration, + pub decode: Duration, + pub accept: AcceptStats, + /// Total *target* tokens emitted by this stage. + pub emitted_tokens: u32, + /// Number of forward passes through the LLM (prefill counted as 1). + pub forward_passes: u32, + /// Shared HSD driver internals for profiler attribution. + pub dsv: SpecDecodeStats, +} + +impl StageStats { + pub fn add_assign(&mut self, other: Self) { + self.vision_prefill += other.vision_prefill; + self.draft_prep += other.draft_prep; + self.decode += other.decode; + self.emitted_tokens += other.emitted_tokens; + self.forward_passes += other.forward_passes; + self.dsv.add_assign(&other.dsv); + self.accept.add_assign(other.accept); + } +} + +/// Stage-1 stats for one verified region, retained for region-kind diagnostics. +#[derive(Debug, Clone, Default)] +pub struct RegionStageStats { + pub kind: RegionKind, + pub stats: StageStats, +} + +/// End-to-end HSD timing breakdown for one page, plus per-stage details. +#[derive(Debug, Clone, Default)] +pub struct HsdStats { + pub stage1: StageStats, + pub stage2: StageStats, + /// Per-region Stage-1 stats, one entry per region that actually ran HSD. + pub stage1_regions: Vec, + /// Drafter pipeline wall-clock (layout + region recognition). + pub drafter: Duration, +} + +impl HsdStats { + /// Total wall-clock from page-image input to final parsing result. + pub fn total(&self) -> Duration { + self.drafter + + self.stage1.vision_prefill + + self.stage1.draft_prep + + self.stage1.decode + + self.stage2.vision_prefill + + self.stage2.draft_prep + + self.stage2.decode + } + + /// Combined AAL across both stages (weighted by step count). + pub fn overall_aal(&self) -> f32 { + let total_steps = self.stage1.accept.num_steps + self.stage2.accept.num_steps; + if total_steps == 0 { + 0.0 + } else { + let sum: u32 = self + .stage1 + .accept + .per_step_accepted + .iter() + .chain(self.stage2.accept.per_step_accepted.iter()) + .sum(); + sum as f32 / total_steps as f32 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_hsd_config_enables_both_stages() { + // Locks the paper-aligned default. Stage 2 is the load-bearing pass + // (paper Table 7: Stage-1-only on dots.ocr drops 88.41 → 70.47 on + // OmniDocBench). Anyone flipping enable_stage2 = false in Default + // silently loses ~18 accuracy points on multi-block pages. + let cfg = HsdConfig::default(); + assert!(cfg.enable_stage1, "Stage 1 should default to enabled"); + assert!(cfg.enable_stage2, "Stage 2 should default to enabled"); + assert!(cfg.max_page_tokens > 0); + assert!(cfg.max_region_tokens > 0); + } + + #[test] + fn default_dsv_config_matches_paper_n_and_tau() { + // Paper §4.3: n = 3, τ = 0.75. The engineering caps deviate (paper has + // no caps), but n and τ must match the paper out of the box. + let cfg = DsvConfig::default(); + assert_eq!(cfg.window_len, 3, "paper §4.3 sets n = 3"); + assert!( + (cfg.tau - 0.75).abs() < 1e-6, + "paper §4.3 sets τ = 0.75, got {}", + cfg.tau + ); + } +} diff --git a/oar-ocr-vl/src/hsd/verify.rs b/oar-ocr-vl/src/hsd/verify.rs new file mode 100644 index 0000000..93962ac --- /dev/null +++ b/oar-ocr-vl/src/hsd/verify.rs @@ -0,0 +1,747 @@ +//! Decoupled Speculative Verification (DSV) — `SpecDecode` operator. +//! +//! This is the model-facing surface of HSD. Concrete VLM backends provide a +//! thin adapter implementing [`SpecBackend`]; the verification driver in this +//! file ([`spec_decode`]) is shared across all backends. +//! +//! ## Algorithm (paper §3.2) +//! +//! At every iteration `spec_decode` does the following: +//! +//! 1. Take the most recent `n = cfg.window_len` accepted tokens as the +//! reference window and slide it over each draft. Extract the suffixes that +//! follow each match ([`super::matching::collect_candidates`]). +//! 2. Merge the suffixes into a prefix tree +//! ([`super::prefix_tree::build_prefix_tree`]). +//! 3. If the tree is empty, fall back to a single-token greedy step. Otherwise +//! call [`SpecBackend::verify_tree`] which appends the packed tree tokens to +//! the KV cache under a tree-ancestry mask and returns per-node log-probs. +//! 4. Greedily traverse the tree from the root. At each node, choose the child +//! with the highest target-model log-probability and accept it iff it is +//! within `log τ` of the model's unrestricted argmax token (paper eq. 11). +//! Stop at a leaf or upon rejection. +//! 5. Append the accepted path tokens to the output, then append the greedy +//! bonus token `û` from the terminal distribution. +//! 6. [`SpecBackend::commit_verify`] gathers the KV cache to keep only the +//! accepted-path positions, then `step_one(û)` populates the cache for the +//! bonus token and produces the next-iteration log-probs. +//! +//! [`spec_decode_strict`] is a debugging variant: it replays draft tokens one +//! by one through [`SpecBackend::step_one`] instead of using tree verification. + +use candle_core::{DType, Device, Result as CandleResult, Tensor}; +use std::time::Instant; + +use super::matching::collect_candidates; +use super::prefix_tree::{PrefixTree, build_prefix_tree}; +use super::types::{AcceptStats, Draft, DsvConfig, SpecDecodeStats}; + +/// Backend adapter for HSD verification. Each VLM backend (HunyuanOCR, +/// PaddleOCR-VL, MinerU, …) implements this trait once, and reuses the +/// [`spec_decode`] driver. +/// +/// All log-probability tensors must be **post log-softmax in F32**. The driver +/// only reads them via `to_vec1::()`; intermediate compute can run in +/// BF16/F16 inside the backend, but the surface contract is F32. +pub trait SpecBackend { + /// Decode a single token and advance the KV cache by one. Returns the + /// next-token log-probability tensor of shape `(vocab,)`. + fn step_one(&mut self, token: u32) -> CandleResult; + + /// Run a verification forward pass over a packed prefix tree. + /// + /// Implementations must: + /// - Append `tree.num_nodes()` tokens to the KV cache. + /// - Use a tree-ancestry attention mask + /// ([`crate::attention::create_tree_attention_mask`]) so each node only + /// sees the accepted prefix and its own ancestor chain. + /// - Use position ids `accepted_kv_len + tree.depths[i]` for node `i`. + /// - Return `(num_nodes, vocab)` log-probabilities. + /// + /// The KV cache is left in the post-append state; the driver will call + /// [`SpecBackend::commit_verify`] to gather it down to the accepted path. + fn verify_tree(&mut self, tree: &PrefixTree) -> CandleResult; + + /// Gather the KV cache so it keeps only the accepted prefix plus the + /// supplied path-node positions. An empty `accepted_path` means trim back + /// to the prefix (full rejection). + /// + /// `accepted_path` is given as packed-tree indices in walk order + /// (root → leaf). Each implementation is responsible for translating these + /// into KV-cache absolute positions (`prefix_kv_len + idx`). + fn commit_verify(&mut self, accepted_path: &[usize]) -> CandleResult<()>; + + /// Returns true if `tok` is any of the backend's end-of-generation + /// tokens. Backends typically have several (eod, eos, end-of-turn marker) + /// — returning a single id and comparing with `==` would let HSD generate + /// past a real stop token if the model emits a *different* stop than the + /// one we know about. + fn is_eos(&self, tok: u32) -> bool; +} + +/// Move a 1-D log-prob tensor to CPU/F32 for cheap host-side scanning. +fn lp_to_host(t: &Tensor) -> CandleResult> { + t.to_dtype(DType::F32)?.to_device(&Device::Cpu)?.to_vec1() +} + +/// Argmax over a host-side log-prob slice. Returns `(token_id, log_prob)`. +fn argmax_host(v: &[f32]) -> (u32, f32) { + let mut best_idx = 0usize; + let mut best_val = f32::NEG_INFINITY; + for (i, &x) in v.iter().enumerate() { + if x > best_val { + best_val = x; + best_idx = i; + } + } + (best_idx as u32, best_val) +} + +/// GPU-side argmax that copies only the resulting scalar token id back to the +/// host. Replaces `lp_to_host` + `argmax_host` on the fallback / bonus paths, +/// where we don't need the full distribution — only the argmax. On long +/// sequences with many empty-tree fallbacks this dominates HSD wall time: +/// per-call cost drops from ~21 ms (D2H of `vocab × 4 B` bytes) to ~1 ms +/// (kernel launch + 4-byte D2H). nsys: `cuMemcpyDtoHAsync_v2` was 87 % of +/// HSD time on long-output pages before this change. +fn argmax_on_device(t: &Tensor) -> CandleResult { + t.argmax(candle_core::D::Minus1)? + .to_dtype(DType::U32)? + .to_scalar::() +} + +/// Greedy traversal of the prefix tree under the τ-tolerance acceptance test. +/// +/// Returns +/// - `path`: tree-node indices visited along the accepted path, root → leaf. +/// - `terminal_lp`: log-prob distribution at the final node (used to sample +/// the greedy bonus token û). Pre-materialised on the host so the caller can +/// `argmax_host` it without another GPU→CPU sync. +/// +/// We pull the entire `(num_nodes, vocab)` log-prob matrix to the host in a +/// single transfer at the top, plus the `(vocab,)` root distribution. Per +/// nsys, the previous per-step `cuMemcpyDtoHAsync` calls (one per traversal +/// step, each of size `4 × vocab`) accounted for ~87 % of HSD wall time on +/// long-output pages because each transfer paid the full GPU sync latency. A +/// single bulk transfer is dominated by PCIe bandwidth regardless of how many +/// traversal steps follow. +fn greedy_traverse( + tree: &PrefixTree, + node_logprobs: &Tensor, + root_logprobs: &Tensor, + tau: f32, +) -> CandleResult<(Vec, Vec)> { + let log_tau = tau.ln(); + let mut path: Vec = Vec::new(); + let mut s: Option = None; + + // Bulk D2H copies: root's (vocab,) and the full (num_nodes, vocab) matrix. + let root_host: Vec = lp_to_host(root_logprobs)?; + let vocab = root_host.len(); + let num_nodes = tree.num_nodes(); + let nodes_host: Vec = if num_nodes == 0 { + Vec::new() + } else { + // Reshape rather than `flatten_all` so a backend handing us a 1-D + // or batched (B, N, V) tensor still yields a flat `num_nodes * vocab` + // buffer — and we get an explicit shape error if the totals disagree + // (which would otherwise show up as an out-of-bounds slice in `row`). + node_logprobs + .reshape((num_nodes, vocab))? + .to_dtype(DType::F32)? + .to_device(&Device::Cpu)? + .flatten_all()? + .to_vec1()? + }; + + let row = |node_idx: usize| -> &[f32] { + let start = node_idx * vocab; + &nodes_host[start..start + vocab] + }; + + let mut cur_lp_view: &[f32] = &root_host[..]; + let mut terminal_lp_owned: Vec = root_host.clone(); + + loop { + let children = tree.children_of(s); + if children.is_empty() { + break; + } + let (_, u_hat_lp) = argmax_host(cur_lp_view); + + let mut best_node: Option = None; + let mut best_lp = f32::NEG_INFINITY; + for &c in &children { + let tok = tree.tokens[c] as usize; + let lp = cur_lp_view.get(tok).copied().unwrap_or(f32::NEG_INFINITY); + if lp > best_lp { + best_lp = lp; + best_node = Some(c); + } + } + let best = best_node.expect("non-empty children list"); + let margin = best_lp - u_hat_lp; + + if margin >= log_tau { + path.push(best); + cur_lp_view = row(best); + s = Some(best); + } else { + break; + } + } + + // Save the terminal distribution for the bonus-token sample. If we + // accepted at least one tree node, the terminal distribution lives in + // `nodes_host`; otherwise it's the root. + if let Some(last) = path.last().copied() { + terminal_lp_owned = row(last).to_vec(); + } + Ok((path, terminal_lp_owned)) +} + +/// `SpecDecode(p_θ, z, Ỹ)` — the page-/region-level verification driver. +/// +/// `initial_logprobs` is the prefill's last-position distribution (shape +/// `(vocab,)`). `drafts` may contain Stage-1 region drafts (one verify per +/// region) or Stage-2 page drafts (one verify per page). +pub fn spec_decode( + backend: &mut B, + drafts: &[Draft], + initial_logprobs: Tensor, + max_new_tokens: usize, + cfg: &DsvConfig, + stats: &mut AcceptStats, + timings: &mut SpecDecodeStats, +) -> CandleResult> { + if cfg.tau >= 1.0 && cfg.strict_at_tau_one { + return spec_decode_strict( + backend, + drafts, + initial_logprobs, + max_new_tokens, + cfg, + stats, + timings, + ); + } + + let mut accepted: Vec = Vec::with_capacity(max_new_tokens); + let mut cur_logprobs = initial_logprobs; + + while accepted.len() < max_new_tokens { + // 1. Build candidates from the most recent accepted-token window. + let n = accepted.len().min(cfg.window_len); + let tail = &accepted[accepted.len() - n..]; + let t_build = Instant::now(); + let tree = { + let candidates = collect_candidates(tail, accepted.len(), drafts, cfg); + let candidate_count = candidates.len() as u32; + timings.candidate_steps += 1; + timings.candidates_total += candidate_count as u64; + timings.candidates_max = timings.candidates_max.max(candidate_count); + build_prefix_tree(&candidates) + }; + timings.candidate_build += t_build.elapsed(); + + // 2. Empty tree → fall back to a single-token greedy step. + // Argmax on-device so we only D2H a single scalar instead of the + // full vocab. On pages where most steps are fallbacks this saves + // ~20 ms / iteration vs the previous host-side argmax path. + if tree.is_empty() { + timings.empty_tree_calls += 1; + let t_argmax = Instant::now(); + let u_hat = argmax_on_device(&cur_logprobs)?; + timings.fallback_argmax += t_argmax.elapsed(); + timings.fallback_argmax_calls += 1; + // Mirror baseline `generate_tokens_internal`: EOS terminates the + // loop without being appended to the output sequence. The + // tokenizer would strip it on decode anyway, but keeping it in + // `accepted` makes the τ=1.0 oracle check (raw token equality) + // diverge by one token at the tail. The fallback step itself + // still ran, so `num_fallbacks` counts it. + stats.record_fallback(); + if backend.is_eos(u_hat) { + break; + } + accepted.push(u_hat); + if accepted.len() >= max_new_tokens { + break; + } + let t_step = Instant::now(); + cur_logprobs = backend.step_one(u_hat)?; + timings.step_one += t_step.elapsed(); + timings.step_one_calls += 1; + continue; + } + + // 3. Verify the tree in one packed forward pass. + let nodes = tree.num_nodes() as u32; + let t_verify = Instant::now(); + let node_logprobs = backend.verify_tree(&tree)?; + timings.verify_tree += t_verify.elapsed(); + timings.verify_tree_calls += 1; + timings.tree_nodes_total += nodes as u64; + timings.tree_nodes_max = timings.tree_nodes_max.max(nodes); + + // 4. Greedy traversal under the τ test (returns the accepted path + // and the *host-side* terminal distribution — only one D2H copy + // per verify step now, vs one per traversal step before). + let t_traverse = Instant::now(); + let (path, term_host) = greedy_traverse(&tree, &node_logprobs, &cur_logprobs, cfg.tau)?; + timings.traverse += t_traverse.elapsed(); + + // 5. Truncate the accepted path to the remaining token budget *before* + // committing so the KV cache and `accepted` stay in lockstep. The + // tree may have accepted more nodes than we're allowed to emit; we + // keep at most `remaining` tokens (or the path's natural EOS, which + // terminates this step). + let remaining = max_new_tokens.saturating_sub(accepted.len()); + let mut take = 0usize; + let mut path_eos = false; + for &node_idx in &path { + if take >= remaining { + break; + } + let tok = tree.tokens[node_idx]; + if backend.is_eos(tok) { + path_eos = true; + take += 1; // include the EOS slot in the commit length + break; + } + take += 1; + } + let path_cap = &path[..take]; + + // 6. Commit only the path prefix we'll actually emit. + let t_commit = Instant::now(); + backend.commit_verify(path_cap)?; + timings.commit += t_commit.elapsed(); + + // 7. Append accepted-path tokens (EOS is consumed without emit). AAL + // is recorded as the count of *draft* tokens accepted in this step + // (paper §4.2 / Leviathan et al. 2023), excluding the bonus û. + for &node_idx in path_cap { + let tok = tree.tokens[node_idx]; + if backend.is_eos(tok) { + break; + } + accepted.push(tok); + } + stats.record(path_cap.len() as u32); + if path_cap.is_empty() { + timings.rejected_tree_calls += 1; + } else { + timings.accepted_tree_calls += 1; + } + if path_eos || accepted.len() >= max_new_tokens { + break; + } + + // 8. Take the greedy bonus token û from the terminal distribution + // (already on host from greedy_traverse). + let (u_hat, _) = argmax_host(&term_host); + if backend.is_eos(u_hat) { + break; + } + accepted.push(u_hat); + if accepted.len() >= max_new_tokens { + break; + } + + // 9. Step once on û to populate KV and seed the next iteration. + let t_step = Instant::now(); + cur_logprobs = backend.step_one(u_hat)?; + timings.step_one += t_step.elapsed(); + timings.step_one_calls += 1; + } + + Ok(accepted) +} + +fn spec_decode_strict( + backend: &mut B, + drafts: &[Draft], + initial_logprobs: Tensor, + max_new_tokens: usize, + cfg: &DsvConfig, + stats: &mut AcceptStats, + timings: &mut SpecDecodeStats, +) -> CandleResult> { + let mut accepted: Vec = Vec::with_capacity(max_new_tokens); + let mut cur_logprobs = initial_logprobs; + + while accepted.len() < max_new_tokens { + let n = accepted.len().min(cfg.window_len); + let tail = &accepted[accepted.len() - n..]; + let t_build = Instant::now(); + let tree = { + let candidates = collect_candidates(tail, accepted.len(), drafts, cfg); + let candidate_count = candidates.len() as u32; + timings.candidate_steps += 1; + timings.candidates_total += candidate_count as u64; + timings.candidates_max = timings.candidates_max.max(candidate_count); + build_prefix_tree(&candidates) + }; + timings.candidate_build += t_build.elapsed(); + + let mut step_accepted = 0u32; + // Tracks whether the inner loop terminated by hitting EOS. We can no + // longer derive this from `accepted.last()` because the driver no + // longer pushes EOS into `accepted` (matches baseline; see the + // corresponding rework in the main `spec_decode` path). + let mut step_eos = false; + let mut s: Option = None; + loop { + let children = tree.children_of(s); + if children.is_empty() || accepted.len() >= max_new_tokens { + break; + } + + let t_traverse = Instant::now(); + // Strict τ=1.0 is an oracle correctness check: it must agree with + // the baseline greedy path token-for-token, including tie-break + // direction. Baseline (`argmax_with_repetition_penalty`) and + // `argmax_host` both pick the *first* index via strict `>`; candle's + // CUDA `argmax` may break ties differently (block-reduction order), + // so the host-side path stays for correctness. The bulk D2H cost + // here is acceptable — strict mode is a debugging tool, not a + // production decode path. + let best_child = { + let cur_host = lp_to_host(&cur_logprobs)?; + let (u_hat, _) = argmax_host(&cur_host); + children + .iter() + .copied() + .find(|&child| tree.tokens[child] == u_hat) + }; + timings.traverse += t_traverse.elapsed(); + + let Some(child) = best_child else { + break; + }; + + let tok = tree.tokens[child]; + if backend.is_eos(tok) { + step_accepted += 1; + step_eos = true; + break; + } + accepted.push(tok); + step_accepted += 1; + if accepted.len() >= max_new_tokens { + break; + } + let t_step = Instant::now(); + cur_logprobs = backend.step_one(tok)?; + timings.step_one += t_step.elapsed(); + timings.step_one_calls += 1; + s = Some(child); + } + + if step_accepted > 0 { + timings.accepted_tree_calls += 1; + stats.record(step_accepted); + if step_eos || accepted.len() >= max_new_tokens { + break; + } + continue; + } + + if tree.is_empty() { + timings.empty_tree_calls += 1; + } else { + timings.rejected_tree_calls += 1; + } + let t_argmax = Instant::now(); + let u_hat = argmax_on_device(&cur_logprobs)?; + timings.fallback_argmax += t_argmax.elapsed(); + timings.fallback_argmax_calls += 1; + if backend.is_eos(u_hat) { + stats.record_fallback(); + break; + } + accepted.push(u_hat); + stats.record_fallback(); + if accepted.len() >= max_new_tokens { + break; + } + let t_step = Instant::now(); + cur_logprobs = backend.step_one(u_hat)?; + timings.step_one += t_step.elapsed(); + timings.step_one_calls += 1; + } + + Ok(accepted) +} + +#[cfg(test)] +mod tests { + use super::*; + use candle_core::Device; + + /// Deterministic oracle backend used to validate the driver in isolation + /// from any model. Logprobs are concentrated (~1.0) at the next "true" + /// token, so: + /// - Drafts that match the oracle are fully accepted. + /// - Drafts that diverge are rejected and the driver falls through to + /// single-token decoding. + struct OracleBackend { + vocab: usize, + eos: u32, + oracle: Vec, + /// Number of post-prefill tokens already "decoded" (== KV positions + /// past the prefill prefix). The next step_one / verify_tree result + /// derives from `oracle[position + d]` for some depth d ≥ 1. + position: usize, + /// Tally of `step_one` and `verify_tree` calls — useful in tests to + /// confirm parallel verification actually saved forward passes. + n_step_one: usize, + n_verify_tree: usize, + } + + impl OracleBackend { + fn new(vocab: usize, eos: u32, oracle: Vec) -> Self { + Self { + vocab, + eos, + oracle, + position: 0, + n_step_one: 0, + n_verify_tree: 0, + } + } + + fn lp_for(&self, tok: u32) -> Tensor { + // Concentrated distribution: target token ≈ log 1, others ≈ log 0. + let mut v = vec![-100.0f32; self.vocab]; + v[tok as usize] = 0.0; + Tensor::from_vec(v, (self.vocab,), &Device::Cpu).unwrap() + } + + fn next_oracle(&self, offset: usize) -> u32 { + self.oracle + .get(self.position + offset) + .copied() + .unwrap_or(self.eos) + } + } + + impl SpecBackend for OracleBackend { + fn step_one(&mut self, _token: u32) -> CandleResult { + self.n_step_one += 1; + self.position += 1; + let nxt = self.next_oracle(0); + Ok(self.lp_for(nxt)) + } + + fn verify_tree(&mut self, tree: &PrefixTree) -> CandleResult { + self.n_verify_tree += 1; + let n = tree.num_nodes(); + let mut buf = vec![-100.0f32; n * self.vocab]; + for i in 0..n { + let depth = tree.depths[i] as usize; + let nxt = self.next_oracle(depth); + buf[i * self.vocab + nxt as usize] = 0.0; + } + Tensor::from_vec(buf, (n, self.vocab), &Device::Cpu) + } + + fn commit_verify(&mut self, path: &[usize]) -> CandleResult<()> { + self.position += path.len(); + Ok(()) + } + + fn is_eos(&self, tok: u32) -> bool { + tok == self.eos + } + } + + fn cfg() -> DsvConfig { + DsvConfig::default() + } + + fn run_spec_decode( + backend: &mut B, + drafts: &[Draft], + initial_logprobs: Tensor, + max_new_tokens: usize, + cfg: &DsvConfig, + stats: &mut AcceptStats, + ) -> CandleResult> { + let mut timings = SpecDecodeStats::default(); + spec_decode( + backend, + drafts, + initial_logprobs, + max_new_tokens, + cfg, + stats, + &mut timings, + ) + } + + #[test] + fn perfect_match_accepts_full_path() { + let oracle = vec![10u32, 20, 30, 40, 50, 99]; + let mut backend = OracleBackend::new(128, 99, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + let drafts = vec![Draft::new(oracle.clone())]; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 64, &cfg(), &mut stats).unwrap(); + + // EOS terminates the output without being appended (matches the + // baseline `generate_tokens_internal` contract). + assert_eq!(out, &oracle[..oracle.len() - 1]); + // One verify call accepted the entire chain (including the EOS that + // wasn't emitted). + assert_eq!(backend.n_verify_tree, 1); + // step_one is never invoked because EOS was reached inside the path. + assert_eq!(backend.n_step_one, 0); + assert!(stats.aal() > 0.0); + } + + #[test] + fn no_match_falls_back_to_step_one() { + let oracle = vec![10u32, 20, 30, 40, 99]; + let mut backend = OracleBackend::new(128, 99, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + // Drafts that share no tokens with the oracle. + let drafts = vec![Draft::new(vec![1, 2, 3, 4])]; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 64, &cfg(), &mut stats).unwrap(); + + // EOS is consumed by the driver and not emitted. + assert_eq!(out, &oracle[..oracle.len() - 1]); + // Step 0: empty-tail tree from the draft → all rejected → fallback. + // Subsequent steps: tail = [oracle_i], no match in [1,2,3,4] → + // tree empty → fallback, so every subsequent token also falls back. + // The terminating EOS step also fell back but the driver no longer + // pushes EOS into `accepted` *or* counts it in `num_fallbacks` + // (matches baseline `generate_tokens_internal`). + assert!( + backend.n_step_one as usize >= oracle.len() - 1, + "n_step_one = {} < {}", + backend.n_step_one, + oracle.len() - 1 + ); + let expected_fallbacks = (oracle.len() as u32) - 1; + assert!( + stats.num_fallbacks >= expected_fallbacks, + "num_fallbacks = {} < {}", + stats.num_fallbacks, + expected_fallbacks + ); + } + + #[test] + fn partial_match_accepts_prefix_then_diverges() { + // Oracle starts the same as the draft, then diverges. + let oracle = vec![10u32, 20, 30, 77, 88, 99]; + let draft_tokens = vec![10u32, 20, 30, 40, 50]; + let mut backend = OracleBackend::new(128, 99, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + let drafts = vec![Draft::new(draft_tokens)]; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 64, &cfg(), &mut stats).unwrap(); + // EOS terminates output without being emitted. + assert_eq!(out, &oracle[..oracle.len() - 1]); + // First verify accepts [10, 20, 30] (depths 1..=3 match the oracle), + // then the depth-4 child carries token 40 which the oracle rejects. + // After that, the accepted-tail [_,_,30] no longer matches the draft + // (the draft has 30 at index 2, but its trailing window [10,20,30] + // only appears once with suffix [40,50]; reject again at first child). + assert!(backend.n_verify_tree >= 1); + // AAL across all steps should reflect the early acceptance. + assert!(stats.aal() > 0.0); + } + + #[test] + fn max_new_tokens_caps_output() { + // EOS=63 is unreachable here (oracle is [1..=10]) so the only termination + // path is the budget cap. The tree-verify branch must truncate the + // accepted path before commit; otherwise the output would overshoot. + let oracle = vec![1u32, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + let mut backend = OracleBackend::new(64, 63, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + let drafts = vec![Draft::new(oracle.clone())]; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 4, &cfg(), &mut stats).unwrap(); + // Hard cap: the driver must not emit more than `max_new_tokens` tokens + // even when the tree path could accept the entire draft in one verify. + assert_eq!(out.len(), 4); + assert_eq!(&out[..], &oracle[..4]); + } + + #[test] + fn eos_in_initial_logprobs_short_circuits_via_fallback() { + // Initial logprobs concentrated at EOS: with a draft starting at EOS, + // the first verify accepts EOS as path[0] and we exit immediately. + // The EOS token itself is *not* emitted (matches baseline). + let mut backend = OracleBackend::new(32, 31, vec![31]); + let init_lp = backend.lp_for(31); + let drafts = vec![Draft::new(vec![31])]; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 64, &cfg(), &mut stats).unwrap(); + assert!(out.is_empty()); + } + + #[test] + fn multiple_drafts_get_merged_in_tree() { + // Two drafts sharing a prefix; the longer one should be accepted. + let oracle = vec![10u32, 20, 30, 40, 50, 99]; + let mut backend = OracleBackend::new(128, 99, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + let drafts = vec![ + Draft::new(vec![10, 20, 30]), + Draft::new(vec![10, 20, 30, 40, 50, 99]), + ]; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 64, &cfg(), &mut stats).unwrap(); + // EOS at the tail of the chain terminates without being emitted. + assert_eq!(out, &oracle[..oracle.len() - 1]); + assert_eq!(backend.n_verify_tree, 1); + } + + #[test] + fn tau_one_is_strict() { + // With τ = 1.0 and a draft containing the oracle, acceptance still + // succeeds because cur_lp(u*) == cur_lp(û) (both 0.0) so the test + // (0 - 0 >= log 1 = 0) just barely passes. EOS terminates without + // being emitted, matching baseline. + let oracle = vec![5u32, 6, 7, 99]; + let mut backend = OracleBackend::new(128, 99, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + let drafts = vec![Draft::new(oracle.clone())]; + let cfg = DsvConfig { + tau: 1.0, + ..DsvConfig::default() + }; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 32, &cfg, &mut stats).unwrap(); + assert_eq!(out, &oracle[..oracle.len() - 1]); + } + + #[test] + fn tau_one_tree_path_matches_strict() { + // With strict_at_tau_one = false the driver stays on the tree-verify + // path even at τ = 1.0. Output must still match the strict-replay + // route (paper §3.3: τ is a tolerance, the tree path subsumes strict + // replay when the threshold is set to 1.0). + let oracle = vec![5u32, 6, 7, 99]; + let mut backend = OracleBackend::new(128, 99, oracle.clone()); + let init_lp = backend.lp_for(oracle[0]); + let drafts = vec![Draft::new(oracle.clone())]; + let cfg = DsvConfig { + tau: 1.0, + strict_at_tau_one: false, + ..DsvConfig::default() + }; + let mut stats = AcceptStats::default(); + let out = run_spec_decode(&mut backend, &drafts, init_lp, 32, &cfg, &mut stats).unwrap(); + assert_eq!(out, &oracle[..oracle.len() - 1]); + // Verify we actually went through the tree path (not the strict path, + // which never increments verify-tree calls). + assert!(backend.n_verify_tree >= 1); + } +} diff --git a/oar-ocr-vl/src/hunyuanocr/llm.rs b/oar-ocr-vl/src/hunyuanocr/llm.rs index de8c6df..5b727e2 100644 --- a/oar-ocr-vl/src/hunyuanocr/llm.rs +++ b/oar-ocr-vl/src/hunyuanocr/llm.rs @@ -2,68 +2,91 @@ use super::config::HunyuanOcrConfig; use crate::attention::{ RotaryEmbedding, repeat_kv, scaled_dot_product_attention, select_rope_sections, }; +#[cfg(feature = "hsd")] +use crate::hsd::TrimmableKvCache; +#[cfg(not(feature = "hsd"))] +use crate::kv_trim::TrimmableKvCache; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing, rotate_half}; use candle_core::Tensor; -use candle_nn::{Module, kv_cache::KvCache}; +use candle_nn::Module; use oar_ocr_core::core::OCRError; use std::cell::RefCell; +/// Apply XDRoPE to `(q, k)` using already-section-mixed F32 `cos`/`sin`. +/// +/// The section-mix (`select_rope_sections`) and the F32 cast of cos/sin are +/// layer-invariant — only q/k change between layers — so the caller +/// ([`HunyuanLlm::forward`]) hoists those steps out of the layer loop and +/// hands us the prepared tensors. Each layer then only pays the q/k F32 cast +/// and the actual rotary multiply. fn apply_xdrope_rotary_pos_emb( q: &Tensor, k: &Tensor, cos: &Tensor, sin: &Tensor, - xdrope_section: &[usize], ) -> Result<(Tensor, Tensor), OCRError> { - // XDRoPE uses 4 position dimensions - let cos = select_rope_sections(cos, xdrope_section, 4)?; - let sin = select_rope_sections(sin, xdrope_section, 4)?; - - let q_mul = q.broadcast_mul(&cos).map_err(|e| { + // Match upstream HF (`apply_rotary_pos_emb_xdrope`): apply the rotary + // mix in F32, then cast q/k back to the original dtype. + use candle_core::DType; + let origin_dtype = q.dtype(); + let q_f32 = q + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "xdrope q to_dtype f32", e))?; + let k_f32 = k + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "xdrope k to_dtype f32", e))?; + + let q_mul = q_f32.broadcast_mul(cos).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: xdrope q*cos failed", e, ) })?; - let q_half = rotate_half(q)?; - let q_half_mul = q_half.broadcast_mul(&sin).map_err(|e| { + let q_half = rotate_half(&q_f32)?; + let q_half_mul = q_half.broadcast_mul(sin).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: xdrope rotate_half(q)*sin failed", e, ) })?; - let q_rot = (&q_mul + &q_half_mul).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: xdrope apply on q failed", - e, - ) - })?; - - let k_mul = k.broadcast_mul(&cos).map_err(|e| { + let q_rot = (&q_mul + &q_half_mul) + .map_err(|e| { + candle_to_ocr_processing( + oar_ocr_core::core::errors::ProcessingStage::TensorOperation, + "HunyuanOCR: xdrope apply on q failed", + e, + ) + })? + .to_dtype(origin_dtype) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "xdrope q to_dtype back", e))?; + + let k_mul = k_f32.broadcast_mul(cos).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: xdrope k*cos failed", e, ) })?; - let k_half = rotate_half(k)?; - let k_half_mul = k_half.broadcast_mul(&sin).map_err(|e| { + let k_half = rotate_half(&k_f32)?; + let k_half_mul = k_half.broadcast_mul(sin).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: xdrope rotate_half(k)*sin failed", e, ) })?; - let k_rot = (&k_mul + &k_half_mul).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: xdrope apply on k failed", - e, - ) - })?; + let k_rot = (&k_mul + &k_half_mul) + .map_err(|e| { + candle_to_ocr_processing( + oar_ocr_core::core::errors::ProcessingStage::TensorOperation, + "HunyuanOCR: xdrope apply on k failed", + e, + ) + })? + .to_dtype(origin_dtype) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "xdrope k to_dtype back", e))?; Ok((q_rot, k_rot)) } @@ -124,8 +147,7 @@ struct HunyuanAttention { num_kv_groups: usize, head_dim: usize, scaling: f64, - xdrope_section: Vec, - kv_cache: RefCell, + kv_cache: RefCell, } impl HunyuanAttention { @@ -180,12 +202,11 @@ impl HunyuanAttention { (None, None) }; - // Create KvCache with dim=2 for seq_len dimension - // Pre-allocate enough space to avoid O(N) reallocation during generation - // Conservative estimate: vision tokens + max_generation_tokens - // Typical: ~1000-2000 vision tokens + 4096 generation tokens = ~6000-8000 total - // Use 16384 to handle worst case without reallocation - let kv_cache = KvCache::new(2, 16384); + // Cat-along-seq KV cache. Capacity 16384 covers ~1000-2000 vision tokens + // plus the longest realistic generation. Same growth strategy as + // candle_nn::kv_cache::KvCache (Tensor::cat per append). Trim/gather + // support is required by HSD's tree-verification path. + let kv_cache = TrimmableKvCache::new(2, 16384); Ok(Self { q_proj, @@ -199,7 +220,6 @@ impl HunyuanAttention { num_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads, head_dim: cfg.head_dim, scaling: (cfg.head_dim as f64).powf(-0.5), - xdrope_section: cfg.rope_scaling.xdrope_section.clone(), kv_cache: RefCell::new(kv_cache), }) } @@ -242,6 +262,11 @@ impl HunyuanAttention { .transpose(1, 2) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "attn v transpose", e))?; + let (q, k) = apply_xdrope_rotary_pos_emb(&q, &k, cos, sin)?; + + // Match upstream HunyuanVL: apply XDRoPE first, then Q/K RMSNorm. + // The learned RMSNorm weight is per head dimension, so it does not + // commute with the rotary half-dimension mixing. let q = match &self.query_layernorm { Some(ln) => ln .forward(&q) @@ -255,8 +280,6 @@ impl HunyuanAttention { None => k, }; - let (q, k) = apply_xdrope_rotary_pos_emb(&q, &k, cos, sin, &self.xdrope_section)?; - let q = q .contiguous() .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "attn q contiguous", e))?; @@ -285,7 +308,9 @@ impl HunyuanAttention { candle_to_ocr_inference("HunyuanOCR", "attn value_states contiguous", e) })?; - // Use unified attention implementation + // Use unified attention implementation (BF16 Q·K, F32 softmax, BF16 A·V). + // The main Hunyuan-specific numerical requirement is above: upstream + // applies XDRoPE in F32 before Q/K RMSNorm. let attn_output = scaled_dot_product_attention( &q, &key_states, @@ -310,6 +335,19 @@ impl HunyuanAttention { fn clear_kv_cache(&self) { self.kv_cache.borrow_mut().reset(); } + + #[cfg(feature = "hsd")] + fn current_kv_len(&self) -> usize { + self.kv_cache.borrow().current_seq_len() + } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.kv_cache + .borrow_mut() + .keep_indices(indices) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "keep_kv_indices", e)) + } } #[derive(Debug)] @@ -374,6 +412,11 @@ impl HunyuanDecoderLayer { fn clear_kv_cache(&self) { self.self_attn.clear_kv_cache(); } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.self_attn.keep_kv_indices(indices) + } } #[derive(Debug)] @@ -382,6 +425,10 @@ pub struct HunyuanLlm { layers: Vec, norm: candle_nn::RmsNorm, rotary: RotaryEmbedding, + /// XDRoPE section sizes (`config.rope_scaling.xdrope_section`). Used once + /// per `forward` to section-mix the rotary `cos`/`sin` tensors before they + /// fan out to every layer — see [`apply_xdrope_rotary_pos_emb`]. + xdrope_section: Vec, } impl HunyuanLlm { @@ -399,13 +446,20 @@ impl HunyuanLlm { let norm = candle_nn::rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm")) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "load final norm", e))?; - let rotary = RotaryEmbedding::new_multi_axis(cfg.head_dim, cfg.rope_theta, 4, vb.device())?; + let rope_theta = match cfg.rope_scaling.alpha { + Some(alpha) if alpha != 0.0 => { + cfg.rope_theta * alpha.powf(cfg.head_dim as f64 / (cfg.head_dim as f64 - 2.0)) + } + _ => cfg.rope_theta, + }; + let rotary = RotaryEmbedding::new_multi_axis(cfg.head_dim, rope_theta, 4, vb.device())?; Ok(Self { embed_tokens, layers, norm, rotary, + xdrope_section: cfg.rope_scaling.xdrope_section.clone(), }) } @@ -425,12 +479,26 @@ impl HunyuanLlm { position_ids: &Tensor, causal_mask: Option<&Tensor>, ) -> Result { + use candle_core::DType; + let (cos, sin) = self .rotary .forward_multi_axis(position_ids, inputs_embeds.dtype())?; + // XDRoPE section-mix + F32 cast: the result is layer-invariant + // (depends only on position_ids and xdrope_section), so do it once + // here instead of per-layer inside `apply_xdrope_rotary_pos_emb`. + // Saves ~num_layers × (2 select_rope_sections + 2 to_dtype) ops per + // forward. + let cos = select_rope_sections(&cos, &self.xdrope_section, 4)? + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "xdrope cos to_dtype f32", e))?; + let sin = select_rope_sections(&sin, &self.xdrope_section, 4)? + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "xdrope sin to_dtype f32", e))?; + let mut hidden_states = inputs_embeds.clone(); - for layer in &self.layers { + for layer in self.layers.iter() { hidden_states = layer.forward(&hidden_states, &cos, &sin, causal_mask)?; } let hidden_states = self @@ -445,4 +513,25 @@ impl HunyuanLlm { layer.clear_kv_cache(); } } + + /// Current sequence length held in the KV cache. All layers stay in sync, + /// so we read it from layer 0. + #[cfg(feature = "hsd")] + pub fn current_kv_len(&self) -> usize { + self.layers + .first() + .map(|l| l.self_attn.current_kv_len()) + .unwrap_or(0) + } + + /// Gather every layer's KV cache to keep only the supplied positions + /// (in order). Used by HSD after tree-attention verification to retain the + /// accepted-path KV entries and drop the rejected-tree positions. + #[cfg(feature = "hsd")] + pub fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + for layer in &self.layers { + layer.keep_kv_indices(indices)?; + } + Ok(()) + } } diff --git a/oar-ocr-vl/src/hunyuanocr/mod.rs b/oar-ocr-vl/src/hunyuanocr/mod.rs index c6f6979..04579af 100644 --- a/oar-ocr-vl/src/hunyuanocr/mod.rs +++ b/oar-ocr-vl/src/hunyuanocr/mod.rs @@ -12,4 +12,6 @@ mod vision; pub use config::{ HunyuanOcrConfig, HunyuanOcrImageProcessorConfig, HunyuanOcrRopeScaling, HunyuanOcrVisionConfig, }; +#[cfg(feature = "hsd")] +pub use model::HunyuanHsdPrompts; pub use model::HunyuanOcr; diff --git a/oar-ocr-vl/src/hunyuanocr/model.rs b/oar-ocr-vl/src/hunyuanocr/model.rs index dfa5398..37d03cb 100644 --- a/oar-ocr-vl/src/hunyuanocr/model.rs +++ b/oar-ocr-vl/src/hunyuanocr/model.rs @@ -4,14 +4,112 @@ use super::config::{HunyuanOcrConfig, HunyuanOcrImageProcessorConfig}; use super::llm::HunyuanLlm; use super::processing::{HunyuanOcrImageInputs, preprocess_image}; use super::vision::HunyuanVisionModel; +#[cfg(feature = "hsd")] +use crate::attention::create_tree_attention_mask; use crate::attention::{combine_masks, create_causal_mask, create_left_padding_mask}; +#[cfg(feature = "hsd")] +use crate::hsd::backend_util::{commit_keep_indices, step_pos_ids, tree_pos_ids}; +#[cfg(feature = "hsd")] +use crate::hsd::drafting::{ + TargetDraftAdapter, build_region_draft_candidates_with_adapter, crop_region_image, + region_markdown_candidates_for, structure_result_to_layout_elements, +}; +#[cfg(feature = "hsd")] +use crate::hsd::prefix_tree::PrefixTree; +#[cfg(feature = "hsd")] +use crate::hsd::types::{ + AcceptStats, Draft, HsdConfig, HsdStats, RegionDraft, RegionStageStats, StageStats, +}; +#[cfg(feature = "hsd")] +use crate::hsd::verify::{SpecBackend, spec_decode}; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing}; +#[cfg(feature = "hsd")] +use candle_core::Result as CandleResult; use candle_core::{D, DType, Device, IndexOp, Tensor}; +#[cfg(feature = "hsd")] +use candle_nn::ops as cnn_ops; use image::RgbImage; use oar_ocr_core::core::OCRError; +#[cfg(feature = "hsd")] +use oar_ocr_core::domain::structure::{LayoutElement, StructureResult}; use std::path::{Path, PathBuf}; +#[cfg(feature = "hsd")] +use std::time::{Duration, Instant}; use tokenizers::Tokenizer; +/// Read `generation_config.json::repetition_penalty`. Returns 1.0 (no-op) if +/// the file is missing, unparseable, or the field is absent — matches +/// HuggingFace's default. Local HunyuanOCR config ships 1.03. +fn load_repetition_penalty(model_dir: &Path) -> f64 { + let path = model_dir.join("generation_config.json"); + let Ok(contents) = std::fs::read_to_string(&path) else { + return 1.0; + }; + let Ok(v) = serde_json::from_str::(&contents) else { + return 1.0; + }; + v.get("repetition_penalty") + .and_then(|x| x.as_f64()) + .unwrap_or(1.0) +} + +/// Apply HuggingFace's `RepetitionPenaltyLogitsProcessor` rule to a 1D logits +/// tensor and return the argmax id. For each token id that appears in +/// `seen`, the rule pushes its logit toward zero **once**: +/// `logit /= penalty` when positive, `logit *= penalty` when non-positive +/// (see `transformers.generation.logits_process.RepetitionPenaltyLogitsProcessor`). +/// HF computes this with `scatter(input_ids, …)`, which collapses duplicate +/// positions in `input_ids` down to a single penalty per unique vocab id — +/// applying the penalty per *occurrence* would compound to `penalty^k` for a +/// token repeated `k` times and quickly suppresses legitimate high-frequency +/// tokens like `` in a structured HTML page. We dedup before applying. +fn argmax_with_repetition_penalty( + logits: &Tensor, + seen: &[u32], + penalty: f32, +) -> Result { + let mut vec = logits + .to_dtype(DType::F32) + .and_then(|t| t.to_vec1::()) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "rep penalty to_vec1", e))?; + let vocab = vec.len(); + let mut unique: Vec = seen.to_vec(); + unique.sort_unstable(); + unique.dedup(); + for &id in &unique { + let idx = id as usize; + if idx >= vocab { + continue; + } + let v = vec[idx]; + vec[idx] = if v > 0.0 { v / penalty } else { v * penalty }; + } + let mut best_idx = 0usize; + let mut best_val = f32::NEG_INFINITY; + for (i, &v) in vec.iter().enumerate() { + if v > best_val { + best_val = v; + best_idx = i; + } + } + Ok(best_idx as u32) +} + +/// Page-level and region-level instructions for +/// [`HunyuanOcr::generate_hsd_full`]. +/// +/// `page` is used for Stage 2 full-page verification; `region` is used for +/// Stage 1 crop verification (each region image is paired with this single +/// prompt). Kept as a separate struct so the call site reads as +/// `HunyuanHsdPrompts { page: "...", region: "..." }` rather than two +/// adjacent untyped `&str`s. +#[cfg(feature = "hsd")] +#[derive(Debug, Clone, Copy)] +pub struct HunyuanHsdPrompts<'a> { + pub page: &'a str, + pub region: &'a str, +} + pub struct HunyuanOcr { device: Device, dtype: DType, @@ -21,6 +119,14 @@ pub struct HunyuanOcr { llm: HunyuanLlm, vision: HunyuanVisionModel, stop_token_ids: Vec, + /// `generation_config.json::repetition_penalty`. HuggingFace's + /// `generate(do_sample=False)` still applies repetition_penalty via the + /// LogitsProcessor list before the argmax. Without it, large-context chart + /// inputs can collapse into runaway-repeat loops (e.g. Mermaid node IDs + /// `A, B, … BZ, BZW, BZWW, BZWWZ …`) that never hit EOS. Default 1.0 means + /// the value isn't applied. Only the baseline greedy path consumes this; + /// the HSD verification paths intentionally keep raw logits. + repetition_penalty: f64, } impl HunyuanOcr { @@ -59,6 +165,8 @@ impl HunyuanOcr { let llm = HunyuanLlm::load(&cfg, vb.pp("model"))?; let vision = HunyuanVisionModel::load(&cfg.vision_config, vb.pp("vit"))?; + let repetition_penalty = load_repetition_penalty(model_dir); + Ok(Self { device, dtype, @@ -68,6 +176,7 @@ impl HunyuanOcr { llm, vision, stop_token_ids, + repetition_penalty, }) } @@ -101,10 +210,64 @@ impl HunyuanOcr { })]; } - match self.generate_internal(images, instructions, max_new_tokens) { + match self.generate_tokens_internal(images, instructions, max_new_tokens) { + Ok(results) => results + .into_iter() + .map(|tokens| self.decode_generated_tokens(&tokens)) + .collect(), + Err(e) => { + // Walk the source chain so the underlying candle / CUDA + // failure isn't hidden behind the top-level OCRError. + let mut chain = format!("generation failed: {e}"); + let mut cur: Option<&dyn std::error::Error> = std::error::Error::source(&e); + while let Some(s) = cur { + chain.push_str(&format!("\n caused by: {s}")); + cur = s.source(); + } + let msg = chain; + (0..images.len()) + .map(|_| { + Err(OCRError::InvalidInput { + message: msg.clone(), + }) + }) + .collect() + } + } + } + + /// Generate raw baseline tokens for oracle-draft / tokenizer round-trip + /// experiments. Tokens are exactly the ids emitted by the decode loop, + /// excluding stop tokens, before tokenizer decoding or trimming. + pub fn generate_tokens( + &self, + images: &[RgbImage], + instructions: &[impl AsRef], + max_new_tokens: usize, + ) -> Vec, OCRError>> { + if images.is_empty() { + return Vec::new(); + } + if images.len() != instructions.len() { + return vec![Err(OCRError::InvalidInput { + message: format!( + "HunyuanOCR: images count ({}) != instructions count ({})", + images.len(), + instructions.len() + ), + })]; + } + + match self.generate_tokens_internal(images, instructions, max_new_tokens) { Ok(results) => results.into_iter().map(Ok).collect(), Err(e) => { - let msg = format!("generation failed: {e}"); + let mut chain = format!("generation failed: {e}"); + let mut cur: Option<&dyn std::error::Error> = std::error::Error::source(&e); + while let Some(s) = cur { + chain.push_str(&format!("\n caused by: {s}")); + cur = s.source(); + } + let msg = chain; (0..images.len()) .map(|_| { Err(OCRError::InvalidInput { @@ -117,12 +280,12 @@ impl HunyuanOcr { } /// Internal generation implementation supporting batched inference. - fn generate_internal( + fn generate_tokens_internal( &self, images: &[RgbImage], instructions: &[impl AsRef], max_new_tokens: usize, - ) -> Result, OCRError> { + ) -> Result>, OCRError> { let batch_size = images.len(); // 1. Preprocess all images and build prompts @@ -187,16 +350,16 @@ impl HunyuanOcr { }); } - // Fuse image embeddings + // Fuse image embeddings — see hsd_prefill_single for the spec. let (start_pos, end_pos) = find_image_span(input_ids, &self.cfg)?; - let region_len = end_pos - start_pos + 1; + let inner_len = end_pos.saturating_sub(start_pos + 1); let (img_len, _) = image_embeds .dims2() .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "image_embeds dims2", e))?; - if region_len != img_len { + if inner_len != img_len { return Err(OCRError::InvalidInput { message: format!( - "HunyuanOCR: image span length mismatch: tokens={region_len} embeds={img_len}" + "HunyuanOCR: image-token run length mismatch: tokens={inner_len} embeds={img_len}" ), }); } @@ -206,16 +369,16 @@ impl HunyuanOcr { })?; let mut parts: Vec = Vec::with_capacity(3); - if start_pos > 0 { - parts.push(token_embeds.i((0..start_pos, ..)).map_err(|e| { - candle_to_ocr_inference("HunyuanOCR", "slice prefix embeddings", e) - })?); - } + // Prefix incl. image_start (text-embedded). + parts.push(token_embeds.i((0..=start_pos, ..)).map_err(|e| { + candle_to_ocr_inference("HunyuanOCR", "slice prefix embeddings", e) + })?); parts.push(image_embeds); - if end_pos + 1 < input_ids.len() { + if end_pos < input_ids.len() { + // Suffix incl. image_end (text-embedded). parts.push( token_embeds - .i((end_pos + 1..input_ids.len(), ..)) + .i((end_pos..input_ids.len(), ..)) .map_err(|e| { candle_to_ocr_inference("HunyuanOCR", "slice suffix embeddings", e) })?, @@ -305,10 +468,27 @@ impl HunyuanOcr { if finished[i] { next_tokens.push(0); // Padding token for finished samples } else { - let tok = logits - .argmax(D::Minus1) - .and_then(|t| t.to_scalar::()) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "argmax", e))?; + // Mirror HuggingFace's `generate(do_sample=False)`: even + // greedy decoding runs the LogitsProcessor list, so the + // `repetition_penalty` from generation_config.json gets + // applied to logits before argmax. Without this the model + // can spiral into runaway-repeat loops on large-context + // inputs (observed on chart_01.jpg, seq≈11584, producing + // 33K chars of synthetic Mermaid node IDs `BZ, BZW, …`). + // HSD applies the same processor inside HunyuanSpecBackend + // so τ=1.0 comparisons stay aligned with greedy decoding. + let tok = if self.repetition_penalty > 1.0 && !generated[i].is_empty() { + argmax_with_repetition_penalty( + logits, + &generated[i], + self.repetition_penalty as f32, + )? + } else { + logits + .argmax(D::Minus1) + .and_then(|t| t.to_scalar::()) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "argmax", e))? + }; if self.stop_token_ids.contains(&tok) { finished[i] = true; @@ -353,19 +533,540 @@ impl HunyuanOcr { } } - // 8. Decode results - let mut results = Vec::with_capacity(batch_size); - for tokens in generated { - let decoded = + self.llm.clear_kv_cache(); + Ok(generated) + } + + pub fn decode_tokens(&self, tokens: &[u32]) -> Result { + self.decode_generated_tokens(tokens) + } + + /// Decode tokens in the form the model actually emitted. HunyuanOCR's + /// `decode_tokens` only applies `trim()` post-process, so this is + /// effectively an alias provided for API symmetry with backends that do + /// have non-trivial post-process (PaddleOCR-VL / GLM-OCR). + pub fn decode_tokens_raw(&self, tokens: &[u32]) -> Result { + self.tokenizer + .decode(tokens, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("HunyuanOCR: tokenizer decode failed: {e}"), + }) + } + + pub fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + fn decode_generated_tokens(&self, tokens: &[u32]) -> Result { + Ok(self.decode_tokens_raw(tokens)?.trim().to_string()) + } + + /// Generate OCR output for a single image using Hierarchical Speculative + /// Decoding (currently page-level / Stage-2-style: drafts are matched and + /// verified against a single full-page forward pass). + /// + /// `drafts` are markdown-style strings produced by the lightweight pipeline + /// drafter (PP-DocLayout + region recognizers). Each is tokenized with the + /// HunyuanOCR tokenizer (this is mandatory — HSD's prefix matching must + /// happen in the target VLM's token space). + /// + /// Returns `(generated_text, stats)` where `stats.stage2.accept` records + /// the AAL and step counts needed to compute SR_decode / SR_e2e. + #[cfg(feature = "hsd")] + pub fn generate_hsd( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let t_drafter_prep = Instant::now(); + + // Tokenize all drafts up-front with the HunyuanOCR tokenizer. + let mut tokenized: Vec = Vec::with_capacity(drafts.len()); + for d in drafts { + if d.trim().is_empty() { + continue; + } + let enc = self.tokenizer - .decode(&tokens, true) + .encode(d.as_str(), false) .map_err(|e| OCRError::InvalidInput { - message: format!("HunyuanOCR: tokenizer decode failed: {e}"), + message: format!("HunyuanOCR HSD: tokenizer encode failed: {e}"), })?; - results.push(decoded.trim().to_string()); + let tokens = enc.get_ids().to_vec(); + if !tokens.is_empty() { + tokenized.push(Draft::new(tokens)); + } + } + self.generate_hsd_tokenized( + image, + instruction, + &tokenized, + hsd_cfg, + t_drafter_prep.elapsed(), + ) + } + + /// HSD entry that consumes already-tokenized drafts. This is the oracle + /// path used by benchmarks to avoid `decode -> encode` tokenizer + /// round-trips when the draft comes from this backend's own baseline. + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_token_drafts( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[Draft], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let (tokens, stats) = self.generate_hsd_tokens_tokenized( + image, + instruction, + drafts, + hsd_cfg, + Duration::ZERO, + )?; + let text = self.decode_tokens(&tokens)?; + Ok((text.trim().to_string(), stats)) + } + + /// Token-returning HSD entry for diagnostics and exact oracle checks. + #[cfg(feature = "hsd")] + pub fn generate_hsd_tokens_with_token_drafts( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[Draft], + hsd_cfg: &HsdConfig, + ) -> Result<(Vec, HsdStats), OCRError> { + self.generate_hsd_tokens_tokenized(image, instruction, drafts, hsd_cfg, Duration::ZERO) + } + + #[cfg(feature = "hsd")] + fn generate_hsd_tokenized( + &self, + image: &RgbImage, + instruction: &str, + tokenized: &[Draft], + hsd_cfg: &HsdConfig, + drafter_elapsed: Duration, + ) -> Result<(String, HsdStats), OCRError> { + let (generated, stats) = self.generate_hsd_tokens_tokenized( + image, + instruction, + tokenized, + hsd_cfg, + drafter_elapsed, + )?; + let text = self + .tokenizer + .decode(&generated, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("HunyuanOCR HSD: tokenizer decode failed: {e}"), + })?; + Ok((text.trim().to_string(), stats)) + } + + #[cfg(feature = "hsd")] + fn generate_hsd_tokens_tokenized( + &self, + image: &RgbImage, + instruction: &str, + tokenized: &[Draft], + hsd_cfg: &HsdConfig, + drafter_elapsed: Duration, + ) -> Result<(Vec, HsdStats), OCRError> { + if !self.device.is_cuda() { + return Err(OCRError::ConfigError { + message: "HSD requires CUDA device".to_string(), + }); + } + + let mut stats = HsdStats { + drafter: drafter_elapsed, + ..Default::default() + }; + // Stage 2 (page-level) prefill. + let t_prefill = Instant::now(); + let initial_lp = self.hsd_prefill_single(image, instruction)?; + stats.stage2.vision_prefill = t_prefill.elapsed(); + stats.stage2.forward_passes = 1; + + // Drive HSD verification. + let t_decode = Instant::now(); + let mut backend = HunyuanSpecBackend::new(self); + let mut accept = AcceptStats::default(); + let mut dsv = Default::default(); + let generated = spec_decode( + &mut backend, + tokenized, + initial_lp, + hsd_cfg.max_page_tokens, + &hsd_cfg.dsv, + &mut accept, + &mut dsv, + ) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "spec_decode", e))?; + stats.stage2.decode = t_decode.elapsed(); + stats.stage2.emitted_tokens = generated.len() as u32; + stats.stage2.accept = accept; + stats.stage2.dsv = dsv; + stats.stage2.forward_passes += backend.forward_passes; + self.llm.clear_kv_cache(); + Ok((generated, stats)) + } + + /// Full Hierarchical Speculative Decoding entry: Stage 1 (region-level + /// local verification) followed by Stage 2 (page-level global verification). + /// + /// Full HSD entry: Stage 1 verifies region-level candidate drafts on + /// cropped images, then Stage 2 verifies the Stage-1 output set on the + /// full page image. + /// + /// `text_candidates(elem)` can return top-k recognizer outputs or outputs + /// from multiple independent drafters. Candidates are serialized through + /// HunyuanOCR's target adapter, tokenized with HunyuanOCR's tokenizer, + /// deduplicated per region, and verified together in Stage 1. + /// + /// `hsd_cfg.enable_stage1` and `hsd_cfg.enable_stage2` independently gate + /// the two stages (used for ablations — `enable_stage2 = false` reproduces + /// the lossy "Stage 1 only" line in the paper's Tab. 7). + #[cfg(feature = "hsd")] + pub fn generate_hsd_full( + &self, + image: &RgbImage, + prompts: HunyuanHsdPrompts<'_>, + elements: &[LayoutElement], + ignore_labels: &[String], + text_candidates: C, + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> + where + C: Fn(&LayoutElement) -> Vec, + { + let HunyuanHsdPrompts { + page: page_instruction, + region: region_instruction, + } = prompts; + let mut stats = HsdStats::default(); + + // 1. Build region drafts using the target VLM's tokenizer. + let t_drafter = Instant::now(); + let tokenizer = &self.tokenizer; + let region_drafts = build_region_draft_candidates_with_adapter( + elements, + ignore_labels, + TargetDraftAdapter::HunyuanOcr, + &text_candidates, + |s: &str| { + tokenizer + .encode(s, false) + .map(|enc| enc.get_ids().to_vec()) + .unwrap_or_default() + }, + ); + stats.drafter = t_drafter.elapsed(); + let original_region_drafts = region_markdown_candidates_for( + elements, + ignore_labels, + TargetDraftAdapter::HunyuanOcr, + &text_candidates, + ); + + // 2. Stage 1 — build independent region work items, then run + // target-model verification. Each HunyuanOCR instance owns a mutable + // LLM KV cache, so parallel verification uses separate model workers. + // + // Collect (reading_order, text) pairs so Stage 2 can join in the + // reading order the layout drafter assigned, not the + // worker-completion order. The `RegionDraft::reading_order` field is + // populated by `build_region_drafts` and was previously ignored. + let mut stage1_outputs: Vec<(usize, String)> = Vec::with_capacity(region_drafts.len()); + if hsd_cfg.enable_stage1 && !region_drafts.is_empty() { + let t_prep = Instant::now(); + let stage1_work = build_stage1_work_items(image, ®ion_drafts)?; + stats.stage1.draft_prep += t_prep.elapsed(); + + let stage1_results = self.run_stage1_work(&stage1_work, region_instruction, hsd_cfg)?; + // Pair each Stage-1 output with its layout-drafter reading_order. + // Regions without an explicit order fall back to their input + // index (stable, deterministic). + for (idx, ((text, item_stats), (region, _img))) in stage1_results + .into_iter() + .zip(stage1_work.iter()) + .enumerate() + { + let order = region.reading_order.unwrap_or(idx); + stats.stage1_regions.push(RegionStageStats { + kind: region.kind, + stats: item_stats.clone(), + }); + stats.stage1.add_assign(item_stats); + stage1_outputs.push((order, text)); + } + stage1_outputs.sort_by_key(|(order, _)| *order); + } + + // 3. Stage 2 — page-level global verification. Per paper Eq. 3 the + // page draft is the *unordered set* `Ỹ^pg = {ŷ^(i)}`, one draft + // per region. We pass that set straight to `spec_decode` instead + // of concatenating into a single markdown string — the sliding + // window in `collect_candidates` scans each draft independently + // (Eqs. 1+2), so per-region n-gram locality is preserved even + // when full-page transitions don't appear naturally in the + // target VLM's output. + if hsd_cfg.enable_stage2 { + let page_drafts: Vec = if !stage1_outputs.is_empty() { + stage1_outputs.iter().map(|(_, t)| t.clone()).collect() + } else { + original_region_drafts + }; + if !page_drafts.is_empty() { + let (text, s2_stats) = + self.generate_hsd(image, page_instruction, &page_drafts, hsd_cfg)?; + stats.stage2 = s2_stats.stage2; + stats.drafter += s2_stats.drafter; + return Ok((text, stats)); + } + } + + // 4. Stage 2 disabled — return Stage-1-only aggregation (lossy ablation). + let stage1_only = stage1_outputs + .into_iter() + .map(|(_, t)| t) + .collect::>() + .join("\n\n"); + Ok((stage1_only, stats)) + } + + /// One-call HSD entry that consumes a `StructureResult` (i.e. the output of + /// the OARStructure / PP-StructureV3 pipeline) directly. + /// + /// This is the paper's PP-StructureV3 → SpecDecode integration point + /// realized as a single Rust call: + /// + /// ```no_run + /// # use oar_ocr::prelude::{OARStructure, OARStructureBuilder}; + /// # use oar_ocr_vl::HunyuanOcr; + /// # use oar_ocr_vl::hsd::types::HsdConfig; + /// # fn main() -> Result<(), Box> { + /// # let device = oar_ocr_vl::utils::parse_device("cuda:0")?; + /// # let structure: OARStructure = unimplemented!(); + /// # let model: HunyuanOcr = unimplemented!(); + /// # let image: image::RgbImage = unimplemented!(); + /// let s = structure.predict_image(image.clone())?; + /// let (output, stats) = model.generate_hsd_with_structure( + /// &image, + /// "Extract and parse the document content as markdown.", + /// "Extract and parse this region.", + /// &s, + /// &[], + /// &HsdConfig::default(), + /// )?; + /// # Ok(()) } + /// ``` + /// + /// Internally: + /// 1. Calls [`structure_result_to_layout_elements`] to backfill table HTML + /// and formula LaTeX from the structure's side records onto the layout + /// elements. + /// 2. Delegates to [`Self::generate_hsd_full`] with `text_candidates = + /// |elem| elem.text.iter().cloned().collect()` (single-candidate per + /// region — equivalent to the paper's PP-StructureV3 drafter). Callers + /// that want multi-candidate per region (e.g. top-k OCR) should keep + /// using `generate_hsd_full` directly with a custom closure. + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_structure( + &self, + image: &RgbImage, + page_instruction: &str, + region_instruction: &str, + structure: &StructureResult, + ignore_labels: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let elements = structure_result_to_layout_elements(structure); + self.generate_hsd_full( + image, + HunyuanHsdPrompts { + page: page_instruction, + region: region_instruction, + }, + &elements, + ignore_labels, + |elem| elem.text.iter().cloned().collect(), + hsd_cfg, + ) + } + + #[cfg(feature = "hsd")] + fn run_stage1_work( + &self, + stage1_work: &[(RegionDraft, RgbImage)], + instruction: &str, + hsd_cfg: &HsdConfig, + ) -> Result, OCRError> { + stage1_work + .iter() + .map(|(region, crop)| self.run_stage1_item(region, crop, instruction, hsd_cfg)) + .collect() + } + + #[cfg(feature = "hsd")] + fn run_stage1_item( + &self, + region: &RegionDraft, + crop: &RgbImage, + instruction: &str, + hsd_cfg: &HsdConfig, + ) -> Result<(String, StageStats), OCRError> { + let mut stats = StageStats::default(); + + let t_pre = Instant::now(); + let init_lp = self.hsd_prefill_single(crop, instruction)?; + stats.vision_prefill += t_pre.elapsed(); + stats.forward_passes += 1; + + let t_dec = Instant::now(); + let mut backend = HunyuanSpecBackend::new(self); + let mut accept = AcceptStats::default(); + let mut dsv = Default::default(); + let toks = spec_decode( + &mut backend, + ®ion.drafts, + init_lp, + hsd_cfg.max_region_tokens, + &hsd_cfg.dsv, + &mut accept, + &mut dsv, + ) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "stage1 spec_decode", e))?; + stats.decode += t_dec.elapsed(); + stats.emitted_tokens += toks.len() as u32; + stats.forward_passes += backend.forward_passes; + stats.dsv = dsv; + stats.accept = accept; + self.llm.clear_kv_cache(); + + let text = self + .tokenizer + .decode(&toks, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("HunyuanOCR HSD: tokenizer decode failed: {e}"), + })?; + Ok((text.trim().to_string(), stats)) + } + + /// Run the prefill forward pass for HSD, leaving the LLM's KV cache + /// populated. Returns the F32 log-probabilities at the last prompt + /// position (shape `(vocab,)`). + #[cfg(feature = "hsd")] + fn hsd_prefill_single(&self, image: &RgbImage, instruction: &str) -> Result { + // 1. Preprocess. + let image_inputs = preprocess_image( + image, + &self.image_cfg, + &self.cfg.vision_config, + &self.device, + self.dtype, + )?; + let prompt = build_prompt(instruction); + let enc = self + .tokenizer + .encode(prompt, false) + .map_err(|e| OCRError::InvalidInput { + message: format!("HunyuanOCR HSD: tokenizer encode failed: {e}"), + })?; + let mut input_ids = enc.get_ids().to_vec(); + expand_image_tokens_in_place(&mut input_ids, &self.cfg, &image_inputs)?; + let seq_len = input_ids.len(); + + // 2. Vision features. + let (image_embeds, merged_hw) = self.vision.forward(&image_inputs.pixel_values)?; + if merged_hw + != ( + image_inputs.grid_thw_merged.1, + image_inputs.grid_thw_merged.2, + ) + { + return Err(OCRError::InvalidInput { + message: format!( + "HunyuanOCR: merged grid mismatch: vision={:?} preprocessor={:?}", + merged_hw, image_inputs.grid_thw_merged + ), + }); } - Ok(results) + // 3. Build fused embeddings. + let input_ids_t = Tensor::new(input_ids.clone(), &self.device) + .and_then(|t| t.reshape((1, seq_len))) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "create input_ids", e))?; + let token_embeds = + self.llm.embed(&input_ids_t)?.squeeze(0).map_err(|e| { + candle_to_ocr_inference("HunyuanOCR", "squeeze token embeddings", e) + })?; + + // Splice: keep image_start (120118) and image_end (120119) as text + // tokens (matching the upstream HF processor), replace only the + // contiguous image_token_id run between them with the vit output. + let (start_pos, end_pos) = find_image_span(&input_ids, &self.cfg)?; + let inner_len = end_pos.saturating_sub(start_pos + 1); + let (img_len, _) = image_embeds + .dims2() + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "image_embeds dims2", e))?; + if inner_len != img_len { + return Err(OCRError::InvalidInput { + message: format!( + "HunyuanOCR: image-token run length mismatch: tokens={inner_len} embeds={img_len}" + ), + }); + } + let mut parts: Vec = Vec::with_capacity(3); + // Prefix: [0, start_pos] inclusive — keeps image_start as text-embedded. + parts.push( + token_embeds + .i((0..=start_pos, ..)) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "slice prefix embeddings", e))?, + ); + parts.push(image_embeds); + if end_pos < input_ids.len() { + // Suffix: [end_pos, input_ids.len()) — keeps image_end as text-embedded. + parts.push( + token_embeds + .i((end_pos..input_ids.len(), ..)) + .map_err(|e| { + candle_to_ocr_inference("HunyuanOCR", "slice suffix embeddings", e) + })?, + ); + } + let refs: Vec<&Tensor> = parts.iter().collect(); + let inputs_embeds = Tensor::cat(&refs, 0) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "cat embeds", e))? + .unsqueeze(0) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "unsqueeze embeds", e))?; + + // 4. Position ids and causal mask. + let pos_ids = build_position_ids(&input_ids, &self.cfg, &image_inputs)?; + let causal = create_causal_mask(seq_len, seq_len, self.dtype, &self.device) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "create causal", e))?; + + // 5. Prefill. + self.llm.clear_kv_cache(); + let hidden = self.llm.forward(&inputs_embeds, &pos_ids, Some(&causal))?; + + // 6. Last-position log-probabilities, F32. + let last = hidden + .i((0, seq_len - 1, ..)) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "get last hidden", e))?; + let logits = self.logits_from_hidden(&last)?; // (vocab,) + let lp = cnn_ops::log_softmax( + &logits + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "logits to f32", e))?, + D::Minus1, + ) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "log_softmax prefill", e))?; + Ok(lp) } fn logits_from_hidden(&self, hidden: &Tensor) -> Result { @@ -392,6 +1093,257 @@ impl HunyuanOcr { } } +#[cfg(feature = "hsd")] +fn build_stage1_work_items( + image: &RgbImage, + region_drafts: &[RegionDraft], +) -> Result, OCRError> { + region_drafts + .iter() + .map(|region| crop_region_image(image, ®ion.bbox).map(|crop| (region.clone(), crop))) + .collect() +} + +/// HSD adapter for HunyuanOCR. Borrows the model and drives the LLM's KV +/// cache through tree-attention verifications and single-token decode steps. +#[cfg(feature = "hsd")] +struct HunyuanSpecBackend<'a> { + model: &'a HunyuanOcr, + /// KV cache length captured at the start of the most recent + /// [`SpecBackend::verify_tree`] call. [`SpecBackend::commit_verify`] uses + /// this to translate path-node indices into absolute KV positions. + pre_verify_kv: usize, + /// Number of LLM forward passes (verify_tree + step_one) — populated for + /// the per-stage StageStats accounting in `generate_hsd`. + forward_passes: u32, + /// Generated tokens accepted so far (post-prompt). Used to apply + /// `generation_config.json::repetition_penalty` to log-probs in step_one / + /// verify_tree, mirroring the baseline greedy path. Without this the + /// τ=1.0 oracle correctness check fails when the baseline picks a + /// rep-penalized token but HSD's raw argmax picks the same token's + /// repetition. + committed_tokens: Vec, + /// Token-ids buffer passed into the most recent `verify_tree` call, + /// indexed by node id. Needed by `commit_verify` to extend + /// `committed_tokens` with the actually accepted tokens. + last_verify_tokens: Vec, + /// Parent-pointer array for the most recent `verify_tree` call, indexed + /// by node id. Used to compute the per-row "seen" set when applying + /// repetition penalty (committed prefix + ancestor chain to that node). + last_verify_parents: Vec>, +} + +#[cfg(feature = "hsd")] +impl<'a> HunyuanSpecBackend<'a> { + fn new(model: &'a HunyuanOcr) -> Self { + Self { + model, + pre_verify_kv: 0, + forward_passes: 0, + committed_tokens: Vec::new(), + last_verify_tokens: Vec::new(), + last_verify_parents: Vec::new(), + } + } + + fn project_logits_2d(&self, hidden_2d: &Tensor) -> CandleResult { + // hidden_2d: (N, hidden). Project to (N, vocab) raw logits in F32. + let w = self.model.llm.token_embedding_weight(); + let wt = w.transpose(0, 1)?; + let logits = hidden_2d.matmul(&wt)?.to_dtype(DType::F32)?; + Ok(logits) + } + + fn project_logits_1d(&self, hidden_1d: &Tensor) -> CandleResult { + // hidden_1d: (hidden,). Project to (vocab,) raw logits in F32. + let w = self.model.llm.token_embedding_weight(); + let wt = w.transpose(0, 1)?; + let logits = hidden_1d + .unsqueeze(0)? + .matmul(&wt)? + .squeeze(0)? + .to_dtype(DType::F32)?; + Ok(logits) + } + + /// Apply HF's `RepetitionPenaltyLogitsProcessor` rule to one row of raw + /// logits in-place. Deduped seen set: each vocab id pays the penalty at + /// most once, mirroring HF's `scatter`. Logits are host-side after a + /// `to_vec1`; the caller rebuilds a fresh device tensor afterwards. + fn apply_rep_penalty_row(row: &mut [f32], seen: &[u32], penalty: f32) { + if penalty == 1.0 || seen.is_empty() { + return; + } + let vocab = row.len(); + let mut unique: Vec = seen.to_vec(); + unique.sort_unstable(); + unique.dedup(); + for &id in &unique { + let idx = id as usize; + if idx >= vocab { + continue; + } + let v = row[idx]; + row[idx] = if v > 0.0 { v / penalty } else { v * penalty }; + } + } + + /// Apply rep_penalty + log_softmax to a 1D logits tensor with the + /// committed-tokens seen set. + fn penalize_and_log_softmax_1d(&self, logits_1d: &Tensor) -> CandleResult { + let penalty = self.model.repetition_penalty as f32; + if penalty == 1.0 || self.committed_tokens.is_empty() { + return cnn_ops::log_softmax(logits_1d, D::Minus1); + } + let device = logits_1d.device().clone(); + let mut row = logits_1d.to_vec1::()?; + Self::apply_rep_penalty_row(&mut row, &self.committed_tokens, penalty); + let len = row.len(); + let penalized = Tensor::from_vec(row, (len,), &device)?; + cnn_ops::log_softmax(&penalized, D::Minus1) + } + + /// Apply per-row rep_penalty + log_softmax to a 2D logits tensor where + /// row `i` corresponds to verify-tree node `i`. The "seen" set for each + /// row is `committed_tokens` plus the ancestor-chain tokens from the + /// tree (the node's own token is *included* — the model already saw it + /// in the forward pass, so it counts toward the rep-penalty set for + /// predicting the next position). + fn penalize_and_log_softmax_verify( + &self, + logits_2d: &Tensor, + tree_tokens: &[u32], + tree_parents: &[Option], + ) -> CandleResult { + let penalty = self.model.repetition_penalty as f32; + if penalty == 1.0 || (self.committed_tokens.is_empty() && tree_tokens.is_empty()) { + return cnn_ops::log_softmax(logits_2d, D::Minus1); + } + let (n, vocab) = logits_2d.dims2()?; + let device = logits_2d.device().clone(); + let mut flat: Vec = logits_2d.to_vec2::()?.into_iter().flatten().collect(); + // Build per-node ancestor chains once. + let ancestors: Vec> = (0..n) + .map(|i| { + let mut chain = Vec::new(); + let mut cur = Some(i); + while let Some(j) = cur { + chain.push(tree_tokens[j]); + cur = tree_parents[j]; + } + chain + }) + .collect(); + for (row, ancestor_chain) in flat.chunks_mut(vocab).zip(ancestors.iter()) { + let mut seen = self.committed_tokens.clone(); + seen.extend_from_slice(ancestor_chain); + Self::apply_rep_penalty_row(row, &seen, penalty); + } + let penalized = Tensor::from_vec(flat, (n, vocab), &device)?; + cnn_ops::log_softmax(&penalized, D::Minus1) + } +} + +#[cfg(feature = "hsd")] +impl<'a> SpecBackend for HunyuanSpecBackend<'a> { + fn step_one(&mut self, token: u32) -> CandleResult { + let model = self.model; + let device = &model.device; + + // (1, 1) token tensor → (1, 1, hidden). + let tok_t = Tensor::new(vec![token], device)?.reshape((1usize, 1usize))?; + let embeds = model + .llm + .embed(&tok_t) + .map_err(|e| candle_core::Error::Msg(format!("HunyuanOCR HSD step_one embed: {e}")))?; + + // Position id = current cache length (next slot). HunyuanOCR uses + // 4-axis MRoPE with rope_delta = 0. + let pos_ids = step_pos_ids(4, model.llm.current_kv_len(), 0, device)?; + + // Continuation forward — no mask (autoregressive on growing cache). + let hidden = model.llm.forward(&embeds, &pos_ids, None).map_err(|e| { + candle_core::Error::Msg(format!("HunyuanOCR HSD step_one forward: {e}")) + })?; + self.forward_passes += 1; + let last = hidden.i((0, 0, ..))?; + // Record the just-decoded token before scoring the next position so + // rep_penalty includes it in the seen set (mirrors baseline greedy + // which calls `argmax_with_repetition_penalty(logits, &generated[..])` + // after `generated.push(prev_tok)`). + self.committed_tokens.push(token); + let logits = self.project_logits_1d(&last)?; + self.penalize_and_log_softmax_1d(&logits) + } + + fn verify_tree(&mut self, tree: &PrefixTree) -> CandleResult { + let n = tree.num_nodes(); + let model = self.model; + let device = &model.device; + let dtype = model.dtype; + + let prefix_kv = model.llm.current_kv_len(); + self.pre_verify_kv = prefix_kv; + + // Packed tree tokens: (1, N). + let tok_t = Tensor::new(tree.tokens.clone(), device)?.reshape((1usize, n))?; + let embeds = model.llm.embed(&tok_t).map_err(|e| { + candle_core::Error::Msg(format!("HunyuanOCR HSD verify_tree embed: {e}")) + })?; + + // Position ids: depth-`d` node represents the d-th newly generated + // token, so its absolute sequence position is `prefix_kv + d - 1` + // (depth-1 token sits in the very next cache slot after the prompt, + // which is at index `prefix_kv`). HunyuanOCR uses 4-axis MRoPE with + // rope_delta = 0. + let pos_ids = tree_pos_ids(4, prefix_kv, 0, tree, device)?; + + // Tree-ancestry mask — each candidate token sees prefix + its own + // ancestor chain only (paper Fig. 2c). + let mask = create_tree_attention_mask(&tree.parents, prefix_kv, dtype, device)?; + + let hidden = model + .llm + .forward(&embeds, &pos_ids, Some(&mask)) + .map_err(|e| { + candle_core::Error::Msg(format!("HunyuanOCR HSD verify_tree forward: {e}")) + })?; + self.forward_passes += 1; + // (1, N, hidden) → (N, hidden) → (N, vocab) log-probs. + let h2 = hidden.squeeze(0)?; + let logits = self.project_logits_2d(&h2)?; + // Cache tree shape so `commit_verify` can read off the accepted + // tokens and extend `committed_tokens`. + self.last_verify_tokens = tree.tokens.clone(); + self.last_verify_parents = tree.parents.clone(); + self.penalize_and_log_softmax_verify(&logits, &tree.tokens, &tree.parents) + } + + fn commit_verify(&mut self, accepted_path: &[usize]) -> CandleResult<()> { + let indices = commit_keep_indices(self.pre_verify_kv, accepted_path); + // Extend the rep-penalty seen set with the tokens we just committed + // (each accepted path index maps to a verify-tree node). + for &p in accepted_path { + if let Some(&tok) = self.last_verify_tokens.get(p) { + self.committed_tokens.push(tok); + } + } + self.model.llm.keep_kv_indices(&indices).map_err(|e| { + let mut chain = format!("HunyuanOCR HSD commit_verify: {e}"); + let mut cur: Option<&dyn std::error::Error> = std::error::Error::source(&e); + while let Some(s) = cur { + chain.push_str(&format!("\n caused by: {s}")); + cur = s.source(); + } + candle_core::Error::Msg(chain) + }) + } + + fn is_eos(&self, tok: u32) -> bool { + self.model.stop_token_ids.contains(&tok) + } +} + fn resolve_safetensors_shards(model_dir: &Path) -> Result, OCRError> { let single = model_dir.join("model.safetensors"); if single.exists() { @@ -435,7 +1387,17 @@ fn expand_image_tokens_in_place( image_inputs: &HunyuanOcrImageInputs, ) -> Result<(), OCRError> { let (_, hm, wm) = image_inputs.grid_thw_merged; - let expected_tokens = hm.saturating_mul(wm.saturating_add(1)); + // Match the upstream HuggingFace processor: + // transformers/models/hunyuan_vl/processing_hunyuan_vl.py:62 + // num_image_tokens = patch_h * (patch_w + 1) + 2 + // The `+ 2` accounts for the begin/end markers that the vit's perceive + // step prepends/appends to the spatial sequence — those positions also + // get replaced by image embeddings (rather than carrying separate + // `image_start` / `image_end` text embeddings, which is the scheme an + // earlier internal Tencent variant used and which this Rust port + // originally followed). The placeholder run is contiguous and uses + // `image_token_id` exclusively — no `image_newline_token_id` interleaving. + let expected_tokens = hm.saturating_mul(wm.saturating_add(1)).saturating_add(2); if expected_tokens == 0 { return Err(OCRError::InvalidInput { message: "HunyuanOCR: empty merged grid".to_string(), @@ -458,13 +1420,10 @@ fn expand_image_tokens_in_place( }); }; - let mut expanded: Vec = Vec::with_capacity(expected_tokens); - for _r in 0..hm { - expanded.extend(std::iter::repeat_n(cfg.image_token_id, wm)); - expanded.push(cfg.image_newline_token_id); - } - - // Replace the single placeholder image_token_id with the expanded sequence. + let expanded: Vec = std::iter::repeat_n(cfg.image_token_id, expected_tokens).collect(); + // Replace the single image_token_id placeholder with the expanded run. + // image_start_token_id and image_end_token_id stay in input_ids on + // either side; they receive plain text embeddings via embed_tokens. input_ids.splice(pos..pos + 1, expanded); Ok(()) } @@ -493,6 +1452,26 @@ fn build_position_ids( image_inputs: &HunyuanOcrImageInputs, ) -> Result { let seq_len = input_ids.len(); + // 4-axis XDRoPE position ids matching the upstream HF processor exactly: + // transformers/models/hunyuan_vl/processing_hunyuan_vl.py:74-94. + // + // Axis order is `[seq, w, h, t]` (the order `select_rope_sections` + // expects for `xdrope_section`). For non-image tokens all four axes hold + // the plain sequence index. For the spatial run inside the image span we + // overwrite axes w/h/t: + // - w cycles `0..(patch_w+1)`, repeated `patch_h` times, + // - h is `[h]*(patch_w+1)` for `h` in `0..patch_h`, + // - t is 0 across the run. + // The run starts at `first_image_token + 1` and spans `(patch_w+1)*patch_h` + // tokens — i.e. the *middle* of the expanded `patch_h*(patch_w+1) + 2` + // image-token block. The first and last image_tokens (perceive begin/end + // markers) keep their default arange position. + // + // Earlier this port used pure-sequential position ids for all four axes, + // which made the model produce hallucinated text (e.g. "The text in the + // image is not complete.") instead of OCR output: the trained weights + // expect the spatial xdrope rotation to encode 2-D image structure, and + // collapsing to 1-D destroys the geometry. let mut pos: Vec = vec![0; 4 * seq_len]; for i in 0..seq_len { let p = i as i64; @@ -502,36 +1481,27 @@ fn build_position_ids( pos[3 * seq_len + i] = p; } - let (start_pos, _end_pos) = find_image_span(input_ids, cfg)?; - let vision_start = input_ids[start_pos + 1..] - .iter() - .position(|&id| id == cfg.image_token_id) - .map(|p| start_pos + 1 + p) - .ok_or_else(|| OCRError::InvalidInput { - message: "HunyuanOCR: image_token_id not found after image_start_token_id".to_string(), - })?; - - let (_, hm, wm) = image_inputs.grid_thw_merged; - let vision_tokens = hm.saturating_mul(wm.saturating_add(1)); - let base = vision_start as i64; - - for j in 0..vision_tokens { - let idx = vision_start + j; - if idx >= seq_len { + let first_image_pos = input_ids.iter().position(|&id| id == cfg.image_token_id); + if let Some(first) = first_image_pos { + let (_, hm, wm) = image_inputs.grid_thw_merged; + let start = first + 1; + let replace_num = (wm + 1) * hm; + if start + replace_num > seq_len { return Err(OCRError::InvalidInput { message: format!( - "HunyuanOCR: vision token span exceeds input length (start={vision_start} count={vision_tokens} len={seq_len})" + "HunyuanOCR: image span ({} positions starting at {}) exceeds input length {}", + replace_num, start, seq_len ), }); } - let row = j / (wm + 1); - let col = j % (wm + 1); - let t_pos = base; - let h_pos = base + row as i64; - let w_pos = base + col as i64; - pos[seq_len + idx] = t_pos; - pos[2 * seq_len + idx] = h_pos; - pos[3 * seq_len + idx] = w_pos; + for j in 0..replace_num { + let idx = start + j; + let row = j / (wm + 1); // 0..hm + let col = j % (wm + 1); // 0..wm (inclusive of newline column) + pos[seq_len + idx] = col as i64; // axis 1 = w + pos[2 * seq_len + idx] = row as i64; // axis 2 = h + pos[3 * seq_len + idx] = 0; // axis 3 = t + } } Tensor::from_vec( @@ -547,3 +1517,31 @@ fn build_position_ids( ) }) } + +#[cfg(all(test, feature = "hsd"))] +mod tests { + use super::*; + + #[test] + fn hsd_repetition_penalty_matches_baseline_argmax_for_nondefault_penalty() { + let device = Device::Cpu; + let logits = Tensor::from_vec(vec![0.0f32, 8.0, 5.0], (3,), &device).unwrap(); + let seen = [1u32, 1u32]; + let penalty = 1.7f32; + + let baseline = argmax_with_repetition_penalty(&logits, &seen, penalty).unwrap(); + + let mut row = logits.to_vec1::().unwrap(); + HunyuanSpecBackend::apply_rep_penalty_row(&mut row, &seen, penalty); + let hsd = row + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Less)) + .map(|(idx, _)| idx as u32) + .unwrap(); + + assert_eq!(baseline, hsd); + assert_eq!(hsd, 2); + assert!((row[1] - 8.0 / penalty).abs() < 1e-6); + } +} diff --git a/oar-ocr-vl/src/hunyuanocr/processing.rs b/oar-ocr-vl/src/hunyuanocr/processing.rs index 95e0785..887876f 100644 --- a/oar-ocr-vl/src/hunyuanocr/processing.rs +++ b/oar-ocr-vl/src/hunyuanocr/processing.rs @@ -1,7 +1,7 @@ use super::config::{HunyuanOcrImageProcessorConfig, HunyuanOcrVisionConfig}; use crate::utils::{ candle_to_ocr_processing, - image::{clamp_to_max_image_size, image_to_chw, pil_resample_to_filter_type, smart_resize}, + image::{clamp_to_max_image_size, image_to_chw, smart_resize}, }; use candle_core::{DType, Device, Tensor}; use image::{RgbImage, imageops::FilterType}; @@ -86,10 +86,11 @@ pub fn preprocess_image( }); } - let resize_filter = cfg - .resample - .and_then(pil_resample_to_filter_type) - .unwrap_or(FilterType::CatmullRom); + // Match transformers' HunYuanVLImageProcessor._preprocess: it accepts a + // resample argument but calls PIL `image.resize((w, h))` without passing it. + // HunyuanOCR therefore ignores cfg.resample and always uses the Pillow + // default BICUBIC-equivalent path, even when the config says `resample: 1`. + let resize_filter = FilterType::CatmullRom; let (h, w) = (image.height(), image.width()); let factor = (cfg.patch_size * cfg.merge_size) as u32; diff --git a/oar-ocr-vl/src/hunyuanocr/vision.rs b/oar-ocr-vl/src/hunyuanocr/vision.rs index 8ec5b61..db897a8 100644 --- a/oar-ocr-vl/src/hunyuanocr/vision.rs +++ b/oar-ocr-vl/src/hunyuanocr/vision.rs @@ -4,6 +4,13 @@ use candle_core::{D, DType, Device, IndexOp, Tensor}; use candle_nn::{Conv2d, Conv2dConfig, LayerNorm, LayerNormConfig, Linear, Module}; use oar_ocr_core::core::OCRError; +/// Late vit layers are the cross-implementation drift hotspot: BF16 Q·K +/// accumulation drift redirects attention "sink" positions enough to swap the +/// dominant attention head by layer 26 (cosine to upstream drops from +/// ~0.999 at layer 11 to ~0.95). Running attention in F32 from this layer +/// onwards is the empirically stable compromise. See `VisionAttention::forward`. +const VIT_LATE_F32_THRESHOLD: usize = 20; + #[derive(Debug, Clone)] struct VisionEmbeddings { patch_embedding: Conv2d, @@ -152,6 +159,7 @@ impl VisionEmbeddings { }) } + #[allow(dead_code)] fn extra_pos(&self) -> Result { self.position_embedding .embeddings() @@ -175,10 +183,17 @@ struct VisionAttention { num_heads: usize, head_dim: usize, scaling: f64, + /// Precomputed `layer_idx >= VIT_LATE_F32_THRESHOLD` — true when this + /// layer's attention runs in F32 (see [`VIT_LATE_F32_THRESHOLD`]). + use_f32: bool, } impl VisionAttention { - fn load(cfg: &HunyuanOcrVisionConfig, vb: candle_nn::VarBuilder) -> Result { + fn load( + cfg: &HunyuanOcrVisionConfig, + layer_idx: usize, + vb: candle_nn::VarBuilder, + ) -> Result { if !cfg.hidden_size.is_multiple_of(cfg.num_attention_heads) { return Err(OCRError::ConfigError { message: format!( @@ -205,6 +220,7 @@ impl VisionAttention { num_heads: cfg.num_attention_heads, head_dim, scaling: (head_dim as f64).powf(-0.5), + use_f32: layer_idx >= VIT_LATE_F32_THRESHOLD, }) } @@ -217,28 +233,32 @@ impl VisionAttention { ) })?; - let q = self + let q_proj = self .q_proj .forward(hidden_states) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn q_proj", e))? + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn q_proj", e))?; + let k_proj = self + .k_proj + .forward(hidden_states) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k_proj", e))?; + let v_proj = self + .v_proj + .forward(hidden_states) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn v_proj", e))?; + + let q = q_proj .reshape((b, seq_len, self.num_heads, self.head_dim)) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn q reshape", e))? .transpose(1, 2) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn q transpose", e))?; - let k = self - .k_proj - .forward(hidden_states) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k_proj", e))? + let k = k_proj .reshape((b, seq_len, self.num_heads, self.head_dim)) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k reshape", e))? .transpose(1, 2) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k transpose", e))?; - let v = self - .v_proj - .forward(hidden_states) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn v_proj", e))? + let v = v_proj .reshape((b, seq_len, self.num_heads, self.head_dim)) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn v reshape", e))? .transpose(1, 2) @@ -254,31 +274,90 @@ impl VisionAttention { .contiguous() .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn v contiguous", e))?; - let attn_weights = q - .matmul( - &k.transpose(2, 3) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k t23", e))? - .contiguous() - .map_err(|e| { - candle_to_ocr_inference("HunyuanOCR", "vit attn k t23 contiguous", e) - })?, - ) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn qk matmul", e))? - .affine(self.scaling, 0.0) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn scaling", e))?; - - let attn_weights = candle_nn::ops::softmax_last_dim( - &attn_weights + // Chunked attention over the query dimension. Without chunking the + // (B, H, N, N) attention matrix at N=4320 (the vit's full patch + // sequence) needs ~4 GB just for the BF16 buffer, OOM'ing the 4090. + // + // Late layers (idx >= VIT_LATE_F32_THRESHOLD) run attention in F32 + // because BF16 Q·K accumulation drift redirects attention to different + // sink tokens and gets amplified by the final MLPs — see the constant + // for the cross-implementation cosine numbers. + const VIT_ATTN_QUERY_CHUNK: usize = 1024; + let use_f32 = self.use_f32; + let v_dtype = v.dtype(); + let kt = k + .transpose(2, 3) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k t23", e))? + .contiguous() + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn k t23 contiguous", e))?; + let (kt_attn, v_attn) = if use_f32 { + let kt_f32 = kt .to_dtype(DType::F32) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn cast f32", e))?, - ) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn softmax", e))? - .to_dtype(v.dtype()) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn cast back", e))?; + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn kt to f32", e))?; + let v_f32 = v + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn v to f32", e))?; + (kt_f32, v_f32) + } else { + (kt, v.clone()) + }; - let attn_output = attn_weights - .matmul(&v) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn av matmul", e))? + let attend_chunk = |q_in: &Tensor| -> Result { + let q_use = if use_f32 { + q_in.to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn q to f32", e))? + } else { + q_in.clone() + }; + let attn = q_use + .matmul(&kt_attn) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn qk matmul", e))? + .affine(self.scaling, 0.0) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn scaling", e))?; + let attn = if use_f32 { + candle_nn::ops::softmax_last_dim(&attn) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn softmax", e))? + } else { + let attn_f32 = attn + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn cast f32", e))?; + let attn_f32 = candle_nn::ops::softmax_last_dim(&attn_f32) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn softmax", e))?; + attn_f32 + .to_dtype(v_dtype) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn cast back", e))? + }; + let out = attn + .matmul(&v_attn) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn av matmul", e))?; + if use_f32 { + out.to_dtype(v_dtype) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn out to bf16", e)) + } else { + Ok(out) + } + }; + + let attn_output = if seq_len <= VIT_ATTN_QUERY_CHUNK { + attend_chunk(&q)? + } else { + let mut chunks: Vec = + Vec::with_capacity(seq_len.div_ceil(VIT_ATTN_QUERY_CHUNK)); + let mut start = 0; + while start < seq_len { + let len = (seq_len - start).min(VIT_ATTN_QUERY_CHUNK); + let q_chunk = q.narrow(2, start, len).map_err(|e| { + candle_to_ocr_inference("HunyuanOCR", "vit attn chunked q narrow", e) + })?; + chunks.push(attend_chunk(&q_chunk)?); + start += len; + } + let chunk_refs: Vec<&Tensor> = chunks.iter().collect(); + Tensor::cat(&chunk_refs, 2) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn chunked cat", e))? + }; + + let attn_output = attn_output .transpose(1, 2) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit attn out transpose", e))? .reshape((b, seq_len, self.num_heads * self.head_dim)) @@ -318,9 +397,16 @@ impl VisionMlp { .fc1 .forward(xs) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit mlp fc1", e))?; + // Match PyTorch `nn.GELU()` (exact erf formula). candle's `.gelu()` + // uses the tanh approximation + // (`0.5 * v * (1 + tanh(sqrt(2/π)*(v + 0.044715*v³)))`), which + // diverges by up to ~0.001 per element from the erf formula. Across + // 27 vit MLPs that drift compounds enough to swap which positions + // become attention sinks in late layers (max-abs delta jumped from + // 1.4 at layer 1 to 13419 at layer 26). let hidden = hidden - .gelu() - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit mlp gelu", e))?; + .gelu_erf() + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit mlp gelu_erf", e))?; self.fc2 .forward(&hidden) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "vit mlp fc2", e)) @@ -336,8 +422,12 @@ struct VisionEncoderLayer { } impl VisionEncoderLayer { - fn load(cfg: &HunyuanOcrVisionConfig, vb: candle_nn::VarBuilder) -> Result { - let self_attn = VisionAttention::load(cfg, vb.pp("self_attn"))?; + fn load( + cfg: &HunyuanOcrVisionConfig, + layer_idx: usize, + vb: candle_nn::VarBuilder, + ) -> Result { + let self_attn = VisionAttention::load(cfg, layer_idx, vb.pp("self_attn"))?; let mlp = VisionMlp::load(cfg, vb.pp("mlp"))?; let ln_cfg = LayerNormConfig { @@ -396,7 +486,6 @@ struct VisionPerceive { after_rms: candle_nn::RmsNorm, image_begin: Tensor, image_end: Tensor, - image_sep: Tensor, image_newline: Tensor, } @@ -446,9 +535,10 @@ impl VisionPerceive { let image_end = vb .get(1024usize, "image_end") .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "load perceive image_end", e))?; - let image_sep = vb - .get(1024usize, "image_sep") - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "load perceive image_sep", e))?; + // `image_sep` exists in the trained weights but is *never* used by + // upstream's `HunYuanVisionPatchMerger.forward` — see + // `transformers/models/hunyuan_vl/modeling_hunyuan_vl.py:189-206`. We + // skip loading it. let image_newline = vb .get(4608usize, "image_newline") .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "load perceive image_newline", e))?; @@ -462,7 +552,6 @@ impl VisionPerceive { after_rms, image_begin, image_end, - image_sep, image_newline, }) } @@ -519,9 +608,11 @@ impl VisionPerceive { .proj_0 .forward(&feat_map) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive proj.0 forward", e))?; + // Match PyTorch `nn.GELU()` exact erf formula here too — see + // `VisionMlp::forward` for the rationale. let feat = feat - .gelu() - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive proj.0 gelu", e))?; + .gelu_erf() + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive proj.0 gelu_erf", e))?; let feat = self .proj_2 .forward(&feat) @@ -590,26 +681,27 @@ impl VisionPerceive { .mlp .forward(&tokens) .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive mlp forward", e))?; - let tokens = self - .after_rms - .forward(&tokens) - .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive after_rms forward", e))?; - - let sep = self.image_sep.reshape((1usize, 1024usize)).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: reshape image_sep failed", - e, - ) - })?; - let tokens = tokens.broadcast_add(&sep).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: add image_sep failed", - e, - ) - })?; + // Match upstream HF (`HunYuanVisionPatchMerger.forward` in + // modeling_hunyuan_vl.py:189-206) exactly: + // 1. mlp(x) + // 2. cat([image_begin, mlp_out, image_end]) + // 3. after_rms(cat) + // + // We previously (a) applied `after_rms` to the mlp output BEFORE + // concatenating the begin/end markers and (b) broadcast-added an + // unused `image_sep` parameter to every token. Both wrong: + // upstream's `after_rms` runs once over the full begin+tokens+end + // sequence, which lifts the image_begin / image_end embedding + // magnitudes from their stored norm (~0.9) up to the post-RMSNorm + // scale (~22), matching the surrounding patch tokens. Our pre-fix + // perceive output had image_begin / image_end at norm ~0.9 — 25× + // smaller than upstream — so the LLM saw those marker positions as + // near-zero vectors and the prefill's last-position logits diverged + // (cos 0.69 vs upstream → wrong argmax → hallucinated + // continuations like "The presence of factors…" instead of OCR text). + // The `image_sep` Parameter is declared in upstream weights but + // *never used in the forward path*; we now drop it on the floor too. let begin = self.image_begin.reshape((1usize, 1024usize)).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, @@ -624,13 +716,22 @@ impl VisionPerceive { e, ) })?; - Tensor::cat(&[&begin, &tokens, &end], 0).map_err(|e| { + let begin = begin.to_dtype(tokens.dtype()).map_err(|e| { + candle_to_ocr_inference("HunyuanOCR", "perceive image_begin to dtype", e) + })?; + let end = end + .to_dtype(tokens.dtype()) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive image_end to dtype", e))?; + let cat = Tensor::cat(&[&begin, &tokens, &end], 0).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: concat begin/tokens/end failed", e, ) - }) + })?; + self.after_rms + .forward(&cat) + .map_err(|e| candle_to_ocr_inference("HunyuanOCR", "perceive after_rms forward", e)) } } @@ -647,7 +748,11 @@ impl HunyuanVisionModel { let embeddings = VisionEmbeddings::load(cfg, vb.pp("embeddings"))?; let mut layers = Vec::with_capacity(cfg.num_hidden_layers); for i in 0..cfg.num_hidden_layers { - layers.push(VisionEncoderLayer::load(cfg, vb.pp(format!("layers.{i}")))?); + layers.push(VisionEncoderLayer::load( + cfg, + i, + vb.pp(format!("layers.{i}")), + )?); } let perceive = VisionPerceive::load(cfg, vb.pp("perceive"))?; Ok(Self { @@ -741,50 +846,24 @@ impl HunyuanVisionModel { ) })?; - // Add a lightweight extra token (mean-pooled) to match the original position embedding layout. - let extra_pos = self.embeddings.extra_pos()?.unsqueeze(0).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: unsqueeze extra_pos failed", - e, - ) - })?; - let extra = patch_tokens.mean_keepdim(0).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: mean_keepdim vit extra token failed", - e, - ) - })?; - let extra = extra.broadcast_add(&extra_pos).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: add extra position embedding failed", - e, - ) - })?; - let extra = extra.unsqueeze(0).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: unsqueeze extra token failed", - e, - ) - })?; - - let patch_tokens = patch_tokens.unsqueeze(0).map_err(|e| { + // No extra/cls token. The upstream HunYuanVisionPatchEmbed + // (`transformers/models/hunyuan_vl/modeling_hunyuan_vl.py`) declares + // `num_positions = max_num_patches + 1` but the runtime path uses + // `position_embedding.weight[1:, :]` and feeds only the patch tokens + // through the encoder — the slot-0 entry is a vestigial cls token + // present in the trained weights for compatibility but never + // propagated. An earlier internal Tencent variant (which this Rust + // port originally followed) prepended a mean-pooled extra token plus + // a learned `extra_pos`, which contributed unwanted attention scores + // to every patch and accumulated noise across 27 encoder layers. + // Removing that prepend halves the residual vit_out drift vs upstream. + let hidden = patch_tokens.unsqueeze(0).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: add batch dim to vit tokens failed", e, ) })?; - let hidden = Tensor::cat(&[&extra, &patch_tokens], 1).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: concat vit extra token failed", - e, - ) - })?; let mut hidden_states = hidden; for (i, layer) in self.layers.iter().enumerate() { @@ -797,15 +876,8 @@ impl HunyuanVisionModel { })?; } - // Drop the extra token before spatial perceiver merge. - let patch_out = hidden_states.i((.., 1.., ..)).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "HunyuanOCR: slice vit patch outputs failed", - e, - ) - })?; - let patch_out = patch_out.squeeze(0).map_err(|e| { + // No extra token to drop now (see above). Just unbatch. + let patch_out = hidden_states.squeeze(0).map_err(|e| { candle_to_ocr_processing( oar_ocr_core::core::errors::ProcessingStage::TensorOperation, "HunyuanOCR: squeeze vit patch outputs failed", @@ -848,17 +920,26 @@ fn interpolate_bilinear_align_corners_false( return out; } - let scale_y = in_h as f32 / out_h as f32; - let scale_x = in_w as f32 / out_w as f32; + // Match upstream HF (`HunYuanVisionPatchEmbed.forward` in + // modeling_hunyuan_vl.py:143-148): the sample stride is computed from + // `(out_h + 0.1) / in_h` (a deliberate `+0.1` to "avoid floating point + // error in the interpolation" — see the comment + facebookresearch/dino#8). + // PyTorch's `interpolate(scale_factor)` then derives the source coord as + // `(out_x + 0.5) / scale_factor - 0.5`, which is *not* the same as + // `(out_x + 0.5) * (in / out) - 0.5` we used before. + let scale_factor_y = (out_h as f32 + 0.1) / in_h as f32; + let scale_factor_x = (out_w as f32 + 0.1) / in_w as f32; + let inv_scale_y = 1.0 / scale_factor_y; + let inv_scale_x = 1.0 / scale_factor_x; for oy in 0..out_h { - let fy = (oy as f32 + 0.5) * scale_y - 0.5; + let fy = ((oy as f32 + 0.5) * inv_scale_y - 0.5).max(0.0); let y0 = fy.floor().max(0.0) as usize; let y1 = (y0 + 1).min(in_h - 1); let wy = fy - y0 as f32; for ox in 0..out_w { - let fx = (ox as f32 + 0.5) * scale_x - 0.5; + let fx = ((ox as f32 + 0.5) * inv_scale_x - 0.5).max(0.0); let x0 = fx.floor().max(0.0) as usize; let x1 = (x0 + 1).min(in_w - 1); let wx = fx - x0 as f32; diff --git a/oar-ocr-vl/src/lib.rs b/oar-ocr-vl/src/lib.rs index 063e79d..fccd7f5 100644 --- a/oar-ocr-vl/src/lib.rs +++ b/oar-ocr-vl/src/lib.rs @@ -8,18 +8,21 @@ //! ## Module Structure //! //! - `paddleocr_vl` - PaddleOCR-VL for OCR, table, formula, chart, spotting, and seal recognition -//! - `unirec` - UniRec unified text/formula/table recognition //! - `hunyuanocr` - HunyuanOCR OCR expert VLM //! - `glmocr` - GLM-OCR OCR expert VLM -//! - `lightonocr` - LightOnOCR end-to-end OCR VLM //! - `mineru` - MinerU2.5 document parsing VLM (Qwen2-VL backbone) //! - `doc_parser` - Unified document parsing with pluggable recognition backends //! - `utils` - Utility functions (device parsing, candle helpers, markdown, OTSL conversion) //! - `attention` - Unified attention implementation shared by all models +//! - `hsd` - Hierarchical Speculative Decoding (DSV-on-OAR engineering of paper +//! arXiv:2602.12957 §3.2); gated behind the `hsd` feature, which requires +//! `cuda` for tree-attention and KV-cache gather. See the module docs for +//! the two-stage flow, prefix-tree batching, and per-backend integration. //! //! ## Features //! //! - `cuda` - Enable CUDA support for GPU acceleration +//! - `hsd` - Enable Hierarchical Speculative Decoding (implies `cuda`) //! //! ## Device Configuration //! @@ -39,25 +42,34 @@ pub mod doc_parser; pub mod glmocr; pub mod hunyuanocr; -pub mod lightonocr; pub mod mineru; pub mod paddleocr_vl; -pub mod unirec; pub mod utils; // Shared attention implementation pub mod attention; +// `TrimmableKvCache` backs the KV cache used by every model's attention +// forward path, so it must remain accessible regardless of the `hsd` feature. +// The source lives under `hsd/kv_trim.rs` and is also re-exported from +// `crate::hsd` when that module is compiled in. +#[path = "hsd/kv_trim.rs"] +pub(crate) mod kv_trim; + +// Hierarchical Speculative Decoding (requires CUDA-backed Candle for KV-cache +// gather and tree-attention; gated behind the `hsd` cargo feature). +#[cfg(feature = "hsd")] +pub mod hsd; + // Re-exports for convenience pub use paddleocr_vl::{ PaddleOcrVl, PaddleOcrVlConfig, PaddleOcrVlImageProcessorConfig, PaddleOcrVlTask, }; -pub use unirec::UniRec; - pub use glmocr::GlmOcr; +#[cfg(feature = "hsd")] +pub use hunyuanocr::HunyuanHsdPrompts; pub use hunyuanocr::HunyuanOcr; -pub use lightonocr::LightOnOcr; pub use mineru::MinerU; pub use doc_parser::{DocParser, DocParserConfig, RecognitionBackend, RecognitionTask}; diff --git a/oar-ocr-vl/src/lightonocr/config.rs b/oar-ocr-vl/src/lightonocr/config.rs deleted file mode 100644 index 3e4388a..0000000 --- a/oar-ocr-vl/src/lightonocr/config.rs +++ /dev/null @@ -1,138 +0,0 @@ -use candle_nn::Activation; -use oar_ocr_core::core::OCRError; -use serde::Deserialize; -use std::path::Path; - -#[derive(Debug, Clone, Deserialize)] -pub struct LightOnOcrTextConfig { - pub vocab_size: usize, - pub hidden_size: usize, - pub intermediate_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub num_key_value_heads: usize, - pub head_dim: usize, - pub attention_bias: bool, - #[serde(default)] - pub attention_dropout: f32, - pub hidden_act: Activation, - pub max_position_embeddings: usize, - pub rms_norm_eps: f64, - pub rope_theta: f64, - #[serde(default)] - pub sliding_window: Option, - #[serde(default)] - pub max_window_layers: usize, - #[serde(default)] - pub use_sliding_window: bool, - #[serde(default)] - pub tie_word_embeddings: bool, - #[serde(default)] - pub use_qk_norm: bool, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct LightOnOcrVisionConfig { - pub hidden_size: usize, - pub num_channels: usize, - pub image_size: usize, - pub patch_size: usize, - pub rope_theta: f64, - pub intermediate_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub head_dim: usize, - pub hidden_act: Activation, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct LightOnOcrConfig { - pub text_config: LightOnOcrTextConfig, - pub vision_config: LightOnOcrVisionConfig, - pub image_token_id: u32, - pub eos_token_id: u32, - pub pad_token_id: u32, - pub spatial_merge_size: usize, - pub projector_hidden_act: Activation, - pub multimodal_projector_bias: bool, - #[serde(default)] - pub vision_feature_layer: i32, -} - -impl LightOnOcrConfig { - pub fn from_path(path: impl AsRef) -> Result { - let contents = std::fs::read_to_string(path)?; - serde_json::from_str(&contents).map_err(|e| OCRError::ConfigError { - message: format!("failed to parse LightOnOCR config.json: {e}"), - }) - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct LightOnOcrImageProcessorSize { - pub longest_edge: u32, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct LightOnOcrImageProcessorConfig { - pub do_resize: bool, - pub do_rescale: bool, - pub do_normalize: bool, - pub do_convert_rgb: bool, - pub rescale_factor: f32, - pub image_mean: Vec, - pub image_std: Vec, - pub patch_size: usize, - #[serde(default)] - pub resample: Option, - pub size: LightOnOcrImageProcessorSize, -} - -impl LightOnOcrImageProcessorConfig { - pub fn validate(&self) -> Result<(), OCRError> { - if self.image_mean.len() != 3 || self.image_std.len() != 3 { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR image_mean/std must have length 3, got mean={} std={}", - self.image_mean.len(), - self.image_std.len() - ), - }); - } - if self.image_std.contains(&0.0) { - return Err(OCRError::ConfigError { - message: "LightOnOCR image_std values must be non-zero (used as divisor)" - .to_string(), - }); - } - if self.patch_size == 0 { - return Err(OCRError::ConfigError { - message: "LightOnOCR patch_size must be > 0".to_string(), - }); - } - if self.size.longest_edge == 0 { - return Err(OCRError::ConfigError { - message: "LightOnOCR longest_edge must be > 0".to_string(), - }); - } - Ok(()) - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct LightOnOcrProcessorConfig { - pub image_processor: LightOnOcrImageProcessorConfig, - #[serde(default)] - pub spatial_merge_size: Option, - #[serde(default)] - pub patch_size: Option, -} - -impl LightOnOcrProcessorConfig { - pub fn from_path(path: impl AsRef) -> Result { - let contents = std::fs::read_to_string(path)?; - serde_json::from_str(&contents).map_err(|e| OCRError::ConfigError { - message: format!("failed to parse LightOnOCR processor_config.json: {e}"), - }) - } -} diff --git a/oar-ocr-vl/src/lightonocr/mod.rs b/oar-ocr-vl/src/lightonocr/mod.rs deleted file mode 100644 index 2c4167f..0000000 --- a/oar-ocr-vl/src/lightonocr/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! LightOnOCR (LightOnOCR-2) Vision-Language model implementation. - -mod config; -mod model; -mod processing; -mod text; -mod vision; - -pub use config::{ - LightOnOcrConfig, LightOnOcrImageProcessorConfig, LightOnOcrProcessorConfig, - LightOnOcrTextConfig, LightOnOcrVisionConfig, -}; -pub use model::LightOnOcr; diff --git a/oar-ocr-vl/src/lightonocr/model.rs b/oar-ocr-vl/src/lightonocr/model.rs deleted file mode 100644 index 1f3caf4..0000000 --- a/oar-ocr-vl/src/lightonocr/model.rs +++ /dev/null @@ -1,683 +0,0 @@ -use super::config::{LightOnOcrConfig, LightOnOcrProcessorConfig}; -use super::processing::{LightOnOcrImageInputs, preprocess_image}; -use super::text::LightOnOcrTextModel; -use super::vision::{PixtralVisionConfig, PixtralVisionModel}; -use crate::attention::{combine_masks, create_causal_mask, create_left_padding_mask}; -use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing}; -use candle_core::{D, DType, Device, IndexOp, Tensor}; -use candle_nn::{Activation, Linear, RmsNorm, VarBuilder, linear_b, linear_no_bias, rms_norm}; -use image::RgbImage; -use oar_ocr_core::core::OCRError; -use std::path::Path; -use tokenizers::Tokenizer; - -pub struct LightOnOcr { - device: Device, - dtype: DType, - cfg: LightOnOcrConfig, - image_cfg: super::config::LightOnOcrImageProcessorConfig, - tokenizer: Tokenizer, - text: LightOnOcrTextModel, - vision: PixtralVisionModel, - projector: VisionProjection, - eos_token_id: u32, - image_token_id: u32, -} - -impl LightOnOcr { - pub fn from_dir(model_dir: impl AsRef, device: Device) -> Result { - let model_dir = model_dir.as_ref(); - let cfg = LightOnOcrConfig::from_path(model_dir.join("config.json"))?; - let processor_cfg = - LightOnOcrProcessorConfig::from_path(model_dir.join("processor_config.json"))?; - let image_cfg = processor_cfg.image_processor; - - if let Some(patch_size) = processor_cfg.patch_size - && patch_size != image_cfg.patch_size - { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR patch_size mismatch: processor {} != image_processor {}", - patch_size, image_cfg.patch_size - ), - }); - } - if let Some(merge_size) = processor_cfg.spatial_merge_size - && merge_size != cfg.spatial_merge_size - { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR spatial_merge_size mismatch: processor {} != config {}", - merge_size, cfg.spatial_merge_size - ), - }); - } - if image_cfg.patch_size != cfg.vision_config.patch_size { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR patch_size mismatch: image_processor {} != vision_config {}", - image_cfg.patch_size, cfg.vision_config.patch_size - ), - }); - } - - let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json")).map_err(|e| { - OCRError::ConfigError { - message: format!("failed to load LightOnOCR tokenizer.json: {e}"), - } - })?; - let tok_image_id = - tokenizer - .token_to_id("<|image_pad|>") - .ok_or_else(|| OCRError::ConfigError { - message: "LightOnOCR tokenizer is missing <|image_pad|> token".to_string(), - })?; - if tok_image_id != cfg.image_token_id { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR image_token_id mismatch: tokenizer {tok_image_id} != config {}", - cfg.image_token_id - ), - }); - } - if let Some(tok_eos_id) = tokenizer.token_to_id("<|im_end|>") - && tok_eos_id != cfg.eos_token_id - { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR eos_token_id mismatch: tokenizer {tok_eos_id} != config {}", - cfg.eos_token_id - ), - }); - } - - let dtype = device.bf16_default_to_f32(); - // SAFETY: The mmap'd file must not be modified or deleted while in use. - // This is upheld because model files are read-only assets loaded at initialization. - let vb = unsafe { - VarBuilder::from_mmaped_safetensors( - &[model_dir.join("model.safetensors")], - dtype, - &device, - ) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load model.safetensors", e))? - }; - - let vision_cfg = PixtralVisionConfig::from(&cfg.vision_config); - - let vision = PixtralVisionModel::new(&vision_cfg, vb.pp("model").pp("vision_encoder")) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load vision encoder", e))?; - - let text = LightOnOcrTextModel::new(&cfg.text_config, vb.clone()) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load text model", e))?; - - let projector = VisionProjection::load(&cfg, vb.pp("model").pp("vision_projection"))?; - - let eos_token_id = cfg.eos_token_id; - let image_token_id = cfg.image_token_id; - - Ok(Self { - device, - dtype, - cfg, - image_cfg, - tokenizer, - text, - vision, - projector, - eos_token_id, - image_token_id, - }) - } - - /// Generate OCR output for one or more images with optional instructions. - /// - /// Supports true GPU batching when multiple images are provided. - /// An empty instruction performs plain text extraction. For structured output, - /// pass a task-specific prompt (e.g. "Parse the table in the image into HTML."). - /// - /// # Arguments - /// * `images` - Input images - /// * `instructions` - Instruction for each image (must match images length) - /// * `max_new_tokens` - Maximum tokens to generate per image - /// - /// # Returns - /// Vector of results, one per input image. - pub fn generate( - &self, - images: &[RgbImage], - instructions: &[impl AsRef], - max_new_tokens: usize, - ) -> Vec> { - if images.is_empty() { - return Vec::new(); - } - if images.len() != instructions.len() { - return vec![Err(OCRError::InvalidInput { - message: format!( - "LightOnOCR: images count ({}) != instructions count ({})", - images.len(), - instructions.len() - ), - })]; - } - - match self.generate_internal(images, instructions, max_new_tokens) { - Ok(results) => results.into_iter().map(Ok).collect(), - Err(e) => { - let msg = format!("generation failed: {e}"); - (0..images.len()) - .map(|_| { - Err(OCRError::InvalidInput { - message: msg.clone(), - }) - }) - .collect() - } - } - } - - /// Internal generation implementation supporting batched inference. - fn generate_internal( - &self, - images: &[RgbImage], - instructions: &[impl AsRef], - max_new_tokens: usize, - ) -> Result, OCRError> { - let batch_size = images.len(); - - // 1. Preprocess all images and build prompts - let mut all_input_ids: Vec> = Vec::with_capacity(batch_size); - let mut all_image_inputs: Vec = Vec::with_capacity(batch_size); - let mut all_image_token_counts: Vec = Vec::with_capacity(batch_size); - - for (image, instruction) in images.iter().zip(instructions.iter()) { - let instruction = instruction.as_ref(); - let image_inputs = preprocess_image( - image, - &self.image_cfg, - self.cfg.spatial_merge_size, - &self.device, - self.dtype, - )?; - - let image_token_count = (image_inputs.grid_h / self.cfg.spatial_merge_size) - * (image_inputs.grid_w / self.cfg.spatial_merge_size); - - // Build prompt tokens - let prefix = "<|im_start|>system<|im_end|>\n<|im_start|>user\n"; - let instruction_trimmed = instruction.trim(); - let suffix = if instruction_trimmed.is_empty() { - "<|im_end|>\n<|im_start|>assistant\n" - } else { - "\n" - }; - let after_instruction = if instruction_trimmed.is_empty() { - "" - } else { - "<|im_end|>\n<|im_start|>assistant\n" - }; - - let prefix_enc = - self.tokenizer - .encode(prefix, false) - .map_err(|e| OCRError::InvalidInput { - message: format!("LightOnOCR: tokenizer encode prefix failed: {e}"), - })?; - let suffix_enc = - self.tokenizer - .encode(suffix, false) - .map_err(|e| OCRError::InvalidInput { - message: format!("LightOnOCR: tokenizer encode suffix failed: {e}"), - })?; - - let mut input_ids = - Vec::with_capacity(prefix_enc.len() + image_token_count + suffix_enc.len() + 50); - input_ids.extend_from_slice(prefix_enc.get_ids()); - input_ids.extend(std::iter::repeat_n(self.image_token_id, image_token_count)); - input_ids.extend_from_slice(suffix_enc.get_ids()); - - if !instruction_trimmed.is_empty() { - let instruction_enc = - self.tokenizer - .encode(instruction_trimmed, false) - .map_err(|e| OCRError::InvalidInput { - message: format!( - "LightOnOCR: tokenizer encode instruction failed: {e}" - ), - })?; - input_ids.extend_from_slice(instruction_enc.get_ids()); - - let after_enc = self - .tokenizer - .encode(after_instruction, false) - .map_err(|e| OCRError::InvalidInput { - message: format!( - "LightOnOCR: tokenizer encode after_instruction failed: {e}" - ), - })?; - input_ids.extend_from_slice(after_enc.get_ids()); - } - - all_input_ids.push(input_ids); - all_image_inputs.push(image_inputs); - all_image_token_counts.push(image_token_count); - } - - // 2. Compute vision features and project for each image - let mut all_image_embeds: Vec = Vec::with_capacity(batch_size); - for image_inputs in &all_image_inputs { - let vision_tokens = self - .vision - .forward(&image_inputs.pixel_values) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "vision forward", e))?; - let projected = - self.projector - .forward(&vision_tokens, image_inputs.grid_h, image_inputs.grid_w)?; - all_image_embeds.push(projected); - } - - // 3. Build embeddings per sample with left-padding - let seq_lens: Vec = all_input_ids.iter().map(|ids| ids.len()).collect(); - let max_seq_len = *seq_lens.iter().max().unwrap(); - - let mut batch_embeds: Vec = Vec::with_capacity(batch_size); - - for (i, input_ids) in all_input_ids.iter().enumerate() { - let seq_len = input_ids.len(); - let pad_len = max_seq_len - seq_len; - let image_token_count = all_image_token_counts[i]; - - // Embed tokens - let input_ids_t = Tensor::new(input_ids.clone(), &self.device) - .and_then(|t| t.reshape((1, seq_len))) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "create input_ids", e))?; - let mut inputs_embeds = self - .text - .embed_tokens(&input_ids_t) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "embed tokens", e))?; - - // Insert image embeddings - let spans = find_image_spans(input_ids, self.image_token_id); - let total_image_tokens: usize = spans.iter().map(|(s, e)| e - s).sum(); - if total_image_tokens != image_token_count { - return Err(OCRError::InvalidInput { - message: format!( - "LightOnOCR: image token count mismatch: expected {image_token_count}, got {total_image_tokens}" - ), - }); - } - inputs_embeds = - insert_image_embeds(inputs_embeds, all_image_embeds[i].clone(), &spans)?; - - // Left-pad if needed - if pad_len > 0 { - let hidden_size = inputs_embeds - .dim(2) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "get hidden_size", e))?; - let pad = Tensor::zeros( - (1, pad_len, hidden_size), - inputs_embeds.dtype(), - &self.device, - ) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "create pad", e))?; - inputs_embeds = Tensor::cat(&[&pad, &inputs_embeds], 1) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "cat pad", e))?; - } - batch_embeds.push(inputs_embeds); - } - - // 4. Stack batched tensors - let batch_refs: Vec<&Tensor> = batch_embeds.iter().collect(); - let inputs_embeds = Tensor::cat(&batch_refs, 0) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "stack embeds", e))?; - - // 5. Create attention mask - let causal = create_causal_mask(max_seq_len, max_seq_len, self.dtype, &self.device) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "create causal", e))?; - let padding = create_left_padding_mask(&seq_lens, max_seq_len, self.dtype, &self.device) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "create padding", e))?; - let mask = combine_masks(&causal, &padding) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "combine masks", e))?; - - // 6. Prefill - self.text.clear_kv_cache(); - // Build seqlen_offsets: for left-padded sequences, offset is the padding length - let seqlen_offsets: Vec = seq_lens.iter().map(|&len| max_seq_len - len).collect(); - let logits = self - .text - .forward_embeds(inputs_embeds, Some(&mask), &seqlen_offsets) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "prefill forward", e))?; - - // 7. Get initial logits per sample - let mut logits_list: Vec = Vec::with_capacity(batch_size); - for i in 0..batch_size { - let sample_logits = logits - .i(i) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "get sample logits", e))?; - logits_list.push(sample_logits); - } - - // 8. Autoregressive decode - let mut generated: Vec> = vec![Vec::new(); batch_size]; - let mut finished: Vec = vec![false; batch_size]; - let mut positions: Vec = seq_lens.clone(); - - for _ in 0..max_new_tokens { - if finished.iter().all(|&f| f) { - break; - } - - let mut next_tokens: Vec = Vec::with_capacity(batch_size); - for (i, logits) in logits_list.iter().enumerate() { - if finished[i] { - next_tokens.push(0); // Padding token for finished samples - } else { - let tok = logits - .argmax(D::Minus1) - .and_then(|t| t.to_scalar::()) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "argmax", e))?; - - if tok == self.eos_token_id { - finished[i] = true; - } else { - generated[i].push(tok); - } - next_tokens.push(tok); - } - } - - if finished.iter().all(|&f| f) { - break; - } - - // Batch forward for next tokens - let tokens = Tensor::new(next_tokens, &self.device) - .and_then(|t| t.reshape((batch_size, 1))) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "create tokens", e))?; - let embeds = self - .text - .embed_tokens(&tokens) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "embed next tokens", e))?; - - // For decode step, no mask needed, use positions as offsets - let next_logits = self - .text - .forward_embeds(embeds, None, &positions) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "decode forward", e))?; - - logits_list.clear(); - for i in 0..batch_size { - let sample_logits = next_logits - .i(i) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "get decode logits", e))?; - logits_list.push(sample_logits); - } - - for (i, p) in positions.iter_mut().enumerate() { - if !finished[i] { - *p += 1; - } - } - } - - // 9. Decode results - let mut results = Vec::with_capacity(batch_size); - for tokens in generated { - let text = - self.tokenizer - .decode(&tokens, false) - .map_err(|e| OCRError::InvalidInput { - message: format!("LightOnOCR: tokenizer decode failed: {e}"), - })?; - results.push(text); - } - - Ok(results) - } -} - -struct VisionProjection { - merge_size: usize, - patch_merger: Linear, - norm: RmsNorm, - linear_1: Linear, - linear_2: Linear, - act: Activation, -} - -impl VisionProjection { - fn load(cfg: &LightOnOcrConfig, vb: VarBuilder) -> Result { - let hidden = cfg.vision_config.hidden_size; - let merge_size = cfg.spatial_merge_size; - let in_dim = hidden * merge_size * merge_size; - - let patch_merger = - linear_no_bias(in_dim, hidden, vb.pp("patch_merger").pp("merging_layer")) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load patch_merger", e))?; - let norm = rms_norm(hidden, cfg.text_config.rms_norm_eps, vb.pp("norm")) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load vision_projection norm", e))?; - - let linear_1 = linear_b( - hidden, - hidden, - cfg.multimodal_projector_bias, - vb.pp("linear_1"), - ) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load vision_projection linear_1", e))?; - let linear_2 = linear_b( - hidden, - hidden, - cfg.multimodal_projector_bias, - vb.pp("linear_2"), - ) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "load vision_projection linear_2", e))?; - - Ok(Self { - merge_size, - patch_merger, - norm, - linear_1, - linear_2, - act: cfg.projector_hidden_act, - }) - } - - fn forward(&self, xs: &Tensor, grid_h: usize, grid_w: usize) -> Result { - let normed = xs - .apply(&self.norm) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "vision_projection norm", e))?; - let merged = self.merge_patches(&normed, grid_h, grid_w)?; - let x = merged - .apply(&self.linear_1) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "vision_projection linear_1", e))?; - let x = x - .apply(&self.act) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "vision_projection act", e))?; - x.apply(&self.linear_2) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "vision_projection linear_2", e)) - } - - fn merge_patches(&self, xs: &Tensor, grid_h: usize, grid_w: usize) -> Result { - let (b, seq_len, hidden) = xs - .dims3() - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches dims3", e))?; - if grid_h * grid_w != seq_len { - return Err(OCRError::InvalidInput { - message: format!( - "LightOnOCR merge_patches: expected grid {}x{} to match seq_len={seq_len}", - grid_h, grid_w - ), - }); - } - let merged_h = grid_h / self.merge_size; - let merged_w = grid_w / self.merge_size; - if merged_h == 0 || merged_w == 0 { - return Err(OCRError::InvalidInput { - message: format!( - "LightOnOCR merge_patches: grid {}x{} too small for merge_size={}", - grid_h, grid_w, self.merge_size - ), - }); - } - - let trim_h = merged_h * self.merge_size; - let trim_w = merged_w * self.merge_size; - - // Match Mistral3PatchMerger.forward from HuggingFace transformers - // (models/mistral3/modular_mistral3.py), which uses F.unfold(kernel=merge, stride=merge). - // Here we implement the equivalent via reshape + permute. - // 1. (b, seq, hidden) -> (b, h, w, hidden) - let xs = xs - .reshape((b, grid_h, grid_w, hidden)) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches reshape1", e))?; - let xs = xs - .narrow(1, 0, trim_h) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches narrow1", e))?; - let xs = xs - .narrow(2, 0, trim_w) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches narrow2", e))?; - - // 2. (b, h, w, hidden) -> (b, h', merge_h, w', merge_w, hidden) - // indices: (0, 1, 2, 3, 4, 5) - let xs = xs - .reshape(( - b, - merged_h, - self.merge_size, - merged_w, - self.merge_size, - hidden, - )) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches reshape2", e))?; - - // 3. permute to (b, h', w', hidden, merge_h, merge_w) - // unfold order: hidden varies slowest in the flattened output - // i.e. unfold gives (h', w', hidden, merge_h, merge_w), view flattens last 3 dims - let xs = xs - .permute((0, 1, 3, 5, 2, 4)) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches permute", e))?; - - // 4. flatten to (b, h'*w', hidden*merge*merge) - // Note: this is hidden*merge*merge, not merge*merge*hidden - let xs = xs - .reshape(( - b, - merged_h * merged_w, - hidden * self.merge_size * self.merge_size, - )) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches reshape3", e))?; - - xs.apply(&self.patch_merger) - .map_err(|e| candle_to_ocr_inference("LightOnOCR", "merge_patches patch_merger", e)) - } -} - -fn find_image_spans(input_ids: &[u32], image_token_id: u32) -> Vec<(usize, usize)> { - let mut spans = Vec::new(); - let mut start = None; - for (idx, &token) in input_ids.iter().enumerate() { - if token == image_token_id { - if start.is_none() { - start = Some(idx); - } - } else if let Some(s) = start.take() { - spans.push((s, idx)); - } - } - if let Some(s) = start { - spans.push((s, input_ids.len())); - } - spans -} - -fn insert_image_embeds( - mut input_embeds: Tensor, - image_embeds: Tensor, - spans: &[(usize, usize)], -) -> Result { - let (b, _seq_len, hidden) = input_embeds.dims3().map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "LightOnOCR: input_embeds dims3 failed", - e, - ) - })?; - let image_embeds = image_embeds.to_dtype(input_embeds.dtype()).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "LightOnOCR: cast image_embeds dtype failed", - e, - ) - })?; - - let mut offset = 0usize; - for &(start, end) in spans { - let len = end - start; - let chunk = image_embeds.narrow(1, offset, len).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "LightOnOCR: slice image_embeds failed", - e, - ) - })?; - input_embeds = input_embeds - .slice_assign(&[0..b, start..end, 0..hidden], &chunk) - .map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "LightOnOCR: slice_assign image_embeds failed", - e, - ) - })?; - offset += len; - } - Ok(input_embeds) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_find_image_spans_single_span() { - let input_ids = [1, 2, 99, 99, 99, 3, 4]; - let spans = find_image_spans(&input_ids, 99); - assert_eq!(spans, vec![(2, 5)]); - } - - #[test] - fn test_find_image_spans_multiple_spans() { - let input_ids = [99, 99, 1, 2, 99, 99, 99, 3]; - let spans = find_image_spans(&input_ids, 99); - assert_eq!(spans, vec![(0, 2), (4, 7)]); - } - - #[test] - fn test_find_image_spans_trailing() { - let input_ids = [1, 2, 99, 99]; - let spans = find_image_spans(&input_ids, 99); - assert_eq!(spans, vec![(2, 4)]); - } - - #[test] - fn test_find_image_spans_none() { - let input_ids = [1, 2, 3, 4]; - let spans = find_image_spans(&input_ids, 99); - assert!(spans.is_empty()); - } - - #[test] - fn test_find_image_spans_all() { - let input_ids = [99, 99, 99]; - let spans = find_image_spans(&input_ids, 99); - assert_eq!(spans, vec![(0, 3)]); - } - - #[test] - fn test_find_image_spans_single_token_spans() { - let input_ids = [99, 1, 99, 2, 99]; - let spans = find_image_spans(&input_ids, 99); - assert_eq!(spans, vec![(0, 1), (2, 3), (4, 5)]); - } -} diff --git a/oar-ocr-vl/src/lightonocr/processing.rs b/oar-ocr-vl/src/lightonocr/processing.rs deleted file mode 100644 index b04241d..0000000 --- a/oar-ocr-vl/src/lightonocr/processing.rs +++ /dev/null @@ -1,129 +0,0 @@ -use super::config::LightOnOcrImageProcessorConfig; -use crate::utils::{ - candle_to_ocr_processing, - image::{image_to_chw, pil_resample_to_filter_type, round_up_to_multiple}, -}; -use candle_core::{DType, Device, Tensor}; -use image::{RgbImage, imageops::FilterType}; -use oar_ocr_core::core::OCRError; - -#[derive(Debug)] -pub struct LightOnOcrImageInputs { - pub pixel_values: Tensor, - pub grid_h: usize, - pub grid_w: usize, -} - -pub fn preprocess_image( - image: &RgbImage, - cfg: &LightOnOcrImageProcessorConfig, - spatial_merge_size: usize, - device: &Device, - dtype: DType, -) -> Result { - cfg.validate()?; - if spatial_merge_size == 0 { - return Err(OCRError::ConfigError { - message: "LightOnOCR spatial_merge_size must be > 0".to_string(), - }); - } - - let (orig_h, orig_w) = (image.height(), image.width()); - let max_edge = cfg.size.longest_edge; - let mut target_h = orig_h; - let mut target_w = orig_w; - - if cfg.do_resize { - let max_dim = orig_h.max(orig_w); - if max_dim > max_edge { - let ratio = max_dim as f32 / max_edge as f32; - target_h = ((orig_h as f32) / ratio).ceil() as u32; - target_w = ((orig_w as f32) / ratio).ceil() as u32; - } - } - - let factor = cfg.patch_size as u32; - if cfg.do_resize { - target_h = round_up_to_multiple(target_h, factor); - target_w = round_up_to_multiple(target_w, factor); - } - - let resize_filter = cfg - .resample - .and_then(pil_resample_to_filter_type) - .unwrap_or(FilterType::CatmullRom); - - let resized = if target_h != orig_h || target_w != orig_w { - image::imageops::resize(image, target_w, target_h, resize_filter) - } else { - image.clone() - }; - - if target_h % (cfg.patch_size as u32) != 0 || target_w % (cfg.patch_size as u32) != 0 { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR preprocess produced non-divisible dims: {target_h}x{target_w} not divisible by patch_size={}", - cfg.patch_size - ), - }); - } - - let grid_h = (target_h / cfg.patch_size as u32) as usize; - let grid_w = (target_w / cfg.patch_size as u32) as usize; - let merged_h = grid_h / spatial_merge_size; - let merged_w = grid_w / spatial_merge_size; - if merged_h == 0 || merged_w == 0 { - return Err(OCRError::ConfigError { - message: format!( - "LightOnOCR preprocess produced grid {}x{} too small for spatial_merge_size={}", - grid_h, grid_w, spatial_merge_size - ), - }); - } - - let scale = if cfg.do_rescale { - Some(cfg.rescale_factor) - } else { - None - }; - - let mean: &[f32] = if cfg.do_normalize { - &cfg.image_mean - } else { - &[0.0, 0.0, 0.0] - }; - let std: &[f32] = if cfg.do_normalize { - &cfg.image_std - } else { - &[1.0, 1.0, 1.0] - }; - - let data = image_to_chw(&resized, mean, std, scale); - - let pixel_values = Tensor::from_vec( - data, - (1usize, 3usize, target_h as usize, target_w as usize), - device, - ) - .map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "LightOnOCR: failed to create pixel_values tensor", - e, - ) - })?; - - let pixel_values = pixel_values.to_dtype(dtype).map_err(|e| { - candle_to_ocr_processing( - oar_ocr_core::core::errors::ProcessingStage::TensorOperation, - "LightOnOCR: failed to cast pixel_values dtype", - e, - ) - })?; - - Ok(LightOnOcrImageInputs { - pixel_values, - grid_h, - grid_w, - }) -} diff --git a/oar-ocr-vl/src/lightonocr/text.rs b/oar-ocr-vl/src/lightonocr/text.rs deleted file mode 100644 index 75662d5..0000000 --- a/oar-ocr-vl/src/lightonocr/text.rs +++ /dev/null @@ -1,309 +0,0 @@ -use std::cell::RefCell; -use std::sync::Arc; - -use candle_core::{IndexOp, Result, Tensor}; -use candle_nn::{ - Activation, Embedding, Linear, Module, RmsNorm, VarBuilder, embedding, kv_cache::KvCache, - linear_b, rms_norm, -}; - -use super::config::LightOnOcrTextConfig; -use crate::attention::{RotaryEmbedding, repeat_kv, scaled_dot_product_attention}; - -struct Mlp { - gate_proj: Linear, - up_proj: Linear, - down_proj: Linear, - act_fn: Activation, -} - -impl Mlp { - fn new(cfg: &LightOnOcrTextConfig, vb: VarBuilder) -> Result { - let hidden_sz = cfg.hidden_size; - let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_b(hidden_sz, intermediate_sz, false, vb.pp("gate_proj"))?; - let up_proj = linear_b(hidden_sz, intermediate_sz, false, vb.pp("up_proj"))?; - let down_proj = linear_b(intermediate_sz, hidden_sz, false, vb.pp("down_proj"))?; - Ok(Self { - gate_proj, - up_proj, - down_proj, - act_fn: cfg.hidden_act, - }) - } - - fn forward(&self, xs: &Tensor) -> Result { - let lhs = self.gate_proj.forward(xs)?.apply(&self.act_fn)?; - let rhs = self.up_proj.forward(xs)?; - self.down_proj.forward(&(lhs * rhs)?) - } -} - -struct Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - q_norm: RmsNorm, - k_norm: RmsNorm, - num_heads: usize, - num_kv_heads: usize, - head_dim: usize, - rotary_emb: Arc, - n_kv_groups: usize, - softmax_scale: f64, - kv_cache: RefCell, -} - -impl Attention { - fn new( - rotary_emb: Arc, - cfg: &LightOnOcrTextConfig, - vb: VarBuilder, - ) -> Result { - let hidden_sz = cfg.hidden_size; - let num_heads = cfg.num_attention_heads; - let num_kv_heads = cfg.num_key_value_heads; - let q_proj = linear_b( - hidden_sz, - num_heads * cfg.head_dim, - cfg.attention_bias, - vb.pp("q_proj"), - )?; - let k_proj = linear_b( - hidden_sz, - num_kv_heads * cfg.head_dim, - cfg.attention_bias, - vb.pp("k_proj"), - )?; - let v_proj = linear_b( - hidden_sz, - num_kv_heads * cfg.head_dim, - cfg.attention_bias, - vb.pp("v_proj"), - )?; - let o_proj = linear_b( - num_heads * cfg.head_dim, - hidden_sz, - cfg.attention_bias, - vb.pp("o_proj"), - )?; - let q_norm = rms_norm(cfg.head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?; - let k_norm = rms_norm(cfg.head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?; - - // Create KvCache with dim=2 for seq_len dimension - // Cap initial capacity to avoid excessive memory allocation - let kv_cache_capacity = cfg.max_position_embeddings.min(8192); - let kv_cache = KvCache::new(2, kv_cache_capacity); - - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - q_norm, - k_norm, - num_heads, - num_kv_heads, - head_dim: cfg.head_dim, - rotary_emb, - n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads, - softmax_scale: 1.0 / (cfg.head_dim as f64).sqrt(), - kv_cache: RefCell::new(kv_cache), - }) - } - - fn forward( - &self, - xs: &Tensor, - attention_mask: Option<&Tensor>, - seqlen_offsets: &[usize], - ) -> Result { - let (b_sz, q_len, _) = xs.dims3()?; - let mut q = self.q_proj.forward(xs)?; - let mut k = self.k_proj.forward(xs)?; - let mut v = self.v_proj.forward(xs)?; - - q = q - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - k = k - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - v = v - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - - q = q.apply(&self.q_norm)?; - k = k.apply(&self.k_norm)?; - - (q, k) = self.rotary_emb.apply_rotary_emb(&q, &k, seqlen_offsets)?; - - let q = q.contiguous()?; - let k = k.contiguous()?; - let v = v.contiguous()?; - - let (k, v) = self.kv_cache.borrow_mut().append(&k, &v)?; - - let k = repeat_kv(&k, self.n_kv_groups)?.contiguous()?; - let v = repeat_kv(&v, self.n_kv_groups)?.contiguous()?; - - // Use unified attention implementation - let is_causal = attention_mask.is_none(); - let attn_output = scaled_dot_product_attention( - &q, - &k, - &v, - attention_mask, - self.softmax_scale, - is_causal, - )?; - let attn_output = attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?; - - self.o_proj.forward(&attn_output) - } - - fn clear_kv_cache(&self) { - self.kv_cache.borrow_mut().reset(); - } -} - -pub struct DecoderLayer { - self_attn: Attention, - mlp: Mlp, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, -} - -impl DecoderLayer { - fn new( - rotary_emb: Arc, - cfg: &LightOnOcrTextConfig, - vb: VarBuilder, - ) -> Result { - let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; - let mlp = Mlp::new(cfg, vb.pp("mlp"))?; - let input_layernorm = - rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = rms_norm( - cfg.hidden_size, - cfg.rms_norm_eps, - vb.pp("post_attention_layernorm"), - )?; - Ok(Self { - self_attn, - mlp, - input_layernorm, - post_attention_layernorm, - }) - } - - fn forward( - &self, - xs: &Tensor, - attention_mask: Option<&Tensor>, - seqlen_offsets: &[usize], - ) -> Result { - let residual = xs; - let xs = self.input_layernorm.forward(xs)?; - let xs = self - .self_attn - .forward(&xs, attention_mask, seqlen_offsets)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = self - .mlp - .forward(&xs.apply(&self.post_attention_layernorm)?)?; - residual + xs - } - - fn clear_kv_cache(&self) { - self.self_attn.clear_kv_cache(); - } -} - -pub struct LightOnOcrTextModel { - embed_tokens: Embedding, - norm: RmsNorm, - layers: Vec, - lm_head: Linear, -} - -impl LightOnOcrTextModel { - /// Maximum positions to precompute for RoPE embeddings. - /// Capped to avoid excessive memory usage when config specifies very large values. - const MAX_PRECOMPUTED_POSITIONS: usize = 8192; - - pub fn new(cfg: &LightOnOcrTextConfig, vb: VarBuilder) -> Result { - let vb_m = vb.pp("model").pp("language_model"); - let embed_tokens = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; - - // Cap precomputed RoPE size to avoid memory waste from large max_position_embeddings - let rope_max_pos = cfg - .max_position_embeddings - .min(Self::MAX_PRECOMPUTED_POSITIONS); - let rotary_emb = Arc::new(RotaryEmbedding::new_precomputed( - cfg.rope_theta as f32, - cfg.head_dim, - rope_max_pos, - vb.device(), - vb_m.dtype(), - )?); - - let vb_l = vb_m.pp("layers"); - let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - for layer_idx in 0..cfg.num_hidden_layers { - layers.push(DecoderLayer::new( - rotary_emb.clone(), - cfg, - vb_l.pp(layer_idx), - )?); - } - let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = if cfg.tie_word_embeddings { - Linear::new(embed_tokens.embeddings().clone(), None) - } else { - linear_b(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))? - }; - - Ok(Self { - embed_tokens, - norm, - layers, - lm_head, - }) - } - - pub fn embed_tokens(&self, input_ids: &Tensor) -> Result { - self.embed_tokens.forward(input_ids) - } - - pub fn forward_embeds( - &self, - mut xs: Tensor, - attention_mask: Option<&Tensor>, - seqlen_offsets: &[usize], - ) -> Result { - let (_, seq_len, _) = xs.dims3()?; - - for layer in &self.layers { - let mask = attention_mask - .map(|m| m.to_device(xs.device())) - .transpose()?; - xs = layer.forward(&xs, mask.as_ref(), seqlen_offsets)?; - } - - xs = xs.apply(&self.norm)?; - - self.lm_head - .forward(&xs)? - .i((.., seq_len - 1, ..))? - .contiguous() - } - - pub fn clear_kv_cache(&self) { - for layer in &self.layers { - layer.clear_kv_cache(); - } - } -} diff --git a/oar-ocr-vl/src/lightonocr/vision.rs b/oar-ocr-vl/src/lightonocr/vision.rs deleted file mode 100644 index b22833e..0000000 --- a/oar-ocr-vl/src/lightonocr/vision.rs +++ /dev/null @@ -1,568 +0,0 @@ -//! Pixtral Vision Model implementation for LightOnOCR. -//! -//! This is a custom implementation that matches HuggingFace Transformers' -//! `modeling_pixtral.py` more closely than the Candle version, particularly -//! in the RoPE (Rotary Position Embedding) computation. - -use crate::attention::on_compute_device; -use candle_core::{D, DType, Device, Module, Result, Tensor}; -use candle_nn::{ - Conv2d, Conv2dConfig, Linear, RmsNorm, VarBuilder, conv2d_no_bias, linear_b, rms_norm, -}; -use tracing::debug; - -use super::config::LightOnOcrVisionConfig; - -/// Pixtral Vision Model configuration. -#[derive(Debug, Clone)] -pub struct PixtralVisionConfig { - pub hidden_size: usize, - pub num_channels: usize, - pub image_size: usize, - pub patch_size: usize, - pub rope_theta: f64, - pub intermediate_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub head_dim: usize, - pub hidden_act: candle_nn::Activation, -} - -impl From<&LightOnOcrVisionConfig> for PixtralVisionConfig { - fn from(cfg: &LightOnOcrVisionConfig) -> Self { - Self { - hidden_size: cfg.hidden_size, - num_channels: cfg.num_channels, - image_size: cfg.image_size, - patch_size: cfg.patch_size, - rope_theta: cfg.rope_theta, - intermediate_size: cfg.intermediate_size, - num_hidden_layers: cfg.num_hidden_layers, - num_attention_heads: cfg.num_attention_heads, - head_dim: cfg.head_dim, - hidden_act: cfg.hidden_act, - } - } -} - -/// Compute 2D position IDs in meshgrid format. -/// -/// This matches `position_ids_in_meshgrid` from HuggingFace Transformers. -fn position_ids_in_meshgrid( - num_patches_h: usize, - num_patches_w: usize, - max_width: usize, - device: &Device, -) -> Result { - on_compute_device(device, |compute_device| { - let h = Tensor::arange(0u32, num_patches_h as u32, compute_device)?; - let w = Tensor::arange(0u32, num_patches_w as u32, compute_device)?; - - // meshgrid with indexing="ij" - let h_grid = h - .unsqueeze(1)? - .broadcast_as((num_patches_h, num_patches_w))?; - let w_grid = w - .unsqueeze(0)? - .broadcast_as((num_patches_h, num_patches_w))?; - - // ids = h_grid * max_width + w_grid - let ids = (h_grid.to_dtype(DType::U32)? * (max_width as f64))? - .add(&w_grid.to_dtype(DType::U32)?)? - .flatten_all()?; - - Ok(ids) - }) -} - -/// Pixtral Rotary Embedding. -/// -/// The key difference from standard RoPE is that Pixtral uses 2D position encoding -/// where half the frequencies are for height and half for width. -#[derive(Debug, Clone)] -pub struct PixtralRotaryEmbedding { - /// Precomputed inverse frequencies, shape: (max_patches^2, head_dim) - inv_freq: Tensor, - max_patches_per_side: usize, -} - -impl PixtralRotaryEmbedding { - pub fn new(cfg: &PixtralVisionConfig, device: &Device) -> Result { - let dim = cfg.head_dim; - let base = cfg.rope_theta as f32; - let max_patches_per_side = cfg.image_size / cfg.patch_size; - - debug!( - "PixtralRotaryEmbedding: dim={}, base={}, max_patches={}", - dim, base, max_patches_per_side - ); - - // Compute base frequencies: 1 / (theta ^ (2i/dim)) for i in 0..dim/2 - let freqs: Vec = (0..dim) - .step_by(2) - .map(|i| 1f32 / base.powf(i as f32 / dim as f32)) - .collect(); - - debug!( - "PixtralRotaryEmbedding: freqs len={}, first 4: {:?}", - freqs.len(), - &freqs[..4.min(freqs.len())] - ); - - // Split frequencies: even indices for height, odd indices for width - let freqs_h: Vec = freqs.iter().step_by(2).copied().collect(); - let freqs_w: Vec = freqs.iter().skip(1).step_by(2).copied().collect(); - - debug!( - "PixtralRotaryEmbedding: freqs_h len={}, first 4: {:?}", - freqs_h.len(), - &freqs_h[..4.min(freqs_h.len())] - ); - debug!( - "PixtralRotaryEmbedding: freqs_w len={}, first 4: {:?}", - freqs_w.len(), - &freqs_w[..4.min(freqs_w.len())] - ); - - // Use on_compute_device to handle Metal's lack of support for arange and broadcast_as - let inv_freq = on_compute_device(device, |compute_device| { - let freqs_h = Tensor::new(freqs_h.clone(), compute_device)?; - let freqs_w = Tensor::new(freqs_w.clone(), compute_device)?; - - // Position indices - let h = Tensor::arange(0u32, max_patches_per_side as u32, compute_device)? - .to_dtype(DType::F32)?; - let w = Tensor::arange(0u32, max_patches_per_side as u32, compute_device)? - .to_dtype(DType::F32)?; - - // Compute outer products: (max_patches, dim/4) - let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?; - let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?; - - debug!( - "PixtralRotaryEmbedding: freqs_h outer shape {:?}, freqs_w outer shape {:?}", - freqs_h.dims(), - freqs_w.dims() - ); - - // Build the full inv_freq tensor: - // freqs_h: (max_patches, 1, dim/4) repeated along width - // freqs_w: (1, max_patches, dim/4) repeated along height - // concat along last dim -> (max_patches, max_patches, dim/2) - // reshape -> (max_patches^2, dim/2) - let freqs_h = freqs_h - .unsqueeze(1)? - .broadcast_as((max_patches_per_side, max_patches_per_side, freqs_h.dim(1)?))? - .contiguous()?; - let freqs_w = freqs_w - .unsqueeze(0)? - .broadcast_as((max_patches_per_side, max_patches_per_side, freqs_w.dim(1)?))? - .contiguous()?; - - debug!( - "PixtralRotaryEmbedding: broadcast freqs_h {:?}, freqs_w {:?}", - freqs_h.dims(), - freqs_w.dims() - ); - - let inv_freq = Tensor::cat(&[freqs_h, freqs_w], D::Minus1)? - .reshape((max_patches_per_side * max_patches_per_side, dim / 2))?; - - debug!( - "PixtralRotaryEmbedding: inv_freq after concat and reshape {:?}", - inv_freq.dims() - ); - - // CRITICAL: Duplicate inv_freq to full dimension (dim/2 -> dim) - // This matches: inv_freq = torch.cat((inv_freq, inv_freq), dim=-1) - Tensor::cat(&[&inv_freq, &inv_freq], D::Minus1) - })?; - - debug!( - "PixtralRotaryEmbedding: final inv_freq shape {:?}", - inv_freq.dims() - ); - - // Debug: print inv_freq values at key positions for verification - if let Ok(inv_freq_vec) = inv_freq.to_vec2::() { - // Position 0 (h=0, w=0) - should be all zeros - debug!( - "PixtralRotaryEmbedding: inv_freq[0, :8]: {:?}", - &inv_freq_vec[0][..8] - ); - // Position 1 (h=0, w=1) - should have zeros in first 16 (freqs_h[0]=0), and freqs_w[1] in next 16 - debug!( - "PixtralRotaryEmbedding: inv_freq[1, :8]: {:?}", - &inv_freq_vec[1][..8] - ); - // Position 110 (h=1, w=0) - should have freqs_h[1] in first 16, and zeros in next 16 - if inv_freq_vec.len() > 110 { - debug!( - "PixtralRotaryEmbedding: inv_freq[110, :8]: {:?}", - &inv_freq_vec[110][..8] - ); - debug!( - "PixtralRotaryEmbedding: inv_freq[110, 16:24]: {:?}", - &inv_freq_vec[110][16..24] - ); - } - } - - Ok(Self { - inv_freq, - max_patches_per_side, - }) - } - - /// Compute cos and sin embeddings for given position IDs. - pub fn forward(&self, position_ids: &Tensor, dtype: DType) -> Result<(Tensor, Tensor)> { - // Select frequencies for the given positions - let freqs = self.inv_freq.index_select(position_ids, 0)?; - - debug!( - "PixtralRotaryEmbedding forward: position_ids shape {:?}, freqs shape {:?}", - position_ids.dims(), - freqs.dims() - ); - - let cos = freqs.cos()?.to_dtype(dtype)?; - let sin = freqs.sin()?.to_dtype(dtype)?; - - Ok((cos, sin)) - } - - pub fn max_patches_per_side(&self) -> usize { - self.max_patches_per_side - } -} - -/// Rotate half of the hidden dims. -fn rotate_half(x: &Tensor) -> Result { - let last_dim = x.dim(D::Minus1)?; - let x1 = x.narrow(D::Minus1, 0, last_dim / 2)?; - let x2 = x.narrow(D::Minus1, last_dim / 2, last_dim / 2)?; - Tensor::cat(&[&x2.neg()?, &x1], D::Minus1) -} - -/// Apply rotary position embedding to query and key tensors. -/// -/// Args: -/// q: Query tensor of shape (batch, heads, seq_len, head_dim) -/// k: Key tensor of shape (batch, heads, seq_len, head_dim) -/// cos: Cosine embedding of shape (seq_len, head_dim) -/// sin: Sine embedding of shape (seq_len, head_dim) -fn apply_rotary_pos_emb( - q: &Tensor, - k: &Tensor, - cos: &Tensor, - sin: &Tensor, -) -> Result<(Tensor, Tensor)> { - // unsqueeze_dim=0 for Pixtral: (seq_len, head_dim) -> (1, seq_len, head_dim) - // This broadcasts with (batch, heads, seq_len, head_dim) - let cos = cos.unsqueeze(0)?; - let sin = sin.unsqueeze(0)?; - - let q_embed = q - .broadcast_mul(&cos)? - .add(&rotate_half(q)?.broadcast_mul(&sin)?)?; - let k_embed = k - .broadcast_mul(&cos)? - .add(&rotate_half(k)?.broadcast_mul(&sin)?)?; - - Ok((q_embed, k_embed)) -} - -/// Pixtral MLP layer. -#[derive(Debug, Clone)] -struct PixtralMlp { - gate_proj: Linear, - up_proj: Linear, - down_proj: Linear, - act_fn: candle_nn::Activation, -} - -impl PixtralMlp { - fn new(cfg: &PixtralVisionConfig, vb: VarBuilder) -> Result { - let h = cfg.hidden_size; - let i = cfg.intermediate_size; - let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?; - let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?; - let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?; - Ok(Self { - gate_proj, - up_proj, - down_proj, - act_fn: cfg.hidden_act, - }) - } -} - -impl Module for PixtralMlp { - fn forward(&self, xs: &Tensor) -> Result { - let gate = self.gate_proj.forward(xs)?.apply(&self.act_fn)?; - let up = self.up_proj.forward(xs)?; - self.down_proj.forward(&(gate * up)?) - } -} - -/// Pixtral Attention layer. -#[derive(Debug, Clone)] -struct PixtralAttention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - o_proj: Linear, - num_heads: usize, - head_dim: usize, - scale: f64, -} - -impl PixtralAttention { - fn new(cfg: &PixtralVisionConfig, vb: VarBuilder) -> Result { - let h = cfg.hidden_size; - let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?; - let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?; - let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?; - let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?; - let scale = (cfg.head_dim as f64).powf(-0.5); - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - num_heads: cfg.num_attention_heads, - head_dim: cfg.head_dim, - scale, - }) - } - - fn forward( - &self, - xs: &Tensor, - cos: &Tensor, - sin: &Tensor, - attention_mask: Option<&Tensor>, - ) -> Result { - let (batch, patches, _) = xs.dims3()?; - - let q = self.q_proj.forward(xs)?; - let k = self.k_proj.forward(xs)?; - let v = self.v_proj.forward(xs)?; - - // Reshape to (batch, heads, patches, head_dim) - let shape = (batch, patches, self.num_heads, self.head_dim); - let q = q.reshape(shape)?.transpose(1, 2)?.contiguous()?; - let k = k.reshape(shape)?.transpose(1, 2)?.contiguous()?; - let v = v.reshape(shape)?.transpose(1, 2)?.contiguous()?; - - // Apply rotary embeddings - let (q, k) = apply_rotary_pos_emb(&q, &k, cos, sin)?; - - // Scaled dot-product attention - let attn_weights = (q.matmul(&k.transpose(D::Minus2, D::Minus1)?)? * self.scale)?; - - let attn_weights = match attention_mask { - Some(mask) => attn_weights.broadcast_add(mask)?, - None => attn_weights, - }; - - // Compute softmax in float32 for numerical stability, then cast back - let input_dtype = attn_weights.dtype(); - let attn_weights = attn_weights.to_dtype(DType::F32)?; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - let attn_weights = attn_weights.to_dtype(input_dtype)?; - let attn_output = attn_weights.matmul(&v)?; - - // Reshape back to (batch, patches, hidden) - let attn_output = attn_output.transpose(1, 2)?.reshape((batch, patches, ()))?; - - self.o_proj.forward(&attn_output) - } -} - -/// Pixtral Attention Layer (attention + MLP with residual connections). -#[derive(Debug, Clone)] -struct PixtralAttentionLayer { - attention_norm: RmsNorm, - attention: PixtralAttention, - ffn_norm: RmsNorm, - feed_forward: PixtralMlp, -} - -impl PixtralAttentionLayer { - fn new(cfg: &PixtralVisionConfig, vb: VarBuilder) -> Result { - let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?; - let attention = PixtralAttention::new(cfg, vb.pp("attention"))?; - let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?; - let feed_forward = PixtralMlp::new(cfg, vb.pp("feed_forward"))?; - Ok(Self { - attention_norm, - attention, - ffn_norm, - feed_forward, - }) - } - - fn forward( - &self, - xs: &Tensor, - cos: &Tensor, - sin: &Tensor, - attention_mask: Option<&Tensor>, - ) -> Result { - // Pre-norm attention with residual - let residual = xs; - let xs = self.attention_norm.forward(xs)?; - let xs = self.attention.forward(&xs, cos, sin, attention_mask)?; - let xs = (residual + xs)?; - - // Pre-norm FFN with residual - let residual = &xs; - let xs = self.ffn_norm.forward(&xs)?; - let xs = self.feed_forward.forward(&xs)?; - residual + xs - } -} - -/// Pixtral Transformer. -#[derive(Debug, Clone)] -struct PixtralTransformer { - layers: Vec, -} - -impl PixtralTransformer { - fn new(cfg: &PixtralVisionConfig, vb: VarBuilder) -> Result { - let vb_layers = vb.pp("layers"); - let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - for i in 0..cfg.num_hidden_layers { - layers.push(PixtralAttentionLayer::new(cfg, vb_layers.pp(i))?); - } - Ok(Self { layers }) - } - - fn forward( - &self, - xs: &Tensor, - cos: &Tensor, - sin: &Tensor, - attention_mask: Option<&Tensor>, - ) -> Result { - let mut hidden_states = xs.clone(); - for layer in &self.layers { - hidden_states = layer.forward(&hidden_states, cos, sin, attention_mask)?; - } - Ok(hidden_states) - } -} - -/// Generate block attention mask for multiple images. -/// -/// Each image's patches can only attend to patches from the same image. -fn generate_block_attention_mask( - patch_counts: &[usize], - dtype: DType, - device: &Device, -) -> Result { - let seq_len: usize = patch_counts.iter().sum(); - // Use dtype.min for numerical stability (matches HuggingFace) - // BF16/F16 min is ~-3.39e+38, which is more stable than NEG_INFINITY - let d_min: f32 = match dtype { - DType::F32 => f32::MIN, - DType::F16 => -65504.0, // half::f16::MIN - DType::BF16 => -3.3895313e+38, // half::bf16::MIN - _ => f32::MIN, - }; - - // Start with all d_min (no attention) - let mut mask_data = vec![d_min; seq_len * seq_len]; - - // For each block, allow full attention within the block - let mut offset = 0usize; - for &count in patch_counts { - for i in 0..count { - for j in 0..count { - mask_data[(offset + i) * seq_len + (offset + j)] = 0.0; - } - } - offset += count; - } - - let mask = Tensor::from_vec(mask_data, (seq_len, seq_len), device)?; - // Expand to (1, 1, seq_len, seq_len) for broadcasting - mask.unsqueeze(0)?.unsqueeze(0)?.to_dtype(dtype) -} - -/// Pixtral Vision Model. -#[derive(Debug, Clone)] -pub struct PixtralVisionModel { - patch_conv: Conv2d, - ln_pre: RmsNorm, - transformer: PixtralTransformer, - rotary_emb: PixtralRotaryEmbedding, -} - -impl PixtralVisionModel { - pub fn new(cfg: &PixtralVisionConfig, vb: VarBuilder) -> Result { - let conv_cfg = Conv2dConfig { - stride: cfg.patch_size, - ..Default::default() - }; - let patch_conv = conv2d_no_bias( - cfg.num_channels, - cfg.hidden_size, - cfg.patch_size, - conv_cfg, - vb.pp("patch_conv"), - )?; - - let ln_pre = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?; - let transformer = PixtralTransformer::new(cfg, vb.pp("transformer"))?; - let rotary_emb = PixtralRotaryEmbedding::new(cfg, vb.device())?; - - Ok(Self { - patch_conv, - ln_pre, - transformer, - rotary_emb, - }) - } - - /// Forward pass for a single image. - /// - /// Args: - /// pixel_values: Image tensor of shape (1, C, H, W) - /// - /// Returns: - /// Hidden states of shape (1, num_patches, hidden_size) - pub fn forward(&self, pixel_values: &Tensor) -> Result { - let dtype = pixel_values.dtype(); - - // Apply patch convolution: (1, C, H, W) -> (1, hidden, grid_h, grid_w) - let patch_embeds = self.patch_conv.forward(pixel_values)?; - let (_, _, grid_h, grid_w) = patch_embeds.dims4()?; - - // Flatten and transpose: (1, hidden, grid_h * grid_w) -> (1, grid_h * grid_w, hidden) - let patch_embeds = patch_embeds.flatten_from(2)?.transpose(1, 2)?; - - // Apply pre-norm - let patch_embeds = self.ln_pre.forward(&patch_embeds)?; - - // Compute position IDs and rotary embeddings - let position_ids = position_ids_in_meshgrid( - grid_h, - grid_w, - self.rotary_emb.max_patches_per_side(), - patch_embeds.device(), - )?; - let (cos, sin) = self.rotary_emb.forward(&position_ids, dtype)?; - - // For single image, use block attention mask (though it's effectively full attention) - let patch_count = grid_h * grid_w; - let attention_mask = - generate_block_attention_mask(&[patch_count], dtype, patch_embeds.device())?; - - // Run transformer - self.transformer - .forward(&patch_embeds, &cos, &sin, Some(&attention_mask)) - } -} diff --git a/oar-ocr-vl/src/mineru/model.rs b/oar-ocr-vl/src/mineru/model.rs index f77e9eb..7e95510 100644 --- a/oar-ocr-vl/src/mineru/model.rs +++ b/oar-ocr-vl/src/mineru/model.rs @@ -2,22 +2,105 @@ use super::config::{MinerUConfig, MinerUImageProcessorConfig}; use super::processing::preprocess_images; use super::text::MinerUTextModel; use super::vision::MinerUVisionModel; +#[cfg(feature = "hsd")] +use crate::attention::create_tree_attention_mask; use crate::attention::{ combine_masks, create_causal_mask, create_left_padding_mask, on_compute_device, }; +#[cfg(feature = "hsd")] +use crate::hsd::backend_util::{commit_keep_indices, step_pos_ids, tree_pos_ids}; +#[cfg(feature = "hsd")] +use crate::hsd::drafting::{ + TargetDraftAdapter, bbox_xyxy, crop_region_image, format_verified_region, map_layout_kind, + region_markdown_for, region_markdowns_for, structure_result_to_layout_elements, +}; +#[cfg(feature = "hsd")] +use crate::hsd::prefix_tree::PrefixTree; +#[cfg(feature = "hsd")] +use crate::hsd::types::{AcceptStats, Draft, HsdConfig, HsdStats, RegionStageStats}; +#[cfg(feature = "hsd")] +use crate::hsd::verify::{SpecBackend, spec_decode}; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing}; +#[cfg(feature = "hsd")] +use candle_core::{D, Result as CandleResult}; use candle_core::{DType, Device, IndexOp, Tensor}; +#[cfg(feature = "hsd")] +use candle_nn::ops as cnn_ops; use candle_nn::{Linear, Module, VarBuilder, linear_no_bias}; use image::RgbImage; use oar_ocr_core::core::OCRError; +use oar_ocr_core::domain::structure::LayoutElementType; +#[cfg(feature = "hsd")] +use oar_ocr_core::domain::structure::{LayoutElement, StructureResult}; use rand::distr::weighted::WeightedIndex; use rand::prelude::*; use serde::Deserialize; use std::cmp::Ordering; use std::collections::HashSet; use std::path::Path; +#[cfg(feature = "hsd")] +use std::time::{Duration, Instant}; use tokenizers::Tokenizer; +/// Canonical MinerU2.5 per-element prompts as defined by the official +/// `mineru_vl_utils` package (`DEFAULT_PROMPTS` in `mineru_client.py`). +/// +/// MinerU's `two_step_extract` flow first runs a layout pass, then routes each +/// cropped region to a per-type recognizer with the matching prompt. Outside +/// of `two_step_extract`, callers can still mix and match: a single +/// `Text Recognition:` prompt fed an entire page yields a generic markdown +/// output (the non-standard usage we previously defaulted to in +/// `hsd_omnidocbench`). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(not(feature = "hsd"), allow(dead_code))] +pub enum MinerUTaskPrompt { + /// `\nText Recognition:` — default for body text, titles, paragraphs, + /// lists, captions, references, footnotes, page numbers, etc. + Text, + /// `\nFormula Recognition:` — display formulas (`Formula`, + /// `FormulaNumber`). + Formula, + /// `\nTable Recognition:` — tables. + Table, + /// `\nImage Analysis:` — figure / image / chart blocks. + ImageAnalysis, + /// `\nLayout Detection:` — full-page layout dump (only used by + /// `two_step_extract` Stage 0, not HSD verify). Kept for completeness + /// with the official `mineru_vl_utils` prompt set so callers can drive + /// the layout pass externally if they choose. + #[allow(dead_code)] + LayoutDetection, +} + +#[cfg_attr(not(feature = "hsd"), allow(dead_code))] +impl MinerUTaskPrompt { + /// Canonical prompt string (with the leading `\n` that MinerU's + /// `two_step_extract` builds via its chat-template wrapper). + pub fn prompt(self) -> &'static str { + match self { + Self::Text => "\nText Recognition:", + Self::Formula => "\nFormula Recognition:", + Self::Table => "\nTable Recognition:", + Self::ImageAnalysis => "\nImage Analysis:", + Self::LayoutDetection => "\nLayout Detection:", + } + } + + /// Map an OAR `LayoutElementType` to the MinerU element prompt that best + /// matches its content kind. Mirrors the heuristic the official `mineru_vl_utils` + /// client uses when picking a per-block prompt (text-like → `[default]`, + /// table → `table`, equation → `equation`, image/chart → `image`). + pub fn for_layout(t: LayoutElementType) -> Self { + use LayoutElementType::*; + match t { + Table => Self::Table, + Formula | FormulaNumber => Self::Formula, + Image | Chart | Seal | HeaderImage | FooterImage => Self::ImageAnalysis, + _ => Self::Text, + } + } +} + pub struct MinerU { device: Device, dtype: DType, @@ -241,7 +324,47 @@ impl MinerU { })]; } - match self.generate_internal(images, instructions, max_new_tokens) { + match self.generate_tokens_internal(images, instructions, max_new_tokens) { + Ok(results) => results + .into_iter() + .map(|tokens| self.decode_generated_tokens(&tokens)) + .collect(), + Err(e) => { + let msg = format!("generation failed: {e}"); + (0..images.len()) + .map(|_| { + Err(OCRError::InvalidInput { + message: msg.clone(), + }) + }) + .collect() + } + } + } + + /// Generate raw baseline tokens for oracle-draft / tokenizer round-trip + /// experiments. Tokens are exactly the ids emitted by the decode loop, + /// excluding stop tokens, before skip-token filtering and tokenizer decode. + pub fn generate_tokens( + &self, + images: &[RgbImage], + instructions: &[impl AsRef], + max_new_tokens: usize, + ) -> Vec, OCRError>> { + if images.is_empty() { + return Vec::new(); + } + if images.len() != instructions.len() { + return vec![Err(OCRError::InvalidInput { + message: format!( + "MinerU2.5: images count ({}) != instructions count ({})", + images.len(), + instructions.len() + ), + })]; + } + + match self.generate_tokens_internal(images, instructions, max_new_tokens) { Ok(results) => results.into_iter().map(Ok).collect(), Err(e) => { let msg = format!("generation failed: {e}"); @@ -256,12 +379,12 @@ impl MinerU { } } - fn generate_internal( + fn generate_tokens_internal( &self, images: &[RgbImage], instructions: &[impl AsRef], max_new_tokens: usize, - ) -> Result, OCRError> { + ) -> Result>, OCRError> { let batch_size = images.len(); let image_inputs = preprocess_images(images, &self.image_cfg, &self.device, self.dtype)?; @@ -307,7 +430,11 @@ impl MinerU { } let seq_lens: Vec = all_input_ids.iter().map(|ids| ids.len()).collect(); - let max_seq_len = *seq_lens.iter().max().unwrap(); + let Some(&max_seq_len) = seq_lens.iter().max() else { + return Err(OCRError::InvalidInput { + message: "MinerU2.5: empty batch is not supported".to_string(), + }); + }; let mut batch_embeds: Vec = Vec::with_capacity(batch_size); let mut rope_deltas: Vec = Vec::with_capacity(batch_size); @@ -458,14 +585,7 @@ impl MinerU { break; } - let sampling_params = SamplingParams { - repetition_penalty: self.repetition_penalty, - no_repeat_ngram_size: self.no_repeat_ngram_size, - do_sample: self.do_sample, - temperature: self.temperature, - top_p: self.top_p, - top_k: self.top_k, - }; + let sampling_params = self.sampling_params(); let mut next_tokens: Vec = Vec::with_capacity(batch_size); for (i, logits) in logits_list.iter().enumerate() { if finished[i] { @@ -528,23 +648,590 @@ impl MinerU { } } - let mut results = Vec::with_capacity(batch_size); - for tokens in generated.into_iter() { - // Filter out bos/eos/pad tokens before decoding (matching official implementation) - let filtered: Vec = tokens - .into_iter() - .filter(|t| !self.skip_token_ids.contains(t)) - .collect(); - let decoded = self - .tokenizer - .decode(&filtered, false) // skip_special_tokens=false to preserve special tokens - .map_err(|e| OCRError::InvalidInput { - message: format!("decode failed: {e}"), - })?; - results.push(decoded); + Ok(generated) + } + + pub fn decode_tokens(&self, tokens: &[u32]) -> Result { + self.decode_generated_tokens(tokens) + } + + /// Decode tokens in the form the model actually emitted. MinerU2.5's + /// `decode_tokens` only filters bos/eos/pad before `tokenizer.decode` — + /// there is no markdown / wrapping / layout post-process at this layer + /// (layout-aware reordering happens in `two_step_extract`, not here). + /// This alias exists for API symmetry with PaddleOCR-VL / GLM-OCR. + pub fn decode_tokens_raw(&self, tokens: &[u32]) -> Result { + self.decode_generated_tokens(tokens) + } + + pub fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + fn sampling_params(&self) -> SamplingParams { + SamplingParams { + repetition_penalty: self.repetition_penalty, + no_repeat_ngram_size: self.no_repeat_ngram_size, + do_sample: self.do_sample, + temperature: self.temperature, + top_p: self.top_p, + top_k: self.top_k, + } + } + + fn decode_generated_tokens(&self, tokens: &[u32]) -> Result { + // Filter out bos/eos/pad tokens before decoding (matching official implementation). + let filtered: Vec = tokens + .iter() + .copied() + .filter(|t| !self.skip_token_ids.contains(t)) + .collect(); + self.tokenizer + .decode(&filtered, false) // skip_special_tokens=false to preserve special tokens + .map_err(|e| OCRError::InvalidInput { + message: format!("decode failed: {e}"), + }) + } + + /// Hierarchical Speculative Decoding entry for a single image / region. + /// + /// MinerU2.5 is naturally per-region (decoupled VLM in the paper's + /// taxonomy), so the only HSD stage that applies is region-level. The + /// verifier applies the same repetition / n-gram / sampling logits + /// processors as the baseline generator before DSV acceptance decisions. + #[cfg(feature = "hsd")] + pub fn generate_hsd( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let t_drafter = Instant::now(); + let tokenized = self.tokenize_drafts(drafts)?; + self.generate_hsd_tokenized( + image, + instruction, + &tokenized, + hsd_cfg, + hsd_cfg.max_region_tokens, + t_drafter.elapsed(), + ) + } + + /// HSD entry that consumes already-tokenized drafts. This is the oracle + /// path used by benchmarks to avoid `decode -> encode` tokenizer + /// round-trips when the draft comes from this backend's own baseline. + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_token_drafts( + &self, + image: &RgbImage, + instruction: &str, + drafts: &[Draft], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + self.generate_hsd_tokenized( + image, + instruction, + drafts, + hsd_cfg, + hsd_cfg.max_region_tokens, + Duration::ZERO, + ) + } + + #[cfg(feature = "hsd")] + fn tokenize_drafts(&self, drafts: &[String]) -> Result, OCRError> { + let mut tokenized: Vec = Vec::with_capacity(drafts.len()); + for d in drafts { + if d.trim().is_empty() { + continue; + } + let enc = + self.tokenizer + .encode(d.as_str(), false) + .map_err(|e| OCRError::InvalidInput { + message: format!("MinerU2.5 HSD: tokenizer encode failed: {e}"), + })?; + let tokens = enc.get_ids().to_vec(); + if !tokens.is_empty() { + tokenized.push(Draft::new(tokens)); + } + } + Ok(tokenized) + } + + #[cfg(feature = "hsd")] + fn generate_hsd_tokenized( + &self, + image: &RgbImage, + instruction: &str, + tokenized: &[Draft], + hsd_cfg: &HsdConfig, + max_new_tokens: usize, + drafter_elapsed: Duration, + ) -> Result<(String, HsdStats), OCRError> { + if !self.device.is_cuda() { + return Err(OCRError::ConfigError { + message: "HSD requires CUDA device".to_string(), + }); + } + + let mut stats = HsdStats { + drafter: drafter_elapsed, + ..Default::default() + }; + // Stage 2 fields are reused for stat bookkeeping in the single-image path. + let t_pre = Instant::now(); + let (initial_lp, rope_delta, prompt_tokens) = + self.hsd_prefill_single(image, instruction)?; + stats.stage2.vision_prefill = t_pre.elapsed(); + stats.stage2.forward_passes = 1; + + let t_dec = Instant::now(); + let mut backend = MinerUSpecBackend::new(self, rope_delta, prompt_tokens); + let mut accept = AcceptStats::default(); + let mut dsv = Default::default(); + let generated = spec_decode( + &mut backend, + tokenized, + initial_lp, + max_new_tokens, + &hsd_cfg.dsv, + &mut accept, + &mut dsv, + ) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "spec_decode", e))?; + stats.stage2.decode = t_dec.elapsed(); + stats.stage2.emitted_tokens = generated.len() as u32; + stats.stage2.accept = accept; + stats.stage2.dsv = dsv; + stats.stage2.forward_passes += backend.forward_passes; + + // Strip the first stop token and anything after it before decoding. + let stop_pos = generated + .iter() + .position(|t| self.eos_token_ids.contains(t)) + .unwrap_or(generated.len()); + let trimmed = &generated[..stop_pos]; + + // Match generate_internal: filter bos/eos/pad before decode, preserve + // other special tokens (skip_special_tokens=false). + let filtered: Vec = trimmed + .iter() + .copied() + .filter(|t| !self.skip_token_ids.contains(t)) + .collect(); + let text = self + .tokenizer + .decode(&filtered, false) + .map_err(|e| OCRError::InvalidInput { + message: format!("MinerU2.5 HSD: tokenizer decode failed: {e}"), + })?; + Ok((text, stats)) + } + + /// Run the full two-stage HSD: Stage 1 verifies each layout-detected + /// region against the layout drafter's text, then Stage 2 (gated by + /// `hsd_cfg.enable_stage2`) verifies the Stage-1-aggregated markdown on + /// the full image with `hsd_cfg.max_page_tokens` budget. + /// + /// - `enable_stage1 = false`: skip per-region verification; build the + /// Stage 2 draft set directly from the layout drafter's per-element + /// markdowns (`region_markdowns`). Mirrors the paper's Table 8 + /// "Page-level Spec. Decoding only" ablation. + /// - `enable_stage2 = false`: return the Stage-1-only aggregation (lossy + /// ablation matching paper Table 8). + /// + /// `region_instruction` is used only for Stage 1 crop verification; + /// `page_instruction` is used for Stage 2 full-page verification. + /// + /// **Two-step mode**: when `region_instruction` is empty, Stage 1 + /// dispatches a per-element prompt via [`MinerUTaskPrompt::for_layout`] + /// (e.g. `\nText Recognition:`, `\nTable Recognition:`, + /// `\nFormula Recognition:`). This mirrors MinerU's official + /// `two_step_extract` flow where each layout-detected block is routed to + /// its matching recognizer. Passing a non-empty `region_instruction` keeps + /// the legacy "one prompt for all regions" behaviour for ablation. + #[cfg(feature = "hsd")] + pub fn generate_hsd_full( + &self, + image: &RgbImage, + elements: &[LayoutElement], + ignore_labels: &[String], + page_instruction: &str, + region_instruction: &str, + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let mut stats = HsdStats::default(); + let mut region_md: Vec<(usize, String)> = Vec::with_capacity(elements.len()); + let two_step_mode = region_instruction.trim().is_empty(); + + if hsd_cfg.enable_stage1 { + for (idx, elem) in elements.iter().enumerate() { + if let Some(label) = &elem.label + && ignore_labels.iter().any(|l| l == label) + { + continue; + } + // Visual-only regions have no text to verify. + if matches!( + elem.element_type, + LayoutElementType::Image + | LayoutElementType::HeaderImage + | LayoutElementType::FooterImage + | LayoutElementType::Seal + ) { + continue; + } + let draft = region_markdown_for(elem, TargetDraftAdapter::MinerU); + if draft.trim().is_empty() { + continue; + } + + let bbox = bbox_xyxy(&elem.bbox); + let crop = crop_region_image(image, &bbox)?; + let drafts = vec![draft]; + // Two-step dispatch: pick the official MinerU per-element + // prompt based on `LayoutElementType`. The legacy fixed-prompt + // path is taken only when the caller explicitly passes a + // non-empty `region_instruction`. + let effective_region_instruction = if two_step_mode { + MinerUTaskPrompt::for_layout(elem.element_type).prompt() + } else { + region_instruction + }; + let (region_text, region_stats) = + self.generate_hsd(&crop, effective_region_instruction, &drafts, hsd_cfg)?; + stats.drafter += region_stats.drafter; + + let kind = map_layout_kind(elem.element_type); + stats.stage1_regions.push(RegionStageStats { + kind, + stats: region_stats.stage2.clone(), + }); + stats.stage1.add_assign(region_stats.stage2); + let order = elem.order_index.map(|x| x as usize).unwrap_or(idx); + region_md.push((order, format_verified_region(®ion_text, kind))); + } + } + + region_md.sort_by_key(|(order, _)| *order); + let region_md: Vec = region_md + .into_iter() + .map(|(_, text)| text) + .filter(|s| !s.trim().is_empty()) + .collect(); + + // Stage 2 — page-level global verification on the full image. Per + // paper Eq. 3 the page draft is the *unordered set* `Ỹ^pg = {ŷ^(i)}`, + // one draft per region. We pass the Vec straight to `spec_decode` + // instead of pre-joining: `collect_candidates` scans each draft + // independently (Eqs. 1+2), so per-region n-gram locality is + // preserved even when full-page transitions don't appear naturally + // in the target VLM's output. Budget = `max_page_tokens`. + if hsd_cfg.enable_stage2 { + let t_drafter = Instant::now(); + let page_drafts: Vec = if !region_md.is_empty() { + region_md.clone() + } else { + region_markdowns_for(elements, ignore_labels, TargetDraftAdapter::MinerU) + }; + if !page_drafts.is_empty() { + let tokenized = self.tokenize_drafts(&page_drafts)?; + let (text, s2_stats) = self.generate_hsd_tokenized( + image, + page_instruction, + &tokenized, + hsd_cfg, + hsd_cfg.max_page_tokens, + t_drafter.elapsed(), + )?; + stats.stage2 = s2_stats.stage2; + stats.drafter += s2_stats.drafter; + return Ok((text, stats)); + } + } + + // Stage 2 disabled or no draft to verify — return Stage-1-only join + // as a human-readable fallback. The `\n\n` separator here is for the + // *output* (caller-facing), not for any further HSD input. + Ok((region_md.join("\n\n"), stats)) + } + + /// One-call HSD entry that consumes a `StructureResult` (the output of + /// the OARStructure / PP-StructureV3 pipeline) directly. + /// + /// Backfills table HTML / formula LaTeX via + /// [`structure_result_to_layout_elements`] then delegates to + /// [`Self::generate_hsd_full`]. When `region_instruction` is empty the + /// MinerU two-step mode kicks in and each region uses its canonical + /// per-type prompt (`MinerUTaskPrompt::for_layout`). + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_structure( + &self, + image: &RgbImage, + page_instruction: &str, + region_instruction: &str, + structure: &StructureResult, + ignore_labels: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let elements = structure_result_to_layout_elements(structure); + self.generate_hsd_full( + image, + &elements, + ignore_labels, + page_instruction, + region_instruction, + hsd_cfg, + ) + } + + /// Run a single-image prefill with the supplied instruction. Returns + /// the F32 last-position log-probabilities and the MRoPE delta. + #[cfg(feature = "hsd")] + fn hsd_prefill_single( + &self, + image: &RgbImage, + instruction: &str, + ) -> Result<(Tensor, i64, Vec), OCRError> { + // Preprocess single image. + let image_inputs = preprocess_images( + std::slice::from_ref(image), + &self.image_cfg, + &self.device, + self.dtype, + )?; + let (t, h, w) = image_inputs.image_grid_thw[0]; + let image_token_count = (t * h * w) / (self.spatial_merge_size * self.spatial_merge_size); + + // Build prompt and expand image placeholders. + let prompt = build_prompt(instruction); + let enc = self + .tokenizer + .encode(prompt, false) + .map_err(|e| OCRError::InvalidInput { + message: format!("MinerU2.5 HSD: tokenizer encode failed: {e}"), + })?; + let input_ids = + expand_image_tokens(enc.get_ids(), self.image_token_id, &[image_token_count])?; + let seq_len = input_ids.len(); + + // Vision features. + let image_embeds = self + .vision + .forward(&image_inputs.pixel_values, &image_inputs.image_grid_thw)?; + let actual = image_embeds + .dim(0) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "image_embeds dim", e))?; + if actual != image_token_count { + return Err(OCRError::InvalidInput { + message: format!( + "MinerU2.5 HSD: image embeds count mismatch: got {actual}, expected {image_token_count}" + ), + }); + } + + // Build embeddings, splice in the image tokens. + let input_ids_t = Tensor::new(input_ids.clone(), &self.device) + .and_then(|t| t.reshape((1, seq_len))) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "create input_ids", e))?; + let mut inputs_embeds = self.text.embed(&input_ids_t)?; + + if let Some(first_pos) = input_ids.iter().position(|&id| id == self.image_token_id) { + let image_end = first_pos + image_token_count; + let mut parts: Vec = Vec::with_capacity(3); + if first_pos > 0 { + parts.push( + inputs_embeds + .narrow(1, 0, first_pos) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "narrow prefix", e))?, + ); + } + parts.push( + image_embeds + .unsqueeze(0) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "unsqueeze img", e))?, + ); + if image_end < seq_len { + parts.push( + inputs_embeds + .narrow(1, image_end, seq_len - image_end) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "narrow suffix", e))?, + ); + } + let refs: Vec<&Tensor> = parts.iter().collect(); + inputs_embeds = Tensor::cat(&refs, 1) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "cat embeds", e))?; + } + + // 3-axis MRoPE position ids + delta. + let (pos_ids, rope_delta) = get_rope_index( + &self.cfg, + &input_ids, + &[image_inputs.image_grid_thw[0]], + self.vision_start_token_id, + self.video_token_id, + self.spatial_merge_size, + &self.device, + )?; + + let causal = create_causal_mask(seq_len, seq_len, self.dtype, &self.device) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "create causal", e))?; + + self.text.clear_kv_cache(); + let hidden = self.text.forward(&inputs_embeds, &pos_ids, Some(&causal))?; + + let last = hidden + .i((0, seq_len - 1, ..)) + .and_then(|t| t.unsqueeze(0)) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "get last hidden", e))?; + let logits = self + .lm_head + .forward(&last) + .and_then(|t| t.squeeze(0)) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "lm_head prefill", e))?; + let lp = processed_logprobs_from_logits( + &logits, + &input_ids, + &self.sampling_params(), + &self.device, + ) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "logits processors prefill", e))?; + Ok((lp, rope_delta, input_ids)) + } +} + +/// HSD adapter for MinerU2.5. Same shape as PaddleOCR-VL (3-axis MRoPE, +/// independent lm_head, rope_delta captured at prefill). +#[cfg(feature = "hsd")] +struct MinerUSpecBackend<'a> { + model: &'a MinerU, + rope_delta: i64, + history: Vec, + pending_tree: Option, + pre_verify_kv: usize, + forward_passes: u32, +} + +#[cfg(feature = "hsd")] +impl<'a> MinerUSpecBackend<'a> { + fn new(model: &'a MinerU, rope_delta: i64, prompt_tokens: Vec) -> Self { + Self { + model, + rope_delta, + history: prompt_tokens, + pending_tree: None, + pre_verify_kv: 0, + forward_passes: 0, + } + } + + fn project_logprobs_2d(&self, hidden_2d: &Tensor, tree: &PrefixTree) -> CandleResult { + let logits = self.model.lm_head.forward(hidden_2d)?; + let params = self.model.sampling_params(); + let mut rows: Vec = Vec::with_capacity(tree.num_nodes()); + for node_idx in 0..tree.num_nodes() { + let mut node_history = self.history.clone(); + node_history.extend(tree.path_tokens(node_idx)); + let row = logits.i(node_idx)?; + rows.push(processed_logprobs_from_logits( + &row, + &node_history, + ¶ms, + &self.model.device, + )?); + } + let refs: Vec<&Tensor> = rows.iter().collect(); + Tensor::stack(&refs, 0) + } + + fn project_logprobs_1d(&self, hidden_1d: &Tensor) -> CandleResult { + let logits = self + .model + .lm_head + .forward(&hidden_1d.unsqueeze(0)?)? + .squeeze(0)?; + processed_logprobs_from_logits( + &logits, + &self.history, + &self.model.sampling_params(), + &self.model.device, + ) + } +} + +#[cfg(feature = "hsd")] +impl<'a> SpecBackend for MinerUSpecBackend<'a> { + fn step_one(&mut self, token: u32) -> CandleResult { + let model = self.model; + let device = &model.device; + self.history.push(token); + + let tok_t = Tensor::new(vec![token], device)?.reshape((1usize, 1usize))?; + let embeds = model + .text + .embed(&tok_t) + .map_err(|e| candle_core::Error::Msg(format!("MinerU2.5 HSD step_one embed: {e}")))?; + + let pos_ids = step_pos_ids(3, model.text.current_kv_len(), self.rope_delta, device)?; + + let hidden = model + .text + .forward(&embeds, &pos_ids, None) + .map_err(|e| candle_core::Error::Msg(format!("MinerU2.5 HSD step_one forward: {e}")))?; + self.forward_passes += 1; + let last = hidden.i((0, 0, ..))?; + self.project_logprobs_1d(&last) + } + + fn verify_tree(&mut self, tree: &PrefixTree) -> CandleResult { + let n = tree.num_nodes(); + let model = self.model; + let device = &model.device; + let dtype = model.dtype; + + let prefix_kv = model.text.current_kv_len(); + self.pre_verify_kv = prefix_kv; + + let tok_t = Tensor::new(tree.tokens.clone(), device)?.reshape((1usize, n))?; + let embeds = model.text.embed(&tok_t).map_err(|e| { + candle_core::Error::Msg(format!("MinerU2.5 HSD verify_tree embed: {e}")) + })?; + + let pos_ids = tree_pos_ids(3, prefix_kv, self.rope_delta, tree, device)?; + let mask = create_tree_attention_mask(&tree.parents, prefix_kv, dtype, device)?; + + let hidden = model + .text + .forward(&embeds, &pos_ids, Some(&mask)) + .map_err(|e| { + candle_core::Error::Msg(format!("MinerU2.5 HSD verify_tree forward: {e}")) + })?; + self.forward_passes += 1; + let h2 = hidden.squeeze(0)?; + self.pending_tree = Some(tree.clone()); + self.project_logprobs_2d(&h2, tree) + } + + fn commit_verify(&mut self, accepted_path: &[usize]) -> CandleResult<()> { + let indices = commit_keep_indices(self.pre_verify_kv, accepted_path); + self.model + .text + .keep_kv_indices(&indices) + .map_err(|e| candle_core::Error::Msg(format!("MinerU2.5 HSD commit_verify: {e}")))?; + + if let Some(tree) = self.pending_tree.take() { + for &p in accepted_path { + self.history.push(tree.tokens[p]); + } } + Ok(()) + } - Ok(results) + fn is_eos(&self, tok: u32) -> bool { + self.model.eos_token_ids.contains(&tok) } } @@ -637,33 +1324,57 @@ fn select_next_token( .to_vec1::() .map_err(|e| candle_to_ocr_inference("MinerU2.5", "logits to vec", e))?; - apply_repetition_penalty(&mut logits_vec, history, params.repetition_penalty); - apply_no_repeat_ngram(&mut logits_vec, history, params.no_repeat_ngram_size); + apply_sampling_processors(&mut logits_vec, history, params); if !params.do_sample || params.top_k == 1 { return Ok(argmax_token(&logits_vec)); } + let probs = softmax(&logits_vec); + if let Some(idx) = sample_from_probs(&probs) { + Ok(idx as u32) + } else { + Ok(argmax_token(&logits_vec)) + } +} + +#[cfg(feature = "hsd")] +fn processed_logprobs_from_logits( + logits: &Tensor, + history: &[u32], + params: &SamplingParams, + device: &Device, +) -> CandleResult { + let logits = logits.to_dtype(DType::F32)?.to_device(&Device::Cpu)?; + let mut logits_vec = logits.to_vec1::()?; + apply_sampling_processors(&mut logits_vec, history, params); + + let vocab = logits_vec.len(); + let processed = Tensor::from_vec(logits_vec, vocab, device)?; + cnn_ops::log_softmax(&processed, D::Minus1) +} + +fn apply_sampling_processors(logits: &mut [f32], history: &[u32], params: &SamplingParams) { + apply_repetition_penalty(logits, history, params.repetition_penalty); + apply_no_repeat_ngram(logits, history, params.no_repeat_ngram_size); + + if !params.do_sample || params.top_k == 1 { + return; + } + let temp = if params.temperature <= 0.0 { 1.0 } else { params.temperature }; if (temp - 1.0).abs() > f32::EPSILON { - for val in logits_vec.iter_mut() { + for val in logits.iter_mut() { *val /= temp; } } - apply_top_k(&mut logits_vec, params.top_k); - apply_top_p(&mut logits_vec, params.top_p); - - let probs = softmax(&logits_vec); - if let Some(idx) = sample_from_probs(&probs) { - Ok(idx as u32) - } else { - Ok(argmax_token(&logits_vec)) - } + apply_top_k(logits, params.top_k); + apply_top_p(logits, params.top_p); } fn argmax_token(logits: &[f32]) -> u32 { @@ -941,3 +1652,98 @@ fn get_rope_index( Ok((position_ids, rope_delta)) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mineru_task_prompt_text_recognition_matches_official() { + assert_eq!(MinerUTaskPrompt::Text.prompt(), "\nText Recognition:"); + } + + #[test] + fn mineru_task_prompt_formula_recognition_matches_official() { + assert_eq!(MinerUTaskPrompt::Formula.prompt(), "\nFormula Recognition:"); + } + + #[test] + fn mineru_task_prompt_table_recognition_matches_official() { + assert_eq!(MinerUTaskPrompt::Table.prompt(), "\nTable Recognition:"); + } + + #[test] + fn mineru_task_prompt_image_analysis_matches_official() { + assert_eq!( + MinerUTaskPrompt::ImageAnalysis.prompt(), + "\nImage Analysis:" + ); + } + + #[test] + fn mineru_task_prompt_layout_detection_matches_official() { + assert_eq!( + MinerUTaskPrompt::LayoutDetection.prompt(), + "\nLayout Detection:" + ); + } + + #[test] + fn for_layout_routes_table_kinds_to_table() { + assert_eq!( + MinerUTaskPrompt::for_layout(LayoutElementType::Table), + MinerUTaskPrompt::Table + ); + } + + #[test] + fn for_layout_routes_formula_kinds_to_formula() { + assert_eq!( + MinerUTaskPrompt::for_layout(LayoutElementType::Formula), + MinerUTaskPrompt::Formula + ); + assert_eq!( + MinerUTaskPrompt::for_layout(LayoutElementType::FormulaNumber), + MinerUTaskPrompt::Formula + ); + } + + #[test] + fn for_layout_routes_visual_kinds_to_image_analysis() { + for ty in [ + LayoutElementType::Image, + LayoutElementType::Chart, + LayoutElementType::Seal, + LayoutElementType::HeaderImage, + LayoutElementType::FooterImage, + ] { + assert_eq!( + MinerUTaskPrompt::for_layout(ty), + MinerUTaskPrompt::ImageAnalysis, + "expected ImageAnalysis for {ty:?}", + ); + } + } + + #[test] + fn for_layout_defaults_text_for_text_like_kinds() { + for ty in [ + LayoutElementType::Text, + LayoutElementType::Content, + LayoutElementType::DocTitle, + LayoutElementType::ParagraphTitle, + LayoutElementType::List, + LayoutElementType::Reference, + LayoutElementType::Footnote, + LayoutElementType::Number, + LayoutElementType::Header, + LayoutElementType::Footer, + ] { + assert_eq!( + MinerUTaskPrompt::for_layout(ty), + MinerUTaskPrompt::Text, + "expected Text for {ty:?}", + ); + } + } +} diff --git a/oar-ocr-vl/src/mineru/text.rs b/oar-ocr-vl/src/mineru/text.rs index 5df042f..acbe74c 100644 --- a/oar-ocr-vl/src/mineru/text.rs +++ b/oar-ocr-vl/src/mineru/text.rs @@ -2,11 +2,14 @@ use super::config::MinerUConfig; use crate::attention::{ RotaryEmbedding, repeat_kv, scaled_dot_product_attention, select_rope_sections, }; +#[cfg(feature = "hsd")] +use crate::hsd::TrimmableKvCache; +#[cfg(not(feature = "hsd"))] +use crate::kv_trim::TrimmableKvCache; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing, rotate_half}; use candle_core::Tensor; use candle_nn::{ - Embedding, Linear, Module, VarBuilder, embedding, kv_cache::KvCache, linear, linear_no_bias, - rms_norm, + Embedding, Linear, Module, VarBuilder, embedding, linear, linear_no_bias, rms_norm, }; use oar_ocr_core::core::OCRError; use std::cell::RefCell; @@ -132,7 +135,7 @@ struct MinerUAttention { head_dim: usize, scaling: f64, mrope_section: Vec, - kv_cache: RefCell, + kv_cache: RefCell, } impl MinerUAttention { @@ -174,7 +177,8 @@ impl MinerUAttention { ) .map_err(|e| candle_to_ocr_inference("MinerU2.5", "load o_proj", e))?; - let kv_cache = KvCache::new(2, cfg.max_position_embeddings.max(8192)); + // Trim/gather-capable KV cache (HSD verification path). + let kv_cache = TrimmableKvCache::new(2, cfg.max_position_embeddings.max(8192)); Ok(Self { q_proj, @@ -265,6 +269,19 @@ impl MinerUAttention { fn clear_kv_cache(&self) { self.kv_cache.borrow_mut().reset(); } + + #[cfg(feature = "hsd")] + fn current_kv_len(&self) -> usize { + self.kv_cache.borrow().current_seq_len() + } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.kv_cache + .borrow_mut() + .keep_indices(indices) + .map_err(|e| candle_to_ocr_inference("MinerU2.5", "keep_kv_indices", e)) + } } pub struct MinerUDecoderLayer { @@ -335,6 +352,11 @@ impl MinerUDecoderLayer { fn clear_kv_cache(&self) { self.self_attn.clear_kv_cache(); } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.self_attn.keep_kv_indices(indices) + } } pub struct MinerUTextModel { @@ -407,4 +429,24 @@ impl MinerUTextModel { layer.clear_kv_cache(); } } + + /// Current sequence length held in the KV cache. All layers stay in sync, + /// so we read it from layer 0. + #[cfg(feature = "hsd")] + pub fn current_kv_len(&self) -> usize { + self.layers + .first() + .map(|l| l.self_attn.current_kv_len()) + .unwrap_or(0) + } + + /// Gather every layer's KV cache to keep only the supplied positions + /// (in order). Used by HSD after tree-attention verification. + #[cfg(feature = "hsd")] + pub fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + for layer in &self.layers { + layer.keep_kv_indices(indices)?; + } + Ok(()) + } } diff --git a/oar-ocr-vl/src/paddleocr_vl/ernie.rs b/oar-ocr-vl/src/paddleocr_vl/ernie.rs index b6ab950..3669cf2 100644 --- a/oar-ocr-vl/src/paddleocr_vl/ernie.rs +++ b/oar-ocr-vl/src/paddleocr_vl/ernie.rs @@ -2,9 +2,13 @@ use super::config::PaddleOcrVlConfig; use crate::attention::{ RotaryEmbedding, repeat_kv, scaled_dot_product_attention, select_rope_sections, }; +#[cfg(feature = "hsd")] +use crate::hsd::TrimmableKvCache; +#[cfg(not(feature = "hsd"))] +use crate::kv_trim::TrimmableKvCache; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing, rotate_half}; use candle_core::Tensor; -use candle_nn::{Module, kv_cache::KvCache}; +use candle_nn::Module; use oar_ocr_core::core::OCRError; use std::cell::RefCell; @@ -134,7 +138,7 @@ struct Ernie4_5Attention { head_dim: usize, scaling: f64, mrope_section: Vec, - kv_cache: RefCell, + kv_cache: RefCell, } impl Ernie4_5Attention { @@ -186,7 +190,8 @@ impl Ernie4_5Attention { // Conservative estimate: vision tokens + max_generation_tokens // Typical: ~1000-2000 vision tokens + 4096 generation tokens = ~6000-8000 total // Use 16384 to handle worst case without reallocation - let kv_cache = KvCache::new(2, 16384); + // Trim/gather-capable KV cache (HSD verification path). + let kv_cache = TrimmableKvCache::new(2, 16384); Ok(Self { q_proj, @@ -296,6 +301,19 @@ impl Ernie4_5Attention { fn clear_kv_cache(&self) { self.kv_cache.borrow_mut().reset(); } + + #[cfg(feature = "hsd")] + fn current_kv_len(&self) -> usize { + self.kv_cache.borrow().current_seq_len() + } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.kv_cache + .borrow_mut() + .keep_indices(indices) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "keep_kv_indices", e)) + } } #[derive(Debug)] @@ -360,6 +378,11 @@ impl Ernie4_5DecoderLayer { fn clear_kv_cache(&self) { self.self_attn.clear_kv_cache(); } + + #[cfg(feature = "hsd")] + fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + self.self_attn.keep_kv_indices(indices) + } } #[derive(Debug)] @@ -425,4 +448,24 @@ impl Ernie4_5Model { layer.clear_kv_cache(); } } + + /// Current sequence length held in the KV cache (read from layer 0; all + /// layers stay in sync). + #[cfg(feature = "hsd")] + pub fn current_kv_len(&self) -> usize { + self.layers + .first() + .map(|l| l.self_attn.current_kv_len()) + .unwrap_or(0) + } + + /// Gather every layer's KV cache to keep only the supplied positions. + /// Used by HSD after tree-attention verification. + #[cfg(feature = "hsd")] + pub fn keep_kv_indices(&self, indices: &[u32]) -> Result<(), OCRError> { + for layer in &self.layers { + layer.keep_kv_indices(indices)?; + } + Ok(()) + } } diff --git a/oar-ocr-vl/src/paddleocr_vl/model.rs b/oar-ocr-vl/src/paddleocr_vl/model.rs index e4a3787..259cbf1 100644 --- a/oar-ocr-vl/src/paddleocr_vl/model.rs +++ b/oar-ocr-vl/src/paddleocr_vl/model.rs @@ -5,14 +5,37 @@ use super::ernie::Ernie4_5Model; use super::processing; use super::projector::Projector; use super::vision::VisionModel; +#[cfg(feature = "hsd")] +use crate::attention::create_tree_attention_mask; use crate::attention::{combine_masks, create_causal_mask, create_left_padding_mask}; +#[cfg(feature = "hsd")] +use crate::hsd::backend_util::{commit_keep_indices, step_pos_ids, tree_pos_ids}; +#[cfg(feature = "hsd")] +use crate::hsd::drafting::{ + TargetDraftAdapter, bbox_xyxy, crop_region_image, format_verified_region, map_layout_kind, + region_markdown_for, structure_result_to_layout_elements, +}; +#[cfg(feature = "hsd")] +use crate::hsd::prefix_tree::PrefixTree; +#[cfg(feature = "hsd")] +use crate::hsd::types::{AcceptStats, Draft, HsdConfig, HsdStats, RegionStageStats}; +#[cfg(feature = "hsd")] +use crate::hsd::verify::{SpecBackend, spec_decode}; use crate::utils::image::pil_resample_to_filter_type; use crate::utils::{candle_to_ocr_inference, candle_to_ocr_processing}; +#[cfg(feature = "hsd")] +use candle_core::Result as CandleResult; use candle_core::{D, DType, Device, IndexOp, Tensor}; use candle_nn::Module; +#[cfg(feature = "hsd")] +use candle_nn::ops as cnn_ops; use image::{RgbImage, imageops::FilterType}; use oar_ocr_core::core::OCRError; +#[cfg(feature = "hsd")] +use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, StructureResult}; use std::path::Path; +#[cfg(feature = "hsd")] +use std::time::{Duration, Instant}; use tokenizers::Tokenizer; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -186,7 +209,49 @@ impl PaddleOcrVl { })]; } - match self.generate_internal(images, tasks, max_new_tokens) { + match self.generate_tokens_internal(images, tasks, max_new_tokens) { + Ok(results) => results + .into_iter() + .enumerate() + .map(|(i, tokens)| self.decode_generated_tokens(&tokens, tasks[i])) + .collect(), + Err(e) => { + let msg = format!("generation failed: {e}"); + (0..images.len()) + .map(|_| { + Err(OCRError::InvalidInput { + message: msg.clone(), + }) + }) + .collect() + } + } + } + + /// Generate raw baseline tokens for oracle-draft / tokenizer round-trip + /// experiments. Tokens are exactly the ids emitted by the decode loop, + /// excluding EOS / separator stop tokens, before tokenizer decoding or + /// task postprocessing. + pub fn generate_tokens( + &self, + images: &[RgbImage], + tasks: &[PaddleOcrVlTask], + max_new_tokens: usize, + ) -> Vec, OCRError>> { + if images.is_empty() { + return Vec::new(); + } + if images.len() != tasks.len() { + return vec![Err(OCRError::InvalidInput { + message: format!( + "PaddleOCR-VL: images count ({}) != tasks count ({})", + images.len(), + tasks.len() + ), + })]; + } + + match self.generate_tokens_internal(images, tasks, max_new_tokens) { Ok(results) => results.into_iter().map(Ok).collect(), Err(e) => { let msg = format!("generation failed: {e}"); @@ -202,12 +267,12 @@ impl PaddleOcrVl { } /// Internal generation implementation supporting batched inference. - fn generate_internal( + fn generate_tokens_internal( &self, images: &[RgbImage], tasks: &[PaddleOcrVlTask], max_new_tokens: usize, - ) -> Result, OCRError> { + ) -> Result>, OCRError> { let batch_size = images.len(); // 1. Preprocess all images @@ -303,7 +368,11 @@ impl PaddleOcrVl { // 4. Build embeddings per sample let seq_lens: Vec = all_input_ids.iter().map(|ids| ids.len()).collect(); - let max_seq_len = *seq_lens.iter().max().unwrap(); + let Some(&max_seq_len) = seq_lens.iter().max() else { + return Err(OCRError::InvalidInput { + message: "PaddleOCR-VL: empty batch is not supported".to_string(), + }); + }; let mut batch_embeds: Vec = Vec::with_capacity(batch_size); let mut rope_deltas: Vec = Vec::with_capacity(batch_size); @@ -502,20 +571,571 @@ impl PaddleOcrVl { } } - // 10. Decode results - let mut results = Vec::with_capacity(batch_size); - for (i, tokens) in generated.into_iter().enumerate() { - let decoded = + Ok(generated) + } + + pub fn decode_tokens( + &self, + tokens: &[u32], + task: PaddleOcrVlTask, + ) -> Result<(String, String), OCRError> { + self.decode_generated_tokens(tokens, task) + } + + /// Decode tokens **without** applying PaddleOCR-VL's task-specific + /// post-process (OTSL→HTML for tables, `$$..$$` stripping for formulas). + /// This is the raw pre-postprocess string the model actually emitted — + /// use this when feeding PaddleOCR-VL output as a draft for another + /// target VLM. DSV matches at token granularity, so any post-process on + /// the source side will byte-mismatch the target's natural output. + pub fn decode_tokens_raw(&self, tokens: &[u32]) -> Result { + self.tokenizer + .decode(tokens, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("decode failed: {e}"), + }) + } + + pub fn tokenizer(&self) -> &Tokenizer { + &self.tokenizer + } + + fn decode_generated_tokens( + &self, + tokens: &[u32], + task: PaddleOcrVlTask, + ) -> Result<(String, String), OCRError> { + let decoded = self.decode_tokens_raw(tokens)?; + let processed = task.postprocess(decoded.clone()); + Ok((decoded, processed)) + } + + /// Hierarchical Speculative Decoding entry for a single image / region. + /// + /// PaddleOCR-VL is naturally per-region (no full-page mode), so the only + /// HSD stage that applies is region-level. `task` selects the prompt + /// prefix (Ocr / Table / Formula / …); `drafts` are the lightweight + /// pipeline drafter's region-text candidates, tokenized with PaddleOCR-VL's + /// own tokenizer before being matched in the verifier. + #[cfg(feature = "hsd")] + pub fn generate_hsd( + &self, + image: &RgbImage, + task: PaddleOcrVlTask, + drafts: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let t_drafter = Instant::now(); + + let mut tokenized: Vec = Vec::with_capacity(drafts.len()); + for d in drafts { + if d.trim().is_empty() { + continue; + } + let enc = self.tokenizer - .decode(&tokens, true) + .encode(d.as_str(), false) .map_err(|e| OCRError::InvalidInput { - message: format!("decode failed: {e}"), + message: format!("PaddleOCR-VL HSD: tokenizer encode failed: {e}"), })?; - let processed = tasks[i].postprocess(decoded.clone()); - results.push((decoded, processed)); + let tokens = enc.get_ids().to_vec(); + if !tokens.is_empty() { + tokenized.push(Draft::new(tokens)); + } + } + self.generate_hsd_tokenized(image, task, &tokenized, hsd_cfg, t_drafter.elapsed()) + } + + /// HSD entry that consumes already-tokenized drafts. This is the oracle + /// path used by benchmarks to avoid `decode -> encode` tokenizer + /// round-trips when the draft comes from this backend's own baseline. + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_token_drafts( + &self, + image: &RgbImage, + task: PaddleOcrVlTask, + drafts: &[Draft], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + self.generate_hsd_tokenized(image, task, drafts, hsd_cfg, Duration::ZERO) + } + + #[cfg(feature = "hsd")] + fn generate_hsd_tokenized( + &self, + image: &RgbImage, + task: PaddleOcrVlTask, + tokenized: &[Draft], + hsd_cfg: &HsdConfig, + drafter_elapsed: Duration, + ) -> Result<(String, HsdStats), OCRError> { + if !self.device.is_cuda() { + return Err(OCRError::ConfigError { + message: "HSD requires CUDA device".to_string(), + }); + } + + let mut stats = HsdStats { + drafter: drafter_elapsed, + ..Default::default() + }; + // Stage 2 (page-level) terminology is reused here for stat bookkeeping. + let t_pre = Instant::now(); + let (initial_lp, rope_delta) = self.hsd_prefill_single(image, task)?; + stats.stage2.vision_prefill = t_pre.elapsed(); + stats.stage2.forward_passes = 1; + + let t_dec = Instant::now(); + let mut backend = PaddleOcrVlSpecBackend::new(self, rope_delta); + let mut accept = AcceptStats::default(); + let mut dsv = Default::default(); + let generated = spec_decode( + &mut backend, + tokenized, + initial_lp, + hsd_cfg.max_region_tokens, + &hsd_cfg.dsv, + &mut accept, + &mut dsv, + ) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "spec_decode", e))?; + stats.stage2.decode = t_dec.elapsed(); + stats.stage2.emitted_tokens = generated.len() as u32; + stats.stage2.accept = accept; + stats.stage2.dsv = dsv; + stats.stage2.forward_passes += backend.forward_passes; + + // Strip the first stop token and anything after it before decoding. + let stop_pos = generated + .iter() + .position(|&t| t == self.eos_token_id || self.sep_token_id.is_some_and(|id| id == t)) + .unwrap_or(generated.len()); + let truncated = &generated[..stop_pos]; + + let decoded = + self.tokenizer + .decode(truncated, true) + .map_err(|e| OCRError::InvalidInput { + message: format!("PaddleOCR-VL HSD: tokenizer decode failed: {e}"), + })?; + let postprocessed = task.postprocess(decoded); + Ok((postprocessed, stats)) + } + + /// Run HSD per element across an entire layout-detected page, then + /// aggregate the per-region outputs into a markdown-style document. + /// + /// **No Stage 2 page-level verify** — PaddleOCR-VL is element-only by + /// design (its prompts are task-scoped: "OCR:", "Table Recognition:", + /// etc.), so there is no native single-prompt page-level inference to + /// verify against. The HunyuanOCR / GLM-OCR / MinerU `generate_hsd_full` + /// paths *do* run Stage 2 because their target prompts can describe a + /// whole page. For full HSD over a PaddleOCR-VL document use the + /// `doc_parser` flow (layout + per-region HSD) — what this function + /// already does. + /// + /// Element-type → task mapping follows the same heuristic as + /// `doc_parser::DocParser` (Table → Table, Formula → Formula, etc.). + /// Visual-only regions (Image / Seal / HeaderImage / FooterImage) are + /// skipped, mirroring the `task_for_element_type` logic. + #[cfg(feature = "hsd")] + pub fn generate_hsd_full( + &self, + image: &RgbImage, + elements: &[LayoutElement], + ignore_labels: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + self.generate_hsd_full_impl(image, elements, ignore_labels, hsd_cfg, None) + } + + /// One-call HSD entry that consumes a `StructureResult` (the output of + /// the OARStructure / PP-StructureV3 pipeline) directly. + /// + /// Backfills table HTML / formula LaTeX via + /// [`structure_result_to_layout_elements`] then delegates to + /// [`Self::generate_hsd_full`]. PaddleOCR-VL remains element-level (no + /// Stage 2) — `page_instruction` / `region_instruction` are not part of + /// this signature because the per-element prompt is picked from the + /// layout type by `task_for_element_type` inside `generate_hsd_full_impl`. + #[cfg(feature = "hsd")] + pub fn generate_hsd_with_structure( + &self, + image: &RgbImage, + structure: &StructureResult, + ignore_labels: &[String], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + let elements = structure_result_to_layout_elements(structure); + self.generate_hsd_full(image, &elements, ignore_labels, hsd_cfg) + } + + /// Full per-region HSD where some regions already have target-token drafts. + /// `token_drafts[i]`, when present, is used for `elements[i]` directly and + /// avoids re-tokenizing decoded baseline text. + #[cfg(feature = "hsd")] + pub fn generate_hsd_full_with_token_drafts( + &self, + image: &RgbImage, + elements: &[LayoutElement], + ignore_labels: &[String], + token_drafts: &[Option>], + hsd_cfg: &HsdConfig, + ) -> Result<(String, HsdStats), OCRError> { + if token_drafts.len() != elements.len() { + return Err(OCRError::InvalidInput { + message: format!( + "PaddleOCR-VL HSD: token draft count ({}) != element count ({})", + token_drafts.len(), + elements.len() + ), + }); } + self.generate_hsd_full_impl(image, elements, ignore_labels, hsd_cfg, Some(token_drafts)) + } + + #[cfg(feature = "hsd")] + fn generate_hsd_full_impl( + &self, + image: &RgbImage, + elements: &[LayoutElement], + ignore_labels: &[String], + hsd_cfg: &HsdConfig, + token_drafts: Option<&[Option>]>, + ) -> Result<(String, HsdStats), OCRError> { + let mut stats = HsdStats::default(); + let mut region_md: Vec<(usize, String)> = Vec::with_capacity(elements.len()); + + for (idx, elem) in elements.iter().enumerate() { + if let Some(label) = &elem.label + && ignore_labels.iter().any(|l| l == label) + { + continue; + } + let Some(task) = task_for_layout_type(elem.element_type) else { + continue; + }; + let token_draft = token_drafts + .and_then(|drafts| drafts.get(idx)) + .and_then(|d| d.as_ref()) + .filter(|d| !d.is_empty()); + // Format `elem.text` into PaddleOCR-VL's raw (pre-postprocess) form: + // tables stay OTSL, formulas get `$$..$$` wrapping (post-process + // strips it), other elements stay plain. This avoids the trap + // where a layout pipeline's HTML/markdown is byte-incompatible + // with what the VLM actually emits as logits. + let text_draft = region_markdown_for(elem, TargetDraftAdapter::PaddleOcrVl); + let text_draft = text_draft.trim(); + let text_draft = if text_draft.is_empty() { + None + } else { + Some(text_draft) + }; + if token_draft.is_none() && text_draft.is_none() { + continue; + } + + let bbox = bbox_xyxy(&elem.bbox); + let crop = crop_region_image(image, &bbox)?; + let (region_text, region_stats) = if let Some(tokens) = token_draft { + let drafts = vec![Draft::new(tokens.clone())]; + self.generate_hsd_with_token_drafts(&crop, task, &drafts, hsd_cfg)? + } else { + let drafts = vec![text_draft.expect("checked above").to_string()]; + self.generate_hsd(&crop, task, &drafts, hsd_cfg)? + }; + // Accumulate stats (region-level passes are stored under stage1). + stats.drafter += region_stats.drafter; + + let kind = map_layout_kind(elem.element_type); + stats.stage1_regions.push(RegionStageStats { + kind, + stats: region_stats.stage2.clone(), + }); + stats.stage1.add_assign(region_stats.stage2); + let order = elem.order_index.map(|x| x as usize).unwrap_or(idx); + region_md.push((order, format_verified_region(®ion_text, kind))); + } + + region_md.sort_by_key(|(order, _)| *order); + let merged = region_md + .into_iter() + .map(|(_, text)| text) + .filter(|s| !s.trim().is_empty()) + .collect::>() + .join("\n\n"); + Ok((merged, stats)) + } + + /// Run a single-image prefill with the supplied task prompt. Returns + /// the F32 last-position log-probabilities and the MRoPE delta that + /// post-image text positions need to add to their token index. + #[cfg(feature = "hsd")] + fn hsd_prefill_single( + &self, + image: &RgbImage, + task: PaddleOcrVlTask, + ) -> Result<(Tensor, i64), OCRError> { + // Optional spotting upscaling — mirror generate_internal's behaviour. + let resized; + let image_for_pp: &RgbImage = if task.needs_spotting_preprocess() + && image.width() < SPOTTING_UPSCALE_THRESHOLD + && image.height() < SPOTTING_UPSCALE_THRESHOLD + { + let resize_filter = self + .image_cfg + .resample + .and_then(pil_resample_to_filter_type) + .unwrap_or(FilterType::CatmullRom); + resized = image::imageops::resize( + image, + image.width().saturating_mul(2), + image.height().saturating_mul(2), + resize_filter, + ); + &resized + } else { + image + }; + let max_pixels = if task.needs_spotting_preprocess() { + let factor = (self.image_cfg.patch_size * self.image_cfg.merge_size) as u32; + let spotting_max_pixels = SPOTTING_MAX_LONG_SIDE + .saturating_mul(factor) + .saturating_mul(factor); + self.image_cfg.max_pixels.max(spotting_max_pixels) + } else { + self.image_cfg.max_pixels + }; + + let image_inputs = processing::preprocess_images_with_max_pixels( + std::slice::from_ref(image_for_pp), + &self.image_cfg, + &self.device, + self.dtype, + max_pixels, + )?; + + let (t, h, w) = image_inputs.image_grid_thw[0]; + let image_token_count = + (t * h * w) / (self.image_cfg.merge_size * self.image_cfg.merge_size); + + // Build prompt: prefix + N×<|IMAGE_PLACEHOLDER|> + suffix. + let prefix = "<|begin_of_sentence|>User: <|IMAGE_START|>"; + let suffix = format!("<|IMAGE_END|>{}\n{}", task.prompt(), self.assistant_prefix); + let prefix_enc = + self.tokenizer + .encode(prefix, false) + .map_err(|e| OCRError::InvalidInput { + message: format!("PaddleOCR-VL HSD: tokenizer encode failed: {e}"), + })?; + let suffix_enc = + self.tokenizer + .encode(suffix.as_str(), false) + .map_err(|e| OCRError::InvalidInput { + message: format!("PaddleOCR-VL HSD: tokenizer encode failed: {e}"), + })?; + let mut input_ids = + Vec::with_capacity(prefix_enc.len() + image_token_count + suffix_enc.len()); + input_ids.extend_from_slice(prefix_enc.get_ids()); + input_ids.extend(std::iter::repeat_n( + self.image_placeholder_token_id, + image_token_count, + )); + input_ids.extend_from_slice(suffix_enc.get_ids()); + let seq_len = input_ids.len(); + + // Vision features → projector → image embeddings. + let vision_feats = self + .vision + .forward(&image_inputs.pixel_values, &image_inputs.image_grid_thw)?; + let image_embeds = self + .projector + .forward(&vision_feats, &image_inputs.image_grid_thw)?; + + // Build token embeddings, then splice in the image embeddings. + let input_ids_t = Tensor::new(input_ids.clone(), &self.device) + .and_then(|t| t.reshape((1, seq_len))) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "create input_ids", e))?; + let mut inputs_embeds = self.llm.embed(&input_ids_t)?; + + if let Some(first_pos) = input_ids + .iter() + .position(|&id| id == self.cfg.image_token_id) + { + let image_end = first_pos + image_token_count; + let mut parts: Vec = Vec::with_capacity(3); + if first_pos > 0 { + parts.push( + inputs_embeds + .narrow(1, 0, first_pos) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "narrow prefix", e))?, + ); + } + parts.push( + image_embeds + .unsqueeze(0) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "unsqueeze img", e))?, + ); + if image_end < seq_len { + parts.push( + inputs_embeds + .narrow(1, image_end, seq_len - image_end) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "narrow suffix", e))?, + ); + } + let refs: Vec<&Tensor> = parts.iter().collect(); + inputs_embeds = Tensor::cat(&refs, 1) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "cat embeds", e))?; + } + + // Position IDs (3-axis MRoPE) + delta for post-image text positions. + let (pos_ids, rope_delta) = get_rope_index( + &self.cfg, + &input_ids, + &[image_inputs.image_grid_thw[0]], + &self.device, + )?; + + // Causal mask for prefill. + let causal = create_causal_mask(seq_len, seq_len, self.dtype, &self.device) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "create causal", e))?; + + self.llm.clear_kv_cache(); + let hidden = self.llm.forward(&inputs_embeds, &pos_ids, Some(&causal))?; + + let last = hidden + .i((0, seq_len - 1, ..)) + .and_then(|t| t.unsqueeze(0)) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "get last hidden", e))?; + // Project via lm_head → log-softmax in F32. + let logits = self + .lm_head + .forward(&last) + .and_then(|t| t.squeeze(0)) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "lm_head prefill", e))?; + let lp = cnn_ops::log_softmax( + &logits + .to_dtype(DType::F32) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "logits to f32", e))?, + D::Minus1, + ) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "log_softmax prefill", e))?; + Ok((lp, rope_delta)) + } +} + +/// Map a layout element type to the PaddleOCR-VL recognition task. Returns +/// `None` for pure-visual regions that have no textual content to verify. +#[cfg(feature = "hsd")] +fn task_for_layout_type(t: LayoutElementType) -> Option { + use LayoutElementType::*; + match t { + Table => Some(PaddleOcrVlTask::Table), + Chart => Some(PaddleOcrVlTask::Chart), + Formula => Some(PaddleOcrVlTask::Formula), + FormulaNumber => Some(PaddleOcrVlTask::Ocr), + Image | HeaderImage | FooterImage | Seal => None, + _ => Some(PaddleOcrVlTask::Ocr), + } +} + +/// HSD adapter for PaddleOCR-VL. The MRoPE position offset (`rope_delta`) +/// is captured during prefill and re-used by every subsequent step / verify. +#[cfg(feature = "hsd")] +struct PaddleOcrVlSpecBackend<'a> { + model: &'a PaddleOcrVl, + rope_delta: i64, + pre_verify_kv: usize, + forward_passes: u32, +} + +#[cfg(feature = "hsd")] +impl<'a> PaddleOcrVlSpecBackend<'a> { + fn new(model: &'a PaddleOcrVl, rope_delta: i64) -> Self { + Self { + model, + rope_delta, + pre_verify_kv: 0, + forward_passes: 0, + } + } + + fn project_logprobs_2d(&self, hidden_2d: &Tensor) -> CandleResult { + // (N, hidden) → (N, vocab) → log-softmax F32. + let logits = self.model.lm_head.forward(hidden_2d)?; + cnn_ops::log_softmax(&logits.to_dtype(DType::F32)?, D::Minus1) + } + + fn project_logprobs_1d(&self, hidden_1d: &Tensor) -> CandleResult { + // (hidden,) → (1, hidden) → (1, vocab) → (vocab,) → log-softmax F32. + let logits = self + .model + .lm_head + .forward(&hidden_1d.unsqueeze(0)?)? + .squeeze(0)?; + cnn_ops::log_softmax(&logits.to_dtype(DType::F32)?, D::Minus1) + } +} + +#[cfg(feature = "hsd")] +impl<'a> SpecBackend for PaddleOcrVlSpecBackend<'a> { + fn step_one(&mut self, token: u32) -> CandleResult { + let model = self.model; + let device = &model.device; + + let tok_t = Tensor::new(vec![token], device)?.reshape((1usize, 1usize))?; + let embeds = model.llm.embed(&tok_t).map_err(|e| { + candle_core::Error::Msg(format!("PaddleOCR-VL HSD step_one embed: {e}")) + })?; + + let pos_ids = step_pos_ids(3, model.llm.current_kv_len(), self.rope_delta, device)?; + + let hidden = model.llm.forward(&embeds, &pos_ids, None).map_err(|e| { + candle_core::Error::Msg(format!("PaddleOCR-VL HSD step_one forward: {e}")) + })?; + self.forward_passes += 1; + let last = hidden.i((0, 0, ..))?; + self.project_logprobs_1d(&last) + } + + fn verify_tree(&mut self, tree: &PrefixTree) -> CandleResult { + let n = tree.num_nodes(); + let model = self.model; + let device = &model.device; + let dtype = model.dtype; + + let prefix_kv = model.llm.current_kv_len(); + self.pre_verify_kv = prefix_kv; + + let tok_t = Tensor::new(tree.tokens.clone(), device)?.reshape((1usize, n))?; + let embeds = model.llm.embed(&tok_t).map_err(|e| { + candle_core::Error::Msg(format!("PaddleOCR-VL HSD verify_tree embed: {e}")) + })?; + + let pos_ids = tree_pos_ids(3, prefix_kv, self.rope_delta, tree, device)?; + let mask = create_tree_attention_mask(&tree.parents, prefix_kv, dtype, device)?; + + let hidden = model + .llm + .forward(&embeds, &pos_ids, Some(&mask)) + .map_err(|e| { + candle_core::Error::Msg(format!("PaddleOCR-VL HSD verify_tree forward: {e}")) + })?; + self.forward_passes += 1; + let h2 = hidden.squeeze(0)?; + self.project_logprobs_2d(&h2) + } + + fn commit_verify(&mut self, accepted_path: &[usize]) -> CandleResult<()> { + let indices = commit_keep_indices(self.pre_verify_kv, accepted_path); + self.model + .llm + .keep_kv_indices(&indices) + .map_err(|e| candle_core::Error::Msg(format!("PaddleOCR-VL HSD commit_verify: {e}"))) + } - Ok(results) + fn is_eos(&self, tok: u32) -> bool { + tok == self.model.eos_token_id || self.model.sep_token_id.is_some_and(|id| id == tok) } } diff --git a/oar-ocr-vl/src/paddleocr_vl/vision.rs b/oar-ocr-vl/src/paddleocr_vl/vision.rs index 928881c..88ea04c 100644 --- a/oar-ocr-vl/src/paddleocr_vl/vision.rs +++ b/oar-ocr-vl/src/paddleocr_vl/vision.rs @@ -5,6 +5,28 @@ use candle_nn::Module; use oar_ocr_core::core::OCRError; use rayon::prelude::*; +/// Above this seq length the vision attention computes softmax in chunks +/// along the query dim, instead of allocating the full `[seq, seq]` F32 buffer. +/// Mirrors the memory profile of PyTorch's `F.scaled_dot_product_attention`, +/// which is what `modeling_paddleocr_vl.py::PaddleOCRAttention` picks at +/// runtime when `_attn_implementation == "sdpa"` (the transformers default for +/// models with `_supports_sdpa = True`). +/// +/// The threshold is set above PaddleOCR-VL-1.5's worst-case seq length so +/// that v1.5 always takes the single-shot full-matrix path. Vision attention +/// runs on the pre-merge SigLIP patch grid (patch_size² = 196), so v1.5 at +/// max_pixels = 1_003_520 produces seq ≈ 5040 and v1 at max_pixels = +/// 2_822_400 produces seq ≈ 14200. Going through chunked matmul changes +/// cuBLAS' kernel tiling and shifts a few low-confidence argmax tokens; +/// keeping v1.5 on its original path preserves byte-stable output. Above +/// 8192 the chunked path kicks in — that's the regime where the full +/// `[seq, seq]` F32 buffer would OOM (12+ GB on v1 / chart_01.jpg). +const ATTN_FULL_SEQ_THRESHOLD: usize = 8192; +/// Query chunk size when chunked attention kicks in. 512 keeps each chunk's +/// F32 softmax scratch at `chunk * seq_k * num_heads * 4 bytes` — about +/// 110 MB at seq_k = 3600 / heads = 16, well within VRAM headroom. +const ATTN_CHUNK_SIZE: usize = 512; + /// SigLIP-style 2D rotary embedding for vision encoder #[derive(Debug, Clone)] struct SigLIPRotaryEmbedding { @@ -341,33 +363,82 @@ impl VisionAttention { ) }; - let attn_weights = q - .matmul( - &k.transpose(2, 3) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision k t23", e))? - .contiguous() + let kt = k + .transpose(2, 3) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision k t23", e))? + .contiguous() + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision k t23 contiguous", e))?; + + // Chunked attention over the query dim to mirror PyTorch's `sdpa_attention_forward` + // memory profile (no full [seq, seq] materialization). Each iteration only allocates + // O(chunk * seq_k) F32 scratch for the softmax, instead of O(seq_q * seq_k) for the + // full eager path. Numerically equivalent to the original code, and matches Python's + // PaddleOCRAttention which auto-selects sdpa at runtime (modeling_paddleocr_vl.py:1298). + let attn_output = if seq <= ATTN_FULL_SEQ_THRESHOLD { + // Original eager attention path — kept byte-identical to the pre-fix + // implementation so that small/medium seq inputs (incl. all v1.5 + // inputs) produce the exact same output. Removing the final + // `.contiguous()` here changes cuBLAS' matmul-shape selection and + // shifts low-confidence argmax tokens, so we preserve it. + let scores = q + .matmul(&kt) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision qk matmul", e))? + .affine(self.scale, 0.0) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision scaling", e))?; + let probs = + candle_nn::ops::softmax_last_dim(&scores.to_dtype(DType::F32).map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision attn cast f32", e) + })?) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision attn softmax", e))? + .to_dtype(v.dtype()) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision attn cast back", e))? + .contiguous() + .map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision attn contiguous", e) + })?; + probs + .matmul(&v) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision av matmul", e))? + } else { + let mut chunks: Vec = Vec::with_capacity(seq.div_ceil(ATTN_CHUNK_SIZE)); + let mut start = 0; + while start < seq { + let len = ATTN_CHUNK_SIZE.min(seq - start); + let q_chunk = q + .narrow(2, start, len) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision q narrow", e))?; + let scores = q_chunk + .matmul(&kt) + .map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision qk matmul (chunk)", e) + })? + .affine(self.scale, 0.0) .map_err(|e| { - candle_to_ocr_inference("PaddleOCR-VL", "vision k t23 contiguous", e) + candle_to_ocr_inference("PaddleOCR-VL", "vision scaling (chunk)", e) + })?; + let probs = candle_nn::ops::softmax_last_dim( + &scores.to_dtype(DType::F32).map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision attn cast f32 (chunk)", e) })?, - ) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision qk matmul", e))? - .affine(self.scale, 0.0) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision scaling", e))?; - - let attn_weights = candle_nn::ops::softmax_last_dim( - &attn_weights - .to_dtype(DType::F32) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision attn cast f32", e))?, - ) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision attn softmax", e))? - .to_dtype(v.dtype()) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision attn cast back", e))? - .contiguous() - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision attn contiguous", e))?; - - let attn_output = attn_weights - .matmul(&v) - .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision av matmul", e))? + ) + .map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision attn softmax (chunk)", e) + })? + .to_dtype(v.dtype()) + .map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision attn cast back (chunk)", e) + })?; + let out_chunk = probs.matmul(&v).map_err(|e| { + candle_to_ocr_inference("PaddleOCR-VL", "vision av matmul (chunk)", e) + })?; + chunks.push(out_chunk); + start += len; + } + Tensor::cat(&chunks, 2) + .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision av concat", e))? + }; + + let attn_output = attn_output .transpose(1, 2) .map_err(|e| candle_to_ocr_inference("PaddleOCR-VL", "vision out transpose", e))? .reshape((b, seq, embed_dim)) diff --git a/oar-ocr-vl/src/unirec/config.rs b/oar-ocr-vl/src/unirec/config.rs deleted file mode 100644 index a3b68fb..0000000 --- a/oar-ocr-vl/src/unirec/config.rs +++ /dev/null @@ -1,240 +0,0 @@ -//! UniRec model configuration. - -use serde::{Deserialize, Serialize}; -use std::path::Path; - -use oar_ocr_core::core::OCRError; - -/// UniRec model configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UniRecConfig { - /// Model dimensionality (d_model) - #[serde(default = "default_d_model")] - pub d_model: usize, - - /// Vocabulary size - #[serde(default = "default_vocab_size")] - pub vocab_size: usize, - - /// Number of decoder layers - #[serde(default = "default_decoder_layers")] - pub decoder_layers: usize, - - /// Number of decoder attention heads - #[serde(default = "default_decoder_attention_heads")] - pub decoder_attention_heads: usize, - - /// Decoder FFN dimension - #[serde(default = "default_decoder_ffn_dim")] - pub decoder_ffn_dim: usize, - - /// Encoder sequence length (vision features) - #[serde(default = "default_encoder_seq_len")] - pub encoder_seq_len: usize, - - /// Input image height - #[serde(default = "default_input_height")] - pub input_height: usize, - - /// Input image width - #[serde(default = "default_input_width")] - pub input_width: usize, - - /// Maximum sequence length for generation - #[serde(default = "default_max_length")] - pub max_length: usize, - - /// BOS token id - #[serde(default = "default_bos_token_id")] - pub bos_token_id: u32, - - /// EOS token id - #[serde(default = "default_eos_token_id")] - pub eos_token_id: u32, - - /// PAD token id - #[serde(default = "default_pad_token_id")] - pub pad_token_id: u32, - - /// Decoder start token id - #[serde(default = "default_decoder_start_token_id")] - pub decoder_start_token_id: u32, - - /// Dropout rate - #[serde(default = "default_dropout")] - pub dropout: f64, - - /// Attention dropout rate - #[serde(default = "default_attention_dropout")] - pub attention_dropout: f64, - - /// Whether to scale embeddings - #[serde(default = "default_scale_embedding")] - pub scale_embedding: bool, - - // FocalSVTR encoder config - /// Vision encoder embed dimension (base) - #[serde(default = "default_encoder_embed_dim")] - pub encoder_embed_dim: usize, - - /// Vision encoder layer depths - #[serde(default = "default_encoder_depths")] - pub encoder_depths: Vec, - - /// Vision encoder focal levels - #[serde(default = "default_focal_levels")] - pub focal_levels: Vec, - - /// Vision encoder focal windows - #[serde(default = "default_focal_windows")] - pub focal_windows: Vec, - - /// Vision encoder max kernel heights - #[serde(default = "default_max_khs")] - pub max_khs: Vec, - - /// Vision encoder subsampling kernels - #[serde(default = "default_sub_k")] - pub sub_k: Vec<(usize, usize)>, -} - -fn default_d_model() -> usize { - 768 -} - -fn default_vocab_size() -> usize { - 56371 -} - -fn default_decoder_layers() -> usize { - 6 -} - -fn default_decoder_attention_heads() -> usize { - 12 -} - -fn default_decoder_ffn_dim() -> usize { - 3072 -} - -fn default_encoder_seq_len() -> usize { - 1320 -} - -fn default_input_height() -> usize { - 1408 // max_height from Python max_side[1] -} - -fn default_input_width() -> usize { - 960 // max_width from Python max_side[0] -} - -fn default_max_length() -> usize { - 2048 -} - -fn default_bos_token_id() -> u32 { - 0 -} - -fn default_eos_token_id() -> u32 { - 2 -} - -fn default_pad_token_id() -> u32 { - 1 -} - -fn default_decoder_start_token_id() -> u32 { - 0 -} - -fn default_dropout() -> f64 { - 0.0 -} - -fn default_attention_dropout() -> f64 { - 0.0 -} - -fn default_scale_embedding() -> bool { - true -} - -fn default_encoder_embed_dim() -> usize { - 96 -} - -fn default_encoder_depths() -> Vec { - vec![2, 2, 9, 2] -} - -fn default_focal_levels() -> Vec { - vec![3, 3, 3, 3] -} - -fn default_focal_windows() -> Vec { - vec![3, 3, 3, 3] -} - -fn default_max_khs() -> Vec { - vec![7, 3, 3, 3] -} - -fn default_sub_k() -> Vec<(usize, usize)> { - vec![(2, 2), (2, 2), (2, 2), (0, 0)] // (0,0) means no downsampling for last stage -} - -impl Default for UniRecConfig { - fn default() -> Self { - Self { - d_model: default_d_model(), - vocab_size: default_vocab_size(), - decoder_layers: default_decoder_layers(), - decoder_attention_heads: default_decoder_attention_heads(), - decoder_ffn_dim: default_decoder_ffn_dim(), - encoder_seq_len: default_encoder_seq_len(), - input_height: default_input_height(), - input_width: default_input_width(), - max_length: default_max_length(), - bos_token_id: default_bos_token_id(), - eos_token_id: default_eos_token_id(), - pad_token_id: default_pad_token_id(), - decoder_start_token_id: default_decoder_start_token_id(), - dropout: default_dropout(), - attention_dropout: default_attention_dropout(), - scale_embedding: default_scale_embedding(), - encoder_embed_dim: default_encoder_embed_dim(), - encoder_depths: default_encoder_depths(), - focal_levels: default_focal_levels(), - focal_windows: default_focal_windows(), - max_khs: default_max_khs(), - sub_k: default_sub_k(), - } - } -} - -impl UniRecConfig { - /// Load configuration from a JSON file. - pub fn from_path(path: impl AsRef) -> Result { - let path = path.as_ref(); - let content = std::fs::read_to_string(path).map_err(|e| OCRError::ConfigError { - message: format!("Failed to read UniRec config from {:?}: {}", path, e), - })?; - serde_json::from_str(&content).map_err(|e| OCRError::ConfigError { - message: format!("Failed to parse UniRec config: {}", e), - }) - } - - /// Get the number of attention heads per key-value head. - /// For standard multi-head attention, this is 1. - pub fn num_key_value_heads(&self) -> usize { - self.decoder_attention_heads - } - - /// Get the dimension per attention head. - pub fn head_dim(&self) -> usize { - self.d_model / self.decoder_attention_heads - } -} diff --git a/oar-ocr-vl/src/unirec/decoder.rs b/oar-ocr-vl/src/unirec/decoder.rs deleted file mode 100644 index 5e56386..0000000 --- a/oar-ocr-vl/src/unirec/decoder.rs +++ /dev/null @@ -1,415 +0,0 @@ -//! M2M100-style decoder implementation for UniRec. -//! -//! Based on the M2M100 (Multilingual Translation) decoder architecture -//! with learned positional embeddings and scaled word embeddings. - -use candle_core::{D, DType, Device, Result, Tensor}; -use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder, kv_cache::KvCache}; -use std::cell::RefCell; - -use super::config::UniRecConfig; -use crate::attention::{on_compute_device, scaled_dot_product_attention}; -use crate::utils::candle_to_ocr_inference; -use oar_ocr_core::core::OCRError; - -/// Scaled word embedding for M2M100. -#[derive(Debug, Clone)] -pub struct ScaledWordEmbedding { - embedding: Embedding, - embed_scale: f64, -} - -impl ScaledWordEmbedding { - fn load(vocab_size: usize, embed_dim: usize, scale: bool, vb: VarBuilder) -> Result { - let embedding = candle_nn::embedding(vocab_size, embed_dim, vb)?; - let embed_scale = if scale { - (embed_dim as f64).sqrt() - } else { - 1.0 - }; - Ok(Self { - embedding, - embed_scale, - }) - } -} - -impl Module for ScaledWordEmbedding { - fn forward(&self, x: &Tensor) -> Result { - let emb = self.embedding.forward(x)?; - if self.embed_scale != 1.0 { - &emb * self.embed_scale - } else { - Ok(emb) - } - } -} - -/// Sinusoidal positional embedding (computed, not learned). -#[derive(Debug, Clone)] -struct SinusoidalPositionalEmbedding { - embed_dim: usize, - offset: usize, -} - -impl SinusoidalPositionalEmbedding { - fn new(embed_dim: usize) -> Self { - // M2M100 uses offset=2 for padding - Self { - embed_dim, - offset: 2, - } - } - - fn forward(&self, position_ids: &Tensor, device: &Device, dtype: DType) -> Result { - let half_dim = self.embed_dim / 2; - let emb_scale = -(10000f64.ln()) / (half_dim as f64); - - // Use on_compute_device to handle Metal's lack of support for arange and broadcast_* - on_compute_device(device, |compute_device| { - // Create frequency tensor: exp(-i * log(10000) / half_dim) for i in [0, half_dim) - // Shape: [half_dim] - let freq_indices = - Tensor::arange(0u32, half_dim as u32, compute_device)?.to_dtype(DType::F32)?; - let freqs = (&freq_indices * emb_scale)?.exp()?; - - // Transfer position_ids to compute device if needed - let position_ids = position_ids.to_device(compute_device)?; - - // Add offset to positions: [batch, seq_len] - let positions = position_ids - .to_dtype(DType::F32)? - .broadcast_add(&Tensor::new(self.offset as f32, compute_device)?)?; - - // Compute angles: positions [batch, seq_len, 1] * freqs [1, 1, half_dim] - // Result shape: [batch, seq_len, half_dim] - let positions_expanded = positions.unsqueeze(D::Minus1)?; - let freqs_expanded = freqs.reshape((1, 1, half_dim))?; - let angles = positions_expanded.broadcast_mul(&freqs_expanded)?; - - // Compute sin and cos, then concatenate: [batch, seq_len, embed_dim] - let sin_emb = angles.sin()?; - let cos_emb = angles.cos()?; - Tensor::cat(&[&sin_emb, &cos_emb], D::Minus1)?.to_dtype(dtype) - }) - } -} - -/// Multi-head attention for decoder. -#[derive(Debug)] -struct M2M100Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, - out_proj: Linear, - num_heads: usize, - head_dim: usize, - scale: f64, - kv_cache: RefCell, -} - -impl M2M100Attention { - fn load( - embed_dim: usize, - num_heads: usize, - _is_cross_attention: bool, - vb: VarBuilder, - ) -> Result { - let head_dim = embed_dim / num_heads; - let q_proj = candle_nn::linear(embed_dim, embed_dim, vb.pp("q_proj"))?; - let k_proj = candle_nn::linear(embed_dim, embed_dim, vb.pp("k_proj"))?; - let v_proj = candle_nn::linear(embed_dim, embed_dim, vb.pp("v_proj"))?; - let out_proj = candle_nn::linear(embed_dim, embed_dim, vb.pp("out_proj"))?; - - // Create KvCache with dim=2 for seq_len dimension - // Pre-allocate 8192 to avoid reallocation (double of typical 4096 max_tokens) - let kv_cache = KvCache::new(2, 8192); - - Ok(Self { - q_proj, - k_proj, - v_proj, - out_proj, - num_heads, - head_dim, - scale: (head_dim as f64).powf(-0.5), - kv_cache: RefCell::new(kv_cache), - }) - } - - fn forward( - &self, - hidden_states: &Tensor, - key_value_states: Option<&Tensor>, - attention_mask: Option<&Tensor>, - is_cross_attention: bool, - ) -> Result { - let (batch_size, seq_len, _) = hidden_states.dims3()?; - - // Query projection - let query_states = self.q_proj.forward(hidden_states)?; - let query_states = query_states - .reshape((batch_size, seq_len, self.num_heads, self.head_dim))? - .transpose(1, 2)? - .contiguous()?; // (B, num_heads, seq_len, head_dim) - - // For cross-attention, check if we already have cached KV - let use_cached = if is_cross_attention { - self.kv_cache.borrow().current_seq_len() > 0 - } else { - false - }; - - let (key_states, value_states) = if use_cached { - // Cross-attention: reuse cached encoder KV - let cache = self.kv_cache.borrow(); - match (cache.k()?, cache.v()?) { - (Some(k), Some(v)) => (k.clone(), v.clone()), - _ => return Err(candle_core::Error::Msg("kv cache is empty".into())), - } - } else { - // Self-attention or first cross-attention step: compute new KV - let kv_source = key_value_states.unwrap_or(hidden_states); - let (_, kv_len, _) = kv_source.dims3()?; - - let k = self.k_proj.forward(kv_source)?; - let v = self.v_proj.forward(kv_source)?; - let k = k - .reshape((batch_size, kv_len, self.num_heads, self.head_dim))? - .transpose(1, 2)? - .contiguous()?; - let v = v - .reshape((batch_size, kv_len, self.num_heads, self.head_dim))? - .transpose(1, 2)? - .contiguous()?; - - // Append to cache - self.kv_cache.borrow_mut().append(&k, &v)? - }; - - // Use unified scaled dot-product attention - let attn_output = scaled_dot_product_attention( - &query_states, - &key_states, - &value_states, - attention_mask, - self.scale, - false, // is_causal=false, mask is passed explicitly - )?; - - // Reshape back - let attn_output = attn_output.transpose(1, 2)?.contiguous()?.reshape(( - batch_size, - seq_len, - self.num_heads * self.head_dim, - ))?; - - self.out_proj.forward(&attn_output) - } - - fn clear_cache(&self) { - self.kv_cache.borrow_mut().reset(); - } -} - -/// M2M100 decoder layer. -#[derive(Debug)] -struct M2M100DecoderLayer { - self_attn: M2M100Attention, - self_attn_layer_norm: LayerNorm, - encoder_attn: M2M100Attention, - encoder_attn_layer_norm: LayerNorm, - fc1: Linear, - fc2: Linear, - final_layer_norm: LayerNorm, -} - -impl M2M100DecoderLayer { - fn load(cfg: &UniRecConfig, vb: VarBuilder) -> Result { - let self_attn = M2M100Attention::load( - cfg.d_model, - cfg.decoder_attention_heads, - false, - vb.pp("self_attn"), - )?; - let self_attn_layer_norm = - candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; - - let encoder_attn = M2M100Attention::load( - cfg.d_model, - cfg.decoder_attention_heads, - true, - vb.pp("encoder_attn"), - )?; - let encoder_attn_layer_norm = - candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?; - - let fc1 = candle_nn::linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?; - let fc2 = candle_nn::linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; - let final_layer_norm = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; - - Ok(Self { - self_attn, - self_attn_layer_norm, - encoder_attn, - encoder_attn_layer_norm, - fc1, - fc2, - final_layer_norm, - }) - } - - fn forward( - &self, - hidden_states: &Tensor, - encoder_hidden_states: &Tensor, - self_attn_mask: Option<&Tensor>, - cross_attn_mask: Option<&Tensor>, - ) -> Result { - let residual = hidden_states.clone(); - - // Self attention - let hidden_states = self.self_attn_layer_norm.forward(hidden_states)?; - let hidden_states = self.self_attn.forward( - &hidden_states, - None, - self_attn_mask, - false, // is_cross_attention = false for self-attention - )?; - let hidden_states = (&residual + &hidden_states)?; - - // Cross attention - let residual = hidden_states.clone(); - let hidden_states = self.encoder_attn_layer_norm.forward(&hidden_states)?; - let hidden_states = self.encoder_attn.forward( - &hidden_states, - Some(encoder_hidden_states), - cross_attn_mask, - true, // is_cross_attention = true for cross-attention - )?; - let hidden_states = (&residual + &hidden_states)?; - - // FFN - let residual = hidden_states.clone(); - let hidden_states = self.final_layer_norm.forward(&hidden_states)?; - let hidden_states = self.fc1.forward(&hidden_states)?; - let hidden_states = hidden_states.relu()?; - let hidden_states = self.fc2.forward(&hidden_states)?; - &residual + &hidden_states - } - - fn clear_kv_cache(&self) { - self.self_attn.clear_cache(); - self.encoder_attn.clear_cache(); - } -} - -/// M2M100 Decoder. -#[derive(Debug)] -pub struct M2M100Decoder { - embed_tokens: ScaledWordEmbedding, - embed_positions: SinusoidalPositionalEmbedding, - layers: Vec, - layer_norm: LayerNorm, -} - -impl M2M100Decoder { - /// Load M2M100 decoder from weights. - pub fn load(cfg: &UniRecConfig, vb: VarBuilder) -> std::result::Result { - let embed_tokens = ScaledWordEmbedding::load( - cfg.vocab_size, - cfg.d_model, - cfg.scale_embedding, - vb.pp("embed_tokens"), - ) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "load embed_tokens", e))?; - - // Use sinusoidal positional embeddings (computed, not learned) - let embed_positions = SinusoidalPositionalEmbedding::new(cfg.d_model); - - let mut layers = Vec::new(); - for i in 0..cfg.decoder_layers { - let layer = - M2M100DecoderLayer::load(cfg, vb.pp(format!("layers.{}", i))).map_err(|e| { - candle_to_ocr_inference("M2M100Decoder", format!("load layer.{}", i), e) - })?; - layers.push(layer); - } - - let layer_norm = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("layer_norm")) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "load layer_norm", e))?; - - Ok(Self { - embed_tokens, - embed_positions, - layers, - layer_norm, - }) - } - - /// Forward pass through the decoder. - pub fn forward( - &self, - input_ids: &Tensor, - encoder_hidden_states: &Tensor, - position_offset: usize, - self_attn_mask: Option<&Tensor>, - ) -> std::result::Result { - let (batch_size, seq_len) = input_ids - .dims2() - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "get input dims", e))?; - let device = input_ids.device(); - let dtype = encoder_hidden_states.dtype(); - - // Token embeddings - let inputs_embeds = self - .embed_tokens - .forward(input_ids) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "embed_tokens forward", e))?; - - // Position ids - let position_ids: Vec = (position_offset..(position_offset + seq_len)) - .map(|p| p as u32) - .collect(); - let position_ids = Tensor::new(&position_ids[..], device) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "create position_ids", e))? - .unsqueeze(0) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "unsqueeze position_ids", e))? - .broadcast_as((batch_size, seq_len)) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "broadcast position_ids", e))?; - - let positions = self - .embed_positions - .forward(&position_ids, device, dtype) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "embed_positions forward", e))?; - - // Combine embeddings - let mut hidden_states = (&inputs_embeds + &positions) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "add embeddings", e))?; - - // Process through layers - for (i, layer) in self.layers.iter().enumerate() { - hidden_states = layer - .forward( - &hidden_states, - encoder_hidden_states, - self_attn_mask, - None, // cross_attn_mask - encoder hidden states don't need masking - ) - .map_err(|e| { - candle_to_ocr_inference("M2M100Decoder", format!("layer.{} forward", i), e) - })?; - } - - // Final layer norm - self.layer_norm - .forward(&hidden_states) - .map_err(|e| candle_to_ocr_inference("M2M100Decoder", "layer_norm forward", e)) - } - - pub fn clear_kv_cache(&self) { - for layer in &self.layers { - layer.clear_kv_cache(); - } - } -} diff --git a/oar-ocr-vl/src/unirec/encoder.rs b/oar-ocr-vl/src/unirec/encoder.rs deleted file mode 100644 index aaa8800..0000000 --- a/oar-ocr-vl/src/unirec/encoder.rs +++ /dev/null @@ -1,609 +0,0 @@ -//! FocalSVTR visual encoder implementation. -//! -//! FocalSVTR combines Focal Modulation Networks with SVTR-style patch embedding -//! for vision encoding in OCR tasks. - -use candle_core::{Module, Result, Tensor}; -use candle_nn::{Conv2d, Conv2dConfig, Dropout, LayerNorm, Linear, VarBuilder}; - -use super::config::UniRecConfig; -use crate::utils::candle_to_ocr_inference; -use oar_ocr_core::core::OCRError; - -/// Drop path (stochastic depth) for regularization. -#[derive(Debug, Clone)] -struct DropPath { - drop_prob: f64, -} - -impl DropPath { - fn new(drop_prob: f64) -> Self { - Self { drop_prob } - } -} - -impl Module for DropPath { - fn forward(&self, x: &Tensor) -> Result { - if self.drop_prob == 0.0 { - return Ok(x.clone()); - } - // During inference, we don't apply drop path - Ok(x.clone()) - } -} - -/// Convolution + BatchNorm + Activation layer. -#[derive(Debug, Clone)] -struct ConvBNLayer { - conv: Conv2d, - bn_weight: Tensor, - bn_bias: Tensor, - bn_running_mean: Tensor, - bn_running_var: Tensor, - eps: f64, -} - -impl ConvBNLayer { - fn load( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - stride: usize, - padding: usize, - vb: VarBuilder, - ) -> Result { - let conv_cfg = Conv2dConfig { - stride, - padding, - ..Default::default() - }; - // Conv layer has no bias in the original model - let conv = candle_nn::conv2d_no_bias( - in_channels, - out_channels, - kernel_size, - conv_cfg, - vb.pp("conv"), - )?; - - // Load BatchNorm parameters - let norm_vb = vb.pp("norm"); - let bn_weight = norm_vb.get(out_channels, "weight")?; - let bn_bias = norm_vb.get(out_channels, "bias")?; - let bn_running_mean = norm_vb.get(out_channels, "running_mean")?; - let bn_running_var = norm_vb.get(out_channels, "running_var")?; - - Ok(Self { - conv, - bn_weight, - bn_bias, - bn_running_mean, - bn_running_var, - eps: 1e-5, - }) - } -} - -impl Module for ConvBNLayer { - fn forward(&self, x: &Tensor) -> Result { - let x = self.conv.forward(x)?; - // Apply batch norm in inference mode: y = (x - mean) / sqrt(var + eps) * weight + bias - let (_, c, _, _) = x.dims4()?; - - // Reshape stats for broadcasting: (C,) -> (1, C, 1, 1) - let mean = self.bn_running_mean.reshape((1, c, 1, 1))?; - let var = self.bn_running_var.reshape((1, c, 1, 1))?; - let weight = self.bn_weight.reshape((1, c, 1, 1))?; - let bias = self.bn_bias.reshape((1, c, 1, 1))?; - - // Normalize - let x = x.broadcast_sub(&mean)?; - let std = (var + self.eps)?.sqrt()?; - let x = x.broadcast_div(&std)?; - let x = x.broadcast_mul(&weight)?; - let x = x.broadcast_add(&bias)?; - - // GELU activation - x.gelu() - } -} - -/// MLP block with two linear layers. -#[derive(Debug, Clone)] -struct Mlp { - fc1: Linear, - fc2: Linear, - drop: Dropout, -} - -impl Mlp { - fn load(in_features: usize, hidden_features: usize, drop: f64, vb: VarBuilder) -> Result { - let fc1 = candle_nn::linear(in_features, hidden_features, vb.pp("fc1"))?; - let fc2 = candle_nn::linear(hidden_features, in_features, vb.pp("fc2"))?; - let drop = Dropout::new(drop as f32); - Ok(Self { fc1, fc2, drop }) - } -} - -impl Module for Mlp { - fn forward(&self, x: &Tensor) -> Result { - let x = self.fc1.forward(x)?; - let x = x.gelu()?; - let x = self.drop.forward(&x, false)?; - let x = self.fc2.forward(&x)?; - self.drop.forward(&x, false) - } -} - -/// Focal Modulation block for multi-scale context aggregation. -#[derive(Debug)] -struct FocalModulation { - focal_level: usize, - f_proj: Linear, - h_conv: Conv2d, - proj: Linear, - focal_layers: Vec, -} - -impl FocalModulation { - fn load(dim: usize, focal_window: usize, focal_level: usize, vb: VarBuilder) -> Result { - // f projects to 2*dim + (focal_level + 1) for gating - let f_proj = candle_nn::linear(dim, 2 * dim + focal_level + 1, vb.pp("f"))?; - - // h is a 1x1 conv for modulator - let h_conv = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp("h"))?; - - // Output projection - let proj = candle_nn::linear(dim, dim, vb.pp("proj"))?; - - // Focal convolution layers at different scales - let focal_factor = 2; - let mut focal_layers = Vec::new(); - for k in 0..focal_level { - let kernel_size = focal_factor * k + focal_window; - // Padding should be (kernel_size - 1) / 2 to preserve spatial dimensions - let padding = kernel_size / 2; - - // Depthwise conv with groups=dim - let cfg = Conv2dConfig { - padding, - groups: dim, - ..Default::default() - }; - // Weight key is focal_layers.{k}.0.weight (the .0 is from Sequential) - let conv = candle_nn::conv2d_no_bias( - dim, - dim, - kernel_size, - cfg, - vb.pp(format!("focal_layers.{}.0", k)), - )?; - focal_layers.push(conv); - } - - Ok(Self { - focal_level, - f_proj, - h_conv, - proj, - focal_layers, - }) - } - - fn forward(&self, x: &Tensor) -> Result { - // x: (B, H, W, C) - let (b, h, w, c) = x.dims4()?; - - // Pre linear projection: (B, H, W, C) -> (B, H, W, 2*C + focal_level + 1) - let projected = self.f_proj.forward(x)?; - // Permute to (B, 2*C + focal_level + 1, H, W) and make contiguous for CUDA - let projected = projected.permute((0, 3, 1, 2))?.contiguous()?; - - // Split into q, ctx, gates - let q = projected.narrow(1, 0, c)?.contiguous()?; - let mut ctx = projected.narrow(1, c, c)?.contiguous()?; - let gates = projected.narrow(1, 2 * c, self.focal_level + 1)?; - - // Context aggregation with focal convolutions - let mut ctx_all = Tensor::zeros((b, c, h, w), x.dtype(), x.device())?; - for l in 0..self.focal_level { - ctx = self.focal_layers[l].forward(&ctx)?; - ctx = ctx.gelu()?; - let gate = gates.narrow(1, l, 1)?; - let weighted = ctx.broadcast_mul(&gate)?; - ctx_all = (&ctx_all + &weighted)?; - } - - // Global context - let ctx_global = ctx.mean_keepdim(2)?.mean_keepdim(3)?; - let ctx_global = ctx_global.gelu()?; - let gate_global = gates.narrow(1, self.focal_level, 1)?; - let weighted_global = ctx_global.broadcast_mul(&gate_global)?; - ctx_all = (&ctx_all + &weighted_global)?; - - // Focal modulation - let modulator = self.h_conv.forward(&ctx_all)?; - let x_out = (&q * &modulator)?; - - // Permute back to (B, H, W, C) and make contiguous for linear projection - let x_out = x_out.permute((0, 2, 3, 1))?.contiguous()?; - - // Post linear projection - self.proj.forward(&x_out) - } -} - -/// Focal Network Block. -#[derive(Debug)] -struct FocalNetBlock { - norm1: LayerNorm, - modulation: FocalModulation, - drop_path: DropPath, - norm2: LayerNorm, - mlp: Mlp, -} - -impl FocalNetBlock { - fn load( - dim: usize, - mlp_ratio: f64, - drop: f64, - drop_path: f64, - focal_level: usize, - focal_window: usize, - vb: VarBuilder, - ) -> Result { - let norm1 = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let modulation = - FocalModulation::load(dim, focal_window, focal_level, vb.pp("modulation"))?; - let drop_path = DropPath::new(drop_path); - let norm2 = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm2"))?; - let mlp_hidden_dim = (dim as f64 * mlp_ratio) as usize; - let mlp = Mlp::load(dim, mlp_hidden_dim, drop, vb.pp("mlp"))?; - - Ok(Self { - norm1, - modulation, - drop_path, - norm2, - mlp, - }) - } - - fn forward(&self, x: &Tensor, h: usize, w: usize) -> Result { - let (b, _l, c) = x.dims3()?; - let shortcut = x.clone(); - - // Focal modulation - let x = self - .norm1 - .forward(x) - .map_err(|e| candle_core::Error::Msg(format!("norm1 failed: {}", e)))?; - let x = x.reshape((b, h, w, c)).map_err(|e| { - candle_core::Error::Msg(format!("reshape for modulation failed: {}", e)) - })?; - let x = self - .modulation - .forward(&x) - .map_err(|e| candle_core::Error::Msg(format!("modulation failed: {}", e)))?; - let x = x.reshape((b, h * w, c)).map_err(|e| { - candle_core::Error::Msg(format!("reshape after modulation failed: {}", e)) - })?; - - // Residual connection with drop path - let x = self.drop_path.forward(&x)?; - let x = (&shortcut + &x)?; - - // MLP - let mlp_out = self - .norm2 - .forward(&x) - .map_err(|e| candle_core::Error::Msg(format!("norm2 failed: {}", e)))?; - let mlp_out = self - .mlp - .forward(&mlp_out) - .map_err(|e| candle_core::Error::Msg(format!("mlp failed: {}", e)))?; - let mlp_out = self.drop_path.forward(&mlp_out)?; - &x + &mlp_out - } -} - -/// Patch embedding layer. -#[derive(Debug)] -struct PatchEmbed { - proj: Conv2d, - norm: Option, -} - -impl PatchEmbed { - fn load( - in_chans: usize, - embed_dim: usize, - patch_size: (usize, usize), - use_conv_embed: bool, - is_stem: bool, - vb: VarBuilder, - ) -> Result { - let (kernel_size, stride, padding) = if use_conv_embed { - if is_stem { (7, 4, 2) } else { (3, 2, 1) } - } else { - ( - patch_size.0.max(patch_size.1), - patch_size.0.max(patch_size.1), - 0, - ) - }; - - let cfg = Conv2dConfig { - stride, - padding, - ..Default::default() - }; - let proj = candle_nn::conv2d(in_chans, embed_dim, kernel_size, cfg, vb.pp("proj"))?; - let norm = candle_nn::layer_norm(embed_dim, 1e-5, vb.pp("norm")).ok(); - - Ok(Self { proj, norm }) - } - - fn forward(&self, x: &Tensor) -> Result<(Tensor, usize, usize)> { - let x = self.proj.forward(x)?; - let (_, _, h, w) = x.dims4()?; - // Flatten to (B, H*W, C) - let x = x.flatten(2, 3)?.permute((0, 2, 1))?; - let x = if let Some(ref norm) = self.norm { - norm.forward(&x)? - } else { - x - }; - Ok((x, h, w)) - } -} - -/// Configuration for BasicLayer. -struct BasicLayerConfig<'a> { - dim: usize, - out_dim: Option, - depth: usize, - mlp_ratio: f64, - drop: f64, - drop_path: &'a [f64], - focal_level: usize, - focal_window: usize, - downsample_kernel: Option<(usize, usize)>, -} - -/// Basic layer containing multiple FocalNetBlocks. -#[derive(Debug)] -struct BasicLayer { - blocks: Vec, - downsample: Option, -} - -impl BasicLayer { - fn load(cfg: BasicLayerConfig<'_>, vb: VarBuilder) -> Result { - let mut blocks = Vec::new(); - for i in 0..cfg.depth { - let block = FocalNetBlock::load( - cfg.dim, - cfg.mlp_ratio, - cfg.drop, - cfg.drop_path.get(i).copied().unwrap_or(0.0), - cfg.focal_level, - cfg.focal_window, - vb.pp(format!("blocks.{}", i)), - )?; - blocks.push(block); - } - - let downsample = if let (Some(out_d), Some((kh, kw))) = (cfg.out_dim, cfg.downsample_kernel) - { - if kh > 0 && kw > 0 { - Some(PatchEmbed::load( - cfg.dim, - out_d, - (kh, kw), - false, - false, - vb.pp("downsample"), - )?) - } else { - None - } - } else { - None - }; - - Ok(Self { blocks, downsample }) - } - - fn forward(&self, x: &Tensor, h: usize, w: usize) -> Result<(Tensor, usize, usize)> { - let mut x = x.clone(); - let (mut out_h, mut out_w) = (h, w); - - for (i, block) in self.blocks.iter().enumerate() { - x = block - .forward(&x, out_h, out_w) - .map_err(|e| candle_core::Error::Msg(format!("block.{} failed: {}", i, e)))?; - } - - if let Some(ref downsample) = self.downsample { - let (b, _, c) = x.dims3()?; - // Reshape to (B, C, H, W) for downsampling - let x_2d = x.permute((0, 2, 1))?.reshape((b, c, out_h, out_w))?; - let (x_new, new_h, new_w) = downsample.forward(&x_2d)?; - x = x_new; - out_h = new_h; - out_w = new_w; - } - - Ok((x, out_h, out_w)) - } -} - -/// FocalSVTR visual encoder. -#[derive(Debug)] -pub struct FocalSVTR { - patch_embed_0: ConvBNLayer, - patch_embed_1: ConvBNLayer, - pos_drop: Dropout, - layers: Vec, - vision_fc: Linear, -} - -impl FocalSVTR { - /// Load FocalSVTR encoder from weights. - pub fn load(cfg: &UniRecConfig, vb: VarBuilder) -> std::result::Result { - // Calculate embed dimensions for each stage - let base_dim = cfg.encoder_embed_dim; - let embed_dims: Vec = (0..cfg.encoder_depths.len()) - .map(|i| base_dim * (1 << i)) - .collect(); - let num_features = *embed_dims.last().unwrap_or(&base_dim); - - // Patch embedding (two ConvBNLayers) - let patch_embed_0 = ConvBNLayer::load( - 3, - embed_dims[0] / 2, - 3, - 2, - 1, - vb.pp("vision_encoder.patch_embed.0"), - ) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "load patch_embed.0", e))?; - - let patch_embed_1 = ConvBNLayer::load( - embed_dims[0] / 2, - embed_dims[0], - 3, - 2, - 1, - vb.pp("vision_encoder.patch_embed.1"), - ) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "load patch_embed.1", e))?; - - let pos_drop = Dropout::new(0.0); - - // Calculate drop path rates - let total_depth: usize = cfg.encoder_depths.iter().sum(); - let drop_path_rate = 0.1; - let dpr: Vec = (0..total_depth) - .map(|i| drop_path_rate * (i as f64) / (total_depth as f64 - 1.0).max(1.0)) - .collect(); - - // Build layers - let mut layers = Vec::new(); - let num_layers = cfg.encoder_depths.len(); - let mut depth_offset = 0; - - for i in 0..num_layers { - let depth = cfg.encoder_depths[i]; - let layer_dpr = &dpr[depth_offset..depth_offset + depth]; - depth_offset += depth; - - let out_dim = if i < num_layers - 1 { - Some(embed_dims[i + 1]) - } else { - None - }; - - let downsample_kernel = if i < num_layers - 1 { - let (kh, kw) = cfg.sub_k[i]; - if kh > 0 && kw > 0 { - Some((kh, kw)) - } else { - None - } - } else { - None - }; - - let focal_level = cfg.focal_levels.get(i).copied().unwrap_or(3); - let focal_window = cfg.focal_windows.get(i).copied().unwrap_or(3); - - let layer = BasicLayer::load( - BasicLayerConfig { - dim: embed_dims[i], - out_dim, - depth, - mlp_ratio: 4.0, - drop: 0.0, - drop_path: layer_dpr, - focal_level, - focal_window, - downsample_kernel, - }, - vb.pp(format!("vision_encoder.layers.{}", i)), - ) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", format!("load layer.{}", i), e))?; - - layers.push(layer); - } - - // Vision FC layer to project to d_model - let vision_fc = candle_nn::linear(num_features, cfg.d_model, vb.pp("vision_fc")) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "load vision_fc", e))?; - - Ok(Self { - patch_embed_0, - patch_embed_1, - pos_drop, - layers, - vision_fc, - }) - } - - /// Forward pass through the encoder. - pub fn forward(&self, x: &Tensor) -> std::result::Result { - // Patch embedding - let x = self - .patch_embed_0 - .forward(x) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "patch_embed_0", e))?; - let x = self - .patch_embed_1 - .forward(&x) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "patch_embed_1", e))?; - - let (_, _, h, w) = x - .dims4() - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "get dims", e))?; - - // Flatten to (B, H*W, C) - let x = x - .flatten(2, 3) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "flatten", e))? - .permute((0, 2, 1)) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "permute", e))?; - - // Position dropout - let mut x = self - .pos_drop - .forward(&x, false) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "pos_drop", e))?; - let (mut out_h, mut out_w) = (h, w); - - // Process through layers - for (i, layer) in self.layers.iter().enumerate() { - let result = layer.forward(&x, out_h, out_w); - match result { - Ok((new_x, new_h, new_w)) => { - x = new_x; - out_h = new_h; - out_w = new_w; - } - Err(e) => { - return Err(candle_to_ocr_inference( - "FocalSVTR", - format!("layer.{}: {}", i, e), - candle_core::Error::Msg(format!("{}", e)), - )); - } - } - } - - // Vision FC projection - self.vision_fc - .forward(&x) - .map_err(|e| candle_to_ocr_inference("FocalSVTR", "vision_fc", e)) - } -} diff --git a/oar-ocr-vl/src/unirec/mod.rs b/oar-ocr-vl/src/unirec/mod.rs deleted file mode 100644 index 500da5a..0000000 --- a/oar-ocr-vl/src/unirec/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -//! UniRec Vision-Language model implementation using Candle. -//! -//! UniRec is a unified text and formula recognition model that combines: -//! - FocalSVTR visual encoder (Focal Modulation Networks) -//! - M2M100 decoder (multilingual translation decoder architecture) -//! -//! This module provides a native Rust implementation for efficient inference. - -mod config; -mod decoder; -mod encoder; -mod model; - -pub use config::UniRecConfig; -pub use model::UniRec; diff --git a/oar-ocr-vl/src/unirec/model.rs b/oar-ocr-vl/src/unirec/model.rs deleted file mode 100644 index 53897a3..0000000 --- a/oar-ocr-vl/src/unirec/model.rs +++ /dev/null @@ -1,367 +0,0 @@ -//! UniRec Vision-Language model for unified text, formula, and table recognition. - -use candle_core::{D, DType, Device, IndexOp, Tensor}; -use candle_nn::{Linear, Module}; -use image::RgbImage; -use once_cell::sync::Lazy; -use regex::Regex; -use std::path::Path; -use tokenizers::Tokenizer; -use tokenizers::decoders::byte_level::ByteLevel; - -use super::config::UniRecConfig; -use super::decoder::M2M100Decoder; -use super::encoder::FocalSVTR; -use crate::utils::{candle_to_ocr_inference, image::image_to_chw}; -use oar_ocr_core::core::OCRError; - -// Static regexes for postprocessing (compiled once) -static UNDERSCORE_RE: Lazy = - Lazy::new(|| Regex::new(r"_{4,}").expect("static underscore regex")); -static DOTS_RE: Lazy = Lazy::new(|| Regex::new(r"\.{4,}").expect("static dots regex")); -static SPACES_RE: Lazy = Lazy::new(|| Regex::new(r"[ ]{2,}").expect("static spaces regex")); - -/// UniRec model for unified text, formula, and table recognition. -pub struct UniRec { - device: Device, - dtype: DType, - cfg: UniRecConfig, - tokenizer: Tokenizer, - encoder: FocalSVTR, - decoder: M2M100Decoder, - lm_head: Linear, -} - -impl UniRec { - /// Load UniRec model from a directory containing model weights and config. - pub fn from_dir(model_dir: impl AsRef, device: Device) -> Result { - let model_dir = model_dir.as_ref(); - - // Load config - let cfg = UniRecConfig::from_path(model_dir.join("config.json"))?; - - // Load tokenizer with ByteLevel decoder for proper BPE decoding - let mut tokenizer = - Tokenizer::from_file(model_dir.join("tokenizer.json")).map_err(|e| { - OCRError::ConfigError { - message: format!("Failed to load UniRec tokenizer: {}", e), - } - })?; - tokenizer.with_decoder(Some(ByteLevel::default())); - - // Determine dtype - let dtype = device.bf16_default_to_f32(); - - // Load model weights - let vb = unsafe { - candle_nn::VarBuilder::from_mmaped_safetensors( - &[model_dir.join("model.safetensors")], - dtype, - &device, - ) - .map_err(|e| candle_to_ocr_inference("UniRec", "load model.safetensors", e))? - }; - - // Load encoder - let encoder = FocalSVTR::load(&cfg, vb.pp("model.encoder"))?; - - // Load decoder - let decoder = M2M100Decoder::load(&cfg, vb.pp("model.decoder"))?; - - // Load language model head - let lm_head = candle_nn::linear_no_bias(cfg.d_model, cfg.vocab_size, vb.pp("lm_head")) - .map_err(|e| candle_to_ocr_inference("UniRec", "load lm_head", e))?; - - Ok(Self { - device, - dtype, - cfg, - tokenizer, - encoder, - decoder, - lm_head, - }) - } - - /// Compute target dimensions for an image preserving aspect ratio. - fn compute_target_size(&self, image: &RgbImage) -> (usize, usize) { - let (max_w, max_h) = (self.cfg.input_width, self.cfg.input_height); - let (orig_w, orig_h) = (image.width() as usize, image.height() as usize); - - // Compute resize dimensions preserving aspect ratio - let (new_w, new_h) = if orig_w > max_w || orig_h > max_h { - let aspect_ratio = orig_w as f64 / orig_h as f64; - if (max_w as f64 / max_h as f64) >= aspect_ratio { - // Height limited - let new_h = max_h; - let new_w = (new_h as f64 * aspect_ratio) as usize; - (new_w, new_h) - } else { - // Width limited - let new_w = max_w; - let new_h = (new_w as f64 / aspect_ratio) as usize; - (new_w, new_h) - } - } else { - (orig_w, orig_h) - }; - - // Round to multiples of 64 (minimum 64) - let target_h = ((new_h / 64) * 64).max(64); - let target_w = ((new_w / 64) * 64).max(64); - - (target_w, target_h) - } - - /// Preprocess an image with specific target dimensions (for batching). - fn preprocess_image_with_size( - &self, - image: &RgbImage, - target_w: usize, - target_h: usize, - ) -> Result { - // First resize to natural dimensions preserving aspect ratio - let (natural_w, natural_h) = self.compute_target_size(image); - - let resized = image::imageops::resize( - image, - natural_w as u32, - natural_h as u32, - image::imageops::FilterType::CatmullRom, - ); - - // Convert to tensor with normalization and padding to target size - let chw_data = image_to_chw( - &resized, - &[0.5, 0.5, 0.5], - &[0.5, 0.5, 0.5], - Some(1.0 / 255.0), - ); - - let mut data = vec![0f32; 3 * target_h * target_w]; - let natural_area = natural_h * natural_w; - let target_area = target_h * target_w; - - // Copy into padded buffer respecting stride - for c in 0..3 { - let src_channel_offset = c * natural_area; - let dst_channel_offset = c * target_area; - for y in 0..natural_h { - let src_row_start = src_channel_offset + y * natural_w; - let dst_row_start = dst_channel_offset + y * target_w; - data[dst_row_start..dst_row_start + natural_w] - .copy_from_slice(&chw_data[src_row_start..src_row_start + natural_w]); - } - } - - Tensor::from_vec(data, (1, 3, target_h, target_w), &self.device) - .map_err(|e| candle_to_ocr_inference("UniRec", "create padded image tensor", e))? - .to_dtype(self.dtype) - .map_err(|e| candle_to_ocr_inference("UniRec", "cast padded image dtype", e)) - } - - /// Generate text from one or more images using greedy decoding. - /// - /// Supports true GPU batching when multiple images are provided. - /// - /// # Arguments - /// * `images` - Input images - /// * `max_new_tokens` - Maximum tokens to generate per image - /// - /// # Returns - /// Vector of results, one per input image. - pub fn generate( - &self, - images: &[RgbImage], - max_new_tokens: usize, - ) -> Vec> { - if images.is_empty() { - return Vec::new(); - } - - match self.generate_internal(images, max_new_tokens) { - Ok(results) => results.into_iter().map(Ok).collect(), - Err(e) => { - let msg = format!("generation failed: {e}"); - (0..images.len()) - .map(|_| { - Err(OCRError::InvalidInput { - message: msg.clone(), - }) - }) - .collect() - } - } - } - - /// Internal generation implementation supporting batched inference. - fn generate_internal( - &self, - images: &[RgbImage], - max_new_tokens: usize, - ) -> Result, OCRError> { - let batch_size = images.len(); - - // 1. Compute max dimensions across all images - let mut max_w = 0usize; - let mut max_h = 0usize; - for image in images { - let (w, h) = self.compute_target_size(image); - max_w = max_w.max(w); - max_h = max_h.max(h); - } - - // 2. Preprocess all images with padding to max size - let mut pixel_tensors: Vec = Vec::with_capacity(batch_size); - for image in images { - let tensor = self.preprocess_image_with_size(image, max_w, max_h)?; - pixel_tensors.push(tensor); - } - - // 3. Stack into batch tensor - let refs: Vec<&Tensor> = pixel_tensors.iter().collect(); - let pixel_values = Tensor::cat(&refs, 0) - .map_err(|e| candle_to_ocr_inference("UniRec", "stack pixel values", e))?; - - // 4. Encode all images - let encoder_hidden_states = self.encoder.forward(&pixel_values)?; - - // 5. Clear KV cache before generation - self.decoder.clear_kv_cache(); - - // 6. Initialize generation state - let mut generated_tokens: Vec> = vec![vec![self.cfg.bos_token_id]; batch_size]; - let mut finished: Vec = vec![false; batch_size]; - let mut position_offset = 0usize; - - for _ in 0..max_new_tokens { - if finished.iter().all(|&f| f) { - break; - } - - // Create input tensor - let input_ids = if position_offset == 0 { - // First step - use BOS token for all - Tensor::new(vec![self.cfg.bos_token_id; batch_size], &self.device) - .map_err(|e| candle_to_ocr_inference("UniRec", "create input_ids", e))? - .reshape((batch_size, 1)) - .map_err(|e| candle_to_ocr_inference("UniRec", "reshape input_ids", e))? - } else { - // Subsequent steps - use last token for each sample - let last_tokens: Vec = generated_tokens - .iter() - .map(|tokens| *tokens.last().unwrap_or(&self.cfg.bos_token_id)) - .collect(); - Tensor::new(last_tokens, &self.device) - .map_err(|e| candle_to_ocr_inference("UniRec", "create input_ids", e))? - .reshape((batch_size, 1)) - .map_err(|e| candle_to_ocr_inference("UniRec", "reshape input_ids", e))? - }; - - // No causal mask needed for single token decode - let self_attn_mask = None; - - // Decoder forward - let hidden_states = self.decoder.forward( - &input_ids, - &encoder_hidden_states, - position_offset, - self_attn_mask, - )?; - - // Get logits for last position - let (_, seq_len, _) = hidden_states - .dims3() - .map_err(|e| candle_to_ocr_inference("UniRec", "get hidden dims", e))?; - let last_hidden = hidden_states - .i((.., seq_len - 1, ..)) - .map_err(|e| candle_to_ocr_inference("UniRec", "select last hidden", e))?; - - let logits = self - .lm_head - .forward(&last_hidden) - .map_err(|e| candle_to_ocr_inference("UniRec", "lm_head forward", e))?; - - // Greedy decoding for each sample - let next_tokens = logits - .argmax(D::Minus1) - .map_err(|e| candle_to_ocr_inference("UniRec", "argmax", e))? - .to_vec1::() - .map_err(|e| candle_to_ocr_inference("UniRec", "to_vec1", e))?; - - for (i, &token) in next_tokens.iter().enumerate() { - if finished[i] { - continue; - } - if token == self.cfg.eos_token_id { - finished[i] = true; - } else { - generated_tokens[i].push(token); - } - } - - // Update position offset - position_offset += 1; - } - - // 7. Decode tokens for each sample - let mut results = Vec::with_capacity(batch_size); - for tokens in generated_tokens { - // Skip BOS token - let tokens_to_decode: Vec = if tokens.len() > 1 { - tokens[1..].to_vec() - } else { - vec![] - }; - - let raw_output = self - .tokenizer - .decode(&tokens_to_decode, false) - .map_err(|e| OCRError::InvalidInput { - message: format!("UniRec: tokenizer decode failed: {}", e), - })?; - - results.push(postprocess_unirec_output(&raw_output)); - } - - Ok(results) - } - - /// Get model configuration. - pub fn config(&self) -> &UniRecConfig { - &self.cfg - } - - /// Get the device the model is running on. - pub fn device(&self) -> &Device { - &self.device - } -} - -/// Postprocess UniRec output to clean up special tokens. -/// -/// Matches the Python `clean_special_tokens` function from OpenOCR. -/// Note: ByteLevel decoder handles Ġ→space and Ċ→newline conversion automatically. -fn postprocess_unirec_output(text: &str) -> String { - // Clean up special tokens (ByteLevel decoder handles Ġ and Ċ) - let result = text - .replace("<|bos|>", "") - .replace("<|eos|>", "") - .replace("<|pad|>", "") - .replace("<|unk|>", "") - .replace("-<|sn|>", "") - .replace("<|sn|>", " ") - .replace("", "") - .replace("", "") - .replace('\u{FFFF}', ""); - - // Match OpenOCR's extra cleanup rules (using static regexes for efficiency). - let result = UNDERSCORE_RE.replace_all(&result, "___"); - let result = DOTS_RE.replace_all(&result, "..."); - - // Collapse repeated spaces introduced during token cleanup. - let result = SPACES_RE.replace_all(&result, " "); - - // Trim leading/trailing whitespace - result.trim().to_string() -} diff --git a/oar-ocr-vl/src/utils.rs b/oar-ocr-vl/src/utils.rs index 40ed74c..68befe9 100644 --- a/oar-ocr-vl/src/utils.rs +++ b/oar-ocr-vl/src/utils.rs @@ -55,9 +55,9 @@ use std::collections::HashSet; /// let cuda1 = parse_device("cuda:1")?; /// /// // Metal examples (only when metal feature is enabled) -/// # #[cfg(feature = "metal")] +/// # #[cfg(all(feature = "metal", target_os = "macos"))] /// let metal = parse_device("metal")?; -/// # #[cfg(feature = "metal")] +/// # #[cfg(all(feature = "metal", target_os = "macos"))] /// let metal1 = parse_device("metal:1")?; /// # Ok(()) /// # } @@ -69,15 +69,15 @@ fn cuda_not_enabled() -> OCRError { } } -#[cfg(not(feature = "metal"))] +#[cfg(not(all(feature = "metal", target_os = "macos")))] fn metal_not_enabled() -> OCRError { OCRError::ConfigError { - message: "Metal support not enabled. Compile with --features metal".to_string(), + message: "Metal support not enabled. Compile on macOS with --features metal".to_string(), } } /// Helper function to parse a device string with an ordinal (e.g., "cuda:1", "metal:0"). -#[cfg(any(feature = "cuda", feature = "metal"))] +#[cfg(any(feature = "cuda", all(feature = "metal", target_os = "macos")))] fn parse_device_with_ordinal( s: &str, prefix: &str, @@ -114,13 +114,13 @@ pub fn parse_device(device_str: &str) -> Result { } } "metal" => { - #[cfg(feature = "metal")] + #[cfg(all(feature = "metal", target_os = "macos"))] { Device::new_metal(0).map_err(|e| OCRError::ConfigError { message: format!("Failed to create Metal device: {}", e), }) } - #[cfg(not(feature = "metal"))] + #[cfg(not(all(feature = "metal", target_os = "macos")))] { Err(metal_not_enabled()) } @@ -136,11 +136,11 @@ pub fn parse_device(device_str: &str) -> Result { } } s if s.starts_with("metal:") => { - #[cfg(feature = "metal")] + #[cfg(all(feature = "metal", target_os = "macos"))] { parse_device_with_ordinal(s, "metal:", "Metal", Device::new_metal) } - #[cfg(not(feature = "metal"))] + #[cfg(not(all(feature = "metal", target_os = "macos")))] { Err(metal_not_enabled()) } @@ -708,7 +708,7 @@ mod tests { } } - #[cfg(feature = "metal")] + #[cfg(all(feature = "metal", target_os = "macos"))] #[test] fn test_parse_device_metal_with_ordinal() { // Test parsing "metal:0" - should always work on Apple devices @@ -719,7 +719,7 @@ mod tests { // so we don't test them here. The parsing logic itself works correctly. } - #[cfg(feature = "metal")] + #[cfg(all(feature = "metal", target_os = "macos"))] #[test] fn test_parse_device_metal_invalid_ordinal() { let result = parse_device("metal:abc"); @@ -729,7 +729,7 @@ mod tests { } } - #[cfg(not(feature = "metal"))] + #[cfg(not(all(feature = "metal", target_os = "macos")))] #[test] fn test_parse_device_metal_not_enabled() { let result = parse_device("metal"); diff --git a/oar-ocr-vl/src/utils/image.rs b/oar-ocr-vl/src/utils/image.rs index 3db2d83..abdea60 100644 --- a/oar-ocr-vl/src/utils/image.rs +++ b/oar-ocr-vl/src/utils/image.rs @@ -7,7 +7,13 @@ struct UnsafeSlice { slice: *mut [T], } +// SAFETY: `UnsafeSlice` only transfers a raw slice pointer between threads. +// Mutation remains gated behind `write`, whose caller must uphold disjoint +// in-bounds access. The element type must be safe to send to another thread. unsafe impl Send for UnsafeSlice {} +// SAFETY: Shared references to `UnsafeSlice` do not expose safe mutation. +// Concurrent writes are only possible through the unsafe `write` contract, +// which requires callers to avoid aliasing the same element across threads. unsafe impl Sync for UnsafeSlice {} impl UnsafeSlice { @@ -67,6 +73,9 @@ pub fn image_to_chw( let b_norm = (b - mean[2]) / std[2]; // Write to CHW planes + // SAFETY: each parallel iteration owns a unique pixel index `i`. + // The three destinations are in separate channel planes and are + // therefore distinct for every `i` in `0..num_pixels`. unsafe { output_slice.write(i, r_norm); output_slice.write(num_pixels + i, g_norm); diff --git a/oar-ocr-vl/src/utils/table.rs b/oar-ocr-vl/src/utils/table.rs index 7467ca5..5d01ce1 100644 --- a/oar-ocr-vl/src/utils/table.rs +++ b/oar-ocr-vl/src/utils/table.rs @@ -14,6 +14,181 @@ const OTSL_LCEL: &str = ""; const OTSL_UCEL: &str = ""; const OTSL_XCEL: &str = ""; +/// Convert an HTML `` snippet to PaddleOCR-VL's raw OTSL token form. +/// +/// PaddleOCR-VL's "Table Recognition:" prompt emits OTSL tokens +/// (``, ``, ``, ``, ``, ``) which the +/// model's post-process then translates into HTML. To feed an HTML-shaped +/// draft (e.g. from a layout pipeline) into HSD against PaddleOCR-VL we +/// need the inverse: HTML → OTSL. +/// +/// The parser is intentionally tolerant — it uses regex-based extraction +/// rather than a full HTML parser and mirrors `clean_html_table`'s repair of +/// common attribute typos (`...`) and within each row, cells +/// (`` / ``) with optional `colspan` / `rowspan`. +/// 2. Lay cells onto a 2D grid, skipping positions occupied by an earlier +/// cell's row/col span. +/// 3. Emit row-by-row: each original cell anchor becomes `content` (or +/// `` if empty); spanned positions emit `` (col-only), +/// `` (row-only), or `` (both). End each row with ``. +/// +/// Returns `None` when the input is empty, contains no `` tag, or has no +/// parseable cells. Callers can then decide whether to skip the draft. +pub fn convert_html_to_otsl(input: &str) -> Option { + let trimmed = input.trim(); + if trimmed.is_empty() || !TR_OPEN_RE.is_match(trimmed) { + return None; + } + // Repair the common `>> = vec![vec![None; num_cols]; num_rows]; + for (r, cells) in rows.into_iter().enumerate() { + let mut c = 0usize; + for (rowspan, colspan, text) in cells { + // Skip positions already occupied by a previous rowspan. + while c < num_cols && grid[r][c].is_some() { + c += 1; + } + if c >= num_cols { + break; + } + let rs = rowspan.max(1); + let cs = colspan.max(1); + let rs_end = (r + rs).min(num_rows); + let cs_end = (c + cs).min(num_cols); + for row in grid[r..rs_end].iter_mut() { + for slot in row[c..cs_end].iter_mut() { + *slot = Some((r, c, rs, cs, text.clone())); + } + } + c += cs; + } + } + + // Emit OTSL row by row. + let mut out = String::new(); + for (r, row) in grid.iter().enumerate() { + for (c, slot) in row.iter().enumerate() { + match slot { + None => out.push_str(OTSL_ECEL), + Some((anchor_r, anchor_c, _rs, _cs, text)) => { + let is_row_anchor = *anchor_r == r; + let is_col_anchor = *anchor_c == c; + match (is_row_anchor, is_col_anchor) { + (true, true) => { + if text.is_empty() { + out.push_str(OTSL_ECEL); + } else { + out.push_str(OTSL_FCEL); + out.push_str(text); + } + } + (true, false) => out.push_str(OTSL_LCEL), + (false, true) => out.push_str(OTSL_UCEL), + (false, false) => out.push_str(OTSL_XCEL), + } + } + } + } + out.push_str(OTSL_NL); + } + Some(out) +} + +static TR_RE: Lazy = + Lazy::new(|| Regex::new(r"(?is)]*>(.*?)").expect("static regex: ...")); +static TR_OPEN_RE: Lazy = + Lazy::new(|| Regex::new(r"(?i)]").expect("static regex: = Lazy::new(|| { + Regex::new(r"(?is)]*)>(.*?)").expect("static regex: ...") +}); +static STRIP_TAG_RE: Lazy = + Lazy::new(|| Regex::new(r"<[^>]*>").expect("static regex: html tag stripper")); +// `colspan` / `rowspan` attribute scanners. The leading `(?:^|\s)` anchors +// the match so substrings like `data-colspan=` or `class="mycolspan"` don't +// trip the parser (mis-extracting a span from an unrelated attribute). +static COLSPAN_RE: Lazy = Lazy::new(|| { + Regex::new(r#"(?i)(?:^|\s)colspan\s*=\s*"?(\d+)"?"#).expect("static regex: colspan attr") +}); +static ROWSPAN_RE: Lazy = Lazy::new(|| { + Regex::new(r#"(?i)(?:^|\s)rowspan\s*=\s*"?(\d+)"?"#).expect("static regex: rowspan attr") +}); + +fn extract_span(attrs: &str, name: &str) -> usize { + let re: &Regex = match name { + "colspan" => &COLSPAN_RE, + "rowspan" => &ROWSPAN_RE, + _ => return 1, + }; + re.captures(attrs) + .and_then(|caps| caps.get(1)) + .and_then(|m| m.as_str().parse::().ok()) + .filter(|n| *n > 0) + .unwrap_or(1) +} + +fn clean_cell_text(body: &str) -> String { + // Strip any nested tags (rare in OCR tables but possible — e.g.
, ). + let stripped = STRIP_TAG_RE.replace_all(body, ""); + // Decode the few HTML entities the existing post-process emits via + // `html_escape::encode_text`. We avoid pulling a full entity decoder for + // a hot path — these five cover what `otsl_export_to_html` produces. + let decoded = stripped + .replace("&", "&") + .replace("<", "<") + .replace(">", ">") + .replace(""", "\"") + .replace("'", "'"); + decoded.trim().to_string() +} + /// Convert OTSL table tokens (or TSV text) to HTML table. pub fn convert_otsl_to_html(input: &str) -> String { let trimmed = input.trim(); @@ -437,4 +612,101 @@ mod tests { assert!(html.contains("
......
")); assert!(html.contains("")); } + + #[test] + fn convert_html_to_otsl_simple_grid() { + // 2x2 plain table, no merges. + let html = "
a
ab
cd
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "abcd"); + } + + #[test] + fn convert_html_to_otsl_empty_cells_become_ecel() { + let html = "
a
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "a"); + } + + #[test] + fn convert_html_to_otsl_colspan_emits_lcel() { + // Row 1: A → A + let html = "
A
xy
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "Axy"); + } + + #[test] + fn convert_html_to_otsl_rowspan_emits_ucel() { + // Column 0 spans both rows; column 1 is two separate cells. + let html = "
Ab
c
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "Abc"); + } + + #[test] + fn convert_html_to_otsl_xcel_for_combined_span() { + // 2x2 cell merged together; the bottom-right corner must be . + let html = "
A
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "A"); + } + + #[test] + fn convert_html_to_otsl_handles_tdcolspan_typo() { + // PaddleOCR-VL post-process repairs `Axy"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "Axy"); + } + + #[test] + fn convert_html_to_otsl_decodes_html_entities() { + // The forward converter html_escapes content; the inverse must + // round-trip the most common entities. + let html = "
a & bx < y
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "a & bx < y"); + } + + #[test] + fn convert_html_to_otsl_returns_none_for_non_table_input() { + assert!(convert_html_to_otsl("plain text").is_none()); + assert!(convert_html_to_otsl("

not a table

").is_none()); + assert!(convert_html_to_otsl("").is_none()); + } + + #[test] + fn convert_html_to_otsl_accepts_uppercase_tags() { + // The cell regexes are case-insensitive; the precheck must be too, + // otherwise mixed-case HTML drops on the floor. + let html = "
ab
"; + let otsl = convert_html_to_otsl(html).expect("conversion"); + assert_eq!(otsl, "ab"); + } + + #[test] + fn extract_span_ignores_substring_attribute_matches() { + // `data-colspan=` / `xrowspan=` are unrelated attributes that happen + // to end with our needle. The scanner must not pick numbers out of + // them — a plain `` should still yield colspan=1, rowspan=1. + assert_eq!(extract_span(r#" data-colspan="7""#, "colspan"), 1); + assert_eq!(extract_span(r#" xrowspan="9""#, "rowspan"), 1); + assert_eq!(extract_span(r#" class="mycolspan""#, "colspan"), 1); + // Genuine match still works, including unquoted and mixed-case forms. + assert_eq!(extract_span(r#" colspan="3""#, "colspan"), 3); + assert_eq!(extract_span(" COLSPAN=4", "colspan"), 4); + assert_eq!(extract_span(r#" class="data" rowspan="2""#, "rowspan"), 2); + } + + #[test] + fn convert_html_to_otsl_roundtrips_through_otsl_to_html() { + // OTSL → HTML → OTSL should reconstruct the same OTSL for a simple + // grid (modulo whitespace handling). This guards against drift between + // the two converters. + let otsl_in = "abcd"; + let html = convert_otsl_to_html(otsl_in); + let otsl_out = convert_html_to_otsl(&html).expect("round-trip"); + assert_eq!(otsl_out, otsl_in); + } } diff --git a/src/oarocr/structure.rs b/src/oarocr/structure.rs index e479ddc..b8d1905 100644 --- a/src/oarocr/structure.rs +++ b/src/oarocr/structure.rs @@ -28,6 +28,7 @@ use oar_ocr_core::domain::tasks::{ }; use std::path::PathBuf; use std::sync::Arc; +use std::time::Instant; /// IoU threshold for removing overlapping layout elements (0.5 = 50% overlap). const LAYOUT_OVERLAP_IOU_THRESHOLD: f32 = 0.5; @@ -158,6 +159,7 @@ pub struct OARStructureBuilder { formula_recognition_model: Option, formula_recognition_type: Option, // "pp_formulanet" or "unimernet" formula_tokenizer_path: Option, + formula_ort_session_config: Option, // Optional seal text detection seal_text_detection_model: Option, @@ -224,6 +226,7 @@ impl OARStructureBuilder { formula_recognition_model: None, formula_recognition_type: None, formula_tokenizer_path: None, + formula_ort_session_config: None, seal_text_detection_model: None, text_detection_model: None, text_line_orientation_model: None, @@ -588,6 +591,12 @@ impl OARStructureBuilder { self } + /// Sets an ONNX Runtime session configuration only for formula recognition. + pub fn formula_ort_session(mut self, config: OrtSessionConfig) -> Self { + self.formula_ort_session_config = Some(config); + self + } + /// Integrates OCR into the pipeline for text extraction. /// /// # Arguments @@ -991,7 +1000,11 @@ impl OARStructureBuilder { builder = builder.task_config(config.clone()); } - if let Some(ref ort_config) = self.ort_session_config { + if let Some(ort_config) = self + .formula_ort_session_config + .as_ref() + .or(self.ort_session_config.as_ref()) + { builder = builder.with_ort_config(ort_config.clone()); } @@ -1008,7 +1021,11 @@ impl OARStructureBuilder { builder = builder.task_config(config.clone()); } - if let Some(ref ort_config) = self.ort_session_config { + if let Some(ort_config) = self + .formula_ort_session_config + .as_ref() + .or(self.ort_session_config.as_ref()) + { builder = builder.with_ort_config(ort_config.clone()); } @@ -1656,10 +1673,12 @@ impl OARStructure { use oar_ocr_core::domain::structure::{LayoutElement, LayoutElementType, RegionBlock}; let input = ImageTaskInput::new(vec![page_image.clone()]); + let t_layout = Instant::now(); let layout_result = self .pipeline .layout_detection_adapter .execute(input, None)?; + let layout_dur = t_layout.elapsed(); let mut layout_elements: Vec = Vec::new(); if let Some(elements) = layout_result.elements.first() { @@ -1675,6 +1694,7 @@ impl OARStructure { let mut detected_region_blocks: Option> = None; if let Some(ref region_adapter) = self.pipeline.region_detection_adapter { let region_input = ImageTaskInput::new(vec![page_image.clone()]); + let t_region = Instant::now(); if let Ok(region_result) = region_adapter.execute(region_input, None) && let Some(region_elements) = region_result.elements.first() && !region_elements.is_empty() @@ -1690,6 +1710,11 @@ impl OARStructure { .collect(); detected_region_blocks = Some(blocks); } + tracing::debug!( + "structure stage: region detection {:.1} ms, blocks={}", + t_region.elapsed().as_secs_f64() * 1000.0, + detected_region_blocks.as_ref().map_or(0, Vec::len) + ); } if layout_elements.len() > 1 { @@ -1707,6 +1732,11 @@ impl OARStructure { } crate::domain::structure::apply_standardized_layout_label_fixes(&mut layout_elements); + tracing::debug!( + "structure stage: layout detection {:.1} ms, elements={}", + layout_dur.as_secs_f64() * 1000.0, + layout_elements.len() + ); Ok((layout_elements, detected_region_blocks)) } @@ -1759,15 +1789,32 @@ impl OARStructure { return Ok(Vec::new()); } - let input = ImageTaskInput::new(crops); - let formula_result = formula_adapter.execute(input, None)?; + let t_formula = Instant::now(); + let batch_size = formula_adapter.recommended_batch_size().max(1); + let crop_count = bboxes.len(); + let mut formula_results = Vec::with_capacity(crop_count); + let mut score_results = Vec::with_capacity(crop_count); + let mut remaining_crops = crops; + while !remaining_crops.is_empty() { + let chunk_len = batch_size.min(remaining_crops.len()); + let rest = remaining_crops.split_off(chunk_len); + let chunk_vec = remaining_crops; + remaining_crops = rest; + + let output = formula_adapter.execute(ImageTaskInput::new(chunk_vec), None)?; + formula_results.extend(output.formulas); + score_results.extend(output.scores); + } + tracing::debug!( + "structure stage: formula recognition {:.1} ms, crops={}, batches={}, batch_size={}", + t_formula.elapsed().as_secs_f64() * 1000.0, + crop_count, + crop_count.div_ceil(batch_size), + batch_size + ); let mut formulas = Vec::new(); - for ((bbox, formula), score) in bboxes - .into_iter() - .zip(formula_result.formulas) - .zip(formula_result.scores) - { + for ((bbox, formula), score) in bboxes.into_iter().zip(formula_results).zip(score_results) { let width = bbox.x_max() - bbox.x_min(); let height = bbox.y_max() - bbox.y_min(); if width <= 0.0 || height <= 0.0 { @@ -1983,7 +2030,9 @@ impl OARStructure { // Text detection (on masked image). let input = ImageTaskInput::new(vec![ocr_image.clone()]); + let t_text_det = Instant::now(); let det_result = text_detection_adapter.execute(input, None)?; + let text_det_dur = t_text_det.elapsed(); let mut detection_boxes = if let Some(detections) = det_result.detections.first() { detections @@ -2164,6 +2213,8 @@ impl OARStructure { let batch_size = self.pipeline.region_batch_size.unwrap_or(8).max(1); let mut recognized_by_det_idx: Vec> = vec![None; detection_boxes.len()]; + let mut rec_batches = 0usize; + let t_text_rec = Instant::now(); while !items.is_empty() { let take_n = batch_size.min(items.len()); @@ -2178,6 +2229,7 @@ impl OARStructure { } let rec_input = ImageTaskInput::new(rec_imgs); + rec_batches += 1; if let Ok(rec_result) = text_recognition_adapter.execute(rec_input, None) { for ((det_idx, text), score) in det_indices .into_iter() @@ -2193,6 +2245,13 @@ impl OARStructure { } } } + tracing::debug!( + "structure stage: text recognition {:.1} ms, crops={}, batches={}, batch_size={}", + t_text_rec.elapsed().as_secs_f64() * 1000.0, + detection_boxes.len(), + rec_batches, + batch_size + ); // Emit OCR regions in original detection order, matching PaddleX. for (det_idx, rec) in recognized_by_det_idx.into_iter().enumerate() { @@ -2223,6 +2282,12 @@ impl OARStructure { text_recognition_adapter, batch_size, )?; + tracing::debug!( + "structure stage: text detection {:.1} ms, boxes={}, recognized_regions={}", + text_det_dur.as_secs_f64() * 1000.0, + detection_boxes.len(), + text_regions.len() + ); Ok(text_regions) } @@ -2319,13 +2384,16 @@ impl OARStructure { Self::assign_region_block_membership(regions, &layout_elements); } + let t_ocr = Instant::now(); let mut text_regions = self.run_overall_ocr( ¤t_image, &layout_elements, detected_region_blocks.as_deref(), )?; + let ocr_dur = t_ocr.elapsed(); { + let t_tables = Instant::now(); let analyzer = crate::oarocr::table_analyzer::TableAnalyzer::new( crate::oarocr::table_analyzer::TableAnalyzerConfig { table_classification_adapter: self @@ -2367,7 +2435,17 @@ impl OARStructure { &formulas, &text_regions, )?); + tracing::debug!( + "structure stage: table analysis {:.1} ms, tables={}", + t_tables.elapsed().as_secs_f64() * 1000.0, + tables.len() + ); } + tracing::debug!( + "structure stage: overall OCR total {:.1} ms, regions={}", + ocr_dur.as_secs_f64() * 1000.0, + text_regions.len() + ); // 5b. Optional OCR box splitting by table cell boundaries. // @@ -2512,17 +2590,23 @@ impl OARStructure { /// Analyzes the structure of a single document image. pub fn predict_image(&self, image: image::RgbImage) -> Result { + let t_total = Instant::now(); let prepared = self.prepare_page(image)?; let formulas = self.recognize_formulas(&prepared.current_image, &prepared.layout_elements)?; - self.complete_page(prepared, formulas) + let result = self.complete_page(prepared, formulas)?; + tracing::debug!( + "structure stage: total predict_image {:.1} ms", + t_total.elapsed().as_secs_f64() * 1000.0 + ); + Ok(result) } /// Analyzes multiple document page images with cross-page formula batching. /// /// All formula crops from every page are collected first and forwarded to the /// formula adapter in a single `execute` call, reducing ONNX inference overhead - /// compared to calling [`predict_image`] sequentially. Layout detection and all + /// compared to calling [`Self::predict_image`] sequentially. Layout detection and all /// other per-page steps are still performed independently per page. /// /// Per-page errors are returned individually so that a failure on one page does