diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py new file mode 100644 index 00000000..c06bcf02 --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -0,0 +1,332 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +WMC-ERC: World Model-Conditioned Entropy Regularized Co-evolution. + +Dynamic entropy clipping for multi-turn agentic RL. Uses the LLM's own +prediction entropy at environment token positions as a World Model uncertainty +signal (H_WM) to dynamically gate policy gradient updates, preventing +entropy collapse in well-understood regions while permitting exploration +in uncertain ones. + +Core idea: +- S_* measures per-turn "blind confidence" of the policy (entropy collapse momentum) +- H_WM measures per-turn World Model uncertainty (prediction entropy at env tokens) +- Dynamic mask m_t: when WM is confident but policy is overconfident → block update + when WM is uncertain → allow exploration regardless of confidence + +Reference: World Model-Conditioned Entropy Regularized Co-evolution (WMC-ERC) +""" + +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) + + +def compute_s_star( + old_log_probs: torch.Tensor, + entropys: torch.Tensor, + response_mask: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute per-turn policy blind confidence S_*. + + S_*^t = mean over action tokens in turn t of: p_k * (H + log p_k) + + where p_k is the probability of the chosen action token and H is the + full-distribution entropy. This quantity measures how strongly the + policy's token-level probability mass is driving entropy downward: + high S_* means the policy is aggressively collapsing toward a single + action, creating momentum for entropy collapse. + + Based on first-order Taylor expansion of the entropy discriminator + (see WMC-ERC algorithm specification). + + Args: + old_log_probs: (batch_size, response_length) log probs of chosen tokens + entropys: (batch_size, response_length) entropy of policy distribution + response_mask: (batch_size, response_length) 1=action, 0=env/pad + turn_boundaries: per-sample list of (start, end) tuples for action turns + + Returns: + List of lists of scalar tensors, one S_* per turn per sample. + """ + batch_size = old_log_probs.shape[0] + device = old_log_probs.device + s_star_per_sample = [] + + for i in range(batch_size): + s_star_turns = [] + for start, end in turn_boundaries[i]: + log_p = old_log_probs[i, start:end] + H = entropys[i, start:end] + mask = response_mask[i, start:end] + count = mask.sum() + + if count > 0: + p_k = torch.exp(log_p) + s_token = p_k * (H + log_p) + s_t = (s_token * mask).sum() / count + else: + s_t = torch.tensor(0.0, device=device) + + s_star_turns.append(s_t) + s_star_per_sample.append(s_star_turns) + + return s_star_per_sample + + +def compute_h_wm( + entropys: torch.Tensor, + response_mask: torch.Tensor, + attention_mask_response: torch.Tensor, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[torch.Tensor]]: + """Compute per-turn World Model uncertainty H_WM. + + H_WM^t = mean prediction entropy at env token positions following action turn t. + + Env tokens after turn t represent the environment's response to action a_t. + The model's entropy at these positions measures how uncertain it is about + predicting the next state — i.e., the World Model's "cognitive blind spot". + + Higher H_WM → model doesn't understand this environment transition well. + Lower H_WM → model has seen similar transitions and is confident. + + Args: + entropys: (batch_size, response_length) entropy at all positions + response_mask: (batch_size, response_length) 1=action, 0=env/pad + attention_mask_response: (batch_size, response_length) 1=real, 0=padding + turn_boundaries: per-sample list of (start, end) for action turns + + Returns: + List of lists of scalar tensors, one H_WM per turn per sample. + """ + batch_size = entropys.shape[0] + seq_len = entropys.shape[1] + device = entropys.device + env_mask = attention_mask_response * (1.0 - response_mask) + + h_wm_per_sample = [] + + for i in range(batch_size): + boundaries = turn_boundaries[i] + h_wm_turns = [] + + for t, (start, end) in enumerate(boundaries): + # Env tokens after this turn: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(boundaries): + env_end = boundaries[t + 1][0] + else: + env_end = seq_len + + region_mask = env_mask[i, end:env_end] + region_entropy = entropys[i, end:env_end] + count = region_mask.sum() + + if count > 0: + h_wm_t = (region_entropy * region_mask).sum() / count + else: + h_wm_t = torch.tensor(0.0, device=device) + + h_wm_turns.append(h_wm_t) + + h_wm_per_sample.append(h_wm_turns) + + return h_wm_per_sample + + +def compute_dynamic_mask( + s_star_per_sample: List[List[torch.Tensor]], + h_wm_per_sample: List[List[torch.Tensor]], + mu_base: float, + mu_exp: float, + lambda_wm: float, + s_bar: float, + sigma: float, + h_bar: float, +) -> List[List[float]]: + """Compute per-turn dynamic entropy clipping mask. + + Logic: + - Normalization: H_WM is normalized by h_bar (batch or global). + - Threshold: threshold = mu * (1 + lambda * H_WM_norm) * sigma + + Args: + s_star_per_sample: per-sample, per-turn S_* tensors + h_wm_per_sample: per-sample, per-turn H_WM tensors + mu_base: clipping coefficient for collapsing side + mu_exp: clipping coefficient for exploration side + lambda_wm: WM uncertainty weight + s_bar: mean of S_* (batch or global) + sigma: std of S_* (batch or global) + h_bar: mean of H_WM (batch or global) + + Returns: + List of lists of floats (0.0 or 1.0), one mask per turn per sample. + """ + mask_per_sample = [] + for i in range(len(s_star_per_sample)): + masks = [] + for t in range(len(s_star_per_sample[i])): + s_t = s_star_per_sample[i][t].detach().item() + h_t = h_wm_per_sample[i][t].detach().item() + + # Normalize H_WM + h_t_norm = h_t / (h_bar + 1e-8) + + # Asymmetric threshold calculation + if s_t > s_bar: + # Collapsing side + threshold = mu_base * (1.0 + lambda_wm * h_t_norm) * sigma + m_t = 1.0 if (s_t - s_bar) <= threshold else 0.0 + else: + # Exploration side + threshold = mu_exp * (1.0 + lambda_wm * h_t_norm) * sigma + m_t = 1.0 if (s_bar - s_t) <= threshold else 0.0 + + masks.append(m_t) + mask_per_sample.append(masks) + + return mask_per_sample + + +def apply_wmc_erc( + batch, + entropys: torch.Tensor, + wmc_erc_config, + running_stats: Dict[str, float], +) -> Tuple[object, Dict[str, float]]: + """Apply WMC-ERC dynamic entropy clipping to batch advantages. + + Args: + batch: DataProto or compatible object + entropys: (batch_size, response_length) + wmc_erc_config: OmegaConf DictConfig or dict + running_stats: Dictionary for global running statistics + + Returns: + (batch, metrics) + """ + enable = wmc_erc_config.get("enable", True) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'enable', True) + if not enable: + return batch, {} + + clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch") + + response_mask = batch.batch["response_mask"] + old_log_probs = batch.batch["old_log_probs"] + advantages = batch.batch["advantages"] + response_length = advantages.shape[1] + attention_mask = batch.batch["attention_mask"] + attention_mask_response = attention_mask[:, -response_length:] + + # 1. Turn boundaries + turn_boundaries = compute_turn_boundaries(response_mask) + + # 2. Compute S_* and H_WM per turn + s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries) + h_wm = compute_h_wm(entropys, response_mask, attention_mask_response, turn_boundaries) + + # Calculate batch statistics + all_s = [s.item() for turns in s_star for s in turns] + all_h = [h.item() for turns in h_wm for h in turns] + + if not all_s: + return batch, {} + + batch_s_bar = np.mean(all_s) + batch_s_std = np.std(all_s) + 1e-8 + batch_h_bar = np.mean(all_h) + 1e-8 + + # Update global statistics + momentum = wmc_erc_config.get("momentum", 0.9) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'momentum', 0.9) + if not running_stats.get("initialized", False): + running_stats["s_bar"] = batch_s_bar + running_stats["s_std"] = batch_s_std + running_stats["h_bar"] = batch_h_bar + running_stats["initialized"] = True + else: + running_stats["s_bar"] = (1 - momentum) * batch_s_bar + momentum * running_stats["s_bar"] + running_stats["s_std"] = (1 - momentum) * batch_s_std + momentum * running_stats["s_std"] + running_stats["h_bar"] = (1 - momentum) * batch_h_bar + momentum * running_stats["h_bar"] + + # Select statistics for masking + if clipping_type == "global": + use_s_bar = running_stats["s_bar"] + use_s_std = running_stats["s_std"] + use_h_bar = running_stats["h_bar"] + else: + use_s_bar = batch_s_bar + use_s_std = batch_s_std + use_h_bar = batch_h_bar + + # 4. Dynamic mask + mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0)) + mu_exp = float(wmc_erc_config.get("mu_exp", 2.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_exp', 2.0)) + lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0)) + + mask = compute_dynamic_mask( + s_star, h_wm, mu_base, mu_exp, lambda_wm, + s_bar=use_s_bar, + sigma=use_s_std, + h_bar=use_h_bar + ) + + # 5. Apply mask to advantages + batch_size = advantages.shape[0] + for i in range(batch_size): + for t, (start, end) in enumerate(turn_boundaries[i]): + if t < len(mask[i]): + advantages[i, start:end] *= mask[i][t] + batch.batch["advantages"] = advantages + + # 6. Metrics + all_m = [m for turns in mask for m in turns] + + num_collapsing_masked = 0 + num_exploration_masked = 0 + for i in range(len(s_star)): + for t in range(len(s_star[i])): + if mask[i][t] == 0.0: + if s_star[i][t].item() > use_s_bar: + num_collapsing_masked += 1 + else: + num_exploration_masked += 1 + + env_mask = attention_mask_response * (1.0 - response_mask) + env_count = env_mask.sum() + wm_nll = (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() if env_count > 0 else 0.0 + + metrics = { + "wmc_erc/batch_s_bar": float(batch_s_bar), + "wmc_erc/batch_s_std": float(batch_s_std), + "wmc_erc/batch_h_bar": float(batch_h_bar), + "wmc_erc/running_s_bar": float(running_stats["s_bar"]), + "wmc_erc/running_s_std": float(running_stats["s_std"]), + "wmc_erc/running_h_bar": float(running_stats["h_bar"]), + "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, + "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), + "wmc_erc/num_collapsing_masked": num_collapsing_masked, + "wmc_erc/num_exploration_masked": num_exploration_masked, + "wmc_erc/total_turns": len(all_m), + "wmc_erc/wm_nll": wm_nll, + } + + return batch, metrics diff --git a/opentinker/client/alfworld_rl.py b/opentinker/client/alfworld_rl.py index 0810cc5d..df02c10a 100644 --- a/opentinker/client/alfworld_rl.py +++ b/opentinker/client/alfworld_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/android_world_rl.py b/opentinker/client/android_world_rl.py index 122b02f1..8af8e265 100644 --- a/opentinker/client/android_world_rl.py +++ b/opentinker/client/android_world_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 0f183186..c3fac911 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -60,7 +60,7 @@ interaction: env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} # If you run the ALFWorld env server in sharded mode (--shards N), # set env_shards=N. The client will route each instance_id to a stable shard. - env_shards: 32 + env_shards: 8 max_steps: 20 # ALFWorld episodes max steps max_total_steps: 20 # Max environment step calls (controls rollout turns) observation_template: "{observation}" @@ -80,4 +80,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml new file mode 100644 index 00000000..a788ce24 --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -0,0 +1,92 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "grpo" +rollout_n: 8 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 2.0 + lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml new file mode 100644 index 00000000..94cf216e --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml @@ -0,0 +1,93 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc_ppo + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "gae" +rollout_n: 1 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 2.0 + lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 +agent_num_workers: 4 diff --git a/opentinker/client/geo3k_rl.py b/opentinker/client/geo3k_rl.py index f45f5f8f..f51b7a0a 100644 --- a/opentinker/client/geo3k_rl.py +++ b/opentinker/client/geo3k_rl.py @@ -70,6 +70,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/geo3k_tool_rl.py b/opentinker/client/geo3k_tool_rl.py index e812132f..9f08d9a7 100644 --- a/opentinker/client/geo3k_tool_rl.py +++ b/opentinker/client/geo3k_tool_rl.py @@ -93,6 +93,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/gomoku_rl.py b/opentinker/client/gomoku_rl.py index bf42ca8e..cb4b2608 100755 --- a/opentinker/client/gomoku_rl.py +++ b/opentinker/client/gomoku_rl.py @@ -114,6 +114,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/math_rl.py b/opentinker/client/math_rl.py index 5bffb2df..804cc545 100755 --- a/opentinker/client/math_rl.py +++ b/opentinker/client/math_rl.py @@ -74,6 +74,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/math_tool_rl.py b/opentinker/client/math_tool_rl.py index 1c68b06d..3ef1d3ce 100755 --- a/opentinker/client/math_tool_rl.py +++ b/opentinker/client/math_tool_rl.py @@ -73,6 +73,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 08f65e9b..958ba2d2 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -579,6 +579,7 @@ def __init__( project_name: Optional[str] = None, experiment_name: Optional[str] = None, logger_backends: Optional[List[str]] = None, + config: Optional[Any] = None, **client_kwargs, ): self.client = HTTPTrainingClient(server_url, **client_kwargs) @@ -588,11 +589,23 @@ def __init__( if logger_backends and project_name and experiment_name: from verl.utils.tracking import Tracking + # Convert DictConfig to dict if necessary for Tracking + tracking_config = config + if config is not None and not isinstance(config, dict): + from omegaconf import OmegaConf + + tracking_config = OmegaConf.to_container(config, resolve=True) + + # Ensure 'trainer' key exists to avoid KeyError in verl.utils.tracking + if tracking_config is not None: + if "trainer" not in tracking_config: + tracking_config["trainer"] = {} + self.tracker = Tracking( project_name=project_name, experiment_name=experiment_name, default_backend=logger_backends, - config=None, # Can pass config if needed + config=tracking_config, ) logger.info(f"Initialized tracking with backends: {logger_backends}") @@ -807,10 +820,11 @@ def fit( # Update progress bar if verbose and progress_bar: # Show key metrics in progress bar (filter game/ metrics except win_rate) + # Added wmc_erc/mask_ratio to monitor dynamic entropy clipping display_metrics = { k: v for k, v in last_metrics.items() - if not k.startswith("game/") or k == "game/win_rate" + if not k.startswith("game/") or k == "game/win_rate" or k == "wmc_erc/mask_ratio" } metrics_str = ", ".join( [ diff --git a/opentinker/environment/__init__.py b/opentinker/environment/__init__.py index 6d096c35..7bc88a03 100755 --- a/opentinker/environment/__init__.py +++ b/opentinker/environment/__init__.py @@ -32,14 +32,33 @@ run_game_server, ) -from opentinker.environment.inference_pipeline import ( - InferencePipeline, - InferenceResult, - RemoteEnvironmentClient, - run_inference, - load_samples, - generate_samples, -) +# Lazy import for InferencePipeline to avoid heavy dependencies (like vllm) +# when only the game server is needed. +def __getattr__(name): + if name in [ + "InferencePipeline", + "InferenceResult", + "RemoteEnvironmentClient", + "run_inference", + "load_samples", + "generate_samples", + ]: + from opentinker.environment.inference_pipeline import ( + InferencePipeline, + InferenceResult, + RemoteEnvironmentClient, + run_inference, + load_samples, + generate_samples, + ) + globals()["InferencePipeline"] = InferencePipeline + globals()["InferenceResult"] = InferenceResult + globals()["RemoteEnvironmentClient"] = RemoteEnvironmentClient + globals()["run_inference"] = run_inference + globals()["load_samples"] = load_samples + globals()["generate_samples"] = generate_samples + return globals()[name] + raise AttributeError(f"module {__name__} has no attribute {name}") __all__ = [ # Base diff --git a/opentinker/environment/base_game_server.py b/opentinker/environment/base_game_server.py index 668add10..77c2a59c 100755 --- a/opentinker/environment/base_game_server.py +++ b/opentinker/environment/base_game_server.py @@ -383,42 +383,48 @@ async def health_check(): @app.post("/reset") async def reset(request: ResetRequest): """Reset/create a game instance.""" - instance_id = request.instance_id - job_id = request.job_id - # Extract extra fields for game reset (exclude instance_id and job_id) - reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) + try: + instance_id = request.instance_id + job_id = request.job_id + # Extract extra fields for game reset (exclude instance_id and job_id) + reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) - with games_lock: - # Reuse existing game instance if available (avoids re-initialization) - if instance_id in games: - game = games[instance_id] - else: - game = game_class(**game_kwargs) - games[instance_id] = game + with games_lock: + # Reuse existing game instance if available (avoids re-initialization) + if instance_id in games: + game = games[instance_id] + else: + game = game_class(**game_kwargs) + games[instance_id] = game - # Reset the game (this is the slow part) - observation = game.reset(**reset_kwargs) + # Reset the game (this is the slow part) + observation = game.reset(**reset_kwargs) - # Track that this game has started (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.register_game_start(instance_id, job_id) + # Track that this game has started (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.register_game_start(instance_id, job_id) - system_prompt = game.get_system_prompt() - initial_message = game.get_initial_user_message() - full_observation = ( - f"{initial_message}\n\n{observation}" if observation else initial_message - ) + system_prompt = game.get_system_prompt() + initial_message = game.get_initial_user_message() + full_observation = ( + f"{initial_message}\n\n{observation}" if observation else initial_message + ) - response = { - "observation": full_observation, - "system_prompt": system_prompt, - } + response = { + "observation": full_observation, + "system_prompt": system_prompt, + } - state = game.get_state() - if state: - response["game_state"] = state + state = game.get_state() + if state: + response["game_state"] = state - return response + return response + except Exception as e: + import traceback + error_msg = f"Reset failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + raise HTTPException(status_code=500, detail=error_msg) @app.post("/finalize") async def finalize(request: ResetRequest): @@ -433,37 +439,45 @@ async def finalize(request: ResetRequest): @app.post("/step") async def step(request: StepRequest): """Execute a step in the game.""" - instance_id = request.instance_id - job_id = request.job_id - action = request.action + try: + instance_id = request.instance_id + job_id = request.job_id + action = request.action + + if instance_id not in games: + raise HTTPException( + status_code=404, + detail=f"Instance {instance_id} not found. Call /reset first.", + ) - if instance_id not in games: - raise HTTPException( - status_code=404, - detail=f"Instance {instance_id} not found. Call /reset first.", - ) + game = games[instance_id] + result = game.step(action) - game = games[instance_id] - result = game.step(action) + # Record statistics with instance_id for per-game tracking (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.record_game_result( + result.info, result.reward, result.done, instance_id, job_id + ) - # Record statistics with instance_id for per-game tracking (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.record_game_result( - result.info, result.reward, result.done, instance_id, job_id - ) + # Clean up finished games + if result.done: + with games_lock: + if instance_id in games: + del games[instance_id] - # Clean up finished games - if result.done: - with games_lock: - if instance_id in games: - del games[instance_id] - - return { - "observation": result.observation, - "reward": result.reward, - "done": result.done, - "info": result.info, - } + return { + "observation": result.observation, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + except Exception as e: + import traceback + error_msg = f"Step failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=error_msg) @app.get("/stats") async def get_stats(job_id: str = "default"): diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 2a1824d4..abc5b42b 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -1035,6 +1035,22 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: Popen process object """ env = os.environ.copy() + + # [FIX] Check for broken libcuda.so.1 (size 0) - common issue in some environments + libcuda_path = "/usr/lib/x86_64-linux-gnu/libcuda.so.1" + if os.path.exists(libcuda_path) and os.path.getsize(libcuda_path) == 0: + compat_path = "/usr/local/cuda-12.4/compat" + if os.path.isdir(compat_path): + current_ld_path = env.get("LD_LIBRARY_PATH", "") + env["LD_LIBRARY_PATH"] = ( + f"{compat_path}:{current_ld_path}".strip(":") + if current_ld_path + else compat_path + ) + logger.info( + f"Job {job.job_id}: 🛠 Fixed broken libcuda.so.1 by adding {compat_path} to LD_LIBRARY_PATH" + ) + # Set CUDA_VISIBLE_DEVICES to comma-separated list of GPU IDs env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, job.gpu_ids)) # Pass job_id to agent loop for per-client trace subdirectory isolation diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh new file mode 100755 index 00000000..97ab8252 --- /dev/null +++ b/opentinker/scripts/run_alfworld.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# ALFWorld Training Script (Multi-Turn) +# +# This script runs ALFWorld RL training with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env +# Terminal 3: bash run_alfworld.sh client +# +# Prerequisites: +# - pip install alfworld +# - alfworld-download +# - See docs/alfworld_multiturn.md for environment setup + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT="${SCHEDULER_PORT:-9780}" +ENV_PORT="${ENV_PORT:-1234}" +GPUS="${GPUS:-[0,1,2,3]}" +MODEL_PATH="/inspire/hdd/project/robot-reasoning/xuyue-p-xuyue/ziyu/.cache/huggingface/hub/models--Qwen--Qwen2.5-3B-Instruct" + +# OpenTinker root (relative to this script: opentinker/scripts/run_alfworld.sh) +OPENTINKER_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +export VLLM_DISABLE_SLEEP_MODE=1 +export HF_HUB_OFFLINE=1 +export WANDB_MODE=offline + +# Activate conda environment (adjust to your setup if needed) +if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/anaconda3/etc/profile.d/conda.sh" +elif [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +fi +# conda activate opentinker + +# Change to OpenTinker directory +cd "$OPENTINKER_ROOT" + +# Get current host IP for communication between components +# Use 127.0.0.1 if running everything on the same machine +HOST_IP="${HOST_IP:-127.0.0.1}" + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + env|2) + echo "========================================" + echo "Step 2: Starting ALFWorld Environment Server on port $ENV_PORT" + echo "========================================" + python -m opentinker.environment.alfworld.alfworld_server \ + --port "$ENV_PORT" \ + --max_steps 50 \ + --split train \ + --num_games -1 \ + ;; + + client|3) + echo "========================================" + echo "Step 3: Starting ALFWorld RL Client" + echo "========================================" + python opentinker/client/alfworld_rl.py \ + --config-name alfworld_wmc_erc_param \ + tokenizer_path="$MODEL_PATH" \ + batch_size=4 \ + val_batch_size=50 \ + num_steps=1000 \ + save_freq=2000 \ + test_freq=10 \ + scheduler_url="http://$HOST_IP:$SCHEDULER_PORT" \ + interaction.config.env_port="$ENV_PORT" \ + interaction.config.env_host="$HOST_IP" + ;; + + *) + echo "Usage: $0 {scheduler|env|client}" + echo "" + echo "Example (separate terminals):" + echo " Terminal 1: bash $0 scheduler" + echo " Terminal 2: bash $0 env" + echo " Terminal 3: bash $0 client" + exit 1 + ;; +esac diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 8d67b106..9ef27d05 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -27,6 +27,7 @@ import asyncio import base64 +import gc import logging import signal import sys @@ -640,6 +641,15 @@ def __init__( self.is_initialized = False self.global_steps = 0 + # WMC-ERC running statistics (S_bar, Sigma, H_bar) + self.wmc_erc_stats = { + "s_bar": 0.0, + "s_std": 1.0, + "h_bar": 1.0, + "momentum": 0.9, # EMA momentum + "initialized": False, + } + # Generation config (can be overridden by client) self.generation_config = { "do_sample": True, # CRITICAL: Enable sampling by default for PPO training @@ -943,7 +953,9 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # episodes into individual per-turn training samples, the gen_batch_output # batch size is larger than the original batch. We need to expand the # original batch to match using the expansion index. - expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + expansion_index = gen_batch_output.meta_info.pop( + "per_turn_expansion_index", None + ) if expansion_index is not None: logger.info( f"[Per-turn training] Expanding original batch from {len(batch)} to " @@ -956,6 +968,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) # Expand non-tensor batch expanded_non_tensor = {} @@ -1080,6 +1093,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) + _wmc_erc_entropys = entropys # Preserve for WMC-ERC before pop old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) logger.info( @@ -1161,6 +1175,24 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) + # 9.5 WMC-ERC: Dynamic entropy clipping + wmc_erc_cfg = OmegaConf.select(self.config, "wmc_erc", default=None) + if wmc_erc_cfg and wmc_erc_cfg.get("enable", False): + from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + ) + + batch, wmc_metrics = apply_wmc_erc( + batch, _wmc_erc_entropys, wmc_erc_cfg, self.wmc_erc_stats + ) + metrics.update(wmc_metrics) + clipping_mode = wmc_erc_cfg.get("clipping_type", "batch") + logger.info( + f"[WMC-ERC] mode={clipping_mode}, mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " + f"s_star={wmc_metrics.get('wmc_erc/batch_s_bar', 'N/A'):.4f}, " + f"h_wm={wmc_metrics.get('wmc_erc/batch_h_bar', 'N/A'):.4f}" + ) + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1258,6 +1290,13 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info(f"Training step {self.global_steps} completed successfully") + # Free large intermediates and force garbage collection to prevent OOM + del batch, gen_batch, gen_batch_output, reward_tensor + if "_wmc_erc_entropys" in dir(): + del _wmc_erc_entropys + gc.collect() + torch.cuda.empty_cache() + return { "status": "success", "metrics": metrics, @@ -1535,7 +1574,9 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: # 6. Merge original batch and generated output # Per-turn training expansion: expand batch if gen output is larger - expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) + expansion_index = gen_batch_output.meta_info.pop( + "per_turn_expansion_index", None + ) if expansion_index is not None: logger.info( f"[Per-turn training] Validation: Expanding original batch from {len(batch)} to " @@ -1547,6 +1588,7 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict + batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) expanded_non_tensor = {} for k, v in batch.non_tensor_batch.items(): @@ -2106,6 +2148,13 @@ def run_fastapi_server(): namespace=_server_cfg.ray.namespace, num_gpus=_server_cfg.trainer.n_gpus_per_node, # Explicitly specify number of GPUs ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NCCL_CUMEM_ENABLE": "0", + "VLLM_DISABLE_SLEEP_MODE": "1", + "RAY_memory_usage_threshold": "0.99", + }, + }, ) else: # Connect to existing Ray cluster at specific address @@ -2114,6 +2163,13 @@ def run_fastapi_server(): address=_server_cfg.ray.address, namespace=_server_cfg.ray.namespace, ignore_reinit_error=True, + runtime_env={ + "env_vars": { + "NCCL_CUMEM_ENABLE": "0", + "VLLM_DISABLE_SLEEP_MODE": "1", + "RAY_memory_usage_threshold": "0.99", + }, + }, ) # Verify GPU availability diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py new file mode 100644 index 00000000..f083277a --- /dev/null +++ b/opentinker/tests/test_wmc_erc.py @@ -0,0 +1,128 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for WMC-ERC dynamic entropy clipping module using unittest. + +Run with: python opentinker/tests/test_wmc_erc.py +""" + +import unittest +import numpy as np +import torch +from unittest.mock import MagicMock + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) +from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + compute_dynamic_mask, + compute_h_wm, + compute_s_star, +) + + +class TestWmcErc(unittest.TestCase): + """Test WMC-ERC components and orchestration.""" + + def test_single_turn_s_star(self): + """S_* for a single turn = mean of p_k * (H + log p_k) over action tokens.""" + p = torch.tensor([0.8, 0.6, 0.9, 0.7]) + old_log_probs = torch.zeros(1, 6) + old_log_probs[0, :4] = torch.log(p) + entropys = torch.tensor([[1.0, 1.5, 0.5, 1.2, 0.0, 0.0]]) + response_mask = torch.tensor([[1, 1, 1, 1, 0, 0]], dtype=torch.float32) + boundaries = compute_turn_boundaries(response_mask) + + result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 1) + + H = torch.tensor([1.0, 1.5, 0.5, 1.2]) + expected = (p * (H + torch.log(p))).mean().item() + self.assertAlmostEqual(result[0][0].item(), expected, places=5) + + def test_asymmetric_behavior(self): + """Test that mu_base and mu_exp act differently using compute_dynamic_mask.""" + s_star = [[torch.tensor(15.0)], [torch.tensor(5.0)]] # Mean=10 + h_wm = [[torch.tensor(1.0)]] * 2 + s_bar = 10.0 + sigma = 5.0 + h_bar = 1.0 + + # 1. mu_base=0.1 (block high), mu_exp=10.0 (allow low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=0.1, mu_exp=10.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, h_bar=h_bar) + self.assertEqual(mask[0], [0.0]) # High blocked + self.assertEqual(mask[1], [1.0]) # Low allowed + + # 2. mu_base=10.0 (allow high), mu_exp=0.1 (block low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=10.0, mu_exp=0.1, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, h_bar=h_bar) + self.assertEqual(mask[0], [1.0]) # High allowed + self.assertEqual(mask[1], [0.0]) # Low blocked + + def _make_batch(self, advantages, response_mask, old_log_probs, attention_mask): + batch = MagicMock() + batch.batch = { + "advantages": advantages.clone(), + "response_mask": response_mask, + "old_log_probs": old_log_probs, + "attention_mask": attention_mask, + } + return batch + + def test_clipping_type_batch(self): + """Verify that 'batch' mode uses current batch statistics.""" + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + # Sample 3 is an outlier in this batch + old_log_probs = torch.tensor([[np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.9)]*4]) + entropys = torch.tensor([[2.0]*4, [2.0]*4, [2.0]*4, [3.0]*4]) + attention_mask = torch.ones(4, 4, dtype=torch.float32) + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + config = {"mu_base": 0.1, "mu_exp": 10.0, "lambda_wm": 1.0, "enable": True, "clipping_type": "batch"} + running_stats = {"s_bar": 100.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} # Very different global stats + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # In 'batch' mode, Sample 3 should be blocked based on BATCH mean, not GLOBAL mean + # If it used global mean (100), sample 3 (S*~2.6) would be an exploration outlier and allowed by mu_exp=10 + # But in batch mode (S_bar~0.8), sample 3 is a collapsing outlier and blocked by mu_base=0.1 + self.assertTrue((batch.batch["advantages"][3] == 0).all()) + + def test_clipping_type_global(self): + """Verify that 'global' mode uses running statistics.""" + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + old_log_probs = torch.tensor([[np.log(0.3)]*4]*4) + entropys = torch.tensor([[2.0]*4]*4) + # S* for all samples will be ~0.24 + + attention_mask = torch.ones(4, 4, dtype=torch.float32) + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Global s_bar is far away (10.0), so batch S* (0.24) looks like a huge exploration outlier + config = {"mu_base": 10.0, "mu_exp": 0.1, "lambda_wm": 0.0, "enable": True, "clipping_type": "global"} + running_stats = {"s_bar": 10.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # Should be blocked because mu_exp=0.1 is very tight relative to global s_bar + self.assertTrue((batch.batch["advantages"] == 0).all()) + + +if __name__ == "__main__": + unittest.main()