diff --git a/examples/deepeyes/env_deepeyes.py b/examples/deepeyes/env_deepeyes.py index dcf9fdd..07763d4 100644 --- a/examples/deepeyes/env_deepeyes.py +++ b/examples/deepeyes/env_deepeyes.py @@ -27,12 +27,16 @@ class DeepeyesEnv(BaseInteractionEnv): MIN_DIMENSION = 28 - def __init__(self, *, max_turns: int | None = None, image=None): + def __init__(self, *, max_turns: int | None = None, image=None, normalize_bbox: bool = True): self.max_turns = max_turns self.turn = 0 self.tool_calls: list[dict[str, Any]] = [] self.current_image = image self.origin_image = image + # Whether to convert bbox coordinates from normalized [0, 1000] to absolute pixels. + # Qwen-VL / Qwen2-VL / Qwen3-VL output 0-1000 normalized coords → set True (default). + # Qwen2.5-VL outputs absolute pixel coords → set False. + self.normalize_bbox = normalize_bbox def reset(self): self.turn = 0 @@ -119,13 +123,21 @@ def _maybe_resize_bbox(self, bbox_2d: list[float]) -> Optional[list[float]]: image_height = self.current_image.height left, top, right, bottom = bbox_2d - # 1. Clamp the initial bounding box to the image dimensions. + # 1. Convert normalized [0, 1000] coordinates to absolute pixel coordinates. + # Qwen-VL / Qwen2-VL / Qwen3-VL use 0-1000 normalized coords; Qwen2.5-VL uses absolute pixels. + if self.normalize_bbox: + left = left / 1000.0 * image_width + top = top / 1000.0 * image_height + right = right / 1000.0 * image_width + bottom = bottom / 1000.0 * image_height + + # 2. Clamp the bounding box to the image dimensions. left = max(0.0, float(left)) top = max(0.0, float(top)) right = min(float(image_width), float(right)) bottom = min(float(image_height), float(bottom)) - # 2. If clamped bbox is invalid, return immediately. + # 3. If clamped bbox is invalid, return immediately. if not self._validate_bbox(left, top, right, bottom): return None @@ -133,7 +145,7 @@ def _maybe_resize_bbox(self, bbox_2d: list[float]) -> Optional[list[float]]: height = bottom - top width = right - left - # 3. If the box is too small, attempt to resize it. + # 4. If the box is too small, attempt to resize it. if height < self.MIN_DIMENSION or width < self.MIN_DIMENSION: logger.info(f"Bbox {width}x{height} is smaller than {self.MIN_DIMENSION}, attempting resize.") center_x = (left + right) / 2.0 @@ -182,7 +194,7 @@ def _maybe_resize_bbox(self, bbox_2d: list[float]) -> Optional[list[float]]: # Use floor and ceil for final integer coordinates. current_bbox = [floor(new_left), floor(new_top), ceil(new_right), ceil(new_bottom)] - # 4. Final validation on the resulting bounding box (either original or resized). + # 5. Final validation on the resulting bounding box (either original or resized). final_left, final_top, final_right, final_bottom = current_bbox if not self._validate_bbox(final_left, final_top, final_right, final_bottom): logger.warning(f"Final bbox is invalid after processing: {current_bbox}") @@ -288,7 +300,8 @@ def build_env(sample: Sample | None = None, args: Any | None = None, **_: Any) - max_turns = args.max_turns if max_turns is None: raise ValueError("max_turns must be set via --custom-config-path in the custom config file.") + normalize_bbox = getattr(args, "normalize_bbox", True) image = _extract_initial_image(sample) if image is None: logger.warning("No image found in sample.multimodal_inputs or metadata.") - return DeepeyesEnv(max_turns=max_turns, image=image) + return DeepeyesEnv(max_turns=max_turns, image=image, normalize_bbox=normalize_bbox) diff --git a/examples/deepeyes/reward_deepeyes.py b/examples/deepeyes/reward_deepeyes.py index 4c43174..47a3648 100644 --- a/examples/deepeyes/reward_deepeyes.py +++ b/examples/deepeyes/reward_deepeyes.py @@ -24,49 +24,49 @@ def get_gpt4_score_ICE(): [Question]: Is the countertop tan or blue? [Standard Answer]: The countertop is tan. [Model_answer] : tan -Judgement: 1 +Judgment: 1 """ # noqa example_2 = """ [Question]: On which side of the picture is the barrier? [Standard Answer]: The barrier is on the left side of the picture. [Model_answer] : left -Judgement: 1 +Judgment: 1 """ # noqa example_3 = """ [Question]: Is the kite brown and large? [Standard Answer]: Yes, the kite is brown and large. [Model_answer] : Yes -Judgement: 1 +Judgment: 1 """ # noqa example_4 = """ [Question]: Are the spots on a giraffe? [Standard Answer]: No, the spots are on a banana. [Model_answer] : no -Judgement: 1 +Judgment: 1 """ # noqa example_5 = """ [Question]: Who is wearing pants? [Standard Answer]: The boy is wearing pants. [Model_answer] : The person in the picture is wearing pants. -Judgement: 1 +Judgment: 1 """ # noqa example_6 = """ [Question]: Is the man phone both blue and closed? [Standard Answer]: Yes, the man phone is both blue and closed. [Model_answer] : No. -Judgement: 0 +Judgment: 0 """ # noqa example_7 = """ [Question]: What color is the towel in the center of the picture? [Standard Answer]: The towel in the center of the picture is blue. [Model_answer] : The towel in the center of the picture is pink. -Judgement: 0 +Judgment: 0 """ # noqa return [example_1, example_2, example_3, example_4, example_5, example_6, example_7] @@ -76,7 +76,7 @@ def get_chat_template(): chat_template = """ Below are two answers to a question. Question is [Question], [Standard Answer] is the standard answer to the question, and [Model_answer] is the answer extracted from a model's output to this question. Determine whether these two answers are consistent. Note that [Model Answer] is consistent with [Standard Answer] whenever they are essentially the same. If the meaning is expressed in the same way, it is considered consistent, for example, 'pink' and 'it is pink'. -If they are consistent, Judement is 1; if they are different, Judement is 0. Just output Judement and don't output anything else.\n\n +If they are consistent, Judgment is 1; if they are different, Judgment is 0. Just output Judgment and don't output anything else.\n\n """ return chat_template @@ -91,7 +91,7 @@ def get_prompt(predict_str, ground_truth, question): [Question]: {question} [Standard Answer]: {ground_truth} [Model_answer] : {predict_str} -Judgement:""" +Judgment:""" full_prompt = f"{demo_prompt}{test_prompt}" return full_prompt @@ -189,8 +189,8 @@ def compute_score(predict_str: str, ground_truth: str, extra_info: dict | None = response = "error" # print(response) - if "Judgement:" in response: - response = response.split("Judgement:")[-1].strip() + if "Judgment:" in response: + response = response.split("Judgment:")[-1].strip() if "1" in response: acc_reward = 1.0 elif "0" in response: diff --git a/examples/deepeyes/run_deepeyes_qwen35_9B_async.sh b/examples/deepeyes/run_deepeyes_qwen35_9B_async.sh new file mode 100755 index 0000000..67870d7 --- /dev/null +++ b/examples/deepeyes/run_deepeyes_qwen35_9B_async.sh @@ -0,0 +1,227 @@ +#!/bin/bash + +# Copyright (c) 2026 Relax Authors. All Rights Reserved. +# +# Qwen3.5-9B 8xGPU single-node fully-async DeepEyes training script. +# +# Resource layout (8 GPUs, fully-async): +# actor: 4 GPUs (TP=4) +# rollout: 2 GPUs (1 engine × 2 GPUs) +# reference: 1 GPU (TP=1, weight-only) +# actor_fwd: 1 GPU +# +# Usage: +# MODEL_DIR=/path/to/models DATA_DIR=/path/to/data SAVE_DIR=/path/to/save \ +# bash examples/deepeyes/run_deepeyes_qwen35_9B_async.sh + +set -ex +set -o pipefail + +############################################################################### +# ENVIRONMENT # +############################################################################### + +TIMESTAMP=$(date "+%Y-%m-%d-%H:%M:%S") + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +# Auto-source local environment when not launched via an external entrypoint +if [ -z "${RELAX_ENTRYPOINT_MODE:-}" ]; then + source "${SCRIPT_DIR}/../../scripts/entrypoint/local.sh" +fi +source "${MODEL_CONFIG_DIR}/qwen35-9B.sh" + +############################################################################### +# DIRS # +############################################################################### + +PROJECT_NAME="${PROJECT_NAME:=Relax/dev/deepeyes}" +EXP_NAME="qwen35-9B-deepeyes-async-${TIMESTAMP}" + +# Require MODEL_DIR, DATA_DIR, SAVE_DIR from environment or set defaults +if [ -z "${MODEL_DIR:-}" ] || [ -z "${DATA_DIR:-}" ] || [ -z "${SAVE_DIR:-}" ]; then + echo "ERROR: MODEL_DIR, DATA_DIR, and SAVE_DIR must be set." + echo "Example: MODEL_DIR=/path/to/models DATA_DIR=/path/to/data SAVE_DIR=/path/to/save bash $0" + exit 1 +fi +mkdir -p ${SAVE_DIR} + +############################################################################### +# JUDGE MODEL API # +############################################################################### + +source "${SCRIPT_DIR}/sglang_judge_service.sh" + +############################################################################### +# MODEL CONFIG # +############################################################################### + +CKPT_ARGS=( + --hf-checkpoint ${MODEL_DIR}/Qwen3.5-9B + --ref-load ${MODEL_DIR}/Qwen3.5-9B + --save ${SAVE_DIR}/Qwen3.5-9B-DeepEyes-Checkpoint + --megatron-to-hf-mode bridge + --save-interval 100 + --max-actor-ckpt-to-keep 1 +) + +############################################################################### +# DATASETS # +############################################################################### + +TRAIN_FILES=( + "'${DATA_DIR}/deepeyes-v1/data_0.1.2_visual_toolbox_v2.parquet@[0:5000]'" + "'${DATA_DIR}/deepeyes-v1/data_v0.8_visual_toolbox_v2.parquet@[0:5000]'" +) +TEST_FILES=("${DATA_DIR}/deepeyes-v1/data_thinklite_reasoning_acc.parquet@[0:256]") +PROMPT_SET="[$(IFS=,; echo "${TRAIN_FILES[*]}")]" + +############################################################################### +# ROLLOUT CONFIG # +############################################################################### + +NUM_ROLLOUT="${NUM_ROLLOUT:=2000}" + +ROLLOUT_ARGS=( + --prompt-data "${PROMPT_SET}" + --input-key prompt + --label-key reward_model + --multimodal-keys '{"image":"images"}' + --reward-key score + --metadata-key extra_info + --apply-chat-template + --custom-generate-function-path examples.deepeyes.rollout.generate + --custom-rm-path examples.deepeyes.reward_deepeyes.reward_func + --custom-config-path examples/deepeyes/deepeyes_config.yaml + --num-rollout ${NUM_ROLLOUT} + --rollout-batch-size 32 + --n-samples-per-prompt 8 + --rollout-max-response-len 2048 + --rollout-max-prompt-len 2048 + --rollout-temperature 1 + --global-batch-size 256 + --use-fault-tolerance + --rollout-shuffle + --use-streaming-dataset +) + +############################################################################### +# EVAL CONFIG # +############################################################################### + +EVAL_ARGS=( + --eval-interval 100 + --eval-prompt-data vstar ${TEST_FILES} + --n-samples-per-eval-prompt 8 + --eval-max-response-len 2048 + --eval-top-p 0.7 +) + +############################################################################### +# ALGORITHM CONFIG # +############################################################################### + +GRPO_ARGS=( + --advantage-estimator grpo + --kl-loss-coef 0.00 + --kl-loss-type low_var_kl + --entropy-coef 0.00 + --eps-clip 0.2 + --eps-clip-high 0.28 + --eps-clip-c 3 + --use-tis +) + +############################################################################### +# OPTIMIZER CONFIG # +############################################################################### + +OPTIMIZER_ARGS=( + --optimizer adam + --lr 1e-6 + --lr-decay-style constant + --weight-decay 0.1 + --adam-beta1 0.9 + --adam-beta2 0.98 + --optimizer-cpu-offload + --overlap-cpu-optimizer-d2h-h2d + --use-precision-aware-optimizer +) + +############################################################################### +# SGLANG CONFIG # +############################################################################### + +SGLANG_ARGS=( + --rollout-num-gpus-per-engine 2 + --sglang-mem-fraction-static 0.6 +) + +############################################################################### +# LOGGING CONFIG # +############################################################################### + +LOG_ARGS=( + --use-clearml + --use-metrics-service + --tb-project-name ${PROJECT_NAME} + --tb-experiment-name ${EXP_NAME} +) + +############################################################################### +# MEGATRON CONFIG # +############################################################################### + +MEGATRON_ARGS=( + --tensor-model-parallel-size 4 + --sequence-parallel + --pipeline-model-parallel-size 1 + --context-parallel-size 1 + --expert-model-parallel-size 1 + --expert-tensor-parallel-size 1 + --recompute-granularity full + --recompute-method uniform + --recompute-num-layers 1 + --use-dynamic-batch-size + --max-tokens-per-gpu 9216 + --no-rope-fusion + --attention-dropout 0.0 + --hidden-dropout 0.0 + --accumulate-allreduce-grads-in-fp32 + --attention-softmax-in-fp32 + --attention-backend flash +) + +############################################################################### +# RESOURCE CONFIG # +############################################################################### + +# Fully-async: actor(4 GPU) + rollout(2 GPU) + reference(1 GPU) + actor_fwd(1 GPU) = 8 GPU +RAY_RESOURCE_ARGS=( + --resource '{"actor": [1, 4], "rollout": [1, 2], "reference": [1, 1], "actor_fwd": [1, 1], "advantages": [1, 0]}' + --max-staleness 2 + --num-data-storage-units 1 + --num-iters-per-train-update 8 + --ref-actor-config '{"tensor_model_parallel_size": 1, "max_tokens_per_gpu": 16384, "sequence_parallel": false, "only_load_weight": true}' + --fully-async + --use-health-check +) + +############################################################################### +# LAUNCH JOB # +############################################################################### + +mkdir -p logs + +ray job submit ${RAY_NO_WAIT:+--no-wait} --address="http://127.0.0.1:8265" \ + -- python3 -m relax.entrypoints.train \ + "${RAY_RESOURCE_ARGS[@]}" \ + "${MODEL_ARGS[@]}" \ + "${CKPT_ARGS[@]}" \ + "${ROLLOUT_ARGS[@]}" \ + "${GRPO_ARGS[@]}" \ + "${OPTIMIZER_ARGS[@]}" \ + "${SGLANG_ARGS[@]}" \ + "${LOG_ARGS[@]}" \ + "${MEGATRON_ARGS[@]}" \ + "${EVAL_ARGS[@]}" \ + 2>&1 | tee logs/${EXP_NAME}.log diff --git a/relax/backends/megatron/__init__.py b/relax/backends/megatron/__init__.py index 4ffd43d..03a19d4 100644 --- a/relax/backends/megatron/__init__.py +++ b/relax/backends/megatron/__init__.py @@ -2,7 +2,7 @@ import logging -import torch +from relax.utils import device as device_utils try: @@ -15,7 +15,7 @@ def new_init(self, *args, **kwargs): if torch_memory_saver._impl is not None: torch_memory_saver._impl._binary_wrapper.cdll.tms_set_interesting_region(False) old_init(self, *args, **kwargs) - torch.cuda.synchronize() + device_utils.synchronize() if torch_memory_saver._impl is not None: torch_memory_saver._impl._binary_wrapper.cdll.tms_set_interesting_region(True) diff --git a/relax/backends/megatron/actor.py b/relax/backends/megatron/actor.py index 4ac23c1..e162244 100644 --- a/relax/backends/megatron/actor.py +++ b/relax/backends/megatron/actor.py @@ -22,6 +22,7 @@ from relax.distributed.checkpoint_service.client.engine import create_client from relax.distributed.ray.train_actor import TrainRayActor +from relax.utils import device as device_utils from relax.utils import tracking_utils from relax.utils.async_utils import run from relax.utils.data.stream_dataloader import ( @@ -921,7 +922,7 @@ def _check_services_health(self) -> tuple[bool, bool]: flags = torch.tensor( [int(rollout_only), int(actor_fwd_only)], dtype=torch.int32, - device=torch.cuda.current_device(), + device=device_utils.make_current_torch_device(), ) dist.all_reduce(flags, op=dist.ReduceOp.MAX, group=get_gloo_group()) rollout_only = bool(flags[0].item()) @@ -1035,7 +1036,7 @@ def all_consumed(self, task_name, rollout_id): status = [run(self.data_system_client.async_check_consumption_status(task_name, f"train_{rollout_id}"))] else: status = [True] - status = torch.tensor(status, device=torch.cuda.current_device()) + status = torch.tensor(status, device=device_utils.make_current_torch_device()) dist.broadcast(status, group=mpu.get_tensor_model_parallel_group(), group_src=0) dist.broadcast(status, group=mpu.get_pipeline_model_parallel_group(), group_src=0) diff --git a/relax/backends/megatron/arguments.py b/relax/backends/megatron/arguments.py index d76acdb..c3e7daf 100644 --- a/relax/backends/megatron/arguments.py +++ b/relax/backends/megatron/arguments.py @@ -5,6 +5,7 @@ from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding from transformers import AutoConfig +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger @@ -17,18 +18,17 @@ def validate_args(args): """Run megatron's own validate_args plus slime-specific megatron validations.""" - import torch - - if not torch.cuda.is_available(): + if not device_utils.is_available(): from unittest.mock import patch - class _CudaProperty: + class _DeviceProperty: major = 9 minor = 0 + device_name = device_utils.get_device_name() with ( - patch("torch.cuda.get_device_properties", return_value=_CudaProperty()), - patch("torch.cuda.get_device_capability", return_value=(9, 0)), + patch(f"torch.{device_name}.get_device_properties", return_value=_DeviceProperty()), + patch(f"torch.{device_name}.get_device_capability", return_value=(9, 0)), ): _megatron_validate_args(args) else: diff --git a/relax/backends/megatron/data.py b/relax/backends/megatron/data.py index 22e21fc..6a931ab 100644 --- a/relax/backends/megatron/data.py +++ b/relax/backends/megatron/data.py @@ -12,6 +12,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from torch.nn.utils.rnn import pad_sequence +from relax.utils import device as device_utils from relax.utils import tracking_utils from relax.utils.data.data import get_minimum_num_micro_batch_size from relax.utils.data.seqlen_balancing import get_seqlen_balanced_partitions @@ -173,7 +174,9 @@ def get_batch( tokens = F.pad(tokens, (0, pad), value=pad_token_id) cu_seqlens_list.append(cu_seqlens_list[-1] + pad) - cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int, device=torch.cuda.current_device()) + cu_seqlens = torch.tensor( + cu_seqlens_list, dtype=torch.int, device=device_utils.make_current_torch_device() + ) tokens = tokens.chunk(cp_size, dim=0)[cp_rank] else: tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens] @@ -191,7 +194,9 @@ def get_batch( cu_seqlens.append(cu_seqlens[-1] + pad) # thd requires the cu_seqlens to be of the origin length - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size + cu_seqlens = ( + torch.tensor(cu_seqlens, dtype=torch.int).to(device_utils.make_current_torch_device()) * cp_size + ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() packed_seq_params = PackedSeqParams( @@ -440,7 +445,9 @@ def get_data_iterator( # across DP ranks so that all ranks execute the same number of training steps # (required by collective operations in the training loop). if getattr(args, "balance_data", False): - steps_tensor = torch.tensor([num_steps_per_rollout], dtype=torch.int, device=torch.cuda.current_device()) + steps_tensor = torch.tensor( + [num_steps_per_rollout], dtype=torch.int, device=device_utils.make_current_torch_device() + ) dist.all_reduce(steps_tensor, op=dist.ReduceOp.MAX, group=dp_group) num_steps_per_rollout = steps_tensor.item() @@ -472,7 +479,9 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices= get_minimum_num_micro_batch_size(samples[start:end], args.max_tokens_per_gpu * cp_size) ) - num_microbatches = torch.tensor(num_microbatches, dtype=torch.int, device=torch.cuda.current_device()) + num_microbatches = torch.tensor( + num_microbatches, dtype=torch.int, device=device_utils.make_current_torch_device() + ) dist.all_reduce(num_microbatches, op=dist.ReduceOp.MAX, group=dp_group) if vpp_size > 1: diff --git a/relax/backends/megatron/kernels/int4_qat/setup.py b/relax/backends/megatron/kernels/int4_qat/setup.py index b8bfc7d..2fc5ba1 100644 --- a/relax/backends/megatron/kernels/int4_qat/setup.py +++ b/relax/backends/megatron/kernels/int4_qat/setup.py @@ -6,6 +6,8 @@ # Get CUDA arch list +# NOTE: This setup script is CUDA-only as it compiles .cu kernel files via CUDAExtension. +# Non-CUDA backends (NPU, XPU, PPU) should provide their own kernel implementations. arch_list = [] if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): diff --git a/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py b/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py index 7c1b5d2..77eabb4 100644 --- a/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py +++ b/relax/backends/megatron/weight_conversion/processors/quantizer_compressed_tensors.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +from relax.utils import device as device_utils + try: import fake_int4_quant_cuda @@ -91,7 +93,7 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze awq_linear.bias = linear.bias.clone().half() pack_num = 32 // awq_linear.w_bit - device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = device_utils.make_current_torch_device() repeat_scales = scales.to(device).t().repeat_interleave(group_size, 1) if isinstance(zeros, torch.Tensor): @@ -284,7 +286,7 @@ def quantize_params_compressed_tensors(converted_named_params, quantization_conf qw, s, zp = pack_layer(param, group_size, is_symmetric) qweight_name = name.replace(".weight", ".weight_packed") scale_name = name.replace(".weight", ".weight_scale") - weight_shape = torch.tensor(param.shape, dtype=torch.int32, device="cuda") + weight_shape = torch.tensor(param.shape, dtype=torch.int32, device=device_utils.get_device_name()) weight_shape_name = name.replace(".weight", ".weight_shape") if zp is not None: zp_name = name.replace(".weight", ".weight_zero_point") diff --git a/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py b/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py index 02da465..32a382f 100644 --- a/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py +++ b/relax/backends/megatron/weight_update/hf_weight_iterator_direct.py @@ -7,6 +7,7 @@ from megatron.core import mpu from tqdm import tqdm +from relax.utils import device as device_utils from relax.utils.distributed_utils import get_gloo_group from relax.utils.types import ParamInfo @@ -55,13 +56,15 @@ def _get_megatron_full_params( if dist.get_rank() == info.src_rank: params.append( torch.nn.Parameter( - megatron_local_weights[info.name].to(device=torch.cuda.current_device(), non_blocking=True), + megatron_local_weights[info.name].to( + device=device_utils.make_current_torch_device(), non_blocking=True + ), requires_grad=False, ) ) else: - params.append(torch.empty(info.shape, dtype=info.dtype, device=torch.cuda.current_device())) - torch.cuda.synchronize() + params.append(torch.empty(info.shape, dtype=info.dtype, device=device_utils.make_current_torch_device())) + device_utils.synchronize() # broadcast params across pp ranks if pp_size > 1: diff --git a/relax/backends/megatron/weight_update/update_weight_from_distributed.py b/relax/backends/megatron/weight_update/update_weight_from_distributed.py index 2563c25..a255db0 100644 --- a/relax/backends/megatron/weight_update/update_weight_from_distributed.py +++ b/relax/backends/megatron/weight_update/update_weight_from_distributed.py @@ -12,12 +12,17 @@ from ray.actor import ActorHandle from tqdm import tqdm +from relax.utils import device as device_utils from relax.utils.distributed_utils import get_gloo_group, init_process_group +from relax.utils.logging_utils import get_logger from ..weight_conversion import convert_to_hf from .common import all_gather_param, named_params_and_buffers +logger = get_logger(__name__) + + class UpdateWeightFromDistributed: """Update distributed engines via NCCL. @@ -210,7 +215,7 @@ def _update_expert_bucket_weights_from_distributed( handles = [] for i, (_name, param) in enumerate(named_tensors): params = [ - torch.empty_like(param.data, device=torch.cuda.current_device()) + torch.empty_like(param.data, device=device_utils.make_current_torch_device()) for _ in range(mpu.get_expert_model_parallel_world_size()) ] handle = dist.all_gather(params, param.data, group=mpu.get_expert_model_parallel_group(), async_op=True) @@ -261,6 +266,7 @@ def connect_rollout_engines_from_distributed( group_name: str, rollout_engines: Sequence[ActorHandle], engine_gpu_counts: Sequence[int] | None = None, + max_retries: int = 3, ) -> dist.ProcessGroup: """Create NCCL group: training rank 0 + all engine GPUs. Blocks until joined. @@ -273,37 +279,55 @@ def connect_rollout_engines_from_distributed( engine_gpu_counts = [args.rollout_num_gpus_per_engine] * len(rollout_engines) master_address = ray._private.services.get_node_ip_address() - with socket.socket() as sock: - sock.bind(("", 0)) - master_port = sock.getsockname()[1] world_size = sum(engine_gpu_counts) + 1 # +1 for training rank 0 - # Compute cumulative rank offsets: engine i starts at cumulative[i] + 1. cumulative = [0] for c in engine_gpu_counts: cumulative.append(cumulative[-1] + c) - refs = [ - engine.init_weights_update_group.remote( - master_address, - master_port, - cumulative[i] + 1, - world_size, - group_name, - backend="nccl", - ) - for i, engine in enumerate(rollout_engines) - ] - model_update_groups = init_process_group( - backend="nccl", - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=group_name, - timeout=timedelta(minutes=args.distributed_timeout_minutes), - ) - ray.get(refs) - return model_update_groups + last_error = None + dist_backend = device_utils.get_dist_backend() + for attempt in range(1, max_retries + 1): + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + + refs = [ + engine.init_weights_update_group.remote( + master_address, + master_port, + cumulative[i] + 1, + world_size, + group_name, + backend=dist_backend, + ) + for i, engine in enumerate(rollout_engines) + ] + try: + model_update_groups = init_process_group( + backend=dist_backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=group_name, + timeout=timedelta(minutes=args.distributed_timeout_minutes), + ) + ray.get(refs) + return model_update_groups + except Exception as e: + last_error = e + logger.warning( + f"Failed to connect rollout engines (attempt {attempt}/{max_retries}, port={master_port}): {e}", + exc_info=(attempt == max_retries), + ) + try: + ray.get(refs, timeout=5) + except Exception: + pass + if attempt < max_retries: + time.sleep(5.0 * attempt) + + raise RuntimeError(f"Failed to connect rollout engines after {max_retries} attempts") from last_error def disconnect_rollout_engines_from_distributed(args, group_name, model_update_groups, rollout_engines): @@ -311,6 +335,8 @@ def disconnect_rollout_engines_from_distributed(args, group_name, model_update_g refs = [engine.destroy_weights_update_group.remote(group_name) for engine in rollout_engines] dist.destroy_process_group(model_update_groups) ray.get(refs) + # Wait for NCCL socket ports to be released by the OS + time.sleep(2.0) def update_weights_from_distributed( diff --git a/relax/backends/sglang/routing_replay_patch.py b/relax/backends/sglang/routing_replay_patch.py index 651e7f5..7d51556 100644 --- a/relax/backends/sglang/routing_replay_patch.py +++ b/relax/backends/sglang/routing_replay_patch.py @@ -52,6 +52,8 @@ import torch +from relax.utils import device as device_utils + logger = logging.getLogger(__name__) @@ -95,8 +97,8 @@ def _patched_init(self, *args, **kwargs): self._pinned_loc = torch.zeros(max_batch, dtype=torch.int64, device="cpu", pin_memory=True) # Dedicated copy stream + event. - self._copy_stream = torch.cuda.Stream(device=dev_buf.device) - self._copy_event = torch.cuda.Event() + self._copy_stream = device_utils.Stream(device=dev_buf.device) + self._copy_event = device_utils.Event() # Pending scatter state. self._pending_n = 0 # 0 means nothing pending @@ -142,7 +144,7 @@ def _patched_sync(self, forward_batch, can_run_graph, cuda_graph_batch): # In overlap-scheduler mode this is the *forward_stream*; without # overlap it is the default stream. We need this reference so that # copy_stream can order itself after the GPU→GPU staging copy below. - active_stream = torch.cuda.current_stream(self.device_cache.buffer.device) + active_stream = device_utils.current_stream(self.device_cache.buffer.device) # 1) GPU→GPU snapshot on the active stream — fast, no sync. self._staging_buffer[:n_tok].copy_(self.device_cache.buffer[local_start_pos:local_end_pos]) @@ -150,7 +152,7 @@ def _patched_sync(self, forward_batch, can_run_graph, cuda_graph_batch): # 2) On copy stream: async copies to pinned CPU buffers. # copy_stream waits on active_stream so the staging snapshot # above completes before we start reading it. - with torch.cuda.stream(self._copy_stream): + with device_utils.stream_context(self._copy_stream): self._copy_stream.wait_stream(active_stream) # 2a) Routing data: staging[:n_tok, :, :topk] → pinned_staging self._pinned_staging[:n_tok, :, :topk].copy_(self._staging_buffer[:n_tok, :, :topk], non_blocking=True) diff --git a/relax/backends/sglang/sglang_engine.py b/relax/backends/sglang/sglang_engine.py index 5bc0f40..1014d77 100644 --- a/relax/backends/sglang/sglang_engine.py +++ b/relax/backends/sglang/sglang_engine.py @@ -20,6 +20,7 @@ from relax.distributed.checkpoint_service.client.engine import create_client from relax.distributed.ray.ray_actor import RayActor +from relax.utils import device as device_utils from relax.utils.async_utils import run from relax.utils.http_utils import get_host_info from relax.utils.logging_utils import get_logger @@ -42,10 +43,11 @@ def get_base_gpu_id(args, rank): def _to_local_gpu_id(physical_gpu_id: int) -> int: - cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + visible_env = device_utils.get_visible_devices_env_var() + cvd = os.environ.get(visible_env) if not cvd: return physical_gpu_id # no remapping - # CUDA_VISIBLE_DEVICES can be like "4,5,6,7" + # Visible devices can be like "4,5,6,7" visible = [int(x) for x in cvd.split(",") if x.strip() != ""] # In a remapped process, valid torch device indices are 0..len(visible)-1 if physical_gpu_id in visible: @@ -54,7 +56,7 @@ def _to_local_gpu_id(physical_gpu_id: int) -> int: if 0 <= physical_gpu_id < len(visible): return physical_gpu_id raise RuntimeError( - f"GPU id {physical_gpu_id} is not valid under CUDA_VISIBLE_DEVICES={cvd}. " + f"Device id {physical_gpu_id} is not valid under {visible_env}={cvd}. " f"Expected one of {visible} (physical) or 0..{len(visible) - 1} (local)." ) diff --git a/relax/core/controller.py b/relax/core/controller.py index 0c9b5f6..2e60886 100644 --- a/relax/core/controller.py +++ b/relax/core/controller.py @@ -16,6 +16,7 @@ from relax.core.registry import ALGOS, ROLES, process_role from relax.core.service import Service, create_placement_group from relax.distributed.checkpoint_service.coordinator.service import create_dcs_deployment +from relax.utils import device as device_utils from relax.utils.async_utils import run, shutdown_async_loop from relax.utils.health_system import HealthManager from relax.utils.logging_utils import get_logger @@ -238,7 +239,8 @@ def _validate_gpu_resources(self, roles_to_create, colocate, actor_rollout_pg_ro total_required = sum(num_gpus for _, _, num_gpus, _ in roles_to_create) cluster_resources = ray.cluster_resources() - total_available = int(cluster_resources.get("GPU", 0)) + accel_resource = device_utils.get_ray_accelerator_name() + total_available = int(cluster_resources.get(accel_resource, 0)) logger.info( f"Resource validation: required GPUs={total_required}, cluster GPUs={total_available}, colocate={colocate}" diff --git a/relax/core/service.py b/relax/core/service.py index 52ae3eb..eb9441c 100644 --- a/relax/core/service.py +++ b/relax/core/service.py @@ -12,6 +12,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from relax.distributed.ray.placement_group import InfoActor, sort_key +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.utils import get_serve_url, recovery_load_path @@ -295,7 +296,8 @@ def _ensure_placement_group(self) -> Optional[Any]: def create_placement_group(num_gpus): """Create a placement group with the specified number of GPUs.""" - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + accel_resource = device_utils.get_ray_accelerator_name() + bundles = [{accel_resource: 1, "CPU": 1} for _ in range(num_gpus)] pg = placement_group(bundles, strategy="PACK") num_bundles = len(bundles) ray.get(pg.ready()) diff --git a/relax/distributed/checkpoint_service/backends/device_direct.py b/relax/distributed/checkpoint_service/backends/device_direct.py index ba5f6f4..57fa3e8 100644 --- a/relax/distributed/checkpoint_service/backends/device_direct.py +++ b/relax/distributed/checkpoint_service/backends/device_direct.py @@ -39,6 +39,7 @@ from relax.distributed.checkpoint_service.backends.base import CommBackend, TensorFusion from relax.distributed.checkpoint_service.config import BackendType, RoleInfo from relax.distributed.checkpoint_service.utils import load_weight +from relax.utils import device as device_utils from relax.utils.distributed_utils import get_gloo_group, init_process_group from relax.utils.logging_utils import get_logger @@ -98,7 +99,7 @@ def __init__( self.coordinator_url = coordinator_url self.lock = lock self.timeout_seconds = timeout_seconds - self.device = next(model[0].parameters()).device if model else torch.cuda.current_device() + self.device = next(model[0].parameters()).device if model else device_utils.current_device() self._comm_stream: Optional[Any] = None # CUDA stream self._thread_pool = ThreadPoolExecutor(max_workers=4) @@ -114,12 +115,13 @@ def __init__( # Ray actors for rollout communication self.rollout_engines: Dict[int, Any] = {} # rank -> Ray actor handle - torch.cuda.set_device(self.device) + device_utils.set_device(self.device) # Bridge-based HF weight converter (lazy-initialized on first use) self._use_bridge = getattr(args, "megatron_to_hf_mode", None) == "bridge" self._bridge_task_map: Optional[Dict[str, Any]] = None # global_param_name -> WeightConversionTask self._bridge_mapping_registry = None # MegatronMappingRegistry for dynamic lookups + self._bridge_expert_transposes_down: bool = True # set in _init_bridge_tasks def _init_bridge_tasks(self) -> None: """Lazily initialize Bridge conversion tasks and build a lookup table. @@ -203,6 +205,16 @@ def _init_bridge_tasks(self) -> None: inner_tp._detected_type = inner_tp._detect_parallelism_type(task.megatron_module) inner_tp._mapping = inner_tp._get_or_create_mapping(inner_tp._detected_type) + # Detect whether the Bridge's ExpertMLPDownProjMapping applies a + # transpose in megatron_to_hf (Qwen3-VL does, Qwen3.5 does not). + # Used by _convert_to_hf_bridge to decide whether to undo the transpose. + self._bridge_expert_transposes_down = False + for task in self._bridge_task_map.values(): + cls = type(task.mapping) + if cls.__name__ == "ExpertMLPDownProjMapping": + self._bridge_expert_transposes_down = "megatron_to_hf" in cls.__dict__ + break + logger.info(f"Bridge task map initialized with {len(self._bridge_task_map)} local tasks") @staticmethod @@ -390,16 +402,21 @@ def _noop_gather_from_ep_ranks(self_m, megatron_weights, megatron_module, hf_par # ── Post-process expert weights ────────────────────────────────── # Bridge's ExpertMLPGateUpProjMapping and ExpertMLPDownProjMapping - # (used by Qwen3-VL MoE) apply an extra ``.transpose(-1, -2)`` in - # their ``megatron_to_hf`` methods, assuming Megatron stores expert - # weights in column-major order. However, the raw ``convert_to_hf`` - # does NOT transpose expert weights — Megatron's expert weights are - # already in the same layout as HF. We must undo Bridge's transpose - # to match the format that SGLang / ``convert_to_hf`` expects. + # apply transformations that differ by model family: + # + # **Qwen3-VL** (qwen3_vl_bridge.py): + # gate_up_proj: transpose each half then stack → [2, D_out, D_in] + # down_proj: transpose → [D_in, D_out] + # We must undo the transpose. + # + # **Qwen3.5** (qwen35_vl_bridge.py): + # gate_up_proj: cat without transpose → [2*H, D] (2-D) + # down_proj: no transpose (AutoMapping) → [H, D] (2-D) + # No un-transpose needed; just split the fused tensor. # # Additionally, Bridge outputs fused names without expert_id: - # - ``...experts.gate_up_proj`` with shape [2, D_out, D_in] - # - ``...experts.down_proj`` with shape [D_in, D_out] + # - ``...experts.gate_up_proj`` + # - ``...experts.down_proj`` # We split into per-expert format with correct names and shapes: # - ``...experts.{E}.gate_proj.weight`` [H, D] # - ``...experts.{E}.up_proj.weight`` [H, D] @@ -410,19 +427,28 @@ def _noop_gather_from_ep_ranks(self_m, megatron_weights, megatron_module, hf_par postprocessed: list[tuple[str, torch.Tensor]] = [] for hf_name, tensor in converted_named_tensors: if hf_name.endswith(".experts.gate_up_proj"): - # Bridge output: [2, D_out, D_in] (transposed by Bridge) - # Undo transpose on each slice: [D_out, D_in] -> [D_in, D_out] - gate_tensor = tensor[0].transpose(-1, -2).contiguous() - up_tensor = tensor[1].transpose(-1, -2).contiguous() base = hf_name[: -len(".gate_up_proj")] + if tensor.ndim == 3: + # Qwen3-VL style: [2, D_out, D_in] (transposed by Bridge) + # Undo transpose on each slice: [D_out, D_in] -> [D_in, D_out] + gate_tensor = tensor[0].transpose(-1, -2).contiguous() + up_tensor = tensor[1].transpose(-1, -2).contiguous() + else: + # Qwen3.5 style: [2*H, D] (cat, no transpose by Bridge) + # Split along dim 0 into two [H, D] tensors + gate_tensor, up_tensor = tensor.chunk(2, dim=0) postprocessed.append((f"{base}.{expert_id}.gate_proj.weight", gate_tensor)) postprocessed.append((f"{base}.{expert_id}.up_proj.weight", up_tensor)) elif hf_name.endswith(".experts.down_proj"): - # Bridge output: transposed — undo to match raw convert_to_hf base = hf_name[: -len(".down_proj")] - postprocessed.append( - (f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous()) - ) + if tensor.ndim == 2 and not self._bridge_expert_transposes_down: + # Qwen3.5 style: AutoMapping, no transpose — already [H, D] + postprocessed.append((f"{base}.{expert_id}.down_proj.weight", tensor)) + else: + # Qwen3-VL style: transposed — undo to match raw convert_to_hf + postprocessed.append( + (f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous()) + ) else: postprocessed.append((hf_name, tensor)) converted_named_tensors = postprocessed @@ -590,6 +616,8 @@ def init_process_group_for_rollout(self, topology_data: Optional[Dict] = None) - dist.destroy_process_group(self._model_update_groups) ray.get(futures) self._model_update_groups = None + # Wait for NCCL socket ports to be released by the OS + time.sleep(2.0) except Exception as e: logger.warning(f"Error destroying old process group: {e}") self._model_update_groups = None @@ -604,32 +632,59 @@ def init_process_group_for_rollout(self, topology_data: Optional[Dict] = None) - cumulative_offset += gpus_for_node world_size = cumulative_offset - master_port = self._find_free_port_in_range(self._MASTER_PORT_MIN, self._MASTER_PORT_MAX) - - # Prepare init payloads for each rollout node - init_payloads = {} - for rank, role_info in self.rollout_topology.items(): - init_payloads[int(rank)] = { - "master_address": master_address, - "master_port": master_port, - "rank_offset": rank_offsets[int(rank)], - "world_size": world_size, - "group_name": self._group_name, - "backend": self.backend_type, - } + max_retries = 3 + last_error = None + for attempt in range(1, max_retries + 1): + master_port = self._find_free_port_in_range(self._MASTER_PORT_MIN, self._MASTER_PORT_MAX) + + init_payloads = {} + for rank, role_info in self.rollout_topology.items(): + init_payloads[int(rank)] = { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offsets[int(rank)], + "world_size": world_size, + "group_name": self._group_name, + "backend": self.backend_type, + } - logger.info(f"Sending init_weights_update_group to {len(self.rollout_topology)} rollout nodes...") - futures = self._batch_request("/init_weights_update_group", init_payloads, get_rank=True) + logger.info( + f"Sending init_weights_update_group to {len(self.rollout_topology)} rollout nodes " + f"(attempt {attempt}/{max_retries}, port={master_port})..." + ) + futures = self._batch_request("/init_weights_update_group", init_payloads, get_rank=True) - self._model_update_groups = init_process_group( - backend=self.backend_type, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name=self._group_name, - timeout=timedelta(seconds=180), - ) - ray.get(futures) + try: + self._model_update_groups = init_process_group( + backend=self.backend_type, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name=self._group_name, + timeout=timedelta(seconds=180), + ) + ray.get(futures) + last_error = None + break + except Exception as e: + last_error = e + logger.warning( + f"Failed to init process group for rollout (attempt {attempt}/{max_retries}, " + f"port={master_port}): {e}", + exc_info=(attempt == max_retries), + ) + self._model_update_groups = None + try: + ray.get(futures, timeout=5) + except Exception: + pass + if attempt < max_retries: + time.sleep(5.0 * attempt) + + if last_error is not None: + raise RuntimeError( + f"Failed to init process group for rollout after {max_retries} attempts" + ) from last_error def init_process_groups_for_actor_fwd_ref(self, topology_data) -> None: """Initialize process groups used for actor -> actor_fwd weight sync. @@ -804,7 +859,7 @@ def update_weights_for_rollout(self, rollout_only=False, actor_fwd_only=False) - # allocator keeps large reserved blocks that are internally # fragmented, which can cause OOM when the optimizer later tries # to allocate contiguous Adam state buffers. - torch.cuda.empty_cache() + device_utils.empty_cache() def _update_weight_from_distributed( self, diff --git a/relax/distributed/ray/rollout.py b/relax/distributed/ray/rollout.py index 18f2505..dc3d99e 100644 --- a/relax/distributed/ray/rollout.py +++ b/relax/distributed/ray/rollout.py @@ -23,6 +23,7 @@ from relax.backends.sglang.sglang_engine import SGLangEngine from relax.engine.rollout.base_types import call_rollout_fn +from relax.utils import device as device_utils from relax.utils import tracking_utils from relax.utils.health_monitor import RolloutHealthMonitor from relax.utils.http_utils import SLIME_HOST_IP_ENV, _wrap_ipv6, find_available_port, get_host_info, init_http_client @@ -1425,7 +1426,8 @@ async def _scale_out_ray_native(self, request: ScaleOutRequest) -> None: per_replica_pgs = [] for i in range(request.num_replicas): num_gpus = gpus_per_engine - bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_gpus)] + accel_resource = device_utils.get_ray_accelerator_name() + bundles = [{accel_resource: 1, "CPU": 1} for _ in range(num_gpus)] pg = ray.util.placement_group(bundles, strategy="PACK") per_replica_pgs.append(pg) @@ -2065,13 +2067,14 @@ async def _sync_single_engine_weights( ) try: + dist_backend = device_utils.get_dist_backend() init_seed_ref = seed_engine.init_weights_send_group_for_remote_instance.remote( master_address=master_address, ports=ports_str, group_rank=0, world_size=2, group_name=group_name, - backend="nccl", + backend=dist_backend, ) init_new_ref = new_engine.init_weights_send_group_for_remote_instance.remote( master_address=master_address, @@ -2079,7 +2082,7 @@ async def _sync_single_engine_weights( group_rank=1, world_size=2, group_name=group_name, - backend="nccl", + backend=dist_backend, ) init_results = await asyncio.wait_for( asyncio.gather(init_seed_ref, init_new_ref), diff --git a/relax/distributed/ray/train_actor.py b/relax/distributed/ray/train_actor.py index dd225b1..2a3eaed 100644 --- a/relax/distributed/ray/train_actor.py +++ b/relax/distributed/ray/train_actor.py @@ -11,6 +11,7 @@ import relax.utils.training.eval_config from relax.distributed.ray.ray_actor import RayActor +from relax.utils import device as device_utils from relax.utils.distributed_utils import init_gloo_group from relax.utils.logging_utils import get_logger from relax.utils.memory_utils import clear_memory, print_memory @@ -20,7 +21,7 @@ def get_local_gpu_id(): - cvd = os.environ.get("CUDA_VISIBLE_DEVICES", None) + cvd = os.environ.get(device_utils.get_visible_devices_env_var(), None) if cvd is None: return ray.get_gpu_ids()[0] else: @@ -57,7 +58,7 @@ def init(self, args, role, with_ref=False, with_opd_teacher=False): torch.serialization.add_safe_globals([relax.utils.training.eval_config.EvalDatasetConfig]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(f"cuda:{local_rank}") + device_utils.set_device(f"{device_utils.get_device_name()}:{local_rank}") backend = args.distributed_backend @@ -70,27 +71,8 @@ def init(self, args, role, with_ref=False, with_opd_teacher=False): args.rank = dist.get_rank() args.world_size = dist.get_world_size() - try: - if torch.version.hip is not None: - logger.info("Detected ROCm/HIP environment, skipping NUMA affinity setup") - # will find the coresponding API to implement ROCm version as below - else: - import pynvml - - pynvml.nvmlInit() - - local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node - - handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank) - pynvml.nvmlDeviceSetCpuAffinity(handle) - - logger.info(f"Set NUMA affinity for GPU {local_rank}") - pynvml.nvmlShutdown() - - except ImportError: - logger.info("Warning: pynvml not available, skipping NUMA affinity setup") - except Exception as e: - logger.info(f"Warning: Failed to set NUMA affinity: {e}") + numa_local_rank = int(os.environ["RANK"]) % args.num_gpus_per_node + device_utils.set_numa_affinity(numa_local_rank) def clear_memory(self): print_memory("before TrainRayActor.clear_memory") diff --git a/relax/distributed/ray/utils.py b/relax/distributed/ray/utils.py index 1918eaa..7f798fa 100644 --- a/relax/distributed/ray/utils.py +++ b/relax/distributed/ray/utils.py @@ -2,9 +2,9 @@ import os import ray -import torch from relax.distributed.ray.ray_actor import RayActor +from relax.utils import device as device_utils # Refer to @@ -31,8 +31,8 @@ def ray_noset_visible_devices(env_vars=os.environ): def get_physical_gpu_id(): - device = torch.cuda.current_device() - props = torch.cuda.get_device_properties(device) + device = device_utils.current_device() + props = device_utils.get_device_properties(device) return str(props.uuid) diff --git a/relax/utils/arguments.py b/relax/utils/arguments.py index 911787b..98ea94f 100644 --- a/relax/utils/arguments.py +++ b/relax/utils/arguments.py @@ -10,6 +10,7 @@ from relax.backends.sglang.arguments import sglang_parse_args from relax.backends.sglang.arguments import validate_args as sglang_validate_args +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.training.eval_config import EvalDatasetConfig, build_eval_dataset_configs, ensure_dataset_list @@ -62,7 +63,7 @@ def add_serve_arguments(parser): parser.add_argument( "--checkpoint-engine-backend", type=str, - default="nccl", + default=device_utils.get_dist_backend(), help=("Backend for checkpoint engine."), ) parser.add_argument( @@ -184,7 +185,7 @@ def add_cluster_arguments(parser): ), ) - reset_arg(parser, "--distributed-backend", type=str, default="nccl") + reset_arg(parser, "--distributed-backend", type=str, default=device_utils.get_dist_backend()) reset_arg(parser, "--distributed-timeout-minutes", type=int, default=30) return parser @@ -1905,6 +1906,16 @@ def add_autoscaler_arguments(parser): default=None, help="Path to the YAML config for custom function arguments.", ) + parser.add_argument( + "--normalize-bbox", + action=argparse.BooleanOptionalAction, + default=True, + help=( + "Convert model-output bbox coordinates from normalized [0, 1000] to absolute pixels. " + "Required for Qwen-VL/Qwen2-VL/Qwen3-VL (default True). " + "Set --no-normalize-bbox for Qwen2.5-VL which outputs absolute pixel coordinates." + ), + ) reset_arg(parser, "--padded-vocab-size", type=int, default=None) return parser diff --git a/relax/utils/checkpoint_write_patch.py b/relax/utils/checkpoint_write_patch.py index c573a7a..379369a 100644 --- a/relax/utils/checkpoint_write_patch.py +++ b/relax/utils/checkpoint_write_patch.py @@ -36,6 +36,7 @@ import torch +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger @@ -245,10 +246,10 @@ def _patched_write_preloaded_data_multiproc( # cause SIGSEGV. Use threaded parallel writes instead — all tensors # are already on CPU so the I/O releases the GIL and threads achieve # real parallelism without duplicating the CUDA context. - cuda_initialised = torch.cuda.is_available() and torch.cuda.is_initialized() + cuda_initialised = device_utils.is_available() and device_utils.is_initialized() if cuda_initialised: _logger.debug( - f"rank: {rank}, CUDA initialised – using threaded parallel " + f"rank: {rank}, device initialised – using threaded parallel " f"(no-fork) checkpoint write for {len(write_buckets)} buckets" ) write_results_or_exc = _write_buckets_threaded(transform_list, use_msc, write_buckets) @@ -285,7 +286,7 @@ def _patched_schedule_async_call(self, async_req): if async_req.async_fn is None: return # nothing to do - cuda_initialised = torch.cuda.is_available() and torch.cuda.is_initialized() + cuda_initialised = device_utils.is_available() and device_utils.is_initialized() if not cuda_initialised: # CUDA not initialised — safe to use the original fork path. return _original_schedule(self, async_req) @@ -301,7 +302,7 @@ def _patched_schedule_async_call(self, async_req): rank = torch.distributed.get_rank() start_sync = time() - torch.cuda.synchronize() + device_utils.synchronize() end_sync = time() _logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ") diff --git a/relax/utils/data/stream_dataloader.py b/relax/utils/data/stream_dataloader.py index 9f62d36..80be728 100644 --- a/relax/utils/data/stream_dataloader.py +++ b/relax/utils/data/stream_dataloader.py @@ -13,6 +13,8 @@ from transfer_queue.dataloader.streaming_dataloader import StreamingDataLoader from transfer_queue.dataloader.streaming_dataset import StreamingDataset +from relax.utils import device as device_utils + logger = logging.getLogger(__name__) @@ -310,9 +312,9 @@ def get_data_from_transfer_queue( # will receive the real data via broadcast. rollout_data = [None, None] - # Use an explicit CUDA device so the communication backend (e.g. NCCL) - # can bind to a known CUDA context. - cuda_dev = torch.device(f"cuda:{torch.cuda.current_device()}") + # Use an explicit device so the communication backend (e.g. NCCL) + # can bind to a known device context. + cuda_dev = device_utils.make_current_torch_device() # --- Extract rollout_routed_experts BEFORE broadcast_object_list --- # broadcast_object_list uses pickle for the entire payload. When @@ -428,7 +430,7 @@ def post_process_rollout_data(args, rollout_data): # code in this module expects lists of sequence tensors for packing) from relax.backends.megatron.cp_utils import slice_log_prob_with_cp - cuda_dev = torch.device(f"cuda:{torch.cuda.current_device()}") + cuda_dev = device_utils.make_current_torch_device() rollout_data["tokens"] = [torch.tensor(t, dtype=torch.long, device=cuda_dev) for t in rollout_data["tokens"]] rollout_data["loss_masks"] = [ torch.tensor(t, dtype=torch.int, device=cuda_dev) for t in rollout_data["loss_masks"] diff --git a/relax/utils/device.py b/relax/utils/device.py new file mode 100644 index 0000000..b53c349 --- /dev/null +++ b/relax/utils/device.py @@ -0,0 +1,446 @@ +# Copyright (c) 2026 Relax Authors. All Rights Reserved. +# +# Multi-hardware backend abstraction layer. +# +# Inspired by verl (https://github.com/verl-project/verl) device.py +# and slime (https://github.com/THUDM/slime) plugin architecture. +# +# This module provides a unified device abstraction that allows Relax to run +# on multiple hardware backends (NVIDIA CUDA, Ascend NPU, AMD ROCm, Kunlunxin XPU, +# PPU, etc.) with minimal code changes throughout the framework. +# +# Usage: +# from relax.utils.device import get_device_name, get_torch_device, ... +# +# The module auto-detects the available accelerator at import time and exposes +# a consistent API regardless of the underlying hardware. + +import os +from enum import Enum +from functools import lru_cache +from typing import Optional + +import torch + +from relax.utils.logging_utils import get_logger + + +logger = get_logger(__name__) + + +# --------------------------------------------------------------------------- +# Accelerator type enum +# --------------------------------------------------------------------------- +class AcceleratorType(str, Enum): + """Supported hardware accelerator types.""" + + CUDA = "cuda" # NVIDIA GPU + NPU = "npu" # Ascend NPU (Huawei) + XPU = "xpu" # Intel / Kunlunxin XPU + PPU = "ppu" # PPU (Enflame / custom) + ROCM = "rocm" # AMD ROCm (uses 'cuda' device in PyTorch but HIP backend) + CPU = "cpu" # CPU fallback + + +# --------------------------------------------------------------------------- +# Detection helpers (cached — hardware won't change at runtime) +# --------------------------------------------------------------------------- +@lru_cache(maxsize=1) +def _detect_accelerator() -> AcceleratorType: + """Detect the available hardware accelerator. + + Detection order follows specificity: NPU > XPU > PPU > CUDA/ROCm > CPU. + Environment variable ``RELAX_DEVICE_TYPE`` can override auto-detection. + """ + # Allow explicit override via environment variable + override = os.environ.get("RELAX_DEVICE_TYPE", "").lower().strip() + if override: + for accel in AcceleratorType: + if override == accel.value: + logger.info(f"Device type overridden by RELAX_DEVICE_TYPE={override}") + return accel + logger.warning(f"Unknown RELAX_DEVICE_TYPE='{override}', falling back to auto-detection") + + # Ascend NPU + if _is_npu_available(): + return AcceleratorType.NPU + + # Kunlunxin / Intel XPU + if _is_xpu_available(): + return AcceleratorType.XPU + + # PPU (Enflame) + if _is_ppu_available(): + return AcceleratorType.PPU + + # NVIDIA CUDA or AMD ROCm (both expose torch.cuda) + if torch.cuda.is_available(): + if _is_rocm(): + return AcceleratorType.ROCM + return AcceleratorType.CUDA + + return AcceleratorType.CPU + + +def _is_npu_available() -> bool: + """Check if Ascend NPU is available.""" + try: + if not hasattr(torch, "npu"): + return False + return torch.npu.is_available() + except (ImportError, AttributeError): + return False + + +def _is_xpu_available() -> bool: + """Check if XPU (Intel / Kunlunxin) is available.""" + try: + if not hasattr(torch, "xpu"): + return False + return torch.xpu.is_available() + except (ImportError, AttributeError): + return False + + +def _is_ppu_available() -> bool: + """Check if PPU is available.""" + try: + if not hasattr(torch, "ppu"): + return False + return torch.ppu.is_available() + except (ImportError, AttributeError): + return False + + +def _is_rocm() -> bool: + """Check if the current CUDA build is actually AMD ROCm/HIP.""" + return getattr(torch.version, "hip", None) is not None + + +# --------------------------------------------------------------------------- +# Public API — device info +# --------------------------------------------------------------------------- +def get_accelerator_type() -> AcceleratorType: + """Return the detected :class:`AcceleratorType`.""" + return _detect_accelerator() + + +def get_device_name() -> str: + """Return the PyTorch device type string (``'cuda'``, ``'npu'``, ``'xpu'``, + etc.). + + For ROCm, returns ``'cuda'`` because PyTorch ROCm uses the CUDA device + namespace. + """ + accel = _detect_accelerator() + if accel == AcceleratorType.ROCM: + return "cuda" # ROCm uses torch.cuda namespace + if accel == AcceleratorType.CPU: + return "cpu" + return accel.value + + +def get_torch_device_module(): + """Return the ``torch.`` module (e.g. ``torch.cuda``, + ``torch.npu``). + + This is the namespace that provides ``current_device()``, + ``synchronize()``, ``empty_cache()``, etc. + """ + name = get_device_name() + try: + return getattr(torch, name) + except AttributeError: + logger.warning(f"torch.{name} not found, falling back to torch.cuda") + return torch.cuda + + +# --------------------------------------------------------------------------- +# Public API — distributed backend +# --------------------------------------------------------------------------- + +# Mapping from accelerator type to the default collective communication backend +_DIST_BACKEND_MAP = { + AcceleratorType.CUDA: "nccl", + AcceleratorType.ROCM: "nccl", # ROCm uses RCCL which is NCCL-compatible + AcceleratorType.NPU: "hccl", + AcceleratorType.XPU: "xccl", + AcceleratorType.PPU: "eccl", + AcceleratorType.CPU: "gloo", +} + + +def get_dist_backend() -> str: + """Return the default distributed communication backend name. + + Returns ``'nccl'`` for NVIDIA/AMD, ``'hccl'`` for Ascend NPU, etc. + """ + return _DIST_BACKEND_MAP.get(_detect_accelerator(), "nccl") + + +# --------------------------------------------------------------------------- +# Public API — environment variables +# --------------------------------------------------------------------------- + +# Mapping from accelerator type to the visible-devices environment variable +_VISIBLE_DEVICES_ENV_MAP = { + AcceleratorType.CUDA: "CUDA_VISIBLE_DEVICES", + AcceleratorType.ROCM: "CUDA_VISIBLE_DEVICES", # ROCm also uses this (or HIP_VISIBLE_DEVICES) + AcceleratorType.NPU: "ASCEND_RT_VISIBLE_DEVICES", + AcceleratorType.XPU: "XPU_VISIBLE_DEVICES", + AcceleratorType.PPU: "PPU_VISIBLE_DEVICES", + AcceleratorType.CPU: "", +} + + +def get_visible_devices_env_var() -> str: + """Return the environment variable name for controlling visible devices. + + E.g. ``'CUDA_VISIBLE_DEVICES'`` for NVIDIA, ``'ASCEND_RT_VISIBLE_DEVICES'`` + for Ascend NPU. + """ + return _VISIBLE_DEVICES_ENV_MAP.get(_detect_accelerator(), "CUDA_VISIBLE_DEVICES") + + +def get_visible_devices() -> Optional[str]: + """Return the value of the visible-devices environment variable, or + None.""" + env_var = get_visible_devices_env_var() + if not env_var: + return None + return os.environ.get(env_var) + + +# --------------------------------------------------------------------------- +# Public API — Ray resource name +# --------------------------------------------------------------------------- + +_RAY_RESOURCE_MAP = { + AcceleratorType.CUDA: "GPU", + AcceleratorType.ROCM: "GPU", + AcceleratorType.NPU: "NPU", + AcceleratorType.XPU: "XPU", + AcceleratorType.PPU: "PPU", + AcceleratorType.CPU: "CPU", +} + + +def get_ray_accelerator_name() -> str: + """Return the Ray resource name for the current accelerator. + + E.g. ``'GPU'`` for NVIDIA/AMD, ``'NPU'`` for Ascend. + """ + return _RAY_RESOURCE_MAP.get(_detect_accelerator(), "GPU") + + +# --------------------------------------------------------------------------- +# Public API — device operations (thin wrappers) +# --------------------------------------------------------------------------- +def current_device() -> int: + """Return the index of the current device.""" + mod = get_torch_device_module() + return mod.current_device() + + +def set_device(device) -> None: + """Set the current device. + + Args: + device: Device index (int) or device string (e.g. ``'cuda:0'``). + """ + mod = get_torch_device_module() + mod.set_device(device) + + +def device_count() -> int: + """Return the number of available accelerator devices.""" + mod = get_torch_device_module() + return mod.device_count() + + +def synchronize(device=None) -> None: + """Synchronize the current (or specified) device.""" + accel = _detect_accelerator() + if accel == AcceleratorType.CPU: + return # no-op for CPU + mod = get_torch_device_module() + if device is not None: + mod.synchronize(device) + else: + mod.synchronize() + + +def empty_cache() -> None: + """Release all unoccupied cached memory.""" + accel = _detect_accelerator() + if accel == AcceleratorType.CPU: + return + mod = get_torch_device_module() + mod.empty_cache() + + +def memory_allocated(device=None) -> int: + """Return the current GPU memory occupied by tensors in bytes.""" + mod = get_torch_device_module() + if device is not None: + return mod.memory_allocated(device) + return mod.memory_allocated() + + +def memory_reserved(device=None) -> int: + """Return the current GPU memory managed by the caching allocator in + bytes.""" + mod = get_torch_device_module() + if device is not None: + return mod.memory_reserved(device) + return mod.memory_reserved() + + +def mem_get_info(device=None): + """Return ``(free, total)`` memory in bytes for the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.mem_get_info(device) + return mod.mem_get_info() + + +def get_device_properties(device=None): + """Return device properties for the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.get_device_properties(device) + return mod.get_device_properties(mod.current_device()) + + +def current_stream(device=None): + """Return the currently selected stream for the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.current_stream(device) + return mod.current_stream() + + +def Stream(device=None, **kwargs): + """Create a new stream on the given device.""" + mod = get_torch_device_module() + if device is not None: + return mod.Stream(device=device, **kwargs) + return mod.Stream(**kwargs) + + +def Event(**kwargs): + """Create a new event.""" + mod = get_torch_device_module() + return mod.Event(**kwargs) + + +def stream_context(stream): + """Return a context manager that sets the given stream as the current + stream. + + Equivalent to ``torch.cuda.stream(s)`` but dispatches to the correct device + backend (e.g. ``torch.npu.stream(s)`` on Ascend NPU). + """ + mod = get_torch_device_module() + return mod.stream(stream) + + +def is_initialized() -> bool: + """Return True if the device backend has been initialized. + + Equivalent to ``torch.cuda.is_initialized()`` but dispatches to the correct + device backend. + """ + mod = get_torch_device_module() + if hasattr(mod, "is_initialized"): + return mod.is_initialized() + # Fallback: if the backend doesn't expose is_initialized, check if + # any device is available (conservative — assumes initialized if available). + return is_available() + + +# --------------------------------------------------------------------------- +# Public API — device string helpers +# --------------------------------------------------------------------------- +def make_device_string(index: Optional[int] = None) -> str: + """Build a device string like ``'cuda:0'`` or ``'npu:2'``. + + Args: + index: Device index. If None, uses :func:`current_device`. + """ + name = get_device_name() + if name == "cpu": + return "cpu" + if index is None: + index = current_device() + return f"{name}:{index}" + + +def make_current_torch_device() -> torch.device: + """Return a ``torch.device`` for the current accelerator and device + index.""" + return torch.device(make_device_string()) + + +# --------------------------------------------------------------------------- +# Public API — NUMA affinity +# --------------------------------------------------------------------------- +def set_numa_affinity(local_rank: int) -> None: + """Set NUMA affinity for the given local rank. + + On NVIDIA GPUs, uses pynvml. On other backends, this is a no-op with a + warning. + """ + accel = _detect_accelerator() + if accel in (AcceleratorType.CUDA,): + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank) + pynvml.nvmlDeviceSetCpuAffinity(handle) + logger.info(f"Set NUMA affinity for GPU {local_rank}") + pynvml.nvmlShutdown() + except ImportError: + logger.info("pynvml not available, skipping NUMA affinity setup") + except Exception as e: + logger.info(f"Failed to set NUMA affinity: {e}") + elif accel == AcceleratorType.ROCM: + logger.info("ROCm/HIP environment detected, skipping NUMA affinity setup") + elif accel == AcceleratorType.NPU: + logger.info("Ascend NPU environment, skipping NUMA affinity setup (not yet supported)") + else: + logger.info(f"NUMA affinity not supported for {accel.value}, skipping") + + +# --------------------------------------------------------------------------- +# Public API — expandable segments (CUDA-specific, no-op on others) +# --------------------------------------------------------------------------- +def set_expandable_segments(enable: bool) -> None: + """Configure CUDA memory allocator expandable segments. + + Only effective on NVIDIA CUDA. No-op on other backends. + """ + if _detect_accelerator() == AcceleratorType.CUDA: + try: + torch.cuda.memory._set_allocator_settings(f"expandable_segments:{enable}") + except Exception as e: + logger.warning(f"Failed to set expandable_segments: {e}") + + +# --------------------------------------------------------------------------- +# Public API — availability check +# --------------------------------------------------------------------------- +def is_available() -> bool: + """Return True if any accelerator device is available (not CPU-only).""" + return _detect_accelerator() != AcceleratorType.CPU + + +# --------------------------------------------------------------------------- +# Convenience: boolean flags (for backward compatibility / quick checks) +# --------------------------------------------------------------------------- +is_cuda_available: bool = torch.cuda.is_available() +is_npu_available: bool = _is_npu_available() +is_xpu_available: bool = _is_xpu_available() +is_ppu_available: bool = _is_ppu_available() +is_rocm: bool = _is_rocm() diff --git a/relax/utils/memory_utils.py b/relax/utils/memory_utils.py index 1c73176..8611288 100644 --- a/relax/utils/memory_utils.py +++ b/relax/utils/memory_utils.py @@ -5,6 +5,7 @@ import torch import torch.distributed as dist +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger @@ -12,23 +13,23 @@ def clear_memory(clear_host_memory: bool = False): - torch.cuda.synchronize() + device_utils.synchronize() gc.collect() - torch.cuda.empty_cache() + device_utils.empty_cache() if clear_host_memory: torch._C._host_emptyCache() def available_memory(): - device = torch.cuda.current_device() - free, total = torch.cuda.mem_get_info(device) + dev = device_utils.current_device() + free, total = device_utils.mem_get_info(dev) return { - "gpu": str(device), + "device": str(dev), "total_GB": _byte_to_gb(total), "free_GB": _byte_to_gb(free), "used_GB": _byte_to_gb(total - free), - "allocated_GB": _byte_to_gb(torch.cuda.memory_allocated(device)), - "reserved_GB": _byte_to_gb(torch.cuda.memory_reserved(device)), + "allocated_GB": _byte_to_gb(device_utils.memory_allocated(dev)), + "reserved_GB": _byte_to_gb(device_utils.memory_reserved(dev)), } diff --git a/relax/utils/profile_utils.py b/relax/utils/profile_utils.py index 72ece27..e7a6a74 100644 --- a/relax/utils/profile_utils.py +++ b/relax/utils/profile_utils.py @@ -6,6 +6,7 @@ import torch +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.memory_utils import print_memory @@ -109,7 +110,16 @@ class _TorchMemoryProfiler(_BaseMemoryProfiler): def start(self): logger.info("Attach OOM dump memory history.") - torch.cuda.memory._record_memory_history( + # Memory snapshot APIs are currently CUDA-specific. + # On non-CUDA backends, log a warning and skip. + device_mod = device_utils.get_torch_device_module() + if not hasattr(device_mod, "memory"): + logger.warning( + f"Memory snapshot profiling is not supported on {device_utils.get_device_name()} backend, skipping." + ) + return + + device_mod.memory._record_memory_history( max_entries=1000000, # record stack information for the trace events # trace_alloc_record_context=True, @@ -121,15 +131,22 @@ def oom_observer(device, alloc, device_alloc, device_free): f"Observe OOM, will dump snapshot to {self._path_dump}. ({device=} {alloc=} {device_alloc=} {device_free=}; stacktrace is as follows)" ) traceback.print_stack() - torch.cuda.memory._dump_snapshot(self._path_dump) + device_mod.memory._dump_snapshot(self._path_dump) print_memory("when oom") - torch._C._cuda_attach_out_of_memory_observer(oom_observer) + if hasattr(torch._C, "_cuda_attach_out_of_memory_observer"): + torch._C._cuda_attach_out_of_memory_observer(oom_observer) def stop(self): logger.info(f"Dump memory snapshot to: {self._path_dump}") - torch.cuda.memory._dump_snapshot(self._path_dump) - torch.cuda.memory._record_memory_history(enabled=None) + device_mod = device_utils.get_torch_device_module() + if not hasattr(device_mod, "memory"): + logger.warning( + f"Memory snapshot profiling is not supported on {device_utils.get_device_name()} backend, skipping." + ) + return + device_mod.memory._dump_snapshot(self._path_dump) + device_mod.memory._record_memory_history(enabled=None) class _MemrayMemoryProfiler(_BaseMemoryProfiler): diff --git a/relax/utils/reloadable_process_group.py b/relax/utils/reloadable_process_group.py index 6016ea6..220cb61 100644 --- a/relax/utils/reloadable_process_group.py +++ b/relax/utils/reloadable_process_group.py @@ -1,10 +1,12 @@ import os +import time from contextlib import contextmanager from datetime import timedelta import torch import torch.distributed as dist +from relax.utils import device as device_utils from relax.utils.logging_utils import get_logger from relax.utils.memory_utils import available_memory, clear_memory, print_memory @@ -149,13 +151,15 @@ def __getattr__(self, name): return getattr(self.group, name) @staticmethod - def destroy_process_groups(): + def destroy_process_groups(post_destroy_delay: float = 2.0): pid = os.getpid() + destroyed_count = 0 for reloadable_group in ReloadableProcessGroup.GROUPS.get(pid, []): if reloadable_group.group is None: continue try: dist.destroy_process_group(reloadable_group.group) + destroyed_count += 1 except ValueError as e: logger.warning( f"Process group already invalid/destroyed; skipping cleanup. Exception: {e}", @@ -165,21 +169,52 @@ def destroy_process_groups(): del reloadable_group.group reloadable_group.group = None + if destroyed_count > 0 and post_destroy_delay > 0: + # Wait for OS to release NCCL socket ports (TCP TIME_WAIT), + # preventing "Address already in use" on subsequent reload. + logger.info( + f"Destroyed {destroyed_count} process groups, waiting {post_destroy_delay}s " + "for NCCL socket port release" + ) + time.sleep(post_destroy_delay) + @staticmethod - def reload_process_groups(timeout_minutes: int = 30): + def reload_process_groups(timeout_minutes: int = 30, max_retries: int = 3, retry_delay: float = 5.0): pid = os.getpid() reloadable_groups = ReloadableProcessGroup.GROUPS.get(pid, []) logger.info(f"Reloading {len(reloadable_groups)} process groups in pid {pid}") old_new_group = old_new_group_dict.get(pid) - for reloadable_group in reloadable_groups: + for idx, reloadable_group in enumerate(reloadable_groups): if reloadable_group.group is not None: continue - group = old_new_group( - ranks=reloadable_group.group_info["ranks"], - backend="nccl", - timeout=timedelta(minutes=timeout_minutes), - ) - reloadable_group.group = group + last_error = None + for attempt in range(1, max_retries + 1): + try: + group = old_new_group( + ranks=reloadable_group.group_info["ranks"], + backend=device_utils.get_dist_backend(), + timeout=timedelta(minutes=timeout_minutes), + ) + reloadable_group.group = group + if attempt > 1: + logger.info(f"Process group {idx} reloaded successfully on attempt {attempt}") + last_error = None + break + except Exception as e: + last_error = e + logger.warning( + f"Failed to reload process group {idx} (attempt {attempt}/{max_retries}): {e}", + exc_info=(attempt == max_retries), + ) + if attempt < max_retries: + sleep_time = retry_delay * attempt + logger.info(f"Retrying in {sleep_time}s...") + time.sleep(sleep_time) + if last_error is not None: + raise RuntimeError( + f"Failed to reload process group {idx} after {max_retries} attempts " + f"(ranks={reloadable_group.group_info['ranks']})" + ) from last_error def rank(self) -> int: return self.group.rank() @@ -293,14 +328,16 @@ def bound_device_id(self, dev): self.group.bound_device_id = dev -def destroy_process_groups(): +def destroy_process_groups(post_destroy_delay: float = 2.0): """Destroy all reloadable process groups.""" - ReloadableProcessGroup.destroy_process_groups() + ReloadableProcessGroup.destroy_process_groups(post_destroy_delay=post_destroy_delay) -def reload_process_groups(timeout_minutes: int = 30): +def reload_process_groups(timeout_minutes: int = 30, max_retries: int = 3, retry_delay: float = 5.0): """Reload all reloadable process groups.""" - ReloadableProcessGroup.reload_process_groups(timeout_minutes=timeout_minutes) + ReloadableProcessGroup.reload_process_groups( + timeout_minutes=timeout_minutes, max_retries=max_retries, retry_delay=retry_delay + ) @contextmanager diff --git a/relax/utils/training/routing_replay.py b/relax/utils/training/routing_replay.py index 096f4e7..9561031 100644 --- a/relax/utils/training/routing_replay.py +++ b/relax/utils/training/routing_replay.py @@ -2,6 +2,8 @@ import torch +from relax.utils import device as device_utils + ROUTING_REPLAY = None @@ -29,12 +31,12 @@ def record(self, top_indices): def pop_forward(self): top_indices = self.top_indices_list[self.forward_index] self.forward_index += 1 - return top_indices.to(torch.cuda.current_device()) + return top_indices.to(device_utils.make_current_torch_device()) def pop_backward(self): top_indices = self.top_indices_list[self.backward_index] self.backward_index += 1 - return top_indices.to(torch.cuda.current_device()) + return top_indices.to(device_utils.make_current_torch_device()) def clear(self): self.forward_index = 0 diff --git a/relax/utils/training/tensor_backper.py b/relax/utils/training/tensor_backper.py index 955975a..e9a0145 100644 --- a/relax/utils/training/tensor_backper.py +++ b/relax/utils/training/tensor_backper.py @@ -4,6 +4,8 @@ import torch +from relax.utils import device as device_utils + _SourceGetter = Callable[[], Iterable[tuple[str, torch.Tensor]]] @@ -59,7 +61,7 @@ def backup(self, tag: str) -> None: if name not in backup_dict: backup_dict[name] = torch.empty_like(param, device=torch.device("cpu"), pin_memory=True) backup_dict[name].copy_(param.detach(), non_blocking=True) - torch.cuda.synchronize() + device_utils.synchronize() @torch.no_grad() def copy(self, *, src_tag: str, dst_tag: str): @@ -72,7 +74,7 @@ def restore(self, tag: str) -> None: for name, param in self._source_getter(): assert name in backup_dict param.copy_(backup_dict[name], non_blocking=True) - torch.cuda.synchronize() + device_utils.synchronize() class _TensorBackuperNoop(TensorBackuper): @@ -95,12 +97,12 @@ def get(self, tag: str): def backup(self, tag: str) -> None: assert tag == self._single_tag self._backup_hash_dict = _compute_hash_dict(dict(self._source_getter())) - torch.cuda.synchronize() + device_utils.synchronize() def restore(self, tag: str) -> None: assert tag == self._single_tag assert _compute_hash_dict(dict(self._source_getter())) == self._backup_hash_dict - torch.cuda.synchronize() + device_utils.synchronize() def _compute_hash_dict(tensors: dict[str, torch.Tensor]): diff --git a/relax/utils/utils.py b/relax/utils/utils.py index fb532de..d9e81ef 100644 --- a/relax/utils/utils.py +++ b/relax/utils/utils.py @@ -58,8 +58,12 @@ def convert_samples_to_train_data(args: Any, samples: list[Sample] | list[list[S train_data["loss_masks"] = loss_masks # overwriting the raw reward - if samples[0].metadata and "raw_reward" in samples[0].metadata: - train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples] + # populate this field for a subset of samples (e.g. SWE but not code). + if any(sample.metadata and "raw_reward" in sample.metadata for sample in samples): + train_data["raw_reward"] = [ + sample.metadata["raw_reward"] if sample.metadata and "raw_reward" in sample.metadata else sample.reward + for sample in samples + ] # For rollout buffer if samples[0].metadata and "round_number" in samples[0].metadata: @@ -75,7 +79,7 @@ def convert_samples_to_train_data(args: Any, samples: list[Sample] | list[list[S if samples[0].train_metadata is not None: train_data["metadata"] = [sample.train_metadata for sample in samples] - if samples[0].multimodal_train_inputs is not None: + if any(sample.multimodal_train_inputs is not None for sample in samples): train_data["multimodal_train_inputs"] = [sample.multimodal_train_inputs for sample in samples] if samples[0].teacher_log_probs is not None: diff --git a/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py b/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py index ffa8f41..4a583a5 100644 --- a/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py +++ b/tests/distributed/checkpoint_service/test_dcs_weight_conversion.py @@ -37,6 +37,12 @@ ExpertMLPDownProjMapping, ExpertMLPGateUpProjMapping, ) +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import ( # noqa: E402 + ExpertMLPDownProjMapping as Qwen35ExpertMLPDownProjMapping, +) +from megatron.bridge.models.qwen_vl.qwen35_vl_bridge import ( # noqa: E402 + ExpertMLPGateUpProjMapping as Qwen35ExpertMLPGateUpProjMapping, +) from relax.backends.megatron.misc_utils import strip_param_name_prefix # noqa: E402 from relax.backends.megatron.weight_conversion.processors import quantize_params, remove_padding # noqa: E402 @@ -77,7 +83,14 @@ def _noop_gather(self_m, megatron_weights, megatron_module, hf_param_name): return {str(hf_param_name): megatron_weights} saved_originals: dict = {} - patched_classes = [MegatronParamMapping, GatedMLPMapping, ExpertMLPGateUpProjMapping, ExpertMLPDownProjMapping] + patched_classes = [ + MegatronParamMapping, + GatedMLPMapping, + ExpertMLPGateUpProjMapping, + ExpertMLPDownProjMapping, + Qwen35ExpertMLPGateUpProjMapping, + Qwen35ExpertMLPDownProjMapping, + ] for cls in patched_classes: if "gather_from_ep_ranks" in cls.__dict__: saved_originals[cls] = cls.__dict__["gather_from_ep_ranks"] @@ -113,15 +126,35 @@ def _make_expert_down_mapping(layer_idx: int, expert_id: int) -> ExpertMLPDownPr return m +def _make_qwen35_expert_gate_up_mapping(layer_idx: int, expert_id: int) -> Qwen35ExpertMLPGateUpProjMapping: + """Create a real Qwen3.5 ExpertMLPGateUpProjMapping for testing.""" + return Qwen35ExpertMLPGateUpProjMapping( + megatron_param=f"language_model.decoder.layers.{layer_idx}.mlp.experts.linear_fc1.weight{expert_id}", + hf_param=f"model.language_model.layers.{layer_idx}.mlp.experts.gate_up_proj", + ) + + +def _make_qwen35_expert_down_mapping(layer_idx: int, expert_id: int) -> Qwen35ExpertMLPDownProjMapping: + """Create a real Qwen3.5 ExpertMLPDownProjMapping with eagerly initialized + inner mapping.""" + m = Qwen35ExpertMLPDownProjMapping( + megatron_param=f"language_model.decoder.layers.{layer_idx}.mlp.experts.linear_fc2.weight{expert_id}", + hf_param=f"model.language_model.layers.{layer_idx}.mlp.experts.down_proj", + ) + m._detected_type = "replicated" + m._mapping = m._get_or_create_mapping("replicated") + return m + + def _apply_expert_postprocessing( converted_dict: Dict[str, torch.Tensor], megatron_param_name: str, + bridge_expert_transposes_down: bool = True, ) -> List[Tuple[str, torch.Tensor]]: """Apply the same expert weight post-processing as ``_convert_to_hf_bridge``. - This calls the real production logic extracted from device_direct.py lines - 399-420. + Mirrors the production logic in device_direct.py. """ converted_named_tensors = list(converted_dict.items()) expert_id_match = re.search(r"weight(\d+)", megatron_param_name) @@ -130,14 +163,22 @@ def _apply_expert_postprocessing( postprocessed: list[tuple[str, torch.Tensor]] = [] for hf_name, tensor in converted_named_tensors: if hf_name.endswith(".experts.gate_up_proj"): - gate_tensor = tensor[0].transpose(-1, -2).contiguous() - up_tensor = tensor[1].transpose(-1, -2).contiguous() base = hf_name[: -len(".gate_up_proj")] + if tensor.ndim == 3: + gate_tensor = tensor[0].transpose(-1, -2).contiguous() + up_tensor = tensor[1].transpose(-1, -2).contiguous() + else: + gate_tensor, up_tensor = tensor.chunk(2, dim=0) postprocessed.append((f"{base}.{expert_id}.gate_proj.weight", gate_tensor)) postprocessed.append((f"{base}.{expert_id}.up_proj.weight", up_tensor)) elif hf_name.endswith(".experts.down_proj"): base = hf_name[: -len(".down_proj")] - postprocessed.append((f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous())) + if tensor.ndim == 2 and not bridge_expert_transposes_down: + postprocessed.append((f"{base}.{expert_id}.down_proj.weight", tensor)) + else: + postprocessed.append( + (f"{base}.{expert_id}.down_proj.weight", tensor.transpose(-1, -2).contiguous()) + ) else: postprocessed.append((hf_name, tensor)) converted_named_tensors = postprocessed @@ -657,3 +698,88 @@ def test_element_count_preserved(self): postprocessed = _apply_expert_postprocessing(bridge_output, "decoder.layers.0.mlp.experts.linear_fc1.weight0") total_numel = sum(t.numel() for _, t in postprocessed) assert total_numel == original_numel + + +# ─── Tests for Qwen3.5 Bridge (2D cat, no transpose) ───────────────────────── + + +class TestQwen35BridgeMappingOutput: + """Test Qwen3.5 Bridge mapping output format (2D cat, no transpose).""" + + def test_qwen35_gate_up_outputs_2d_cat(self): + """Qwen3.5 ExpertMLPGateUpProjMapping outputs 2D [2*H, D] via cat.""" + with _patch_gather_from_ep_ranks(): + m = _make_qwen35_expert_gate_up_mapping(layer_idx=0, expert_id=3) + H, D = 768, 2048 + fused = torch.randn(H * 2, D) + result = m.megatron_to_hf(fused, None) + + key = "model.language_model.layers.0.mlp.experts.gate_up_proj" + assert list(result.keys()) == [key] + tensor = result[key] + assert tensor.ndim == 2 + assert tensor.shape == (H * 2, D) + + def test_qwen35_down_proj_no_transpose(self): + """Qwen3.5 ExpertMLPDownProjMapping does not transpose.""" + with _patch_gather_from_ep_ranks(): + m = _make_qwen35_expert_down_mapping(layer_idx=0, expert_id=3) + D, H = 2048, 768 + param = torch.randn(D, H) + result = m.megatron_to_hf(param, None) + + key = "model.language_model.layers.0.mlp.experts.down_proj" + assert list(result.keys()) == [key] + tensor = result[key] + assert tensor.shape == (D, H) + assert torch.allclose(tensor, param) + + def test_qwen35_expert_transposes_down_detection(self): + """Qwen3.5 ExpertMLPDownProjMapping lacks megatron_to_hf override.""" + assert "megatron_to_hf" not in Qwen35ExpertMLPDownProjMapping.__dict__ + assert "megatron_to_hf" in ExpertMLPDownProjMapping.__dict__ + + +class TestQwen35PostProcessingCorrectness: + """Verify Qwen3.5 Bridge output + post-processing produces correct HF + weights.""" + + def test_qwen35_gate_up_postprocessed(self): + """Qwen3.5 gate_up 2D + post-processing produces correct gate/up.""" + H, D = 768, 2048 + expert_id = 3 + megatron_param = torch.randn(H * 2, D) + expected_gate, expected_up = megatron_param.chunk(2, dim=0) + + with _patch_gather_from_ep_ranks(): + mapping = _make_qwen35_expert_gate_up_mapping(layer_idx=0, expert_id=expert_id) + bridge_output = mapping.megatron_to_hf(megatron_param, None) + + megatron_name = f"language_model.decoder.layers.0.mlp.experts.linear_fc1.weight{expert_id}" + postprocessed = _apply_expert_postprocessing(bridge_output, megatron_name, bridge_expert_transposes_down=False) + + assert len(postprocessed) == 2 + assert postprocessed[0][0].endswith(f".experts.{expert_id}.gate_proj.weight") + assert postprocessed[1][0].endswith(f".experts.{expert_id}.up_proj.weight") + assert postprocessed[0][1].shape == (H, D) + assert postprocessed[1][1].shape == (H, D) + assert torch.allclose(postprocessed[0][1], expected_gate) + assert torch.allclose(postprocessed[1][1], expected_up) + + def test_qwen35_down_proj_postprocessed(self): + """Qwen3.5 down_proj passthrough (no transpose undo).""" + D, H = 2048, 768 + expert_id = 5 + megatron_param = torch.randn(D, H) + + with _patch_gather_from_ep_ranks(): + mapping = _make_qwen35_expert_down_mapping(layer_idx=0, expert_id=expert_id) + bridge_output = mapping.megatron_to_hf(megatron_param, None) + + megatron_name = f"language_model.decoder.layers.0.mlp.experts.linear_fc2.weight{expert_id}" + postprocessed = _apply_expert_postprocessing(bridge_output, megatron_name, bridge_expert_transposes_down=False) + + assert len(postprocessed) == 1 + assert postprocessed[0][0].endswith(f".experts.{expert_id}.down_proj.weight") + assert postprocessed[0][1].shape == (D, H) + assert torch.allclose(postprocessed[0][1], megatron_param)