diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py new file mode 100644 index 00000000..be20e9f1 --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -0,0 +1,414 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +WMC-ERC: World Model-Conditioned Entropy Regularized Co-evolution. + +Dynamic entropy clipping for multi-turn agentic RL. Uses the LLM's own +prediction entropy at environment token positions as a World Model uncertainty +signal (H_WM) to dynamically gate policy gradient updates, preventing +entropy collapse in well-understood regions while permitting exploration +in uncertain ones. + +Core idea: +- S_* measures per-turn "blind confidence" of the policy (entropy collapse momentum) +- H_WM measures per-turn World Model uncertainty (prediction entropy at env tokens) +- Dynamic mask m_t: when WM is confident but policy is overconfident → block update + when WM is uncertain → allow exploration regardless of confidence +- Entropy floor: when per-turn action entropy falls below a threshold, force mask + to prevent entropy collapse even when z-score gating degrades + +Reference: World Model-Conditioned Entropy Regularized Co-evolution (WMC-ERC) +""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) + + +def compute_s_star( + old_log_probs: torch.Tensor, + entropys: torch.Tensor, + response_mask: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute per-turn policy blind confidence S_*. + + S_*^t = mean over action tokens in turn t of: p_k * (H + log p_k) + + where p_k is the probability of the chosen action token and H is the + full-distribution entropy. This quantity measures how strongly the + policy's token-level probability mass is driving entropy downward: + high S_* means the policy is aggressively collapsing toward a single + action, creating momentum for entropy collapse. + + Based on first-order Taylor expansion of the entropy discriminator + (see WMC-ERC algorithm specification). + + Args: + old_log_probs: (batch_size, response_length) log probs of chosen tokens + entropys: (batch_size, response_length) entropy of policy distribution + response_mask: (batch_size, response_length) 1=action, 0=env/pad + turn_boundaries: per-sample list of (start, end) tuples for action turns + + Returns: + List of lists of scalar tensors, one S_* per turn per sample. + """ + batch_size = old_log_probs.shape[0] + device = old_log_probs.device + s_star_per_sample = [] + + for i in range(batch_size): + s_star_turns = [] + for start, end in turn_boundaries[i]: + log_p = old_log_probs[i, start:end] + H = entropys[i, start:end] + mask = response_mask[i, start:end] + count = mask.sum() + + if count > 0: + p_k = torch.exp(log_p) + s_token = p_k * (H + log_p) + s_t = (s_token * mask).sum() / count + else: + s_t = torch.tensor(0.0, device=device) + + s_star_turns.append(s_t) + s_star_per_sample.append(s_star_turns) + + return s_star_per_sample + + +def compute_per_turn_entropy( + entropys: torch.Tensor, + response_mask: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute mean action-token entropy per turn. + + Args: + entropys: (batch_size, response_length) entropy at all positions + response_mask: (batch_size, response_length) 1=action, 0=env/pad + turn_boundaries: per-sample list of (start, end) for action turns + + Returns: + List of lists of scalar tensors, one mean entropy per turn per sample. + """ + batch_size = entropys.shape[0] + device = entropys.device + ent_per_sample = [] + + for i in range(batch_size): + ent_turns = [] + for start, end in turn_boundaries[i]: + H = entropys[i, start:end] + mask = response_mask[i, start:end] + count = mask.sum() + if count > 0: + ent_t = (H * mask).sum() / count + else: + ent_t = torch.tensor(0.0, device=device) + ent_turns.append(ent_t) + ent_per_sample.append(ent_turns) + + return ent_per_sample + + +def compute_h_wm( + entropys: torch.Tensor, + response_mask: torch.Tensor, + attention_mask_response: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute per-turn World Model uncertainty H_WM. + + H_WM^t = mean prediction entropy at env token positions following action turn t. + + Env tokens after turn t represent the environment's response to action a_t. + The model's entropy at these positions measures how uncertain it is about + predicting the next state — i.e., the World Model's "cognitive blind spot". + + Higher H_WM → model doesn't understand this environment transition well. + Lower H_WM → model has seen similar transitions and is confident. + + Args: + entropys: (batch_size, response_length) entropy at all positions + response_mask: (batch_size, response_length) 1=action, 0=env/pad + attention_mask_response: (batch_size, response_length) 1=real, 0=padding + turn_boundaries: per-sample list of (start, end) for action turns + + Returns: + List of lists of scalar tensors, one H_WM per turn per sample. + """ + batch_size = entropys.shape[0] + seq_len = entropys.shape[1] + device = entropys.device + env_mask = attention_mask_response * (1.0 - response_mask) + + h_wm_per_sample = [] + + for i in range(batch_size): + boundaries = turn_boundaries[i] + h_wm_turns = [] + + for t, (start, end) in enumerate(boundaries): + # Env tokens after this turn: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(boundaries): + env_end = boundaries[t + 1][0] + else: + env_end = seq_len + + region_mask = env_mask[i, end:env_end] + region_entropy = entropys[i, end:env_end] + count = region_mask.sum() + + if count > 0: + h_wm_t = (region_entropy * region_mask).sum() / count + else: + h_wm_t = torch.tensor(0.0, device=device) + + h_wm_turns.append(h_wm_t) + + h_wm_per_sample.append(h_wm_turns) + + return h_wm_per_sample + + +def compute_dynamic_mask( + s_star_per_sample: List[List[torch.Tensor]], + h_wm_per_sample: List[List[torch.Tensor]], + mu_base: float = 1.0, + lambda_wm: float = 1.0, + per_turn_entropy: Optional[List[List[torch.Tensor]]] = None, + entropy_floor: float = 0.0, +) -> List[List[float]]: + """Compute per-turn dynamic entropy clipping mask. + + Two-stage gating: + + Stage 1 (z-score): m_t = 1 if |S_*^t - S_bar| <= mu * (1 + lambda * H_WM^t) * sigma + 0 otherwise + + Stage 2 (entropy floor): if per-turn action entropy < entropy_floor → m_t = 0 + This prevents the degenerate case where z-score gating fails because all + S_* values cluster near zero during entropy collapse. + + Args: + s_star_per_sample: per-sample, per-turn S_* tensors + h_wm_per_sample: per-sample, per-turn H_WM tensors + mu_base: base clipping coefficient + lambda_wm: WM uncertainty weight + per_turn_entropy: per-sample, per-turn mean action entropy (optional) + entropy_floor: minimum allowed per-turn entropy; turns below are masked + + Returns: + List of lists of floats (0.0 or 1.0), one mask per turn per sample. + """ + # Flatten all S_* for batch statistics + all_s = [] + for turns in s_star_per_sample: + for s in turns: + all_s.append(s.detach()) + + if len(all_s) == 0: + return [[] for _ in s_star_per_sample] + + all_s_tensor = torch.stack(all_s) + s_bar = all_s_tensor.mean() + + # Guard for single-element: std is 0, threshold = mu_base * (1 + lambda * h_wm) * 0 + # → everything would be masked. Use 1.0 as default sigma for single element. + if len(all_s) <= 1: + sigma = torch.tensor(1.0, device=all_s_tensor.device) + else: + sigma = all_s_tensor.std(unbiased=False) + 1e-8 + + mask_per_sample = [] + n_zscore_masked = 0 + n_entropy_floor_masked = 0 + + for i in range(len(s_star_per_sample)): + masks = [] + for t in range(len(s_star_per_sample[i])): + s_t = s_star_per_sample[i][t].detach() + h_t = h_wm_per_sample[i][t].detach() + + # Stage 1: z-score gating + threshold = mu_base * (1.0 + lambda_wm * h_t) * sigma + if torch.abs(s_t - s_bar) > threshold: + m_t = 0.0 + n_zscore_masked += 1 + # Stage 2: entropy floor gating + elif ( + per_turn_entropy is not None + and entropy_floor > 0 + and i < len(per_turn_entropy) + and t < len(per_turn_entropy[i]) + and per_turn_entropy[i][t].detach().item() < entropy_floor + ): + m_t = 0.0 + n_entropy_floor_masked += 1 + else: + m_t = 1.0 + + masks.append(m_t) + mask_per_sample.append(masks) + + return mask_per_sample, n_zscore_masked, n_entropy_floor_masked + + +def apply_wmc_erc( + batch, + entropys: torch.Tensor, + wmc_erc_config, +) -> Tuple[object, Dict[str, float]]: + """Apply WMC-ERC dynamic entropy clipping to batch advantages. + + Pipeline: + 1. Compute turn boundaries from response_mask + 2. Compute S_* (policy blind confidence) per turn + 3. Compute per-turn action entropy + 4. Compute H_WM (world model uncertainty) per turn from env token entropys + 5. Compute dynamic mask m_t per turn (z-score + entropy floor) + 6. Apply mask to advantages: A_masked = A * m_t (broadcast to tokens) + 7. Return metrics for logging + + Args: + batch: DataProto or compatible object with batch dict containing + advantages, response_mask, old_log_probs, attention_mask + entropys: (batch_size, response_length) stored before pop in train_step + wmc_erc_config: OmegaConf DictConfig or dict with mu_base, lambda_wm, enable + + Returns: + (batch, metrics) where batch has masked advantages and metrics dict + """ + enable = ( + wmc_erc_config.get("enable", True) + if hasattr(wmc_erc_config, "get") + else getattr(wmc_erc_config, "enable", True) + ) + if not enable: + return batch, {} + + response_mask = batch.batch["response_mask"] + old_log_probs = batch.batch["old_log_probs"] + advantages = batch.batch["advantages"] + + # Compute attention mask for response region + response_length = advantages.shape[1] + attention_mask = batch.batch["attention_mask"] + attention_mask_response = attention_mask[:, -response_length:] + + # 1. Turn boundaries + turn_boundaries = compute_turn_boundaries(response_mask) + + # 2. S_* per turn + s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries) + + # 3. Per-turn action entropy + turn_entropy = compute_per_turn_entropy(entropys, response_mask, turn_boundaries) + + # 4. H_WM per turn + h_wm = compute_h_wm( + entropys, response_mask, attention_mask_response, turn_boundaries + ) + + # 5. Dynamic mask (z-score + entropy floor) + _get = lambda key, default: ( + wmc_erc_config.get(key, default) + if hasattr(wmc_erc_config, "get") + else getattr(wmc_erc_config, key, default) + ) + mu_base = float(_get("mu_base", 1.0)) + lambda_wm = float(_get("lambda_wm", 1.0)) + entropy_floor = float(_get("entropy_floor", 0.0)) + + mask, n_zscore_masked, n_entropy_floor_masked = compute_dynamic_mask( + s_star, + h_wm, + mu_base, + lambda_wm, + per_turn_entropy=turn_entropy, + entropy_floor=entropy_floor, + ) + + # 6. Apply mask to advantages (in-place) + batch_size = advantages.shape[0] + for i in range(batch_size): + for t, (start, end) in enumerate(turn_boundaries[i]): + if t < len(mask[i]): + advantages[i, start:end] *= mask[i][t] + batch.batch["advantages"] = advantages + + # 7. Adaptive entropy control via beta_token + # Instead of fixed entropy_coeff, use per-token beta based on turn entropy. + # - turns with entropy < target → beta = entropy_coeff (encourage exploration) + # - turns with entropy >= target → beta = 0 (don't push entropy higher) + # This prevents both entropy collapse AND entropy explosion. + entropy_target = float(_get("entropy_target", 0.0)) + base_entropy_coeff = float(_get("base_entropy_coeff", 0.0)) + + if entropy_target > 0 and base_entropy_coeff > 0: + beta_token = torch.zeros_like(advantages) + for i in range(batch_size): + for t, (start, end) in enumerate(turn_boundaries[i]): + if t < len(turn_entropy[i]): + turn_ent = turn_entropy[i][t].detach().item() + if turn_ent < entropy_target: + # Linear scaling: full coeff at entropy=0, zero at target + scale = max(0.0, 1.0 - turn_ent / entropy_target) + beta_token[i, start:end] = base_entropy_coeff * scale + # else: beta stays 0 (no entropy bonus for high-entropy turns) + batch.batch["beta_token"] = beta_token + + # 8. Metrics + all_s = [s.item() for turns in s_star for s in turns] + all_h = [h.item() for turns in h_wm for h in turns] + all_m = [m for turns in mask for m in turns] + all_te = [e.item() for turns in turn_entropy for e in turns] + + # WM NLL (monitoring only — not in backward pass for this prototype) + env_mask = attention_mask_response * (1.0 - response_mask) + env_count = env_mask.sum() + wm_nll = ( + (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() + if env_count > 0 + else 0.0 + ) + + # Beta token stats for logging + beta_mean = 0.0 + if entropy_target > 0 and base_entropy_coeff > 0: + active_beta = beta_token[response_mask.bool()] + beta_mean = active_beta.mean().item() if active_beta.numel() > 0 else 0.0 + + metrics = { + "wmc_erc/s_star_mean": float(np.mean(all_s)) if all_s else 0.0, + "wmc_erc/s_star_std": float(np.std(all_s)) if all_s else 0.0, + "wmc_erc/h_wm_mean": float(np.mean(all_h)) if all_h else 0.0, + "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, + "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), + "wmc_erc/total_turns": len(all_m), + "wmc_erc/wm_nll": wm_nll, + "wmc_erc/turn_entropy_mean": float(np.mean(all_te)) if all_te else 0.0, + "wmc_erc/n_zscore_masked": n_zscore_masked, + "wmc_erc/n_entropy_floor_masked": n_entropy_floor_masked, + "wmc_erc/adaptive_beta_mean": beta_mean, + } + + return batch, metrics diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 0f183186..1fdad611 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -45,10 +45,10 @@ algorithm: "agent_loop" # - "grpo" : Standard GRPO (outcome-only advantage) # - "grpo_per_step" : Per-step GRPO with return-based advantages (for multi-turn tasks) # - "gae" : Generalized Advantage Estimation (for PPO, requires critic) -adv_estimator: "grpo" +adv_estimator: "gae" # rollout_n: number of samples per prompt for GRPO/grpo_per_step # For PPO (gae), rollout_n is typically 1 -rollout_n: 8 +rollout_n: 1 # Interaction configuration interaction: @@ -60,7 +60,7 @@ interaction: env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} # If you run the ALFWorld env server in sharded mode (--shards N), # set env_shards=N. The client will route each instance_id to a stable shard. - env_shards: 32 + env_shards: 8 max_steps: 20 # ALFWorld episodes max steps max_total_steps: 20 # Max environment step calls (controls rollout turns) observation_template: "{observation}" diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml new file mode 100644 index 00000000..50b7fef5 --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -0,0 +1,95 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# PPO with GAE advantage estimation (requires critic) +adv_estimator: "gae" +rollout_n: 1 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: base clipping coefficient (controls tightness of the gate) +# - lambda_wm: how much WM uncertainty widens the gate (higher = more tolerant in unknown regions) +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + lambda_wm: 10.0 + entropy_floor: 0.1 + # Adaptive entropy control via beta_token (replaces fixed entropy_coeff) + # beta = base_entropy_coeff * max(0, 1 - turn_entropy/target) per turn + # -> full bonus at entropy=0, zero bonus at entropy=target, never pushes above target + entropy_target: 2.0 + base_entropy_coeff: 0.02 + +# Disable fixed entropy_coeff since adaptive beta_token handles it +entropy_coeff: 0.0 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 8 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 8 diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 08f65e9b..9e7242be 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -644,13 +644,38 @@ def set_config(self, args: DictConfig, env=None): server_cfg = OmegaConf.merge( server_cfg, OmegaConf.create( - {"actor_rollout_ref": {"rollout": {"agent": {"num_workers": agent_num_workers}}}} + { + "actor_rollout_ref": { + "rollout": {"agent": {"num_workers": agent_num_workers}} + } + } ), ) print( f"[ServiceClient] Overriding agent num_workers to: {agent_num_workers}" ) + # Pass WMC-ERC config to server if present + wmc_erc_cfg = getattr(args, "wmc_erc", None) + if wmc_erc_cfg is not None: + wmc_erc_dict = OmegaConf.to_container(wmc_erc_cfg, resolve=True) + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create({"wmc_erc": wmc_erc_dict}), + ) + print(f"[ServiceClient] Passing WMC-ERC config to server: {wmc_erc_dict}") + + # Pass entropy_coeff to server actor config if present + entropy_coeff = getattr(args, "entropy_coeff", None) + if entropy_coeff is not None: + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create( + {"actor_rollout_ref": {"actor": {"entropy_coeff": entropy_coeff}}} + ), + ) + print(f"[ServiceClient] Setting entropy_coeff: {entropy_coeff}") + generation_config = { "temperature": args.temperature, "top_p": args.top_p, diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 2a1824d4..b162cf34 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -284,7 +284,7 @@ def check_gpu_available(gpu_id: int) -> bool: return True # Fail open # Thresholds for considering a GPU "idle" - MAX_MEMORY_MB = 10 # Allow up to 100 MB (some baseline CUDA overhead) + MAX_MEMORY_MB = 1000 # Allow up to 1000 MB (some overhead from Ray/CUDA init) MAX_UTILIZATION = 1000 # Allow up to 5% utilization if memory_used_mb > MAX_MEMORY_MB or utilization_percent > MAX_UTILIZATION: @@ -294,35 +294,9 @@ def check_gpu_available(gpu_id: int) -> bool: ) return False - # Check 2: Look for running processes on this GPU - pmon_result = subprocess.run( - ["nvidia-smi", "pmon", "-c", "1", "-s", "um"], - capture_output=True, - text=True, - timeout=5, - ) - - if pmon_result.returncode == 0: - # Parse pmon output to check for processes on this GPU - # Format: "# gpu pid type sm mem enc dec command" - # " 0 12345 C 50 500 0 0 python" - lines = pmon_result.stdout.strip().split("\n") - for line in lines: - if line.startswith("#") or not line.strip(): - continue - parts = line.split() - if len(parts) >= 2: - try: - gpu_idx = int(parts[0].strip()) - if gpu_idx == gpu_id and parts[1].strip() != "-": - # Found a process on this GPU - pid = parts[1].strip() - logger.warning( - f"GPU {gpu_id}: ⚠️ OCCUPIED - Process {pid} detected via pmon" - ) - return False - except (ValueError, IndexError): - continue + # Check 2: pmon process check - SKIPPED to allow GPU sharing for small models + # When using small models (e.g. 0.5B), GPU sharing is safe as long as + # total memory fits. The memory threshold above handles this. # All checks passed - GPU is idle logger.debug( @@ -1085,12 +1059,18 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: if kl_config: use_kl_in_reward = kl_config.get("use_kl_in_reward") if use_kl_in_reward is not None: - cmd.append(f"algorithm.use_kl_in_reward={str(use_kl_in_reward).lower()}") - logger.info(f"Job {job.job_id}: ✓ KL use_kl_in_reward={use_kl_in_reward}") + cmd.append( + f"algorithm.use_kl_in_reward={str(use_kl_in_reward).lower()}" + ) + logger.info( + f"Job {job.job_id}: ✓ KL use_kl_in_reward={use_kl_in_reward}" + ) use_kl_loss = kl_config.get("use_kl_loss") if use_kl_loss is not None: - cmd.append(f"actor_rollout_ref.actor.use_kl_loss={str(use_kl_loss).lower()}") + cmd.append( + f"actor_rollout_ref.actor.use_kl_loss={str(use_kl_loss).lower()}" + ) logger.info(f"Job {job.job_id}: ✓ KL use_kl_loss={use_kl_loss}") kl_loss_coef = kl_config.get("kl_loss_coef") diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 8d67b106..13c1e850 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -27,6 +27,7 @@ import asyncio import base64 +import gc import logging import signal import sys @@ -943,7 +944,9 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # episodes into individual per-turn training samples, the gen_batch_output # batch size is larger than the original batch. We need to expand the # original batch to match using the expansion index. - expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + expansion_index = gen_batch_output.meta_info.pop( + "per_turn_expansion_index", None + ) if expansion_index is not None: logger.info( f"[Per-turn training] Expanding original batch from {len(batch)} to " @@ -956,6 +959,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) # Expand non-tensor batch expanded_non_tensor = {} @@ -1080,6 +1084,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) + _wmc_erc_entropys = entropys # Preserve for WMC-ERC before pop old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) logger.info( @@ -1161,6 +1166,23 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) + # 9.5 WMC-ERC: Dynamic entropy clipping + wmc_erc_cfg = OmegaConf.select(self.config, "wmc_erc", default=None) + if wmc_erc_cfg and wmc_erc_cfg.get("enable", False): + from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + ) + + batch, wmc_metrics = apply_wmc_erc( + batch, _wmc_erc_entropys, wmc_erc_cfg + ) + metrics.update(wmc_metrics) + logger.info( + f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " + f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " + f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}" + ) + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1258,6 +1280,13 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info(f"Training step {self.global_steps} completed successfully") + # Free large intermediates and force garbage collection to prevent OOM + del batch, gen_batch, gen_batch_output, reward_tensor + if "_wmc_erc_entropys" in dir(): + del _wmc_erc_entropys + gc.collect() + torch.cuda.empty_cache() + return { "status": "success", "metrics": metrics, @@ -1535,7 +1564,9 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: # 6. Merge original batch and generated output # Per-turn training expansion: expand batch if gen output is larger - expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + expansion_index = gen_batch_output.meta_info.pop( + "per_turn_expansion_index", None + ) if expansion_index is not None: logger.info( f"[Per-turn training] Validation: Expanding original batch from {len(batch)} to " @@ -1547,6 +1578,7 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) expanded_non_tensor = {} for k, v in batch.non_tensor_batch.items(): @@ -2106,6 +2138,14 @@ def run_fastapi_server(): namespace=_server_cfg.ray.namespace, num_gpus=_server_cfg.trainer.n_gpus_per_node, # Explicitly specify number of GPUs ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NCCL_CUMEM_ENABLE": "0", + "VLLM_DISABLE_SLEEP_MODE": "1", + "RAY_memory_usage_threshold": "0.99", + "VLLM_GPU_MEMORY_UTILIZATION": "0.15", + }, + }, ) else: # Connect to existing Ray cluster at specific address @@ -2114,6 +2154,14 @@ def run_fastapi_server(): address=_server_cfg.ray.address, namespace=_server_cfg.ray.namespace, ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NCCL_CUMEM_ENABLE": "0", + "VLLM_DISABLE_SLEEP_MODE": "1", + "RAY_memory_usage_threshold": "0.99", + "VLLM_GPU_MEMORY_UTILIZATION": "0.15", + }, + }, ) # Verify GPU availability diff --git a/opentinker/server/launch_http_server.py b/opentinker/server/launch_http_server.py index 830af82c..4e05751b 100644 --- a/opentinker/server/launch_http_server.py +++ b/opentinker/server/launch_http_server.py @@ -64,7 +64,7 @@ def main(cfg): cfg.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 cfg.actor_rollout_ref.rollout.name = "vllm" - cfg.actor_rollout_ref.rollout.gpu_memory_utilization = 0.6 + cfg.actor_rollout_ref.rollout.gpu_memory_utilization = 0.15 # GRPO/GRPO-per-step 特定配置 # grpo_per_step uses the same training framework as grpo, just with different advantage estimation diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py new file mode 100644 index 00000000..3d53a335 --- /dev/null +++ b/opentinker/tests/test_wmc_erc.py @@ -0,0 +1,403 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for WMC-ERC dynamic entropy clipping module. + +Run with: pytest opentinker/tests/test_wmc_erc.py -v +""" + +import numpy as np +import pytest +import torch + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) +from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + compute_dynamic_mask, + compute_h_wm, + compute_s_star, +) + + +class TestComputeSStar: + """Test S_* (policy blind confidence) computation.""" + + def test_single_turn(self): + """S_* for a single turn = mean of p_k * (H + log p_k) over action tokens.""" + # 1 sample, 6 positions: 4 action + 2 padding + p = torch.tensor([0.8, 0.6, 0.9, 0.7]) + old_log_probs = torch.zeros(1, 6) + old_log_probs[0, :4] = torch.log(p) + entropys = torch.tensor([[1.0, 1.5, 0.5, 1.2, 0.0, 0.0]]) + response_mask = torch.tensor([[1, 1, 1, 1, 0, 0]], dtype=torch.float32) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result) == 1 # 1 sample + assert len(result[0]) == 1 # 1 turn + + # Manual: p_k * (H + log(p_k)) for each token, then mean + H = torch.tensor([1.0, 1.5, 0.5, 1.2]) + expected = (p * (H + torch.log(p))).mean().item() + assert abs(result[0][0].item() - expected) < 1e-5 + + def test_multi_turn(self): + """Two turns should produce two S_* values.""" + old_log_probs = torch.log( + torch.tensor([[0.8, 0.6, 0.5, 0.5, 0.9, 0.7, 0.5, 0.5]]) + ) + entropys = torch.tensor([[1.0, 1.5, 0.0, 0.0, 0.5, 1.2, 0.0, 0.0]]) + response_mask = torch.tensor( + [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 + ) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result[0]) == 2 # 2 turns + # Turn 0 and Turn 1 should have different values + assert result[0][0].item() != result[0][1].item() + + def test_batch(self): + """Batch of 2 samples.""" + old_log_probs = torch.log( + torch.tensor( + [ + [0.8, 0.6, 0.5, 0.5], + [0.5, 0.9, 0.5, 0.5], + ] + ) + ) + entropys = torch.tensor( + [ + [1.0, 1.5, 0.0, 0.0], + [2.0, 0.3, 0.0, 0.0], + ] + ) + response_mask = torch.tensor( + [ + [1, 1, 0, 0], + [1, 1, 0, 0], + ], + dtype=torch.float32, + ) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result) == 2 + + def test_empty_turns(self): + """Sample with no action tokens should produce empty list.""" + old_log_probs = torch.zeros(1, 4) + entropys = torch.zeros(1, 4) + response_mask = torch.zeros(1, 4, dtype=torch.float32) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result) == 1 + assert len(result[0]) == 0 + + +class TestComputeHWM: + """Test H_WM (world model uncertainty) computation.""" + + def test_single_turn_with_env_tokens(self): + """H_WM for turn 0 = mean entropy at env positions after turn 0.""" + # Sequence: [action, action, env, env, pad, pad] + entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.0, 0.0]]) + response_mask = torch.tensor([[1, 1, 0, 0, 0, 0]], dtype=torch.float32) + attention_mask_response = torch.tensor( + [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 + ) + boundaries = [[(0, 2)]] # 1 turn: action at [0,2) + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + assert len(result) == 1 + assert len(result[0]) == 1 + # Env tokens at [2,3] → mean entropy = (3.0 + 4.0) / 2 = 3.5 + assert abs(result[0][0].item() - 3.5) < 1e-5 + + def test_two_turns(self): + """Two turns: H_WM_0 from env between turns, H_WM_1 from env after turn 1.""" + # [act, act, env, env, act, act, env, pad] + entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.5, 1.2, 2.0, 0.0]]) + response_mask = torch.tensor( + [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 + ) + attention_mask_response = torch.tensor( + [[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.float32 + ) + boundaries = [[(0, 2), (4, 6)]] + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + assert len(result[0]) == 2 + # Turn 0 env: positions [2,4) → (3.0+4.0)/2 = 3.5 + assert abs(result[0][0].item() - 3.5) < 1e-5 + # Turn 1 env: positions [6,8) but attn_mask=[1,0] → only pos 6 → 2.0 + assert abs(result[0][1].item() - 2.0) < 1e-5 + + def test_no_env_after_last_turn(self): + """Last turn has no env tokens → H_WM = 0.""" + # [act, act, env, act, act, pad] + entropys = torch.tensor([[1.0, 1.5, 3.0, 0.5, 1.2, 0.0]]) + response_mask = torch.tensor([[1, 1, 0, 1, 1, 0]], dtype=torch.float32) + attention_mask_response = torch.tensor( + [[1, 1, 1, 1, 1, 0]], dtype=torch.float32 + ) + boundaries = [[(0, 2), (3, 5)]] + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + # Turn 0 env: positions [2,3) → 3.0 + assert abs(result[0][0].item() - 3.0) < 1e-5 + # Turn 1 env: positions [5,6) but attn_mask=[0] → H_WM = 0 + assert result[0][1].item() == 0.0 + + def test_empty_turns(self): + """No turns → empty H_WM list.""" + entropys = torch.zeros(1, 4) + response_mask = torch.zeros(1, 4, dtype=torch.float32) + attention_mask_response = torch.ones(1, 4, dtype=torch.float32) + boundaries = [[]] + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + assert len(result[0]) == 0 + + +class TestComputeDynamicMask: + """Test dynamic entropy clipping mask.""" + + def test_all_pass(self): + """When all S_* are close to mean, all masks = 1.""" + s_star = [[torch.tensor(1.0), torch.tensor(1.1)]] + h_wm = [[torch.tensor(0.5), torch.tensor(0.5)]] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=3.0, lambda_wm=1.0) + assert mask == [[1.0, 1.0]] + + def test_outlier_blocked_low_hwm(self): + """High S_* outlier with low H_WM (known env) → blocked.""" + s_star = [ + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(10.0)], # outlier + ] + h_wm = [ + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(0.0)], # known env → tight threshold + ] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) + # The outlier (10.0) should be blocked when threshold is tight + assert mask[3] == [0.0] + # Normal ones should pass + assert mask[0] == [1.0] + + def test_outlier_allowed_high_hwm(self): + """High S_* outlier with high H_WM (unknown env) → allowed.""" + s_star = [ + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(10.0)], # outlier + ] + h_wm = [ + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(100.0)], # very uncertain → wide threshold + ] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) + # The outlier should be allowed because H_WM is very high + assert mask[3] == [1.0] + + def test_empty(self): + """Empty input should return empty.""" + mask = compute_dynamic_mask([], [], mu_base=1.0, lambda_wm=1.0) + assert mask == [] + + def test_single_element(self): + """Single S_* value should not produce nan (std guard).""" + s_star = [[torch.tensor(5.0)]] + h_wm = [[torch.tensor(1.0)]] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) + # Single element: |5.0 - 5.0| = 0 <= threshold → should pass + assert mask == [[1.0]] + + +class TestApplyWmcErc: + """Test full WMC-ERC orchestration.""" + + def _make_batch( + self, advantages, response_mask, old_log_probs, attention_mask + ): + """Create a minimal mock batch with required fields.""" + from unittest.mock import MagicMock + + batch = MagicMock() + batch.batch = { + "advantages": advantages.clone(), + "response_mask": response_mask, + "old_log_probs": old_log_probs, + "attention_mask": attention_mask, + } + return batch + + def test_masking_zeros_advantage(self): + """When a turn is masked, its advantages become zero.""" + # 4 samples, 1 turn each (4 action tokens + 0 env tokens) + # Sample 3 has extremely different S_* pattern + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + # Make sample 3 very overconfident (high p_k, low H) + old_log_probs = torch.tensor( + [ + [np.log(0.3)] * 4, + [np.log(0.3)] * 4, + [np.log(0.3)] * 4, + [np.log(0.99)] * 4, # very high confidence + ] + ) + entropys = torch.tensor( + [ + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [0.01, 0.01, 0.01, 0.01], # very low entropy + ] + ) + attention_mask = torch.ones(4, 4, dtype=torch.float32) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} + + _, metrics = apply_wmc_erc(batch, entropys, config) + + # Check metrics exist + assert "wmc_erc/mask_ratio" in metrics + assert "wmc_erc/s_star_mean" in metrics + assert "wmc_erc/h_wm_mean" in metrics + assert "wmc_erc/total_turns" in metrics + assert metrics["wmc_erc/total_turns"] == 4 + + # Verify masking behavior: + # Samples 0-2: S_* = 0.3*(2.0+log(0.3)) ≈ 0.239 (normal) + # Sample 3: S_* = 0.99*(0.01+log(0.99)) ≈ 0 (outlier in opposite direction) + # H_WM = 0 for all (no env tokens) → tight threshold + # Sample 3 should be masked (|S_3 - S_bar| > threshold) + adv = batch.batch["advantages"] + assert metrics["wmc_erc/num_masked_turns"] >= 1 + assert (adv[3] == 0).all(), "Sample 3 (overconfident outlier) should have zero advantages" + assert (adv[:3] == 2.0).all(), "Samples 0-2 (normal) should keep original advantages" + + def test_disabled(self): + """When enable=False, advantages unchanged.""" + response_mask = torch.ones(2, 4, dtype=torch.float32) + advantages = torch.ones(2, 4) * 5.0 + old_log_probs = torch.full((2, 4), np.log(0.5)) + attention_mask = torch.ones(2, 4, dtype=torch.float32) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + entropys = torch.ones(2, 4) + config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": False} + + _, metrics = apply_wmc_erc(batch, entropys, config) + assert (batch.batch["advantages"] == 5.0).all() + assert metrics == {} + + def test_multi_turn_selective_masking(self): + """Multi-turn: only overconfident turns in known env get masked.""" + # 2 samples, 2 turns each: [act, act, env, env, act, act, env, pad] + response_mask = torch.tensor( + [ + [1, 1, 0, 0, 1, 1, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0], + ], + dtype=torch.float32, + ) + advantages = torch.ones(2, 8) * 3.0 + # Sample 0: normal confidence + # Sample 1: overconfident on both turns + old_log_probs = torch.tensor( + [ + [np.log(0.3), np.log(0.3), 0, 0, np.log(0.3), np.log(0.3), 0, 0], + [np.log(0.99), np.log(0.99), 0, 0, np.log(0.99), np.log(0.99), 0, 0], + ] + ) + entropys = torch.tensor( + [ + [2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.0], + [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0], + ] + ) + attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + ], + dtype=torch.float32, + ) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} + + _, metrics = apply_wmc_erc(batch, entropys, config) + assert metrics["wmc_erc/total_turns"] == 4 + # With low H_WM on sample 1, its overconfident turns should be blocked + # This is a statistical test — the exact outcome depends on batch stats + + def test_returns_wm_nll_metric(self): + """WM NLL metric should be computed from env token log probs.""" + response_mask = torch.tensor( + [[1, 1, 0, 0, 0, 0]], dtype=torch.float32 + ) + advantages = torch.ones(1, 6) + old_log_probs = torch.tensor( + [[np.log(0.5), np.log(0.5), np.log(0.3), np.log(0.4), 0, 0]] + ) + entropys = torch.tensor([[1.0, 1.0, 2.0, 3.0, 0.0, 0.0]]) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 + ) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + config = {"mu_base": 5.0, "lambda_wm": 1.0, "enable": True} + + _, metrics = apply_wmc_erc(batch, entropys, config) + assert "wmc_erc/wm_nll" in metrics + # WM NLL = -mean(log_prob at env positions [2,3]) + # = -(log(0.3) + log(0.4)) / 2 + expected_nll = -(np.log(0.3) + np.log(0.4)) / 2.0 + assert abs(metrics["wmc_erc/wm_nll"] - expected_nll) < 1e-4 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])