Feat/fsdp sglang fp8 rollout#1379
Conversation
- Extract Triton kernel from megatron_utils/fp8/kernels.py into shared areal/utils/kernel/fp8_kernel.py with no TE/Megatron dependency - Add auto-padding for non-multiple-of-block-size dimensions - Add pure-PyTorch fallback when Triton is unavailable - Add should_quantize_param() for filtering Linear layers Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Add online FP8 block-wise quantization (128x128, e4m3fn) to FSDPEngine weight sync path before NCCL broadcast to SGLang - Add unified FP8 kernel with Triton + PyTorch fallback + auto-padding - Add should_quantize_param() for filtering Linear projection layers - Thread quantization config through SGLangConfig -> RLTrainer -> WeightUpdateMeta -> FSDPEngine - SGLangConfig.build_args() injects --quantization=fp8 and json_model_override_args with block-wise FP8 config - Add NotImplementedError guard for Megatron FP8 weight update path - Add quantization example configs in examples/quantization/ - Add unit tests for FP8 kernel (param filtering + quantization) Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces online FP8 block-wise quantization support for SGLang rollout while maintaining FSDP training in BF16. It implements a unified FP8 quantization kernel with a Triton path and a PyTorch fallback, updates weight synchronization in the FSDP engine, and adds corresponding configurations, examples, and tests. The review feedback highlights three key issues: a hardcoded absolute path in the test file that will cause environment failures, in-place tensor mutation in the PyTorch fallback kernel that risks corrupting model weights, and potential loss of user-defined model overrides when setting quantization configurations in the CLI arguments.
| spec = importlib.util.spec_from_file_location( | ||
| "fp8_kernel", | ||
| "/F00120250029/lixiang_share/zengziyi_share/zengziyi/Research/Areal_sub/areal/utils/kernel/fp8_kernel.py", | ||
| ) |
There was a problem hiding this comment.
The test file uses a hardcoded absolute path pointing to a specific local directory (/F00120250029/lixiang_share/...). This will cause the tests to fail on any other machine or CI/CD environment. Please resolve the path dynamically relative to the test file's location.
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
fp8_kernel_path = os.path.join(repo_root, "areal/utils/kernel/fp8_kernel.py")
spec = importlib.util.spec_from_file_location("fp8_kernel", fp8_kernel_path)| def _scaled_fp8_blockwise_pytorch( | ||
| data_hp: torch.Tensor, | ||
| block_size: list[int] | tuple[int, int], | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Pure-PyTorch block-wise FP8 quantization. | ||
|
|
||
| Args: | ||
| data_hp: BF16/FP16 weight tensor, shape (M, N). | ||
| block_size: [block_m, block_n]. | ||
|
|
||
| Returns: | ||
| (fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n)) | ||
| and scale = absmax / FP8_MAX. | ||
| """ | ||
| block_size0, block_size1 = block_size[0], block_size[1] | ||
| original_shape = data_hp.shape |
There was a problem hiding this comment.
The pure-PyTorch fallback function _scaled_fp8_blockwise_pytorch mutates the input tensor data_hp in-place (e.g., data_hp.mul_(scale_fp) and data_hp.clamp_). If the input tensor is a direct reference to the model's weights (which can happen if _get_full_tensor returns the parameter data directly), this will corrupt the model's weights during training. Cloning the input tensor at the start of the function ensures safety.
| def _scaled_fp8_blockwise_pytorch( | |
| data_hp: torch.Tensor, | |
| block_size: list[int] | tuple[int, int], | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Pure-PyTorch block-wise FP8 quantization. | |
| Args: | |
| data_hp: BF16/FP16 weight tensor, shape (M, N). | |
| block_size: [block_m, block_n]. | |
| Returns: | |
| (fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n)) | |
| and scale = absmax / FP8_MAX. | |
| """ | |
| block_size0, block_size1 = block_size[0], block_size[1] | |
| original_shape = data_hp.shape | |
| def _scaled_fp8_blockwise_pytorch( | |
| data_hp: torch.Tensor, | |
| block_size: list[int] | tuple[int, int], | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Pure-PyTorch block-wise FP8 quantization. | |
| Args: | |
| data_hp: BF16/FP16 weight tensor, shape (M, N). | |
| block_size: [block_m, block_n]. | |
| Returns: | |
| (fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n)) | |
| and scale = absmax / FP8_MAX. | |
| """ | |
| data_hp = data_hp.clone() | |
| block_size0, block_size1 = block_size[0], block_size[1] | |
| original_shape = data_hp.shape |
| args["json_model_override_args"] = json.dumps( | ||
| {"quantization_config": fp8_quant_config}, | ||
| separators=(",", ":"), | ||
| ) |
There was a problem hiding this comment.
Directly overwriting json_model_override_args with the quantization config will discard any other user-defined model overrides that might already be present in args. It is safer to parse the existing overrides (if any), merge the new quantization config, and then serialize it back to JSON.
existing_override = args.get("json_model_override_args")
override_dict = {}
if existing_override:
try:
override_dict = json.loads(existing_override)
except Exception:
pass
override_dict["quantization_config"] = fp8_quant_config
args["json_model_override_args"] = json.dumps(
override_dict,
separators=(",", ":"),
)
Description
Enable FP8 block-wise quantization for SGLang rollout while keeping FSDP training in BF16. The training engine quantizes BF16 weights to FP8 online before NCCL broadcast to SGLang.
What changed
areal/utils/kernel/fp8_kernel.py): Auto-detect GPU availability; fall back to PyTorch when Triton is unavailable or running on CPU.areal/api/io_struct.py): Addquantizationandquantization_configfields toWeightUpdateMeta.areal/api/cli_args.py): AddquantizationtoSGLangConfig;build_args()injects--quantization=fp8andjson_model_override_argswith block-wise FP8 config.areal/trainer/rl_trainer.py): Propagateconfig.sglang.quantizationinto xccl weight meta construction.areal/engine/fsdp_engine.py): Insert_materialize_and_maybe_quantize()generator into_update_weights_from_distributed(). All ranks call_get_full_tensor()(collective); rank 0 optionally quantizes eligible 2D Linear layers and yields(name, fp8_weight)+(name.replace(".weight", ".weight_scale_inv"), scale).areal/engine/megatron_engine.py): Guard withNotImplementedErrorfor Megatron FP8 path.examples/quantization/): Add FSDP + GSM8K math + FP8 rollout config.tests/utils/kernel/test_fp8_kernel.py): Param filtering + quantization shape/dtype + round-trip accuracy + CPU fallback.Related Issue
#1378
Type of Change
Checklist
tests/utils/kernel/test_fp8_kernel.py, 19 tests)docs/superpowers/plans/2026-05-30-fsdp-sglang-fp8-rollout-proposal.md)mainBreaking Change Details (if applicable):
None.
quantizationdefaults to empty string; existing configs work unchanged.Additional Context
process_weights_after_loadingis not called during SGLang weight updates (same behavior as verl's FP8 rollout). Triton/Cutlass backends useweight_scale_invdirectly without additional post-processing.fsdp_adapter.pylacks colocate methods), so this PR targets the xccl path only.