From 586dd318959f1489048caa7634668d1df595075e Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Tue, 17 Mar 2026 15:36:28 +0000 Subject: [PATCH 1/8] Added entropy normalization and different threshold for clipping. --- .../backend_patch/verl/trainer/ppo/wmc_erc.py | 163 ++++--- opentinker/client/alfworld_rl.py | 1 + opentinker/client/android_world_rl.py | 1 + .../client/client_config/alfworld_param.yaml | 4 +- .../client_config/alfworld_wmc_erc_param.yaml | 14 +- .../alfworld_wmc_erc_param_ppo.yaml | 93 ++++ opentinker/client/geo3k_rl.py | 1 + opentinker/client/geo3k_tool_rl.py | 1 + opentinker/client/gomoku_rl.py | 1 + opentinker/client/math_rl.py | 1 + opentinker/client/math_tool_rl.py | 1 + .../client/utils/http_training_client.py | 18 +- opentinker/environment/__init__.py | 35 +- opentinker/environment/base_game_server.py | 126 +++--- opentinker/scheduler/job_scheduler.py | 16 + opentinker/scripts/run_alfworld.sh | 97 ++++ opentinker/server/http_training_server.py | 18 +- opentinker/tests/test_wmc_erc.py | 413 +++--------------- 18 files changed, 520 insertions(+), 484 deletions(-) create mode 100644 opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml create mode 100755 opentinker/scripts/run_alfworld.sh diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py index 2ea6182a..c06bcf02 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -155,60 +155,52 @@ def compute_h_wm( def compute_dynamic_mask( s_star_per_sample: List[List[torch.Tensor]], h_wm_per_sample: List[List[torch.Tensor]], - mu_base: float = 1.0, - lambda_wm: float = 1.0, + mu_base: float, + mu_exp: float, + lambda_wm: float, + s_bar: float, + sigma: float, + h_bar: float, ) -> List[List[float]]: """Compute per-turn dynamic entropy clipping mask. - m_t = 1 if |S_*^t - S_bar| <= mu_base * (1 + lambda * H_WM^t) * sigma - 0 otherwise - Logic: - - WM confident (H_WM→0): threshold tightens → overconfident policy blocked - (prevents overfitting in well-understood regions) - - WM uncertain (H_WM large): threshold widens → exploration encouraged - (allows agent to gather data in unknown regions to train the WM) - - H_WM is detached — it only acts as a scalar gate, never participates in - policy gradient backpropagation. + - Normalization: H_WM is normalized by h_bar (batch or global). + - Threshold: threshold = mu * (1 + lambda * H_WM_norm) * sigma Args: s_star_per_sample: per-sample, per-turn S_* tensors h_wm_per_sample: per-sample, per-turn H_WM tensors - mu_base: base clipping coefficient + mu_base: clipping coefficient for collapsing side + mu_exp: clipping coefficient for exploration side lambda_wm: WM uncertainty weight + s_bar: mean of S_* (batch or global) + sigma: std of S_* (batch or global) + h_bar: mean of H_WM (batch or global) Returns: List of lists of floats (0.0 or 1.0), one mask per turn per sample. """ - # Flatten all S_* for batch statistics - all_s = [] - for turns in s_star_per_sample: - for s in turns: - all_s.append(s.detach()) - - if len(all_s) == 0: - return [[] for _ in s_star_per_sample] - - all_s_tensor = torch.stack(all_s) - s_bar = all_s_tensor.mean() - - # Guard for single-element: std is 0, threshold = mu_base * (1 + lambda * h_wm) * 0 - # → everything would be masked. Use 1.0 as default sigma for single element. - if len(all_s) <= 1: - sigma = torch.tensor(1.0, device=all_s_tensor.device) - else: - sigma = all_s_tensor.std(unbiased=False) + 1e-8 - mask_per_sample = [] for i in range(len(s_star_per_sample)): masks = [] for t in range(len(s_star_per_sample[i])): - s_t = s_star_per_sample[i][t].detach() - h_t = h_wm_per_sample[i][t].detach() - - threshold = mu_base * (1.0 + lambda_wm * h_t) * sigma - m_t = 1.0 if torch.abs(s_t - s_bar) <= threshold else 0.0 + s_t = s_star_per_sample[i][t].detach().item() + h_t = h_wm_per_sample[i][t].detach().item() + + # Normalize H_WM + h_t_norm = h_t / (h_bar + 1e-8) + + # Asymmetric threshold calculation + if s_t > s_bar: + # Collapsing side + threshold = mu_base * (1.0 + lambda_wm * h_t_norm) * sigma + m_t = 1.0 if (s_t - s_bar) <= threshold else 0.0 + else: + # Exploration side + threshold = mu_exp * (1.0 + lambda_wm * h_t_norm) * sigma + m_t = 1.0 if (s_bar - s_t) <= threshold else 0.0 + masks.append(m_t) mask_per_sample.append(masks) @@ -219,35 +211,28 @@ def apply_wmc_erc( batch, entropys: torch.Tensor, wmc_erc_config, + running_stats: Dict[str, float], ) -> Tuple[object, Dict[str, float]]: """Apply WMC-ERC dynamic entropy clipping to batch advantages. - Pipeline: - 1. Compute turn boundaries from response_mask - 2. Compute S_* (policy blind confidence) per turn - 3. Compute H_WM (world model uncertainty) per turn from env token entropys - 4. Compute dynamic mask m_t per turn - 5. Apply mask to advantages: A_masked = A * m_t (broadcast to tokens) - 6. Return metrics for logging - Args: - batch: DataProto or compatible object with batch dict containing - advantages, response_mask, old_log_probs, attention_mask - entropys: (batch_size, response_length) stored before pop in train_step - wmc_erc_config: OmegaConf DictConfig or dict with mu_base, lambda_wm, enable + batch: DataProto or compatible object + entropys: (batch_size, response_length) + wmc_erc_config: OmegaConf DictConfig or dict + running_stats: Dictionary for global running statistics Returns: - (batch, metrics) where batch has masked advantages and metrics dict + (batch, metrics) """ enable = wmc_erc_config.get("enable", True) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'enable', True) if not enable: return batch, {} + clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch") + response_mask = batch.batch["response_mask"] old_log_probs = batch.batch["old_log_probs"] advantages = batch.batch["advantages"] - - # Compute attention mask for response region response_length = advantages.shape[1] attention_mask = batch.batch["attention_mask"] attention_mask_response = attention_mask[:, -response_length:] @@ -255,18 +240,56 @@ def apply_wmc_erc( # 1. Turn boundaries turn_boundaries = compute_turn_boundaries(response_mask) - # 2. S_* per turn + # 2. Compute S_* and H_WM per turn s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries) - - # 3. H_WM per turn h_wm = compute_h_wm(entropys, response_mask, attention_mask_response, turn_boundaries) + # Calculate batch statistics + all_s = [s.item() for turns in s_star for s in turns] + all_h = [h.item() for turns in h_wm for h in turns] + + if not all_s: + return batch, {} + + batch_s_bar = np.mean(all_s) + batch_s_std = np.std(all_s) + 1e-8 + batch_h_bar = np.mean(all_h) + 1e-8 + + # Update global statistics + momentum = wmc_erc_config.get("momentum", 0.9) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'momentum', 0.9) + if not running_stats.get("initialized", False): + running_stats["s_bar"] = batch_s_bar + running_stats["s_std"] = batch_s_std + running_stats["h_bar"] = batch_h_bar + running_stats["initialized"] = True + else: + running_stats["s_bar"] = (1 - momentum) * batch_s_bar + momentum * running_stats["s_bar"] + running_stats["s_std"] = (1 - momentum) * batch_s_std + momentum * running_stats["s_std"] + running_stats["h_bar"] = (1 - momentum) * batch_h_bar + momentum * running_stats["h_bar"] + + # Select statistics for masking + if clipping_type == "global": + use_s_bar = running_stats["s_bar"] + use_s_std = running_stats["s_std"] + use_h_bar = running_stats["h_bar"] + else: + use_s_bar = batch_s_bar + use_s_std = batch_s_std + use_h_bar = batch_h_bar + # 4. Dynamic mask mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0)) + mu_exp = float(wmc_erc_config.get("mu_exp", 2.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_exp', 2.0)) lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0)) - mask = compute_dynamic_mask(s_star, h_wm, mu_base, lambda_wm) - - # 5. Apply mask to advantages (in-place) + + mask = compute_dynamic_mask( + s_star, h_wm, mu_base, mu_exp, lambda_wm, + s_bar=use_s_bar, + sigma=use_s_std, + h_bar=use_h_bar + ) + + # 5. Apply mask to advantages batch_size = advantages.shape[0] for i in range(batch_size): for t, (start, end) in enumerate(turn_boundaries[i]): @@ -275,21 +298,33 @@ def apply_wmc_erc( batch.batch["advantages"] = advantages # 6. Metrics - all_s = [s.item() for turns in s_star for s in turns] - all_h = [h.item() for turns in h_wm for h in turns] all_m = [m for turns in mask for m in turns] + + num_collapsing_masked = 0 + num_exploration_masked = 0 + for i in range(len(s_star)): + for t in range(len(s_star[i])): + if mask[i][t] == 0.0: + if s_star[i][t].item() > use_s_bar: + num_collapsing_masked += 1 + else: + num_exploration_masked += 1 - # WM NLL (monitoring only — not in backward pass for this prototype) env_mask = attention_mask_response * (1.0 - response_mask) env_count = env_mask.sum() wm_nll = (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() if env_count > 0 else 0.0 metrics = { - "wmc_erc/s_star_mean": float(np.mean(all_s)) if all_s else 0.0, - "wmc_erc/s_star_std": float(np.std(all_s)) if all_s else 0.0, - "wmc_erc/h_wm_mean": float(np.mean(all_h)) if all_h else 0.0, + "wmc_erc/batch_s_bar": float(batch_s_bar), + "wmc_erc/batch_s_std": float(batch_s_std), + "wmc_erc/batch_h_bar": float(batch_h_bar), + "wmc_erc/running_s_bar": float(running_stats["s_bar"]), + "wmc_erc/running_s_std": float(running_stats["s_std"]), + "wmc_erc/running_h_bar": float(running_stats["h_bar"]), "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), + "wmc_erc/num_collapsing_masked": num_collapsing_masked, + "wmc_erc/num_exploration_masked": num_exploration_masked, "wmc_erc/total_turns": len(all_m), "wmc_erc/wm_nll": wm_nll, } diff --git a/opentinker/client/alfworld_rl.py b/opentinker/client/alfworld_rl.py index 0810cc5d..df02c10a 100644 --- a/opentinker/client/alfworld_rl.py +++ b/opentinker/client/alfworld_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/android_world_rl.py b/opentinker/client/android_world_rl.py index 122b02f1..8af8e265 100644 --- a/opentinker/client/android_world_rl.py +++ b/opentinker/client/android_world_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 0f183186..c3fac911 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -60,7 +60,7 @@ interaction: env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} # If you run the ALFWorld env server in sharded mode (--shards N), # set env_shards=N. The client will route each instance_id to a stable shard. - env_shards: 32 + env_shards: 8 max_steps: 20 # ALFWorld episodes max steps max_total_steps: 20 # Max environment step calls (controls rollout turns) observation_template: "{observation}" @@ -80,4 +80,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index c6cf5c19..a788ce24 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -45,17 +45,23 @@ max_prompt_tokens: 2048 algorithm: "agent_loop" # Use per-step advantage for multi-turn credit assignment -adv_estimator: "grpo_per_step" +adv_estimator: "grpo" rollout_n: 8 # WMC-ERC: Dynamic Entropy Clipping -# - mu_base: base clipping coefficient (controls tightness of the gate) -# - lambda_wm: how much WM uncertainty widens the gate (higher = more tolerant in unknown regions) +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch wmc_erc: enable: true mu_base: 1.0 + mu_exp: 2.0 lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 # Interaction configuration interaction: @@ -83,4 +89,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: null # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml new file mode 100644 index 00000000..94cf216e --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml @@ -0,0 +1,93 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc_ppo + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "gae" +rollout_n: 1 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 2.0 + lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 +agent_num_workers: 4 diff --git a/opentinker/client/geo3k_rl.py b/opentinker/client/geo3k_rl.py index f45f5f8f..f51b7a0a 100644 --- a/opentinker/client/geo3k_rl.py +++ b/opentinker/client/geo3k_rl.py @@ -70,6 +70,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/geo3k_tool_rl.py b/opentinker/client/geo3k_tool_rl.py index e812132f..9f08d9a7 100644 --- a/opentinker/client/geo3k_tool_rl.py +++ b/opentinker/client/geo3k_tool_rl.py @@ -93,6 +93,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/gomoku_rl.py b/opentinker/client/gomoku_rl.py index bf42ca8e..cb4b2608 100755 --- a/opentinker/client/gomoku_rl.py +++ b/opentinker/client/gomoku_rl.py @@ -114,6 +114,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/math_rl.py b/opentinker/client/math_rl.py index 5bffb2df..804cc545 100755 --- a/opentinker/client/math_rl.py +++ b/opentinker/client/math_rl.py @@ -74,6 +74,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/math_tool_rl.py b/opentinker/client/math_tool_rl.py index 1c68b06d..3ef1d3ce 100755 --- a/opentinker/client/math_tool_rl.py +++ b/opentinker/client/math_tool_rl.py @@ -73,6 +73,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 08f65e9b..958ba2d2 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -579,6 +579,7 @@ def __init__( project_name: Optional[str] = None, experiment_name: Optional[str] = None, logger_backends: Optional[List[str]] = None, + config: Optional[Any] = None, **client_kwargs, ): self.client = HTTPTrainingClient(server_url, **client_kwargs) @@ -588,11 +589,23 @@ def __init__( if logger_backends and project_name and experiment_name: from verl.utils.tracking import Tracking + # Convert DictConfig to dict if necessary for Tracking + tracking_config = config + if config is not None and not isinstance(config, dict): + from omegaconf import OmegaConf + + tracking_config = OmegaConf.to_container(config, resolve=True) + + # Ensure 'trainer' key exists to avoid KeyError in verl.utils.tracking + if tracking_config is not None: + if "trainer" not in tracking_config: + tracking_config["trainer"] = {} + self.tracker = Tracking( project_name=project_name, experiment_name=experiment_name, default_backend=logger_backends, - config=None, # Can pass config if needed + config=tracking_config, ) logger.info(f"Initialized tracking with backends: {logger_backends}") @@ -807,10 +820,11 @@ def fit( # Update progress bar if verbose and progress_bar: # Show key metrics in progress bar (filter game/ metrics except win_rate) + # Added wmc_erc/mask_ratio to monitor dynamic entropy clipping display_metrics = { k: v for k, v in last_metrics.items() - if not k.startswith("game/") or k == "game/win_rate" + if not k.startswith("game/") or k == "game/win_rate" or k == "wmc_erc/mask_ratio" } metrics_str = ", ".join( [ diff --git a/opentinker/environment/__init__.py b/opentinker/environment/__init__.py index 6d096c35..7bc88a03 100755 --- a/opentinker/environment/__init__.py +++ b/opentinker/environment/__init__.py @@ -32,14 +32,33 @@ run_game_server, ) -from opentinker.environment.inference_pipeline import ( - InferencePipeline, - InferenceResult, - RemoteEnvironmentClient, - run_inference, - load_samples, - generate_samples, -) +# Lazy import for InferencePipeline to avoid heavy dependencies (like vllm) +# when only the game server is needed. +def __getattr__(name): + if name in [ + "InferencePipeline", + "InferenceResult", + "RemoteEnvironmentClient", + "run_inference", + "load_samples", + "generate_samples", + ]: + from opentinker.environment.inference_pipeline import ( + InferencePipeline, + InferenceResult, + RemoteEnvironmentClient, + run_inference, + load_samples, + generate_samples, + ) + globals()["InferencePipeline"] = InferencePipeline + globals()["InferenceResult"] = InferenceResult + globals()["RemoteEnvironmentClient"] = RemoteEnvironmentClient + globals()["run_inference"] = run_inference + globals()["load_samples"] = load_samples + globals()["generate_samples"] = generate_samples + return globals()[name] + raise AttributeError(f"module {__name__} has no attribute {name}") __all__ = [ # Base diff --git a/opentinker/environment/base_game_server.py b/opentinker/environment/base_game_server.py index 668add10..77c2a59c 100755 --- a/opentinker/environment/base_game_server.py +++ b/opentinker/environment/base_game_server.py @@ -383,42 +383,48 @@ async def health_check(): @app.post("/reset") async def reset(request: ResetRequest): """Reset/create a game instance.""" - instance_id = request.instance_id - job_id = request.job_id - # Extract extra fields for game reset (exclude instance_id and job_id) - reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) + try: + instance_id = request.instance_id + job_id = request.job_id + # Extract extra fields for game reset (exclude instance_id and job_id) + reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) - with games_lock: - # Reuse existing game instance if available (avoids re-initialization) - if instance_id in games: - game = games[instance_id] - else: - game = game_class(**game_kwargs) - games[instance_id] = game + with games_lock: + # Reuse existing game instance if available (avoids re-initialization) + if instance_id in games: + game = games[instance_id] + else: + game = game_class(**game_kwargs) + games[instance_id] = game - # Reset the game (this is the slow part) - observation = game.reset(**reset_kwargs) + # Reset the game (this is the slow part) + observation = game.reset(**reset_kwargs) - # Track that this game has started (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.register_game_start(instance_id, job_id) + # Track that this game has started (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.register_game_start(instance_id, job_id) - system_prompt = game.get_system_prompt() - initial_message = game.get_initial_user_message() - full_observation = ( - f"{initial_message}\n\n{observation}" if observation else initial_message - ) + system_prompt = game.get_system_prompt() + initial_message = game.get_initial_user_message() + full_observation = ( + f"{initial_message}\n\n{observation}" if observation else initial_message + ) - response = { - "observation": full_observation, - "system_prompt": system_prompt, - } + response = { + "observation": full_observation, + "system_prompt": system_prompt, + } - state = game.get_state() - if state: - response["game_state"] = state + state = game.get_state() + if state: + response["game_state"] = state - return response + return response + except Exception as e: + import traceback + error_msg = f"Reset failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + raise HTTPException(status_code=500, detail=error_msg) @app.post("/finalize") async def finalize(request: ResetRequest): @@ -433,37 +439,45 @@ async def finalize(request: ResetRequest): @app.post("/step") async def step(request: StepRequest): """Execute a step in the game.""" - instance_id = request.instance_id - job_id = request.job_id - action = request.action + try: + instance_id = request.instance_id + job_id = request.job_id + action = request.action + + if instance_id not in games: + raise HTTPException( + status_code=404, + detail=f"Instance {instance_id} not found. Call /reset first.", + ) - if instance_id not in games: - raise HTTPException( - status_code=404, - detail=f"Instance {instance_id} not found. Call /reset first.", - ) + game = games[instance_id] + result = game.step(action) - game = games[instance_id] - result = game.step(action) + # Record statistics with instance_id for per-game tracking (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.record_game_result( + result.info, result.reward, result.done, instance_id, job_id + ) - # Record statistics with instance_id for per-game tracking (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.record_game_result( - result.info, result.reward, result.done, instance_id, job_id - ) + # Clean up finished games + if result.done: + with games_lock: + if instance_id in games: + del games[instance_id] - # Clean up finished games - if result.done: - with games_lock: - if instance_id in games: - del games[instance_id] - - return { - "observation": result.observation, - "reward": result.reward, - "done": result.done, - "info": result.info, - } + return { + "observation": result.observation, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + except Exception as e: + import traceback + error_msg = f"Step failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=error_msg) @app.get("/stats") async def get_stats(job_id: str = "default"): diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 2a1824d4..abc5b42b 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -1035,6 +1035,22 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: Popen process object """ env = os.environ.copy() + + # [FIX] Check for broken libcuda.so.1 (size 0) - common issue in some environments + libcuda_path = "/usr/lib/x86_64-linux-gnu/libcuda.so.1" + if os.path.exists(libcuda_path) and os.path.getsize(libcuda_path) == 0: + compat_path = "/usr/local/cuda-12.4/compat" + if os.path.isdir(compat_path): + current_ld_path = env.get("LD_LIBRARY_PATH", "") + env["LD_LIBRARY_PATH"] = ( + f"{compat_path}:{current_ld_path}".strip(":") + if current_ld_path + else compat_path + ) + logger.info( + f"Job {job.job_id}: 🛠 Fixed broken libcuda.so.1 by adding {compat_path} to LD_LIBRARY_PATH" + ) + # Set CUDA_VISIBLE_DEVICES to comma-separated list of GPU IDs env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, job.gpu_ids)) # Pass job_id to agent loop for per-client trace subdirectory isolation diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh new file mode 100755 index 00000000..97ab8252 --- /dev/null +++ b/opentinker/scripts/run_alfworld.sh @@ -0,0 +1,97 @@ +#!/bin/bash +# ALFWorld Training Script (Multi-Turn) +# +# This script runs ALFWorld RL training with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env +# Terminal 3: bash run_alfworld.sh client +# +# Prerequisites: +# - pip install alfworld +# - alfworld-download +# - See docs/alfworld_multiturn.md for environment setup + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT="${SCHEDULER_PORT:-9780}" +ENV_PORT="${ENV_PORT:-1234}" +GPUS="${GPUS:-[0,1,2,3]}" +MODEL_PATH="/inspire/hdd/project/robot-reasoning/xuyue-p-xuyue/ziyu/.cache/huggingface/hub/models--Qwen--Qwen2.5-3B-Instruct" + +# OpenTinker root (relative to this script: opentinker/scripts/run_alfworld.sh) +OPENTINKER_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +export VLLM_DISABLE_SLEEP_MODE=1 +export HF_HUB_OFFLINE=1 +export WANDB_MODE=offline + +# Activate conda environment (adjust to your setup if needed) +if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/anaconda3/etc/profile.d/conda.sh" +elif [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +fi +# conda activate opentinker + +# Change to OpenTinker directory +cd "$OPENTINKER_ROOT" + +# Get current host IP for communication between components +# Use 127.0.0.1 if running everything on the same machine +HOST_IP="${HOST_IP:-127.0.0.1}" + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + env|2) + echo "========================================" + echo "Step 2: Starting ALFWorld Environment Server on port $ENV_PORT" + echo "========================================" + python -m opentinker.environment.alfworld.alfworld_server \ + --port "$ENV_PORT" \ + --max_steps 50 \ + --split train \ + --num_games -1 \ + ;; + + client|3) + echo "========================================" + echo "Step 3: Starting ALFWorld RL Client" + echo "========================================" + python opentinker/client/alfworld_rl.py \ + --config-name alfworld_wmc_erc_param \ + tokenizer_path="$MODEL_PATH" \ + batch_size=4 \ + val_batch_size=50 \ + num_steps=1000 \ + save_freq=2000 \ + test_freq=10 \ + scheduler_url="http://$HOST_IP:$SCHEDULER_PORT" \ + interaction.config.env_port="$ENV_PORT" \ + interaction.config.env_host="$HOST_IP" + ;; + + *) + echo "Usage: $0 {scheduler|env|client}" + echo "" + echo "Example (separate terminals):" + echo " Terminal 1: bash $0 scheduler" + echo " Terminal 2: bash $0 env" + echo " Terminal 3: bash $0 client" + exit 1 + ;; +esac diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index c4fe1ce0..9ef27d05 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -641,6 +641,15 @@ def __init__( self.is_initialized = False self.global_steps = 0 + # WMC-ERC running statistics (S_bar, Sigma, H_bar) + self.wmc_erc_stats = { + "s_bar": 0.0, + "s_std": 1.0, + "h_bar": 1.0, + "momentum": 0.9, # EMA momentum + "initialized": False, + } + # Generation config (can be overridden by client) self.generation_config = { "do_sample": True, # CRITICAL: Enable sampling by default for PPO training @@ -1174,13 +1183,14 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: ) batch, wmc_metrics = apply_wmc_erc( - batch, _wmc_erc_entropys, wmc_erc_cfg + batch, _wmc_erc_entropys, wmc_erc_cfg, self.wmc_erc_stats ) metrics.update(wmc_metrics) + clipping_mode = wmc_erc_cfg.get("clipping_type", "batch") logger.info( - f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " - f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " - f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}" + f"[WMC-ERC] mode={clipping_mode}, mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " + f"s_star={wmc_metrics.get('wmc_erc/batch_s_bar', 'N/A'):.4f}, " + f"h_wm={wmc_metrics.get('wmc_erc/batch_h_bar', 'N/A'):.4f}" ) # 10. Update critic diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py index 3d53a335..f083277a 100644 --- a/opentinker/tests/test_wmc_erc.py +++ b/opentinker/tests/test_wmc_erc.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests for WMC-ERC dynamic entropy clipping module. +Tests for WMC-ERC dynamic entropy clipping module using unittest. -Run with: pytest opentinker/tests/test_wmc_erc.py -v +Run with: python opentinker/tests/test_wmc_erc.py """ +import unittest import numpy as np -import pytest import torch +from unittest.mock import MagicMock from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( compute_turn_boundaries, @@ -32,12 +33,11 @@ ) -class TestComputeSStar: - """Test S_* (policy blind confidence) computation.""" +class TestWmcErc(unittest.TestCase): + """Test WMC-ERC components and orchestration.""" - def test_single_turn(self): + def test_single_turn_s_star(self): """S_* for a single turn = mean of p_k * (H + log p_k) over action tokens.""" - # 1 sample, 6 positions: 4 action + 2 padding p = torch.tensor([0.8, 0.6, 0.9, 0.7]) old_log_probs = torch.zeros(1, 6) old_log_probs[0, :4] = torch.log(p) @@ -46,214 +46,34 @@ def test_single_turn(self): boundaries = compute_turn_boundaries(response_mask) result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 1 # 1 sample - assert len(result[0]) == 1 # 1 turn + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 1) - # Manual: p_k * (H + log(p_k)) for each token, then mean H = torch.tensor([1.0, 1.5, 0.5, 1.2]) expected = (p * (H + torch.log(p))).mean().item() - assert abs(result[0][0].item() - expected) < 1e-5 - - def test_multi_turn(self): - """Two turns should produce two S_* values.""" - old_log_probs = torch.log( - torch.tensor([[0.8, 0.6, 0.5, 0.5, 0.9, 0.7, 0.5, 0.5]]) - ) - entropys = torch.tensor([[1.0, 1.5, 0.0, 0.0, 0.5, 1.2, 0.0, 0.0]]) - response_mask = torch.tensor( - [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 - ) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result[0]) == 2 # 2 turns - # Turn 0 and Turn 1 should have different values - assert result[0][0].item() != result[0][1].item() - - def test_batch(self): - """Batch of 2 samples.""" - old_log_probs = torch.log( - torch.tensor( - [ - [0.8, 0.6, 0.5, 0.5], - [0.5, 0.9, 0.5, 0.5], - ] - ) - ) - entropys = torch.tensor( - [ - [1.0, 1.5, 0.0, 0.0], - [2.0, 0.3, 0.0, 0.0], - ] - ) - response_mask = torch.tensor( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - ], - dtype=torch.float32, - ) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 2 - - def test_empty_turns(self): - """Sample with no action tokens should produce empty list.""" - old_log_probs = torch.zeros(1, 4) - entropys = torch.zeros(1, 4) - response_mask = torch.zeros(1, 4, dtype=torch.float32) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 1 - assert len(result[0]) == 0 - - -class TestComputeHWM: - """Test H_WM (world model uncertainty) computation.""" - - def test_single_turn_with_env_tokens(self): - """H_WM for turn 0 = mean entropy at env positions after turn 0.""" - # Sequence: [action, action, env, env, pad, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.0, 0.0]]) - response_mask = torch.tensor([[1, 1, 0, 0, 0, 0]], dtype=torch.float32) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2)]] # 1 turn: action at [0,2) - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result) == 1 - assert len(result[0]) == 1 - # Env tokens at [2,3] → mean entropy = (3.0 + 4.0) / 2 = 3.5 - assert abs(result[0][0].item() - 3.5) < 1e-5 - - def test_two_turns(self): - """Two turns: H_WM_0 from env between turns, H_WM_1 from env after turn 1.""" - # [act, act, env, env, act, act, env, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.5, 1.2, 2.0, 0.0]]) - response_mask = torch.tensor( - [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 - ) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2), (4, 6)]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result[0]) == 2 - # Turn 0 env: positions [2,4) → (3.0+4.0)/2 = 3.5 - assert abs(result[0][0].item() - 3.5) < 1e-5 - # Turn 1 env: positions [6,8) but attn_mask=[1,0] → only pos 6 → 2.0 - assert abs(result[0][1].item() - 2.0) < 1e-5 - - def test_no_env_after_last_turn(self): - """Last turn has no env tokens → H_WM = 0.""" - # [act, act, env, act, act, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 0.5, 1.2, 0.0]]) - response_mask = torch.tensor([[1, 1, 0, 1, 1, 0]], dtype=torch.float32) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 1, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2), (3, 5)]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - # Turn 0 env: positions [2,3) → 3.0 - assert abs(result[0][0].item() - 3.0) < 1e-5 - # Turn 1 env: positions [5,6) but attn_mask=[0] → H_WM = 0 - assert result[0][1].item() == 0.0 - - def test_empty_turns(self): - """No turns → empty H_WM list.""" - entropys = torch.zeros(1, 4) - response_mask = torch.zeros(1, 4, dtype=torch.float32) - attention_mask_response = torch.ones(1, 4, dtype=torch.float32) - boundaries = [[]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result[0]) == 0 - - -class TestComputeDynamicMask: - """Test dynamic entropy clipping mask.""" - - def test_all_pass(self): - """When all S_* are close to mean, all masks = 1.""" - s_star = [[torch.tensor(1.0), torch.tensor(1.1)]] - h_wm = [[torch.tensor(0.5), torch.tensor(0.5)]] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=3.0, lambda_wm=1.0) - assert mask == [[1.0, 1.0]] - - def test_outlier_blocked_low_hwm(self): - """High S_* outlier with low H_WM (known env) → blocked.""" - s_star = [ - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(10.0)], # outlier - ] - h_wm = [ - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], # known env → tight threshold - ] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # The outlier (10.0) should be blocked when threshold is tight - assert mask[3] == [0.0] - # Normal ones should pass - assert mask[0] == [1.0] - - def test_outlier_allowed_high_hwm(self): - """High S_* outlier with high H_WM (unknown env) → allowed.""" - s_star = [ - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(10.0)], # outlier - ] - h_wm = [ - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(100.0)], # very uncertain → wide threshold - ] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # The outlier should be allowed because H_WM is very high - assert mask[3] == [1.0] - - def test_empty(self): - """Empty input should return empty.""" - mask = compute_dynamic_mask([], [], mu_base=1.0, lambda_wm=1.0) - assert mask == [] - - def test_single_element(self): - """Single S_* value should not produce nan (std guard).""" - s_star = [[torch.tensor(5.0)]] - h_wm = [[torch.tensor(1.0)]] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # Single element: |5.0 - 5.0| = 0 <= threshold → should pass - assert mask == [[1.0]] - - -class TestApplyWmcErc: - """Test full WMC-ERC orchestration.""" - - def _make_batch( - self, advantages, response_mask, old_log_probs, attention_mask - ): - """Create a minimal mock batch with required fields.""" - from unittest.mock import MagicMock - + self.assertAlmostEqual(result[0][0].item(), expected, places=5) + + def test_asymmetric_behavior(self): + """Test that mu_base and mu_exp act differently using compute_dynamic_mask.""" + s_star = [[torch.tensor(15.0)], [torch.tensor(5.0)]] # Mean=10 + h_wm = [[torch.tensor(1.0)]] * 2 + s_bar = 10.0 + sigma = 5.0 + h_bar = 1.0 + + # 1. mu_base=0.1 (block high), mu_exp=10.0 (allow low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=0.1, mu_exp=10.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, h_bar=h_bar) + self.assertEqual(mask[0], [0.0]) # High blocked + self.assertEqual(mask[1], [1.0]) # Low allowed + + # 2. mu_base=10.0 (allow high), mu_exp=0.1 (block low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=10.0, mu_exp=0.1, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, h_bar=h_bar) + self.assertEqual(mask[0], [1.0]) # High allowed + self.assertEqual(mask[1], [0.0]) # Low blocked + + def _make_batch(self, advantages, response_mask, old_log_probs, attention_mask): batch = MagicMock() batch.batch = { "advantages": advantages.clone(), @@ -263,141 +83,46 @@ def _make_batch( } return batch - def test_masking_zeros_advantage(self): - """When a turn is masked, its advantages become zero.""" - # 4 samples, 1 turn each (4 action tokens + 0 env tokens) - # Sample 3 has extremely different S_* pattern + def test_clipping_type_batch(self): + """Verify that 'batch' mode uses current batch statistics.""" response_mask = torch.ones(4, 4, dtype=torch.float32) advantages = torch.ones(4, 4) * 2.0 - # Make sample 3 very overconfident (high p_k, low H) - old_log_probs = torch.tensor( - [ - [np.log(0.3)] * 4, - [np.log(0.3)] * 4, - [np.log(0.3)] * 4, - [np.log(0.99)] * 4, # very high confidence - ] - ) - entropys = torch.tensor( - [ - [2.0, 2.0, 2.0, 2.0], - [2.0, 2.0, 2.0, 2.0], - [2.0, 2.0, 2.0, 2.0], - [0.01, 0.01, 0.01, 0.01], # very low entropy - ] - ) + # Sample 3 is an outlier in this batch + old_log_probs = torch.tensor([[np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.9)]*4]) + entropys = torch.tensor([[2.0]*4, [2.0]*4, [2.0]*4, [3.0]*4]) attention_mask = torch.ones(4, 4, dtype=torch.float32) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - - # Check metrics exist - assert "wmc_erc/mask_ratio" in metrics - assert "wmc_erc/s_star_mean" in metrics - assert "wmc_erc/h_wm_mean" in metrics - assert "wmc_erc/total_turns" in metrics - assert metrics["wmc_erc/total_turns"] == 4 - - # Verify masking behavior: - # Samples 0-2: S_* = 0.3*(2.0+log(0.3)) ≈ 0.239 (normal) - # Sample 3: S_* = 0.99*(0.01+log(0.99)) ≈ 0 (outlier in opposite direction) - # H_WM = 0 for all (no env tokens) → tight threshold - # Sample 3 should be masked (|S_3 - S_bar| > threshold) - adv = batch.batch["advantages"] - assert metrics["wmc_erc/num_masked_turns"] >= 1 - assert (adv[3] == 0).all(), "Sample 3 (overconfident outlier) should have zero advantages" - assert (adv[:3] == 2.0).all(), "Samples 0-2 (normal) should keep original advantages" - - def test_disabled(self): - """When enable=False, advantages unchanged.""" - response_mask = torch.ones(2, 4, dtype=torch.float32) - advantages = torch.ones(2, 4) * 5.0 - old_log_probs = torch.full((2, 4), np.log(0.5)) - attention_mask = torch.ones(2, 4, dtype=torch.float32) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - entropys = torch.ones(2, 4) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": False} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert (batch.batch["advantages"] == 5.0).all() - assert metrics == {} - - def test_multi_turn_selective_masking(self): - """Multi-turn: only overconfident turns in known env get masked.""" - # 2 samples, 2 turns each: [act, act, env, env, act, act, env, pad] - response_mask = torch.tensor( - [ - [1, 1, 0, 0, 1, 1, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0], - ], - dtype=torch.float32, - ) - advantages = torch.ones(2, 8) * 3.0 - # Sample 0: normal confidence - # Sample 1: overconfident on both turns - old_log_probs = torch.tensor( - [ - [np.log(0.3), np.log(0.3), 0, 0, np.log(0.3), np.log(0.3), 0, 0], - [np.log(0.99), np.log(0.99), 0, 0, np.log(0.99), np.log(0.99), 0, 0], - ] - ) - entropys = torch.tensor( - [ - [2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.0], - [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0], - ] - ) - attention_mask = torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 0], - ], - dtype=torch.float32, - ) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert metrics["wmc_erc/total_turns"] == 4 - # With low H_WM on sample 1, its overconfident turns should be blocked - # This is a statistical test — the exact outcome depends on batch stats - - def test_returns_wm_nll_metric(self): - """WM NLL metric should be computed from env token log probs.""" - response_mask = torch.tensor( - [[1, 1, 0, 0, 0, 0]], dtype=torch.float32 - ) - advantages = torch.ones(1, 6) - old_log_probs = torch.tensor( - [[np.log(0.5), np.log(0.5), np.log(0.3), np.log(0.4), 0, 0]] - ) - entropys = torch.tensor([[1.0, 1.0, 2.0, 3.0, 0.0, 0.0]]) - attention_mask = torch.tensor( - [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 - ) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 5.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert "wmc_erc/wm_nll" in metrics - # WM NLL = -mean(log_prob at env positions [2,3]) - # = -(log(0.3) + log(0.4)) / 2 - expected_nll = -(np.log(0.3) + np.log(0.4)) / 2.0 - assert abs(metrics["wmc_erc/wm_nll"] - expected_nll) < 1e-4 + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + config = {"mu_base": 0.1, "mu_exp": 10.0, "lambda_wm": 1.0, "enable": True, "clipping_type": "batch"} + running_stats = {"s_bar": 100.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} # Very different global stats + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # In 'batch' mode, Sample 3 should be blocked based on BATCH mean, not GLOBAL mean + # If it used global mean (100), sample 3 (S*~2.6) would be an exploration outlier and allowed by mu_exp=10 + # But in batch mode (S_bar~0.8), sample 3 is a collapsing outlier and blocked by mu_base=0.1 + self.assertTrue((batch.batch["advantages"][3] == 0).all()) + + def test_clipping_type_global(self): + """Verify that 'global' mode uses running statistics.""" + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + old_log_probs = torch.tensor([[np.log(0.3)]*4]*4) + entropys = torch.tensor([[2.0]*4]*4) + # S* for all samples will be ~0.24 + + attention_mask = torch.ones(4, 4, dtype=torch.float32) + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Global s_bar is far away (10.0), so batch S* (0.24) looks like a huge exploration outlier + config = {"mu_base": 10.0, "mu_exp": 0.1, "lambda_wm": 0.0, "enable": True, "clipping_type": "global"} + running_stats = {"s_bar": 10.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # Should be blocked because mu_exp=0.1 is very tight relative to global s_bar + self.assertTrue((batch.batch["advantages"] == 0).all()) if __name__ == "__main__": - pytest.main([__file__, "-v"]) + unittest.main() From 5562471be5e4448538b53c6fe92324f1e3c80f2c Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Thu, 19 Mar 2026 06:43:54 +0000 Subject: [PATCH 2/8] update --- .../backend_patch/verl/trainer/ppo/wmc_erc.py | 66 ++++++++----- .../verl/trainer/ppo/world_model_loss.py | 50 ++++++++++ .../alfworld_wm_loss_clip_param.yaml | 97 +++++++++++++++++++ .../client_config/alfworld_wm_loss_param.yaml | 86 ++++++++++++++++ .../client_config/alfworld_wmc_erc_param.yaml | 8 +- .../alfworld_wmc_erc_param_ppo.yaml | 6 +- .../client/utils/http_training_client.py | 13 +++ opentinker/scheduler/job_scheduler.py | 33 +------ opentinker/scripts/run_alfworld.sh | 1 + opentinker/server/http_training_server.py | 76 ++++++--------- opentinker/tests/test_wmc_erc.py | 30 +++++- 11 files changed, 357 insertions(+), 109 deletions(-) create mode 100644 opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py create mode 100644 opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml create mode 100644 opentinker/client/client_config/alfworld_wm_loss_param.yaml diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py index c06bcf02..e617dbef 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -157,29 +157,33 @@ def compute_dynamic_mask( h_wm_per_sample: List[List[torch.Tensor]], mu_base: float, mu_exp: float, + eta_wm: float, lambda_wm: float, s_bar: float, sigma: float, - h_bar: float, + clipping_method: str = "mask", ) -> List[List[float]]: - """Compute per-turn dynamic entropy clipping mask. + """Compute per-turn dynamic entropy clipping mask or coefficient. Logic: - - Normalization: H_WM is normalized by h_bar (batch or global). - - Threshold: threshold = mu * (1 + lambda * H_WM_norm) * sigma + - WM uncertainty signal: f(H_WM) = eta_wm * exp(-lambda_wm * H_WM) + - Threshold: threshold = mu * f(H_WM) * sigma + - Masking: m_t = 0.0 if violation, 1.0 otherwise + - Clipping: m_t = threshold / deviation if violation, 1.0 otherwise (PPO-style) Args: s_star_per_sample: per-sample, per-turn S_* tensors h_wm_per_sample: per-sample, per-turn H_WM tensors mu_base: clipping coefficient for collapsing side mu_exp: clipping coefficient for exploration side - lambda_wm: WM uncertainty weight + eta_wm: base multiplier for WM uncertainty signal + lambda_wm: exponential decay factor for WM uncertainty s_bar: mean of S_* (batch or global) sigma: std of S_* (batch or global) - h_bar: mean of H_WM (batch or global) + clipping_method: "mask" or "clip" Returns: - List of lists of floats (0.0 or 1.0), one mask per turn per sample. + List of lists of floats, one mask/coeff per turn per sample. """ mask_per_sample = [] for i in range(len(s_star_per_sample)): @@ -188,18 +192,28 @@ def compute_dynamic_mask( s_t = s_star_per_sample[i][t].detach().item() h_t = h_wm_per_sample[i][t].detach().item() - # Normalize H_WM - h_t_norm = h_t / (h_bar + 1e-8) + # WM uncertainty signal: f(H_WM) = eta_wm * exp(-lambda_wm * H_WM) + h_factor = eta_wm * np.exp(-lambda_wm * h_t) # Asymmetric threshold calculation if s_t > s_bar: # Collapsing side - threshold = mu_base * (1.0 + lambda_wm * h_t_norm) * sigma - m_t = 1.0 if (s_t - s_bar) <= threshold else 0.0 + threshold = mu_base * h_factor * sigma + diff = s_t - s_bar + if clipping_method == "mask": + m_t = 1.0 if diff <= threshold else 0.0 + else: # PPO-style clipping + # If diff > threshold, we scale the advantage by threshold/diff + # such that the effective update is capped at threshold + m_t = min(1.0, threshold / (diff + 1e-8)) else: # Exploration side - threshold = mu_exp * (1.0 + lambda_wm * h_t_norm) * sigma - m_t = 1.0 if (s_bar - s_t) <= threshold else 0.0 + threshold = mu_exp * h_factor * sigma + diff = s_bar - s_t + if clipping_method == "mask": + m_t = 1.0 if diff <= threshold else 0.0 + else: # PPO-style clipping + m_t = min(1.0, threshold / (diff + 1e-8)) masks.append(m_t) mask_per_sample.append(masks) @@ -229,6 +243,7 @@ def apply_wmc_erc( return batch, {} clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch") + clipping_method = wmc_erc_config.get("clipping_method", "mask") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_method', "mask") response_mask = batch.batch["response_mask"] old_log_probs = batch.batch["old_log_probs"] @@ -277,19 +292,20 @@ def apply_wmc_erc( use_s_std = batch_s_std use_h_bar = batch_h_bar - # 4. Dynamic mask + # 4. Dynamic mask/clip mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0)) mu_exp = float(wmc_erc_config.get("mu_exp", 2.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_exp', 2.0)) + eta_wm = float(wmc_erc_config.get("eta_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'eta_wm', 1.0)) lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0)) mask = compute_dynamic_mask( - s_star, h_wm, mu_base, mu_exp, lambda_wm, + s_star, h_wm, mu_base, mu_exp, eta_wm, lambda_wm, s_bar=use_s_bar, sigma=use_s_std, - h_bar=use_h_bar + clipping_method=clipping_method ) - # 5. Apply mask to advantages + # 5. Apply mask/coeff to advantages batch_size = advantages.shape[0] for i in range(batch_size): for t, (start, end) in enumerate(turn_boundaries[i]): @@ -300,15 +316,15 @@ def apply_wmc_erc( # 6. Metrics all_m = [m for turns in mask for m in turns] - num_collapsing_masked = 0 - num_exploration_masked = 0 + num_collapsing_violated = 0 + num_exploration_violated = 0 for i in range(len(s_star)): for t in range(len(s_star[i])): - if mask[i][t] == 0.0: + if mask[i][t] < 1.0: if s_star[i][t].item() > use_s_bar: - num_collapsing_masked += 1 + num_collapsing_violated += 1 else: - num_exploration_masked += 1 + num_exploration_violated += 1 env_mask = attention_mask_response * (1.0 - response_mask) env_count = env_mask.sum() @@ -322,9 +338,9 @@ def apply_wmc_erc( "wmc_erc/running_s_std": float(running_stats["s_std"]), "wmc_erc/running_h_bar": float(running_stats["h_bar"]), "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, - "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), - "wmc_erc/num_collapsing_masked": num_collapsing_masked, - "wmc_erc/num_exploration_masked": num_exploration_masked, + "wmc_erc/num_violated_turns": sum(1 for m in all_m if m < 1.0), + "wmc_erc/num_collapsing_violated": num_collapsing_violated, + "wmc_erc/num_exploration_violated": num_exploration_violated, "wmc_erc/total_turns": len(all_m), "wmc_erc/wm_nll": wm_nll, } diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py new file mode 100644 index 00000000..d5fea5c4 --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py @@ -0,0 +1,50 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +World Model SFT Loss — trains the policy to also predict observation tokens. + +Joint loss = ppo_loss(action tokens) + wm_coeff * sft_loss(observation tokens) + +Implementation: + The WM SFT loss is computed inside dp_actor.update_policy() (verl modification). + It is gated by config.world_model_coeff > 0 AND observation_mask being present + in the batch. + + observation_mask is computed in http_training_server.py before update_actor: + obs_mask = attention_mask[:, -resp_len:] & ~response_mask + + To enable, set in your config yaml: + actor_rollout_ref: + actor: + world_model_coeff: 0.1 +""" + +import torch + + +def compute_observation_mask(batch) -> torch.Tensor: + """Compute observation_mask from attention_mask and response_mask. + + observation tokens = real tokens in the response portion that are NOT + action (LLM-generated) tokens. + + Args: + batch: DataProto with batch["attention_mask"] and batch["response_mask"] + + Returns: + observation_mask: (batch_size, response_length) float tensor + """ + resp_len = batch.batch["response_mask"].shape[1] + attn_response = batch.batch["attention_mask"][:, -resp_len:] + return (attn_response.bool() & ~batch.batch["response_mask"].bool()).float() \ No newline at end of file diff --git a/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml b/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml new file mode 100644 index 00000000..2405f3e1 --- /dev/null +++ b/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml @@ -0,0 +1,97 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc_wm_loss_0.0001 + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "grpo" +rollout_n: 8 + +# World model SFT loss: predict environment observation tokens as auxiliary task +world_model_coeff: 0.0001 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - eta_wm: base multiplier for WM uncertainty signal +# - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 2.0 + eta_wm: 3.0 + lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc_wm_loss" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wm_loss_param.yaml b/opentinker/client/client_config/alfworld_wm_loss_param.yaml new file mode 100644 index 00000000..74cb07eb --- /dev/null +++ b/opentinker/client/client_config/alfworld_wm_loss_param.yaml @@ -0,0 +1,86 @@ +# ALFWorld Training Configuration +# Use with: python alfworld_rl.py + +# Project settings +project_name: opentinker +experiment_name: alfworld_wm_loss_only_0.0001 + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: 2ed6f8544ac3e30d5c08879166cc10d9c6232448 + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +# Training duration - set ONE of these (num_steps takes precedence if both set) +num_epochs: null # Number of epochs (null = use num_steps) +num_steps: 1000 # Total training steps (null = use num_epochs) +save_freq: 200 +test_freq: 50 # Validation frequency (every N steps) + +# Validation parameters +val_batch_size: 50 # Total validation samples (null = 50) + +# Model parameters +# Generation parameters +temperature: 1 # Lower temperature for more focused responses +top_p: 1 +max_new_tokens: 4096 # TOTAL response budget for entire multi-turn trajectory (NOT per-turn!) +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# RL Algorithm settings (passed to server via scheduler) +# adv_estimator options: +# - "grpo" : Standard GRPO (outcome-only advantage) +# - "grpo_per_step" : Per-step GRPO with return-based advantages (for multi-turn tasks) +# - "gae" : Generalized Advantage Estimation (for PPO, requires critic) +adv_estimator: "grpo" +# rollout_n: number of samples per prompt for GRPO/grpo_per_step +# For PPO (gae), rollout_n is typically 1 +rollout_n: 8 + +# World model SFT loss: predict environment observation tokens as auxiliary task +world_model_coeff: 0.0001 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + # If you run the ALFWorld env server in sharded mode (--shards N), + # set env_shards=N. The client will route each instance_id to a stable shard. + env_shards: 8 + max_steps: 20 # ALFWorld episodes max steps + max_total_steps: 20 # Max environment step calls (controls rollout turns) + observation_template: "{observation}" + # ALFWorld specific settings + split: train # train, eval_in_distribution, eval_out_of_distribution + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 # Per-turn response limit (optional, null for no limit) + # Weave tracing (optional - runs on SERVER side) + weave_project: "zsqzz/alfworld-env-test" + experiment_name: "wm_loss_only" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index a788ce24..21a089fc 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -9,7 +9,7 @@ # Project settings project_name: opentinker -experiment_name: alfworld_wmc_erc +experiment_name: alfworld_wmc_erc_part_mask # Logging logger_backends: ["console", "wandb"] @@ -51,16 +51,20 @@ rollout_n: 8 # WMC-ERC: Dynamic Entropy Clipping # - mu_base: clipping coefficient for collapsing side (S_* > S_bar) # - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) -# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - eta_wm: base multiplier for WM uncertainty signal +# - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) # - clipping_type: "batch" or "global" (global uses running statistics) +# - clipping_method: "mask" (0/1 gating) or "clip" (PPO-style soft clipping) # - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch wmc_erc: enable: true mu_base: 1.0 mu_exp: 2.0 + eta_wm: 3.0 lambda_wm: 1.0 clipping_type: "global" + clipping_method: "mask" momentum: 0.9 # Interaction configuration diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml index 94cf216e..d52f5c88 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml @@ -51,7 +51,8 @@ rollout_n: 1 # WMC-ERC: Dynamic Entropy Clipping # - mu_base: clipping coefficient for collapsing side (S_* > S_bar) # - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) -# - lambda_wm: how much WM uncertainty (normalized) widens the gate +# - eta_wm: base multiplier for WM uncertainty signal +# - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) # - clipping_type: "batch" or "global" (global uses running statistics) # - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch @@ -59,6 +60,7 @@ wmc_erc: enable: true mu_base: 1.0 mu_exp: 2.0 + eta_wm: 1.0 lambda_wm: 1.0 clipping_type: "global" momentum: 0.9 @@ -82,7 +84,7 @@ multi_turn: max_assistant_turns: ${interaction.config.max_total_steps} max_tokens_per_turn: 512 weave_project: null - experiment_name: "alfworld_wmc_erc" + experiment_name: "alfworld_wmc_erc_ppo" # Scheduler settings scheduler_url: "http://0.0.0.0:8780" diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 958ba2d2..1040f725 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -637,6 +637,19 @@ def set_config(self, args: DictConfig, env=None): }, } ) + + # Optional world model SFT coefficient for joint PPO + WM training. + world_model_coeff = args.get("world_model_coeff", None) + if world_model_coeff is not None: + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create( + {"algorithm": {"world_model_coeff": float(world_model_coeff)}} + ), + ) + print( + f"[ServiceClient] Forwarding algorithm.world_model_coeff={world_model_coeff}" + ) # Add multi_turn config if present in args if hasattr(args, "multi_turn") and args.multi_turn: diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index abc5b42b..dbe9c96e 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -284,7 +284,7 @@ def check_gpu_available(gpu_id: int) -> bool: return True # Fail open # Thresholds for considering a GPU "idle" - MAX_MEMORY_MB = 10 # Allow up to 100 MB (some baseline CUDA overhead) + MAX_MEMORY_MB = 2000 # Allow up to 2 GB (root processes may use ~1 GB) MAX_UTILIZATION = 1000 # Allow up to 5% utilization if memory_used_mb > MAX_MEMORY_MB or utilization_percent > MAX_UTILIZATION: @@ -294,35 +294,8 @@ def check_gpu_available(gpu_id: int) -> bool: ) return False - # Check 2: Look for running processes on this GPU - pmon_result = subprocess.run( - ["nvidia-smi", "pmon", "-c", "1", "-s", "um"], - capture_output=True, - text=True, - timeout=5, - ) - - if pmon_result.returncode == 0: - # Parse pmon output to check for processes on this GPU - # Format: "# gpu pid type sm mem enc dec command" - # " 0 12345 C 50 500 0 0 python" - lines = pmon_result.stdout.strip().split("\n") - for line in lines: - if line.startswith("#") or not line.strip(): - continue - parts = line.split() - if len(parts) >= 2: - try: - gpu_idx = int(parts[0].strip()) - if gpu_idx == gpu_id and parts[1].strip() != "-": - # Found a process on this GPU - pid = parts[1].strip() - logger.warning( - f"GPU {gpu_id}: ⚠️ OCCUPIED - Process {pid} detected via pmon" - ) - return False - except (ValueError, IndexError): - continue + # pmon check disabled — root/system processes cause false positives + # when sharing GPUs across users. Memory threshold check above is sufficient. # All checks passed - GPU is idle logger.debug( diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh index 97ab8252..d4911a92 100755 --- a/opentinker/scripts/run_alfworld.sh +++ b/opentinker/scripts/run_alfworld.sh @@ -25,6 +25,7 @@ MODEL_PATH="/inspire/hdd/project/robot-reasoning/xuyue-p-xuyue/ziyu/.cache/huggi # OpenTinker root (relative to this script: opentinker/scripts/run_alfworld.sh) OPENTINKER_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" +export NCCL_CUMEM_ENABLE=0 export VLLM_DISABLE_SLEEP_MODE=1 export HF_HUB_OFFLINE=1 export WANDB_MODE=offline diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 9ef27d05..3636ce35 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -27,7 +27,6 @@ import asyncio import base64 -import gc import logging import signal import sys @@ -85,6 +84,7 @@ ) + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -640,15 +640,7 @@ def __init__( # Server state self.is_initialized = False self.global_steps = 0 - - # WMC-ERC running statistics (S_bar, Sigma, H_bar) - self.wmc_erc_stats = { - "s_bar": 0.0, - "s_std": 1.0, - "h_bar": 1.0, - "momentum": 0.9, # EMA momentum - "initialized": False, - } + self.wm_coeff = 0.0 # Generation config (can be overridden by client) self.generation_config = { @@ -687,6 +679,15 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]: try: # optimizer needs parameter: total_steps self.trainer.post_init(total_steps) + + # Forward algorithm.world_model_coeff → actor config so dp_actor can read it + algo_wm_coeff = self.config.algorithm.get("world_model_coeff", 0.0) + if algo_wm_coeff > 0: + from omegaconf import open_dict + with open_dict(self.config): + self.config.actor_rollout_ref.actor.world_model_coeff = algo_wm_coeff + logger.info(f"Forwarded world_model_coeff={algo_wm_coeff} to actor config") + logger.info("Initializing workers...") # Check async rollout mode @@ -717,6 +718,11 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]: if self.async_rollout_mode: self.async_rollout_manager = self.trainer.async_rollout_manager + # World model SFT loss coeff (actual loss computed in dp_actor via config.world_model_coeff) + self.wm_coeff = self.config.actor_rollout_ref.actor.get("world_model_coeff", 0.0) + if self.wm_coeff > 0: + logger.info(f"World model SFT loss enabled with coeff={self.wm_coeff}") + self.is_initialized = True logger.info("Workers initialized successfully") return {"status": "success", "message": "Workers initialized"} @@ -953,9 +959,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # episodes into individual per-turn training samples, the gen_batch_output # batch size is larger than the original batch. We need to expand the # original batch to match using the expansion index. - expansion_index = gen_batch_output.meta_info.pop( - "per_turn_expansion_index", None - ) + expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) if expansion_index is not None: logger.info( f"[Per-turn training] Expanding original batch from {len(batch)} to " @@ -968,7 +972,6 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict - batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) # Expand non-tensor batch expanded_non_tensor = {} @@ -1093,7 +1096,6 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) - _wmc_erc_entropys = entropys # Preserve for WMC-ERC before pop old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) logger.info( @@ -1175,24 +1177,6 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) - # 9.5 WMC-ERC: Dynamic entropy clipping - wmc_erc_cfg = OmegaConf.select(self.config, "wmc_erc", default=None) - if wmc_erc_cfg and wmc_erc_cfg.get("enable", False): - from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( - apply_wmc_erc, - ) - - batch, wmc_metrics = apply_wmc_erc( - batch, _wmc_erc_entropys, wmc_erc_cfg, self.wmc_erc_stats - ) - metrics.update(wmc_metrics) - clipping_mode = wmc_erc_cfg.get("clipping_type", "batch") - logger.info( - f"[WMC-ERC] mode={clipping_mode}, mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " - f"s_star={wmc_metrics.get('wmc_erc/batch_s_bar', 'N/A'):.4f}, " - f"h_wm={wmc_metrics.get('wmc_erc/batch_h_bar', 'N/A'):.4f}" - ) - # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1202,6 +1186,14 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: ) metrics.update(critic_output_metrics) + # 10.5 Compute observation_mask for world model loss + if self.wm_coeff > 0: + resp_len = batch.batch["response_mask"].shape[1] + attn_response = batch.batch["attention_mask"][:, -resp_len:] + batch.batch["observation_mask"] = ( + attn_response.bool() & ~batch.batch["response_mask"].bool() + ).float() + # 11. Update actor (check critic warmup) if self.config.trainer.critic_warmup <= self.global_steps: with marked_timer("update_actor", timing_raw, color="red"): @@ -1290,13 +1282,6 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info(f"Training step {self.global_steps} completed successfully") - # Free large intermediates and force garbage collection to prevent OOM - del batch, gen_batch, gen_batch_output, reward_tensor - if "_wmc_erc_entropys" in dir(): - del _wmc_erc_entropys - gc.collect() - torch.cuda.empty_cache() - return { "status": "success", "metrics": metrics, @@ -1574,9 +1559,7 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: # 6. Merge original batch and generated output # Per-turn training expansion: expand batch if gen output is larger - expansion_index = gen_batch_output.meta_info.pop( - "per_turn_expansion_index", None - ) + expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) if expansion_index is not None: logger.info( f"[Per-turn training] Validation: Expanding original batch from {len(batch)} to " @@ -1588,7 +1571,6 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict - batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) expanded_non_tensor = {} for k, v in batch.non_tensor_batch.items(): @@ -2146,13 +2128,14 @@ def run_fastapi_server(): ) ray.init( namespace=_server_cfg.ray.namespace, - num_gpus=_server_cfg.trainer.n_gpus_per_node, # Explicitly specify number of GPUs + num_gpus=_server_cfg.trainer.n_gpus_per_node, ignore_reinit_error=True, runtime_env={ "env_vars": { "NCCL_CUMEM_ENABLE": "0", "VLLM_DISABLE_SLEEP_MODE": "1", "RAY_memory_usage_threshold": "0.99", + "VLLM_GPU_MEMORY_UTILIZATION": "0.15", }, }, ) @@ -2168,6 +2151,7 @@ def run_fastapi_server(): "NCCL_CUMEM_ENABLE": "0", "VLLM_DISABLE_SLEEP_MODE": "1", "RAY_memory_usage_threshold": "0.99", + "VLLM_GPU_MEMORY_UTILIZATION": "0.15", }, }, ) @@ -2304,4 +2288,4 @@ def run_fastapi_server(): if __name__ == "__main__": # Example usage print("HTTP Training Server") - print("Use launch_server() to start the server") + print("Use launch_server() to start the server") \ No newline at end of file diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py index f083277a..cc0a9dc6 100644 --- a/opentinker/tests/test_wmc_erc.py +++ b/opentinker/tests/test_wmc_erc.py @@ -62,14 +62,14 @@ def test_asymmetric_behavior(self): h_bar = 1.0 # 1. mu_base=0.1 (block high), mu_exp=10.0 (allow low) - mask = compute_dynamic_mask(s_star, h_wm, mu_base=0.1, mu_exp=10.0, lambda_wm=0.0, - s_bar=s_bar, sigma=sigma, h_bar=h_bar) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=0.1, mu_exp=10.0, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma) self.assertEqual(mask[0], [0.0]) # High blocked self.assertEqual(mask[1], [1.0]) # Low allowed # 2. mu_base=10.0 (allow high), mu_exp=0.1 (block low) - mask = compute_dynamic_mask(s_star, h_wm, mu_base=10.0, mu_exp=0.1, lambda_wm=0.0, - s_bar=s_bar, sigma=sigma, h_bar=h_bar) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=10.0, mu_exp=0.1, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma) self.assertEqual(mask[0], [1.0]) # High allowed self.assertEqual(mask[1], [0.0]) # Low blocked @@ -123,6 +123,28 @@ def test_clipping_type_global(self): # Should be blocked because mu_exp=0.1 is very tight relative to global s_bar self.assertTrue((batch.batch["advantages"] == 0).all()) + def test_clipping_method_clip(self): + """Verify that 'clip' method scales advantages instead of zeroing them.""" + # Setup data such that we have a violation + # S* = 15.0, s_bar = 10.0, sigma = 1.0, mu_base = 1.0, h_factor = 1.0 + # diff = 5.0, threshold = 1.0 + # m_t should be 1.0 / 5.0 = 0.2 + + s_star = [[torch.tensor(15.0)]] + h_wm = [[torch.tensor(0.0)]] # lambda_wm=0, eta_wm=1 -> h_factor=1 + s_bar = 10.0 + sigma = 1.0 + + # 1. Test clip + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, mu_exp=1.0, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, clipping_method="clip") + self.assertAlmostEqual(mask[0][0], 0.2, places=5) + + # 2. Test mask (for comparison) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, mu_exp=1.0, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, clipping_method="mask") + self.assertEqual(mask[0][0], 0.0) + if __name__ == "__main__": unittest.main() From 86a6da5fbfc69a8c18499d53a45e8e9f60602f74 Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Wed, 25 Mar 2026 15:02:41 +0000 Subject: [PATCH 3/8] Clip by scaling instead of masking. --- .../backend_patch/verl/trainer/ppo/wmc_erc.py | 41 ++++++++-- .../client/client_config/alfworld_param.yaml | 2 +- .../alfworld_wm_loss_clip_param.yaml | 20 +++-- .../client_config/alfworld_wm_loss_param.yaml | 8 +- .../client_config/alfworld_wmc_erc_param.yaml | 18 ++++- .../alfworld_wmc_erc_param_ppo.yaml | 1 + .../client/utils/http_training_client.py | 18 ++++- opentinker/scripts/run_alfworld.sh | 2 +- opentinker/server/http_training_server.py | 60 +++++++++++++- opentinker/server/launch_http_server.py | 8 +- opentinker/tests/test_wmc_erc.py | 81 +++++++++++++++++++ 11 files changed, 234 insertions(+), 25 deletions(-) diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py index e617dbef..daf01e34 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -82,6 +82,8 @@ def compute_s_star( if count > 0: p_k = torch.exp(log_p) s_token = p_k * (H + log_p) + # Handle potential NaN due to 0 * -inf when p_k is 0 and log_p is -inf + s_token = torch.nan_to_num(s_token, nan=0.0) s_t = (s_token * mask).sum() / count else: s_t = torch.tensor(0.0, device=device) @@ -205,7 +207,8 @@ def compute_dynamic_mask( else: # PPO-style clipping # If diff > threshold, we scale the advantage by threshold/diff # such that the effective update is capped at threshold - m_t = min(1.0, threshold / (diff + 1e-8)) + u_t = min(1.0, diff / (threshold + 1e-8)) + m_t = 1.0 / (1.0 + 0.5 * u_t) else: # Exploration side threshold = mu_exp * h_factor * sigma @@ -213,7 +216,7 @@ def compute_dynamic_mask( if clipping_method == "mask": m_t = 1.0 if diff <= threshold else 0.0 else: # PPO-style clipping - m_t = min(1.0, threshold / (diff + 1e-8)) + m_t = 1.0 masks.append(m_t) mask_per_sample.append(masks) @@ -244,6 +247,8 @@ def apply_wmc_erc( clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch") clipping_method = wmc_erc_config.get("clipping_method", "mask") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_method', "mask") + clip_positive_only = wmc_erc_config.get("clip_positive_only", False) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clip_positive_only', False) + inverse_sft_mask = wmc_erc_config.get("inverse_sft_mask", False) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'inverse_sft_mask', False) response_mask = batch.batch["response_mask"] old_log_probs = batch.batch["old_log_probs"] @@ -272,11 +277,10 @@ def apply_wmc_erc( # Update global statistics momentum = wmc_erc_config.get("momentum", 0.9) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'momentum', 0.9) - if not running_stats.get("initialized", False): + if len(running_stats.keys()) == 0: running_stats["s_bar"] = batch_s_bar running_stats["s_std"] = batch_s_std running_stats["h_bar"] = batch_h_bar - running_stats["initialized"] = True else: running_stats["s_bar"] = (1 - momentum) * batch_s_bar + momentum * running_stats["s_bar"] running_stats["s_std"] = (1 - momentum) * batch_s_std + momentum * running_stats["s_std"] @@ -307,11 +311,38 @@ def apply_wmc_erc( # 5. Apply mask/coeff to advantages batch_size = advantages.shape[0] + + if inverse_sft_mask: + sft_weights = torch.zeros_like(advantages) + env_mask = attention_mask_response * (1.0 - response_mask) + for i in range(batch_size): for t, (start, end) in enumerate(turn_boundaries[i]): if t < len(mask[i]): - advantages[i, start:end] *= mask[i][t] + m_t = mask[i][t] + + if inverse_sft_mask: + # Env tokens after this turn: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(turn_boundaries[i]): + env_end = turn_boundaries[i][t + 1][0] + else: + env_end = response_length + + sft_weight = 1.0 - m_t + region_mask = env_mask[i, end:env_end] + sft_weights[i, end:env_end] = region_mask * sft_weight + + if m_t < 1.0: + if clip_positive_only: + # Only apply scaling where advantages > 0 + turn_adv = advantages[i, start:end] + advantages[i, start:end] = torch.where(turn_adv > 0, turn_adv * m_t, turn_adv) + else: + advantages[i, start:end] *= m_t batch.batch["advantages"] = advantages + + if inverse_sft_mask: + batch.batch["sft_weights"] = sft_weights # 6. Metrics all_m = [m for turns in mask for m in turns] diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index c3fac911..381ed652 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -3,7 +3,7 @@ # Project settings project_name: opentinker -experiment_name: alfworld_training +experiment_name: alfworld_training_baseline_grpo # Logging logger_backends: ["console", "wandb"] diff --git a/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml b/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml index 2405f3e1..304a17e7 100644 --- a/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml +++ b/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml @@ -9,7 +9,7 @@ # Project settings project_name: opentinker -experiment_name: alfworld_wmc_erc_wm_loss_0.0001 +experiment_name: alfworld_wmc_erc_clip_simple_wm_loss_0.001 # Logging logger_backends: ["console", "wandb"] @@ -23,6 +23,7 @@ wandb_key: null # Model and tokenizer tokenizer_path: null +# enable_sleep_mode: true # Training parameters batch_size: 8 @@ -48,8 +49,11 @@ algorithm: "agent_loop" adv_estimator: "grpo" rollout_n: 8 -# World model SFT loss: predict environment observation tokens as auxiliary task -world_model_coeff: 0.0001 +# World Model SFT loss: predict environment observation tokens as auxiliary task +world_model_loss: + world_model_coeff: 0.001 + world_model_annealing_steps: 0 + world_model_annealing_end_factor: 0.0 # WMC-ERC: Dynamic Entropy Clipping # - mu_base: clipping coefficient for collapsing side (S_* > S_bar) @@ -57,15 +61,21 @@ world_model_coeff: 0.0001 # - eta_wm: base multiplier for WM uncertainty signal # - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) # - clipping_type: "batch" or "global" (global uses running statistics) +# - clipping_method: "mask" (0/1 gating) or "clip" (PPO-style soft clipping) +# - clip_positive_only: if true, only apply clipping to tokens with positive advantages +# - inverse_sft_mask: if true, use (1 - m_t) as weights for SFT loss on subsequent env tokens # - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch wmc_erc: enable: true mu_base: 1.0 - mu_exp: 2.0 - eta_wm: 3.0 + mu_exp: 3.0 + eta_wm: 2.0 lambda_wm: 1.0 clipping_type: "global" + clipping_method: "clip" + clip_positive_only: false + inverse_sft_mask: false momentum: 0.9 # Interaction configuration diff --git a/opentinker/client/client_config/alfworld_wm_loss_param.yaml b/opentinker/client/client_config/alfworld_wm_loss_param.yaml index 74cb07eb..5ca6eb44 100644 --- a/opentinker/client/client_config/alfworld_wm_loss_param.yaml +++ b/opentinker/client/client_config/alfworld_wm_loss_param.yaml @@ -17,6 +17,7 @@ wandb_key: 2ed6f8544ac3e30d5c08879166cc10d9c6232448 # Model and tokenizer tokenizer_path: null +enable_sleep_mode: false # Training parameters batch_size: 8 @@ -50,8 +51,11 @@ adv_estimator: "grpo" # For PPO (gae), rollout_n is typically 1 rollout_n: 8 -# World model SFT loss: predict environment observation tokens as auxiliary task -world_model_coeff: 0.0001 +# World Model SFT loss: predict environment observation tokens as auxiliary task +world_model_loss: + world_model_coeff: 0.0001 + world_model_annealing_steps: 0 + world_model_annealing_end_factor: 0.0 # Interaction configuration interaction: diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index 21a089fc..5566309d 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -9,7 +9,7 @@ # Project settings project_name: opentinker -experiment_name: alfworld_wmc_erc_part_mask +experiment_name: alfworld_wmc_erc_clip # Logging logger_backends: ["console", "wandb"] @@ -23,6 +23,7 @@ wandb_key: null # Model and tokenizer tokenizer_path: null +enable_sleep_mode: false # Training parameters batch_size: 8 @@ -55,18 +56,27 @@ rollout_n: 8 # - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) # - clipping_type: "batch" or "global" (global uses running statistics) # - clipping_method: "mask" (0/1 gating) or "clip" (PPO-style soft clipping) +# - clip_positive_only: if true, only apply clipping to tokens with positive advantages +# - inverse_sft_mask: if true, use (1 - m_t) as weights for SFT loss on subsequent env tokens # - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch wmc_erc: enable: true mu_base: 1.0 - mu_exp: 2.0 - eta_wm: 3.0 + mu_exp: 3.0 + eta_wm: 2.0 lambda_wm: 1.0 clipping_type: "global" - clipping_method: "mask" + clipping_method: "clip" + clip_positive_only: false + inverse_sft_mask: false momentum: 0.9 +# World Model SFT loss: predict environment observation tokens as auxiliary task +world_model_coeff: 0.0 +world_model_annealing_steps: 0 +world_model_annealing_end_factor: 1.0 + # Interaction configuration interaction: name: alfworld diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml index d52f5c88..612ddea1 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml @@ -23,6 +23,7 @@ wandb_key: null # Model and tokenizer tokenizer_path: null +enable_sleep_mode: false # Training parameters batch_size: 8 diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 1040f725..36fb7d37 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -638,17 +638,27 @@ def set_config(self, args: DictConfig, env=None): } ) + # Pass WMC-ERC config to server if present + wmc_erc_cfg = getattr(args, "wmc_erc", None) + if wmc_erc_cfg is not None: + wmc_erc_dict = OmegaConf.to_container(wmc_erc_cfg, resolve=True) + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create({"wmc_erc": wmc_erc_dict}), + ) + print(f"[ServiceClient] Passing WMC-ERC config to server: {wmc_erc_dict}") + # Optional world model SFT coefficient for joint PPO + WM training. - world_model_coeff = args.get("world_model_coeff", None) - if world_model_coeff is not None: + if hasattr(args, "world_model_loss") and args.world_model_loss: + world_model_loss_cfg = OmegaConf.to_container(args.world_model_loss, resolve=True) server_cfg = OmegaConf.merge( server_cfg, OmegaConf.create( - {"algorithm": {"world_model_coeff": float(world_model_coeff)}} + {"algorithm": {"world_model_loss": world_model_loss_cfg}} ), ) print( - f"[ServiceClient] Forwarding algorithm.world_model_coeff={world_model_coeff}" + f"[ServiceClient] Passing world_model_loss config to server: {world_model_loss_cfg}" ) # Add multi_turn config if present in args diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh index d4911a92..c11dd0e9 100755 --- a/opentinker/scripts/run_alfworld.sh +++ b/opentinker/scripts/run_alfworld.sh @@ -74,7 +74,7 @@ case "$1" in echo "Step 3: Starting ALFWorld RL Client" echo "========================================" python opentinker/client/alfworld_rl.py \ - --config-name alfworld_wmc_erc_param \ + --config-name alfworld_wm_loss_clip_param \ tokenizer_path="$MODEL_PATH" \ batch_size=4 \ val_batch_size=50 \ diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 3636ce35..b52641e3 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -27,6 +27,7 @@ import asyncio import base64 +import gc import logging import signal import sys @@ -620,6 +621,7 @@ def __init__( self.resource_pool_manager = resource_pool_manager self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn + self.running_stats: Dict[str, float] = dict() # Initialize trainer (without dataloader - client provides custom dataloader) from verl.single_controller.ray import RayWorkerGroup @@ -677,16 +679,31 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]: Status dict """ try: + # Reset server state for new training session + self.global_steps = 0 + self.running_stats.clear() + logger.info("Reset server state (global_steps and running_stats) for new worker initialization") + # optimizer needs parameter: total_steps self.trainer.post_init(total_steps) # Forward algorithm.world_model_coeff → actor config so dp_actor can read it - algo_wm_coeff = self.config.algorithm.get("world_model_coeff", 0.0) + if hasattr(self.config.algorithm, "world_model_loss"): + algo_wm_coeff = self.config.algorithm.world_model_loss.get("world_model_coeff", 0.0) + algo_wm_annealing_steps = self.config.algorithm.world_model_loss.get("world_model_annealing_steps", 0) + algo_wm_annealing_end_factor = self.config.algorithm.world_model_loss.get("world_model_annealing_end_factor", 1.0) + else: + algo_wm_coeff = 0 + algo_wm_annealing_steps = 0 + algo_wm_annealing_end_factor = 0 + if algo_wm_coeff > 0: from omegaconf import open_dict with open_dict(self.config): self.config.actor_rollout_ref.actor.world_model_coeff = algo_wm_coeff - logger.info(f"Forwarded world_model_coeff={algo_wm_coeff} to actor config") + self.config.actor_rollout_ref.actor.world_model_annealing_steps = algo_wm_annealing_steps + self.config.actor_rollout_ref.actor.world_model_annealing_end_factor = algo_wm_annealing_end_factor + logger.info(f"Forwarded world_model_coeff={algo_wm_coeff} (annealing_steps={algo_wm_annealing_steps}) to actor config") logger.info("Initializing workers...") @@ -1096,6 +1113,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) + _wmc_erc_entropys = entropys # Preserve for WMC-ERC before pop old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) logger.info( @@ -1177,6 +1195,24 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) + # 9.5 WMC-ERC: Dynamic entropy clipping + wmc_erc_cfg = OmegaConf.select(self.config, "wmc_erc", default=None) + if wmc_erc_cfg and wmc_erc_cfg.get("enable", False): + from opentinker.backend_patch.verl.trainer.ppo.wmc_erc import ( + apply_wmc_erc, + ) + + batch, wmc_metrics = apply_wmc_erc( + batch, _wmc_erc_entropys, wmc_erc_cfg, self.running_stats + ) + metrics.update(wmc_metrics) + logger.info( + f"[WMC-ERC] mask_ratio={float(wmc_metrics.get('wmc_erc/mask_ratio', float('nan'))):.3f}, " + f"s_star={float(wmc_metrics.get('wmc_erc/s_star_mean', float('nan'))):.4f}, " + f"h_wm={float(wmc_metrics.get('wmc_erc/h_wm_mean', float('nan'))):.4f}" + ) + + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1282,6 +1318,26 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info(f"Training step {self.global_steps} completed successfully") + # Free large intermediates and force garbage collection to prevent OOM + # Explicitly delete all potential large objects in local scope + del batch, gen_batch, gen_batch_output, reward_tensor + if "old_log_prob" in locals(): + del old_log_prob + if "ref_log_prob" in locals(): + del ref_log_prob + if "values" in locals(): + del values + if "gen_baseline_output" in locals(): + del gen_baseline_output + if "gen_baseline_batch" in locals(): + del gen_baseline_batch + + if "_wmc_erc_entropys" in locals(): + del _wmc_erc_entropys + + gc.collect() + torch.cuda.empty_cache() + return { "status": "success", "metrics": metrics, diff --git a/opentinker/server/launch_http_server.py b/opentinker/server/launch_http_server.py index 830af82c..57843350 100644 --- a/opentinker/server/launch_http_server.py +++ b/opentinker/server/launch_http_server.py @@ -64,7 +64,13 @@ def main(cfg): cfg.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 cfg.actor_rollout_ref.rollout.name = "vllm" - cfg.actor_rollout_ref.rollout.gpu_memory_utilization = 0.6 + cfg.actor_rollout_ref.rollout.gpu_memory_utilization = 0.4 + # # vLLM sleep mode setting + # if "enable_sleep_mode" in cfg: + # cfg.actor_rollout_ref.rollout.free_cache_engine = cfg.enable_sleep_mode + # else: + # # Default to False if not specified (per user request to ensure it doesn't start) + # cfg.actor_rollout_ref.rollout.free_cache_engine = False # GRPO/GRPO-per-step 特定配置 # grpo_per_step uses the same training framework as grpo, just with different advantage estimation diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py index cc0a9dc6..df534bc4 100644 --- a/opentinker/tests/test_wmc_erc.py +++ b/opentinker/tests/test_wmc_erc.py @@ -145,6 +145,87 @@ def test_clipping_method_clip(self): s_bar=s_bar, sigma=sigma, clipping_method="mask") self.assertEqual(mask[0][0], 0.0) + def test_clip_positive_only(self): + """Verify that clip_positive_only only affects positive advantages.""" + # Two turns to create variance so s_t != s_bar + response_mask = torch.tensor([[1, 0, 1, 0]], dtype=torch.float32) + attention_mask = torch.ones(1, 4, dtype=torch.float32) + # Turn 1 adv = 10, Turn 2 adv = -10 + advantages = torch.tensor([[10.0, 0.0, -10.0, 0.0]]) + # Make turn 1 very confident (p=0.9), turn 2 very uncertain (p=0.1) + old_log_probs = torch.tensor([[np.log(0.9), 0.0, np.log(0.1), 0.0]]) + entropys = torch.tensor([[0.1, 0.0, 2.0, 0.0]]) + + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Force a mask (m_t = 0.0) by setting very tight mu + config = { + "mu_base": 0.0001, + "mu_exp": 0.0001, + "lambda_wm": 0.0, + "enable": True, + "clipping_type": "batch", + "clip_positive_only": True + } + running_stats = {"initialized": False} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # If masked, turn 1 (positive adv) should be 0.0, turn 2 (negative adv) should remain -10.0 + self.assertEqual(batch.batch["advantages"][0, 0].item(), 0.0) + self.assertEqual(batch.batch["advantages"][0, 2].item(), -10.0) + + def test_inverse_sft_mask(self): + """Verify that inverse_sft_mask correctly computes sft_weights on env tokens.""" + # Seq: [Action1, Env1, Action2, Env2] -> masks: response_mask=[1, 0, 1, 0], attention=[1, 1, 1, 1] + response_mask = torch.tensor([[1, 0, 1, 0]], dtype=torch.float32) + attention_mask = torch.ones(1, 4, dtype=torch.float32) + advantages = torch.tensor([[2.0, 2.0, 2.0, 2.0]]) + old_log_probs = torch.tensor([[np.log(0.5)] * 4]) + entropys = torch.tensor([[1.0] * 4]) + + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Force m_t = 0.0 for turn 0, m_t = 1.0 for turn 1 + # Turn 0 S* = 0.5 * (1.0 + log(0.5)) ~ 0.15 + # We can just let compute_dynamic_mask do its thing. + # Actually, let's just test that the weights are assigned correctly based on whatever mask is generated. + config = { + "mu_base": 0.001, # Force masking + "mu_exp": 0.001, + "lambda_wm": 0.0, + "enable": True, + "clipping_type": "batch", + "inverse_sft_mask": True + } + running_stats = {"initialized": False} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + self.assertIn("sft_weights", batch.batch) + sft_weights = batch.batch["sft_weights"] + + # Check shapes + self.assertEqual(sft_weights.shape, advantages.shape) + + # m_t is applied to advantages. Action1 (idx 0), Action2 (idx 2) + # Env1 (idx 1), Env2 (idx 3) + # If Action1 was masked (m_t=0), Env1 should have weight 1.0 + # If Action1 was not masked (m_t=1), Env1 should have weight 0.0 + + adv_action1 = batch.batch["advantages"][0, 0].item() + m_t_1 = adv_action1 / 2.0 # original advantage was 2.0 + weight_env1 = sft_weights[0, 1].item() + self.assertAlmostEqual(weight_env1, 1.0 - m_t_1) + + adv_action2 = batch.batch["advantages"][0, 2].item() + m_t_2 = adv_action2 / 2.0 + weight_env2 = sft_weights[0, 3].item() + self.assertAlmostEqual(weight_env2, 1.0 - m_t_2) + + # Ensure action tokens have 0 sft weight + self.assertEqual(sft_weights[0, 0].item(), 0.0) + self.assertEqual(sft_weights[0, 2].item(), 0.0) if __name__ == "__main__": unittest.main() From e413ef4b03e90eee9cf7a71bec200fee0dbb2d13 Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Fri, 3 Apr 2026 10:27:48 +0000 Subject: [PATCH 4/8] feat: implement RWML (Reinforcement World Model Learning) with separated GRPO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement RWML from arxiv:2602.05842 as a separate GRPO training pass for world model learning, decoupled from policy training. Core: EmbeddingSimilarityReward class computes per-turn binary rewards based on text embedding cosine similarity between model predictions (argmax at observation positions) and actual environment observations. Training flow per step: 1. compute_log_prob extracts predicted_ids (argmax on logits) 2. RWML GRPO: compute embedding similarity rewards → GRPO advantages → update_actor 3. Policy GRPO: compute task rewards → GRPO advantages → update_actor Key files: - world_model_rl.py: EmbeddingSimilarityReward, decode_per_turn_texts, compute_rwml_turn_rewards - dp_actor.py: return_predicted_ids in forward pass - http_training_server.py: separate RWML GRPO update step - generic_agent_loop.py: store per-turn observation texts Co-Authored-By: Claude Opus 4.6 (1M context) --- .../verl/trainer/ppo/world_model_rl.py | 237 ++++++++++++++++++ .../client_config/alfworld_wmc_erc_param.yaml | 14 ++ opentinker/server/generic_agent_loop.py | 5 + opentinker/server/http_training_server.py | 81 ++++++ progress.md | 86 +++++++ verl | 2 +- 6 files changed, 424 insertions(+), 1 deletion(-) create mode 100644 opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py create mode 100644 progress.md diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py new file mode 100644 index 00000000..2b09d985 --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py @@ -0,0 +1,237 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +RWML: Reinforcement World Model Learning (arxiv:2602.05842). + +Trains the LLM's world-modeling ability by rewarding accurate next-state +predictions using text-level embedding similarity with an external model. + +For each action turn t in a multi-turn trajectory: + - predicted_obs: model's argmax-decoded tokens at observation positions + - actual_obs: ground-truth observation tokens from the environment + - d(pred, actual) = 1 - cos(E(pred), E(actual)) [external embedding model] + - r^WM = 1.0 if d < tau_d else 0.0 [binary reward] + +The rewards are added to per-turn turn_scores and flow through GRPO per-step +advantage computation. +""" + +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) + + +class EmbeddingSimilarityReward: + """Loads an external text embedding model and computes RWML rewards. + + Uses a HuggingFace sentence-transformers or compatible model to encode + text strings into embeddings, then measures cosine similarity. + """ + + def __init__(self, model_name_or_path: str, device: str = "cpu"): + from sentence_transformers import SentenceTransformer + + self.model = SentenceTransformer(model_name_or_path, device=device) + self.device = device + + @torch.no_grad() + def encode(self, texts: List[str]) -> torch.Tensor: + """Encode texts into L2-normalized embeddings. + + Args: + texts: List of text strings. + + Returns: + Tensor of shape (N, embed_dim), L2-normalized. + """ + if not texts: + return torch.empty(0) + embeddings = self.model.encode( + texts, convert_to_tensor=True, show_progress_bar=False + ) + return F.normalize(embeddings, p=2, dim=1) + + def compute_similarity( + self, texts_a: List[str], texts_b: List[str] + ) -> List[float]: + """Compute pairwise cosine similarity between text pairs. + + Args: + texts_a: First list of texts. + texts_b: Second list of texts (same length). + + Returns: + List of cosine similarity values in [-1, 1]. + """ + assert len(texts_a) == len(texts_b) + if not texts_a: + return [] + emb_a = self.encode(texts_a) + emb_b = self.encode(texts_b) + sims = (emb_a * emb_b).sum(dim=1) + return sims.cpu().tolist() + + def compute_reward( + self, predicted: List[str], actual: List[str], tau_d: float = 0.2 + ) -> Tuple[List[float], List[float]]: + """Compute binary RWML rewards per the paper. + + r^WM = 1.0 if d(pred, actual) < tau_d else 0.0 + where d = 1 - cos_sim + + Args: + predicted: Predicted observation texts. + actual: Actual observation texts. + tau_d: Distance threshold (default 0.2 per paper). + + Returns: + (rewards, similarities): Lists of binary rewards and raw similarities. + """ + similarities = self.compute_similarity(predicted, actual) + rewards = [1.0 if (1.0 - sim) < tau_d else 0.0 for sim in similarities] + return rewards, similarities + + +def decode_per_turn_texts( + token_ids: torch.Tensor, + response_mask: torch.Tensor, + attention_mask: torch.Tensor, + tokenizer, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[str]]: + """Decode per-turn observation texts from token IDs. + + For each action turn t (defined by turn_boundaries), observation tokens + are at positions [end_t, start_{t+1}) in the response portion. This + function decodes those token spans to text strings. + + Args: + token_ids: (batch_size, response_length) token IDs (predicted or actual). + response_mask: (batch_size, response_length) 1=action, 0=env/pad. + attention_mask: (batch_size, full_seq_length) 1=real, 0=pad. + tokenizer: Tokenizer for decoding. + turn_boundaries: Per-sample list of (start, end) for action turns. + + Returns: + List of lists of decoded observation text strings per turn per sample. + """ + batch_size = token_ids.shape[0] + seq_len = token_ids.shape[1] + resp_len = response_mask.shape[1] + attn_resp = attention_mask[:, -resp_len:] + env_mask = attn_resp * (1.0 - response_mask.float()) # 1=env token, 0=action/pad + + all_texts = [] + for i in range(batch_size): + boundaries = turn_boundaries[i] + sample_texts = [] + for t, (start, end) in enumerate(boundaries): + # Observation region: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(boundaries): + env_end = boundaries[t + 1][0] + else: + env_end = seq_len + + region_mask = env_mask[i, end:env_end] + region_ids = token_ids[i, end:env_end] + + # Extract valid (non-padding) observation token IDs + valid_positions = region_mask.bool() + if valid_positions.sum() == 0: + sample_texts.append("") + continue + + valid_ids = region_ids[valid_positions].cpu().tolist() + text = tokenizer.decode(valid_ids, skip_special_tokens=True).strip() + sample_texts.append(text) + all_texts.append(sample_texts) + + return all_texts + + +def compute_rwml_turn_rewards( + predicted_observations: List[List[str]], + actual_observations: List[List[str]], + similarity_reward: EmbeddingSimilarityReward, + tau_d: float = 0.2, +) -> Tuple[np.ndarray, Dict[str, float]]: + """Compute per-turn RWML rewards for a batch. + + Flattens all per-turn pairs, computes embedding similarity rewards in one + batch, then reshapes back to per-sample per-turn structure. + + Args: + predicted_observations: Per-sample, per-turn predicted obs texts. + actual_observations: Per-sample, per-turn actual obs texts. + similarity_reward: EmbeddingSimilarityReward instance. + tau_d: Distance threshold for binary reward. + + Returns: + rwml_turn_rewards: np.ndarray(batch_size, dtype=object), each element + is a list of per-turn reward floats. + metrics: Dict with diagnostic metrics. + """ + batch_size = len(predicted_observations) + + # Flatten all (predicted, actual) pairs with valid text + flat_pred = [] + flat_actual = [] + indices = [] # (sample_idx, turn_idx) + for i in range(batch_size): + n_turns = min(len(predicted_observations[i]), len(actual_observations[i])) + for t in range(n_turns): + pred = predicted_observations[i][t] + actual = actual_observations[i][t] + if pred and actual: # skip empty + flat_pred.append(pred) + flat_actual.append(actual) + indices.append((i, t)) + + # Initialize per-turn rewards with 0.0 + rwml_rewards = np.empty(batch_size, dtype=object) + for i in range(batch_size): + n_turns = len(actual_observations[i]) + rwml_rewards[i] = [0.0] * n_turns + + if not flat_pred: + return rwml_rewards, { + "rwml/mean_reward": 0.0, + "rwml/mean_similarity": 0.0, + "rwml/num_valid_pairs": 0, + "rwml/total_turns": sum(len(a) for a in actual_observations), + } + + # Batch compute rewards + rewards, similarities = similarity_reward.compute_reward( + flat_pred, flat_actual, tau_d + ) + + # Scatter back to per-sample per-turn structure + for idx, (i, t) in enumerate(indices): + rwml_rewards[i][t] = rewards[idx] + + metrics = { + "rwml/mean_reward": float(np.mean(rewards)), + "rwml/mean_similarity": float(np.mean(similarities)), + "rwml/num_valid_pairs": len(flat_pred), + "rwml/total_turns": sum(len(a) for a in actual_observations), + "rwml/reward_rate": float(np.mean(rewards)) if rewards else 0.0, + } + return rwml_rewards, metrics diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index 5566309d..570ff180 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -77,6 +77,20 @@ world_model_coeff: 0.0 world_model_annealing_steps: 0 world_model_annealing_end_factor: 1.0 +# RWML: Reinforcement World Model Learning (arxiv:2602.05842) +# Computes per-turn rewards based on text embedding similarity between the +# model's predicted next observation (argmax at obs token positions) and the +# actual environment observation. Binary reward: 1.0 if (1-cos_sim) < tau_d. +# - enable: master switch +# - embedding_model: HuggingFace model for text embeddings +# - tau_d: distance threshold (paper uses 0.2 for ALFWorld) +# - coeff: weight when combining RWML reward with task success reward +rwml: + enable: false + embedding_model: "Alibaba-NLP/gte-large-en-v1.5" + tau_d: 0.2 + coeff: 1.0 + # Interaction configuration interaction: name: alfworld diff --git a/opentinker/server/generic_agent_loop.py b/opentinker/server/generic_agent_loop.py index bc56696d..43b1c5ab 100755 --- a/opentinker/server/generic_agent_loop.py +++ b/opentinker/server/generic_agent_loop.py @@ -638,6 +638,11 @@ async def _handle_interacting_state( if reward is not None: agent_data.turn_scores.append(reward) + # Store per-turn observation text for RWML reward computation + if "turn_observations" not in agent_data.extra_fields: + agent_data.extra_fields["turn_observations"] = [] + agent_data.extra_fields["turn_observations"].append(observation) + # Store environment info under a SINGLE key to ensure consistent structure # across all samples (avoids DataProto.concat assertion errors when different # samples return different info keys) diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index b52641e3..540ddb5b 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -740,6 +740,19 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]: if self.wm_coeff > 0: logger.info(f"World model SFT loss enabled with coeff={self.wm_coeff}") + # RWML: Reinforcement World Model Learning (per-turn embedding similarity reward) + rwml_cfg = OmegaConf.select(self.config, "rwml", default=None) + self.rwml_enabled = rwml_cfg is not None and rwml_cfg.get("enable", False) + if self.rwml_enabled: + from opentinker.backend_patch.verl.trainer.ppo.world_model_rl import EmbeddingSimilarityReward + self.rwml_reward_fn = EmbeddingSimilarityReward( + model_name_or_path=rwml_cfg.embedding_model, + device="cuda:0" if torch.cuda.is_available() else "cpu", + ) + self.rwml_tau_d = rwml_cfg.get("tau_d", 0.2) + self.rwml_coeff = rwml_cfg.get("coeff", 1.0) + logger.info(f"RWML enabled: model={rwml_cfg.embedding_model}, tau_d={self.rwml_tau_d}, coeff={self.rwml_coeff}") + self.is_initialized = True logger.info("Workers initialized successfully") return {"status": "success", "message": "Workers initialized"} @@ -1081,6 +1094,10 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info("=" * 80) # ===== DEBUG LOGGING END ===== + # Pass RWML flag so actor returns predicted token IDs + if self.rwml_enabled: + batch.meta_info["rwml_enabled"] = True + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) logger.info( f"DEBUG: old_log_prob keys: {list(old_log_prob.batch.keys())}" @@ -1126,6 +1143,70 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: metrics.update(calculate_debug_metrics(batch)) + # 6.5 RWML: Separate world model GRPO update (before policy training) + if self.rwml_enabled and "predicted_ids" in batch.batch: + with marked_timer("rwml_update", timing_raw, color="magenta"): + from opentinker.backend_patch.verl.trainer.ppo.world_model_rl import ( + decode_per_turn_texts, compute_rwml_turn_rewards, + ) + from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, compute_grpo_per_step_advantage, + ) + + predicted_ids = batch.batch.pop("predicted_ids") + response_mask = batch.batch["response_mask"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + turn_boundaries = compute_turn_boundaries(response_mask) + + # Decode predicted and actual observation texts per turn + predicted_obs = decode_per_turn_texts( + predicted_ids, response_mask, attention_mask, + self.tokenizer, turn_boundaries, + ) + actual_obs = decode_per_turn_texts( + responses, response_mask, attention_mask, + self.tokenizer, turn_boundaries, + ) + + # Compute RWML rewards (binary, per turn) + rwml_rewards, rwml_metrics = compute_rwml_turn_rewards( + predicted_obs, actual_obs, self.rwml_reward_fn, self.rwml_tau_d + ) + metrics.update(rwml_metrics) + logger.info( + f"[RWML] mean_sim={rwml_metrics.get('rwml/mean_similarity', 0):.4f}, " + f"mean_reward={rwml_metrics.get('rwml/mean_reward', 0):.4f}, " + f"valid_pairs={rwml_metrics.get('rwml/num_valid_pairs', 0)}" + ) + + # Compute RWML advantages via GRPO per-step (separate from policy) + norm_adv = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + rwml_advantages, _ = compute_grpo_per_step_advantage( + token_level_rewards=torch.zeros_like(response_mask.float()), + response_mask=response_mask, + index=batch.non_tensor_batch["uid"], + turn_scores=rwml_rewards, + gamma=self.config.algorithm.gamma, + norm_adv_by_std_in_grpo=norm_adv, + ) + + # Run separate RWML GRPO actor update + batch.batch["advantages"] = rwml_advantages + batch.meta_info["multi_turn"] = ( + self.config.actor_rollout_ref.rollout.multi_turn.enable + ) + rwml_actor_output = self.actor_rollout_wg.update_actor(batch) + rwml_actor_metrics = reduce_metrics( + rwml_actor_output.meta_info["metrics"] + ) + metrics.update( + {f"rwml/{k}": v for k, v in rwml_actor_metrics.items()} + ) + + # Remove RWML advantages so policy training computes its own + del batch.batch["advantages"] + # 7. Compute ref_log_prob if needed if self.use_reference_policy: with marked_timer("ref_log_prob", timing_raw, color="olive"): diff --git a/progress.md b/progress.md new file mode 100644 index 00000000..0b6b1ee3 --- /dev/null +++ b/progress.md @@ -0,0 +1,86 @@ +# World Model Learning — Implementation Progress + +## Method 1: WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution) +**Status: Complete** +- File: `opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py` +- Uses the LLM's prediction entropy at env token positions as a World Model uncertainty signal +- Dynamically gates policy gradient updates to prevent entropy collapse + +## Method 2: World Model SFT Loss +**Status: Complete** +- File: `opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py` +- Auxiliary SFT loss on observation tokens (next-token prediction) +- Gated by `world_model_coeff` in actor config + +## Method 3: RWML — Reinforcement World Model Learning (arxiv:2602.05842) +**Status: Complete** + +### What it does +Per-turn RL reward based on text-level embedding similarity between model's predicted +next observation and actual environment observation, using an external embedding model. + +``` +d(pred, actual) = 1 - cos(E(pred), E(actual)) [external embedding model] +r^WM = 1.0 if d < tau_d else 0.0 [binary reward, paper default tau_d=0.2] +``` + +### How predictions are obtained +During `compute_log_prob`, logits at observation token positions are argmax-decoded +to get the model's predicted token IDs. These are decoded to text and compared with +actual observation tokens using an external sentence-transformers embedding model. + +### Files modified +- [x] `opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py` — Core RWML module + - `EmbeddingSimilarityReward`: loads embedding model, computes cosine sim + binary reward + - `decode_per_turn_texts()`: extracts observation texts from token IDs per turn + - `compute_rwml_turn_rewards()`: batch reward computation with metrics +- [x] `verl/verl/workers/actor/dp_actor.py` — Predicted token ID extraction + - `_forward_micro_batch`: `return_predicted_ids` → argmax on logits + - `compute_log_prob`: returns `(log_probs, entropys, predicted_ids)` +- [x] `verl/verl/workers/fsdp_workers.py` — DataProto wrapping + - Includes `predicted_ids` in returned DataProto when RWML is enabled +- [x] `opentinker/server/generic_agent_loop.py` — Observation text storage + - Stores per-turn observation texts in `extra_fields["turn_observations"]` +- [x] `opentinker/server/http_training_server.py` — RWML integration + - Initializes embedding model at startup + - Computes RWML rewards from predicted_ids after compute_log_prob + - Adds RWML rewards to turn_scores before advantage computation +- [x] `opentinker/client/client_config/alfworld_wmc_erc_param.yaml` — Config + +### Configuration +```yaml +rwml: + enable: false + embedding_model: "Alibaba-NLP/gte-large-en-v1.5" + tau_d: 0.2 + coeff: 1.0 +``` + +### Data flow (separated from policy training) +``` +Rollout → turn_observations stored in extra_fields + ↓ +compute_log_prob (rwml_enabled=True) + → argmax on logits → predicted_ids + ↓ +RWML reward computation: + → decode predicted obs texts from predicted_ids + → decode actual obs texts from response tokens + → cos_sim(E(predicted), E(actual)) via embedding model + → binary reward: 1 if (1-sim) < tau_d + ↓ +RWML GRPO update (SEPARATE from policy): + → compute_grpo_per_step_advantage(turn_scores=rwml_rewards) + → update_actor(advantages=rwml_advantages) ← world model GRPO + ↓ +Policy GRPO update (normal): + → compute_advantage(turn_scores=task_rewards) + → update_actor(advantages=policy_advantages) ← policy GRPO +``` + +### Note on previous hidden-state approach +The initial WM-RL implementation (hidden-state cosine similarity as auxiliary loss) +was superseded by this RWML implementation which correctly follows the paper: +- Text-level (not hidden-state) embedding similarity +- External embedding model (not training model's own representations) +- Per-turn RL reward (not differentiable auxiliary loss) diff --git a/verl b/verl index 4bf4bd32..e2423e73 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 4bf4bd32d049be648867ade8c72ee3f5c27ebfcf +Subproject commit e2423e73eb5f09e1bdb091a605cd6385088502b3 From 419dce3cf3ac1426cefa588154f301c55c8d4a43 Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Fri, 3 Apr 2026 13:49:13 +0000 Subject: [PATCH 5/8] fix: forward RWML config from client to server The rwml config block was defined in the YAML but never extracted and sent to the training server by the client, unlike wmc_erc which has explicit extraction logic. Without this, OmegaConf.select(self.config, "rwml") always returned None on the server side. Co-Authored-By: Claude Opus 4.6 (1M context) --- opentinker/client/utils/http_training_client.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 36fb7d37..2f349940 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -648,6 +648,16 @@ def set_config(self, args: DictConfig, env=None): ) print(f"[ServiceClient] Passing WMC-ERC config to server: {wmc_erc_dict}") + # Pass RWML config to server if present + rwml_cfg = getattr(args, "rwml", None) + if rwml_cfg is not None: + rwml_dict = OmegaConf.to_container(rwml_cfg, resolve=True) + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create({"rwml": rwml_dict}), + ) + print(f"[ServiceClient] Passing RWML config to server: {rwml_dict}") + # Optional world model SFT coefficient for joint PPO + WM training. if hasattr(args, "world_model_loss") and args.world_model_loss: world_model_loss_cfg = OmegaConf.to_container(args.world_model_loss, resolve=True) From 28de904eca53747afedeb75c9e039a03823ed68e Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Fri, 3 Apr 2026 15:00:52 +0000 Subject: [PATCH 6/8] fix: pass trust_remote_code=True to SentenceTransformer Alibaba-NLP/gte-large-en-v1.5 uses custom code that requires explicit trust_remote_code=True to load. Co-Authored-By: Claude Opus 4.6 (1M context) --- opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py index 2b09d985..19d3d4ed 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py @@ -48,7 +48,7 @@ class EmbeddingSimilarityReward: def __init__(self, model_name_or_path: str, device: str = "cpu"): from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer(model_name_or_path, device=device) + self.model = SentenceTransformer(model_name_or_path, device=device, trust_remote_code=True) self.device = device @torch.no_grad() From fb8f555c5b0d194edadf83dad7c70710bb1d076f Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Fri, 3 Apr 2026 15:07:30 +0000 Subject: [PATCH 7/8] fix: load embedding model in offline mode for air-gapped environments Set HF_HUB_OFFLINE=1 during SentenceTransformer init so it uses the local cache without trying to connect to huggingface.co. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../backend_patch/verl/trainer/ppo/world_model_rl.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py index 19d3d4ed..63ec3eac 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py @@ -46,9 +46,19 @@ class EmbeddingSimilarityReward: """ def __init__(self, model_name_or_path: str, device: str = "cpu"): + import os from sentence_transformers import SentenceTransformer - self.model = SentenceTransformer(model_name_or_path, device=device, trust_remote_code=True) + # Use local cache without network access (offline-friendly) + prev_offline = os.environ.get("HF_HUB_OFFLINE") + os.environ["HF_HUB_OFFLINE"] = "1" + try: + self.model = SentenceTransformer(model_name_or_path, device=device, trust_remote_code=True) + finally: + if prev_offline is None: + os.environ.pop("HF_HUB_OFFLINE", None) + else: + os.environ["HF_HUB_OFFLINE"] = prev_offline self.device = device @torch.no_grad() From 6a00afc7ddf894a00a45cc6ea0556341d2578abf Mon Sep 17 00:00:00 2001 From: PolarisDane <2488721971@qq.com> Date: Fri, 3 Apr 2026 15:11:23 +0000 Subject: [PATCH 8/8] fix: resolve embedding model from HF cache to avoid network access Instead of passing the model ID to SentenceTransformer (which tries to reach huggingface.co even when cached), resolve the local snapshot path from the HF hub cache first. Falls back to the original path if not found in cache or if it's already a local directory. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../verl/trainer/ppo/world_model_rl.py | 44 ++++++++++++++----- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py index 63ec3eac..2fc2f5f9 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py @@ -49,18 +49,42 @@ def __init__(self, model_name_or_path: str, device: str = "cpu"): import os from sentence_transformers import SentenceTransformer - # Use local cache without network access (offline-friendly) - prev_offline = os.environ.get("HF_HUB_OFFLINE") - os.environ["HF_HUB_OFFLINE"] = "1" - try: - self.model = SentenceTransformer(model_name_or_path, device=device, trust_remote_code=True) - finally: - if prev_offline is None: - os.environ.pop("HF_HUB_OFFLINE", None) - else: - os.environ["HF_HUB_OFFLINE"] = prev_offline + local_path = self._resolve_local_path(model_name_or_path) + self.model = SentenceTransformer(local_path, device=device, trust_remote_code=True) self.device = device + @staticmethod + def _resolve_local_path(model_name_or_path: str) -> str: + """Resolve a HF model ID to a local cache path if available. + + If model_name_or_path is already a local directory, return it as-is. + Otherwise, look up the HF hub cache for a cached snapshot and return + its path to avoid any network access. + """ + import os + + if os.path.isdir(model_name_or_path): + return model_name_or_path + + # Try to resolve from HF cache: ~/.cache/huggingface/hub/models--{org}--{name}/snapshots/{hash} + try: + from huggingface_hub import scan_cache_dir + + cache_info = scan_cache_dir() + for repo in cache_info.repos: + if repo.repo_id == model_name_or_path: + # Pick the most recent revision + revisions = sorted(repo.revisions, key=lambda r: r.last_modified, reverse=True) + if revisions: + local = str(revisions[0].snapshot_path) + print(f"[RWML] Resolved {model_name_or_path} to local cache: {local}") + return local + except Exception: + pass + + # Fallback: return as-is and let SentenceTransformer handle it + return model_name_or_path + @torch.no_grad() def encode(self, texts: List[str]) -> torch.Tensor: """Encode texts into L2-normalized embeddings.