Skip to content

Feat/fsdp sglang fp8 rollout#1379

Draft
ZiyiTsang wants to merge 2 commits into
areal-project:mainfrom
ZiyiTsang:feat/fsdp-sglang-fp8-rollout
Draft

Feat/fsdp sglang fp8 rollout#1379
ZiyiTsang wants to merge 2 commits into
areal-project:mainfrom
ZiyiTsang:feat/fsdp-sglang-fp8-rollout

Conversation

@ZiyiTsang
Copy link
Copy Markdown
Collaborator

@ZiyiTsang ZiyiTsang commented May 30, 2026

Description

Enable FP8 block-wise quantization for SGLang rollout while keeping FSDP training in BF16. The training engine quantizes BF16 weights to FP8 online before NCCL broadcast to SGLang.

What changed

  • FP8 kernel (areal/utils/kernel/fp8_kernel.py): Auto-detect GPU availability; fall back to PyTorch when Triton is unavailable or running on CPU.
  • Weight update meta (areal/api/io_struct.py): Add quantization and quantization_config fields to WeightUpdateMeta.
  • SGLang config (areal/api/cli_args.py): Add quantization to SGLangConfig; build_args() injects --quantization=fp8 and json_model_override_args with block-wise FP8 config.
  • Trainer (areal/trainer/rl_trainer.py): Propagate config.sglang.quantization into xccl weight meta construction.
  • FSDP engine (areal/engine/fsdp_engine.py): Insert _materialize_and_maybe_quantize() generator into _update_weights_from_distributed(). All ranks call _get_full_tensor() (collective); rank 0 optionally quantizes eligible 2D Linear layers and yields (name, fp8_weight) + (name.replace(".weight", ".weight_scale_inv"), scale).
  • Megatron engine (areal/engine/megatron_engine.py): Guard with NotImplementedError for Megatron FP8 path.
  • Examples (examples/quantization/): Add FSDP + GSM8K math + FP8 rollout config.
  • Tests (tests/utils/kernel/test_fp8_kernel.py): Param filtering + quantization shape/dtype + round-trip accuracy + CPU fallback.

Related Issue

#1378

Type of Change

  • ✨ New feature

Checklist

  • Relevant tests pass; new tests added for new functionality (tests/utils/kernel/test_fp8_kernel.py, 19 tests)
  • Documentation updated (proposal at docs/superpowers/plans/2026-05-30-fsdp-sglang-fp8-rollout-proposal.md)
  • Branch is up to date with main

Breaking Change Details (if applicable):

None. quantization defaults to empty string; existing configs work unchanged.

Additional Context

  • process_weights_after_loading is not called during SGLang weight updates (same behavior as verl's FP8 rollout). Triton/Cutlass backends use weight_scale_inv directly without additional post-processing.
  • DeepGemm UE8M0 requantization may require a follow-up if that backend is used.
  • awex colocate weight update is not available for FSDP (fsdp_adapter.py lacks colocate methods), so this PR targets the xccl path only.

ZiyiTsang and others added 2 commits May 30, 2026 11:35
- Extract Triton kernel from megatron_utils/fp8/kernels.py into shared
  areal/utils/kernel/fp8_kernel.py with no TE/Megatron dependency
- Add auto-padding for non-multiple-of-block-size dimensions
- Add pure-PyTorch fallback when Triton is unavailable
- Add should_quantize_param() for filtering Linear layers

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
- Add online FP8 block-wise quantization (128x128, e4m3fn) to FSDPEngine
  weight sync path before NCCL broadcast to SGLang
- Add unified FP8 kernel with Triton + PyTorch fallback + auto-padding
- Add should_quantize_param() for filtering Linear projection layers
- Thread quantization config through SGLangConfig -> RLTrainer ->
  WeightUpdateMeta -> FSDPEngine
- SGLangConfig.build_args() injects --quantization=fp8 and
  json_model_override_args with block-wise FP8 config
- Add NotImplementedError guard for Megatron FP8 weight update path
- Add quantization example configs in examples/quantization/
- Add unit tests for FP8 kernel (param filtering + quantization)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces online FP8 block-wise quantization support for SGLang rollout while maintaining FSDP training in BF16. It implements a unified FP8 quantization kernel with a Triton path and a PyTorch fallback, updates weight synchronization in the FSDP engine, and adds corresponding configurations, examples, and tests. The review feedback highlights three key issues: a hardcoded absolute path in the test file that will cause environment failures, in-place tensor mutation in the PyTorch fallback kernel that risks corrupting model weights, and potential loss of user-defined model overrides when setting quantization configurations in the CLI arguments.

Comment on lines +28 to +31
spec = importlib.util.spec_from_file_location(
"fp8_kernel",
"/F00120250029/lixiang_share/zengziyi_share/zengziyi/Research/Areal_sub/areal/utils/kernel/fp8_kernel.py",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test file uses a hardcoded absolute path pointing to a specific local directory (/F00120250029/lixiang_share/...). This will cause the tests to fail on any other machine or CI/CD environment. Please resolve the path dynamically relative to the test file's location.

repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
fp8_kernel_path = os.path.join(repo_root, "areal/utils/kernel/fp8_kernel.py")
spec = importlib.util.spec_from_file_location("fp8_kernel", fp8_kernel_path)

Comment on lines +51 to +66
def _scaled_fp8_blockwise_pytorch(
data_hp: torch.Tensor,
block_size: list[int] | tuple[int, int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pure-PyTorch block-wise FP8 quantization.

Args:
data_hp: BF16/FP16 weight tensor, shape (M, N).
block_size: [block_m, block_n].

Returns:
(fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n))
and scale = absmax / FP8_MAX.
"""
block_size0, block_size1 = block_size[0], block_size[1]
original_shape = data_hp.shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The pure-PyTorch fallback function _scaled_fp8_blockwise_pytorch mutates the input tensor data_hp in-place (e.g., data_hp.mul_(scale_fp) and data_hp.clamp_). If the input tensor is a direct reference to the model's weights (which can happen if _get_full_tensor returns the parameter data directly), this will corrupt the model's weights during training. Cloning the input tensor at the start of the function ensures safety.

Suggested change
def _scaled_fp8_blockwise_pytorch(
data_hp: torch.Tensor,
block_size: list[int] | tuple[int, int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pure-PyTorch block-wise FP8 quantization.
Args:
data_hp: BF16/FP16 weight tensor, shape (M, N).
block_size: [block_m, block_n].
Returns:
(fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n))
and scale = absmax / FP8_MAX.
"""
block_size0, block_size1 = block_size[0], block_size[1]
original_shape = data_hp.shape
def _scaled_fp8_blockwise_pytorch(
data_hp: torch.Tensor,
block_size: list[int] | tuple[int, int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pure-PyTorch block-wise FP8 quantization.
Args:
data_hp: BF16/FP16 weight tensor, shape (M, N).
block_size: [block_m, block_n].
Returns:
(fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n))
and scale = absmax / FP8_MAX.
"""
data_hp = data_hp.clone()
block_size0, block_size1 = block_size[0], block_size[1]
original_shape = data_hp.shape

Comment thread areal/api/cli_args.py
Comment on lines +1926 to +1929
args["json_model_override_args"] = json.dumps(
{"quantization_config": fp8_quant_config},
separators=(",", ":"),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly overwriting json_model_override_args with the quantization config will discard any other user-defined model overrides that might already be present in args. It is safer to parse the existing overrides (if any), merge the new quantization config, and then serialize it back to JSON.

            existing_override = args.get("json_model_override_args")
            override_dict = {}
            if existing_override:
                try:
                    override_dict = json.loads(existing_override)
                except Exception:
                    pass
            override_dict["quantization_config"] = fp8_quant_config
            args["json_model_override_args"] = json.dumps(
                override_dict,
                separators=(",", ":"),
            )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant