diff --git a/README.md b/README.md index a6dd7c8..31d060c 100644 --- a/README.md +++ b/README.md @@ -2,18 +2,20 @@ Transformer-based bitwise-aligned rollout for VeOmni FSDP with VeRL integration. +VeXact is our zero-mismatch rollout engine for LLM reinforcement learning. See our paper **[Diagnosing Training-Inference Mismatch in LLM Reinforcement Learning](http://arxiv.org/abs/2605.14220)** for its use as a TIM-free diagnostic baseline. + ## Key Features - ๐ŸŽฏ **Bitwise-aligned training & inference** โ€” VeOmni FSDP actor and VeXact rollout engine produce identical logprobs for dense and MoE models with verl (the legacy FSDP engine is not supported for MoE models). - - All the dense model should work out-of-the-box if they are not using ops that are different between training and inference like linear attention. - - MoE models need to patch the model with Fused MoE kernel like our Qwen3-MoE and DeepSeek-V3 example. + - All the dense model should work out-of-the-box if they are not using ops that are different between training and inference like linear attention. + - MoE models need to patch the model with Fused MoE kernel like our Qwen3-MoE and DeepSeek-V3 example. - โšก **Fast and aligned kernels** โ€” Fused MoE, fused linear cross-entropy, Flash Attention 3/4 with paged KV cache, all numerically consistent between training and inference - ๐Ÿงฉ **Simple model definitions** โ€” Transformer model code is self-contained and easy to audit, so training and inference model definitions stay in sync - ๐Ÿ“– **Readable codebase** โ€” Clean implementation with chunked prefill, pipeline parallelism, and CUDA graph support -## Effectiveness +## Effectiveness -> **Qwen3-30B-A3B ยท REINFORCE++ ยท DAPO dataset** +> **Qwen3-30B-A3B ยท REINFORCE ยท DAPO dataset** Off-policy logprob bias from vLLM causes the rollout-correction KL to explode after ~300 steps, which triggers gradient norm blow-up and ultimately training collapse. VeXact's bitwise-aligned rollout keeps the KL at exactly zero throughout, yielding stable training and a ~2ร— higher final AIME 2024 score. @@ -46,17 +48,17 @@ bash examples/getting_started/run_qwen3_1b7.sh model_dir=/path/to/model data_dir=/path/to/data bash examples/moe/run_qwen3_30B_A3B_dapo.sh ``` -| Recipe | Model | Dataset | Hardware | Algorithm | -|---|---|---|---|---| -| [`getting_started/run_qwen3_1b7.sh`](examples/getting_started/run_qwen3_1b7.sh) | Qwen3-1.7B | gsm8k | 1ร—8H100 | GRPO | -| [`moe/run_qwen3_30B_A3B_dapo.sh`](examples/moe/run_qwen3_30B_A3B_dapo.sh) | Qwen3-30B-A3B | DAPO-Math-17k / AIME 2025 | 1ร—8H100 | DAPO | -| [`moe/run_qwen3_30B_A3B_reinforce.sh`](examples/moe/run_qwen3_30B_A3B_reinforce.sh) | Qwen3-30B-A3B-Base | DAPO-Math-17k / AIME 2024 | 8ร—8H100 | REINFORCE++ | -| [`moe/run_qwen3_30B_A3B_16H100.sh`](examples/moe/run_qwen3_30B_A3B_16H100.sh) | Qwen3-30B-A3B | gsm8k | 2ร—8H100 | GRPO | -| [`moe/run_qwen3_30B_A3B_8B200.sh`](examples/moe/run_qwen3_30B_A3B_8B200.sh) | Qwen3-30B-A3B | gsm8k | 1ร—8B200 | GRPO | -| [`moe/run_moonlight_gsm8k.sh`](examples/moe/run_moonlight_gsm8k.sh) | Moonlight-16B-A3B-Instruct | gsm8k | 1ร—8B200 | GRPO | -| [`moe/run_moonlight_reinforce.sh`](examples/moe/run_moonlight_reinforce.sh) | Moonlight-16B-A3B-Instruct | DAPO-Math-17k / AIME 2024 | 1ร—8B200 | REINFORCE++ | -| [`verify/run_dense_vexact.sh`](examples/verify/run_dense_vexact.sh) | DeepSeek-R1-Distill-Qwen-1.5B | MATH / AIME 2024+2025 | 1ร—8H100 | GRPO (vexact) | -| [`verify/run_dense_vllm.sh`](examples/verify/run_dense_vllm.sh) | DeepSeek-R1-Distill-Qwen-1.5B | MATH / AIME 2024+2025 | 1ร—8H100 | GRPO (vllm) | +| Recipe | Model | Dataset | Hardware | Algorithm | +| ----------------------------------------------------------------------------------- | ----------------------------- | ------------------------- | -------- | ------------- | +| [`getting_started/run_qwen3_1b7.sh`](examples/getting_started/run_qwen3_1b7.sh) | Qwen3-1.7B | gsm8k | 1ร—8H100 | GRPO | +| [`moe/run_qwen3_30B_A3B_dapo.sh`](examples/moe/run_qwen3_30B_A3B_dapo.sh) | Qwen3-30B-A3B | DAPO-Math-17k / AIME 2025 | 1ร—8H100 | DAPO | +| [`moe/run_qwen3_30B_A3B_reinforce.sh`](examples/moe/run_qwen3_30B_A3B_reinforce.sh) | Qwen3-30B-A3B-Base | DAPO-Math-17k / AIME 2024 | 8ร—8H100 | REINFORCE | +| [`moe/run_qwen3_30B_A3B_16H100.sh`](examples/moe/run_qwen3_30B_A3B_16H100.sh) | Qwen3-30B-A3B | gsm8k | 2ร—8H100 | GRPO | +| [`moe/run_qwen3_30B_A3B_8B200.sh`](examples/moe/run_qwen3_30B_A3B_8B200.sh) | Qwen3-30B-A3B | gsm8k | 1ร—8B200 | GRPO | +| [`moe/run_moonlight_gsm8k.sh`](examples/moe/run_moonlight_gsm8k.sh) | Moonlight-16B-A3B-Instruct | gsm8k | 1ร—8B200 | GRPO | +| [`moe/run_moonlight_reinforce.sh`](examples/moe/run_moonlight_reinforce.sh) | Moonlight-16B-A3B-Instruct | DAPO-Math-17k / AIME 2024 | 1ร—8B200 | REINFORCE | +| [`verify/run_dense_vexact.sh`](examples/verify/run_dense_vexact.sh) | DeepSeek-R1-Distill-Qwen-1.5B | MATH / AIME 2024+2025 | 1ร—8H100 | GRPO (vexact) | +| [`verify/run_dense_vllm.sh`](examples/verify/run_dense_vllm.sh) | DeepSeek-R1-Distill-Qwen-1.5B | MATH / AIME 2024+2025 | 1ร—8H100 | GRPO (vllm) | See [`examples/README.md`](examples/README.md) for path configuration, attention backend selection, and an explanation of the `verify/` pair. @@ -80,10 +82,10 @@ What each extra does: - `gpu` โ€” PyTorch (CUDA 12.9), FlashAttention 2/3/4, quack-kernels, NVML. - `verl` โ€” pulls verl from `verl-project/verl` (pinned by commit in - `[tool.uv.sources]`) plus FastAPI/uvicorn/cachetools used by the trainer. + `[tool.uv.sources]`) plus FastAPI/uvicorn/cachetools used by the trainer. - `veomni` โ€” pulls VeOmni from `ByteDance-Seed/VeOmni` (pinned by commit). - `vllm` โ€” vLLM 0.18 if you prefer it as the rollout engine instead of - VeXact's native one. + VeXact's native one. - `dev` โ€” `pytest`, `pytest-asyncio`, `pre-commit` for development. ### Working on verl or VeOmni locally @@ -118,3 +120,16 @@ Besides VeRL and VeOmni, VeXact builds on and is inspired by the following proje - [batch_invariant_ops](https://github.com/thinking-machines-lab/batch_invariant_ops) โ€” Batch-invariant operators for deterministic inference - [Torch Memory Saver](https://github.com/fzyzcjy/torch_memory_saver) - Model param and KV cache offloads. - [FlashAttention](https://github.com/Dao-AILab/flash-attention) - We support FA4 for SM90+ (including SM100) GPU, including MLA shape for DeepSeek-V3 model architecture. + +## Citation + +If you find our work useful, please consider citing our paper: + +```bibtex +@article{zhong2026diagnosing, + title={Diagnosing Training Inference Mismatch in LLM Reinforcement Learning}, + author={Zhong, Tianle and Ling, Neiwen and Pi, Yifan and Wei, Zijun and Yu, Tianshu and Fox, Geoffrey and Wu, Peng and Yu, Xiao}, + journal={arXiv preprint arXiv:2605.14220}, + year={2026} +} +```