diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 7c53445722..07a6fac07c 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -1832,6 +1832,7 @@ class SGLangConfig: cpu_offload_gb: int = 0 dtype: str = "bfloat16" kv_cache_dtype: str = "auto" + quantization: str = "" dp_size: int = 1 # only used for dp attention ep_size: int = 1 # lora @@ -1914,6 +1915,23 @@ def build_args( ) args.pop("enable_multithread_load", None) + quantization = args.pop("quantization", "") + if quantization == "fp8": + args["quantization"] = "fp8" + fp8_quant_config = { + "quant_method": "fp8", + "activation_scheme": "dynamic", + "weight_block_size": [128, 128], + } + args["json_model_override_args"] = json.dumps( + {"quantization_config": fp8_quant_config}, + separators=(",", ":"), + ) + elif quantization: + raise ValueError( + f"SGLangConfig.quantization must be 'fp8' or empty, got {quantization!r}" + ) + args = dict( # Model and tokenizer tokenizer_path=sglang_config.model_path, diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index 84d099831e..6e7c899987 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -200,6 +200,9 @@ class WeightUpdateMeta: version: int | None = None + quantization: str | None = None + quantization_config: dict | None = None + def with_version(self, version: int) -> "WeightUpdateMeta": """Return a copy of this meta with versioned path. @@ -252,6 +255,8 @@ def from_megatron_xccl( lora_name: str = "", lora_int_id: int = 1, base_model_name: str = "", + quantization: str | None = None, + quantization_config: dict | None = None, ): return cls( type="xccl", @@ -261,6 +266,8 @@ def from_megatron_xccl( lora_name=lora_name, lora_int_id=lora_int_id, base_model_name=base_model_name, + quantization=quantization, + quantization_config=quantization_config, ) @classmethod @@ -272,6 +279,8 @@ def from_fsdp_xccl( lora_name: str = "", lora_int_id: int = 1, base_model_name: str = "", + quantization: str | None = None, + quantization_config: dict | None = None, ): return cls( type="xccl", @@ -281,6 +290,8 @@ def from_fsdp_xccl( lora_name=lora_name, lora_int_id=lora_int_id, base_model_name=base_model_name, + quantization=quantization, + quantization_config=quantization_config, ) @classmethod diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index b17bcf6967..f8e2e2bd3f 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -128,6 +128,7 @@ ) from areal.utils.functional import gather_logprobs, gather_logprobs_entropy from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer +from areal.utils.kernel.fp8_kernel import scaled_fp8_blockwise, should_quantize_param from areal.utils.network import find_free_ports, format_host_for_url, gethostip from areal.utils.offload import is_tms_enabled, torch_memory_saver from areal.utils.perf_tracer import trace_perf, trace_scope @@ -1595,30 +1596,46 @@ def _update_weights_from_distributed(self, meta: WeightUpdateMeta): named_tensors: list[tuple[str, torch.Tensor]] = [] pending_bucket: _PendingWeightUpdateBucket | None = None - if self.config.use_lora: - # For LoRA, only iterate over trainable LoRA parameters - param_iterator = ( - (name, param) - for name, param in self._get_model_name_parameters(meta) - if param.requires_grad + def _materialize_and_maybe_quantize(): + """Materialize tensors from FSDP parameters and optionally quantize to FP8. + + All ranks must call _get_full_tensor() because DTensor.full_tensor() + is a collective operation. Only rank 0 yields tensors for broadcast. + """ + _param_iterator = ( + ( + (name, param) + for name, param in self._get_model_name_parameters(meta) + if param.requires_grad + ) + if self.config.use_lora + else self._get_model_name_parameters(meta) ) - else: - # For full model, iterate over all parameters - param_iterator = self._get_model_name_parameters(meta) - try: - for name, param in param_iterator: - # Ranks other than 0 only help to get the full tensor - # (DTensor.full_tensor() is a collective; all ranks must - # call _get_full_tensor). Only rank 0 broadcasts to the - # rollout engine, so casting is main-rank-only by design. - tensor = self._get_full_tensor(param) + q_config = ( + meta.quantization_config or {} if meta.quantization == "fp8" else {} + ) + block_size = q_config.get("weight_block_size", [128, 128]) + + for _name, _param in _param_iterator: + _tensor = self._get_full_tensor(_param) if not main_rank: continue - # Cast fp32 master storage to compute dtype before broadcast. - # Rollout engines (SGLang/vLLM) expect compute dtype (bf16). - tensor = self._cast_to_compute_dtype(tensor) + _tensor = self._cast_to_compute_dtype(_tensor) + + if ( + meta.quantization == "fp8" + and _tensor.dim() == 2 + and should_quantize_param(_name) + ): + fp8_weight, scale = scaled_fp8_blockwise(_tensor, block_size) + yield (_name, fp8_weight) + yield (_name.replace(".weight", ".weight_scale_inv"), scale) + else: + yield (_name, _tensor) + try: + for name, tensor in _materialize_and_maybe_quantize(): tensor_size = tensor.numel() * tensor.element_size() bucket_overflow = ( buffer_size > 0 diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index f972fd0834..2a76d11de9 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1743,6 +1743,12 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta) -> None: @trace_perf("megatron_engine.update_weights_from_distributed", category="comm") def _update_weights_from_distributed(self, meta: WeightUpdateMeta) -> None: + if meta.quantization == "fp8": + raise NotImplementedError( + "FP8 weight update is not yet supported for Megatron engine. " + "Use FSDP engine instead." + ) + # Reset weight weight meta with local info meta.nccl_master_address = self.weight_update_master_addr meta.nccl_master_port = self.weight_update_master_port diff --git a/areal/engine/megatron_utils/fp8/__init__.py b/areal/engine/megatron_utils/fp8/__init__.py index 84a088bacb..ea1e7d3b07 100644 --- a/areal/engine/megatron_utils/fp8/__init__.py +++ b/areal/engine/megatron_utils/fp8/__init__.py @@ -13,6 +13,14 @@ # - tensor_helper.py: FP8 blockwise tensor helper class # - config.py: Configuration utilities for extracting block size from quantization config +try: + from areal.utils.kernel.fp8_kernel import ( + scaled_fp8_blockwise, + should_quantize_param, + ) +except ImportError: + pass + from areal.engine.megatron_utils.fp8.config import get_block_size_from_config from areal.engine.megatron_utils.fp8.deepgemm import ( DEEPGEMM_BLACKWELL, @@ -44,6 +52,9 @@ "quantize_params", "dequantize_params", "get_block_size_from_config", + # Unified shared kernel + "scaled_fp8_blockwise", + "should_quantize_param", # Kernels "blockwise_cast_to_fp8_triton", "weight_dequant", diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index 73b98c21fc..e58dc1ec94 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -345,6 +345,14 @@ def __init__( } ) + # Propagate SGLang FP8 config to weight update meta + if hasattr(config, "sglang") and config.sglang.quantization: + xccl_kwargs["quantization"] = config.sglang.quantization + if config.sglang.quantization == "fp8": + xccl_kwargs["quantization_config"] = { + "weight_block_size": [128, 128], + } + if self.actor_alloc.backend == "megatron": self.weight_update_meta = WeightUpdateMeta.from_megatron_xccl( **xccl_kwargs diff --git a/areal/utils/kernel/__init__.py b/areal/utils/kernel/__init__.py new file mode 100644 index 0000000000..fcd2bac499 --- /dev/null +++ b/areal/utils/kernel/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Kernel utilities +# +# Shared Triton/PyTorch kernels used across engines. diff --git a/areal/utils/kernel/fp8_kernel.py b/areal/utils/kernel/fp8_kernel.py new file mode 100644 index 0000000000..2d92196464 --- /dev/null +++ b/areal/utils/kernel/fp8_kernel.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unified FP8 block-wise quantization kernel. + +Compatible with SGLang/vLLM FP8 block-wise weight format. +Uses 128x128 blocks by default, e4m3fn dtype. + +Triton path: high-performance GPU kernel. +PyTorch fallback: pure-PyTorch, no Triton dependency. +""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING + +import torch +import torch.nn.functional as F + +from areal.utils.math import ceil_div + +if TYPE_CHECKING: + pass + +logger = logging.getLogger("FP8Kernel") + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- +FP8_DTYPE = torch.float8_e4m3fn +FP8_MAX = torch.finfo(FP8_DTYPE).max # 448.0 +FP8_MIN = -FP8_MAX + +# --------------------------------------------------------------------------- +# Optional Triton +# --------------------------------------------------------------------------- +_TRITON_AVAILABLE = False +try: + import triton + import triton.language as tl + + _TRITON_AVAILABLE = True +except ImportError: + pass + + +# --------------------------------------------------------------------------- +# PyTorch fallback (always available) +# --------------------------------------------------------------------------- +def _scaled_fp8_blockwise_pytorch( + data_hp: torch.Tensor, + block_size: list[int] | tuple[int, int], +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure-PyTorch block-wise FP8 quantization. + + Args: + data_hp: BF16/FP16 weight tensor, shape (M, N). + block_size: [block_m, block_n]. + + Returns: + (fp8_weight, scale) where scale.shape == (ceil(M/block_m), ceil(N/block_n)) + and scale = absmax / FP8_MAX. + """ + block_size0, block_size1 = block_size[0], block_size[1] + original_shape = data_hp.shape + + # Pad to multiples of block size + pad_dim0 = (block_size0 - data_hp.shape[0] % block_size0) % block_size0 + pad_dim1 = (block_size1 - data_hp.shape[1] % block_size1) % block_size1 + if pad_dim0 > 0 or pad_dim1 > 0: + data_hp = F.pad(data_hp, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0) + + max_dtype = FP8_MAX + padded_shape = data_hp.shape + blk_m = data_hp.shape[0] // block_size0 + blk_n = data_hp.shape[1] // block_size1 + + # Reshape to (blk_m, block_m, blk_n, block_n) -> permute -> flatten blocks + data_hp = data_hp.reshape(blk_m, block_size0, blk_n, block_size1) + data_hp = data_hp.permute(0, 2, 1, 3).contiguous() + data_hp = data_hp.to(torch.float32).flatten(start_dim=2) + + # Per-block absmax + max_abs = data_hp.abs().amax(dim=-1, keepdim=True) + scale_fp = torch.empty_like(max_abs) + torch.div(max_dtype, max_abs.clamp_min(1e-10), out=scale_fp) + scale_fp = torch.where(max_abs == 0, torch.ones_like(scale_fp), scale_fp) + scale_fp = torch.where(max_abs.isinf(), torch.ones_like(scale_fp), scale_fp) + + descale_fp = torch.reciprocal(scale_fp) + data_hp.mul_(scale_fp) + data_hp.clamp_(min=-max_dtype, max=max_dtype) + + fp_data = data_hp.to(FP8_DTYPE) + + # Reshape back + fp_data = fp_data.reshape(blk_m, blk_n, block_size0, block_size1) + fp_data = fp_data.permute(0, 2, 1, 3).reshape(padded_shape) + + # Crop padding + if original_shape[0] != padded_shape[0] or original_shape[1] != padded_shape[1]: + fp_data = fp_data[: original_shape[0], : original_shape[1]].contiguous() + + return fp_data, descale_fp.squeeze(-1) + + +# --------------------------------------------------------------------------- +# Triton kernel (optional, preferred) +# --------------------------------------------------------------------------- +if _TRITON_AVAILABLE: + + @triton.jit + def _blockwise_cast_to_fp8_triton( + X, + Y, + S, + stride_xm, + stride_xn, + stride_ym, + stride_yn, + stride_sm, + stride_sn, + M, + N, + eps, + fp8_min, + fp8_max, + BLOCK_M: tl.constexpr = 128, + BLOCK_N: tl.constexpr = 128, + ): + pid_m = tl.cast(tl.program_id(axis=0), tl.int64) + pid_n = tl.cast(tl.program_id(axis=1), tl.int64) + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = off_m < M + mask_n = off_n < N + mask = mask_m[:, None] & mask_n[None, :] + + x = tl.load( + X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, + mask=mask, + other=0.0, + ).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(x)), eps) + x_s = _absmax / fp8_max + s_inv = 1.0 / x_s + y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty) + + tl.store( + Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask + ) + tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s) + + def _blockwise_cast_to_fp8_triton_wrapper( + x: torch.Tensor, + block_size: list[int] | tuple[int, int], + ) -> tuple[torch.Tensor, torch.Tensor]: + BLOCK_M, BLOCK_N = block_size[0], block_size[1] + M, N = x.shape + y = torch.empty(M, N, device=x.device, dtype=FP8_DTYPE) + s = torch.empty( + ceil_div(M, BLOCK_M), + ceil_div(N, BLOCK_N), + dtype=torch.float32, + device=x.device, + ) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"])) + + kwargs = { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "num_warps": 8 if x.is_contiguous() else 1, + "num_stages": 2 if x.is_contiguous() else 4, + } + _blockwise_cast_to_fp8_triton[grid]( + x, + y, + s, + *x.stride(), + *y.stride(), + *s.stride(), + M, + N, + 1e-10, + FP8_MIN, + FP8_MAX, + **kwargs, + ) + return y, s + + +# --------------------------------------------------------------------------- +# Unified public API +# --------------------------------------------------------------------------- +def scaled_fp8_blockwise( + data_hp: torch.Tensor, + weight_block_size: list[int] | tuple[int, int] | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Cast a 2D tensor to FP8 with block-wise quantization. + + Args: + data_hp: Input tensor of shape (M, N). Must be 2D. + weight_block_size: Block size as [BLOCK_M, BLOCK_N]. + Defaults to [128, 128]. + + Returns: + Tuple of (fp8_data, scale): + - fp8_data: FP8 quantized tensor of original shape. + - scale: Per-block scale factors of shape + (ceil(M/BLOCK_M), ceil(N/BLOCK_N)). + scale = absmax / FP8_MAX. Dequantize with: weight * scale. + """ + assert len(data_hp.shape) == 2, f"Only 2D input supported, got {data_hp.shape}" + + if weight_block_size is None: + weight_block_size = [128, 128] + + use_triton = ( + _TRITON_AVAILABLE + and os.environ.get("DISABLE_TRITON_FP8", "0") != "1" + and data_hp.device.type == "cuda" + ) + if use_triton: + # Triton path with auto-padding + block_size0, block_size1 = weight_block_size[0], weight_block_size[1] + original_shape = data_hp.shape + pad_dim0 = (block_size0 - data_hp.shape[0] % block_size0) % block_size0 + pad_dim1 = (block_size1 - data_hp.shape[1] % block_size1) % block_size1 + + if pad_dim0 > 0 or pad_dim1 > 0: + data_hp = F.pad( + data_hp, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0 + ) + + fp_data, scale = _blockwise_cast_to_fp8_triton_wrapper( + data_hp, weight_block_size + ) + + if pad_dim0 > 0 or pad_dim1 > 0: + fp_data = fp_data[: original_shape[0], : original_shape[1]].contiguous() + + return fp_data, scale + + # PyTorch fallback + logger.debug("Triton unavailable or disabled, using PyTorch fallback for FP8 quant") + return _scaled_fp8_blockwise_pytorch(data_hp, weight_block_size) + + +# --------------------------------------------------------------------------- +# Parameter filtering (which layers to quantize) +# --------------------------------------------------------------------------- +def should_quantize_param(param_name: str) -> bool: + """Determine whether a parameter should be quantized to FP8. + + Matches SGLang's FP8 quantization rules. Only Linear weight layers + are quantized; embeddings, norms, and output heads are skipped. + """ + if not param_name.endswith(".weight"): + return False + + param_lower = param_name.lower() + + # Exclude patterns + exclude_patterns = [ + "embed_tokens", + "lm_head", + "layernorm", + "norm", + "ln_", + "embeddings", + "mlp.gate.weight", # MoE router + ] + for pattern in exclude_patterns: + if pattern in param_lower: + return False + + # Include patterns (Linear layers) + include_patterns = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + "fc1", + "fc2", + "mlp", + ] + for pattern in include_patterns: + if pattern in param_lower: + return True + + return False diff --git a/examples/quantization/README.md b/examples/quantization/README.md new file mode 100644 index 0000000000..e1141b8495 --- /dev/null +++ b/examples/quantization/README.md @@ -0,0 +1,44 @@ +# Quantization Examples + +FSDP BF16 Training + SGLang FP8 Rollout + +## Overview + +These configs demonstrate online FP8 block-wise quantization for SGLang inference rollout while keeping FSDP training in BF16. The training engine quantizes BF16 weights to FP8 (128x128 blocks, e4m3fn) before NCCL broadcast to SGLang. + +## Configs + +| Config | Engine | Task | Quantization | +|--------|--------|------|-------------| +| `fsdp_math_grpo_fp8.yaml` | FSDP | GSM8K math (GRPO) | FP8 block-wise | + +## How It Works + +1. **Training**: FSDPEngine keeps weights in BF16, computes gradients in BF16 +2. **Weight sync**: Before each weight update broadcast, FSDPEngine all-gathers sharded weights, then quantizes eligible 2D Linear layers to FP8 with per-128x128-block scales +3. **Broadcast**: `fp8_weight` (float8_e4m3fn) and `weight_scale_inv` (float32) are broadcast separately via NCCL +4. **Rollout**: SGLang receives FP8 weights and scales, uses them directly for block-wise FP8 GEMM + +## Parameters Filtered + +Quantized: `q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`, `fc1`, `fc2` + +Skipped: `embed_tokens`, `lm_head`, `layernorm`, `norm`, `ln_`, `embeddings`, `mlp.gate.weight` (MoE router) + +## Usage + +```bash +python -m areal.examples.math.gsm8k_rl \ + --config examples/quantization/fsdp_math_grpo_fp8.yaml +``` + +## Requirements + +- SGLang with FP8 support (`--quantization=fp8`) +- CUDA GPU with FP8 compute capability (SM89+) +- `sglang.quantization: fp8` in config + +## See Also + +- `areal/utils/kernel/fp8_kernel.py` - Unified FP8 quantization kernel +- `docs/superpowers/plans/2026-05-30-fsdp-sglang-fp8-rollout-proposal.md` - Design proposal diff --git a/examples/quantization/fsdp_math_grpo_fp8.yaml b/examples/quantization/fsdp_math_grpo_fp8.yaml new file mode 100644 index 0000000000..2924ce7fe6 --- /dev/null +++ b/examples/quantization/fsdp_math_grpo_fp8.yaml @@ -0,0 +1,186 @@ +experiment_name: fsdp-math-grpo-fp8 +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + + +scheduler: + type: null + +rollout: + backend: "sglang:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: false + agent: + mode: inline + export_style: individual + turn_discount: 1.0 + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + max_tokens: 2048 + greedy: false + temperature: 1.0 + +actor: + backend: "fsdp:d4p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + packing_algorithm: ffd + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: xccl + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + packing_algorithm: ffd + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + quantization: fp8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/tests/experimental/weight_update/test_fp8_nccl_integration.py b/tests/experimental/weight_update/test_fp8_nccl_integration.py new file mode 100644 index 0000000000..b69169d9c9 --- /dev/null +++ b/tests/experimental/weight_update/test_fp8_nccl_integration.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import subprocess +import sys + +import pytest +import torch + +from areal.infra.platforms import current_platform +from areal.infra.utils.proc import kill_process_tree +from areal.utils.network import find_free_ports + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" +) + +# Project root so that torchrun workers can resolve `from tests.*` imports. +# pytest adds "." via pyproject.toml `pythonpath`, but subprocesses don't inherit that. +_PROJECT_ROOT = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") +) + + +def _run_weight_update_test(n_gpus: int, test_type: str, output: str): + port = find_free_ports(1)[0] + env = os.environ.copy() + env["PYTHONPATH"] = _PROJECT_ROOT + os.pathsep + env.get("PYTHONPATH", "") + proc = subprocess.Popen( + [ + "torchrun", + f"--nproc_per_node={n_gpus}", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/experimental/weight_update/torchrun/run_fp8_weight_transfer.py", + f"--test_type={test_type}", + f"--output={output}", + ], + text=True, + stderr=sys.stdout, + stdout=sys.stdout, + env=env, + ) + try: + proc.wait() + except BaseException: + kill_process_tree(proc.pid) + raise + if proc.returncode != 0: + pytest.fail(f"torchrun exited with code {proc.returncode}") + + with open(output) as f: + result = f.read().strip() + assert result == "Passed", f"Test failed: {result}" + + +@pytest.mark.multi_gpu +@pytest.mark.slow +def test_fp8_weight_transfer_2gpu(tmp_path_factory): + """Test FP8 block-wise quantized weight transfer over NCCL with 2 GPUs. + + Rank 0 quantizes a BF16 weight to FP8, broadcasts the FP8 tensor and + per-block scale to rank 1, plus a non-quantized 1D norm weight. + Rank 1 verifies all tensors match exactly. + """ + if current_platform.device_count() < 2: + pytest.skip("This test requires 2 GPUs") + output = tmp_path_factory.mktemp("test_output") / "fp8_weight_transfer.out" + _run_weight_update_test(2, "fp8_weight_transfer", str(output)) diff --git a/tests/experimental/weight_update/test_fp8_weight_sync_unit.py b/tests/experimental/weight_update/test_fp8_weight_sync_unit.py new file mode 100644 index 0000000000..ab25117c02 --- /dev/null +++ b/tests/experimental/weight_update/test_fp8_weight_sync_unit.py @@ -0,0 +1,382 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for FP8 weight synchronization logic. + +Tests generator yield order, bucket assembly, and ParamSpec generation +without requiring GPU or distributed environment. +""" + +from __future__ import annotations + +import torch + +from areal.api import ParamSpec +from areal.api.alloc_mode import ParallelStrategy +from areal.api.cli_args import SchedulingStrategy +from areal.api.io_struct import WeightUpdateMeta +from areal.utils.kernel.fp8_kernel import scaled_fp8_blockwise, should_quantize_param + +# --------------------------------------------------------------------------- +# Standalone helpers matching FSDPEngine logic +# --------------------------------------------------------------------------- + + +def _materialize_and_maybe_quantize(params, quantization, block_size, main_rank): + """Standalone version of the generator for unit testing.""" + for name, tensor in params: + if not main_rank: + continue + if quantization == "fp8" and tensor.dim() == 2 and should_quantize_param(name): + fp8_weight, scale = scaled_fp8_blockwise(tensor, block_size) + yield (name, fp8_weight) + yield (name.replace(".weight", ".weight_scale_inv"), scale) + else: + yield (name, tensor) + + +def _assemble_buckets(generator, chunk_size_mb): + """Assemble yield items into buckets based on memory limit. + + Mirrors the bucket assembly logic in + FSDPEngine._update_weights_from_distributed. + """ + chunk_size = chunk_size_mb * 1024 * 1024 + buffer_size = 0 + named_tensors = [] + buckets = [] + + for name, tensor in generator: + tensor_size = tensor.numel() * tensor.element_size() + bucket_overflow = buffer_size > 0 and tensor_size + buffer_size > chunk_size + if bucket_overflow: + buckets.append(named_tensors) + named_tensors = [] + buffer_size = 0 + buffer_size += tensor_size + named_tensors.append((name, tensor)) + + if named_tensors: + buckets.append(named_tensors) + + return buckets + + +def _build_param_specs(named_tensors): + """Build ParamSpec list from named tensors. + + Mirrors the logic in _update_bucket_weights_from_distributed_async. + """ + return [ + ParamSpec( + name=name, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype).split("torch.")[1], + ) + for name, tensor in named_tensors + ] + + +# --------------------------------------------------------------------------- +# Test generator yield behavior +# --------------------------------------------------------------------------- + + +class TestMaterializeAndMaybeQuantize: + """Tests for _materialize_and_maybe_quantize generator.""" + + def test_yield_order_weight_then_scale(self): + """Eligible 2D weight yields (weight, scale) in that order.""" + params = [ + ( + "model.layers.0.self_attn.q_proj.weight", + torch.randn(256, 512, dtype=torch.bfloat16), + ), + ] + result = list( + _materialize_and_maybe_quantize(params, "fp8", [128, 128], main_rank=True) + ) + + assert len(result) == 2 + assert result[0][0] == "model.layers.0.self_attn.q_proj.weight" + assert result[0][1].dtype == torch.float8_e4m3fn + assert result[1][0] == "model.layers.0.self_attn.q_proj.weight_scale_inv" + assert result[1][1].dtype == torch.float32 + + def test_skip_non_2d_tensor(self): + """1D tensors (bias, norm) are not quantized.""" + params = [ + ( + "model.layers.0.input_layernorm.weight", + torch.randn(256, dtype=torch.bfloat16), + ), + ] + result = list( + _materialize_and_maybe_quantize(params, "fp8", [128, 128], main_rank=True) + ) + + assert len(result) == 1 + assert result[0][1].dtype == torch.bfloat16 + + def test_skip_embedding(self): + """Embedding weights are not quantized.""" + params = [ + ("model.embed_tokens.weight", torch.randn(1000, 256, dtype=torch.bfloat16)), + ] + result = list( + _materialize_and_maybe_quantize(params, "fp8", [128, 128], main_rank=True) + ) + + assert len(result) == 1 + assert result[0][1].dtype == torch.bfloat16 + + def test_non_main_rank_yields_nothing(self): + """Non-main ranks do not yield tensors.""" + params = [ + ( + "model.layers.0.self_attn.q_proj.weight", + torch.randn(256, 512, dtype=torch.bfloat16), + ), + ] + result = list( + _materialize_and_maybe_quantize(params, "fp8", [128, 128], main_rank=False) + ) + + assert len(result) == 0 + + def test_no_quantization_passthrough(self): + """When quantization is None, all params pass through unchanged.""" + params = [ + ( + "model.layers.0.self_attn.q_proj.weight", + torch.randn(256, 512, dtype=torch.bfloat16), + ), + ( + "model.layers.0.input_layernorm.weight", + torch.randn(256, dtype=torch.bfloat16), + ), + ] + result = list( + _materialize_and_maybe_quantize(params, None, [128, 128], main_rank=True) + ) + + assert len(result) == 2 + assert result[0][1].dtype == torch.bfloat16 + assert result[1][1].dtype == torch.bfloat16 + + def test_mixed_quantizable_and_non_quantizable(self): + """Mixed params: some quantized, some pass-through.""" + params = [ + ( + "model.layers.0.self_attn.q_proj.weight", + torch.randn(256, 512, dtype=torch.bfloat16), + ), + ( + "model.layers.0.input_layernorm.weight", + torch.randn(256, dtype=torch.bfloat16), + ), + ( + "model.layers.0.mlp.gate_proj.weight", + torch.randn(256, 512, dtype=torch.bfloat16), + ), + ] + result = list( + _materialize_and_maybe_quantize(params, "fp8", [128, 128], main_rank=True) + ) + + assert len(result) == 5 # 2 weights + 2 scales + 1 norm + names = [r[0] for r in result] + assert "model.layers.0.self_attn.q_proj.weight" in names + assert "model.layers.0.self_attn.q_proj.weight_scale_inv" in names + assert "model.layers.0.mlp.gate_proj.weight" in names + assert "model.layers.0.mlp.gate_proj.weight_scale_inv" in names + assert "model.layers.0.input_layernorm.weight" in names + + +# --------------------------------------------------------------------------- +# Test bucket assembly +# --------------------------------------------------------------------------- + + +class TestBucketAssembly: + """Tests for _assemble_buckets memory-chunked grouping.""" + + def test_single_bucket(self): + """Small tensors all fit in one bucket.""" + params = [ + ("w1", torch.randn(256, 512, dtype=torch.bfloat16)), + ("s1", torch.randn(2, 4, dtype=torch.float32)), + ("w2", torch.randn(256, 512, dtype=torch.bfloat16)), + ("s2", torch.randn(2, 4, dtype=torch.float32)), + ] + buckets = _assemble_buckets(iter(params), chunk_size_mb=100) + + assert len(buckets) == 1 + assert len(buckets[0]) == 4 + + def test_weight_scale_paired_in_same_bucket(self): + """Weight and its scale naturally fit in same bucket.""" + params = [ + ("layers.0.q_proj.weight", torch.randn(256, 512).to(torch.float8_e4m3fn)), + ( + "layers.0.q_proj.weight_scale_inv", + torch.randn(2, 4, dtype=torch.float32), + ), + ] + buckets = _assemble_buckets(iter(params), chunk_size_mb=100) + + assert len(buckets) == 1 + names = [n for n, _ in buckets[0]] + assert "layers.0.q_proj.weight" in names + assert "layers.0.q_proj.weight_scale_inv" in names + + def test_weight_scale_split_when_weight_oversized(self): + """When single weight exceeds bucket, scale goes to next bucket.""" + large_weight = torch.randn(4096, 4096, dtype=torch.bfloat16) # ~32 MB + params = [ + ("large.weight", large_weight), + ("large.weight_scale_inv", torch.randn(32, 32, dtype=torch.float32)), + ] + buckets = _assemble_buckets(iter(params), chunk_size_mb=10) + + assert len(buckets) == 2 + assert buckets[0][0][0] == "large.weight" + assert buckets[1][0][0] == "large.weight_scale_inv" + + def test_multiple_buckets(self): + """Many tensors split across multiple buckets.""" + params = [] + for i in range(10): + params.append((f"w{i}", torch.randn(1024, 1024, dtype=torch.bfloat16))) + buckets = _assemble_buckets(iter(params), chunk_size_mb=1) + + assert len(buckets) > 1 + total_tensors = sum(len(b) for b in buckets) + assert total_tensors == 10 + + +# --------------------------------------------------------------------------- +# Test ParamSpec generation +# --------------------------------------------------------------------------- + + +class TestParamSpecGeneration: + """Tests for ParamSpec list generation from named tensors.""" + + def test_fp8_weight_dtype(self): + """FP8 weight ParamSpec has float8_e4m3fn dtype.""" + named_tensors = [ + ("q_proj.weight", torch.randn(256, 512).to(torch.float8_e4m3fn)), + ] + specs = _build_param_specs(named_tensors) + + assert len(specs) == 1 + assert specs[0].name == "q_proj.weight" + assert specs[0].dtype == "float8_e4m3fn" + assert specs[0].shape == (256, 512) + + def test_scale_dtype_float32(self): + """Scale ParamSpec has float32 dtype.""" + named_tensors = [ + ("q_proj.weight_scale_inv", torch.randn(2, 4, dtype=torch.float32)), + ] + specs = _build_param_specs(named_tensors) + + assert len(specs) == 1 + assert specs[0].dtype == "float32" + + def test_bf16_weight_dtype(self): + """Non-quantized weight ParamSpec has bfloat16 dtype.""" + named_tensors = [ + ("norm.weight", torch.randn(256, dtype=torch.bfloat16)), + ] + specs = _build_param_specs(named_tensors) + + assert specs[0].dtype == "bfloat16" + + def test_mixed_specs(self): + """Mixed FP8 + BF16 tensors produce correct spec list.""" + named_tensors = [ + ("q_proj.weight", torch.randn(256, 512).to(torch.float8_e4m3fn)), + ("q_proj.weight_scale_inv", torch.randn(2, 4, dtype=torch.float32)), + ("norm.weight", torch.randn(256, dtype=torch.bfloat16)), + ] + specs = _build_param_specs(named_tensors) + + assert len(specs) == 3 + dtypes = [s.dtype for s in specs] + assert "float8_e4m3fn" in dtypes + assert "float32" in dtypes + assert "bfloat16" in dtypes + + +# --------------------------------------------------------------------------- +# Test WeightUpdateMeta serialization +# --------------------------------------------------------------------------- + + +class TestWeightUpdateMetaSerialization: + """Tests for WeightUpdateMeta with quantization fields.""" + + def test_from_fsdp_xccl_with_quantization(self): + """from_fsdp_xccl preserves quantization fields.""" + from areal.api import ModelAllocation + + alloc = ModelAllocation( + backend="fsdp", + name="test", + parallel=ParallelStrategy(), + scheduling_strategy=SchedulingStrategy(), + ) + meta = WeightUpdateMeta.from_fsdp_xccl( + gen_allocation=alloc, + quantization="fp8", + quantization_config={"weight_block_size": [128, 128]}, + ) + + assert meta.quantization == "fp8" + assert meta.quantization_config == {"weight_block_size": [128, 128]} + + def test_from_megatron_xccl_with_quantization(self): + """from_megatron_xccl preserves quantization fields.""" + from areal.api import ModelAllocation + + alloc = ModelAllocation( + backend="megatron", + name="test", + parallel=ParallelStrategy(), + scheduling_strategy=SchedulingStrategy(), + ) + meta = WeightUpdateMeta.from_megatron_xccl( + gen_allocation=alloc, + quantization="fp8", + ) + + assert meta.quantization == "fp8" + + def test_with_version_preserves_quantization(self): + """with_version() copy preserves quantization fields.""" + from areal.api import ModelAllocation + + alloc = ModelAllocation( + backend="fsdp", + name="test", + parallel=ParallelStrategy(), + scheduling_strategy=SchedulingStrategy(), + ) + meta = WeightUpdateMeta.from_fsdp_xccl( + gen_allocation=alloc, + quantization="fp8", + quantization_config={"weight_block_size": [128, 128]}, + ) + meta_v2 = meta.with_version(2) + + assert meta_v2.quantization == "fp8" + assert meta_v2.quantization_config == {"weight_block_size": [128, 128]} + assert meta_v2.version == 2 + + def test_default_no_quantization(self): + """Default WeightUpdateMeta has no quantization.""" + meta = WeightUpdateMeta(type="xccl") + + assert meta.quantization is None + assert meta.quantization_config is None diff --git a/tests/experimental/weight_update/torchrun/run_fp8_weight_transfer.py b/tests/experimental/weight_update/torchrun/run_fp8_weight_transfer.py new file mode 100644 index 0000000000..0f8461b686 --- /dev/null +++ b/tests/experimental/weight_update/torchrun/run_fp8_weight_transfer.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os + +import torch +import torch.distributed as dist + +from tests.experimental.weight_update.torchrun.dist_utils import ( + print_rank0, + write_result, +) + +from areal.experimental.weight_update.nccl_group import init_weights_update_group +from areal.infra.platforms import current_platform +from areal.utils.kernel.fp8_kernel import scaled_fp8_blockwise + + +def run_fp8_weight_transfer(output=None): + """Test: FP8 block-wise quantized weight transfer from training to inference via NCCL. + + Rank 0 (training side) quantizes a BF16 weight to FP8 using block-wise + quantization, then broadcasts both the FP8 weight and the per-block scale + tensor to rank 1 (inference side) over a custom NCCL process group. + A non-quantized 1D norm weight is also broadcast and verified. + """ + rank = dist.get_rank() + world_size = dist.get_world_size() + + print_rank0("=== FP8 Weight Transfer Test ===") + + # Use a different port from the main group to avoid conflicts + master_addr = os.environ.get("MASTER_ADDR", "localhost") + from areal.utils.network import find_free_ports + + if rank == 0: + ports = find_free_ports(1) + port_tensor = torch.tensor(ports, dtype=torch.long, device=f"cuda:{rank}") + else: + port_tensor = torch.zeros(1, dtype=torch.long, device=f"cuda:{rank}") + dist.broadcast(port_tensor, src=0) + master_port = int(port_tensor[0].item()) + + # For this test: rank 0 = training, rank 1 = inference + is_inference = rank == 1 + + try: + group = init_weights_update_group( + master_address=master_addr, + master_port=master_port, + rank=rank, + world_size=world_size, + group_name="awex_test_fp8_transfer", + backend="nccl", + role="inference" if is_inference else "training", + ) + print_rank0(f" Group created successfully with {world_size} ranks") + + device = torch.device(f"cuda:{current_platform.current_device()}") + + # ------------------------------------------------------------------- + # Prepare tensors + # ------------------------------------------------------------------- + weight_shape = (256, 512) + block_size = [128, 128] + scale_shape = (2, 4) # ceil(256/128) x ceil(512/128) + norm_shape = (256,) + + if not is_inference: + # Training side: create deterministic BF16 weights and quantize + torch.manual_seed(42) + + # 2D weight tensor -> quantize to FP8 + q_proj_weight_bf16 = torch.randn( + weight_shape, dtype=torch.bfloat16, device=device + ) + fp8_weight, scale_inv = scaled_fp8_blockwise( + q_proj_weight_bf16, weight_block_size=block_size + ) + + # 1D norm weight -> NOT quantized, stays BF16 + norm_weight_bf16 = torch.randn( + norm_shape, dtype=torch.bfloat16, device=device + ) + + tensors_to_send = { + "layers.0.q_proj.weight": fp8_weight, + "layers.0.q_proj.weight_scale_inv": scale_inv, + "layers.0.norm.weight": norm_weight_bf16, + } + else: + # Inference side: create receive buffers with matching shapes/dtypes + tensors_to_send = { + "layers.0.q_proj.weight": torch.zeros( + weight_shape, dtype=torch.float8_e4m3fn, device=device + ), + "layers.0.q_proj.weight_scale_inv": torch.zeros( + scale_shape, dtype=torch.float32, device=device + ), + "layers.0.norm.weight": torch.zeros( + norm_shape, dtype=torch.bfloat16, device=device + ), + } + + # ------------------------------------------------------------------- + # Broadcast from rank 0 (training) to all other ranks (inference) + # ------------------------------------------------------------------- + for name in sorted(tensors_to_send.keys()): + tensor = tensors_to_send[name] + dist.broadcast(tensor, src=0, group=group) + + current_platform.synchronize() + dist.barrier(group=group) + + # ------------------------------------------------------------------- + # Verify: inference side checks received data matches expected + # ------------------------------------------------------------------- + success = True + if is_inference: + # Re-create expected values on rank 1 to compare + torch.manual_seed(42) + expected_q_proj = torch.randn( + weight_shape, dtype=torch.bfloat16, device=device + ) + expected_fp8, expected_scale = scaled_fp8_blockwise( + expected_q_proj, weight_block_size=block_size + ) + expected_norm = torch.randn(norm_shape, dtype=torch.bfloat16, device=device) + + # Verify FP8 weight + received_fp8 = tensors_to_send["layers.0.q_proj.weight"] + if not torch.equal(received_fp8, expected_fp8): + print_rank0( + " MISMATCH layers.0.q_proj.weight: FP8 weight does not match" + ) + success = False + else: + print_rank0( + f" OK layers.0.q_proj.weight: shape={list(received_fp8.shape)}, dtype={received_fp8.dtype}" + ) + + # Verify scale + received_scale = tensors_to_send["layers.0.q_proj.weight_scale_inv"] + if not torch.equal(received_scale, expected_scale): + print_rank0( + " MISMATCH layers.0.q_proj.weight_scale_inv: scale does not match" + ) + success = False + else: + print_rank0( + f" OK layers.0.q_proj.weight_scale_inv: shape={list(received_scale.shape)}, dtype={received_scale.dtype}" + ) + + # Verify norm weight (1D, BF16, not quantized) + received_norm = tensors_to_send["layers.0.norm.weight"] + if not torch.equal(received_norm, expected_norm): + print_rank0( + " MISMATCH layers.0.norm.weight: norm weight does not match" + ) + success = False + else: + print_rank0( + f" OK layers.0.norm.weight: shape={list(received_norm.shape)}, dtype={received_norm.dtype}" + ) + + print_rank0( + f" Rank {rank} verification: {'PASSED' if success else 'FAILED'}" + ) + + # All-reduce success flag so all ranks agree + success_tensor = torch.tensor( + [1 if success else 0], dtype=torch.int, device=device + ) + dist.all_reduce(success_tensor, op=dist.ReduceOp.MIN, group=group) + success = bool(success_tensor.item()) + + dist.destroy_process_group(group) + print_rank0(f" Overall: {'PASSED' if success else 'FAILED'}") + + except Exception as e: + print_rank0(f" FAILED: {e}") + import traceback + + traceback.print_exc() + success = False + + dist.barrier() + if rank == 0 and output: + write_result(output, success) + return success + + +TEST_REGISTRY = { + "fp8_weight_transfer": run_fp8_weight_transfer, +} + + +def main(): + parser = argparse.ArgumentParser(description="FP8 NCCL Weight Transfer Tests") + parser.add_argument( + "--test_type", + type=str, + required=True, + choices=list(TEST_REGISTRY.keys()), + ) + parser.add_argument("--output", type=str, default=None) + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + torch.cuda.set_device(rank) + + print_rank0("=" * 60) + print_rank0(f"Running: {args.test_type}") + print_rank0("=" * 60) + + try: + test_fn = TEST_REGISTRY[args.test_type] + success = test_fn(args.output) + + dist.barrier() + if success: + print_rank0(f"\n{args.test_type}: PASSED") + else: + print_rank0(f"\n{args.test_type}: FAILED") + if rank == 0 and args.output: + write_result(args.output, False) + except Exception as e: + print(f"Rank {rank} failed: {e}") + import traceback + + traceback.print_exc() + if rank == 0 and args.output: + write_result(args.output, False) + raise + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/utils/kernel/test_fp8_kernel.py b/tests/utils/kernel/test_fp8_kernel.py new file mode 100644 index 0000000000..85f7c82e1d --- /dev/null +++ b/tests/utils/kernel/test_fp8_kernel.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the FP8 block-wise quantization kernel.""" + +from __future__ import annotations + +import importlib.util +import os +import sys +import types + +import pytest +import torch + +# Force PyTorch fallback for Triton-incompatible GPUs (e.g. SM86). +os.environ["DISABLE_TRITON_FP8"] = "1" + +# --------------------------------------------------------------------------- +# Mock areal.utils.math before loading the kernel module +# --------------------------------------------------------------------------- +math_mod = types.ModuleType("areal.utils.math") +math_mod.ceil_div = lambda x, y: (x + y - 1) // y +sys.modules["areal"] = types.ModuleType("areal") +sys.modules["areal.utils"] = types.ModuleType("areal.utils") +sys.modules["areal.utils"].__path__ = [] +sys.modules["areal.utils.math"] = math_mod + +spec = importlib.util.spec_from_file_location( + "fp8_kernel", + "/F00120250029/lixiang_share/zengziyi_share/zengziyi/Research/Areal_sub/areal/utils/kernel/fp8_kernel.py", +) +fp8_mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(fp8_mod) + +scaled_fp8_blockwise = fp8_mod.scaled_fp8_blockwise +should_quantize_param = fp8_mod.should_quantize_param +FP8_MAX = fp8_mod.FP8_MAX + + +# --------------------------------------------------------------------------- +# TestShouldQuantizeParam +# --------------------------------------------------------------------------- +class TestShouldQuantizeParam: + """Tests for should_quantize_param().""" + + @pytest.mark.parametrize( + "param_name", + [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.0.self_attn.o_proj.weight", + "model.layers.0.mlp.gate_proj.weight", + "model.layers.0.mlp.up_proj.weight", + "model.layers.0.mlp.down_proj.weight", + ], + ) + def test_quantize_linear_layers(self, param_name: str) -> None: + """Linear projection weights should be quantized.""" + assert should_quantize_param(param_name) is True + + def test_skip_embedding(self) -> None: + """Embedding token weights should NOT be quantized.""" + assert should_quantize_param("model.embed_tokens.weight") is False + + def test_skip_lm_head(self) -> None: + """LM head weights should NOT be quantized.""" + assert should_quantize_param("lm_head.weight") is False + + @pytest.mark.parametrize( + "param_name", + [ + "model.layers.0.input_layernorm.weight", + "model.layers.0.post_attention_layernorm.weight", + "model.norm.weight", + ], + ) + def test_skip_norm(self, param_name: str) -> None: + """Normalization layer weights should NOT be quantized.""" + assert should_quantize_param(param_name) is False + + def test_skip_bias(self) -> None: + """Bias parameters (not .weight) should NOT be quantized.""" + assert should_quantize_param("model.layers.0.self_attn.q_proj.bias") is False + + def test_skip_moe_router(self) -> None: + """MoE router gate weights should NOT be quantized.""" + assert should_quantize_param("model.layers.0.mlp.gate.weight") is False + + +# --------------------------------------------------------------------------- +# TestScaledFp8Blockwise +# --------------------------------------------------------------------------- +class TestScaledFp8Blockwise: + """Tests for scaled_fp8_blockwise().""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_basic_quantization(self) -> None: + """256x512 BF16 -> fp8 shape (256,512), scale shape (2,4).""" + data = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") + fp8_data, scale = scaled_fp8_blockwise(data, weight_block_size=[128, 128]) + + assert fp8_data.shape == (256, 512) + assert fp8_data.dtype == torch.float8_e4m3fn + assert scale.shape == (2, 4) + assert scale.dtype == torch.float32 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_roundtrip_dequant_approximate(self) -> None: + """Quant then dequant; relative error should be < 5%.""" + torch.manual_seed(42) + data = torch.randn(256, 512, dtype=torch.bfloat16, device="cuda") + fp8_data, scale = scaled_fp8_blockwise(data, weight_block_size=[128, 128]) + + # Dequantize: fp8_data * scale (per-block) + blk_m, blk_n = scale.shape + block_m = data.shape[0] // blk_m + block_n = data.shape[1] // blk_n + + dequant = torch.zeros_like(data, dtype=torch.float32) + fp8_f32 = fp8_data.to(torch.float32) + for i in range(blk_m): + for j in range(blk_n): + row_start = i * block_m + row_end = row_start + block_m + col_start = j * block_n + col_end = col_start + block_n + dequant[row_start:row_end, col_start:col_end] = ( + fp8_f32[row_start:row_end, col_start:col_end] * scale[i, j] + ) + + data_f32 = data.to(torch.float32) + rel_err = (dequant - data_f32).abs().mean() / data_f32.abs().mean() + assert rel_err < 0.05, f"Mean relative error {rel_err} >= 0.05" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_non_multiple_dimensions(self) -> None: + """100x300 (not multiple of 128) -> fp8 shape (100,300), scale shape (1,3).""" + data = torch.randn(100, 300, dtype=torch.bfloat16, device="cuda") + fp8_data, scale = scaled_fp8_blockwise(data, weight_block_size=[128, 128]) + + assert fp8_data.shape == (100, 300) + assert fp8_data.dtype == torch.float8_e4m3fn + assert scale.shape == (1, 3) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + def test_all_zeros(self) -> None: + """128x128 zeros -> scale should be 1.0.""" + data = torch.zeros(128, 128, dtype=torch.bfloat16, device="cuda") + fp8_data, scale = scaled_fp8_blockwise(data, weight_block_size=[128, 128]) + + assert fp8_data.shape == (128, 128) + assert scale.shape == (1, 1) + assert scale.item() == pytest.approx(1.0, abs=1e-5) + + def test_cpu_fallback(self) -> None: + """CPU BF16 tensor should work via PyTorch fallback.""" + data = torch.randn(128, 128, dtype=torch.bfloat16, device="cpu") + fp8_data, scale = scaled_fp8_blockwise(data, weight_block_size=[128, 128]) + + assert fp8_data.shape == (128, 128) + assert fp8_data.dtype == torch.float8_e4m3fn + assert scale.shape == (1, 1) + assert scale.dtype == torch.float32