From d142c21444807647e913d89248c4b98146138c59 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 21:40:40 +0800 Subject: [PATCH 01/12] feat(engine): lora --- areal/engine/fsdp_engine.py | 106 +++++- tests/test_lora_disk_sync.py | 485 +++++++++++++++++++++++++++ tests/test_lora_disk_sync_e2e.py | 139 ++++++++ tests/torchrun/run_lora_disk_sync.py | 293 ++++++++++++++++ 4 files changed, 1016 insertions(+), 7 deletions(-) create mode 100644 tests/test_lora_disk_sync.py create mode 100644 tests/test_lora_disk_sync_e2e.py create mode 100644 tests/torchrun/run_lora_disk_sync.py diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 0e9682bc0d..67f967b1a0 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -1630,7 +1630,20 @@ def _save_model_to_hf( tokenizer: PreTrainedTokenizerFast | None, processor: ProcessorMixin | None, ): - """Save model in HuggingFace format.""" + """Save model in HuggingFace format. + + For full models (``use_lora=False``), the entire ``state_dict`` is + written via :meth:`PreTrainedModel.save_pretrained`. + + For LoRA models (``use_lora=True``), only the trainable LoRA + adapter parameters are written, in the standard PEFT layout + (``adapter_model.safetensors`` + ``adapter_config.json``). This + is the format that SGLang's ``/load_lora_adapter`` endpoint + expects, so the existing disk-based weight update path + (``meta.use_lora=True`` -> ``build_disk_weight_update_requests`` + -> ``/load_lora_adapter``) becomes end-to-end functional without + any additional dispatch logic. + """ if self.model is None: raise RuntimeError("Model not initialized") os.makedirs(path, exist_ok=True) @@ -1643,14 +1656,93 @@ def _save_model_to_hf( # save huggingface model on rank 0 if dist.get_rank() == 0: os.makedirs(path, exist_ok=True) - self.model.save_pretrained(path, state_dict=state_dict) - self.model_config.save_pretrained(path) - if tokenizer is not None and not self.config.use_lora: - tokenizer.save_pretrained(path) - if processor is not None and not self.config.use_lora: - processor.save_pretrained(path) + if self.config.use_lora: + self._save_lora_adapter_to_hf(path, state_dict) + else: + self.model.save_pretrained(path, state_dict=state_dict) + self.model_config.save_pretrained(path) + if tokenizer is not None: + tokenizer.save_pretrained(path) + if processor is not None: + processor.save_pretrained(path) dist.barrier(group=self.cpu_group) + def _save_lora_adapter_to_hf( + self, + path: str, + state_dict: dict[str, torch.Tensor], + ): + """Save only LoRA adapter weights in standard PEFT format. + + Filters ``state_dict`` for LoRA adapter tensors, strips the active + adapter name segment (``.default.``) from each key so the result + matches the layout produced by ``PeftModel.save_pretrained`` / + ``get_peft_model_state_dict`` (and hence what SGLang's + ``/load_lora_adapter`` expects), and writes: + + * ``adapter_model.safetensors`` -- adapter tensor file + * ``adapter_config.json`` -- PEFT-compatible LoRA config + + Called on rank 0 only. + """ + import json + + from safetensors.torch import save_file as safetensors_save_file + + # PEFT's named_parameters() / state_dict include the active adapter + # name in keys, e.g. "...lora_A.default.weight". The standard PEFT + # adapter file format (as produced by get_peft_model_state_dict / + # save_pretrained) strips this adapter name, yielding + # "...lora_A.weight". SGLang's /load_lora_adapter expects the + # stripped format. + adapter_name = "default" # PEFT default adapter name + lora_keywords = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + + adapter_state_dict: dict[str, torch.Tensor] = {} + for name, tensor in state_dict.items(): + if not any(kw in name for kw in lora_keywords): + continue + stripped = name.replace(f".{adapter_name}.", ".") + adapter_state_dict[stripped] = tensor.contiguous().cpu() + + if not adapter_state_dict: + raise RuntimeError( + "use_lora=True but no LoRA adapter parameters were found in " + "the model state_dict; check that PEFT wrapping was applied." + ) + + safetensors_save_file( + adapter_state_dict, os.path.join(path, "adapter_model.safetensors") + ) + + # Build a PEFT-compatible adapter_config.json + config = self.config + if not config.target_modules or config.target_modules == ["all-linear"]: + target_modules_val = "all-linear" + else: + target_modules_val = config.target_modules + + adapter_config = { + "peft_type": "LORA", + "auto_mapping": None, + "base_model_name_or_path": getattr(config, "path", "") or "", + "bias": "none", + "fan_in_fan_out": False, + "inference_mode": True, + "init_lora_weights": True, + "layers_to_transform": None, + "layers_pattern": None, + "lora_alpha": config.lora_alpha, + "lora_dropout": 0.0, + "modules_to_save": None, + "r": config.lora_rank, + "revision": None, + "target_modules": target_modules_val, + "task_type": "CAUSAL_LM", + } + with open(os.path.join(path, "adapter_config.json"), "w") as f: + json.dump(adapter_config, f, indent=2) + def _load_model_from_hf(self, path: str): """Load model from HuggingFace format.""" if dist.get_rank() == 0: diff --git a/tests/test_lora_disk_sync.py b/tests/test_lora_disk_sync.py new file mode 100644 index 0000000000..8ba91072f9 --- /dev/null +++ b/tests/test_lora_disk_sync.py @@ -0,0 +1,485 @@ +"""Unit tests for LoRA disk-based weight synchronization. + +The disk-mode LoRA sync flow on FSDP + SGLang is: + +* Training side (FSDP): ``FSDPEngine._save_model_to_hf`` branches on + ``self.config.use_lora``. When ``use_lora=True`` it calls + ``_save_lora_adapter_to_hf`` which: + - filters the full state_dict for ``lora_A`` / ``lora_B`` / + ``lora_embedding_A`` / ``lora_embedding_B`` keys, + - strips the active-adapter segment ``.default.`` so the layout + matches what ``peft.PeftModel.save_pretrained`` would produce + (and what SGLang's ``/load_lora_adapter`` expects), + - writes ``adapter_model.safetensors`` + ``adapter_config.json``. + +* Inference side (SGLang): ``SGLangBackend.build_disk_weight_update_requests`` + routes ``meta.use_lora=True`` to ``HttpRequest("/load_lora_adapter", ...)`` + and the standard full-model branch to ``/update_weights_from_disk``. + +These unit tests exercise (a) the LoRA filtering / key normalisation +logic, (b) the ``WeightUpdateMeta`` schema, (c) the SGLang +request-building dispatch, and (d) the ``get_versioned_lora_name`` +utility. They are CPU-only and do not require any GPU or running +SGLang server. +""" + +import copy +import json +import os +from dataclasses import asdict + +import pytest +import torch + +from areal.api import ParamSpec, WeightUpdateMeta +from areal.api.cli_args import TrainEngineConfig +from areal.api.io_struct import get_versioned_lora_name + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +_LORA_KEYWORDS = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + + +def _make_dummy_model_params() -> dict[str, torch.Tensor]: + """Return a dict simulating ``named_parameters`` of a LoRA-wrapped model. + + Keys follow the PEFT naming convention produced by + ``peft.get_peft_model`` on a HuggingFace transformer (i.e. they + contain the active adapter name segment ``.default.``). + """ + return { + # Base model weights (non-LoRA) + "base_model.model.model.embed_tokens.weight": torch.randn(1000, 64), + "base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight": torch.randn( + 64, 64 + ), + "base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight": torch.randn( + 64, 64 + ), + "base_model.model.lm_head.weight": torch.randn(1000, 64), + # LoRA adapter weights (with PEFT ".default." segment) + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight": torch.randn( + 8, 64 + ), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight": torch.randn( + 64, 8 + ), + "base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight": torch.randn( + 8, 64 + ), + "base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight": torch.randn( + 64, 8 + ), + } + + +def _is_lora_param(name: str) -> bool: + return any(kw in name for kw in _LORA_KEYWORDS) + + +def _filter_lora_adapter_state( + params: dict[str, torch.Tensor], adapter_name: str = "default" +) -> dict[str, torch.Tensor]: + """Replicate ``FSDPEngine._save_lora_adapter_to_hf`` filtering logic. + + Selects only LoRA tensors and strips the ``..`` segment + so the resulting keys match the standard PEFT adapter file layout + (e.g. ``...lora_A.weight``), which is what SGLang's + ``/load_lora_adapter`` expects. + """ + out: dict[str, torch.Tensor] = {} + for name, tensor in params.items(): + if not _is_lora_param(name): + continue + stripped = name.replace(f".{adapter_name}.", ".") + out[stripped] = tensor + return out + + +# --------------------------------------------------------------------------- +# Test: LoRA adapter filtering / key normalisation +# --------------------------------------------------------------------------- + + +class TestLoRAAdapterFiltering: + """Mirrors ``FSDPEngine._save_lora_adapter_to_hf`` selection logic.""" + + def test_filter_returns_only_lora(self): + params = _make_dummy_model_params() + adapter = _filter_lora_adapter_state(params) + assert len(adapter) == 4 + for name in adapter: + assert _is_lora_param(name) + + def test_filter_excludes_base(self): + params = _make_dummy_model_params() + adapter = _filter_lora_adapter_state(params) + for name in adapter: + assert "base_layer" not in name + + def test_filter_strips_default_segment(self): + """After filtering, ``.default.`` must be removed (PEFT format).""" + params = _make_dummy_model_params() + adapter = _filter_lora_adapter_state(params) + for name in adapter: + assert ".default." not in name + assert _is_lora_param(name) + assert name.endswith(".weight") + + def test_filter_empty_when_no_lora(self): + params = { + "model.embed_tokens.weight": torch.randn(100, 64), + "model.layers.0.self_attn.q_proj.weight": torch.randn(64, 64), + } + assert _filter_lora_adapter_state(params) == {} + + def test_lora_param_shapes_match_rank(self): + """LoRA A has shape (rank, in_features); B has (out_features, rank).""" + params = _make_dummy_model_params() + adapter = _filter_lora_adapter_state(params) + lora_rank = 8 + for name, tensor in adapter.items(): + if "lora_A" in name: + assert tensor.shape[0] == lora_rank + elif "lora_B" in name: + assert tensor.shape[1] == lora_rank + + def test_adapter_keys_match_load_lora_adapter_format(self): + """The stripped keys are exactly what SGLang's + ``/load_lora_adapter`` endpoint consumes when it reads the + ``adapter_model.safetensors`` next to ``adapter_config.json``. + """ + params = _make_dummy_model_params() + adapter = _filter_lora_adapter_state(params) + for k in adapter: + assert k.endswith(".lora_A.weight") or k.endswith(".lora_B.weight"), k + + +# --------------------------------------------------------------------------- +# Test: WeightUpdateMeta schema for LoRA disk sync +# --------------------------------------------------------------------------- + + +class TestWeightUpdateMetaForLoRA: + """Validate the LoRA-relevant fields on WeightUpdateMeta.""" + + def test_default_flags(self): + meta = WeightUpdateMeta(type="disk") + assert meta.use_lora is False + assert meta.lora_name == "" + assert meta.version is None + + def test_construct_lora_meta(self): + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="my-lora", + lora_int_id=1, + peft_config={"r": 16, "lora_alpha": 16}, + path="/tmp/weight_update", + ) + assert meta.use_lora is True + assert meta.lora_name == "my-lora" + assert meta.peft_config["r"] == 16 + + def test_asdict_round_trip(self): + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="my-lora", + lora_int_id=1, + peft_config={"r": 16, "lora_alpha": 16}, + ) + d = asdict(meta) + assert d["type"] == "disk" + assert d["use_lora"] is True + assert d["lora_name"] == "my-lora" + assert d["peft_config"]["r"] == 16 + + def test_with_version_updates_path_and_version(self): + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="my-lora", + path="/tmp/checkpoints/weight_update", + ) + v3 = meta.with_version(3) + assert v3.version == 3 + assert v3.path is not None + assert v3.path.endswith("weight_update_v3") + assert meta.version is None + assert v3.use_lora is True + assert v3.lora_name == "my-lora" + + def test_with_version_rejects_negative(self): + meta = WeightUpdateMeta(type="disk", path="/tmp/wu") + with pytest.raises(ValueError, match="non-negative"): + meta.with_version(-1) + + def test_copy_preserves_lora_fields(self): + meta = WeightUpdateMeta( + type="disk", use_lora=True, lora_name="L", lora_int_id=42 + ) + meta_copy = copy.copy(meta) + assert meta_copy.use_lora is True + assert meta_copy.lora_name == "L" + assert meta_copy.lora_int_id == 42 + + +# --------------------------------------------------------------------------- +# Test: TrainEngineConfig +# --------------------------------------------------------------------------- + + +class TestTrainEngineConfigLoRA: + """LoRA fields on TrainEngineConfig. + + Enabling LoRA disk sync only requires ``use_lora=True`` and + ``weight_update_mode='disk'`` -- no additional flag is needed. + """ + + def test_default_use_lora_false(self): + config = TrainEngineConfig( + experiment_name="test", trial_name="t", backend="fsdp:d1" + ) + assert config.use_lora is False + + def test_enable_lora_with_disk_mode(self): + config = TrainEngineConfig( + experiment_name="test", + trial_name="t", + backend="fsdp:d1", + use_lora=True, + lora_rank=16, + lora_alpha=32, + weight_update_mode="disk", + ) + assert config.use_lora is True + assert config.lora_rank == 16 + assert config.lora_alpha == 32 + assert config.weight_update_mode == "disk" + + +# --------------------------------------------------------------------------- +# Test: SGLang request-building dispatch +# --------------------------------------------------------------------------- + + +class TestSGLangBackendDispatch: + """Verify ``SGLangBackend`` builds the correct HTTP requests for each + ``WeightUpdateMeta`` shape. + """ + + def test_disk_lora_routes_to_load_lora_adapter(self): + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="test-lora", + version=2, + path="/tmp/lora_weights", + ) + requests = backend.build_disk_weight_update_requests(meta) + assert len(requests.requests) == 1 + req = requests.requests[0] + assert req.endpoint == "/load_lora_adapter" + assert req.payload["lora_name"] == "test-lora-v2" + assert req.payload["lora_path"] == "/tmp/lora_weights" + + def test_disk_full_model_routes_to_update_weights_from_disk(self): + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + meta = WeightUpdateMeta(type="disk", use_lora=False, path="/tmp/full_model") + requests = backend.build_disk_weight_update_requests(meta) + assert len(requests.requests) == 1 + req = requests.requests[0] + assert req.endpoint == "/update_weights_from_disk" + assert req.payload["model_path"] == "/tmp/full_model" + + def test_disk_lora_requires_lora_name(self): + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + meta = WeightUpdateMeta(type="disk", use_lora=True, version=0, path="/tmp") + with pytest.raises(ValueError, match="LoRA name"): + backend.build_disk_weight_update_requests(meta) + + def test_disk_lora_requires_version(self): + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="L", + version=None, + path="/tmp", + ) + with pytest.raises(ValueError, match="Version"): + backend.build_disk_weight_update_requests(meta) + + def test_distributed_rejects_lora(self): + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + meta = WeightUpdateMeta(type="xccl", use_lora=True) + with pytest.raises(ValueError, match="does not support LoRA"): + backend.build_distributed_weight_update_requests(meta, []) + + def test_distributed_full_model(self): + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + meta = WeightUpdateMeta( + type="xccl", use_lora=False, nccl_group_name="test_group" + ) + specs = [ + ParamSpec( + name="model.embed_tokens.weight", + shape=(1000, 64), + dtype="bfloat16", + ), + ] + requests = backend.build_distributed_weight_update_requests(meta, specs) + assert len(requests.requests) == 1 + req = requests.requests[0] + assert req.endpoint == "/update_weights_from_distributed" + assert "model.embed_tokens.weight" in req.payload["names"] + assert req.payload["group_name"] == "test_group" + + def test_generation_request_injects_lora_path(self): + from areal.api import ModelRequest + from areal.api.cli_args import GenerationHyperparameters + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + gconfig = GenerationHyperparameters(max_new_tokens=8, lora_name="my-lora") + req = ModelRequest(input_ids=[1, 2, 3], gconfig=gconfig) + http_req = backend.build_generation_request(req, with_lora=True, version=5) + assert http_req.endpoint == "/generate" + assert http_req.payload["lora_path"] == "my-lora-v5" + + def test_generation_request_without_lora(self): + from areal.api import ModelRequest + from areal.api.cli_args import GenerationHyperparameters + from areal.engine.sglang_remote import SGLangBackend + + backend = SGLangBackend() + gconfig = GenerationHyperparameters(max_new_tokens=8) + req = ModelRequest(input_ids=[1, 2, 3], gconfig=gconfig) + http_req = backend.build_generation_request( + req, with_lora=False, version=0 + ) + assert "lora_path" not in http_req.payload + + +# --------------------------------------------------------------------------- +# Test: get_versioned_lora_name utility +# --------------------------------------------------------------------------- + + +class TestGetVersionedLoraName: + def test_basic(self): + assert get_versioned_lora_name("lora-gsm8k", 1) == "lora-gsm8k-v1" + + def test_version_zero(self): + assert get_versioned_lora_name("my-lora", 0) == "my-lora-v0" + + def test_version_large(self): + assert get_versioned_lora_name("adapter", 999) == "adapter-v999" + + +# --------------------------------------------------------------------------- +# Test: ParamSpec +# --------------------------------------------------------------------------- + + +class TestParamSpecForLoRA: + def test_construct_from_lora_tensor(self): + spec = ParamSpec( + name="model.layers.0.self_attn.q_proj.lora_A.weight", + shape=(8, 64), + dtype="bfloat16", + ) + assert spec.name.endswith("lora_A.weight") + assert spec.shape == (8, 64) + + def test_size_bfloat16(self): + spec = ParamSpec(name="t", shape=(8, 64), dtype="bfloat16") + assert spec.size == 1024 # 2 bytes * 512 elements + + def test_size_float32(self): + spec = ParamSpec(name="t", shape=(16, 32), dtype="float32") + assert spec.size == 2048 # 4 bytes * 512 elements + + +# --------------------------------------------------------------------------- +# Test: end-to-end disk-sync handshake (offline simulation) +# --------------------------------------------------------------------------- + + +class TestDiskSyncHandshake: + """Simulate the end-to-end disk-mode LoRA sync without networking. + + The training side writes ``adapter_model.safetensors`` + + ``adapter_config.json`` under ``meta.path``, and the inference side + builds an HTTP request whose ``lora_path`` points to that same + directory. + """ + + def test_meta_path_is_carried_into_load_lora_adapter_payload(self, tmp_path): + from areal.engine.sglang_remote import SGLangBackend + + # 1. Training side writes the (mock) adapter files. + adapter_dir = tmp_path / "weight_update_v1" + adapter_dir.mkdir() + (adapter_dir / "adapter_model.safetensors").write_bytes(b"") + with open(adapter_dir / "adapter_config.json", "w") as f: + json.dump({"peft_type": "LORA", "r": 8, "lora_alpha": 16}, f) + + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="my-lora", + version=1, + path=str(adapter_dir), + ) + + # 2. Inference side translates the meta into an HTTP request. + requests = SGLangBackend().build_disk_weight_update_requests(meta) + + # 3. Verify the request points at the directory the training side + # just wrote. + req = requests.requests[0] + assert req.endpoint == "/load_lora_adapter" + assert req.payload["lora_path"] == str(adapter_dir) + assert os.path.exists( + os.path.join(req.payload["lora_path"], "adapter_config.json") + ) + assert os.path.exists( + os.path.join(req.payload["lora_path"], "adapter_model.safetensors") + ) + assert req.payload["lora_name"] == "my-lora-v1" + + def test_with_version_changes_path_and_lora_name_consistently(self, tmp_path): + """Across versions, the path and the lora_name must stay aligned.""" + from areal.engine.sglang_remote import SGLangBackend + + base_meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="my-lora", + path=str(tmp_path / "weight_update"), + ) + for v in [0, 1, 2, 7]: + m = base_meta.with_version(v) + req = SGLangBackend().build_disk_weight_update_requests(m).requests[0] + assert req.payload["lora_name"] == f"my-lora-v{v}" + assert req.payload["lora_path"].endswith(f"weight_update_v{v}") diff --git a/tests/test_lora_disk_sync_e2e.py b/tests/test_lora_disk_sync_e2e.py new file mode 100644 index 0000000000..4115ad5c00 --- /dev/null +++ b/tests/test_lora_disk_sync_e2e.py @@ -0,0 +1,139 @@ +"""End-to-end test for LoRA disk-based weight synchronization. + +This test launches ``torchrun`` with the helper script +``tests/torchrun/run_lora_disk_sync.py`` which: + +1. Creates a small Qwen3-0.6B model with FSDP + LoRA + ``weight_update_mode="disk"``. +2. Writes a PEFT-format adapter checkpoint via + ``FSDPEngine._save_model_to_hf`` and verifies the on-disk artefacts. +3. Verifies the model can still run a forward pass after the save. + +The test requires GPUs and is marked ``@pytest.mark.slow`` and +``@pytest.mark.sglang`` following AReaL test conventions. + +Usage: + pytest tests/test_lora_disk_sync_e2e.py -v +""" + +import subprocess + +import pytest + +from areal.api.alloc_mode import ModelAllocation +from areal.infra.platforms import current_platform +from areal.utils.network import find_free_ports + + +def _run_torchrun_test(alloc_mode: str, output: str, n_gpus: int | None = None): + """Launch the torchrun helper script and check the result. + + Parameters + ---------- + alloc_mode : str + Backend allocation string, e.g. ``"fsdp:d1t1"``. + output : str + Path to the result file that the torchrun script writes. + n_gpus : int, optional + Override number of GPUs. If ``None``, it is derived from + ``alloc_mode``. + """ + port = find_free_ports(1)[0] + if n_gpus is None: + n_gpus = ModelAllocation.from_str(alloc_mode).parallel.world_size + + cmd = [ + "torchrun", + f"--nproc_per_node={n_gpus}", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/torchrun/run_lora_disk_sync.py", + f"--backend={alloc_mode}", + f"--output={output}", + ] + + try: + result = subprocess.run( + cmd, + check=True, + capture_output=True, + text=True, + timeout=300, + ) + print(result.stdout) + except subprocess.CalledProcessError as e: + pytest.fail( + f"torchrun failed with exit code {e.returncode}.\n" + f"stdout:\n{e.stdout}\n" + f"stderr:\n{e.stderr}" + ) + except subprocess.TimeoutExpired: + pytest.fail("torchrun timed out after 300 seconds.") + + with open(output) as f: + result_text = f.read().strip() + assert result_text == "Passed", f"Test failed: {result_text}" + + +# --------------------------------------------------------------------------- +# Single-GPU test (dp=1, tp=1) +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +@pytest.mark.sglang +def test_lora_disk_sync_single_gpu(tmp_path_factory): + """Test LoRA disk sync with 1 GPU using FSDP engine.""" + if current_platform.device_count() < 1: + pytest.skip("Test requires at least 1 GPU") + + output = tmp_path_factory.mktemp("test_output") / "lora_disk_sync_1gpu.out" + _run_torchrun_test("fsdp:d1t1", str(output), n_gpus=1) + + +# --------------------------------------------------------------------------- +# Multi-GPU test (dp=2, tp=1) +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +@pytest.mark.sglang +def test_lora_disk_sync_multi_gpu_dp2(tmp_path_factory): + """Test LoRA disk sync with 2 GPUs (data parallel = 2).""" + if current_platform.device_count() < 2: + pytest.skip("Test requires at least 2 GPUs") + + output = tmp_path_factory.mktemp("test_output") / "lora_disk_sync_dp2.out" + _run_torchrun_test("fsdp:d2t1", str(output), n_gpus=2) + + +# --------------------------------------------------------------------------- +# Multi-GPU test (dp=4, tp=1) +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +@pytest.mark.sglang +def test_lora_disk_sync_multi_gpu_dp4(tmp_path_factory): + """Test LoRA disk sync with 4 GPUs (data parallel = 4).""" + if current_platform.device_count() < 4: + pytest.skip("Test requires at least 4 GPUs") + + output = tmp_path_factory.mktemp("test_output") / "lora_disk_sync_dp4.out" + _run_torchrun_test("fsdp:d4t1", str(output), n_gpus=4) + + +# --------------------------------------------------------------------------- +# Multi-GPU test (dp=2, tp=2) -- 4 GPUs with tensor parallel +# --------------------------------------------------------------------------- + + +@pytest.mark.slow +@pytest.mark.sglang +def test_lora_disk_sync_multi_gpu_dp2_tp2(tmp_path_factory): + """Test LoRA disk sync with 4 GPUs (dp=2, tp=2).""" + if current_platform.device_count() < 4: + pytest.skip("Test requires at least 4 GPUs") + + output = tmp_path_factory.mktemp("test_output") / "lora_disk_sync_dp2_tp2.out" + _run_torchrun_test("fsdp:d2t2", str(output), n_gpus=4) diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py new file mode 100644 index 0000000000..c1f950fe60 --- /dev/null +++ b/tests/torchrun/run_lora_disk_sync.py @@ -0,0 +1,293 @@ +"""Torchrun script for LoRA disk-sync end-to-end validation. + +This script is launched via ``torchrun`` from the e2e test +(``tests/test_lora_disk_sync_e2e.py``). It: + +1. Creates a small Qwen3-0.6B model with LoRA adapters on the FSDP engine. +2. Calls ``FSDPEngine._save_model_to_hf`` (the production code path used + by ``_update_weights_from_disk``) under ``use_lora=True``. +3. Validates the on-disk artefacts: + * ``adapter_model.safetensors`` exists and is non-empty. + * ``adapter_config.json`` exists, parses, and has the PEFT-required + fields (``peft_type='LORA'``, ``r``, ``lora_alpha``, + ``target_modules``). + * Every adapter tensor key matches the PEFT layout consumed by + SGLang's ``/load_lora_adapter`` (i.e. contains a LoRA keyword + and does NOT contain the active-adapter ``.default.`` segment). +4. Verifies the in-memory model still runs a forward pass. +5. Writes ``"Passed"`` / ``"Failed"`` to an output file (rank 0 only). + +Note: the new disk-mode LoRA sync path does not require any NCCL +process group on the inference side; SGLang loads the adapter directly +from disk via its existing ``/load_lora_adapter`` endpoint. This +script therefore exercises the *training-side* contract end-to-end and +leaves the SGLang HTTP request-building to the unit tests in +``tests/test_lora_disk_sync.py``. + +Usage (invoked by the e2e test, not directly): + torchrun --nproc_per_node=N tests/torchrun/run_lora_disk_sync.py \ + --backend fsdp:d1t1 --output /tmp/result.out +""" + +import argparse +import json +import os +import tempfile + +import torch +import torch.distributed as dist +from safetensors.torch import load_file as safetensors_load_file + +from tests.utils import get_model_path + +from areal.api import FinetuneSpec +from areal.api.alloc_mode import ModelAllocation +from areal.api.cli_args import ( + FSDPEngineConfig, + MicroBatchSpec, + OptimizerConfig, + TrainEngineConfig, +) +from areal.engine import FSDPEngine +from areal.infra.platforms import current_platform + +MODEL_PATH = get_model_path( + "/storage/openpsi/models/Qwen__Qwen3-0.6B/", "Qwen/Qwen3-0.6B" +) + + +_LORA_KEYWORDS = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + + +def write_result(path: str, success: bool) -> None: + with open(path, "w") as f: + f.write("Passed" if success else "Failed") + + +def make_fsdp_engine_with_lora(backend: str) -> FSDPEngine: + """Create an FSDPEngine with LoRA + disk-mode weight update.""" + config = TrainEngineConfig( + backend=backend, + experiment_name="test_lora_disk_sync", + trial_name="test", + mb_spec=MicroBatchSpec(max_tokens_per_mb=256), + path=MODEL_PATH, + optimizer=OptimizerConfig(), + fsdp=FSDPEngineConfig(memory_efficient_load=True), + # LoRA config + use_lora=True, + lora_rank=8, + lora_alpha=16, + peft_type="lora", + # Disk-based weight update mode (LoRA disk sync is implicit + # whenever use_lora=True and weight_update_mode="disk"). + weight_update_mode="disk", + ) + alloc_mode = ModelAllocation.from_str(backend) + engine = FSDPEngine(config) + ft_spec = FinetuneSpec(total_train_epochs=1, dataset_size=128, train_batch_size=8) + engine.create_process_group(parallel_strategy=alloc_mode.parallel) + engine.initialize(None, ft_spec) + return engine + + +def count_lora_and_base_params(engine: FSDPEngine): + """Count LoRA vs base parameters in the engine model.""" + lora_count = 0 + base_count = 0 + for name, _param in engine.model.named_parameters(): + if any(kw in name for kw in _LORA_KEYWORDS): + lora_count += 1 + else: + base_count += 1 + return lora_count, base_count + + +def verify_forward_pass(engine: FSDPEngine) -> bool: + """Run a simple forward pass to verify the model is still functional.""" + try: + with torch.no_grad(): + engine.eval() + input_ids = torch.randint( + 100, 1000, (2, 32), dtype=torch.long, device=engine.device + ) + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + outputs = engine.model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False, + ) + logits = outputs.logits + return logits is not None and logits.shape[0] == 2 + except Exception as e: + print(f"Forward pass failed: {e}", flush=True) + return False + + +def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: int) -> bool: + """Validate the PEFT-format files written by ``_save_model_to_hf``.""" + safetensors_path = os.path.join(adapter_dir, "adapter_model.safetensors") + config_path = os.path.join(adapter_dir, "adapter_config.json") + + if not os.path.exists(safetensors_path): + print(f"ERROR: missing {safetensors_path}", flush=True) + return False + if os.path.getsize(safetensors_path) == 0: + print(f"ERROR: empty {safetensors_path}", flush=True) + return False + if not os.path.exists(config_path): + print(f"ERROR: missing {config_path}", flush=True) + return False + + # Validate adapter_config.json content. + with open(config_path) as f: + cfg = json.load(f) + if cfg.get("peft_type") != "LORA": + print(f"ERROR: peft_type != LORA: {cfg.get('peft_type')}", flush=True) + return False + if cfg.get("r") != lora_rank: + print(f"ERROR: r mismatch: {cfg.get('r')} vs {lora_rank}", flush=True) + return False + if cfg.get("lora_alpha") != lora_alpha: + print( + f"ERROR: lora_alpha mismatch: {cfg.get('lora_alpha')} vs {lora_alpha}", + flush=True, + ) + return False + if "target_modules" not in cfg: + print("ERROR: target_modules missing from adapter_config.json", flush=True) + return False + + # Validate adapter_model.safetensors keys: every key must be a LoRA + # tensor with the ``.default.`` segment stripped. + state = safetensors_load_file(safetensors_path) + if not state: + print("ERROR: adapter_model.safetensors is empty", flush=True) + return False + for k in state: + if not any(kw in k for kw in _LORA_KEYWORDS): + print(f"ERROR: non-LoRA key found in adapter file: {k}", flush=True) + return False + if ".default." in k: + print(f"ERROR: '.default.' was not stripped from key: {k}", flush=True) + return False + return True + + +def test_lora_disk_sync(backend: str, output: str | None = None) -> None: + """Main test logic for LoRA disk sync.""" + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + success = True + + alloc_mode = ModelAllocation.from_str(backend) + dp = alloc_mode.parallel.dp + tp = alloc_mode.parallel.tp + + print( + f"[Rank {rank}] Starting LoRA disk sync test | " + f"backend={backend} | dp={dp} tp={tp} | world_size={world_size}", + flush=True, + ) + + # Step 1: Create engine with LoRA + disk mode. + print( + f"[Rank {rank}] Creating FSDP engine | dp={dp} tp={tp} " + f"use_lora=True weight_update_mode=disk", + flush=True, + ) + engine = make_fsdp_engine_with_lora(backend) + + # Step 2: Verify LoRA parameters exist on the in-memory model. + lora_count, base_count = count_lora_and_base_params(engine) + print( + f"[Rank {rank}] Model has {lora_count} LoRA params and " + f"{base_count} base params", + flush=True, + ) + if lora_count == 0: + print(f"[Rank {rank}] ERROR: No LoRA parameters found!", flush=True) + success = False + + # Step 3: Verify config flags are correct. + if not engine.config.use_lora: + print(f"[Rank {rank}] ERROR: use_lora not set!", flush=True) + success = False + if engine.config.weight_update_mode != "disk": + print( + f"[Rank {rank}] ERROR: weight_update_mode != 'disk' " + f"(got {engine.config.weight_update_mode})", + flush=True, + ) + success = False + + # Step 4: Trigger the actual save path used by + # ``_update_weights_from_disk`` (no inference server required). + with tempfile.TemporaryDirectory() as tmpdir: + adapter_dir = os.path.join(tmpdir, "weight_update_v0") + os.makedirs(adapter_dir, exist_ok=True) + print( + f"[Rank {rank}] Calling _save_model_to_hf -> {adapter_dir}", flush=True + ) + engine._save_model_to_hf( + adapter_dir, + tokenizer=None, + processor=None, + ) + + # Only rank 0 writes the artefacts; verify there. + if rank == 0: + ok = verify_adapter_artifacts( + adapter_dir, + lora_rank=engine.config.lora_rank, + lora_alpha=engine.config.lora_alpha, + ) + if not ok: + success = False + else: + print( + f"[Rank {rank}] PEFT adapter artefacts validated", flush=True + ) + + # Step 5: Verify forward pass still works after the save. + print(f"[Rank {rank}] Verifying forward pass", flush=True) + if not verify_forward_pass(engine): + print(f"[Rank {rank}] ERROR: Forward pass failed!", flush=True) + success = False + else: + print(f"[Rank {rank}] Forward pass OK", flush=True) + + # Cleanup. + current_platform.synchronize() + dist.barrier() + engine.destroy() + + if rank == 0 and output: + write_result(output, success) + + status = "PASSED" if success else "FAILED" + print(f"[Rank {rank}] LoRA disk sync test {status}", flush=True) + + +def main(): + parser = argparse.ArgumentParser( + description="Torchrun script for LoRA disk sync e2e test" + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Path to save test result (Passed/Failed)", + ) + parser.add_argument( + "--backend", + type=str, + default="fsdp:d1t1", + help="Backend allocation string (e.g., 'fsdp:d1t1')", + ) + args = parser.parse_args() + test_lora_disk_sync(backend=args.backend, output=args.output) + + +if __name__ == "__main__": + main() From c2ec2d4eec746729c7e575d16085ac0dbe7a8e79 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 22:01:20 +0800 Subject: [PATCH 02/12] feat(metrics): log size --- areal/engine/fsdp_engine.py | 47 +++++++++++++++++++++++++++++++++ areal/engine/sglang_remote.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 67f967b1a0..9abbc5fadc 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -1665,8 +1665,55 @@ def _save_model_to_hf( tokenizer.save_pretrained(path) if processor is not None: processor.save_pretrained(path) + # Record on-disk artefact size as a wandb metric so the + # full-model -> adapter-only size collapse across steps is + # visible. For LoRA disk sync the first checkpoint may + # still be a full-model warmup save; subsequent adapter-only + # saves are expected to be ~10-100x smaller. + self._log_disk_save_size(path) dist.barrier(group=self.cpu_group) + def _log_disk_save_size(self, path: str) -> None: + """Record the on-disk size of a checkpoint directory as a wandb metric. + + Sums the bytes of every regular file under ``path`` (recursively) + and reports the total under ``weight_update_disk/*``. When + ``self.config.use_lora`` is True the size is also reported under + a dedicated ``lora_bytes`` key so the LoRA-only series is easy + to plot side-by-side with the full-model series. + + Called on rank 0 only; failures are swallowed because metric + emission must never break the training loop. + """ + try: + total_bytes = 0 + for root, _dirs, files in os.walk(path): + for fname in files: + fpath = os.path.join(root, fname) + try: + total_bytes += os.path.getsize(fpath) + except OSError: + # File may have been racily deleted; skip it. + continue + scope = "weight_update_disk" + if self.config.use_lora: + stats_tracker.get(scope).scalar( + lora_bytes=float(total_bytes), + bytes=float(total_bytes), + ) + else: + stats_tracker.get(scope).scalar( + full_bytes=float(total_bytes), + bytes=float(total_bytes), + ) + self.logger.info( + f"[weight_update_disk] saved {total_bytes} bytes to {path} " + f"(use_lora={self.config.use_lora})" + ) + except Exception as e: + # Metric emission must never break training. + self.logger.warning(f"Failed to record disk-save size metric: {e}") + def _save_lora_adapter_to_hf( self, path: str, diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 52039df326..3eec92658f 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -143,9 +143,11 @@ def build_disk_weight_update_requests( payload={"lora_name": lora_name, "lora_path": str(meta.path)}, ) ] + self._log_disk_send_size(meta, use_lora=True) return WeightUpdateRequests(requests=requests) else: # Full model update + self._log_disk_send_size(meta, use_lora=False) return WeightUpdateRequests( requests=[ HttpRequest( @@ -158,6 +160,53 @@ def build_disk_weight_update_requests( ] ) + @staticmethod + def _log_disk_send_size(meta: WeightUpdateMeta, *, use_lora: bool) -> None: + """Record the size of weights the inference side will pull from disk. + + For disk-mode weight updates, the HTTP payload itself is just a + pointer (``model_path`` or ``lora_path``) -- the actual bytes + SGLang loads come from that on-disk directory. We surface the + aggregate file size here so wandb can show: + + * ``weight_update_send/lora_bytes`` for adapter-only sends, + * ``weight_update_send/full_bytes`` for full-model sends, + * ``weight_update_send/bytes`` for either, in a single + unified series. + + Failures are swallowed because metric emission must never break + the weight-update path. + """ + path = meta.path + if path is None: + return + try: + total_bytes = 0 + if os.path.isdir(path): + for root, _dirs, files in os.walk(path): + for fname in files: + fpath = os.path.join(root, fname) + try: + total_bytes += os.path.getsize(fpath) + except OSError: + continue + elif os.path.isfile(path): + total_bytes = os.path.getsize(path) + scope = "weight_update_send" + if use_lora: + stats_tracker.get(scope).scalar( + lora_bytes=float(total_bytes), + bytes=float(total_bytes), + ) + else: + stats_tracker.get(scope).scalar( + full_bytes=float(total_bytes), + bytes=float(total_bytes), + ) + except Exception: + # Metric emission must never break the weight-update path. + pass + def build_distributed_weight_update_requests( self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> WeightUpdateRequests: From 23b28b0c195f08da11a8dcb4f15677ba6eb4c918 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 22:19:28 +0800 Subject: [PATCH 03/12] fix(network): ipv6 --- areal/utils/network.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/areal/utils/network.py b/areal/utils/network.py index 3481016039..0720200e00 100644 --- a/areal/utils/network.py +++ b/areal/utils/network.py @@ -23,6 +23,22 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: Raises: RuntimeError: If no suitable address can be determined """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.connect((probe_host, probe_port)) + return sock.getsockname()[0] + except OSError: + pass + + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: + sock.connect(("2001:4860:4860::8888", probe_port)) + ip6 = sock.getsockname()[0] + if ip6 and ip6 != "::1": + return ip6 + except OSError: + pass + try: hostname = socket.gethostname() infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_DGRAM) @@ -38,19 +54,7 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: except socket.gaierror: pass - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: - sock.connect((probe_host, probe_port)) - return sock.getsockname()[0] - except OSError as e: - try: - with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: - sock.connect(("2001:4860:4860::8888", probe_port)) - ip6 = sock.getsockname()[0] - if ip6 and ip6 != "::1": - return ip6 - except OSError: - raise RuntimeError("Could not determine host IP") from e + raise RuntimeError("Could not determine host IP") def get_loopback_ip() -> str: From e41b1a7a52ab41c9674f17b376d08954eaa9a8ad Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 23:10:06 +0800 Subject: [PATCH 04/12] feat(weight): fix metric --- areal/engine/fsdp_engine.py | 74 ++++++++++++++++++++++++----------- areal/engine/sglang_remote.py | 53 +++++++++++++++---------- 2 files changed, 85 insertions(+), 42 deletions(-) diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 9abbc5fadc..d09ceb65b7 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -1605,7 +1605,23 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): fut = self.rollout_engine.update_weights_from_disk(meta) assert meta.path is not None - self._save_model_to_hf(meta.path, self.tokenizer, self.processor) + # LoRA disk-sync incremental strategy: + # * First weight update (version == 1): save the FULL base+LoRA + # merged model so SGLang can ingest via /update_weights_from_disk. + # * Subsequent updates (version >= 2): save adapter-only PEFT + # artefact so SGLang can ingest via /load_lora_adapter. + # The non-LoRA path is unchanged: always full model save. + save_full_for_lora_first_step = ( + self.config.use_lora + and meta.version is not None + and meta.version <= 1 + ) + self._save_model_to_hf( + meta.path, + self.tokenizer, + self.processor, + force_full_model=save_full_for_lora_first_step, + ) # dist.barrier() are called when _save_model_to_hf finished if dist.get_rank() == 0: @@ -1629,6 +1645,7 @@ def _save_model_to_hf( path: str, tokenizer: PreTrainedTokenizerFast | None, processor: ProcessorMixin | None, + force_full_model: bool = False, ): """Save model in HuggingFace format. @@ -1643,6 +1660,13 @@ def _save_model_to_hf( (``meta.use_lora=True`` -> ``build_disk_weight_update_requests`` -> ``/load_lora_adapter``) becomes end-to-end functional without any additional dispatch logic. + + ``force_full_model=True`` overrides the LoRA branch and always + writes the full base+LoRA-merged HF model. This is used by the + very first weight update so SGLang can do a + ``/update_weights_from_disk`` full warm-load before switching to + adapter-only ``/load_lora_adapter`` deltas on subsequent + updates. """ if self.model is None: raise RuntimeError("Model not initialized") @@ -1656,7 +1680,8 @@ def _save_model_to_hf( # save huggingface model on rank 0 if dist.get_rank() == 0: os.makedirs(path, exist_ok=True) - if self.config.use_lora: + saved_as_lora_adapter = self.config.use_lora and not force_full_model + if saved_as_lora_adapter: self._save_lora_adapter_to_hf(path, state_dict) else: self.model.save_pretrained(path, state_dict=state_dict) @@ -1667,20 +1692,27 @@ def _save_model_to_hf( processor.save_pretrained(path) # Record on-disk artefact size as a wandb metric so the # full-model -> adapter-only size collapse across steps is - # visible. For LoRA disk sync the first checkpoint may - # still be a full-model warmup save; subsequent adapter-only - # saves are expected to be ~10-100x smaller. - self._log_disk_save_size(path) + # visible. For the first LoRA disk sync we save the full + # model warmup; subsequent adapter-only saves should be + # ~10-100x smaller. + self._log_disk_save_size(path, saved_as_lora_adapter=saved_as_lora_adapter) dist.barrier(group=self.cpu_group) - def _log_disk_save_size(self, path: str) -> None: + def _log_disk_save_size(self, path: str, saved_as_lora_adapter: bool = False) -> None: """Record the on-disk size of a checkpoint directory as a wandb metric. Sums the bytes of every regular file under ``path`` (recursively) - and reports the total under ``weight_update_disk/*``. When - ``self.config.use_lora`` is True the size is also reported under - a dedicated ``lora_bytes`` key so the LoRA-only series is easy - to plot side-by-side with the full-model series. + and reports the total under flat top-level keys + ``weight_update_disk_bytes`` (always populated) plus a + format-specific key (``weight_update_disk_lora_bytes`` for an + adapter-only PEFT save, ``weight_update_disk_full_bytes`` for a + full HF model save). + + The metrics are written via the **default** ``stats_tracker`` so + they are exported through the same tracker that already drives + wandb panels for ``ppo_actor`` etc. Using a top-level (no + ``/``) key avoids creating a separate wandb panel group that + the user might overlook. Called on rank 0 only; failures are swallowed because metric emission must never break the training loop. @@ -1695,20 +1727,18 @@ def _log_disk_save_size(self, path: str) -> None: except OSError: # File may have been racily deleted; skip it. continue - scope = "weight_update_disk" - if self.config.use_lora: - stats_tracker.get(scope).scalar( - lora_bytes=float(total_bytes), - bytes=float(total_bytes), - ) + kwargs: dict[str, float] = { + "weight_update_disk_bytes": float(total_bytes), + } + if saved_as_lora_adapter: + kwargs["weight_update_disk_lora_bytes"] = float(total_bytes) else: - stats_tracker.get(scope).scalar( - full_bytes=float(total_bytes), - bytes=float(total_bytes), - ) + kwargs["weight_update_disk_full_bytes"] = float(total_bytes) + stats_tracker.scalar(**kwargs) self.logger.info( f"[weight_update_disk] saved {total_bytes} bytes to {path} " - f"(use_lora={self.config.use_lora})" + f"(use_lora={self.config.use_lora}, " + f"saved_as_lora_adapter={saved_as_lora_adapter})" ) except Exception as e: # Metric emission must never break training. diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 3eec92658f..0858967fed 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -130,7 +130,20 @@ def build_disk_weight_update_requests( self, meta: WeightUpdateMeta ) -> WeightUpdateRequests: """Build SGLang disk weight update requests.""" - if meta.use_lora: + # LoRA disk-sync incremental strategy: + # * version <= 1 (first sync) and use_lora=True: + # train side wrote the FULL base+LoRA-merged HF model; + # hit /update_weights_from_disk to do a full warm-load. + # * version >= 2 and use_lora=True: + # train side wrote an adapter-only PEFT artefact; + # hit /load_lora_adapter so SGLang only ingests deltas. + # * use_lora=False: always /update_weights_from_disk. + use_lora_adapter_endpoint = ( + meta.use_lora + and meta.version is not None + and meta.version > 1 + ) + if use_lora_adapter_endpoint: if not meta.lora_name: raise ValueError("LoRA name is required for LoRA update.") if meta.version is None: @@ -143,11 +156,12 @@ def build_disk_weight_update_requests( payload={"lora_name": lora_name, "lora_path": str(meta.path)}, ) ] - self._log_disk_send_size(meta, use_lora=True) + self._log_disk_send_size(meta, sent_as_lora_adapter=True) return WeightUpdateRequests(requests=requests) else: - # Full model update - self._log_disk_send_size(meta, use_lora=False) + # Full model update (covers non-LoRA always, and the + # version<=1 first-sync warm-load for LoRA). + self._log_disk_send_size(meta, sent_as_lora_adapter=False) return WeightUpdateRequests( requests=[ HttpRequest( @@ -161,18 +175,20 @@ def build_disk_weight_update_requests( ) @staticmethod - def _log_disk_send_size(meta: WeightUpdateMeta, *, use_lora: bool) -> None: + def _log_disk_send_size(meta: WeightUpdateMeta, *, sent_as_lora_adapter: bool) -> None: """Record the size of weights the inference side will pull from disk. For disk-mode weight updates, the HTTP payload itself is just a pointer (``model_path`` or ``lora_path``) -- the actual bytes SGLang loads come from that on-disk directory. We surface the - aggregate file size here so wandb can show: + aggregate file size via the default ``stats_tracker`` under flat + top-level keys (no scope prefix) so the values land in the + same wandb panel group as ``ppo_actor/*`` and don't get hidden + in a separate auto-generated panel: - * ``weight_update_send/lora_bytes`` for adapter-only sends, - * ``weight_update_send/full_bytes`` for full-model sends, - * ``weight_update_send/bytes`` for either, in a single - unified series. + * ``weight_update_send_bytes`` -- always populated, + * ``weight_update_send_lora_bytes`` -- adapter-only sends, + * ``weight_update_send_full_bytes`` -- full-model sends. Failures are swallowed because metric emission must never break the weight-update path. @@ -192,17 +208,14 @@ def _log_disk_send_size(meta: WeightUpdateMeta, *, use_lora: bool) -> None: continue elif os.path.isfile(path): total_bytes = os.path.getsize(path) - scope = "weight_update_send" - if use_lora: - stats_tracker.get(scope).scalar( - lora_bytes=float(total_bytes), - bytes=float(total_bytes), - ) + kwargs: dict[str, float] = { + "weight_update_send_bytes": float(total_bytes), + } + if sent_as_lora_adapter: + kwargs["weight_update_send_lora_bytes"] = float(total_bytes) else: - stats_tracker.get(scope).scalar( - full_bytes=float(total_bytes), - bytes=float(total_bytes), - ) + kwargs["weight_update_send_full_bytes"] = float(total_bytes) + stats_tracker.scalar(**kwargs) except Exception: # Metric emission must never break the weight-update path. pass From b09a9be333d09a86e3998f8b74bc02db5ca8c650 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 23:29:11 +0800 Subject: [PATCH 05/12] refactor(engine): lora --- areal/engine/fsdp_engine.py | 61 +++++++++++------------------------ areal/engine/sglang_remote.py | 30 +++++------------ 2 files changed, 26 insertions(+), 65 deletions(-) diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index d09ceb65b7..40f343748c 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -1605,23 +1605,7 @@ def _update_weights_from_disk(self, meta: WeightUpdateMeta): fut = self.rollout_engine.update_weights_from_disk(meta) assert meta.path is not None - # LoRA disk-sync incremental strategy: - # * First weight update (version == 1): save the FULL base+LoRA - # merged model so SGLang can ingest via /update_weights_from_disk. - # * Subsequent updates (version >= 2): save adapter-only PEFT - # artefact so SGLang can ingest via /load_lora_adapter. - # The non-LoRA path is unchanged: always full model save. - save_full_for_lora_first_step = ( - self.config.use_lora - and meta.version is not None - and meta.version <= 1 - ) - self._save_model_to_hf( - meta.path, - self.tokenizer, - self.processor, - force_full_model=save_full_for_lora_first_step, - ) + self._save_model_to_hf(meta.path, self.tokenizer, self.processor) # dist.barrier() are called when _save_model_to_hf finished if dist.get_rank() == 0: @@ -1645,7 +1629,6 @@ def _save_model_to_hf( path: str, tokenizer: PreTrainedTokenizerFast | None, processor: ProcessorMixin | None, - force_full_model: bool = False, ): """Save model in HuggingFace format. @@ -1660,13 +1643,6 @@ def _save_model_to_hf( (``meta.use_lora=True`` -> ``build_disk_weight_update_requests`` -> ``/load_lora_adapter``) becomes end-to-end functional without any additional dispatch logic. - - ``force_full_model=True`` overrides the LoRA branch and always - writes the full base+LoRA-merged HF model. This is used by the - very first weight update so SGLang can do a - ``/update_weights_from_disk`` full warm-load before switching to - adapter-only ``/load_lora_adapter`` deltas on subsequent - updates. """ if self.model is None: raise RuntimeError("Model not initialized") @@ -1680,8 +1656,7 @@ def _save_model_to_hf( # save huggingface model on rank 0 if dist.get_rank() == 0: os.makedirs(path, exist_ok=True) - saved_as_lora_adapter = self.config.use_lora and not force_full_model - if saved_as_lora_adapter: + if self.config.use_lora: self._save_lora_adapter_to_hf(path, state_dict) else: self.model.save_pretrained(path, state_dict=state_dict) @@ -1691,28 +1666,29 @@ def _save_model_to_hf( if processor is not None: processor.save_pretrained(path) # Record on-disk artefact size as a wandb metric so the - # full-model -> adapter-only size collapse across steps is - # visible. For the first LoRA disk sync we save the full - # model warmup; subsequent adapter-only saves should be - # ~10-100x smaller. - self._log_disk_save_size(path, saved_as_lora_adapter=saved_as_lora_adapter) + # weight-sync data volume is observable. LoRA disk sync + # uses adapter-only PEFT files (~ tens of MB), full-model + # sync uses the full HF directory (~ hundreds of MB to GBs); + # plotting these side-by-side makes the bandwidth saving + # explicit. + self._log_disk_save_size(path) dist.barrier(group=self.cpu_group) - def _log_disk_save_size(self, path: str, saved_as_lora_adapter: bool = False) -> None: + def _log_disk_save_size(self, path: str) -> None: """Record the on-disk size of a checkpoint directory as a wandb metric. Sums the bytes of every regular file under ``path`` (recursively) and reports the total under flat top-level keys ``weight_update_disk_bytes`` (always populated) plus a - format-specific key (``weight_update_disk_lora_bytes`` for an - adapter-only PEFT save, ``weight_update_disk_full_bytes`` for a - full HF model save). + format-specific key (``weight_update_disk_lora_bytes`` when + ``self.config.use_lora`` is True, otherwise + ``weight_update_disk_full_bytes``). The metrics are written via the **default** ``stats_tracker`` so - they are exported through the same tracker that already drives - wandb panels for ``ppo_actor`` etc. Using a top-level (no - ``/``) key avoids creating a separate wandb panel group that - the user might overlook. + they flow through the same tracker that already drives wandb + panels for ``ppo_actor`` etc. Using a top-level (no ``/``) key + avoids creating a separate wandb panel group that the user + might overlook. Called on rank 0 only; failures are swallowed because metric emission must never break the training loop. @@ -1730,15 +1706,14 @@ def _log_disk_save_size(self, path: str, saved_as_lora_adapter: bool = False) -> kwargs: dict[str, float] = { "weight_update_disk_bytes": float(total_bytes), } - if saved_as_lora_adapter: + if self.config.use_lora: kwargs["weight_update_disk_lora_bytes"] = float(total_bytes) else: kwargs["weight_update_disk_full_bytes"] = float(total_bytes) stats_tracker.scalar(**kwargs) self.logger.info( f"[weight_update_disk] saved {total_bytes} bytes to {path} " - f"(use_lora={self.config.use_lora}, " - f"saved_as_lora_adapter={saved_as_lora_adapter})" + f"(use_lora={self.config.use_lora})" ) except Exception as e: # Metric emission must never break training. diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 0858967fed..6584e89242 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -130,20 +130,7 @@ def build_disk_weight_update_requests( self, meta: WeightUpdateMeta ) -> WeightUpdateRequests: """Build SGLang disk weight update requests.""" - # LoRA disk-sync incremental strategy: - # * version <= 1 (first sync) and use_lora=True: - # train side wrote the FULL base+LoRA-merged HF model; - # hit /update_weights_from_disk to do a full warm-load. - # * version >= 2 and use_lora=True: - # train side wrote an adapter-only PEFT artefact; - # hit /load_lora_adapter so SGLang only ingests deltas. - # * use_lora=False: always /update_weights_from_disk. - use_lora_adapter_endpoint = ( - meta.use_lora - and meta.version is not None - and meta.version > 1 - ) - if use_lora_adapter_endpoint: + if meta.use_lora: if not meta.lora_name: raise ValueError("LoRA name is required for LoRA update.") if meta.version is None: @@ -156,12 +143,11 @@ def build_disk_weight_update_requests( payload={"lora_name": lora_name, "lora_path": str(meta.path)}, ) ] - self._log_disk_send_size(meta, sent_as_lora_adapter=True) + self._log_disk_send_size(meta, use_lora=True) return WeightUpdateRequests(requests=requests) else: - # Full model update (covers non-LoRA always, and the - # version<=1 first-sync warm-load for LoRA). - self._log_disk_send_size(meta, sent_as_lora_adapter=False) + # Full model update + self._log_disk_send_size(meta, use_lora=False) return WeightUpdateRequests( requests=[ HttpRequest( @@ -175,14 +161,14 @@ def build_disk_weight_update_requests( ) @staticmethod - def _log_disk_send_size(meta: WeightUpdateMeta, *, sent_as_lora_adapter: bool) -> None: + def _log_disk_send_size(meta: WeightUpdateMeta, *, use_lora: bool) -> None: """Record the size of weights the inference side will pull from disk. For disk-mode weight updates, the HTTP payload itself is just a pointer (``model_path`` or ``lora_path``) -- the actual bytes SGLang loads come from that on-disk directory. We surface the - aggregate file size via the default ``stats_tracker`` under flat - top-level keys (no scope prefix) so the values land in the + aggregate file size via the **default** ``stats_tracker`` under + flat top-level keys (no scope prefix) so the values land in the same wandb panel group as ``ppo_actor/*`` and don't get hidden in a separate auto-generated panel: @@ -211,7 +197,7 @@ def _log_disk_send_size(meta: WeightUpdateMeta, *, sent_as_lora_adapter: bool) - kwargs: dict[str, float] = { "weight_update_send_bytes": float(total_bytes), } - if sent_as_lora_adapter: + if use_lora: kwargs["weight_update_send_lora_bytes"] = float(total_bytes) else: kwargs["weight_update_send_full_bytes"] = float(total_bytes) From d4e0ce00fbfa10b5364e99274af32267c00d301e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 11 May 2026 00:15:27 +0800 Subject: [PATCH 06/12] feat(test): fix --- tests/test_lora_adapter_save.py | 271 +++++++++++++++++++++++ tests/test_lora_disk_size_metrics.py | 311 +++++++++++++++++++++++++++ tests/test_lora_disk_sync.py | 46 ++++ tests/torchrun/run_lora_disk_sync.py | 44 ++++ 4 files changed, 672 insertions(+) create mode 100644 tests/test_lora_adapter_save.py create mode 100644 tests/test_lora_disk_size_metrics.py diff --git a/tests/test_lora_adapter_save.py b/tests/test_lora_adapter_save.py new file mode 100644 index 0000000000..b5ddfc6896 --- /dev/null +++ b/tests/test_lora_adapter_save.py @@ -0,0 +1,271 @@ +"""Unit tests for ``FSDPEngine._save_lora_adapter_to_hf``. + +These tests invoke the real production method (no helper copy) on a +lightweight stub object that quacks like an FSDPEngine. We feed it a +synthetic state_dict matching the exact key shape produced by +``peft.get_peft_model`` on a HuggingFace transformer (i.e. with the +``base_model.model.`` prefix and the ``.default.`` adapter segment), +let the method write to a tmp directory, then assert that: + +* ``adapter_model.safetensors`` exists, parses, and contains ONLY + LoRA tensors with ``.default.`` stripped (the format SGLang's + ``/load_lora_adapter`` consumes); +* ``adapter_config.json`` is well-formed PEFT JSON with the required + fields populated from the engine's ``TrainEngineConfig``. + +The tests are CPU-only and require no FSDP / GPU. +""" + +from __future__ import annotations + +import json +import os +from types import SimpleNamespace + +import pytest +import torch +from safetensors.torch import load_file as safetensors_load_file + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_lora_state_dict() -> dict[str, torch.Tensor]: + """Synthetic state_dict in PEFT layout (``base_model.model.`` + + ``.default.``). Mixes base, lora_A, lora_B, and an embedding LoRA + pair so all four LoRA keywords are exercised. + """ + return { + # Base weights -- must be filtered out. + "base_model.model.model.embed_tokens.weight": torch.zeros(10, 4), + "base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight": torch.zeros( + 4, 4 + ), + "base_model.model.lm_head.weight": torch.zeros(10, 4), + # LoRA tensors -- must be kept and have ``.default.`` stripped. + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight": torch.ones( + 8, 4 + ), + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight": torch.ones( + 4, 8 + ), + "base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight": torch.ones( + 8, 4 + ), + "base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight": torch.ones( + 4, 8 + ), + # Embedding LoRA pair. + "base_model.model.model.embed_tokens.lora_embedding_A.default.weight": torch.ones( + 8, 10 + ), + "base_model.model.model.embed_tokens.lora_embedding_B.default.weight": torch.ones( + 4, 8 + ), + } + + +def _make_engine_stub( + *, + lora_rank: int = 8, + lora_alpha: int = 16, + target_modules=None, + base_path: str = "/storage/dummy/base", +) -> SimpleNamespace: + """Build a SimpleNamespace exposing only the attributes that + ``_save_lora_adapter_to_hf`` reads off ``self``. + """ + return SimpleNamespace( + config=SimpleNamespace( + lora_rank=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules if target_modules is not None else [], + path=base_path, + ), + ) + + +def _invoke(engine_stub, path, state_dict): + from areal.engine.fsdp_engine import FSDPEngine + + return FSDPEngine._save_lora_adapter_to_hf(engine_stub, path, state_dict) + + +# --------------------------------------------------------------------------- +# Tests: the safetensors output +# --------------------------------------------------------------------------- + + +class TestAdapterSafetensors: + def test_only_lora_keys_are_written(self, tmp_path): + engine = _make_engine_stub() + d = tmp_path / "weight_update_v0" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + + f = d / "adapter_model.safetensors" + assert f.exists() + assert f.stat().st_size > 0 + + loaded = safetensors_load_file(str(f)) + # 4 lora_A/B layer pairs + 2 lora_embedding_A/B = 6 tensors + assert len(loaded) == 6 + # No base / lm_head keys leaked through. + for k in loaded: + assert "base_layer" not in k + assert "lm_head" not in k + + def test_default_segment_is_stripped(self, tmp_path): + engine = _make_engine_stub() + d = tmp_path / "wu" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + + loaded = safetensors_load_file(str(d / "adapter_model.safetensors")) + for k in loaded: + # The PEFT adapter file format never carries the active + # adapter name segment. SGLang's loader assumes it's gone. + assert ".default." not in k + # Each remaining key must end in `.weight` and contain a + # LoRA keyword somewhere (covering both the linear and + # embedding cases). + assert k.endswith(".weight"), k + assert any( + kw in k + for kw in ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") + ), k + + def test_tensor_values_are_preserved(self, tmp_path): + """Filtering must not mutate tensor values.""" + engine = _make_engine_stub() + sd = _make_lora_state_dict() + d = tmp_path / "wu" + d.mkdir() + _invoke(engine, str(d), sd) + + loaded = safetensors_load_file(str(d / "adapter_model.safetensors")) + sample_key = ( + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" + ) + assert sample_key in loaded + torch.testing.assert_close( + loaded[sample_key], torch.ones(8, 4), check_dtype=False + ) + + def test_missing_lora_raises(self, tmp_path): + """Calling with a state_dict that has no LoRA params must error + out -- this is the fail-fast guard against forgetting to wrap + the model with PEFT. + """ + engine = _make_engine_stub() + d = tmp_path / "wu" + d.mkdir() + bare_state = { + "model.embed_tokens.weight": torch.zeros(10, 4), + "model.layers.0.self_attn.q_proj.weight": torch.zeros(4, 4), + } + with pytest.raises(RuntimeError, match="no LoRA adapter parameters"): + _invoke(engine, str(d), bare_state) + + +# --------------------------------------------------------------------------- +# Tests: adapter_config.json +# --------------------------------------------------------------------------- + + +class TestAdapterConfigJson: + def _read(self, d) -> dict: + with open(os.path.join(str(d), "adapter_config.json")) as f: + return json.load(f) + + def test_required_fields(self, tmp_path): + engine = _make_engine_stub(lora_rank=16, lora_alpha=32) + d = tmp_path / "wu" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + + cfg = self._read(d) + # Mandatory PEFT fields: + assert cfg["peft_type"] == "LORA" + assert cfg["task_type"] == "CAUSAL_LM" + assert cfg["r"] == 16 + assert cfg["lora_alpha"] == 32 + assert cfg["bias"] == "none" + assert cfg["lora_dropout"] == 0.0 + assert cfg["inference_mode"] is True + assert "target_modules" in cfg + + def test_target_modules_default_to_all_linear(self, tmp_path): + """Empty list / ``["all-linear"]`` must serialize as the string + ``"all-linear"`` -- which is what PEFT and SGLang both expect. + """ + engine = _make_engine_stub(target_modules=[]) + d = tmp_path / "wu" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + assert self._read(d)["target_modules"] == "all-linear" + + engine2 = _make_engine_stub(target_modules=["all-linear"]) + d2 = tmp_path / "wu2" + d2.mkdir() + _invoke(engine2, str(d2), _make_lora_state_dict()) + assert self._read(d2)["target_modules"] == "all-linear" + + def test_target_modules_explicit_list_is_preserved(self, tmp_path): + engine = _make_engine_stub(target_modules=["q_proj", "v_proj"]) + d = tmp_path / "wu" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + cfg = self._read(d) + assert cfg["target_modules"] == ["q_proj", "v_proj"] + + def test_base_model_path_is_carried(self, tmp_path): + engine = _make_engine_stub(base_path="/some/where/qwen3-0.6b") + d = tmp_path / "wu" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + cfg = self._read(d) + assert cfg["base_model_name_or_path"] == "/some/where/qwen3-0.6b" + + +# --------------------------------------------------------------------------- +# Tests: directory layout matches what /load_lora_adapter expects. +# --------------------------------------------------------------------------- + + +class TestLoadLoraAdapterContract: + """The on-disk layout produced by ``_save_lora_adapter_to_hf`` MUST be + exactly what SGLang's ``/load_lora_adapter`` reads. This test pins + the contract so a future refactor cannot silently break the + inference side. + """ + + def test_two_files_present(self, tmp_path): + engine = _make_engine_stub() + d = tmp_path / "weight_update_v7" + d.mkdir() + _invoke(engine, str(d), _make_lora_state_dict()) + files = sorted(os.listdir(str(d))) + assert files == ["adapter_config.json", "adapter_model.safetensors"] + + def test_round_trip_size_is_stable(self, tmp_path): + """Re-saving the same state must yield byte-identical safetensors + (modulo timestamps, which safetensors does not embed). This + catches accidental ordering nondeterminism in the filter loop. + """ + engine = _make_engine_stub() + sd = _make_lora_state_dict() + + d1 = tmp_path / "a" + d1.mkdir() + _invoke(engine, str(d1), sd) + size1 = (d1 / "adapter_model.safetensors").stat().st_size + + d2 = tmp_path / "b" + d2.mkdir() + _invoke(engine, str(d2), sd) + size2 = (d2 / "adapter_model.safetensors").stat().st_size + + assert size1 == size2 diff --git a/tests/test_lora_disk_size_metrics.py b/tests/test_lora_disk_size_metrics.py new file mode 100644 index 0000000000..321b30fab4 --- /dev/null +++ b/tests/test_lora_disk_size_metrics.py @@ -0,0 +1,311 @@ +"""Unit tests for the wandb size-metric helpers added to LoRA disk sync. + +Covers two production paths introduced alongside LoRA disk sync: + +* ``FSDPEngine._log_disk_save_size`` -- after the FSDP engine writes a + checkpoint directory (full HF or PEFT adapter), it reports the + on-disk byte count via the **default** ``stats_tracker`` under flat + top-level keys: + + - ``weight_update_disk_bytes`` (always populated) + - ``weight_update_disk_lora_bytes`` (when ``use_lora=True``) + - ``weight_update_disk_full_bytes`` (when ``use_lora=False``) + +* ``SGLangBackend._log_disk_send_size`` -- the inference side records + how many bytes the SGLang process will pull from disk for the disk + weight-update HTTP call. Same default-tracker, flat-key contract: + + - ``weight_update_send_bytes`` (always populated) + - ``weight_update_send_lora_bytes`` (LoRA branch) + - ``weight_update_send_full_bytes`` (full-model branch) + +These tests are CPU-only, do not touch FSDP / SGLang / GPUs, and run +without a process group. Failures of either helper must NEVER bubble +up to the caller (metric emission must not break training); these +tests assert that contract too. +""" + +from __future__ import annotations + +import os +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from areal.api import WeightUpdateMeta + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write_files_with_total_size(directory: str, total_bytes: int) -> None: + """Create a couple of files under ``directory`` summing to ``total_bytes``. + + Splits the requested size across two files so the recursive + ``os.walk`` aggregation is genuinely exercised (not just a single + ``os.path.getsize``). + """ + os.makedirs(directory, exist_ok=True) + if total_bytes <= 0: + # Touch one empty file so the walk still has something to see. + with open(os.path.join(directory, "empty.bin"), "wb") as f: + f.write(b"") + return + half = total_bytes // 2 + rest = total_bytes - half + with open(os.path.join(directory, "a.bin"), "wb") as f: + f.write(b"\x00" * half) + with open(os.path.join(directory, "b.bin"), "wb") as f: + f.write(b"\x00" * rest) + + +# --------------------------------------------------------------------------- +# FSDPEngine._log_disk_save_size +# --------------------------------------------------------------------------- + + +class TestLogDiskSaveSize: + """Exercise the unbound method on a stub object so we don't need FSDP.""" + + @staticmethod + def _make_engine_stub(*, use_lora: bool): + """Build a SimpleNamespace that quacks like an FSDPEngine for the + narrow surface ``_log_disk_save_size`` actually touches. + """ + return SimpleNamespace( + config=SimpleNamespace(use_lora=use_lora), + logger=MagicMock(), + ) + + def _invoke(self, engine_stub, path): + from areal.engine.fsdp_engine import FSDPEngine + + # Bind the unbound method to the stub. + return FSDPEngine._log_disk_save_size(engine_stub, path) + + def test_lora_branch_writes_bytes_and_lora_bytes(self, tmp_path): + adapter_dir = tmp_path / "weight_update_v1" + _write_files_with_total_size(str(adapter_dir), 12345) + + engine = self._make_engine_stub(use_lora=True) + with patch( + "areal.engine.fsdp_engine.stats_tracker.scalar" + ) as mock_scalar: + self._invoke(engine, str(adapter_dir)) + + mock_scalar.assert_called_once() + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_disk_bytes"] == 12345.0 + assert kwargs["weight_update_disk_lora_bytes"] == 12345.0 + assert "weight_update_disk_full_bytes" not in kwargs + # Logger should have emitted a `[weight_update_disk]` info line. + assert engine.logger.info.called + msg = engine.logger.info.call_args.args[0] + assert "[weight_update_disk]" in msg + assert "use_lora=True" in msg + + def test_full_branch_writes_bytes_and_full_bytes(self, tmp_path): + full_dir = tmp_path / "full_model" + _write_files_with_total_size(str(full_dir), 7777) + + engine = self._make_engine_stub(use_lora=False) + with patch( + "areal.engine.fsdp_engine.stats_tracker.scalar" + ) as mock_scalar: + self._invoke(engine, str(full_dir)) + + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_disk_bytes"] == 7777.0 + assert kwargs["weight_update_disk_full_bytes"] == 7777.0 + assert "weight_update_disk_lora_bytes" not in kwargs + + def test_recursive_walk_sums_subdirectories(self, tmp_path): + root = tmp_path / "weight_update_v9" + sub = root / "nested" + os.makedirs(sub, exist_ok=True) + # 200 bytes top-level + 300 bytes nested = 500 total + with open(root / "top.bin", "wb") as f: + f.write(b"\x00" * 200) + with open(sub / "deep.bin", "wb") as f: + f.write(b"\x00" * 300) + + engine = self._make_engine_stub(use_lora=True) + with patch( + "areal.engine.fsdp_engine.stats_tracker.scalar" + ) as mock_scalar: + self._invoke(engine, str(root)) + + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_disk_bytes"] == 500.0 + assert kwargs["weight_update_disk_lora_bytes"] == 500.0 + + def test_nonexistent_path_does_not_raise(self, tmp_path): + """Reporting a non-existent path must not break the training loop; + ``os.walk`` simply yields nothing and the recorded size is 0. + """ + bad = tmp_path / "does_not_exist" + engine = self._make_engine_stub(use_lora=True) + with patch( + "areal.engine.fsdp_engine.stats_tracker.scalar" + ) as mock_scalar: + # Must not raise. + self._invoke(engine, str(bad)) + # Either reported as zero (preferred) or skipped silently. + if mock_scalar.called: + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_disk_bytes"] == 0.0 + + def test_scalar_failure_is_swallowed(self, tmp_path): + """If ``stats_tracker.scalar`` raises, the helper must catch it + and log a warning rather than propagate. + """ + d = tmp_path / "wu" + _write_files_with_total_size(str(d), 100) + engine = self._make_engine_stub(use_lora=True) + with patch( + "areal.engine.fsdp_engine.stats_tracker.scalar", + side_effect=RuntimeError("boom"), + ): + # Must not raise. + self._invoke(engine, str(d)) + # The warning path should have been hit. + assert engine.logger.warning.called + + +# --------------------------------------------------------------------------- +# SGLangBackend._log_disk_send_size +# --------------------------------------------------------------------------- + + +class TestLogDiskSendSize: + def test_lora_meta_records_bytes_and_lora_bytes(self, tmp_path): + from areal.engine.sglang_remote import SGLangBackend + + d = tmp_path / "weight_update_v3" + _write_files_with_total_size(str(d), 4096) + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="L", + version=3, + path=str(d), + ) + + with patch( + "areal.engine.sglang_remote.stats_tracker.scalar" + ) as mock_scalar: + SGLangBackend._log_disk_send_size(meta, use_lora=True) + + mock_scalar.assert_called_once() + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_send_bytes"] == 4096.0 + assert kwargs["weight_update_send_lora_bytes"] == 4096.0 + assert "weight_update_send_full_bytes" not in kwargs + + def test_full_meta_records_bytes_and_full_bytes(self, tmp_path): + from areal.engine.sglang_remote import SGLangBackend + + d = tmp_path / "full" + _write_files_with_total_size(str(d), 999) + meta = WeightUpdateMeta(type="disk", use_lora=False, path=str(d)) + + with patch( + "areal.engine.sglang_remote.stats_tracker.scalar" + ) as mock_scalar: + SGLangBackend._log_disk_send_size(meta, use_lora=False) + + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_send_bytes"] == 999.0 + assert kwargs["weight_update_send_full_bytes"] == 999.0 + assert "weight_update_send_lora_bytes" not in kwargs + + def test_meta_path_is_a_file_not_dir(self, tmp_path): + """If ``meta.path`` happens to point at a single file rather than + a directory, the helper must still record its size. + """ + from areal.engine.sglang_remote import SGLangBackend + + f = tmp_path / "alone.safetensors" + f.write_bytes(b"\x00" * 555) + meta = WeightUpdateMeta( + type="disk", use_lora=True, lora_name="L", version=0, path=str(f) + ) + + with patch( + "areal.engine.sglang_remote.stats_tracker.scalar" + ) as mock_scalar: + SGLangBackend._log_disk_send_size(meta, use_lora=True) + + kwargs = mock_scalar.call_args.kwargs + assert kwargs["weight_update_send_bytes"] == 555.0 + assert kwargs["weight_update_send_lora_bytes"] == 555.0 + + def test_none_path_is_a_noop(self): + from areal.engine.sglang_remote import SGLangBackend + + meta = WeightUpdateMeta(type="disk", use_lora=True, path=None) + with patch( + "areal.engine.sglang_remote.stats_tracker.scalar" + ) as mock_scalar: + SGLangBackend._log_disk_send_size(meta, use_lora=True) + # No path -> nothing to size -> no scalar call. + mock_scalar.assert_not_called() + + def test_scalar_failure_is_swallowed(self, tmp_path): + from areal.engine.sglang_remote import SGLangBackend + + d = tmp_path / "wu" + _write_files_with_total_size(str(d), 10) + meta = WeightUpdateMeta( + type="disk", use_lora=True, lora_name="L", version=0, path=str(d) + ) + with patch( + "areal.engine.sglang_remote.stats_tracker.scalar", + side_effect=RuntimeError("boom"), + ): + # Must not raise. + SGLangBackend._log_disk_send_size(meta, use_lora=True) + + +# --------------------------------------------------------------------------- +# build_disk_weight_update_requests integration: dispatch must call +# _log_disk_send_size on both branches. +# --------------------------------------------------------------------------- + + +class TestBuildDiskRequestsCallsSendSizeMetric: + def test_lora_branch_invokes_log_disk_send_size(self, tmp_path): + from areal.engine.sglang_remote import SGLangBackend + + d = tmp_path / "weight_update_v1" + _write_files_with_total_size(str(d), 10) + backend = SGLangBackend() + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="my-lora", + version=1, + path=str(d), + ) + + with patch.object(SGLangBackend, "_log_disk_send_size") as mock_metric: + backend.build_disk_weight_update_requests(meta) + mock_metric.assert_called_once() + # use_lora kwarg must be True for the LoRA branch. + assert mock_metric.call_args.kwargs["use_lora"] is True + + def test_full_branch_invokes_log_disk_send_size(self, tmp_path): + from areal.engine.sglang_remote import SGLangBackend + + d = tmp_path / "full" + _write_files_with_total_size(str(d), 10) + backend = SGLangBackend() + meta = WeightUpdateMeta(type="disk", use_lora=False, path=str(d)) + + with patch.object(SGLangBackend, "_log_disk_send_size") as mock_metric: + backend.build_disk_weight_update_requests(meta) + mock_metric.assert_called_once() + assert mock_metric.call_args.kwargs["use_lora"] is False diff --git a/tests/test_lora_disk_sync.py b/tests/test_lora_disk_sync.py index 8ba91072f9..7339145faa 100644 --- a/tests/test_lora_disk_sync.py +++ b/tests/test_lora_disk_sync.py @@ -41,6 +41,11 @@ # --------------------------------------------------------------------------- +# NOTE: This keyword tuple MUST stay in sync with the one inside +# ``FSDPEngine._save_lora_adapter_to_hf``. The companion test +# ``tests/test_lora_adapter_save.py`` exercises the production method +# directly; the helper-based tests in this file are a fast smoke layer +# that does not require importing the engine. _LORA_KEYWORDS = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") @@ -171,7 +176,12 @@ def test_default_flags(self): meta = WeightUpdateMeta(type="disk") assert meta.use_lora is False assert meta.lora_name == "" + assert meta.lora_int_id == 0 assert meta.version is None + # Default peft_config must be an empty dict (not None) so + # downstream callers can do a plain ``meta.peft_config.get(...)`` + # without a None-check. + assert meta.peft_config == {} def test_construct_lora_meta(self): meta = WeightUpdateMeta( @@ -483,3 +493,39 @@ def test_with_version_changes_path_and_lora_name_consistently(self, tmp_path): req = SGLangBackend().build_disk_weight_update_requests(m).requests[0] assert req.payload["lora_name"] == f"my-lora-v{v}" assert req.payload["lora_path"].endswith(f"weight_update_v{v}") + + def test_disk_lora_dispatch_emits_send_size_metric(self, tmp_path): + """Both branches of ``build_disk_weight_update_requests`` MUST + record the on-disk byte count via the default ``stats_tracker`` + under flat top-level keys (no scope prefix). The detailed + contract is exercised in ``tests/test_lora_disk_size_metrics.py``; + this is a lightweight integration smoke that wires the dispatch + and the metric emitter together. + """ + from unittest.mock import patch + + from areal.engine.sglang_remote import SGLangBackend + + adapter_dir = tmp_path / "weight_update_v0" + adapter_dir.mkdir() + (adapter_dir / "adapter_model.safetensors").write_bytes(b"\x00" * 16) + meta = WeightUpdateMeta( + type="disk", + use_lora=True, + lora_name="L", + version=0, + path=str(adapter_dir), + ) + with patch( + "areal.engine.sglang_remote.stats_tracker.scalar" + ) as mock_scalar: + SGLangBackend().build_disk_weight_update_requests(meta) + assert mock_scalar.called, ( + "build_disk_weight_update_requests must record send-size metrics" + ) + kwargs = mock_scalar.call_args.kwargs + assert "weight_update_send_bytes" in kwargs + assert "weight_update_send_lora_bytes" in kwargs + # Flat top-level keys -- never scoped under a subgroup. + for k in kwargs: + assert "/" not in k, f"metric key must be flat, got {k}" diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py index c1f950fe60..0b37c850e3 100644 --- a/tests/torchrun/run_lora_disk_sync.py +++ b/tests/torchrun/run_lora_disk_sync.py @@ -145,6 +145,9 @@ def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: in if cfg.get("peft_type") != "LORA": print(f"ERROR: peft_type != LORA: {cfg.get('peft_type')}", flush=True) return False + if cfg.get("task_type") != "CAUSAL_LM": + print(f"ERROR: task_type != CAUSAL_LM: {cfg.get('task_type')}", flush=True) + return False if cfg.get("r") != lora_rank: print(f"ERROR: r mismatch: {cfg.get('r')} vs {lora_rank}", flush=True) return False @@ -157,6 +160,15 @@ def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: in if "target_modules" not in cfg: print("ERROR: target_modules missing from adapter_config.json", flush=True) return False + # ``base_model_name_or_path`` is required by SGLang's + # /load_lora_adapter when it has to materialize the adapter on the + # base model side. + if "base_model_name_or_path" not in cfg: + print( + "ERROR: base_model_name_or_path missing from adapter_config.json", + flush=True, + ) + return False # Validate adapter_model.safetensors keys: every key must be a LoRA # tensor with the ``.default.`` segment stripped. @@ -171,6 +183,38 @@ def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: in if ".default." in k: print(f"ERROR: '.default.' was not stripped from key: {k}", flush=True) return False + if not k.endswith(".weight"): + print(f"ERROR: adapter key must end with '.weight': {k}", flush=True) + return False + + # Adapter-only saves should be tens of MB at most (Qwen3-0.6B + r=8 + # is around 19MB). If the artefact is GB-scale the engine almost + # certainly fell back to a full-model save -- the very bug that + # this whole disk-sync path exists to fix. Cap at 200 MB to leave + # plenty of headroom while still flagging a regression. + total_bytes = 0 + for root, _dirs, files in os.walk(adapter_dir): + for fname in files: + try: + total_bytes += os.path.getsize(os.path.join(root, fname)) + except OSError: + continue + if total_bytes <= 0: + print(f"ERROR: adapter directory total size is zero: {adapter_dir}", flush=True) + return False + if total_bytes > 200 * 1024 * 1024: + print( + f"ERROR: adapter directory size {total_bytes} bytes exceeds 200 MB ceiling -- " + f"this likely means the engine fell back to a full-model save instead of " + f"saving only the LoRA adapter.", + flush=True, + ) + return False + print( + f"[verify] adapter_dir={adapter_dir} total_bytes={total_bytes} " + f"(adapter-only, well under full-model size)", + flush=True, + ) return True From fa1ad50c1369049aae68d9fd6cf0cd5796d151ae Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 11 May 2026 00:46:44 +0800 Subject: [PATCH 07/12] test(lora_disk_sync): fix --- tests/test_lora_disk_sync_e2e.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_lora_disk_sync_e2e.py b/tests/test_lora_disk_sync_e2e.py index 4115ad5c00..5aee3de665 100644 --- a/tests/test_lora_disk_sync_e2e.py +++ b/tests/test_lora_disk_sync_e2e.py @@ -15,6 +15,7 @@ pytest tests/test_lora_disk_sync_e2e.py -v """ +import os import subprocess import pytest @@ -51,21 +52,20 @@ def _run_torchrun_test(alloc_mode: str, output: str, n_gpus: int | None = None): f"--backend={alloc_mode}", f"--output={output}", ] + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" try: - result = subprocess.run( + subprocess.run( cmd, check=True, - capture_output=True, - text=True, timeout=300, + env=env, ) - print(result.stdout) except subprocess.CalledProcessError as e: pytest.fail( f"torchrun failed with exit code {e.returncode}.\n" - f"stdout:\n{e.stdout}\n" - f"stderr:\n{e.stderr}" + "See the live torchrun output above for details." ) except subprocess.TimeoutExpired: pytest.fail("torchrun timed out after 300 seconds.") From 8bcd7b56ddf95df29f0e9db78d5eed0eb45a47aa Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 11 May 2026 00:50:32 +0800 Subject: [PATCH 08/12] test(lora_disk_sync): add path --- tests/test_lora_disk_sync_e2e.py | 3 +++ tests/torchrun/run_lora_disk_sync.py | 22 ++++++++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/test_lora_disk_sync_e2e.py b/tests/test_lora_disk_sync_e2e.py index 5aee3de665..408eec7dd9 100644 --- a/tests/test_lora_disk_sync_e2e.py +++ b/tests/test_lora_disk_sync_e2e.py @@ -24,6 +24,8 @@ from areal.infra.platforms import current_platform from areal.utils.network import find_free_ports +MODEL_PATH = "/workspace/models/Qwen3-0.6B" + def _run_torchrun_test(alloc_mode: str, output: str, n_gpus: int | None = None): """Launch the torchrun helper script and check the result. @@ -51,6 +53,7 @@ def _run_torchrun_test(alloc_mode: str, output: str, n_gpus: int | None = None): "tests/torchrun/run_lora_disk_sync.py", f"--backend={alloc_mode}", f"--output={output}", + f"--model-path={MODEL_PATH}", ] env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py index 0b37c850e3..bca45f594e 100644 --- a/tests/torchrun/run_lora_disk_sync.py +++ b/tests/torchrun/run_lora_disk_sync.py @@ -51,7 +51,7 @@ from areal.engine import FSDPEngine from areal.infra.platforms import current_platform -MODEL_PATH = get_model_path( +DEFAULT_MODEL_PATH = get_model_path( "/storage/openpsi/models/Qwen__Qwen3-0.6B/", "Qwen/Qwen3-0.6B" ) @@ -64,14 +64,14 @@ def write_result(path: str, success: bool) -> None: f.write("Passed" if success else "Failed") -def make_fsdp_engine_with_lora(backend: str) -> FSDPEngine: +def make_fsdp_engine_with_lora(backend: str, model_path: str) -> FSDPEngine: """Create an FSDPEngine with LoRA + disk-mode weight update.""" config = TrainEngineConfig( backend=backend, experiment_name="test_lora_disk_sync", trial_name="test", mb_spec=MicroBatchSpec(max_tokens_per_mb=256), - path=MODEL_PATH, + path=model_path, optimizer=OptimizerConfig(), fsdp=FSDPEngineConfig(memory_efficient_load=True), # LoRA config @@ -218,7 +218,9 @@ def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: in return True -def test_lora_disk_sync(backend: str, output: str | None = None) -> None: +def test_lora_disk_sync( + backend: str, output: str | None = None, model_path: str = DEFAULT_MODEL_PATH +) -> None: """Main test logic for LoRA disk sync.""" rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -240,7 +242,7 @@ def test_lora_disk_sync(backend: str, output: str | None = None) -> None: f"use_lora=True weight_update_mode=disk", flush=True, ) - engine = make_fsdp_engine_with_lora(backend) + engine = make_fsdp_engine_with_lora(backend, model_path) # Step 2: Verify LoRA parameters exist on the in-memory model. lora_count, base_count = count_lora_and_base_params(engine) @@ -329,8 +331,16 @@ def main(): default="fsdp:d1t1", help="Backend allocation string (e.g., 'fsdp:d1t1')", ) + parser.add_argument( + "--model-path", + type=str, + default=DEFAULT_MODEL_PATH, + help="Local model path used to initialize the FSDP engine.", + ) args = parser.parse_args() - test_lora_disk_sync(backend=args.backend, output=args.output) + test_lora_disk_sync( + backend=args.backend, output=args.output, model_path=args.model_path + ) if __name__ == "__main__": From eb8fe4c8f54d36039a6bef67884bd77b5cb89dd6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 11 May 2026 00:53:34 +0800 Subject: [PATCH 09/12] refactor(tests): fix --- tests/test_lora_disk_sync_e2e.py | 4 +++- tests/torchrun/run_lora_disk_sync.py | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_lora_disk_sync_e2e.py b/tests/test_lora_disk_sync_e2e.py index 408eec7dd9..be6a743e42 100644 --- a/tests/test_lora_disk_sync_e2e.py +++ b/tests/test_lora_disk_sync_e2e.py @@ -24,7 +24,9 @@ from areal.infra.platforms import current_platform from areal.utils.network import find_free_ports -MODEL_PATH = "/workspace/models/Qwen3-0.6B" +MODEL_PATH = os.environ.get( + "AREAL_LORA_DISK_SYNC_MODEL_PATH", "/workspace/models/Qwen3-0.6B" +) def _run_torchrun_test(alloc_mode: str, output: str, n_gpus: int | None = None): diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py index bca45f594e..ad0a23bd1a 100644 --- a/tests/torchrun/run_lora_disk_sync.py +++ b/tests/torchrun/run_lora_disk_sync.py @@ -38,8 +38,6 @@ import torch.distributed as dist from safetensors.torch import load_file as safetensors_load_file -from tests.utils import get_model_path - from areal.api import FinetuneSpec from areal.api.alloc_mode import ModelAllocation from areal.api.cli_args import ( @@ -51,8 +49,8 @@ from areal.engine import FSDPEngine from areal.infra.platforms import current_platform -DEFAULT_MODEL_PATH = get_model_path( - "/storage/openpsi/models/Qwen__Qwen3-0.6B/", "Qwen/Qwen3-0.6B" +DEFAULT_MODEL_PATH = os.environ.get( + "AREAL_LORA_DISK_SYNC_MODEL_PATH", "/workspace/models/Qwen3-0.6B" ) From d9db7f10ee612fb7e110d08e7b59483063dd075c Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 11 May 2026 00:55:26 +0800 Subject: [PATCH 10/12] refactor(tests): fix --- tests/torchrun/run_lora_disk_sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py index ad0a23bd1a..994e24e8a6 100644 --- a/tests/torchrun/run_lora_disk_sync.py +++ b/tests/torchrun/run_lora_disk_sync.py @@ -225,8 +225,8 @@ def test_lora_disk_sync( success = True alloc_mode = ModelAllocation.from_str(backend) - dp = alloc_mode.parallel.dp - tp = alloc_mode.parallel.tp + dp = alloc_mode.parallel.dp_size + tp = alloc_mode.parallel.tp_size print( f"[Rank {rank}] Starting LoRA disk sync test | " From 7f1a9a04abb47a4a6948c45d625490f37e698511 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Mon, 11 May 2026 01:06:34 +0800 Subject: [PATCH 11/12] refactor(test): lora fix --- areal/engine/fsdp_engine.py | 85 +------- areal/engine/sglang_remote.py | 48 ----- tests/test_lora_adapter_save.py | 54 +---- tests/test_lora_disk_size_metrics.py | 311 --------------------------- tests/test_lora_disk_sync.py | 72 +------ tests/test_lora_disk_sync_e2e.py | 30 +-- tests/torchrun/run_lora_disk_sync.py | 37 +--- 7 files changed, 17 insertions(+), 620 deletions(-) delete mode 100644 tests/test_lora_disk_size_metrics.py diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 40f343748c..cf6a56fdcf 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -1632,17 +1632,9 @@ def _save_model_to_hf( ): """Save model in HuggingFace format. - For full models (``use_lora=False``), the entire ``state_dict`` is - written via :meth:`PreTrainedModel.save_pretrained`. - - For LoRA models (``use_lora=True``), only the trainable LoRA - adapter parameters are written, in the standard PEFT layout - (``adapter_model.safetensors`` + ``adapter_config.json``). This - is the format that SGLang's ``/load_lora_adapter`` endpoint - expects, so the existing disk-based weight update path - (``meta.use_lora=True`` -> ``build_disk_weight_update_requests`` - -> ``/load_lora_adapter``) becomes end-to-end functional without - any additional dispatch logic. + LoRA mode writes adapter-only PEFT files for SGLang's + ``/load_lora_adapter`` endpoint; non-LoRA mode writes a full HF + model directory. """ if self.model is None: raise RuntimeError("Model not initialized") @@ -1665,60 +1657,8 @@ def _save_model_to_hf( tokenizer.save_pretrained(path) if processor is not None: processor.save_pretrained(path) - # Record on-disk artefact size as a wandb metric so the - # weight-sync data volume is observable. LoRA disk sync - # uses adapter-only PEFT files (~ tens of MB), full-model - # sync uses the full HF directory (~ hundreds of MB to GBs); - # plotting these side-by-side makes the bandwidth saving - # explicit. - self._log_disk_save_size(path) dist.barrier(group=self.cpu_group) - def _log_disk_save_size(self, path: str) -> None: - """Record the on-disk size of a checkpoint directory as a wandb metric. - - Sums the bytes of every regular file under ``path`` (recursively) - and reports the total under flat top-level keys - ``weight_update_disk_bytes`` (always populated) plus a - format-specific key (``weight_update_disk_lora_bytes`` when - ``self.config.use_lora`` is True, otherwise - ``weight_update_disk_full_bytes``). - - The metrics are written via the **default** ``stats_tracker`` so - they flow through the same tracker that already drives wandb - panels for ``ppo_actor`` etc. Using a top-level (no ``/``) key - avoids creating a separate wandb panel group that the user - might overlook. - - Called on rank 0 only; failures are swallowed because metric - emission must never break the training loop. - """ - try: - total_bytes = 0 - for root, _dirs, files in os.walk(path): - for fname in files: - fpath = os.path.join(root, fname) - try: - total_bytes += os.path.getsize(fpath) - except OSError: - # File may have been racily deleted; skip it. - continue - kwargs: dict[str, float] = { - "weight_update_disk_bytes": float(total_bytes), - } - if self.config.use_lora: - kwargs["weight_update_disk_lora_bytes"] = float(total_bytes) - else: - kwargs["weight_update_disk_full_bytes"] = float(total_bytes) - stats_tracker.scalar(**kwargs) - self.logger.info( - f"[weight_update_disk] saved {total_bytes} bytes to {path} " - f"(use_lora={self.config.use_lora})" - ) - except Exception as e: - # Metric emission must never break training. - self.logger.warning(f"Failed to record disk-save size metric: {e}") - def _save_lora_adapter_to_hf( self, path: str, @@ -1726,27 +1666,14 @@ def _save_lora_adapter_to_hf( ): """Save only LoRA adapter weights in standard PEFT format. - Filters ``state_dict`` for LoRA adapter tensors, strips the active - adapter name segment (``.default.``) from each key so the result - matches the layout produced by ``PeftModel.save_pretrained`` / - ``get_peft_model_state_dict`` (and hence what SGLang's - ``/load_lora_adapter`` expects), and writes: - - * ``adapter_model.safetensors`` -- adapter tensor file - * ``adapter_config.json`` -- PEFT-compatible LoRA config - - Called on rank 0 only. + The saved layout matches what SGLang's ``/load_lora_adapter`` + endpoint expects. """ import json from safetensors.torch import save_file as safetensors_save_file - # PEFT's named_parameters() / state_dict include the active adapter - # name in keys, e.g. "...lora_A.default.weight". The standard PEFT - # adapter file format (as produced by get_peft_model_state_dict / - # save_pretrained) strips this adapter name, yielding - # "...lora_A.weight". SGLang's /load_lora_adapter expects the - # stripped format. + # PEFT adapter files omit the active adapter segment, e.g. ".default.". adapter_name = "default" # PEFT default adapter name lora_keywords = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 6584e89242..52039df326 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -143,11 +143,9 @@ def build_disk_weight_update_requests( payload={"lora_name": lora_name, "lora_path": str(meta.path)}, ) ] - self._log_disk_send_size(meta, use_lora=True) return WeightUpdateRequests(requests=requests) else: # Full model update - self._log_disk_send_size(meta, use_lora=False) return WeightUpdateRequests( requests=[ HttpRequest( @@ -160,52 +158,6 @@ def build_disk_weight_update_requests( ] ) - @staticmethod - def _log_disk_send_size(meta: WeightUpdateMeta, *, use_lora: bool) -> None: - """Record the size of weights the inference side will pull from disk. - - For disk-mode weight updates, the HTTP payload itself is just a - pointer (``model_path`` or ``lora_path``) -- the actual bytes - SGLang loads come from that on-disk directory. We surface the - aggregate file size via the **default** ``stats_tracker`` under - flat top-level keys (no scope prefix) so the values land in the - same wandb panel group as ``ppo_actor/*`` and don't get hidden - in a separate auto-generated panel: - - * ``weight_update_send_bytes`` -- always populated, - * ``weight_update_send_lora_bytes`` -- adapter-only sends, - * ``weight_update_send_full_bytes`` -- full-model sends. - - Failures are swallowed because metric emission must never break - the weight-update path. - """ - path = meta.path - if path is None: - return - try: - total_bytes = 0 - if os.path.isdir(path): - for root, _dirs, files in os.walk(path): - for fname in files: - fpath = os.path.join(root, fname) - try: - total_bytes += os.path.getsize(fpath) - except OSError: - continue - elif os.path.isfile(path): - total_bytes = os.path.getsize(path) - kwargs: dict[str, float] = { - "weight_update_send_bytes": float(total_bytes), - } - if use_lora: - kwargs["weight_update_send_lora_bytes"] = float(total_bytes) - else: - kwargs["weight_update_send_full_bytes"] = float(total_bytes) - stats_tracker.scalar(**kwargs) - except Exception: - # Metric emission must never break the weight-update path. - pass - def build_distributed_weight_update_requests( self, meta: WeightUpdateMeta, param_specs: list[ParamSpec] ) -> WeightUpdateRequests: diff --git a/tests/test_lora_adapter_save.py b/tests/test_lora_adapter_save.py index b5ddfc6896..10ca1ef5dd 100644 --- a/tests/test_lora_adapter_save.py +++ b/tests/test_lora_adapter_save.py @@ -1,19 +1,7 @@ """Unit tests for ``FSDPEngine._save_lora_adapter_to_hf``. -These tests invoke the real production method (no helper copy) on a -lightweight stub object that quacks like an FSDPEngine. We feed it a -synthetic state_dict matching the exact key shape produced by -``peft.get_peft_model`` on a HuggingFace transformer (i.e. with the -``base_model.model.`` prefix and the ``.default.`` adapter segment), -let the method write to a tmp directory, then assert that: - -* ``adapter_model.safetensors`` exists, parses, and contains ONLY - LoRA tensors with ``.default.`` stripped (the format SGLang's - ``/load_lora_adapter`` consumes); -* ``adapter_config.json`` is well-formed PEFT JSON with the required - fields populated from the engine's ``TrainEngineConfig``. - -The tests are CPU-only and require no FSDP / GPU. +The saved adapter must contain only LoRA tensors and PEFT metadata that +SGLang's ``/load_lora_adapter`` can consume. """ from __future__ import annotations @@ -27,24 +15,19 @@ from safetensors.torch import load_file as safetensors_load_file -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - def _make_lora_state_dict() -> dict[str, torch.Tensor]: """Synthetic state_dict in PEFT layout (``base_model.model.`` + ``.default.``). Mixes base, lora_A, lora_B, and an embedding LoRA pair so all four LoRA keywords are exercised. """ return { - # Base weights -- must be filtered out. + # Base weights must be filtered out. "base_model.model.model.embed_tokens.weight": torch.zeros(10, 4), "base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight": torch.zeros( 4, 4 ), "base_model.model.lm_head.weight": torch.zeros(10, 4), - # LoRA tensors -- must be kept and have ``.default.`` stripped. + # LoRA tensors must be kept with ``.default.`` stripped. "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight": torch.ones( 8, 4 ), @@ -93,11 +76,6 @@ def _invoke(engine_stub, path, state_dict): return FSDPEngine._save_lora_adapter_to_hf(engine_stub, path, state_dict) -# --------------------------------------------------------------------------- -# Tests: the safetensors output -# --------------------------------------------------------------------------- - - class TestAdapterSafetensors: def test_only_lora_keys_are_written(self, tmp_path): engine = _make_engine_stub() @@ -110,9 +88,7 @@ def test_only_lora_keys_are_written(self, tmp_path): assert f.stat().st_size > 0 loaded = safetensors_load_file(str(f)) - # 4 lora_A/B layer pairs + 2 lora_embedding_A/B = 6 tensors assert len(loaded) == 6 - # No base / lm_head keys leaked through. for k in loaded: assert "base_layer" not in k assert "lm_head" not in k @@ -125,12 +101,7 @@ def test_default_segment_is_stripped(self, tmp_path): loaded = safetensors_load_file(str(d / "adapter_model.safetensors")) for k in loaded: - # The PEFT adapter file format never carries the active - # adapter name segment. SGLang's loader assumes it's gone. assert ".default." not in k - # Each remaining key must end in `.weight` and contain a - # LoRA keyword somewhere (covering both the linear and - # embedding cases). assert k.endswith(".weight"), k assert any( kw in k @@ -170,11 +141,6 @@ def test_missing_lora_raises(self, tmp_path): _invoke(engine, str(d), bare_state) -# --------------------------------------------------------------------------- -# Tests: adapter_config.json -# --------------------------------------------------------------------------- - - class TestAdapterConfigJson: def _read(self, d) -> dict: with open(os.path.join(str(d), "adapter_config.json")) as f: @@ -187,7 +153,6 @@ def test_required_fields(self, tmp_path): _invoke(engine, str(d), _make_lora_state_dict()) cfg = self._read(d) - # Mandatory PEFT fields: assert cfg["peft_type"] == "LORA" assert cfg["task_type"] == "CAUSAL_LM" assert cfg["r"] == 16 @@ -230,17 +195,8 @@ def test_base_model_path_is_carried(self, tmp_path): assert cfg["base_model_name_or_path"] == "/some/where/qwen3-0.6b" -# --------------------------------------------------------------------------- -# Tests: directory layout matches what /load_lora_adapter expects. -# --------------------------------------------------------------------------- - - class TestLoadLoraAdapterContract: - """The on-disk layout produced by ``_save_lora_adapter_to_hf`` MUST be - exactly what SGLang's ``/load_lora_adapter`` reads. This test pins - the contract so a future refactor cannot silently break the - inference side. - """ + """Pin the PEFT layout consumed by SGLang's LoRA loader.""" def test_two_files_present(self, tmp_path): engine = _make_engine_stub() diff --git a/tests/test_lora_disk_size_metrics.py b/tests/test_lora_disk_size_metrics.py deleted file mode 100644 index 321b30fab4..0000000000 --- a/tests/test_lora_disk_size_metrics.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Unit tests for the wandb size-metric helpers added to LoRA disk sync. - -Covers two production paths introduced alongside LoRA disk sync: - -* ``FSDPEngine._log_disk_save_size`` -- after the FSDP engine writes a - checkpoint directory (full HF or PEFT adapter), it reports the - on-disk byte count via the **default** ``stats_tracker`` under flat - top-level keys: - - - ``weight_update_disk_bytes`` (always populated) - - ``weight_update_disk_lora_bytes`` (when ``use_lora=True``) - - ``weight_update_disk_full_bytes`` (when ``use_lora=False``) - -* ``SGLangBackend._log_disk_send_size`` -- the inference side records - how many bytes the SGLang process will pull from disk for the disk - weight-update HTTP call. Same default-tracker, flat-key contract: - - - ``weight_update_send_bytes`` (always populated) - - ``weight_update_send_lora_bytes`` (LoRA branch) - - ``weight_update_send_full_bytes`` (full-model branch) - -These tests are CPU-only, do not touch FSDP / SGLang / GPUs, and run -without a process group. Failures of either helper must NEVER bubble -up to the caller (metric emission must not break training); these -tests assert that contract too. -""" - -from __future__ import annotations - -import os -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest - -from areal.api import WeightUpdateMeta - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _write_files_with_total_size(directory: str, total_bytes: int) -> None: - """Create a couple of files under ``directory`` summing to ``total_bytes``. - - Splits the requested size across two files so the recursive - ``os.walk`` aggregation is genuinely exercised (not just a single - ``os.path.getsize``). - """ - os.makedirs(directory, exist_ok=True) - if total_bytes <= 0: - # Touch one empty file so the walk still has something to see. - with open(os.path.join(directory, "empty.bin"), "wb") as f: - f.write(b"") - return - half = total_bytes // 2 - rest = total_bytes - half - with open(os.path.join(directory, "a.bin"), "wb") as f: - f.write(b"\x00" * half) - with open(os.path.join(directory, "b.bin"), "wb") as f: - f.write(b"\x00" * rest) - - -# --------------------------------------------------------------------------- -# FSDPEngine._log_disk_save_size -# --------------------------------------------------------------------------- - - -class TestLogDiskSaveSize: - """Exercise the unbound method on a stub object so we don't need FSDP.""" - - @staticmethod - def _make_engine_stub(*, use_lora: bool): - """Build a SimpleNamespace that quacks like an FSDPEngine for the - narrow surface ``_log_disk_save_size`` actually touches. - """ - return SimpleNamespace( - config=SimpleNamespace(use_lora=use_lora), - logger=MagicMock(), - ) - - def _invoke(self, engine_stub, path): - from areal.engine.fsdp_engine import FSDPEngine - - # Bind the unbound method to the stub. - return FSDPEngine._log_disk_save_size(engine_stub, path) - - def test_lora_branch_writes_bytes_and_lora_bytes(self, tmp_path): - adapter_dir = tmp_path / "weight_update_v1" - _write_files_with_total_size(str(adapter_dir), 12345) - - engine = self._make_engine_stub(use_lora=True) - with patch( - "areal.engine.fsdp_engine.stats_tracker.scalar" - ) as mock_scalar: - self._invoke(engine, str(adapter_dir)) - - mock_scalar.assert_called_once() - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_disk_bytes"] == 12345.0 - assert kwargs["weight_update_disk_lora_bytes"] == 12345.0 - assert "weight_update_disk_full_bytes" not in kwargs - # Logger should have emitted a `[weight_update_disk]` info line. - assert engine.logger.info.called - msg = engine.logger.info.call_args.args[0] - assert "[weight_update_disk]" in msg - assert "use_lora=True" in msg - - def test_full_branch_writes_bytes_and_full_bytes(self, tmp_path): - full_dir = tmp_path / "full_model" - _write_files_with_total_size(str(full_dir), 7777) - - engine = self._make_engine_stub(use_lora=False) - with patch( - "areal.engine.fsdp_engine.stats_tracker.scalar" - ) as mock_scalar: - self._invoke(engine, str(full_dir)) - - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_disk_bytes"] == 7777.0 - assert kwargs["weight_update_disk_full_bytes"] == 7777.0 - assert "weight_update_disk_lora_bytes" not in kwargs - - def test_recursive_walk_sums_subdirectories(self, tmp_path): - root = tmp_path / "weight_update_v9" - sub = root / "nested" - os.makedirs(sub, exist_ok=True) - # 200 bytes top-level + 300 bytes nested = 500 total - with open(root / "top.bin", "wb") as f: - f.write(b"\x00" * 200) - with open(sub / "deep.bin", "wb") as f: - f.write(b"\x00" * 300) - - engine = self._make_engine_stub(use_lora=True) - with patch( - "areal.engine.fsdp_engine.stats_tracker.scalar" - ) as mock_scalar: - self._invoke(engine, str(root)) - - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_disk_bytes"] == 500.0 - assert kwargs["weight_update_disk_lora_bytes"] == 500.0 - - def test_nonexistent_path_does_not_raise(self, tmp_path): - """Reporting a non-existent path must not break the training loop; - ``os.walk`` simply yields nothing and the recorded size is 0. - """ - bad = tmp_path / "does_not_exist" - engine = self._make_engine_stub(use_lora=True) - with patch( - "areal.engine.fsdp_engine.stats_tracker.scalar" - ) as mock_scalar: - # Must not raise. - self._invoke(engine, str(bad)) - # Either reported as zero (preferred) or skipped silently. - if mock_scalar.called: - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_disk_bytes"] == 0.0 - - def test_scalar_failure_is_swallowed(self, tmp_path): - """If ``stats_tracker.scalar`` raises, the helper must catch it - and log a warning rather than propagate. - """ - d = tmp_path / "wu" - _write_files_with_total_size(str(d), 100) - engine = self._make_engine_stub(use_lora=True) - with patch( - "areal.engine.fsdp_engine.stats_tracker.scalar", - side_effect=RuntimeError("boom"), - ): - # Must not raise. - self._invoke(engine, str(d)) - # The warning path should have been hit. - assert engine.logger.warning.called - - -# --------------------------------------------------------------------------- -# SGLangBackend._log_disk_send_size -# --------------------------------------------------------------------------- - - -class TestLogDiskSendSize: - def test_lora_meta_records_bytes_and_lora_bytes(self, tmp_path): - from areal.engine.sglang_remote import SGLangBackend - - d = tmp_path / "weight_update_v3" - _write_files_with_total_size(str(d), 4096) - meta = WeightUpdateMeta( - type="disk", - use_lora=True, - lora_name="L", - version=3, - path=str(d), - ) - - with patch( - "areal.engine.sglang_remote.stats_tracker.scalar" - ) as mock_scalar: - SGLangBackend._log_disk_send_size(meta, use_lora=True) - - mock_scalar.assert_called_once() - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_send_bytes"] == 4096.0 - assert kwargs["weight_update_send_lora_bytes"] == 4096.0 - assert "weight_update_send_full_bytes" not in kwargs - - def test_full_meta_records_bytes_and_full_bytes(self, tmp_path): - from areal.engine.sglang_remote import SGLangBackend - - d = tmp_path / "full" - _write_files_with_total_size(str(d), 999) - meta = WeightUpdateMeta(type="disk", use_lora=False, path=str(d)) - - with patch( - "areal.engine.sglang_remote.stats_tracker.scalar" - ) as mock_scalar: - SGLangBackend._log_disk_send_size(meta, use_lora=False) - - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_send_bytes"] == 999.0 - assert kwargs["weight_update_send_full_bytes"] == 999.0 - assert "weight_update_send_lora_bytes" not in kwargs - - def test_meta_path_is_a_file_not_dir(self, tmp_path): - """If ``meta.path`` happens to point at a single file rather than - a directory, the helper must still record its size. - """ - from areal.engine.sglang_remote import SGLangBackend - - f = tmp_path / "alone.safetensors" - f.write_bytes(b"\x00" * 555) - meta = WeightUpdateMeta( - type="disk", use_lora=True, lora_name="L", version=0, path=str(f) - ) - - with patch( - "areal.engine.sglang_remote.stats_tracker.scalar" - ) as mock_scalar: - SGLangBackend._log_disk_send_size(meta, use_lora=True) - - kwargs = mock_scalar.call_args.kwargs - assert kwargs["weight_update_send_bytes"] == 555.0 - assert kwargs["weight_update_send_lora_bytes"] == 555.0 - - def test_none_path_is_a_noop(self): - from areal.engine.sglang_remote import SGLangBackend - - meta = WeightUpdateMeta(type="disk", use_lora=True, path=None) - with patch( - "areal.engine.sglang_remote.stats_tracker.scalar" - ) as mock_scalar: - SGLangBackend._log_disk_send_size(meta, use_lora=True) - # No path -> nothing to size -> no scalar call. - mock_scalar.assert_not_called() - - def test_scalar_failure_is_swallowed(self, tmp_path): - from areal.engine.sglang_remote import SGLangBackend - - d = tmp_path / "wu" - _write_files_with_total_size(str(d), 10) - meta = WeightUpdateMeta( - type="disk", use_lora=True, lora_name="L", version=0, path=str(d) - ) - with patch( - "areal.engine.sglang_remote.stats_tracker.scalar", - side_effect=RuntimeError("boom"), - ): - # Must not raise. - SGLangBackend._log_disk_send_size(meta, use_lora=True) - - -# --------------------------------------------------------------------------- -# build_disk_weight_update_requests integration: dispatch must call -# _log_disk_send_size on both branches. -# --------------------------------------------------------------------------- - - -class TestBuildDiskRequestsCallsSendSizeMetric: - def test_lora_branch_invokes_log_disk_send_size(self, tmp_path): - from areal.engine.sglang_remote import SGLangBackend - - d = tmp_path / "weight_update_v1" - _write_files_with_total_size(str(d), 10) - backend = SGLangBackend() - meta = WeightUpdateMeta( - type="disk", - use_lora=True, - lora_name="my-lora", - version=1, - path=str(d), - ) - - with patch.object(SGLangBackend, "_log_disk_send_size") as mock_metric: - backend.build_disk_weight_update_requests(meta) - mock_metric.assert_called_once() - # use_lora kwarg must be True for the LoRA branch. - assert mock_metric.call_args.kwargs["use_lora"] is True - - def test_full_branch_invokes_log_disk_send_size(self, tmp_path): - from areal.engine.sglang_remote import SGLangBackend - - d = tmp_path / "full" - _write_files_with_total_size(str(d), 10) - backend = SGLangBackend() - meta = WeightUpdateMeta(type="disk", use_lora=False, path=str(d)) - - with patch.object(SGLangBackend, "_log_disk_send_size") as mock_metric: - backend.build_disk_weight_update_requests(meta) - mock_metric.assert_called_once() - assert mock_metric.call_args.kwargs["use_lora"] is False diff --git a/tests/test_lora_disk_sync.py b/tests/test_lora_disk_sync.py index 7339145faa..4a2f03cc9c 100644 --- a/tests/test_lora_disk_sync.py +++ b/tests/test_lora_disk_sync.py @@ -1,27 +1,4 @@ -"""Unit tests for LoRA disk-based weight synchronization. - -The disk-mode LoRA sync flow on FSDP + SGLang is: - -* Training side (FSDP): ``FSDPEngine._save_model_to_hf`` branches on - ``self.config.use_lora``. When ``use_lora=True`` it calls - ``_save_lora_adapter_to_hf`` which: - - filters the full state_dict for ``lora_A`` / ``lora_B`` / - ``lora_embedding_A`` / ``lora_embedding_B`` keys, - - strips the active-adapter segment ``.default.`` so the layout - matches what ``peft.PeftModel.save_pretrained`` would produce - (and what SGLang's ``/load_lora_adapter`` expects), - - writes ``adapter_model.safetensors`` + ``adapter_config.json``. - -* Inference side (SGLang): ``SGLangBackend.build_disk_weight_update_requests`` - routes ``meta.use_lora=True`` to ``HttpRequest("/load_lora_adapter", ...)`` - and the standard full-model branch to ``/update_weights_from_disk``. - -These unit tests exercise (a) the LoRA filtering / key normalisation -logic, (b) the ``WeightUpdateMeta`` schema, (c) the SGLang -request-building dispatch, and (d) the ``get_versioned_lora_name`` -utility. They are CPU-only and do not require any GPU or running -SGLang server. -""" +"""Unit tests for LoRA disk-based weight synchronization.""" import copy import json @@ -36,16 +13,7 @@ from areal.api.io_struct import get_versioned_lora_name -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -# NOTE: This keyword tuple MUST stay in sync with the one inside -# ``FSDPEngine._save_lora_adapter_to_hf``. The companion test -# ``tests/test_lora_adapter_save.py`` exercises the production method -# directly; the helper-based tests in this file are a fast smoke layer -# that does not require importing the engine. +# Keep this in sync with ``FSDPEngine._save_lora_adapter_to_hf``. _LORA_KEYWORDS = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") @@ -493,39 +461,3 @@ def test_with_version_changes_path_and_lora_name_consistently(self, tmp_path): req = SGLangBackend().build_disk_weight_update_requests(m).requests[0] assert req.payload["lora_name"] == f"my-lora-v{v}" assert req.payload["lora_path"].endswith(f"weight_update_v{v}") - - def test_disk_lora_dispatch_emits_send_size_metric(self, tmp_path): - """Both branches of ``build_disk_weight_update_requests`` MUST - record the on-disk byte count via the default ``stats_tracker`` - under flat top-level keys (no scope prefix). The detailed - contract is exercised in ``tests/test_lora_disk_size_metrics.py``; - this is a lightweight integration smoke that wires the dispatch - and the metric emitter together. - """ - from unittest.mock import patch - - from areal.engine.sglang_remote import SGLangBackend - - adapter_dir = tmp_path / "weight_update_v0" - adapter_dir.mkdir() - (adapter_dir / "adapter_model.safetensors").write_bytes(b"\x00" * 16) - meta = WeightUpdateMeta( - type="disk", - use_lora=True, - lora_name="L", - version=0, - path=str(adapter_dir), - ) - with patch( - "areal.engine.sglang_remote.stats_tracker.scalar" - ) as mock_scalar: - SGLangBackend().build_disk_weight_update_requests(meta) - assert mock_scalar.called, ( - "build_disk_weight_update_requests must record send-size metrics" - ) - kwargs = mock_scalar.call_args.kwargs - assert "weight_update_send_bytes" in kwargs - assert "weight_update_send_lora_bytes" in kwargs - # Flat top-level keys -- never scoped under a subgroup. - for k in kwargs: - assert "/" not in k, f"metric key must be flat, got {k}" diff --git a/tests/test_lora_disk_sync_e2e.py b/tests/test_lora_disk_sync_e2e.py index be6a743e42..95ed9fdc79 100644 --- a/tests/test_lora_disk_sync_e2e.py +++ b/tests/test_lora_disk_sync_e2e.py @@ -1,19 +1,4 @@ -"""End-to-end test for LoRA disk-based weight synchronization. - -This test launches ``torchrun`` with the helper script -``tests/torchrun/run_lora_disk_sync.py`` which: - -1. Creates a small Qwen3-0.6B model with FSDP + LoRA + ``weight_update_mode="disk"``. -2. Writes a PEFT-format adapter checkpoint via - ``FSDPEngine._save_model_to_hf`` and verifies the on-disk artefacts. -3. Verifies the model can still run a forward pass after the save. - -The test requires GPUs and is marked ``@pytest.mark.slow`` and -``@pytest.mark.sglang`` following AReaL test conventions. - -Usage: - pytest tests/test_lora_disk_sync_e2e.py -v -""" +"""End-to-end test for FSDP LoRA adapter-only disk saves.""" import os import subprocess @@ -30,18 +15,7 @@ def _run_torchrun_test(alloc_mode: str, output: str, n_gpus: int | None = None): - """Launch the torchrun helper script and check the result. - - Parameters - ---------- - alloc_mode : str - Backend allocation string, e.g. ``"fsdp:d1t1"``. - output : str - Path to the result file that the torchrun script writes. - n_gpus : int, optional - Override number of GPUs. If ``None``, it is derived from - ``alloc_mode``. - """ + """Launch the torchrun helper script and check the result.""" port = find_free_ports(1)[0] if n_gpus is None: n_gpus = ModelAllocation.from_str(alloc_mode).parallel.world_size diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py index 994e24e8a6..1a26d44fb3 100644 --- a/tests/torchrun/run_lora_disk_sync.py +++ b/tests/torchrun/run_lora_disk_sync.py @@ -1,33 +1,4 @@ -"""Torchrun script for LoRA disk-sync end-to-end validation. - -This script is launched via ``torchrun`` from the e2e test -(``tests/test_lora_disk_sync_e2e.py``). It: - -1. Creates a small Qwen3-0.6B model with LoRA adapters on the FSDP engine. -2. Calls ``FSDPEngine._save_model_to_hf`` (the production code path used - by ``_update_weights_from_disk``) under ``use_lora=True``. -3. Validates the on-disk artefacts: - * ``adapter_model.safetensors`` exists and is non-empty. - * ``adapter_config.json`` exists, parses, and has the PEFT-required - fields (``peft_type='LORA'``, ``r``, ``lora_alpha``, - ``target_modules``). - * Every adapter tensor key matches the PEFT layout consumed by - SGLang's ``/load_lora_adapter`` (i.e. contains a LoRA keyword - and does NOT contain the active-adapter ``.default.`` segment). -4. Verifies the in-memory model still runs a forward pass. -5. Writes ``"Passed"`` / ``"Failed"`` to an output file (rank 0 only). - -Note: the new disk-mode LoRA sync path does not require any NCCL -process group on the inference side; SGLang loads the adapter directly -from disk via its existing ``/load_lora_adapter`` endpoint. This -script therefore exercises the *training-side* contract end-to-end and -leaves the SGLang HTTP request-building to the unit tests in -``tests/test_lora_disk_sync.py``. - -Usage (invoked by the e2e test, not directly): - torchrun --nproc_per_node=N tests/torchrun/run_lora_disk_sync.py \ - --backend fsdp:d1t1 --output /tmp/result.out -""" +"""Torchrun E2E validation for FSDP LoRA adapter-only disk saves.""" import argparse import json @@ -185,11 +156,7 @@ def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: in print(f"ERROR: adapter key must end with '.weight': {k}", flush=True) return False - # Adapter-only saves should be tens of MB at most (Qwen3-0.6B + r=8 - # is around 19MB). If the artefact is GB-scale the engine almost - # certainly fell back to a full-model save -- the very bug that - # this whole disk-sync path exists to fix. Cap at 200 MB to leave - # plenty of headroom while still flagging a regression. + # Catch accidental full-model saves. total_bytes = 0 for root, _dirs, files in os.walk(adapter_dir): for fname in files: From 9138a3d05dbd95e8ce99883f12a446fe8be09c62 Mon Sep 17 00:00:00 2001 From: TaoZex <2633363995@qq.com> Date: Thu, 14 May 2026 23:27:11 +0800 Subject: [PATCH 12/12] feat: precommit --- .../openai/proxy/proxy_rollout_server.py | 4 +--- tests/test_lora_adapter_save.py | 4 +--- tests/test_lora_disk_sync.py | 5 +---- tests/torchrun/run_lora_disk_sync.py | 12 +++++------- 4 files changed, 8 insertions(+), 17 deletions(-) diff --git a/areal/experimental/openai/proxy/proxy_rollout_server.py b/areal/experimental/openai/proxy/proxy_rollout_server.py index 2eb11591f8..98fe756ad6 100644 --- a/areal/experimental/openai/proxy/proxy_rollout_server.py +++ b/areal/experimental/openai/proxy/proxy_rollout_server.py @@ -283,9 +283,7 @@ def _setup_openai_client(): # any attacker who can reach this port can call admin endpoints # (grant_capacity, start_session, export_trajectories, ...). loopback_hosts = {"127.0.0.1", "::1", "localhost"} - allow_override = ( - os.environ.get("AREAL_ALLOW_DEFAULT_ADMIN_KEY", "0") == "1" - ) + allow_override = os.environ.get("AREAL_ALLOW_DEFAULT_ADMIN_KEY", "0") == "1" if _server_host in loopback_hosts or allow_override: logger.warning( "Using default admin API key. Change 'admin_api_key' in " diff --git a/tests/test_lora_adapter_save.py b/tests/test_lora_adapter_save.py index 10ca1ef5dd..c92ddc20f6 100644 --- a/tests/test_lora_adapter_save.py +++ b/tests/test_lora_adapter_save.py @@ -117,9 +117,7 @@ def test_tensor_values_are_preserved(self, tmp_path): _invoke(engine, str(d), sd) loaded = safetensors_load_file(str(d / "adapter_model.safetensors")) - sample_key = ( - "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" - ) + sample_key = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight" assert sample_key in loaded torch.testing.assert_close( loaded[sample_key], torch.ones(8, 4), check_dtype=False diff --git a/tests/test_lora_disk_sync.py b/tests/test_lora_disk_sync.py index 4a2f03cc9c..c50b6db5af 100644 --- a/tests/test_lora_disk_sync.py +++ b/tests/test_lora_disk_sync.py @@ -12,7 +12,6 @@ from areal.api.cli_args import TrainEngineConfig from areal.api.io_struct import get_versioned_lora_name - # Keep this in sync with ``FSDPEngine._save_lora_adapter_to_hf``. _LORA_KEYWORDS = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") @@ -352,9 +351,7 @@ def test_generation_request_without_lora(self): backend = SGLangBackend() gconfig = GenerationHyperparameters(max_new_tokens=8) req = ModelRequest(input_ids=[1, 2, 3], gconfig=gconfig) - http_req = backend.build_generation_request( - req, with_lora=False, version=0 - ) + http_req = backend.build_generation_request(req, with_lora=False, version=0) assert "lora_path" not in http_req.payload diff --git a/tests/torchrun/run_lora_disk_sync.py b/tests/torchrun/run_lora_disk_sync.py index 1a26d44fb3..0904be3153 100644 --- a/tests/torchrun/run_lora_disk_sync.py +++ b/tests/torchrun/run_lora_disk_sync.py @@ -93,7 +93,9 @@ def verify_forward_pass(engine: FSDPEngine) -> bool: return False -def verify_adapter_artifacts(adapter_dir: str, *, lora_rank: int, lora_alpha: int) -> bool: +def verify_adapter_artifacts( + adapter_dir: str, *, lora_rank: int, lora_alpha: int +) -> bool: """Validate the PEFT-format files written by ``_save_model_to_hf``.""" safetensors_path = os.path.join(adapter_dir, "adapter_model.safetensors") config_path = os.path.join(adapter_dir, "adapter_config.json") @@ -237,9 +239,7 @@ def test_lora_disk_sync( with tempfile.TemporaryDirectory() as tmpdir: adapter_dir = os.path.join(tmpdir, "weight_update_v0") os.makedirs(adapter_dir, exist_ok=True) - print( - f"[Rank {rank}] Calling _save_model_to_hf -> {adapter_dir}", flush=True - ) + print(f"[Rank {rank}] Calling _save_model_to_hf -> {adapter_dir}", flush=True) engine._save_model_to_hf( adapter_dir, tokenizer=None, @@ -256,9 +256,7 @@ def test_lora_disk_sync( if not ok: success = False else: - print( - f"[Rank {rank}] PEFT adapter artefacts validated", flush=True - ) + print(f"[Rank {rank}] PEFT adapter artefacts validated", flush=True) # Step 5: Verify forward pass still works after the save. print(f"[Rank {rank}] Verifying forward pass", flush=True)