Skip to content

[Feature] FSDP BF16 Training + SGLang FP8 Rollout #1378

@ZiyiTsang

Description

@ZiyiTsang

Checklist

  • New fields in SGLangConfig and WeightUpdateMeta have safe defaults.
  • Existing training configs without quantization work unchanged.
  • Only the xccl weight update mode is affected.

Background

Enable FP8 block-wise quantization for SGLang rollout while keeping FSDP training in BF16. Training engine quantizes BF16 weights to FP8 online before NCCL broadcast to SGLang.

awex is not available for FSDP — its fsdp_adapter.py lacks colocate methods (init_colocate_weight_update, execute_colocate_weight_update) that only exist in megatron_adapter.py. This PR targets the xccl path only.

Proposed Paths

Two independent changes:

  1. Weight update path: Insert FP8 quantization into FSDPEngine._update_weights_from_distributed() after all-gather / cast-to-compute-dtype, before bucket assembly.
  2. Configuration path: Thread sglang.quantization through SGLangConfigRLTrainerWeightUpdateMetaFSDPEngine, and have SGLangConfig.build_args() inject the FP8 server config.

Potential Solution

Weight Update Path

1. Unified FP8 Kernel (areal/utils/kernel/fp8_kernel.py)

Extract existing Triton kernel from megatron_utils/fp8/kernels.py into shared location (no TE/Megatron dep). Add auto-padding, PyTorch fallback, and should_quantize_param() for filtering Linear layers.

2. FSDPEngine (areal/engine/fsdp_engine.py)

Import scaled_fp8_blockwise and should_quantize_param. Replace param iterator in _update_weights_from_distributed() with generator _materialize_and_maybe_quantize():

  • All ranks call _get_full_tensor() (collective).
  • Rank 0 checks meta.quantization == "fp8"; if true and should_quantize_param(name) and tensor is 2D, call scaled_fp8_blockwise() and yield (name, fp8_weight) + (name.replace(".weight", ".weight_scale_inv"), scale).
  • Otherwise yield (name, tensor) unchanged.
  • Bucket loop unchanged.

Configuration Path

3. WeightUpdateMeta (areal/api/io_struct.py)

Add quantization: str | None = None and quantization_config: dict | None = None. Update from_fsdp_xccl() / from_megatron_xccl() factory methods.

4. SGLangConfig (areal/api/cli_args.py)

Add quantization: str = "". In build_args(), when "fp8", inject quantization=fp8 and json_model_override_args with block-wise FP8 config.

5. RLTrainer (areal/trainer/rl_trainer.py)

In xccl weight meta construction, propagate config.sglang.quantization into xccl_kwargs.

Tests

6. tests/utils/kernel/test_fp8_kernel.py

Param filtering, quantization shape/dtype, round-trip dequant accuracy, non-multiple dimensions, zero edge case, CPU fallback.

Additional Information

  • Backward compatibility: quantization defaults to empty string. Existing configs work unchanged.
  • SGLang startup: Keeps load_format="auto". SGLang loads normally at startup; first weight update overwrites with FP8.
  • Rollout correction: Not implemented in this PR. Can be added in follow-up if needed.

Sequence Diagram

sequenceDiagram
    autonumber
    participant RT as RLTrainer
    participant F0 as FSDPEngine (Rank 0)
    participant Fn as FSDPEngine (Other Ranks)
    participant SG as SGLang Server

    %% Startup: Config path
    rect rgb(200, 255, 200)
        RT->>RT: read config.sglang.quantization="fp8"
        RT->>RT: WeightUpdateMeta.from_fsdp_xccl(quantization="fp8")
    end
    RT->>F0: connect_engine(weight_update_meta)
    rect rgb(200, 255, 200)
        RT->>SG: SGLangConfig.build_args(quantization="fp8")
        SG->>SG: start with --quantization=fp8
    end

    %% Training: Weight update path
    loop every weight_update_interval
        F0->>Fn: dist.barrier(cpu_group)

        loop _materialize_and_maybe_quantize()
            F0->>F0: _get_full_tensor(param)
            F0->>F0: _cast_to_compute_dtype()
            alt should_quantize_param(name) and dim==2
                rect rgb(200, 255, 200)
                    F0->>F0: scaled_fp8_blockwise(BF16) → (FP8, scale)
                    F0->>SG: broadcast(name, fp8_weight)
                    F0->>SG: broadcast(name+".weight_scale_inv", scale)
                end
            else skip quantization
                F0->>SG: broadcast(name, BF16_tensor)
            end
        end

        F0->>Fn: dist.barrier(cpu_group)
        SG->>SG: update_weights_from_distributed()
    end
Loading

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions