You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
New fields in SGLangConfig and WeightUpdateMeta have safe defaults.
Existing training configs without quantization work unchanged.
Only the xccl weight update mode is affected.
Background
Enable FP8 block-wise quantization for SGLang rollout while keeping FSDP training in BF16. Training engine quantizes BF16 weights to FP8 online before NCCL broadcast to SGLang.
awex is not available for FSDP — its fsdp_adapter.py lacks colocate methods (init_colocate_weight_update, execute_colocate_weight_update) that only exist in megatron_adapter.py. This PR targets the xccl path only.
Proposed Paths
Two independent changes:
Weight update path: Insert FP8 quantization into FSDPEngine._update_weights_from_distributed() after all-gather / cast-to-compute-dtype, before bucket assembly.
Configuration path: Thread sglang.quantization through SGLangConfig → RLTrainer → WeightUpdateMeta → FSDPEngine, and have SGLangConfig.build_args() inject the FP8 server config.
Extract existing Triton kernel from megatron_utils/fp8/kernels.py into shared location (no TE/Megatron dep). Add auto-padding, PyTorch fallback, and should_quantize_param() for filtering Linear layers.
2. FSDPEngine (areal/engine/fsdp_engine.py)
Import scaled_fp8_blockwise and should_quantize_param. Replace param iterator in _update_weights_from_distributed() with generator _materialize_and_maybe_quantize():
All ranks call _get_full_tensor() (collective).
Rank 0 checks meta.quantization == "fp8"; if true and should_quantize_param(name) and tensor is 2D, call scaled_fp8_blockwise() and yield (name, fp8_weight) + (name.replace(".weight", ".weight_scale_inv"), scale).
Checklist
SGLangConfigandWeightUpdateMetahave safe defaults.quantizationwork unchanged.Background
Enable FP8 block-wise quantization for SGLang rollout while keeping FSDP training in BF16. Training engine quantizes BF16 weights to FP8 online before NCCL broadcast to SGLang.
awex is not available for FSDP — its
fsdp_adapter.pylacks colocate methods (init_colocate_weight_update,execute_colocate_weight_update) that only exist inmegatron_adapter.py. This PR targets the xccl path only.Proposed Paths
Two independent changes:
FSDPEngine._update_weights_from_distributed()after all-gather / cast-to-compute-dtype, before bucket assembly.sglang.quantizationthroughSGLangConfig→RLTrainer→WeightUpdateMeta→FSDPEngine, and haveSGLangConfig.build_args()inject the FP8 server config.Potential Solution
Weight Update Path
1. Unified FP8 Kernel (
areal/utils/kernel/fp8_kernel.py)Extract existing Triton kernel from
megatron_utils/fp8/kernels.pyinto shared location (no TE/Megatron dep). Add auto-padding, PyTorch fallback, andshould_quantize_param()for filtering Linear layers.2. FSDPEngine (
areal/engine/fsdp_engine.py)Import
scaled_fp8_blockwiseandshould_quantize_param. Replace param iterator in_update_weights_from_distributed()with generator_materialize_and_maybe_quantize():_get_full_tensor()(collective).meta.quantization == "fp8"; if true andshould_quantize_param(name)and tensor is 2D, callscaled_fp8_blockwise()and yield(name, fp8_weight)+(name.replace(".weight", ".weight_scale_inv"), scale).(name, tensor)unchanged.Configuration Path
3. WeightUpdateMeta (
areal/api/io_struct.py)Add
quantization: str | None = Noneandquantization_config: dict | None = None. Updatefrom_fsdp_xccl()/from_megatron_xccl()factory methods.4. SGLangConfig (
areal/api/cli_args.py)Add
quantization: str = "". Inbuild_args(), when"fp8", injectquantization=fp8andjson_model_override_argswith block-wise FP8 config.5. RLTrainer (
areal/trainer/rl_trainer.py)In xccl weight meta construction, propagate
config.sglang.quantizationintoxccl_kwargs.Tests
6.
tests/utils/kernel/test_fp8_kernel.pyParam filtering, quantization shape/dtype, round-trip dequant accuracy, non-multiple dimensions, zero edge case, CPU fallback.
Additional Information
quantizationdefaults to empty string. Existing configs work unchanged.load_format="auto". SGLang loads normally at startup; first weight update overwrites with FP8.Sequence Diagram
sequenceDiagram autonumber participant RT as RLTrainer participant F0 as FSDPEngine (Rank 0) participant Fn as FSDPEngine (Other Ranks) participant SG as SGLang Server %% Startup: Config path rect rgb(200, 255, 200) RT->>RT: read config.sglang.quantization="fp8" RT->>RT: WeightUpdateMeta.from_fsdp_xccl(quantization="fp8") end RT->>F0: connect_engine(weight_update_meta) rect rgb(200, 255, 200) RT->>SG: SGLangConfig.build_args(quantization="fp8") SG->>SG: start with --quantization=fp8 end %% Training: Weight update path loop every weight_update_interval F0->>Fn: dist.barrier(cpu_group) loop _materialize_and_maybe_quantize() F0->>F0: _get_full_tensor(param) F0->>F0: _cast_to_compute_dtype() alt should_quantize_param(name) and dim==2 rect rgb(200, 255, 200) F0->>F0: scaled_fp8_blockwise(BF16) → (FP8, scale) F0->>SG: broadcast(name, fp8_weight) F0->>SG: broadcast(name+".weight_scale_inv", scale) end else skip quantization F0->>SG: broadcast(name, BF16_tensor) end end F0->>Fn: dist.barrier(cpu_group) SG->>SG: update_weights_from_distributed() end