Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,6 +1832,7 @@ class SGLangConfig:
cpu_offload_gb: int = 0
dtype: str = "bfloat16"
kv_cache_dtype: str = "auto"
quantization: str = ""
dp_size: int = 1 # only used for dp attention
ep_size: int = 1
# lora
Expand Down Expand Up @@ -1914,6 +1915,23 @@ def build_args(
)
args.pop("enable_multithread_load", None)

quantization = args.pop("quantization", "")
if quantization == "fp8":
args["quantization"] = "fp8"
fp8_quant_config = {
"quant_method": "fp8",
"activation_scheme": "dynamic",
"weight_block_size": [128, 128],
}
args["json_model_override_args"] = json.dumps(
{"quantization_config": fp8_quant_config},
separators=(",", ":"),
)
Comment on lines +1926 to +1929
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Directly overwriting json_model_override_args with the quantization config will discard any other user-defined model overrides that might already be present in args. It is safer to parse the existing overrides (if any), merge the new quantization config, and then serialize it back to JSON.

            existing_override = args.get("json_model_override_args")
            override_dict = {}
            if existing_override:
                try:
                    override_dict = json.loads(existing_override)
                except Exception:
                    pass
            override_dict["quantization_config"] = fp8_quant_config
            args["json_model_override_args"] = json.dumps(
                override_dict,
                separators=(",", ":"),
            )

elif quantization:
raise ValueError(
f"SGLangConfig.quantization must be 'fp8' or empty, got {quantization!r}"
)

args = dict(
# Model and tokenizer
tokenizer_path=sglang_config.model_path,
Expand Down
11 changes: 11 additions & 0 deletions areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ class WeightUpdateMeta:

version: int | None = None

quantization: str | None = None
quantization_config: dict | None = None

def with_version(self, version: int) -> "WeightUpdateMeta":
"""Return a copy of this meta with versioned path.

Expand Down Expand Up @@ -252,6 +255,8 @@ def from_megatron_xccl(
lora_name: str = "",
lora_int_id: int = 1,
base_model_name: str = "",
quantization: str | None = None,
quantization_config: dict | None = None,
):
return cls(
type="xccl",
Expand All @@ -261,6 +266,8 @@ def from_megatron_xccl(
lora_name=lora_name,
lora_int_id=lora_int_id,
base_model_name=base_model_name,
quantization=quantization,
quantization_config=quantization_config,
)

@classmethod
Expand All @@ -272,6 +279,8 @@ def from_fsdp_xccl(
lora_name: str = "",
lora_int_id: int = 1,
base_model_name: str = "",
quantization: str | None = None,
quantization_config: dict | None = None,
):
return cls(
type="xccl",
Expand All @@ -281,6 +290,8 @@ def from_fsdp_xccl(
lora_name=lora_name,
lora_int_id=lora_int_id,
base_model_name=base_model_name,
quantization=quantization,
quantization_config=quantization_config,
)

@classmethod
Expand Down
55 changes: 36 additions & 19 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
)
from areal.utils.functional import gather_logprobs, gather_logprobs_entropy
from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer
from areal.utils.kernel.fp8_kernel import scaled_fp8_blockwise, should_quantize_param
from areal.utils.network import find_free_ports, format_host_for_url, gethostip
from areal.utils.offload import is_tms_enabled, torch_memory_saver
from areal.utils.perf_tracer import trace_perf, trace_scope
Expand Down Expand Up @@ -1595,30 +1596,46 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
named_tensors: list[tuple[str, torch.Tensor]] = []
pending_bucket: _PendingWeightUpdateBucket | None = None

if self.config.use_lora:
# For LoRA, only iterate over trainable LoRA parameters
param_iterator = (
(name, param)
for name, param in self._get_model_name_parameters(meta)
if param.requires_grad
def _materialize_and_maybe_quantize():
"""Materialize tensors from FSDP parameters and optionally quantize to FP8.

All ranks must call _get_full_tensor() because DTensor.full_tensor()
is a collective operation. Only rank 0 yields tensors for broadcast.
"""
_param_iterator = (
(
(name, param)
for name, param in self._get_model_name_parameters(meta)
if param.requires_grad
)
if self.config.use_lora
else self._get_model_name_parameters(meta)
)
else:
# For full model, iterate over all parameters
param_iterator = self._get_model_name_parameters(meta)

try:
for name, param in param_iterator:
# Ranks other than 0 only help to get the full tensor
# (DTensor.full_tensor() is a collective; all ranks must
# call _get_full_tensor). Only rank 0 broadcasts to the
# rollout engine, so casting is main-rank-only by design.
tensor = self._get_full_tensor(param)
q_config = (
meta.quantization_config or {} if meta.quantization == "fp8" else {}
)
block_size = q_config.get("weight_block_size", [128, 128])

for _name, _param in _param_iterator:
_tensor = self._get_full_tensor(_param)
if not main_rank:
continue
# Cast fp32 master storage to compute dtype before broadcast.
# Rollout engines (SGLang/vLLM) expect compute dtype (bf16).
tensor = self._cast_to_compute_dtype(tensor)
_tensor = self._cast_to_compute_dtype(_tensor)

if (
meta.quantization == "fp8"
and _tensor.dim() == 2
and should_quantize_param(_name)
):
fp8_weight, scale = scaled_fp8_blockwise(_tensor, block_size)
yield (_name, fp8_weight)
yield (_name.replace(".weight", ".weight_scale_inv"), scale)
else:
yield (_name, _tensor)

try:
for name, tensor in _materialize_and_maybe_quantize():
tensor_size = tensor.numel() * tensor.element_size()
bucket_overflow = (
buffer_size > 0
Expand Down
6 changes: 6 additions & 0 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,12 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None:

@trace_perf("megatron_engine.update_weights_from_distributed", category="comm")
def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None:
if meta.quantization == "fp8":
raise NotImplementedError(
"FP8 weight update is not yet supported for Megatron engine. "
"Use FSDP engine instead."
)

# Reset weight weight meta with local info
meta.nccl_master_address = self.weight_update_master_addr
meta.nccl_master_port = self.weight_update_master_port
Expand Down
11 changes: 11 additions & 0 deletions areal/engine/megatron_utils/fp8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# - tensor_helper.py: FP8 blockwise tensor helper class
# - config.py: Configuration utilities for extracting block size from quantization config

try:
from areal.utils.kernel.fp8_kernel import (
scaled_fp8_blockwise,
should_quantize_param,
)
except ImportError:
pass

from areal.engine.megatron_utils.fp8.config import get_block_size_from_config
from areal.engine.megatron_utils.fp8.deepgemm import (
DEEPGEMM_BLACKWELL,
Expand Down Expand Up @@ -44,6 +52,9 @@
"quantize_params",
"dequantize_params",
"get_block_size_from_config",
# Unified shared kernel
"scaled_fp8_blockwise",
"should_quantize_param",
# Kernels
"blockwise_cast_to_fp8_triton",
"weight_dequant",
Expand Down
8 changes: 8 additions & 0 deletions areal/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,14 @@ def __init__(
}
)

# Propagate SGLang FP8 config to weight update meta
if hasattr(config, "sglang") and config.sglang.quantization:
xccl_kwargs["quantization"] = config.sglang.quantization
if config.sglang.quantization == "fp8":
xccl_kwargs["quantization_config"] = {
"weight_block_size": [128, 128],
}

if self.actor_alloc.backend == "megatron":
self.weight_update_meta = WeightUpdateMeta.from_megatron_xccl(
**xccl_kwargs
Expand Down
5 changes: 5 additions & 0 deletions areal/utils/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

# Kernel utilities
#
# Shared Triton/PyTorch kernels used across engines.
Loading
Loading