From 2c1445b934e10e7c4dc681d984e71eaa77320dba Mon Sep 17 00:00:00 2001 From: Haofei Yu <1125027232@qq.com> Date: Fri, 13 Mar 2026 04:21:49 +0800 Subject: [PATCH 1/3] update --- .../backend_patch/verl/trainer/ppo/wmc_erc.py | 297 +++++++++++++ .../client_config/alfworld_wmc_erc_param.yaml | 86 ++++ opentinker/server/http_training_server.py | 11 + opentinker/tests/test_wmc_erc.py | 403 ++++++++++++++++++ 4 files changed, 797 insertions(+) create mode 100644 opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py create mode 100644 opentinker/client/client_config/alfworld_wmc_erc_param.yaml create mode 100644 opentinker/tests/test_wmc_erc.py diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py new file mode 100644 index 00000000..2ea6182a --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -0,0 +1,297 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +WMC-ERC: World Model-Conditioned Entropy Regularized Co-evolution. + +Dynamic entropy clipping for multi-turn agentic RL. Uses the LLM's own +prediction entropy at environment token positions as a World Model uncertainty +signal (H_WM) to dynamically gate policy gradient updates, preventing +entropy collapse in well-understood regions while permitting exploration +in uncertain ones. + +Core idea: +- S_* measures per-turn "blind confidence" of the policy (entropy collapse momentum) +- H_WM measures per-turn World Model uncertainty (prediction entropy at env tokens) +- Dynamic mask m_t: when WM is confident but policy is overconfident → block update + when WM is uncertain → allow exploration regardless of confidence + +Reference: World Model-Conditioned Entropy Regularized Co-evolution (WMC-ERC) +""" + +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) + + +def compute_s_star( + old_log_probs: torch.Tensor, + entropys: torch.Tensor, + response_mask: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute per-turn policy blind confidence S_*. + + S_*^t = mean over action tokens in turn t of: p_k * (H + log p_k) + + where p_k is the probability of the chosen action token and H is the + full-distribution entropy. This quantity measures how strongly the + policy's token-level probability mass is driving entropy downward: + high S_* means the policy is aggressively collapsing toward a single + action, creating momentum for entropy collapse. + + Based on first-order Taylor expansion of the entropy discriminator + (see WMC-ERC algorithm specification). + + Args: + old_log_probs: (batch_size, response_length) log probs of chosen tokens + entropys: (batch_size, response_length) entropy of policy distribution + response_mask: (batch_size, response_length) 1=action, 0=env/pad + turn_boundaries: per-sample list of (start, end) tuples for action turns + + Returns: + List of lists of scalar tensors, one S_* per turn per sample. + """ + batch_size = old_log_probs.shape[0] + device = old_log_probs.device + s_star_per_sample = [] + + for i in range(batch_size): + s_star_turns = [] + for start, end in turn_boundaries[i]: + log_p = old_log_probs[i, start:end] + H = entropys[i, start:end] + mask = response_mask[i, start:end] + count = mask.sum() + + if count > 0: + p_k = torch.exp(log_p) + s_token = p_k * (H + log_p) + s_t = (s_token * mask).sum() / count + else: + s_t = torch.tensor(0.0, device=device) + + s_star_turns.append(s_t) + s_star_per_sample.append(s_star_turns) + + return s_star_per_sample + + +def compute_h_wm( + entropys: torch.Tensor, + response_mask: torch.Tensor, + attention_mask_response: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute per-turn World Model uncertainty H_WM. + + H_WM^t = mean prediction entropy at env token positions following action turn t. + + Env tokens after turn t represent the environment's response to action a_t. + The model's entropy at these positions measures how uncertain it is about + predicting the next state — i.e., the World Model's "cognitive blind spot". + + Higher H_WM → model doesn't understand this environment transition well. + Lower H_WM → model has seen similar transitions and is confident. + + Args: + entropys: (batch_size, response_length) entropy at all positions + response_mask: (batch_size, response_length) 1=action, 0=env/pad + attention_mask_response: (batch_size, response_length) 1=real, 0=padding + turn_boundaries: per-sample list of (start, end) for action turns + + Returns: + List of lists of scalar tensors, one H_WM per turn per sample. + """ + batch_size = entropys.shape[0] + seq_len = entropys.shape[1] + device = entropys.device + env_mask = attention_mask_response * (1.0 - response_mask) + + h_wm_per_sample = [] + + for i in range(batch_size): + boundaries = turn_boundaries[i] + h_wm_turns = [] + + for t, (start, end) in enumerate(boundaries): + # Env tokens after this turn: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(boundaries): + env_end = boundaries[t + 1][0] + else: + env_end = seq_len + + region_mask = env_mask[i, end:env_end] + region_entropy = entropys[i, end:env_end] + count = region_mask.sum() + + if count > 0: + h_wm_t = (region_entropy * region_mask).sum() / count + else: + h_wm_t = torch.tensor(0.0, device=device) + + h_wm_turns.append(h_wm_t) + + h_wm_per_sample.append(h_wm_turns) + + return h_wm_per_sample + + +def compute_dynamic_mask( + s_star_per_sample: List[List[torch.Tensor]], + h_wm_per_sample: List[List[torch.Tensor]], + mu_base: float = 1.0, + lambda_wm: float = 1.0, +) -> List[List[float]]: + """Compute per-turn dynamic entropy clipping mask. + + m_t = 1 if |S_*^t - S_bar| <= mu_base * (1 + lambda * H_WM^t) * sigma + 0 otherwise + + Logic: + - WM confident (H_WM→0): threshold tightens → overconfident policy blocked + (prevents overfitting in well-understood regions) + - WM uncertain (H_WM large): threshold widens → exploration encouraged + (allows agent to gather data in unknown regions to train the WM) + + H_WM is detached — it only acts as a scalar gate, never participates in + policy gradient backpropagation. + + Args: + s_star_per_sample: per-sample, per-turn S_* tensors + h_wm_per_sample: per-sample, per-turn H_WM tensors + mu_base: base clipping coefficient + lambda_wm: WM uncertainty weight + + Returns: + List of lists of floats (0.0 or 1.0), one mask per turn per sample. + """ + # Flatten all S_* for batch statistics + all_s = [] + for turns in s_star_per_sample: + for s in turns: + all_s.append(s.detach()) + + if len(all_s) == 0: + return [[] for _ in s_star_per_sample] + + all_s_tensor = torch.stack(all_s) + s_bar = all_s_tensor.mean() + + # Guard for single-element: std is 0, threshold = mu_base * (1 + lambda * h_wm) * 0 + # → everything would be masked. Use 1.0 as default sigma for single element. + if len(all_s) <= 1: + sigma = torch.tensor(1.0, device=all_s_tensor.device) + else: + sigma = all_s_tensor.std(unbiased=False) + 1e-8 + + mask_per_sample = [] + for i in range(len(s_star_per_sample)): + masks = [] + for t in range(len(s_star_per_sample[i])): + s_t = s_star_per_sample[i][t].detach() + h_t = h_wm_per_sample[i][t].detach() + + threshold = mu_base * (1.0 + lambda_wm * h_t) * sigma + m_t = 1.0 if torch.abs(s_t - s_bar) <= threshold else 0.0 + masks.append(m_t) + mask_per_sample.append(masks) + + return mask_per_sample + + +def apply_wmc_erc( + batch, + entropys: torch.Tensor, + wmc_erc_config, +) -> Tuple[object, Dict[str, float]]: + """Apply WMC-ERC dynamic entropy clipping to batch advantages. + + Pipeline: + 1. Compute turn boundaries from response_mask + 2. Compute S_* (policy blind confidence) per turn + 3. Compute H_WM (world model uncertainty) per turn from env token entropys + 4. Compute dynamic mask m_t per turn + 5. Apply mask to advantages: A_masked = A * m_t (broadcast to tokens) + 6. Return metrics for logging + + Args: + batch: DataProto or compatible object with batch dict containing + advantages, response_mask, old_log_probs, attention_mask + entropys: (batch_size, response_length) stored before pop in train_step + wmc_erc_config: OmegaConf DictConfig or dict with mu_base, lambda_wm, enable + + Returns: + (batch, metrics) where batch has masked advantages and metrics dict + """ + enable = wmc_erc_config.get("enable", True) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'enable', True) + if not enable: + return batch, {} + + response_mask = batch.batch["response_mask"] + old_log_probs = batch.batch["old_log_probs"] + advantages = batch.batch["advantages"] + + # Compute attention mask for response region + response_length = advantages.shape[1] + attention_mask = batch.batch["attention_mask"] + attention_mask_response = attention_mask[:, -response_length:] + + # 1. Turn boundaries + turn_boundaries = compute_turn_boundaries(response_mask) + + # 2. S_* per turn + s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries) + + # 3. H_WM per turn + h_wm = compute_h_wm(entropys, response_mask, attention_mask_response, turn_boundaries) + + # 4. Dynamic mask + mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0)) + lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0)) + mask = compute_dynamic_mask(s_star, h_wm, mu_base, lambda_wm) + + # 5. Apply mask to advantages (in-place) + batch_size = advantages.shape[0] + for i in range(batch_size): + for t, (start, end) in enumerate(turn_boundaries[i]): + if t < len(mask[i]): + advantages[i, start:end] *= mask[i][t] + batch.batch["advantages"] = advantages + + # 6. Metrics + all_s = [s.item() for turns in s_star for s in turns] + all_h = [h.item() for turns in h_wm for h in turns] + all_m = [m for turns in mask for m in turns] + + # WM NLL (monitoring only — not in backward pass for this prototype) + env_mask = attention_mask_response * (1.0 - response_mask) + env_count = env_mask.sum() + wm_nll = (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() if env_count > 0 else 0.0 + + metrics = { + "wmc_erc/s_star_mean": float(np.mean(all_s)) if all_s else 0.0, + "wmc_erc/s_star_std": float(np.std(all_s)) if all_s else 0.0, + "wmc_erc/h_wm_mean": float(np.mean(all_h)) if all_h else 0.0, + "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, + "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), + "wmc_erc/total_turns": len(all_m), + "wmc_erc/wm_nll": wm_nll, + } + + return batch, metrics diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml new file mode 100644 index 00000000..7c7c569e --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -0,0 +1,86 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "grpo_per_step" +rollout_n: 8 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: base clipping coefficient (controls tightness of the gate) +# - lambda_wm: how much WM uncertainty widens the gate (higher = more tolerant in unknown regions) +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + lambda_wm: 1.0 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 32 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 8 diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 8d67b106..e5aec372 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -1080,6 +1080,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) + _wmc_erc_entropys = entropys # Preserve for WMC-ERC before pop old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) logger.info( @@ -1161,6 +1162,16 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) + # 9.5 WMC-ERC: Dynamic entropy clipping + wmc_erc_cfg = OmegaConf.select(self.config, "wmc_erc", default=None) + if wmc_erc_cfg and wmc_erc_cfg.get("enable", False): + from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import apply_wmc_erc + batch, wmc_metrics = apply_wmc_erc(batch, _wmc_erc_entropys, wmc_erc_cfg) + metrics.update(wmc_metrics) + logger.info(f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " + f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " + f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}") + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py new file mode 100644 index 00000000..3d53a335 --- /dev/null +++ b/opentinker/tests/test_wmc_erc.py @@ -0,0 +1,403 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for WMC-ERC dynamic entropy clipping module. + +Run with: pytest opentinker/tests/test_wmc_erc.py -v +""" + +import numpy as np +import pytest +import torch + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) +from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + compute_dynamic_mask, + compute_h_wm, + compute_s_star, +) + + +class TestComputeSStar: + """Test S_* (policy blind confidence) computation.""" + + def test_single_turn(self): + """S_* for a single turn = mean of p_k * (H + log p_k) over action tokens.""" + # 1 sample, 6 positions: 4 action + 2 padding + p = torch.tensor([0.8, 0.6, 0.9, 0.7]) + old_log_probs = torch.zeros(1, 6) + old_log_probs[0, :4] = torch.log(p) + entropys = torch.tensor([[1.0, 1.5, 0.5, 1.2, 0.0, 0.0]]) + response_mask = torch.tensor([[1, 1, 1, 1, 0, 0]], dtype=torch.float32) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result) == 1 # 1 sample + assert len(result[0]) == 1 # 1 turn + + # Manual: p_k * (H + log(p_k)) for each token, then mean + H = torch.tensor([1.0, 1.5, 0.5, 1.2]) + expected = (p * (H + torch.log(p))).mean().item() + assert abs(result[0][0].item() - expected) < 1e-5 + + def test_multi_turn(self): + """Two turns should produce two S_* values.""" + old_log_probs = torch.log( + torch.tensor([[0.8, 0.6, 0.5, 0.5, 0.9, 0.7, 0.5, 0.5]]) + ) + entropys = torch.tensor([[1.0, 1.5, 0.0, 0.0, 0.5, 1.2, 0.0, 0.0]]) + response_mask = torch.tensor( + [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 + ) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result[0]) == 2 # 2 turns + # Turn 0 and Turn 1 should have different values + assert result[0][0].item() != result[0][1].item() + + def test_batch(self): + """Batch of 2 samples.""" + old_log_probs = torch.log( + torch.tensor( + [ + [0.8, 0.6, 0.5, 0.5], + [0.5, 0.9, 0.5, 0.5], + ] + ) + ) + entropys = torch.tensor( + [ + [1.0, 1.5, 0.0, 0.0], + [2.0, 0.3, 0.0, 0.0], + ] + ) + response_mask = torch.tensor( + [ + [1, 1, 0, 0], + [1, 1, 0, 0], + ], + dtype=torch.float32, + ) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result) == 2 + + def test_empty_turns(self): + """Sample with no action tokens should produce empty list.""" + old_log_probs = torch.zeros(1, 4) + entropys = torch.zeros(1, 4) + response_mask = torch.zeros(1, 4, dtype=torch.float32) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + assert len(result) == 1 + assert len(result[0]) == 0 + + +class TestComputeHWM: + """Test H_WM (world model uncertainty) computation.""" + + def test_single_turn_with_env_tokens(self): + """H_WM for turn 0 = mean entropy at env positions after turn 0.""" + # Sequence: [action, action, env, env, pad, pad] + entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.0, 0.0]]) + response_mask = torch.tensor([[1, 1, 0, 0, 0, 0]], dtype=torch.float32) + attention_mask_response = torch.tensor( + [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 + ) + boundaries = [[(0, 2)]] # 1 turn: action at [0,2) + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + assert len(result) == 1 + assert len(result[0]) == 1 + # Env tokens at [2,3] → mean entropy = (3.0 + 4.0) / 2 = 3.5 + assert abs(result[0][0].item() - 3.5) < 1e-5 + + def test_two_turns(self): + """Two turns: H_WM_0 from env between turns, H_WM_1 from env after turn 1.""" + # [act, act, env, env, act, act, env, pad] + entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.5, 1.2, 2.0, 0.0]]) + response_mask = torch.tensor( + [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 + ) + attention_mask_response = torch.tensor( + [[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.float32 + ) + boundaries = [[(0, 2), (4, 6)]] + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + assert len(result[0]) == 2 + # Turn 0 env: positions [2,4) → (3.0+4.0)/2 = 3.5 + assert abs(result[0][0].item() - 3.5) < 1e-5 + # Turn 1 env: positions [6,8) but attn_mask=[1,0] → only pos 6 → 2.0 + assert abs(result[0][1].item() - 2.0) < 1e-5 + + def test_no_env_after_last_turn(self): + """Last turn has no env tokens → H_WM = 0.""" + # [act, act, env, act, act, pad] + entropys = torch.tensor([[1.0, 1.5, 3.0, 0.5, 1.2, 0.0]]) + response_mask = torch.tensor([[1, 1, 0, 1, 1, 0]], dtype=torch.float32) + attention_mask_response = torch.tensor( + [[1, 1, 1, 1, 1, 0]], dtype=torch.float32 + ) + boundaries = [[(0, 2), (3, 5)]] + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + # Turn 0 env: positions [2,3) → 3.0 + assert abs(result[0][0].item() - 3.0) < 1e-5 + # Turn 1 env: positions [5,6) but attn_mask=[0] → H_WM = 0 + assert result[0][1].item() == 0.0 + + def test_empty_turns(self): + """No turns → empty H_WM list.""" + entropys = torch.zeros(1, 4) + response_mask = torch.zeros(1, 4, dtype=torch.float32) + attention_mask_response = torch.ones(1, 4, dtype=torch.float32) + boundaries = [[]] + + result = compute_h_wm( + entropys, response_mask, attention_mask_response, boundaries + ) + assert len(result[0]) == 0 + + +class TestComputeDynamicMask: + """Test dynamic entropy clipping mask.""" + + def test_all_pass(self): + """When all S_* are close to mean, all masks = 1.""" + s_star = [[torch.tensor(1.0), torch.tensor(1.1)]] + h_wm = [[torch.tensor(0.5), torch.tensor(0.5)]] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=3.0, lambda_wm=1.0) + assert mask == [[1.0, 1.0]] + + def test_outlier_blocked_low_hwm(self): + """High S_* outlier with low H_WM (known env) → blocked.""" + s_star = [ + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(10.0)], # outlier + ] + h_wm = [ + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(0.0)], # known env → tight threshold + ] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) + # The outlier (10.0) should be blocked when threshold is tight + assert mask[3] == [0.0] + # Normal ones should pass + assert mask[0] == [1.0] + + def test_outlier_allowed_high_hwm(self): + """High S_* outlier with high H_WM (unknown env) → allowed.""" + s_star = [ + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(0.5)], + [torch.tensor(10.0)], # outlier + ] + h_wm = [ + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(0.0)], + [torch.tensor(100.0)], # very uncertain → wide threshold + ] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) + # The outlier should be allowed because H_WM is very high + assert mask[3] == [1.0] + + def test_empty(self): + """Empty input should return empty.""" + mask = compute_dynamic_mask([], [], mu_base=1.0, lambda_wm=1.0) + assert mask == [] + + def test_single_element(self): + """Single S_* value should not produce nan (std guard).""" + s_star = [[torch.tensor(5.0)]] + h_wm = [[torch.tensor(1.0)]] + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) + # Single element: |5.0 - 5.0| = 0 <= threshold → should pass + assert mask == [[1.0]] + + +class TestApplyWmcErc: + """Test full WMC-ERC orchestration.""" + + def _make_batch( + self, advantages, response_mask, old_log_probs, attention_mask + ): + """Create a minimal mock batch with required fields.""" + from unittest.mock import MagicMock + + batch = MagicMock() + batch.batch = { + "advantages": advantages.clone(), + "response_mask": response_mask, + "old_log_probs": old_log_probs, + "attention_mask": attention_mask, + } + return batch + + def test_masking_zeros_advantage(self): + """When a turn is masked, its advantages become zero.""" + # 4 samples, 1 turn each (4 action tokens + 0 env tokens) + # Sample 3 has extremely different S_* pattern + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + # Make sample 3 very overconfident (high p_k, low H) + old_log_probs = torch.tensor( + [ + [np.log(0.3)] * 4, + [np.log(0.3)] * 4, + [np.log(0.3)] * 4, + [np.log(0.99)] * 4, # very high confidence + ] + ) + entropys = torch.tensor( + [ + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [0.01, 0.01, 0.01, 0.01], # very low entropy + ] + ) + attention_mask = torch.ones(4, 4, dtype=torch.float32) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} + + _, metrics = apply_wmc_erc(batch, entropys, config) + + # Check metrics exist + assert "wmc_erc/mask_ratio" in metrics + assert "wmc_erc/s_star_mean" in metrics + assert "wmc_erc/h_wm_mean" in metrics + assert "wmc_erc/total_turns" in metrics + assert metrics["wmc_erc/total_turns"] == 4 + + # Verify masking behavior: + # Samples 0-2: S_* = 0.3*(2.0+log(0.3)) ≈ 0.239 (normal) + # Sample 3: S_* = 0.99*(0.01+log(0.99)) ≈ 0 (outlier in opposite direction) + # H_WM = 0 for all (no env tokens) → tight threshold + # Sample 3 should be masked (|S_3 - S_bar| > threshold) + adv = batch.batch["advantages"] + assert metrics["wmc_erc/num_masked_turns"] >= 1 + assert (adv[3] == 0).all(), "Sample 3 (overconfident outlier) should have zero advantages" + assert (adv[:3] == 2.0).all(), "Samples 0-2 (normal) should keep original advantages" + + def test_disabled(self): + """When enable=False, advantages unchanged.""" + response_mask = torch.ones(2, 4, dtype=torch.float32) + advantages = torch.ones(2, 4) * 5.0 + old_log_probs = torch.full((2, 4), np.log(0.5)) + attention_mask = torch.ones(2, 4, dtype=torch.float32) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + entropys = torch.ones(2, 4) + config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": False} + + _, metrics = apply_wmc_erc(batch, entropys, config) + assert (batch.batch["advantages"] == 5.0).all() + assert metrics == {} + + def test_multi_turn_selective_masking(self): + """Multi-turn: only overconfident turns in known env get masked.""" + # 2 samples, 2 turns each: [act, act, env, env, act, act, env, pad] + response_mask = torch.tensor( + [ + [1, 1, 0, 0, 1, 1, 0, 0], + [1, 1, 0, 0, 1, 1, 0, 0], + ], + dtype=torch.float32, + ) + advantages = torch.ones(2, 8) * 3.0 + # Sample 0: normal confidence + # Sample 1: overconfident on both turns + old_log_probs = torch.tensor( + [ + [np.log(0.3), np.log(0.3), 0, 0, np.log(0.3), np.log(0.3), 0, 0], + [np.log(0.99), np.log(0.99), 0, 0, np.log(0.99), np.log(0.99), 0, 0], + ] + ) + entropys = torch.tensor( + [ + [2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.0], + [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0], + ] + ) + attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1, 1, 0], + ], + dtype=torch.float32, + ) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} + + _, metrics = apply_wmc_erc(batch, entropys, config) + assert metrics["wmc_erc/total_turns"] == 4 + # With low H_WM on sample 1, its overconfident turns should be blocked + # This is a statistical test — the exact outcome depends on batch stats + + def test_returns_wm_nll_metric(self): + """WM NLL metric should be computed from env token log probs.""" + response_mask = torch.tensor( + [[1, 1, 0, 0, 0, 0]], dtype=torch.float32 + ) + advantages = torch.ones(1, 6) + old_log_probs = torch.tensor( + [[np.log(0.5), np.log(0.5), np.log(0.3), np.log(0.4), 0, 0]] + ) + entropys = torch.tensor([[1.0, 1.0, 2.0, 3.0, 0.0, 0.0]]) + attention_mask = torch.tensor( + [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 + ) + + batch = self._make_batch( + advantages, response_mask, old_log_probs, attention_mask + ) + config = {"mu_base": 5.0, "lambda_wm": 1.0, "enable": True} + + _, metrics = apply_wmc_erc(batch, entropys, config) + assert "wmc_erc/wm_nll" in metrics + # WM NLL = -mean(log_prob at env positions [2,3]) + # = -(log(0.3) + log(0.4)) / 2 + expected_nll = -(np.log(0.3) + np.log(0.4)) / 2.0 + assert abs(metrics["wmc_erc/wm_nll"] - expected_nll) < 1e-4 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From db7dee46f6288f5f0f6536f7744c884811472bbd Mon Sep 17 00:00:00 2001 From: lwaekfjlk <1125027232@qq.com> Date: Sat, 14 Mar 2026 16:25:33 +0000 Subject: [PATCH 2/3] support co-evolve v2 --- .../client_config/alfworld_wmc_erc_param.yaml | 2 +- opentinker/server/http_training_server.py | 49 ++++++++++++++++--- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index 7c7c569e..c6cf5c19 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -65,7 +65,7 @@ interaction: env_host: 0.0.0.0 env_port: 8092 env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} - env_shards: 32 + env_shards: 1 max_steps: 20 max_total_steps: 20 observation_template: "{observation}" diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index e5aec372..c4fe1ce0 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -27,6 +27,7 @@ import asyncio import base64 +import gc import logging import signal import sys @@ -943,7 +944,9 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # episodes into individual per-turn training samples, the gen_batch_output # batch size is larger than the original batch. We need to expand the # original batch to match using the expansion index. - expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + expansion_index = gen_batch_output.meta_info.pop( + "per_turn_expansion_index", None + ) if expansion_index is not None: logger.info( f"[Per-turn training] Expanding original batch from {len(batch)} to " @@ -956,6 +959,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) # Expand non-tensor batch expanded_non_tensor = {} @@ -1165,12 +1169,19 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # 9.5 WMC-ERC: Dynamic entropy clipping wmc_erc_cfg = OmegaConf.select(self.config, "wmc_erc", default=None) if wmc_erc_cfg and wmc_erc_cfg.get("enable", False): - from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import apply_wmc_erc - batch, wmc_metrics = apply_wmc_erc(batch, _wmc_erc_entropys, wmc_erc_cfg) + from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + ) + + batch, wmc_metrics = apply_wmc_erc( + batch, _wmc_erc_entropys, wmc_erc_cfg + ) metrics.update(wmc_metrics) - logger.info(f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " - f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " - f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}") + logger.info( + f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " + f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " + f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}" + ) # 10. Update critic if self.use_critic: @@ -1269,6 +1280,13 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info(f"Training step {self.global_steps} completed successfully") + # Free large intermediates and force garbage collection to prevent OOM + del batch, gen_batch, gen_batch_output, reward_tensor + if "_wmc_erc_entropys" in dir(): + del _wmc_erc_entropys + gc.collect() + torch.cuda.empty_cache() + return { "status": "success", "metrics": metrics, @@ -1546,7 +1564,9 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: # 6. Merge original batch and generated output # Per-turn training expansion: expand batch if gen output is larger - expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + expansion_index = gen_batch_output.meta_info.pop( + "per_turn_expansion_index", None + ) if expansion_index is not None: logger.info( f"[Per-turn training] Validation: Expanding original batch from {len(batch)} to " @@ -1558,6 +1578,7 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) expanded_non_tensor = {} for k, v in batch.non_tensor_batch.items(): @@ -2117,6 +2138,13 @@ def run_fastapi_server(): namespace=_server_cfg.ray.namespace, num_gpus=_server_cfg.trainer.n_gpus_per_node, # Explicitly specify number of GPUs ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NCCL_CUMEM_ENABLE": "0", + "VLLM_DISABLE_SLEEP_MODE": "1", + "RAY_memory_usage_threshold": "0.99", + }, + }, ) else: # Connect to existing Ray cluster at specific address @@ -2125,6 +2153,13 @@ def run_fastapi_server(): address=_server_cfg.ray.address, namespace=_server_cfg.ray.namespace, ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NCCL_CUMEM_ENABLE": "0", + "VLLM_DISABLE_SLEEP_MODE": "1", + "RAY_memory_usage_threshold": "0.99", + }, + }, ) # Verify GPU availability From 586dd318959f1489048caa7634668d1df595075e Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Tue, 17 Mar 2026 15:36:28 +0000 Subject: [PATCH 3/3] Added entropy normalization and different threshold for clipping. --- .../backend_patch/verl/trainer/ppo/wmc_erc.py | 163 ++++--- opentinker/client/alfworld_rl.py | 1 + opentinker/client/android_world_rl.py | 1 + .../client/client_config/alfworld_param.yaml | 4 +- .../client_config/alfworld_wmc_erc_param.yaml | 14 +- .../alfworld_wmc_erc_param_ppo.yaml | 93 ++++ opentinker/client/geo3k_rl.py | 1 + opentinker/client/geo3k_tool_rl.py | 1 + opentinker/client/gomoku_rl.py | 1 + opentinker/client/math_rl.py | 1 + opentinker/client/math_tool_rl.py | 1 + .../client/utils/http_training_client.py | 18 +- opentinker/environment/__init__.py | 35 +- opentinker/environment/base_game_server.py | 126 +++--- opentinker/scheduler/job_scheduler.py | 16 + opentinker/scripts/run_alfworld.sh | 97 ++++ opentinker/server/http_training_server.py | 18 +- opentinker/tests/test_wmc_erc.py | 413 +++--------------- 18 files changed, 520 insertions(+), 484 deletions(-) create mode 100644 opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml create mode 100755 opentinker/scripts/run_alfworld.sh diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py index 2ea6182a..c06bcf02 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -155,60 +155,52 @@ def compute_h_wm( def compute_dynamic_mask( s_star_per_sample: List[List[torch.Tensor]], h_wm_per_sample: List[List[torch.Tensor]], - mu_base: float = 1.0, - lambda_wm: float = 1.0, + mu_base: float, + mu_exp: float, + lambda_wm: float, + s_bar: float, + sigma: float, + h_bar: float, ) -> List[List[float]]: """Compute per-turn dynamic entropy clipping mask. - m_t = 1 if |S_*^t - S_bar| <= mu_base * (1 + lambda * H_WM^t) * sigma - 0 otherwise - Logic: - - WM confident (H_WM→0): threshold tightens → overconfident policy blocked - (prevents overfitting in well-understood regions) - - WM uncertain (H_WM large): threshold widens → exploration encouraged - (allows agent to gather data in unknown regions to train the WM) - - H_WM is detached — it only acts as a scalar gate, never participates in - policy gradient backpropagation. + - Normalization: H_WM is normalized by h_bar (batch or global). + - Threshold: threshold = mu * (1 + lambda * H_WM_norm) * sigma Args: s_star_per_sample: per-sample, per-turn S_* tensors h_wm_per_sample: per-sample, per-turn H_WM tensors - mu_base: base clipping coefficient + mu_base: clipping coefficient for collapsing side + mu_exp: clipping coefficient for exploration side lambda_wm: WM uncertainty weight + s_bar: mean of S_* (batch or global) + sigma: std of S_* (batch or global) + h_bar: mean of H_WM (batch or global) Returns: List of lists of floats (0.0 or 1.0), one mask per turn per sample. """ - # Flatten all S_* for batch statistics - all_s = [] - for turns in s_star_per_sample: - for s in turns: - all_s.append(s.detach()) - - if len(all_s) == 0: - return [[] for _ in s_star_per_sample] - - all_s_tensor = torch.stack(all_s) - s_bar = all_s_tensor.mean() - - # Guard for single-element: std is 0, threshold = mu_base * (1 + lambda * h_wm) * 0 - # → everything would be masked. Use 1.0 as default sigma for single element. - if len(all_s) <= 1: - sigma = torch.tensor(1.0, device=all_s_tensor.device) - else: - sigma = all_s_tensor.std(unbiased=False) + 1e-8 - mask_per_sample = [] for i in range(len(s_star_per_sample)): masks = [] for t in range(len(s_star_per_sample[i])): - s_t = s_star_per_sample[i][t].detach() - h_t = h_wm_per_sample[i][t].detach() - - threshold = mu_base * (1.0 + lambda_wm * h_t) * sigma - m_t = 1.0 if torch.abs(s_t - s_bar) <= threshold else 0.0 + s_t = s_star_per_sample[i][t].detach().item() + h_t = h_wm_per_sample[i][t].detach().item() + + # Normalize H_WM + h_t_norm = h_t / (h_bar + 1e-8) + + # Asymmetric threshold calculation + if s_t > s_bar: + # Collapsing side + threshold = mu_base * (1.0 + lambda_wm * h_t_norm) * sigma + m_t = 1.0 if (s_t - s_bar) <= threshold else 0.0 + else: + # Exploration side + threshold = mu_exp * (1.0 + lambda_wm * h_t_norm) * sigma + m_t = 1.0 if (s_bar - s_t) <= threshold else 0.0 + masks.append(m_t) mask_per_sample.append(masks) @@ -219,35 +211,28 @@ def apply_wmc_erc( batch, entropys: torch.Tensor, wmc_erc_config, + running_stats: Dict[str, float], ) -> Tuple[object, Dict[str, float]]: """Apply WMC-ERC dynamic entropy clipping to batch advantages. - Pipeline: - 1. Compute turn boundaries from response_mask - 2. Compute S_* (policy blind confidence) per turn - 3. Compute H_WM (world model uncertainty) per turn from env token entropys - 4. Compute dynamic mask m_t per turn - 5. Apply mask to advantages: A_masked = A * m_t (broadcast to tokens) - 6. Return metrics for logging - Args: - batch: DataProto or compatible object with batch dict containing - advantages, response_mask, old_log_probs, attention_mask - entropys: (batch_size, response_length) stored before pop in train_step - wmc_erc_config: OmegaConf DictConfig or dict with mu_base, lambda_wm, enable + batch: DataProto or compatible object + entropys: (batch_size, response_length) + wmc_erc_config: OmegaConf DictConfig or dict + running_stats: Dictionary for global running statistics Returns: - (batch, metrics) where batch has masked advantages and metrics dict + (batch, metrics) """ enable = wmc_erc_config.get("enable", True) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'enable', True) if not enable: return batch, {} + clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch") + response_mask = batch.batch["response_mask"] old_log_probs = batch.batch["old_log_probs"] advantages = batch.batch["advantages"] - - # Compute attention mask for response region response_length = advantages.shape[1] attention_mask = batch.batch["attention_mask"] attention_mask_response = attention_mask[:, -response_length:] @@ -255,18 +240,56 @@ def apply_wmc_erc( # 1. Turn boundaries turn_boundaries = compute_turn_boundaries(response_mask) - # 2. S_* per turn + # 2. Compute S_* and H_WM per turn s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries) - - # 3. H_WM per turn h_wm = compute_h_wm(entropys, response_mask, attention_mask_response, turn_boundaries) + # Calculate batch statistics + all_s = [s.item() for turns in s_star for s in turns] + all_h = [h.item() for turns in h_wm for h in turns] + + if not all_s: + return batch, {} + + batch_s_bar = np.mean(all_s) + batch_s_std = np.std(all_s) + 1e-8 + batch_h_bar = np.mean(all_h) + 1e-8 + + # Update global statistics + momentum = wmc_erc_config.get("momentum", 0.9) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'momentum', 0.9) + if not running_stats.get("initialized", False): + running_stats["s_bar"] = batch_s_bar + running_stats["s_std"] = batch_s_std + running_stats["h_bar"] = batch_h_bar + running_stats["initialized"] = True + else: + running_stats["s_bar"] = (1 - momentum) * batch_s_bar + momentum * running_stats["s_bar"] + running_stats["s_std"] = (1 - momentum) * batch_s_std + momentum * running_stats["s_std"] + running_stats["h_bar"] = (1 - momentum) * batch_h_bar + momentum * running_stats["h_bar"] + + # Select statistics for masking + if clipping_type == "global": + use_s_bar = running_stats["s_bar"] + use_s_std = running_stats["s_std"] + use_h_bar = running_stats["h_bar"] + else: + use_s_bar = batch_s_bar + use_s_std = batch_s_std + use_h_bar = batch_h_bar + # 4. Dynamic mask mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0)) + mu_exp = float(wmc_erc_config.get("mu_exp", 2.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_exp', 2.0)) lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0)) - mask = compute_dynamic_mask(s_star, h_wm, mu_base, lambda_wm) - - # 5. Apply mask to advantages (in-place) + + mask = compute_dynamic_mask( + s_star, h_wm, mu_base, mu_exp, lambda_wm, + s_bar=use_s_bar, + sigma=use_s_std, + h_bar=use_h_bar + ) + + # 5. Apply mask to advantages batch_size = advantages.shape[0] for i in range(batch_size): for t, (start, end) in enumerate(turn_boundaries[i]): @@ -275,21 +298,33 @@ def apply_wmc_erc( batch.batch["advantages"] = advantages # 6. Metrics - all_s = [s.item() for turns in s_star for s in turns] - all_h = [h.item() for turns in h_wm for h in turns] all_m = [m for turns in mask for m in turns] + + num_collapsing_masked = 0 + num_exploration_masked = 0 + for i in range(len(s_star)): + for t in range(len(s_star[i])): + if mask[i][t] == 0.0: + if s_star[i][t].item() > use_s_bar: + num_collapsing_masked += 1 + else: + num_exploration_masked += 1 - # WM NLL (monitoring only — not in backward pass for this prototype) env_mask = attention_mask_response * (1.0 - response_mask) env_count = env_mask.sum() wm_nll = (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() if env_count > 0 else 0.0 metrics = { - "wmc_erc/s_star_mean": float(np.mean(all_s)) if all_s else 0.0, - "wmc_erc/s_star_std": float(np.std(all_s)) if all_s else 0.0, - "wmc_erc/h_wm_mean": float(np.mean(all_h)) if all_h else 0.0, + "wmc_erc/batch_s_bar": float(batch_s_bar), + "wmc_erc/batch_s_std": float(batch_s_std), + "wmc_erc/batch_h_bar": float(batch_h_bar), + "wmc_erc/running_s_bar": float(running_stats["s_bar"]), + "wmc_erc/running_s_std": float(running_stats["s_std"]), + "wmc_erc/running_h_bar": float(running_stats["h_bar"]), "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), + "wmc_erc/num_collapsing_masked": num_collapsing_masked, + "wmc_erc/num_exploration_masked": num_exploration_masked, "wmc_erc/total_turns": len(all_m), "wmc_erc/wm_nll": wm_nll, } diff --git a/opentinker/client/alfworld_rl.py b/opentinker/client/alfworld_rl.py index 0810cc5d..df02c10a 100644 --- a/opentinker/client/alfworld_rl.py +++ b/opentinker/client/alfworld_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/android_world_rl.py b/opentinker/client/android_world_rl.py index 122b02f1..8af8e265 100644 --- a/opentinker/client/android_world_rl.py +++ b/opentinker/client/android_world_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 0f183186..c3fac911 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -60,7 +60,7 @@ interaction: env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} # If you run the ALFWorld env server in sharded mode (--shards N), # set env_shards=N. The client will route each instance_id to a stable shard. - env_shards: 32 + env_shards: 8 max_steps: 20 # ALFWorld episodes max steps max_total_steps: 20 # Max environment step calls (controls rollout turns) observation_template: "{observation}" @@ -80,4 +80,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index c6cf5c19..a788ce24 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -45,17 +45,23 @@ max_prompt_tokens: 2048 algorithm: "agent_loop" # Use per-step advantage for multi-turn credit assignment -adv_estimator: "grpo_per_step" +adv_estimator: "grpo" rollout_n: 8 # WMC-ERC: Dynamic Entropy Clipping -# - mu_base: base clipping coefficient (controls tightness of the gate) -# - lambda_wm: how much WM uncertainty widens the gate (higher = more tolerant in unknown regions) +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch wmc_erc: enable: true mu_base: 1.0 + mu_exp: 2.0 lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 # Interaction configuration interaction: @@ -83,4 +89,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: null # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml new file mode 100644 index 00000000..94cf216e --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml @@ -0,0 +1,93 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc_ppo + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "gae" +rollout_n: 1 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 2.0 + lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 +agent_num_workers: 4 diff --git a/opentinker/client/geo3k_rl.py b/opentinker/client/geo3k_rl.py index f45f5f8f..f51b7a0a 100644 --- a/opentinker/client/geo3k_rl.py +++ b/opentinker/client/geo3k_rl.py @@ -70,6 +70,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/geo3k_tool_rl.py b/opentinker/client/geo3k_tool_rl.py index e812132f..9f08d9a7 100644 --- a/opentinker/client/geo3k_tool_rl.py +++ b/opentinker/client/geo3k_tool_rl.py @@ -93,6 +93,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/gomoku_rl.py b/opentinker/client/gomoku_rl.py index bf42ca8e..cb4b2608 100755 --- a/opentinker/client/gomoku_rl.py +++ b/opentinker/client/gomoku_rl.py @@ -114,6 +114,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/math_rl.py b/opentinker/client/math_rl.py index 5bffb2df..804cc545 100755 --- a/opentinker/client/math_rl.py +++ b/opentinker/client/math_rl.py @@ -74,6 +74,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/math_tool_rl.py b/opentinker/client/math_tool_rl.py index 1c68b06d..3ef1d3ce 100755 --- a/opentinker/client/math_tool_rl.py +++ b/opentinker/client/math_tool_rl.py @@ -73,6 +73,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 08f65e9b..958ba2d2 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -579,6 +579,7 @@ def __init__( project_name: Optional[str] = None, experiment_name: Optional[str] = None, logger_backends: Optional[List[str]] = None, + config: Optional[Any] = None, **client_kwargs, ): self.client = HTTPTrainingClient(server_url, **client_kwargs) @@ -588,11 +589,23 @@ def __init__( if logger_backends and project_name and experiment_name: from verl.utils.tracking import Tracking + # Convert DictConfig to dict if necessary for Tracking + tracking_config = config + if config is not None and not isinstance(config, dict): + from omegaconf import OmegaConf + + tracking_config = OmegaConf.to_container(config, resolve=True) + + # Ensure 'trainer' key exists to avoid KeyError in verl.utils.tracking + if tracking_config is not None: + if "trainer" not in tracking_config: + tracking_config["trainer"] = {} + self.tracker = Tracking( project_name=project_name, experiment_name=experiment_name, default_backend=logger_backends, - config=None, # Can pass config if needed + config=tracking_config, ) logger.info(f"Initialized tracking with backends: {logger_backends}") @@ -807,10 +820,11 @@ def fit( # Update progress bar if verbose and progress_bar: # Show key metrics in progress bar (filter game/ metrics except win_rate) + # Added wmc_erc/mask_ratio to monitor dynamic entropy clipping display_metrics = { k: v for k, v in last_metrics.items() - if not k.startswith("game/") or k == "game/win_rate" + if not k.startswith("game/") or k == "game/win_rate" or k == "wmc_erc/mask_ratio" } metrics_str = ", ".join( [ diff --git a/opentinker/environment/__init__.py b/opentinker/environment/__init__.py index 6d096c35..7bc88a03 100755 --- a/opentinker/environment/__init__.py +++ b/opentinker/environment/__init__.py @@ -32,14 +32,33 @@ run_game_server, ) -from opentinker.environment.inference_pipeline import ( - InferencePipeline, - InferenceResult, - RemoteEnvironmentClient, - run_inference, - load_samples, - generate_samples, -) +# Lazy import for InferencePipeline to avoid heavy dependencies (like vllm) +# when only the game server is needed. +def __getattr__(name): + if name in [ + "InferencePipeline", + "InferenceResult", + "RemoteEnvironmentClient", + "run_inference", + "load_samples", + "generate_samples", + ]: + from opentinker.environment.inference_pipeline import ( + InferencePipeline, + InferenceResult, + RemoteEnvironmentClient, + run_inference, + load_samples, + generate_samples, + ) + globals()["InferencePipeline"] = InferencePipeline + globals()["InferenceResult"] = InferenceResult + globals()["RemoteEnvironmentClient"] = RemoteEnvironmentClient + globals()["run_inference"] = run_inference + globals()["load_samples"] = load_samples + globals()["generate_samples"] = generate_samples + return globals()[name] + raise AttributeError(f"module {__name__} has no attribute {name}") __all__ = [ # Base diff --git a/opentinker/environment/base_game_server.py b/opentinker/environment/base_game_server.py index 668add10..77c2a59c 100755 --- a/opentinker/environment/base_game_server.py +++ b/opentinker/environment/base_game_server.py @@ -383,42 +383,48 @@ async def health_check(): @app.post("/reset") async def reset(request: ResetRequest): """Reset/create a game instance.""" - instance_id = request.instance_id - job_id = request.job_id - # Extract extra fields for game reset (exclude instance_id and job_id) - reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) + try: + instance_id = request.instance_id + job_id = request.job_id + # Extract extra fields for game reset (exclude instance_id and job_id) + reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) - with games_lock: - # Reuse existing game instance if available (avoids re-initialization) - if instance_id in games: - game = games[instance_id] - else: - game = game_class(**game_kwargs) - games[instance_id] = game + with games_lock: + # Reuse existing game instance if available (avoids re-initialization) + if instance_id in games: + game = games[instance_id] + else: + game = game_class(**game_kwargs) + games[instance_id] = game - # Reset the game (this is the slow part) - observation = game.reset(**reset_kwargs) + # Reset the game (this is the slow part) + observation = game.reset(**reset_kwargs) - # Track that this game has started (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.register_game_start(instance_id, job_id) + # Track that this game has started (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.register_game_start(instance_id, job_id) - system_prompt = game.get_system_prompt() - initial_message = game.get_initial_user_message() - full_observation = ( - f"{initial_message}\n\n{observation}" if observation else initial_message - ) + system_prompt = game.get_system_prompt() + initial_message = game.get_initial_user_message() + full_observation = ( + f"{initial_message}\n\n{observation}" if observation else initial_message + ) - response = { - "observation": full_observation, - "system_prompt": system_prompt, - } + response = { + "observation": full_observation, + "system_prompt": system_prompt, + } - state = game.get_state() - if state: - response["game_state"] = state + state = game.get_state() + if state: + response["game_state"] = state - return response + return response + except Exception as e: + import traceback + error_msg = f"Reset failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + raise HTTPException(status_code=500, detail=error_msg) @app.post("/finalize") async def finalize(request: ResetRequest): @@ -433,37 +439,45 @@ async def finalize(request: ResetRequest): @app.post("/step") async def step(request: StepRequest): """Execute a step in the game.""" - instance_id = request.instance_id - job_id = request.job_id - action = request.action + try: + instance_id = request.instance_id + job_id = request.job_id + action = request.action + + if instance_id not in games: + raise HTTPException( + status_code=404, + detail=f"Instance {instance_id} not found. Call /reset first.", + ) - if instance_id not in games: - raise HTTPException( - status_code=404, - detail=f"Instance {instance_id} not found. Call /reset first.", - ) + game = games[instance_id] + result = game.step(action) - game = games[instance_id] - result = game.step(action) + # Record statistics with instance_id for per-game tracking (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.record_game_result( + result.info, result.reward, result.done, instance_id, job_id + ) - # Record statistics with instance_id for per-game tracking (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.record_game_result( - result.info, result.reward, result.done, instance_id, job_id - ) + # Clean up finished games + if result.done: + with games_lock: + if instance_id in games: + del games[instance_id] - # Clean up finished games - if result.done: - with games_lock: - if instance_id in games: - del games[instance_id] - - return { - "observation": result.observation, - "reward": result.reward, - "done": result.done, - "info": result.info, - } + return { + "observation": result.observation, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + except Exception as e: + import traceback + error_msg = f"Step failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=error_msg) @app.get("/stats") async def get_stats(job_id: str = "default"): diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 2a1824d4..abc5b42b 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -1035,6 +1035,22 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: Popen process object """ env = os.environ.copy() + + # [FIX] Check for broken libcuda.so.1 (size 0) - common issue in some environments + libcuda_path = "/usr/lib/x86_64-linux-gnu/libcuda.so.1" + if os.path.exists(libcuda_path) and os.path.getsize(libcuda_path) == 0: + compat_path = "/usr/local/cuda-12.4/compat" + if os.path.isdir(compat_path): + current_ld_path = env.get("LD_LIBRARY_PATH", "") + env["LD_LIBRARY_PATH"] = ( + f"{compat_path}:{current_ld_path}".strip(":") + if current_ld_path + else compat_path + ) + logger.info( + f"Job {job.job_id}: 🛠 Fixed broken libcuda.so.1 by adding {compat_path} to LD_LIBRARY_PATH" + ) + # Set CUDA_VISIBLE_DEVICES to comma-separated list of GPU IDs env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, job.gpu_ids)) # Pass job_id to agent loop for per-client trace subdirectory isolation diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh new file mode 100755 index 00000000..97ab8252 --- /dev/null +++ b/opentinker/scripts/run_alfworld.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# ALFWorld Training Script (Multi-Turn) +# +# This script runs ALFWorld RL training with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env +# Terminal 3: bash run_alfworld.sh client +# +# Prerequisites: +# - pip install alfworld +# - alfworld-download +# - See docs/alfworld_multiturn.md for environment setup + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT="${SCHEDULER_PORT:-9780}" +ENV_PORT="${ENV_PORT:-1234}" +GPUS="${GPUS:-[0,1,2,3]}" +MODEL_PATH="/inspire/hdd/project/robot-reasoning/xuyue-p-xuyue/ziyu/.cache/huggingface/hub/models--Qwen--Qwen2.5-3B-Instruct" + +# OpenTinker root (relative to this script: opentinker/scripts/run_alfworld.sh) +OPENTINKER_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +export VLLM_DISABLE_SLEEP_MODE=1 +export HF_HUB_OFFLINE=1 +export WANDB_MODE=offline + +# Activate conda environment (adjust to your setup if needed) +if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/anaconda3/etc/profile.d/conda.sh" +elif [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +fi +# conda activate opentinker + +# Change to OpenTinker directory +cd "$OPENTINKER_ROOT" + +# Get current host IP for communication between components +# Use 127.0.0.1 if running everything on the same machine +HOST_IP="${HOST_IP:-127.0.0.1}" + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + env|2) + echo "========================================" + echo "Step 2: Starting ALFWorld Environment Server on port $ENV_PORT" + echo "========================================" + python -m opentinker.environment.alfworld.alfworld_server \ + --port "$ENV_PORT" \ + --max_steps 50 \ + --split train \ + --num_games -1 \ + ;; + + client|3) + echo "========================================" + echo "Step 3: Starting ALFWorld RL Client" + echo "========================================" + python opentinker/client/alfworld_rl.py \ + --config-name alfworld_wmc_erc_param \ + tokenizer_path="$MODEL_PATH" \ + batch_size=4 \ + val_batch_size=50 \ + num_steps=1000 \ + save_freq=2000 \ + test_freq=10 \ + scheduler_url="http://$HOST_IP:$SCHEDULER_PORT" \ + interaction.config.env_port="$ENV_PORT" \ + interaction.config.env_host="$HOST_IP" + ;; + + *) + echo "Usage: $0 {scheduler|env|client}" + echo "" + echo "Example (separate terminals):" + echo " Terminal 1: bash $0 scheduler" + echo " Terminal 2: bash $0 env" + echo " Terminal 3: bash $0 client" + exit 1 + ;; +esac diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index c4fe1ce0..9ef27d05 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -641,6 +641,15 @@ def __init__( self.is_initialized = False self.global_steps = 0 + # WMC-ERC running statistics (S_bar, Sigma, H_bar) + self.wmc_erc_stats = { + "s_bar": 0.0, + "s_std": 1.0, + "h_bar": 1.0, + "momentum": 0.9, # EMA momentum + "initialized": False, + } + # Generation config (can be overridden by client) self.generation_config = { "do_sample": True, # CRITICAL: Enable sampling by default for PPO training @@ -1174,13 +1183,14 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: ) batch, wmc_metrics = apply_wmc_erc( - batch, _wmc_erc_entropys, wmc_erc_cfg + batch, _wmc_erc_entropys, wmc_erc_cfg, self.wmc_erc_stats ) metrics.update(wmc_metrics) + clipping_mode = wmc_erc_cfg.get("clipping_type", "batch") logger.info( - f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " - f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " - f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}" + f"[WMC-ERC] mode={clipping_mode}, mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " + f"s_star={wmc_metrics.get('wmc_erc/batch_s_bar', 'N/A'):.4f}, " + f"h_wm={wmc_metrics.get('wmc_erc/batch_h_bar', 'N/A'):.4f}" ) # 10. Update critic diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py index 3d53a335..f083277a 100644 --- a/opentinker/tests/test_wmc_erc.py +++ b/opentinker/tests/test_wmc_erc.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests for WMC-ERC dynamic entropy clipping module. +Tests for WMC-ERC dynamic entropy clipping module using unittest. -Run with: pytest opentinker/tests/test_wmc_erc.py -v +Run with: python opentinker/tests/test_wmc_erc.py """ +import unittest import numpy as np -import pytest import torch +from unittest.mock import MagicMock from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( compute_turn_boundaries, @@ -32,12 +33,11 @@ ) -class TestComputeSStar: - """Test S_* (policy blind confidence) computation.""" +class TestWmcErc(unittest.TestCase): + """Test WMC-ERC components and orchestration.""" - def test_single_turn(self): + def test_single_turn_s_star(self): """S_* for a single turn = mean of p_k * (H + log p_k) over action tokens.""" - # 1 sample, 6 positions: 4 action + 2 padding p = torch.tensor([0.8, 0.6, 0.9, 0.7]) old_log_probs = torch.zeros(1, 6) old_log_probs[0, :4] = torch.log(p) @@ -46,214 +46,34 @@ def test_single_turn(self): boundaries = compute_turn_boundaries(response_mask) result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 1 # 1 sample - assert len(result[0]) == 1 # 1 turn + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 1) - # Manual: p_k * (H + log(p_k)) for each token, then mean H = torch.tensor([1.0, 1.5, 0.5, 1.2]) expected = (p * (H + torch.log(p))).mean().item() - assert abs(result[0][0].item() - expected) < 1e-5 - - def test_multi_turn(self): - """Two turns should produce two S_* values.""" - old_log_probs = torch.log( - torch.tensor([[0.8, 0.6, 0.5, 0.5, 0.9, 0.7, 0.5, 0.5]]) - ) - entropys = torch.tensor([[1.0, 1.5, 0.0, 0.0, 0.5, 1.2, 0.0, 0.0]]) - response_mask = torch.tensor( - [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 - ) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result[0]) == 2 # 2 turns - # Turn 0 and Turn 1 should have different values - assert result[0][0].item() != result[0][1].item() - - def test_batch(self): - """Batch of 2 samples.""" - old_log_probs = torch.log( - torch.tensor( - [ - [0.8, 0.6, 0.5, 0.5], - [0.5, 0.9, 0.5, 0.5], - ] - ) - ) - entropys = torch.tensor( - [ - [1.0, 1.5, 0.0, 0.0], - [2.0, 0.3, 0.0, 0.0], - ] - ) - response_mask = torch.tensor( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - ], - dtype=torch.float32, - ) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 2 - - def test_empty_turns(self): - """Sample with no action tokens should produce empty list.""" - old_log_probs = torch.zeros(1, 4) - entropys = torch.zeros(1, 4) - response_mask = torch.zeros(1, 4, dtype=torch.float32) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 1 - assert len(result[0]) == 0 - - -class TestComputeHWM: - """Test H_WM (world model uncertainty) computation.""" - - def test_single_turn_with_env_tokens(self): - """H_WM for turn 0 = mean entropy at env positions after turn 0.""" - # Sequence: [action, action, env, env, pad, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.0, 0.0]]) - response_mask = torch.tensor([[1, 1, 0, 0, 0, 0]], dtype=torch.float32) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2)]] # 1 turn: action at [0,2) - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result) == 1 - assert len(result[0]) == 1 - # Env tokens at [2,3] → mean entropy = (3.0 + 4.0) / 2 = 3.5 - assert abs(result[0][0].item() - 3.5) < 1e-5 - - def test_two_turns(self): - """Two turns: H_WM_0 from env between turns, H_WM_1 from env after turn 1.""" - # [act, act, env, env, act, act, env, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.5, 1.2, 2.0, 0.0]]) - response_mask = torch.tensor( - [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 - ) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2), (4, 6)]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result[0]) == 2 - # Turn 0 env: positions [2,4) → (3.0+4.0)/2 = 3.5 - assert abs(result[0][0].item() - 3.5) < 1e-5 - # Turn 1 env: positions [6,8) but attn_mask=[1,0] → only pos 6 → 2.0 - assert abs(result[0][1].item() - 2.0) < 1e-5 - - def test_no_env_after_last_turn(self): - """Last turn has no env tokens → H_WM = 0.""" - # [act, act, env, act, act, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 0.5, 1.2, 0.0]]) - response_mask = torch.tensor([[1, 1, 0, 1, 1, 0]], dtype=torch.float32) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 1, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2), (3, 5)]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - # Turn 0 env: positions [2,3) → 3.0 - assert abs(result[0][0].item() - 3.0) < 1e-5 - # Turn 1 env: positions [5,6) but attn_mask=[0] → H_WM = 0 - assert result[0][1].item() == 0.0 - - def test_empty_turns(self): - """No turns → empty H_WM list.""" - entropys = torch.zeros(1, 4) - response_mask = torch.zeros(1, 4, dtype=torch.float32) - attention_mask_response = torch.ones(1, 4, dtype=torch.float32) - boundaries = [[]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result[0]) == 0 - - -class TestComputeDynamicMask: - """Test dynamic entropy clipping mask.""" - - def test_all_pass(self): - """When all S_* are close to mean, all masks = 1.""" - s_star = [[torch.tensor(1.0), torch.tensor(1.1)]] - h_wm = [[torch.tensor(0.5), torch.tensor(0.5)]] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=3.0, lambda_wm=1.0) - assert mask == [[1.0, 1.0]] - - def test_outlier_blocked_low_hwm(self): - """High S_* outlier with low H_WM (known env) → blocked.""" - s_star = [ - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(10.0)], # outlier - ] - h_wm = [ - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], # known env → tight threshold - ] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # The outlier (10.0) should be blocked when threshold is tight - assert mask[3] == [0.0] - # Normal ones should pass - assert mask[0] == [1.0] - - def test_outlier_allowed_high_hwm(self): - """High S_* outlier with high H_WM (unknown env) → allowed.""" - s_star = [ - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(10.0)], # outlier - ] - h_wm = [ - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(100.0)], # very uncertain → wide threshold - ] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # The outlier should be allowed because H_WM is very high - assert mask[3] == [1.0] - - def test_empty(self): - """Empty input should return empty.""" - mask = compute_dynamic_mask([], [], mu_base=1.0, lambda_wm=1.0) - assert mask == [] - - def test_single_element(self): - """Single S_* value should not produce nan (std guard).""" - s_star = [[torch.tensor(5.0)]] - h_wm = [[torch.tensor(1.0)]] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # Single element: |5.0 - 5.0| = 0 <= threshold → should pass - assert mask == [[1.0]] - - -class TestApplyWmcErc: - """Test full WMC-ERC orchestration.""" - - def _make_batch( - self, advantages, response_mask, old_log_probs, attention_mask - ): - """Create a minimal mock batch with required fields.""" - from unittest.mock import MagicMock - + self.assertAlmostEqual(result[0][0].item(), expected, places=5) + + def test_asymmetric_behavior(self): + """Test that mu_base and mu_exp act differently using compute_dynamic_mask.""" + s_star = [[torch.tensor(15.0)], [torch.tensor(5.0)]] # Mean=10 + h_wm = [[torch.tensor(1.0)]] * 2 + s_bar = 10.0 + sigma = 5.0 + h_bar = 1.0 + + # 1. mu_base=0.1 (block high), mu_exp=10.0 (allow low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=0.1, mu_exp=10.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, h_bar=h_bar) + self.assertEqual(mask[0], [0.0]) # High blocked + self.assertEqual(mask[1], [1.0]) # Low allowed + + # 2. mu_base=10.0 (allow high), mu_exp=0.1 (block low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=10.0, mu_exp=0.1, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, h_bar=h_bar) + self.assertEqual(mask[0], [1.0]) # High allowed + self.assertEqual(mask[1], [0.0]) # Low blocked + + def _make_batch(self, advantages, response_mask, old_log_probs, attention_mask): batch = MagicMock() batch.batch = { "advantages": advantages.clone(), @@ -263,141 +83,46 @@ def _make_batch( } return batch - def test_masking_zeros_advantage(self): - """When a turn is masked, its advantages become zero.""" - # 4 samples, 1 turn each (4 action tokens + 0 env tokens) - # Sample 3 has extremely different S_* pattern + def test_clipping_type_batch(self): + """Verify that 'batch' mode uses current batch statistics.""" response_mask = torch.ones(4, 4, dtype=torch.float32) advantages = torch.ones(4, 4) * 2.0 - # Make sample 3 very overconfident (high p_k, low H) - old_log_probs = torch.tensor( - [ - [np.log(0.3)] * 4, - [np.log(0.3)] * 4, - [np.log(0.3)] * 4, - [np.log(0.99)] * 4, # very high confidence - ] - ) - entropys = torch.tensor( - [ - [2.0, 2.0, 2.0, 2.0], - [2.0, 2.0, 2.0, 2.0], - [2.0, 2.0, 2.0, 2.0], - [0.01, 0.01, 0.01, 0.01], # very low entropy - ] - ) + # Sample 3 is an outlier in this batch + old_log_probs = torch.tensor([[np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.9)]*4]) + entropys = torch.tensor([[2.0]*4, [2.0]*4, [2.0]*4, [3.0]*4]) attention_mask = torch.ones(4, 4, dtype=torch.float32) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - - # Check metrics exist - assert "wmc_erc/mask_ratio" in metrics - assert "wmc_erc/s_star_mean" in metrics - assert "wmc_erc/h_wm_mean" in metrics - assert "wmc_erc/total_turns" in metrics - assert metrics["wmc_erc/total_turns"] == 4 - - # Verify masking behavior: - # Samples 0-2: S_* = 0.3*(2.0+log(0.3)) ≈ 0.239 (normal) - # Sample 3: S_* = 0.99*(0.01+log(0.99)) ≈ 0 (outlier in opposite direction) - # H_WM = 0 for all (no env tokens) → tight threshold - # Sample 3 should be masked (|S_3 - S_bar| > threshold) - adv = batch.batch["advantages"] - assert metrics["wmc_erc/num_masked_turns"] >= 1 - assert (adv[3] == 0).all(), "Sample 3 (overconfident outlier) should have zero advantages" - assert (adv[:3] == 2.0).all(), "Samples 0-2 (normal) should keep original advantages" - - def test_disabled(self): - """When enable=False, advantages unchanged.""" - response_mask = torch.ones(2, 4, dtype=torch.float32) - advantages = torch.ones(2, 4) * 5.0 - old_log_probs = torch.full((2, 4), np.log(0.5)) - attention_mask = torch.ones(2, 4, dtype=torch.float32) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - entropys = torch.ones(2, 4) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": False} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert (batch.batch["advantages"] == 5.0).all() - assert metrics == {} - - def test_multi_turn_selective_masking(self): - """Multi-turn: only overconfident turns in known env get masked.""" - # 2 samples, 2 turns each: [act, act, env, env, act, act, env, pad] - response_mask = torch.tensor( - [ - [1, 1, 0, 0, 1, 1, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0], - ], - dtype=torch.float32, - ) - advantages = torch.ones(2, 8) * 3.0 - # Sample 0: normal confidence - # Sample 1: overconfident on both turns - old_log_probs = torch.tensor( - [ - [np.log(0.3), np.log(0.3), 0, 0, np.log(0.3), np.log(0.3), 0, 0], - [np.log(0.99), np.log(0.99), 0, 0, np.log(0.99), np.log(0.99), 0, 0], - ] - ) - entropys = torch.tensor( - [ - [2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.0], - [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0], - ] - ) - attention_mask = torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 0], - ], - dtype=torch.float32, - ) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert metrics["wmc_erc/total_turns"] == 4 - # With low H_WM on sample 1, its overconfident turns should be blocked - # This is a statistical test — the exact outcome depends on batch stats - - def test_returns_wm_nll_metric(self): - """WM NLL metric should be computed from env token log probs.""" - response_mask = torch.tensor( - [[1, 1, 0, 0, 0, 0]], dtype=torch.float32 - ) - advantages = torch.ones(1, 6) - old_log_probs = torch.tensor( - [[np.log(0.5), np.log(0.5), np.log(0.3), np.log(0.4), 0, 0]] - ) - entropys = torch.tensor([[1.0, 1.0, 2.0, 3.0, 0.0, 0.0]]) - attention_mask = torch.tensor( - [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 - ) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 5.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert "wmc_erc/wm_nll" in metrics - # WM NLL = -mean(log_prob at env positions [2,3]) - # = -(log(0.3) + log(0.4)) / 2 - expected_nll = -(np.log(0.3) + np.log(0.4)) / 2.0 - assert abs(metrics["wmc_erc/wm_nll"] - expected_nll) < 1e-4 + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + config = {"mu_base": 0.1, "mu_exp": 10.0, "lambda_wm": 1.0, "enable": True, "clipping_type": "batch"} + running_stats = {"s_bar": 100.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} # Very different global stats + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # In 'batch' mode, Sample 3 should be blocked based on BATCH mean, not GLOBAL mean + # If it used global mean (100), sample 3 (S*~2.6) would be an exploration outlier and allowed by mu_exp=10 + # But in batch mode (S_bar~0.8), sample 3 is a collapsing outlier and blocked by mu_base=0.1 + self.assertTrue((batch.batch["advantages"][3] == 0).all()) + + def test_clipping_type_global(self): + """Verify that 'global' mode uses running statistics.""" + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + old_log_probs = torch.tensor([[np.log(0.3)]*4]*4) + entropys = torch.tensor([[2.0]*4]*4) + # S* for all samples will be ~0.24 + + attention_mask = torch.ones(4, 4, dtype=torch.float32) + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Global s_bar is far away (10.0), so batch S* (0.24) looks like a huge exploration outlier + config = {"mu_base": 10.0, "mu_exp": 0.1, "lambda_wm": 0.0, "enable": True, "clipping_type": "global"} + running_stats = {"s_bar": 10.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # Should be blocked because mu_exp=0.1 is very tight relative to global s_bar + self.assertTrue((batch.batch["advantages"] == 0).all()) if __name__ == "__main__": - pytest.main([__file__, "-v"]) + unittest.main()