diff --git a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py index 2ea6182a..daf01e34 100644 --- a/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py +++ b/opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py @@ -82,6 +82,8 @@ def compute_s_star( if count > 0: p_k = torch.exp(log_p) s_token = p_k * (H + log_p) + # Handle potential NaN due to 0 * -inf when p_k is 0 and log_p is -inf + s_token = torch.nan_to_num(s_token, nan=0.0) s_t = (s_token * mask).sum() / count else: s_t = torch.tensor(0.0, device=device) @@ -155,60 +157,67 @@ def compute_h_wm( def compute_dynamic_mask( s_star_per_sample: List[List[torch.Tensor]], h_wm_per_sample: List[List[torch.Tensor]], - mu_base: float = 1.0, - lambda_wm: float = 1.0, + mu_base: float, + mu_exp: float, + eta_wm: float, + lambda_wm: float, + s_bar: float, + sigma: float, + clipping_method: str = "mask", ) -> List[List[float]]: - """Compute per-turn dynamic entropy clipping mask. - - m_t = 1 if |S_*^t - S_bar| <= mu_base * (1 + lambda * H_WM^t) * sigma - 0 otherwise + """Compute per-turn dynamic entropy clipping mask or coefficient. Logic: - - WM confident (H_WM→0): threshold tightens → overconfident policy blocked - (prevents overfitting in well-understood regions) - - WM uncertain (H_WM large): threshold widens → exploration encouraged - (allows agent to gather data in unknown regions to train the WM) - - H_WM is detached — it only acts as a scalar gate, never participates in - policy gradient backpropagation. + - WM uncertainty signal: f(H_WM) = eta_wm * exp(-lambda_wm * H_WM) + - Threshold: threshold = mu * f(H_WM) * sigma + - Masking: m_t = 0.0 if violation, 1.0 otherwise + - Clipping: m_t = threshold / deviation if violation, 1.0 otherwise (PPO-style) Args: s_star_per_sample: per-sample, per-turn S_* tensors h_wm_per_sample: per-sample, per-turn H_WM tensors - mu_base: base clipping coefficient - lambda_wm: WM uncertainty weight + mu_base: clipping coefficient for collapsing side + mu_exp: clipping coefficient for exploration side + eta_wm: base multiplier for WM uncertainty signal + lambda_wm: exponential decay factor for WM uncertainty + s_bar: mean of S_* (batch or global) + sigma: std of S_* (batch or global) + clipping_method: "mask" or "clip" Returns: - List of lists of floats (0.0 or 1.0), one mask per turn per sample. + List of lists of floats, one mask/coeff per turn per sample. """ - # Flatten all S_* for batch statistics - all_s = [] - for turns in s_star_per_sample: - for s in turns: - all_s.append(s.detach()) - - if len(all_s) == 0: - return [[] for _ in s_star_per_sample] - - all_s_tensor = torch.stack(all_s) - s_bar = all_s_tensor.mean() - - # Guard for single-element: std is 0, threshold = mu_base * (1 + lambda * h_wm) * 0 - # → everything would be masked. Use 1.0 as default sigma for single element. - if len(all_s) <= 1: - sigma = torch.tensor(1.0, device=all_s_tensor.device) - else: - sigma = all_s_tensor.std(unbiased=False) + 1e-8 - mask_per_sample = [] for i in range(len(s_star_per_sample)): masks = [] for t in range(len(s_star_per_sample[i])): - s_t = s_star_per_sample[i][t].detach() - h_t = h_wm_per_sample[i][t].detach() - - threshold = mu_base * (1.0 + lambda_wm * h_t) * sigma - m_t = 1.0 if torch.abs(s_t - s_bar) <= threshold else 0.0 + s_t = s_star_per_sample[i][t].detach().item() + h_t = h_wm_per_sample[i][t].detach().item() + + # WM uncertainty signal: f(H_WM) = eta_wm * exp(-lambda_wm * H_WM) + h_factor = eta_wm * np.exp(-lambda_wm * h_t) + + # Asymmetric threshold calculation + if s_t > s_bar: + # Collapsing side + threshold = mu_base * h_factor * sigma + diff = s_t - s_bar + if clipping_method == "mask": + m_t = 1.0 if diff <= threshold else 0.0 + else: # PPO-style clipping + # If diff > threshold, we scale the advantage by threshold/diff + # such that the effective update is capped at threshold + u_t = min(1.0, diff / (threshold + 1e-8)) + m_t = 1.0 / (1.0 + 0.5 * u_t) + else: + # Exploration side + threshold = mu_exp * h_factor * sigma + diff = s_bar - s_t + if clipping_method == "mask": + m_t = 1.0 if diff <= threshold else 0.0 + else: # PPO-style clipping + m_t = 1.0 + masks.append(m_t) mask_per_sample.append(masks) @@ -219,35 +228,31 @@ def apply_wmc_erc( batch, entropys: torch.Tensor, wmc_erc_config, + running_stats: Dict[str, float], ) -> Tuple[object, Dict[str, float]]: """Apply WMC-ERC dynamic entropy clipping to batch advantages. - Pipeline: - 1. Compute turn boundaries from response_mask - 2. Compute S_* (policy blind confidence) per turn - 3. Compute H_WM (world model uncertainty) per turn from env token entropys - 4. Compute dynamic mask m_t per turn - 5. Apply mask to advantages: A_masked = A * m_t (broadcast to tokens) - 6. Return metrics for logging - Args: - batch: DataProto or compatible object with batch dict containing - advantages, response_mask, old_log_probs, attention_mask - entropys: (batch_size, response_length) stored before pop in train_step - wmc_erc_config: OmegaConf DictConfig or dict with mu_base, lambda_wm, enable + batch: DataProto or compatible object + entropys: (batch_size, response_length) + wmc_erc_config: OmegaConf DictConfig or dict + running_stats: Dictionary for global running statistics Returns: - (batch, metrics) where batch has masked advantages and metrics dict + (batch, metrics) """ enable = wmc_erc_config.get("enable", True) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'enable', True) if not enable: return batch, {} + clipping_type = wmc_erc_config.get("clipping_type", "batch") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_type', "batch") + clipping_method = wmc_erc_config.get("clipping_method", "mask") if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clipping_method', "mask") + clip_positive_only = wmc_erc_config.get("clip_positive_only", False) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'clip_positive_only', False) + inverse_sft_mask = wmc_erc_config.get("inverse_sft_mask", False) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'inverse_sft_mask', False) + response_mask = batch.batch["response_mask"] old_log_probs = batch.batch["old_log_probs"] advantages = batch.batch["advantages"] - - # Compute attention mask for response region response_length = advantages.shape[1] attention_mask = batch.batch["attention_mask"] attention_mask_response = attention_mask[:, -response_length:] @@ -255,41 +260,118 @@ def apply_wmc_erc( # 1. Turn boundaries turn_boundaries = compute_turn_boundaries(response_mask) - # 2. S_* per turn + # 2. Compute S_* and H_WM per turn s_star = compute_s_star(old_log_probs, entropys, response_mask, turn_boundaries) - - # 3. H_WM per turn h_wm = compute_h_wm(entropys, response_mask, attention_mask_response, turn_boundaries) - # 4. Dynamic mask + # Calculate batch statistics + all_s = [s.item() for turns in s_star for s in turns] + all_h = [h.item() for turns in h_wm for h in turns] + + if not all_s: + return batch, {} + + batch_s_bar = np.mean(all_s) + batch_s_std = np.std(all_s) + 1e-8 + batch_h_bar = np.mean(all_h) + 1e-8 + + # Update global statistics + momentum = wmc_erc_config.get("momentum", 0.9) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'momentum', 0.9) + if len(running_stats.keys()) == 0: + running_stats["s_bar"] = batch_s_bar + running_stats["s_std"] = batch_s_std + running_stats["h_bar"] = batch_h_bar + else: + running_stats["s_bar"] = (1 - momentum) * batch_s_bar + momentum * running_stats["s_bar"] + running_stats["s_std"] = (1 - momentum) * batch_s_std + momentum * running_stats["s_std"] + running_stats["h_bar"] = (1 - momentum) * batch_h_bar + momentum * running_stats["h_bar"] + + # Select statistics for masking + if clipping_type == "global": + use_s_bar = running_stats["s_bar"] + use_s_std = running_stats["s_std"] + use_h_bar = running_stats["h_bar"] + else: + use_s_bar = batch_s_bar + use_s_std = batch_s_std + use_h_bar = batch_h_bar + + # 4. Dynamic mask/clip mu_base = float(wmc_erc_config.get("mu_base", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_base', 1.0)) + mu_exp = float(wmc_erc_config.get("mu_exp", 2.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'mu_exp', 2.0)) + eta_wm = float(wmc_erc_config.get("eta_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'eta_wm', 1.0)) lambda_wm = float(wmc_erc_config.get("lambda_wm", 1.0) if hasattr(wmc_erc_config, 'get') else getattr(wmc_erc_config, 'lambda_wm', 1.0)) - mask = compute_dynamic_mask(s_star, h_wm, mu_base, lambda_wm) - - # 5. Apply mask to advantages (in-place) + + mask = compute_dynamic_mask( + s_star, h_wm, mu_base, mu_exp, eta_wm, lambda_wm, + s_bar=use_s_bar, + sigma=use_s_std, + clipping_method=clipping_method + ) + + # 5. Apply mask/coeff to advantages batch_size = advantages.shape[0] + + if inverse_sft_mask: + sft_weights = torch.zeros_like(advantages) + env_mask = attention_mask_response * (1.0 - response_mask) + for i in range(batch_size): for t, (start, end) in enumerate(turn_boundaries[i]): if t < len(mask[i]): - advantages[i, start:end] *= mask[i][t] + m_t = mask[i][t] + + if inverse_sft_mask: + # Env tokens after this turn: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(turn_boundaries[i]): + env_end = turn_boundaries[i][t + 1][0] + else: + env_end = response_length + + sft_weight = 1.0 - m_t + region_mask = env_mask[i, end:env_end] + sft_weights[i, end:env_end] = region_mask * sft_weight + + if m_t < 1.0: + if clip_positive_only: + # Only apply scaling where advantages > 0 + turn_adv = advantages[i, start:end] + advantages[i, start:end] = torch.where(turn_adv > 0, turn_adv * m_t, turn_adv) + else: + advantages[i, start:end] *= m_t batch.batch["advantages"] = advantages + + if inverse_sft_mask: + batch.batch["sft_weights"] = sft_weights # 6. Metrics - all_s = [s.item() for turns in s_star for s in turns] - all_h = [h.item() for turns in h_wm for h in turns] all_m = [m for turns in mask for m in turns] + + num_collapsing_violated = 0 + num_exploration_violated = 0 + for i in range(len(s_star)): + for t in range(len(s_star[i])): + if mask[i][t] < 1.0: + if s_star[i][t].item() > use_s_bar: + num_collapsing_violated += 1 + else: + num_exploration_violated += 1 - # WM NLL (monitoring only — not in backward pass for this prototype) env_mask = attention_mask_response * (1.0 - response_mask) env_count = env_mask.sum() wm_nll = (-(old_log_probs * env_mask).sum() / (env_count + 1e-8)).item() if env_count > 0 else 0.0 metrics = { - "wmc_erc/s_star_mean": float(np.mean(all_s)) if all_s else 0.0, - "wmc_erc/s_star_std": float(np.std(all_s)) if all_s else 0.0, - "wmc_erc/h_wm_mean": float(np.mean(all_h)) if all_h else 0.0, + "wmc_erc/batch_s_bar": float(batch_s_bar), + "wmc_erc/batch_s_std": float(batch_s_std), + "wmc_erc/batch_h_bar": float(batch_h_bar), + "wmc_erc/running_s_bar": float(running_stats["s_bar"]), + "wmc_erc/running_s_std": float(running_stats["s_std"]), + "wmc_erc/running_h_bar": float(running_stats["h_bar"]), "wmc_erc/mask_ratio": float(np.mean(all_m)) if all_m else 1.0, - "wmc_erc/num_masked_turns": sum(1 for m in all_m if m == 0.0), + "wmc_erc/num_violated_turns": sum(1 for m in all_m if m < 1.0), + "wmc_erc/num_collapsing_violated": num_collapsing_violated, + "wmc_erc/num_exploration_violated": num_exploration_violated, "wmc_erc/total_turns": len(all_m), "wmc_erc/wm_nll": wm_nll, } diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py new file mode 100644 index 00000000..d5fea5c4 --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py @@ -0,0 +1,50 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +World Model SFT Loss — trains the policy to also predict observation tokens. + +Joint loss = ppo_loss(action tokens) + wm_coeff * sft_loss(observation tokens) + +Implementation: + The WM SFT loss is computed inside dp_actor.update_policy() (verl modification). + It is gated by config.world_model_coeff > 0 AND observation_mask being present + in the batch. + + observation_mask is computed in http_training_server.py before update_actor: + obs_mask = attention_mask[:, -resp_len:] & ~response_mask + + To enable, set in your config yaml: + actor_rollout_ref: + actor: + world_model_coeff: 0.1 +""" + +import torch + + +def compute_observation_mask(batch) -> torch.Tensor: + """Compute observation_mask from attention_mask and response_mask. + + observation tokens = real tokens in the response portion that are NOT + action (LLM-generated) tokens. + + Args: + batch: DataProto with batch["attention_mask"] and batch["response_mask"] + + Returns: + observation_mask: (batch_size, response_length) float tensor + """ + resp_len = batch.batch["response_mask"].shape[1] + attn_response = batch.batch["attention_mask"][:, -resp_len:] + return (attn_response.bool() & ~batch.batch["response_mask"].bool()).float() \ No newline at end of file diff --git a/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py new file mode 100644 index 00000000..2fc2f5f9 --- /dev/null +++ b/opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py @@ -0,0 +1,271 @@ +# Copyright 2025 OpenTinker +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +RWML: Reinforcement World Model Learning (arxiv:2602.05842). + +Trains the LLM's world-modeling ability by rewarding accurate next-state +predictions using text-level embedding similarity with an external model. + +For each action turn t in a multi-turn trajectory: + - predicted_obs: model's argmax-decoded tokens at observation positions + - actual_obs: ground-truth observation tokens from the environment + - d(pred, actual) = 1 - cos(E(pred), E(actual)) [external embedding model] + - r^WM = 1.0 if d < tau_d else 0.0 [binary reward] + +The rewards are added to per-turn turn_scores and flow through GRPO per-step +advantage computation. +""" + +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, +) + + +class EmbeddingSimilarityReward: + """Loads an external text embedding model and computes RWML rewards. + + Uses a HuggingFace sentence-transformers or compatible model to encode + text strings into embeddings, then measures cosine similarity. + """ + + def __init__(self, model_name_or_path: str, device: str = "cpu"): + import os + from sentence_transformers import SentenceTransformer + + local_path = self._resolve_local_path(model_name_or_path) + self.model = SentenceTransformer(local_path, device=device, trust_remote_code=True) + self.device = device + + @staticmethod + def _resolve_local_path(model_name_or_path: str) -> str: + """Resolve a HF model ID to a local cache path if available. + + If model_name_or_path is already a local directory, return it as-is. + Otherwise, look up the HF hub cache for a cached snapshot and return + its path to avoid any network access. + """ + import os + + if os.path.isdir(model_name_or_path): + return model_name_or_path + + # Try to resolve from HF cache: ~/.cache/huggingface/hub/models--{org}--{name}/snapshots/{hash} + try: + from huggingface_hub import scan_cache_dir + + cache_info = scan_cache_dir() + for repo in cache_info.repos: + if repo.repo_id == model_name_or_path: + # Pick the most recent revision + revisions = sorted(repo.revisions, key=lambda r: r.last_modified, reverse=True) + if revisions: + local = str(revisions[0].snapshot_path) + print(f"[RWML] Resolved {model_name_or_path} to local cache: {local}") + return local + except Exception: + pass + + # Fallback: return as-is and let SentenceTransformer handle it + return model_name_or_path + + @torch.no_grad() + def encode(self, texts: List[str]) -> torch.Tensor: + """Encode texts into L2-normalized embeddings. + + Args: + texts: List of text strings. + + Returns: + Tensor of shape (N, embed_dim), L2-normalized. + """ + if not texts: + return torch.empty(0) + embeddings = self.model.encode( + texts, convert_to_tensor=True, show_progress_bar=False + ) + return F.normalize(embeddings, p=2, dim=1) + + def compute_similarity( + self, texts_a: List[str], texts_b: List[str] + ) -> List[float]: + """Compute pairwise cosine similarity between text pairs. + + Args: + texts_a: First list of texts. + texts_b: Second list of texts (same length). + + Returns: + List of cosine similarity values in [-1, 1]. + """ + assert len(texts_a) == len(texts_b) + if not texts_a: + return [] + emb_a = self.encode(texts_a) + emb_b = self.encode(texts_b) + sims = (emb_a * emb_b).sum(dim=1) + return sims.cpu().tolist() + + def compute_reward( + self, predicted: List[str], actual: List[str], tau_d: float = 0.2 + ) -> Tuple[List[float], List[float]]: + """Compute binary RWML rewards per the paper. + + r^WM = 1.0 if d(pred, actual) < tau_d else 0.0 + where d = 1 - cos_sim + + Args: + predicted: Predicted observation texts. + actual: Actual observation texts. + tau_d: Distance threshold (default 0.2 per paper). + + Returns: + (rewards, similarities): Lists of binary rewards and raw similarities. + """ + similarities = self.compute_similarity(predicted, actual) + rewards = [1.0 if (1.0 - sim) < tau_d else 0.0 for sim in similarities] + return rewards, similarities + + +def decode_per_turn_texts( + token_ids: torch.Tensor, + response_mask: torch.Tensor, + attention_mask: torch.Tensor, + tokenizer, + turn_boundaries: List[List[Tuple[int, int]]], +) -> List[List[str]]: + """Decode per-turn observation texts from token IDs. + + For each action turn t (defined by turn_boundaries), observation tokens + are at positions [end_t, start_{t+1}) in the response portion. This + function decodes those token spans to text strings. + + Args: + token_ids: (batch_size, response_length) token IDs (predicted or actual). + response_mask: (batch_size, response_length) 1=action, 0=env/pad. + attention_mask: (batch_size, full_seq_length) 1=real, 0=pad. + tokenizer: Tokenizer for decoding. + turn_boundaries: Per-sample list of (start, end) for action turns. + + Returns: + List of lists of decoded observation text strings per turn per sample. + """ + batch_size = token_ids.shape[0] + seq_len = token_ids.shape[1] + resp_len = response_mask.shape[1] + attn_resp = attention_mask[:, -resp_len:] + env_mask = attn_resp * (1.0 - response_mask.float()) # 1=env token, 0=action/pad + + all_texts = [] + for i in range(batch_size): + boundaries = turn_boundaries[i] + sample_texts = [] + for t, (start, end) in enumerate(boundaries): + # Observation region: [end, next_turn_start) or [end, seq_len) + if t + 1 < len(boundaries): + env_end = boundaries[t + 1][0] + else: + env_end = seq_len + + region_mask = env_mask[i, end:env_end] + region_ids = token_ids[i, end:env_end] + + # Extract valid (non-padding) observation token IDs + valid_positions = region_mask.bool() + if valid_positions.sum() == 0: + sample_texts.append("") + continue + + valid_ids = region_ids[valid_positions].cpu().tolist() + text = tokenizer.decode(valid_ids, skip_special_tokens=True).strip() + sample_texts.append(text) + all_texts.append(sample_texts) + + return all_texts + + +def compute_rwml_turn_rewards( + predicted_observations: List[List[str]], + actual_observations: List[List[str]], + similarity_reward: EmbeddingSimilarityReward, + tau_d: float = 0.2, +) -> Tuple[np.ndarray, Dict[str, float]]: + """Compute per-turn RWML rewards for a batch. + + Flattens all per-turn pairs, computes embedding similarity rewards in one + batch, then reshapes back to per-sample per-turn structure. + + Args: + predicted_observations: Per-sample, per-turn predicted obs texts. + actual_observations: Per-sample, per-turn actual obs texts. + similarity_reward: EmbeddingSimilarityReward instance. + tau_d: Distance threshold for binary reward. + + Returns: + rwml_turn_rewards: np.ndarray(batch_size, dtype=object), each element + is a list of per-turn reward floats. + metrics: Dict with diagnostic metrics. + """ + batch_size = len(predicted_observations) + + # Flatten all (predicted, actual) pairs with valid text + flat_pred = [] + flat_actual = [] + indices = [] # (sample_idx, turn_idx) + for i in range(batch_size): + n_turns = min(len(predicted_observations[i]), len(actual_observations[i])) + for t in range(n_turns): + pred = predicted_observations[i][t] + actual = actual_observations[i][t] + if pred and actual: # skip empty + flat_pred.append(pred) + flat_actual.append(actual) + indices.append((i, t)) + + # Initialize per-turn rewards with 0.0 + rwml_rewards = np.empty(batch_size, dtype=object) + for i in range(batch_size): + n_turns = len(actual_observations[i]) + rwml_rewards[i] = [0.0] * n_turns + + if not flat_pred: + return rwml_rewards, { + "rwml/mean_reward": 0.0, + "rwml/mean_similarity": 0.0, + "rwml/num_valid_pairs": 0, + "rwml/total_turns": sum(len(a) for a in actual_observations), + } + + # Batch compute rewards + rewards, similarities = similarity_reward.compute_reward( + flat_pred, flat_actual, tau_d + ) + + # Scatter back to per-sample per-turn structure + for idx, (i, t) in enumerate(indices): + rwml_rewards[i][t] = rewards[idx] + + metrics = { + "rwml/mean_reward": float(np.mean(rewards)), + "rwml/mean_similarity": float(np.mean(similarities)), + "rwml/num_valid_pairs": len(flat_pred), + "rwml/total_turns": sum(len(a) for a in actual_observations), + "rwml/reward_rate": float(np.mean(rewards)) if rewards else 0.0, + } + return rwml_rewards, metrics diff --git a/opentinker/client/alfworld_rl.py b/opentinker/client/alfworld_rl.py index 0810cc5d..df02c10a 100644 --- a/opentinker/client/alfworld_rl.py +++ b/opentinker/client/alfworld_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/android_world_rl.py b/opentinker/client/android_world_rl.py index 122b02f1..8af8e265 100644 --- a/opentinker/client/android_world_rl.py +++ b/opentinker/client/android_world_rl.py @@ -127,6 +127,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 0f183186..381ed652 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -3,7 +3,7 @@ # Project settings project_name: opentinker -experiment_name: alfworld_training +experiment_name: alfworld_training_baseline_grpo # Logging logger_backends: ["console", "wandb"] @@ -60,7 +60,7 @@ interaction: env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} # If you run the ALFWorld env server in sharded mode (--shards N), # set env_shards=N. The client will route each instance_id to a stable shard. - env_shards: 32 + env_shards: 8 max_steps: 20 # ALFWorld episodes max steps max_total_steps: 20 # Max environment step calls (controls rollout turns) observation_template: "{observation}" @@ -80,4 +80,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml b/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml new file mode 100644 index 00000000..304a17e7 --- /dev/null +++ b/opentinker/client/client_config/alfworld_wm_loss_clip_param.yaml @@ -0,0 +1,107 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc_clip_simple_wm_loss_0.001 + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null +# enable_sleep_mode: true + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "grpo" +rollout_n: 8 + +# World Model SFT loss: predict environment observation tokens as auxiliary task +world_model_loss: + world_model_coeff: 0.001 + world_model_annealing_steps: 0 + world_model_annealing_end_factor: 0.0 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - eta_wm: base multiplier for WM uncertainty signal +# - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) +# - clipping_type: "batch" or "global" (global uses running statistics) +# - clipping_method: "mask" (0/1 gating) or "clip" (PPO-style soft clipping) +# - clip_positive_only: if true, only apply clipping to tokens with positive advantages +# - inverse_sft_mask: if true, use (1 - m_t) as weights for SFT loss on subsequent env tokens +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 3.0 + eta_wm: 2.0 + lambda_wm: 1.0 + clipping_type: "global" + clipping_method: "clip" + clip_positive_only: false + inverse_sft_mask: false + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc_wm_loss" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wm_loss_param.yaml b/opentinker/client/client_config/alfworld_wm_loss_param.yaml new file mode 100644 index 00000000..5ca6eb44 --- /dev/null +++ b/opentinker/client/client_config/alfworld_wm_loss_param.yaml @@ -0,0 +1,90 @@ +# ALFWorld Training Configuration +# Use with: python alfworld_rl.py + +# Project settings +project_name: opentinker +experiment_name: alfworld_wm_loss_only_0.0001 + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: 2ed6f8544ac3e30d5c08879166cc10d9c6232448 + +# Model and tokenizer +tokenizer_path: null +enable_sleep_mode: false + +# Training parameters +batch_size: 8 +num_workers: 4 +# Training duration - set ONE of these (num_steps takes precedence if both set) +num_epochs: null # Number of epochs (null = use num_steps) +num_steps: 1000 # Total training steps (null = use num_epochs) +save_freq: 200 +test_freq: 50 # Validation frequency (every N steps) + +# Validation parameters +val_batch_size: 50 # Total validation samples (null = 50) + +# Model parameters +# Generation parameters +temperature: 1 # Lower temperature for more focused responses +top_p: 1 +max_new_tokens: 4096 # TOTAL response budget for entire multi-turn trajectory (NOT per-turn!) +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# RL Algorithm settings (passed to server via scheduler) +# adv_estimator options: +# - "grpo" : Standard GRPO (outcome-only advantage) +# - "grpo_per_step" : Per-step GRPO with return-based advantages (for multi-turn tasks) +# - "gae" : Generalized Advantage Estimation (for PPO, requires critic) +adv_estimator: "grpo" +# rollout_n: number of samples per prompt for GRPO/grpo_per_step +# For PPO (gae), rollout_n is typically 1 +rollout_n: 8 + +# World Model SFT loss: predict environment observation tokens as auxiliary task +world_model_loss: + world_model_coeff: 0.0001 + world_model_annealing_steps: 0 + world_model_annealing_end_factor: 0.0 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + # If you run the ALFWorld env server in sharded mode (--shards N), + # set env_shards=N. The client will route each instance_id to a stable shard. + env_shards: 8 + max_steps: 20 # ALFWorld episodes max steps + max_total_steps: 20 # Max environment step calls (controls rollout turns) + observation_template: "{observation}" + # ALFWorld specific settings + split: train # train, eval_in_distribution, eval_out_of_distribution + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 # Per-turn response limit (optional, null for no limit) + # Weave tracing (optional - runs on SERVER side) + weave_project: "zsqzz/alfworld-env-test" + experiment_name: "wm_loss_only" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml index c6cf5c19..570ff180 100644 --- a/opentinker/client/client_config/alfworld_wmc_erc_param.yaml +++ b/opentinker/client/client_config/alfworld_wmc_erc_param.yaml @@ -9,7 +9,7 @@ # Project settings project_name: opentinker -experiment_name: alfworld_wmc_erc +experiment_name: alfworld_wmc_erc_clip # Logging logger_backends: ["console", "wandb"] @@ -23,6 +23,7 @@ wandb_key: null # Model and tokenizer tokenizer_path: null +enable_sleep_mode: false # Training parameters batch_size: 8 @@ -45,17 +46,50 @@ max_prompt_tokens: 2048 algorithm: "agent_loop" # Use per-step advantage for multi-turn credit assignment -adv_estimator: "grpo_per_step" +adv_estimator: "grpo" rollout_n: 8 # WMC-ERC: Dynamic Entropy Clipping -# - mu_base: base clipping coefficient (controls tightness of the gate) -# - lambda_wm: how much WM uncertainty widens the gate (higher = more tolerant in unknown regions) +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - eta_wm: base multiplier for WM uncertainty signal +# - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) +# - clipping_type: "batch" or "global" (global uses running statistics) +# - clipping_method: "mask" (0/1 gating) or "clip" (PPO-style soft clipping) +# - clip_positive_only: if true, only apply clipping to tokens with positive advantages +# - inverse_sft_mask: if true, use (1 - m_t) as weights for SFT loss on subsequent env tokens +# - momentum: momentum for running statistics (only used if clipping_type is "global") # - enable: master switch wmc_erc: enable: true mu_base: 1.0 + mu_exp: 3.0 + eta_wm: 2.0 lambda_wm: 1.0 + clipping_type: "global" + clipping_method: "clip" + clip_positive_only: false + inverse_sft_mask: false + momentum: 0.9 + +# World Model SFT loss: predict environment observation tokens as auxiliary task +world_model_coeff: 0.0 +world_model_annealing_steps: 0 +world_model_annealing_end_factor: 1.0 + +# RWML: Reinforcement World Model Learning (arxiv:2602.05842) +# Computes per-turn rewards based on text embedding similarity between the +# model's predicted next observation (argmax at obs token positions) and the +# actual environment observation. Binary reward: 1.0 if (1-cos_sim) < tau_d. +# - enable: master switch +# - embedding_model: HuggingFace model for text embeddings +# - tau_d: distance threshold (paper uses 0.2 for ALFWorld) +# - coeff: weight when combining RWML reward with task success reward +rwml: + enable: false + embedding_model: "Alibaba-NLP/gte-large-en-v1.5" + tau_d: 0.2 + coeff: 1.0 # Interaction configuration interaction: @@ -83,4 +117,4 @@ scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: null # GPU settings -num_gpus: 8 +num_gpus: 4 diff --git a/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml new file mode 100644 index 00000000..612ddea1 --- /dev/null +++ b/opentinker/client/client_config/alfworld_wmc_erc_param_ppo.yaml @@ -0,0 +1,96 @@ +# ALFWorld Training Configuration with WMC-ERC Dynamic Entropy Clipping +# Use with: python alfworld_rl.py --config-name alfworld_wmc_erc_param +# +# WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution): +# Uses the LLM's prediction entropy at env token positions as a World Model +# uncertainty signal (H_WM) to dynamically gate policy gradient updates. +# Prevents entropy collapse in well-understood regions while permitting +# exploration in uncertain ones. + +# Project settings +project_name: opentinker +experiment_name: alfworld_wmc_erc_ppo + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null +enable_sleep_mode: false + +# Training parameters +batch_size: 8 +num_workers: 4 +num_epochs: null +num_steps: 1000 +save_freq: 200 +test_freq: 50 + +# Validation parameters +val_batch_size: 50 + +# Generation parameters +temperature: 1 +top_p: 1 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# Use per-step advantage for multi-turn credit assignment +adv_estimator: "gae" +rollout_n: 1 + +# WMC-ERC: Dynamic Entropy Clipping +# - mu_base: clipping coefficient for collapsing side (S_* > S_bar) +# - mu_exp: clipping coefficient for exploration side (S_* < S_bar, usually > mu_base) +# - eta_wm: base multiplier for WM uncertainty signal +# - lambda_wm: exponential decay factor for WM uncertainty (higher H_WM -> smaller gate) +# - clipping_type: "batch" or "global" (global uses running statistics) +# - momentum: momentum for running statistics (only used if clipping_type is "global") +# - enable: master switch +wmc_erc: + enable: true + mu_base: 1.0 + mu_exp: 2.0 + eta_wm: 1.0 + lambda_wm: 1.0 + clipping_type: "global" + momentum: 0.9 + +# Interaction configuration +interaction: + name: alfworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 1 + max_steps: 20 + max_total_steps: 20 + observation_template: "{observation}" + split: train + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: null + experiment_name: "alfworld_wmc_erc_ppo" + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 +agent_num_workers: 4 diff --git a/opentinker/client/geo3k_rl.py b/opentinker/client/geo3k_rl.py index f45f5f8f..f51b7a0a 100644 --- a/opentinker/client/geo3k_rl.py +++ b/opentinker/client/geo3k_rl.py @@ -70,6 +70,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/geo3k_tool_rl.py b/opentinker/client/geo3k_tool_rl.py index e812132f..9f08d9a7 100644 --- a/opentinker/client/geo3k_tool_rl.py +++ b/opentinker/client/geo3k_tool_rl.py @@ -93,6 +93,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/gomoku_rl.py b/opentinker/client/gomoku_rl.py index bf42ca8e..cb4b2608 100755 --- a/opentinker/client/gomoku_rl.py +++ b/opentinker/client/gomoku_rl.py @@ -114,6 +114,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) # Set configuration on server diff --git a/opentinker/client/math_rl.py b/opentinker/client/math_rl.py index 5bffb2df..804cc545 100755 --- a/opentinker/client/math_rl.py +++ b/opentinker/client/math_rl.py @@ -74,6 +74,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/math_tool_rl.py b/opentinker/client/math_tool_rl.py index 1c68b06d..3ef1d3ce 100755 --- a/opentinker/client/math_tool_rl.py +++ b/opentinker/client/math_tool_rl.py @@ -73,6 +73,7 @@ def main(args): project_name=args.project_name, experiment_name=args.experiment_name, logger_backends=args.logger_backends, + config=args, ) client.set_config(args, env) diff --git a/opentinker/client/utils/http_training_client.py b/opentinker/client/utils/http_training_client.py index 08f65e9b..2f349940 100644 --- a/opentinker/client/utils/http_training_client.py +++ b/opentinker/client/utils/http_training_client.py @@ -579,6 +579,7 @@ def __init__( project_name: Optional[str] = None, experiment_name: Optional[str] = None, logger_backends: Optional[List[str]] = None, + config: Optional[Any] = None, **client_kwargs, ): self.client = HTTPTrainingClient(server_url, **client_kwargs) @@ -588,11 +589,23 @@ def __init__( if logger_backends and project_name and experiment_name: from verl.utils.tracking import Tracking + # Convert DictConfig to dict if necessary for Tracking + tracking_config = config + if config is not None and not isinstance(config, dict): + from omegaconf import OmegaConf + + tracking_config = OmegaConf.to_container(config, resolve=True) + + # Ensure 'trainer' key exists to avoid KeyError in verl.utils.tracking + if tracking_config is not None: + if "trainer" not in tracking_config: + tracking_config["trainer"] = {} + self.tracker = Tracking( project_name=project_name, experiment_name=experiment_name, default_backend=logger_backends, - config=None, # Can pass config if needed + config=tracking_config, ) logger.info(f"Initialized tracking with backends: {logger_backends}") @@ -624,6 +637,39 @@ def set_config(self, args: DictConfig, env=None): }, } ) + + # Pass WMC-ERC config to server if present + wmc_erc_cfg = getattr(args, "wmc_erc", None) + if wmc_erc_cfg is not None: + wmc_erc_dict = OmegaConf.to_container(wmc_erc_cfg, resolve=True) + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create({"wmc_erc": wmc_erc_dict}), + ) + print(f"[ServiceClient] Passing WMC-ERC config to server: {wmc_erc_dict}") + + # Pass RWML config to server if present + rwml_cfg = getattr(args, "rwml", None) + if rwml_cfg is not None: + rwml_dict = OmegaConf.to_container(rwml_cfg, resolve=True) + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create({"rwml": rwml_dict}), + ) + print(f"[ServiceClient] Passing RWML config to server: {rwml_dict}") + + # Optional world model SFT coefficient for joint PPO + WM training. + if hasattr(args, "world_model_loss") and args.world_model_loss: + world_model_loss_cfg = OmegaConf.to_container(args.world_model_loss, resolve=True) + server_cfg = OmegaConf.merge( + server_cfg, + OmegaConf.create( + {"algorithm": {"world_model_loss": world_model_loss_cfg}} + ), + ) + print( + f"[ServiceClient] Passing world_model_loss config to server: {world_model_loss_cfg}" + ) # Add multi_turn config if present in args if hasattr(args, "multi_turn") and args.multi_turn: @@ -807,10 +853,11 @@ def fit( # Update progress bar if verbose and progress_bar: # Show key metrics in progress bar (filter game/ metrics except win_rate) + # Added wmc_erc/mask_ratio to monitor dynamic entropy clipping display_metrics = { k: v for k, v in last_metrics.items() - if not k.startswith("game/") or k == "game/win_rate" + if not k.startswith("game/") or k == "game/win_rate" or k == "wmc_erc/mask_ratio" } metrics_str = ", ".join( [ diff --git a/opentinker/environment/__init__.py b/opentinker/environment/__init__.py index 6d096c35..7bc88a03 100755 --- a/opentinker/environment/__init__.py +++ b/opentinker/environment/__init__.py @@ -32,14 +32,33 @@ run_game_server, ) -from opentinker.environment.inference_pipeline import ( - InferencePipeline, - InferenceResult, - RemoteEnvironmentClient, - run_inference, - load_samples, - generate_samples, -) +# Lazy import for InferencePipeline to avoid heavy dependencies (like vllm) +# when only the game server is needed. +def __getattr__(name): + if name in [ + "InferencePipeline", + "InferenceResult", + "RemoteEnvironmentClient", + "run_inference", + "load_samples", + "generate_samples", + ]: + from opentinker.environment.inference_pipeline import ( + InferencePipeline, + InferenceResult, + RemoteEnvironmentClient, + run_inference, + load_samples, + generate_samples, + ) + globals()["InferencePipeline"] = InferencePipeline + globals()["InferenceResult"] = InferenceResult + globals()["RemoteEnvironmentClient"] = RemoteEnvironmentClient + globals()["run_inference"] = run_inference + globals()["load_samples"] = load_samples + globals()["generate_samples"] = generate_samples + return globals()[name] + raise AttributeError(f"module {__name__} has no attribute {name}") __all__ = [ # Base diff --git a/opentinker/environment/base_game_server.py b/opentinker/environment/base_game_server.py index 668add10..77c2a59c 100755 --- a/opentinker/environment/base_game_server.py +++ b/opentinker/environment/base_game_server.py @@ -383,42 +383,48 @@ async def health_check(): @app.post("/reset") async def reset(request: ResetRequest): """Reset/create a game instance.""" - instance_id = request.instance_id - job_id = request.job_id - # Extract extra fields for game reset (exclude instance_id and job_id) - reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) + try: + instance_id = request.instance_id + job_id = request.job_id + # Extract extra fields for game reset (exclude instance_id and job_id) + reset_kwargs = request.model_dump(exclude={"instance_id", "job_id"}) - with games_lock: - # Reuse existing game instance if available (avoids re-initialization) - if instance_id in games: - game = games[instance_id] - else: - game = game_class(**game_kwargs) - games[instance_id] = game + with games_lock: + # Reuse existing game instance if available (avoids re-initialization) + if instance_id in games: + game = games[instance_id] + else: + game = game_class(**game_kwargs) + games[instance_id] = game - # Reset the game (this is the slow part) - observation = game.reset(**reset_kwargs) + # Reset the game (this is the slow part) + observation = game.reset(**reset_kwargs) - # Track that this game has started (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.register_game_start(instance_id, job_id) + # Track that this game has started (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.register_game_start(instance_id, job_id) - system_prompt = game.get_system_prompt() - initial_message = game.get_initial_user_message() - full_observation = ( - f"{initial_message}\n\n{observation}" if observation else initial_message - ) + system_prompt = game.get_system_prompt() + initial_message = game.get_initial_user_message() + full_observation = ( + f"{initial_message}\n\n{observation}" if observation else initial_message + ) - response = { - "observation": full_observation, - "system_prompt": system_prompt, - } + response = { + "observation": full_observation, + "system_prompt": system_prompt, + } - state = game.get_state() - if state: - response["game_state"] = state + state = game.get_state() + if state: + response["game_state"] = state - return response + return response + except Exception as e: + import traceback + error_msg = f"Reset failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + raise HTTPException(status_code=500, detail=error_msg) @app.post("/finalize") async def finalize(request: ResetRequest): @@ -433,37 +439,45 @@ async def finalize(request: ResetRequest): @app.post("/step") async def step(request: StepRequest): """Execute a step in the game.""" - instance_id = request.instance_id - job_id = request.job_id - action = request.action + try: + instance_id = request.instance_id + job_id = request.job_id + action = request.action + + if instance_id not in games: + raise HTTPException( + status_code=404, + detail=f"Instance {instance_id} not found. Call /reset first.", + ) - if instance_id not in games: - raise HTTPException( - status_code=404, - detail=f"Instance {instance_id} not found. Call /reset first.", - ) + game = games[instance_id] + result = game.step(action) - game = games[instance_id] - result = game.step(action) + # Record statistics with instance_id for per-game tracking (with job isolation) + stats = multi_stats.get_job_stats(job_id) + stats.record_game_result( + result.info, result.reward, result.done, instance_id, job_id + ) - # Record statistics with instance_id for per-game tracking (with job isolation) - stats = multi_stats.get_job_stats(job_id) - stats.record_game_result( - result.info, result.reward, result.done, instance_id, job_id - ) + # Clean up finished games + if result.done: + with games_lock: + if instance_id in games: + del games[instance_id] - # Clean up finished games - if result.done: - with games_lock: - if instance_id in games: - del games[instance_id] - - return { - "observation": result.observation, - "reward": result.reward, - "done": result.done, - "info": result.info, - } + return { + "observation": result.observation, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + except Exception as e: + import traceback + error_msg = f"Step failed for instance {request.instance_id}: {str(e)}\n{traceback.format_exc()}" + print(f"[Error] {error_msg}") + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=error_msg) @app.get("/stats") async def get_stats(job_id: str = "default"): diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 2a1824d4..dbe9c96e 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -284,7 +284,7 @@ def check_gpu_available(gpu_id: int) -> bool: return True # Fail open # Thresholds for considering a GPU "idle" - MAX_MEMORY_MB = 10 # Allow up to 100 MB (some baseline CUDA overhead) + MAX_MEMORY_MB = 2000 # Allow up to 2 GB (root processes may use ~1 GB) MAX_UTILIZATION = 1000 # Allow up to 5% utilization if memory_used_mb > MAX_MEMORY_MB or utilization_percent > MAX_UTILIZATION: @@ -294,35 +294,8 @@ def check_gpu_available(gpu_id: int) -> bool: ) return False - # Check 2: Look for running processes on this GPU - pmon_result = subprocess.run( - ["nvidia-smi", "pmon", "-c", "1", "-s", "um"], - capture_output=True, - text=True, - timeout=5, - ) - - if pmon_result.returncode == 0: - # Parse pmon output to check for processes on this GPU - # Format: "# gpu pid type sm mem enc dec command" - # " 0 12345 C 50 500 0 0 python" - lines = pmon_result.stdout.strip().split("\n") - for line in lines: - if line.startswith("#") or not line.strip(): - continue - parts = line.split() - if len(parts) >= 2: - try: - gpu_idx = int(parts[0].strip()) - if gpu_idx == gpu_id and parts[1].strip() != "-": - # Found a process on this GPU - pid = parts[1].strip() - logger.warning( - f"GPU {gpu_id}: ⚠️ OCCUPIED - Process {pid} detected via pmon" - ) - return False - except (ValueError, IndexError): - continue + # pmon check disabled — root/system processes cause false positives + # when sharing GPUs across users. Memory threshold check above is sufficient. # All checks passed - GPU is idle logger.debug( @@ -1035,6 +1008,22 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: Popen process object """ env = os.environ.copy() + + # [FIX] Check for broken libcuda.so.1 (size 0) - common issue in some environments + libcuda_path = "/usr/lib/x86_64-linux-gnu/libcuda.so.1" + if os.path.exists(libcuda_path) and os.path.getsize(libcuda_path) == 0: + compat_path = "/usr/local/cuda-12.4/compat" + if os.path.isdir(compat_path): + current_ld_path = env.get("LD_LIBRARY_PATH", "") + env["LD_LIBRARY_PATH"] = ( + f"{compat_path}:{current_ld_path}".strip(":") + if current_ld_path + else compat_path + ) + logger.info( + f"Job {job.job_id}: 🛠 Fixed broken libcuda.so.1 by adding {compat_path} to LD_LIBRARY_PATH" + ) + # Set CUDA_VISIBLE_DEVICES to comma-separated list of GPU IDs env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, job.gpu_ids)) # Pass job_id to agent loop for per-client trace subdirectory isolation diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh new file mode 100755 index 00000000..c11dd0e9 --- /dev/null +++ b/opentinker/scripts/run_alfworld.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# ALFWorld Training Script (Multi-Turn) +# +# This script runs ALFWorld RL training with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env +# Terminal 3: bash run_alfworld.sh client +# +# Prerequisites: +# - pip install alfworld +# - alfworld-download +# - See docs/alfworld_multiturn.md for environment setup + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT="${SCHEDULER_PORT:-9780}" +ENV_PORT="${ENV_PORT:-1234}" +GPUS="${GPUS:-[0,1,2,3]}" +MODEL_PATH="/inspire/hdd/project/robot-reasoning/xuyue-p-xuyue/ziyu/.cache/huggingface/hub/models--Qwen--Qwen2.5-3B-Instruct" + +# OpenTinker root (relative to this script: opentinker/scripts/run_alfworld.sh) +OPENTINKER_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +export NCCL_CUMEM_ENABLE=0 +export VLLM_DISABLE_SLEEP_MODE=1 +export HF_HUB_OFFLINE=1 +export WANDB_MODE=offline + +# Activate conda environment (adjust to your setup if needed) +if [ -f "$HOME/anaconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/anaconda3/etc/profile.d/conda.sh" +elif [ -f "$HOME/miniconda3/etc/profile.d/conda.sh" ]; then + source "$HOME/miniconda3/etc/profile.d/conda.sh" +fi +# conda activate opentinker + +# Change to OpenTinker directory +cd "$OPENTINKER_ROOT" + +# Get current host IP for communication between components +# Use 127.0.0.1 if running everything on the same machine +HOST_IP="${HOST_IP:-127.0.0.1}" + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + env|2) + echo "========================================" + echo "Step 2: Starting ALFWorld Environment Server on port $ENV_PORT" + echo "========================================" + python -m opentinker.environment.alfworld.alfworld_server \ + --port "$ENV_PORT" \ + --max_steps 50 \ + --split train \ + --num_games -1 \ + ;; + + client|3) + echo "========================================" + echo "Step 3: Starting ALFWorld RL Client" + echo "========================================" + python opentinker/client/alfworld_rl.py \ + --config-name alfworld_wm_loss_clip_param \ + tokenizer_path="$MODEL_PATH" \ + batch_size=4 \ + val_batch_size=50 \ + num_steps=1000 \ + save_freq=2000 \ + test_freq=10 \ + scheduler_url="http://$HOST_IP:$SCHEDULER_PORT" \ + interaction.config.env_port="$ENV_PORT" \ + interaction.config.env_host="$HOST_IP" + ;; + + *) + echo "Usage: $0 {scheduler|env|client}" + echo "" + echo "Example (separate terminals):" + echo " Terminal 1: bash $0 scheduler" + echo " Terminal 2: bash $0 env" + echo " Terminal 3: bash $0 client" + exit 1 + ;; +esac diff --git a/opentinker/server/generic_agent_loop.py b/opentinker/server/generic_agent_loop.py index bc56696d..43b1c5ab 100755 --- a/opentinker/server/generic_agent_loop.py +++ b/opentinker/server/generic_agent_loop.py @@ -638,6 +638,11 @@ async def _handle_interacting_state( if reward is not None: agent_data.turn_scores.append(reward) + # Store per-turn observation text for RWML reward computation + if "turn_observations" not in agent_data.extra_fields: + agent_data.extra_fields["turn_observations"] = [] + agent_data.extra_fields["turn_observations"].append(observation) + # Store environment info under a SINGLE key to ensure consistent structure # across all samples (avoids DataProto.concat assertion errors when different # samples return different info keys) diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index c4fe1ce0..540ddb5b 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -85,6 +85,7 @@ ) + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -620,6 +621,7 @@ def __init__( self.resource_pool_manager = resource_pool_manager self.reward_fn = reward_fn self.val_reward_fn = val_reward_fn + self.running_stats: Dict[str, float] = dict() # Initialize trainer (without dataloader - client provides custom dataloader) from verl.single_controller.ray import RayWorkerGroup @@ -640,6 +642,7 @@ def __init__( # Server state self.is_initialized = False self.global_steps = 0 + self.wm_coeff = 0.0 # Generation config (can be overridden by client) self.generation_config = { @@ -676,8 +679,32 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]: Status dict """ try: + # Reset server state for new training session + self.global_steps = 0 + self.running_stats.clear() + logger.info("Reset server state (global_steps and running_stats) for new worker initialization") + # optimizer needs parameter: total_steps self.trainer.post_init(total_steps) + + # Forward algorithm.world_model_coeff → actor config so dp_actor can read it + if hasattr(self.config.algorithm, "world_model_loss"): + algo_wm_coeff = self.config.algorithm.world_model_loss.get("world_model_coeff", 0.0) + algo_wm_annealing_steps = self.config.algorithm.world_model_loss.get("world_model_annealing_steps", 0) + algo_wm_annealing_end_factor = self.config.algorithm.world_model_loss.get("world_model_annealing_end_factor", 1.0) + else: + algo_wm_coeff = 0 + algo_wm_annealing_steps = 0 + algo_wm_annealing_end_factor = 0 + + if algo_wm_coeff > 0: + from omegaconf import open_dict + with open_dict(self.config): + self.config.actor_rollout_ref.actor.world_model_coeff = algo_wm_coeff + self.config.actor_rollout_ref.actor.world_model_annealing_steps = algo_wm_annealing_steps + self.config.actor_rollout_ref.actor.world_model_annealing_end_factor = algo_wm_annealing_end_factor + logger.info(f"Forwarded world_model_coeff={algo_wm_coeff} (annealing_steps={algo_wm_annealing_steps}) to actor config") + logger.info("Initializing workers...") # Check async rollout mode @@ -708,6 +735,24 @@ def init_workers(self, total_steps: int) -> Dict[str, Any]: if self.async_rollout_mode: self.async_rollout_manager = self.trainer.async_rollout_manager + # World model SFT loss coeff (actual loss computed in dp_actor via config.world_model_coeff) + self.wm_coeff = self.config.actor_rollout_ref.actor.get("world_model_coeff", 0.0) + if self.wm_coeff > 0: + logger.info(f"World model SFT loss enabled with coeff={self.wm_coeff}") + + # RWML: Reinforcement World Model Learning (per-turn embedding similarity reward) + rwml_cfg = OmegaConf.select(self.config, "rwml", default=None) + self.rwml_enabled = rwml_cfg is not None and rwml_cfg.get("enable", False) + if self.rwml_enabled: + from opentinker.backend_patch.verl.trainer.ppo.world_model_rl import EmbeddingSimilarityReward + self.rwml_reward_fn = EmbeddingSimilarityReward( + model_name_or_path=rwml_cfg.embedding_model, + device="cuda:0" if torch.cuda.is_available() else "cpu", + ) + self.rwml_tau_d = rwml_cfg.get("tau_d", 0.2) + self.rwml_coeff = rwml_cfg.get("coeff", 1.0) + logger.info(f"RWML enabled: model={rwml_cfg.embedding_model}, tau_d={self.rwml_tau_d}, coeff={self.rwml_coeff}") + self.is_initialized = True logger.info("Workers initialized successfully") return {"status": "success", "message": "Workers initialized"} @@ -944,9 +989,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # episodes into individual per-turn training samples, the gen_batch_output # batch size is larger than the original batch. We need to expand the # original batch to match using the expansion index. - expansion_index = gen_batch_output.meta_info.pop( - "per_turn_expansion_index", None - ) + expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) if expansion_index is not None: logger.info( f"[Per-turn training] Expanding original batch from {len(batch)} to " @@ -959,7 +1002,6 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict - batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) # Expand non-tensor batch expanded_non_tensor = {} @@ -1052,6 +1094,10 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info("=" * 80) # ===== DEBUG LOGGING END ===== + # Pass RWML flag so actor returns predicted token IDs + if self.rwml_enabled: + batch.meta_info["rwml_enabled"] = True + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) logger.info( f"DEBUG: old_log_prob keys: {list(old_log_prob.batch.keys())}" @@ -1097,6 +1143,70 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: metrics.update(calculate_debug_metrics(batch)) + # 6.5 RWML: Separate world model GRPO update (before policy training) + if self.rwml_enabled and "predicted_ids" in batch.batch: + with marked_timer("rwml_update", timing_raw, color="magenta"): + from opentinker.backend_patch.verl.trainer.ppo.world_model_rl import ( + decode_per_turn_texts, compute_rwml_turn_rewards, + ) + from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( + compute_turn_boundaries, compute_grpo_per_step_advantage, + ) + + predicted_ids = batch.batch.pop("predicted_ids") + response_mask = batch.batch["response_mask"] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + turn_boundaries = compute_turn_boundaries(response_mask) + + # Decode predicted and actual observation texts per turn + predicted_obs = decode_per_turn_texts( + predicted_ids, response_mask, attention_mask, + self.tokenizer, turn_boundaries, + ) + actual_obs = decode_per_turn_texts( + responses, response_mask, attention_mask, + self.tokenizer, turn_boundaries, + ) + + # Compute RWML rewards (binary, per turn) + rwml_rewards, rwml_metrics = compute_rwml_turn_rewards( + predicted_obs, actual_obs, self.rwml_reward_fn, self.rwml_tau_d + ) + metrics.update(rwml_metrics) + logger.info( + f"[RWML] mean_sim={rwml_metrics.get('rwml/mean_similarity', 0):.4f}, " + f"mean_reward={rwml_metrics.get('rwml/mean_reward', 0):.4f}, " + f"valid_pairs={rwml_metrics.get('rwml/num_valid_pairs', 0)}" + ) + + # Compute RWML advantages via GRPO per-step (separate from policy) + norm_adv = self.config.algorithm.get("norm_adv_by_std_in_grpo", True) + rwml_advantages, _ = compute_grpo_per_step_advantage( + token_level_rewards=torch.zeros_like(response_mask.float()), + response_mask=response_mask, + index=batch.non_tensor_batch["uid"], + turn_scores=rwml_rewards, + gamma=self.config.algorithm.gamma, + norm_adv_by_std_in_grpo=norm_adv, + ) + + # Run separate RWML GRPO actor update + batch.batch["advantages"] = rwml_advantages + batch.meta_info["multi_turn"] = ( + self.config.actor_rollout_ref.rollout.multi_turn.enable + ) + rwml_actor_output = self.actor_rollout_wg.update_actor(batch) + rwml_actor_metrics = reduce_metrics( + rwml_actor_output.meta_info["metrics"] + ) + metrics.update( + {f"rwml/{k}": v for k, v in rwml_actor_metrics.items()} + ) + + # Remove RWML advantages so policy training computes its own + del batch.batch["advantages"] + # 7. Compute ref_log_prob if needed if self.use_reference_policy: with marked_timer("ref_log_prob", timing_raw, color="olive"): @@ -1174,15 +1284,16 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: ) batch, wmc_metrics = apply_wmc_erc( - batch, _wmc_erc_entropys, wmc_erc_cfg + batch, _wmc_erc_entropys, wmc_erc_cfg, self.running_stats ) metrics.update(wmc_metrics) logger.info( - f"[WMC-ERC] mask_ratio={wmc_metrics.get('wmc_erc/mask_ratio', 'N/A'):.3f}, " - f"s_star={wmc_metrics.get('wmc_erc/s_star_mean', 'N/A'):.4f}, " - f"h_wm={wmc_metrics.get('wmc_erc/h_wm_mean', 'N/A'):.4f}" + f"[WMC-ERC] mask_ratio={float(wmc_metrics.get('wmc_erc/mask_ratio', float('nan'))):.3f}, " + f"s_star={float(wmc_metrics.get('wmc_erc/s_star_mean', float('nan'))):.4f}, " + f"h_wm={float(wmc_metrics.get('wmc_erc/h_wm_mean', float('nan'))):.4f}" ) + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1192,6 +1303,14 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: ) metrics.update(critic_output_metrics) + # 10.5 Compute observation_mask for world model loss + if self.wm_coeff > 0: + resp_len = batch.batch["response_mask"].shape[1] + attn_response = batch.batch["attention_mask"][:, -resp_len:] + batch.batch["observation_mask"] = ( + attn_response.bool() & ~batch.batch["response_mask"].bool() + ).float() + # 11. Update actor (check critic warmup) if self.config.trainer.critic_warmup <= self.global_steps: with marked_timer("update_actor", timing_raw, color="red"): @@ -1281,9 +1400,22 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info(f"Training step {self.global_steps} completed successfully") # Free large intermediates and force garbage collection to prevent OOM + # Explicitly delete all potential large objects in local scope del batch, gen_batch, gen_batch_output, reward_tensor - if "_wmc_erc_entropys" in dir(): + if "old_log_prob" in locals(): + del old_log_prob + if "ref_log_prob" in locals(): + del ref_log_prob + if "values" in locals(): + del values + if "gen_baseline_output" in locals(): + del gen_baseline_output + if "gen_baseline_batch" in locals(): + del gen_baseline_batch + + if "_wmc_erc_entropys" in locals(): del _wmc_erc_entropys + gc.collect() torch.cuda.empty_cache() @@ -1564,9 +1696,7 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: # 6. Merge original batch and generated output # Per-turn training expansion: expand batch if gen output is larger - expansion_index = gen_batch_output.meta_info.pop( - "per_turn_expansion_index", None - ) + expansion_index = gen_batch_output.meta_info.pop('per_turn_expansion_index', None) if expansion_index is not None: logger.info( f"[Per-turn training] Validation: Expanding original batch from {len(batch)} to " @@ -1578,7 +1708,6 @@ def validate_step(self, batch: DataProto) -> Dict[str, Any]: elif batch.batch is not None: # Empty TensorDict (all keys were popped) — create new one with expanded size from tensordict import TensorDict - batch.batch = TensorDict({}, batch_size=[len(expansion_index)]) expanded_non_tensor = {} for k, v in batch.non_tensor_batch.items(): @@ -2136,13 +2265,14 @@ def run_fastapi_server(): ) ray.init( namespace=_server_cfg.ray.namespace, - num_gpus=_server_cfg.trainer.n_gpus_per_node, # Explicitly specify number of GPUs + num_gpus=_server_cfg.trainer.n_gpus_per_node, ignore_reinit_error=True, runtime_env={ "env_vars": { "NCCL_CUMEM_ENABLE": "0", "VLLM_DISABLE_SLEEP_MODE": "1", "RAY_memory_usage_threshold": "0.99", + "VLLM_GPU_MEMORY_UTILIZATION": "0.15", }, }, ) @@ -2158,6 +2288,7 @@ def run_fastapi_server(): "NCCL_CUMEM_ENABLE": "0", "VLLM_DISABLE_SLEEP_MODE": "1", "RAY_memory_usage_threshold": "0.99", + "VLLM_GPU_MEMORY_UTILIZATION": "0.15", }, }, ) @@ -2294,4 +2425,4 @@ def run_fastapi_server(): if __name__ == "__main__": # Example usage print("HTTP Training Server") - print("Use launch_server() to start the server") + print("Use launch_server() to start the server") \ No newline at end of file diff --git a/opentinker/server/launch_http_server.py b/opentinker/server/launch_http_server.py index 830af82c..57843350 100644 --- a/opentinker/server/launch_http_server.py +++ b/opentinker/server/launch_http_server.py @@ -64,7 +64,13 @@ def main(cfg): cfg.actor_rollout_ref.rollout.tensor_model_parallel_size = 2 cfg.actor_rollout_ref.rollout.name = "vllm" - cfg.actor_rollout_ref.rollout.gpu_memory_utilization = 0.6 + cfg.actor_rollout_ref.rollout.gpu_memory_utilization = 0.4 + # # vLLM sleep mode setting + # if "enable_sleep_mode" in cfg: + # cfg.actor_rollout_ref.rollout.free_cache_engine = cfg.enable_sleep_mode + # else: + # # Default to False if not specified (per user request to ensure it doesn't start) + # cfg.actor_rollout_ref.rollout.free_cache_engine = False # GRPO/GRPO-per-step 特定配置 # grpo_per_step uses the same training framework as grpo, just with different advantage estimation diff --git a/opentinker/tests/test_wmc_erc.py b/opentinker/tests/test_wmc_erc.py index 3d53a335..df534bc4 100644 --- a/opentinker/tests/test_wmc_erc.py +++ b/opentinker/tests/test_wmc_erc.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Tests for WMC-ERC dynamic entropy clipping module. +Tests for WMC-ERC dynamic entropy clipping module using unittest. -Run with: pytest opentinker/tests/test_wmc_erc.py -v +Run with: python opentinker/tests/test_wmc_erc.py """ +import unittest import numpy as np -import pytest import torch +from unittest.mock import MagicMock from opentinker.backend_patch.verl.trainer.ppo.per_step_core_algos import ( compute_turn_boundaries, @@ -32,12 +33,11 @@ ) -class TestComputeSStar: - """Test S_* (policy blind confidence) computation.""" +class TestWmcErc(unittest.TestCase): + """Test WMC-ERC components and orchestration.""" - def test_single_turn(self): + def test_single_turn_s_star(self): """S_* for a single turn = mean of p_k * (H + log p_k) over action tokens.""" - # 1 sample, 6 positions: 4 action + 2 padding p = torch.tensor([0.8, 0.6, 0.9, 0.7]) old_log_probs = torch.zeros(1, 6) old_log_probs[0, :4] = torch.log(p) @@ -46,214 +46,34 @@ def test_single_turn(self): boundaries = compute_turn_boundaries(response_mask) result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 1 # 1 sample - assert len(result[0]) == 1 # 1 turn + self.assertEqual(len(result), 1) + self.assertEqual(len(result[0]), 1) - # Manual: p_k * (H + log(p_k)) for each token, then mean H = torch.tensor([1.0, 1.5, 0.5, 1.2]) expected = (p * (H + torch.log(p))).mean().item() - assert abs(result[0][0].item() - expected) < 1e-5 - - def test_multi_turn(self): - """Two turns should produce two S_* values.""" - old_log_probs = torch.log( - torch.tensor([[0.8, 0.6, 0.5, 0.5, 0.9, 0.7, 0.5, 0.5]]) - ) - entropys = torch.tensor([[1.0, 1.5, 0.0, 0.0, 0.5, 1.2, 0.0, 0.0]]) - response_mask = torch.tensor( - [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 - ) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result[0]) == 2 # 2 turns - # Turn 0 and Turn 1 should have different values - assert result[0][0].item() != result[0][1].item() - - def test_batch(self): - """Batch of 2 samples.""" - old_log_probs = torch.log( - torch.tensor( - [ - [0.8, 0.6, 0.5, 0.5], - [0.5, 0.9, 0.5, 0.5], - ] - ) - ) - entropys = torch.tensor( - [ - [1.0, 1.5, 0.0, 0.0], - [2.0, 0.3, 0.0, 0.0], - ] - ) - response_mask = torch.tensor( - [ - [1, 1, 0, 0], - [1, 1, 0, 0], - ], - dtype=torch.float32, - ) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 2 - - def test_empty_turns(self): - """Sample with no action tokens should produce empty list.""" - old_log_probs = torch.zeros(1, 4) - entropys = torch.zeros(1, 4) - response_mask = torch.zeros(1, 4, dtype=torch.float32) - boundaries = compute_turn_boundaries(response_mask) - - result = compute_s_star(old_log_probs, entropys, response_mask, boundaries) - assert len(result) == 1 - assert len(result[0]) == 0 - - -class TestComputeHWM: - """Test H_WM (world model uncertainty) computation.""" - - def test_single_turn_with_env_tokens(self): - """H_WM for turn 0 = mean entropy at env positions after turn 0.""" - # Sequence: [action, action, env, env, pad, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.0, 0.0]]) - response_mask = torch.tensor([[1, 1, 0, 0, 0, 0]], dtype=torch.float32) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2)]] # 1 turn: action at [0,2) - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result) == 1 - assert len(result[0]) == 1 - # Env tokens at [2,3] → mean entropy = (3.0 + 4.0) / 2 = 3.5 - assert abs(result[0][0].item() - 3.5) < 1e-5 - - def test_two_turns(self): - """Two turns: H_WM_0 from env between turns, H_WM_1 from env after turn 1.""" - # [act, act, env, env, act, act, env, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 4.0, 0.5, 1.2, 2.0, 0.0]]) - response_mask = torch.tensor( - [[1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float32 - ) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2), (4, 6)]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result[0]) == 2 - # Turn 0 env: positions [2,4) → (3.0+4.0)/2 = 3.5 - assert abs(result[0][0].item() - 3.5) < 1e-5 - # Turn 1 env: positions [6,8) but attn_mask=[1,0] → only pos 6 → 2.0 - assert abs(result[0][1].item() - 2.0) < 1e-5 - - def test_no_env_after_last_turn(self): - """Last turn has no env tokens → H_WM = 0.""" - # [act, act, env, act, act, pad] - entropys = torch.tensor([[1.0, 1.5, 3.0, 0.5, 1.2, 0.0]]) - response_mask = torch.tensor([[1, 1, 0, 1, 1, 0]], dtype=torch.float32) - attention_mask_response = torch.tensor( - [[1, 1, 1, 1, 1, 0]], dtype=torch.float32 - ) - boundaries = [[(0, 2), (3, 5)]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - # Turn 0 env: positions [2,3) → 3.0 - assert abs(result[0][0].item() - 3.0) < 1e-5 - # Turn 1 env: positions [5,6) but attn_mask=[0] → H_WM = 0 - assert result[0][1].item() == 0.0 - - def test_empty_turns(self): - """No turns → empty H_WM list.""" - entropys = torch.zeros(1, 4) - response_mask = torch.zeros(1, 4, dtype=torch.float32) - attention_mask_response = torch.ones(1, 4, dtype=torch.float32) - boundaries = [[]] - - result = compute_h_wm( - entropys, response_mask, attention_mask_response, boundaries - ) - assert len(result[0]) == 0 - - -class TestComputeDynamicMask: - """Test dynamic entropy clipping mask.""" - - def test_all_pass(self): - """When all S_* are close to mean, all masks = 1.""" - s_star = [[torch.tensor(1.0), torch.tensor(1.1)]] - h_wm = [[torch.tensor(0.5), torch.tensor(0.5)]] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=3.0, lambda_wm=1.0) - assert mask == [[1.0, 1.0]] - - def test_outlier_blocked_low_hwm(self): - """High S_* outlier with low H_WM (known env) → blocked.""" - s_star = [ - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(10.0)], # outlier - ] - h_wm = [ - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], # known env → tight threshold - ] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # The outlier (10.0) should be blocked when threshold is tight - assert mask[3] == [0.0] - # Normal ones should pass - assert mask[0] == [1.0] - - def test_outlier_allowed_high_hwm(self): - """High S_* outlier with high H_WM (unknown env) → allowed.""" - s_star = [ - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(0.5)], - [torch.tensor(10.0)], # outlier - ] - h_wm = [ - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(0.0)], - [torch.tensor(100.0)], # very uncertain → wide threshold - ] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # The outlier should be allowed because H_WM is very high - assert mask[3] == [1.0] - - def test_empty(self): - """Empty input should return empty.""" - mask = compute_dynamic_mask([], [], mu_base=1.0, lambda_wm=1.0) - assert mask == [] - - def test_single_element(self): - """Single S_* value should not produce nan (std guard).""" - s_star = [[torch.tensor(5.0)]] - h_wm = [[torch.tensor(1.0)]] - mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, lambda_wm=1.0) - # Single element: |5.0 - 5.0| = 0 <= threshold → should pass - assert mask == [[1.0]] - - -class TestApplyWmcErc: - """Test full WMC-ERC orchestration.""" - - def _make_batch( - self, advantages, response_mask, old_log_probs, attention_mask - ): - """Create a minimal mock batch with required fields.""" - from unittest.mock import MagicMock - + self.assertAlmostEqual(result[0][0].item(), expected, places=5) + + def test_asymmetric_behavior(self): + """Test that mu_base and mu_exp act differently using compute_dynamic_mask.""" + s_star = [[torch.tensor(15.0)], [torch.tensor(5.0)]] # Mean=10 + h_wm = [[torch.tensor(1.0)]] * 2 + s_bar = 10.0 + sigma = 5.0 + h_bar = 1.0 + + # 1. mu_base=0.1 (block high), mu_exp=10.0 (allow low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=0.1, mu_exp=10.0, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma) + self.assertEqual(mask[0], [0.0]) # High blocked + self.assertEqual(mask[1], [1.0]) # Low allowed + + # 2. mu_base=10.0 (allow high), mu_exp=0.1 (block low) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=10.0, mu_exp=0.1, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma) + self.assertEqual(mask[0], [1.0]) # High allowed + self.assertEqual(mask[1], [0.0]) # Low blocked + + def _make_batch(self, advantages, response_mask, old_log_probs, attention_mask): batch = MagicMock() batch.batch = { "advantages": advantages.clone(), @@ -263,141 +83,149 @@ def _make_batch( } return batch - def test_masking_zeros_advantage(self): - """When a turn is masked, its advantages become zero.""" - # 4 samples, 1 turn each (4 action tokens + 0 env tokens) - # Sample 3 has extremely different S_* pattern + def test_clipping_type_batch(self): + """Verify that 'batch' mode uses current batch statistics.""" response_mask = torch.ones(4, 4, dtype=torch.float32) advantages = torch.ones(4, 4) * 2.0 - # Make sample 3 very overconfident (high p_k, low H) - old_log_probs = torch.tensor( - [ - [np.log(0.3)] * 4, - [np.log(0.3)] * 4, - [np.log(0.3)] * 4, - [np.log(0.99)] * 4, # very high confidence - ] - ) - entropys = torch.tensor( - [ - [2.0, 2.0, 2.0, 2.0], - [2.0, 2.0, 2.0, 2.0], - [2.0, 2.0, 2.0, 2.0], - [0.01, 0.01, 0.01, 0.01], # very low entropy - ] - ) + # Sample 3 is an outlier in this batch + old_log_probs = torch.tensor([[np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.3)]*4, [np.log(0.9)]*4]) + entropys = torch.tensor([[2.0]*4, [2.0]*4, [2.0]*4, [3.0]*4]) attention_mask = torch.ones(4, 4, dtype=torch.float32) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - - # Check metrics exist - assert "wmc_erc/mask_ratio" in metrics - assert "wmc_erc/s_star_mean" in metrics - assert "wmc_erc/h_wm_mean" in metrics - assert "wmc_erc/total_turns" in metrics - assert metrics["wmc_erc/total_turns"] == 4 - - # Verify masking behavior: - # Samples 0-2: S_* = 0.3*(2.0+log(0.3)) ≈ 0.239 (normal) - # Sample 3: S_* = 0.99*(0.01+log(0.99)) ≈ 0 (outlier in opposite direction) - # H_WM = 0 for all (no env tokens) → tight threshold - # Sample 3 should be masked (|S_3 - S_bar| > threshold) - adv = batch.batch["advantages"] - assert metrics["wmc_erc/num_masked_turns"] >= 1 - assert (adv[3] == 0).all(), "Sample 3 (overconfident outlier) should have zero advantages" - assert (adv[:3] == 2.0).all(), "Samples 0-2 (normal) should keep original advantages" - - def test_disabled(self): - """When enable=False, advantages unchanged.""" - response_mask = torch.ones(2, 4, dtype=torch.float32) - advantages = torch.ones(2, 4) * 5.0 - old_log_probs = torch.full((2, 4), np.log(0.5)) - attention_mask = torch.ones(2, 4, dtype=torch.float32) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - entropys = torch.ones(2, 4) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": False} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert (batch.batch["advantages"] == 5.0).all() - assert metrics == {} - - def test_multi_turn_selective_masking(self): - """Multi-turn: only overconfident turns in known env get masked.""" - # 2 samples, 2 turns each: [act, act, env, env, act, act, env, pad] - response_mask = torch.tensor( - [ - [1, 1, 0, 0, 1, 1, 0, 0], - [1, 1, 0, 0, 1, 1, 0, 0], - ], - dtype=torch.float32, - ) - advantages = torch.ones(2, 8) * 3.0 - # Sample 0: normal confidence - # Sample 1: overconfident on both turns - old_log_probs = torch.tensor( - [ - [np.log(0.3), np.log(0.3), 0, 0, np.log(0.3), np.log(0.3), 0, 0], - [np.log(0.99), np.log(0.99), 0, 0, np.log(0.99), np.log(0.99), 0, 0], - ] - ) - entropys = torch.tensor( - [ - [2.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 0.0], - [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0], - ] - ) - attention_mask = torch.tensor( - [ - [1, 1, 1, 1, 1, 1, 1, 0], - [1, 1, 1, 1, 1, 1, 1, 0], - ], - dtype=torch.float32, - ) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 1.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert metrics["wmc_erc/total_turns"] == 4 - # With low H_WM on sample 1, its overconfident turns should be blocked - # This is a statistical test — the exact outcome depends on batch stats - - def test_returns_wm_nll_metric(self): - """WM NLL metric should be computed from env token log probs.""" - response_mask = torch.tensor( - [[1, 1, 0, 0, 0, 0]], dtype=torch.float32 - ) - advantages = torch.ones(1, 6) - old_log_probs = torch.tensor( - [[np.log(0.5), np.log(0.5), np.log(0.3), np.log(0.4), 0, 0]] - ) - entropys = torch.tensor([[1.0, 1.0, 2.0, 3.0, 0.0, 0.0]]) - attention_mask = torch.tensor( - [[1, 1, 1, 1, 0, 0]], dtype=torch.float32 - ) - - batch = self._make_batch( - advantages, response_mask, old_log_probs, attention_mask - ) - config = {"mu_base": 5.0, "lambda_wm": 1.0, "enable": True} - - _, metrics = apply_wmc_erc(batch, entropys, config) - assert "wmc_erc/wm_nll" in metrics - # WM NLL = -mean(log_prob at env positions [2,3]) - # = -(log(0.3) + log(0.4)) / 2 - expected_nll = -(np.log(0.3) + np.log(0.4)) / 2.0 - assert abs(metrics["wmc_erc/wm_nll"] - expected_nll) < 1e-4 - + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + config = {"mu_base": 0.1, "mu_exp": 10.0, "lambda_wm": 1.0, "enable": True, "clipping_type": "batch"} + running_stats = {"s_bar": 100.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} # Very different global stats + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # In 'batch' mode, Sample 3 should be blocked based on BATCH mean, not GLOBAL mean + # If it used global mean (100), sample 3 (S*~2.6) would be an exploration outlier and allowed by mu_exp=10 + # But in batch mode (S_bar~0.8), sample 3 is a collapsing outlier and blocked by mu_base=0.1 + self.assertTrue((batch.batch["advantages"][3] == 0).all()) + + def test_clipping_type_global(self): + """Verify that 'global' mode uses running statistics.""" + response_mask = torch.ones(4, 4, dtype=torch.float32) + advantages = torch.ones(4, 4) * 2.0 + old_log_probs = torch.tensor([[np.log(0.3)]*4]*4) + entropys = torch.tensor([[2.0]*4]*4) + # S* for all samples will be ~0.24 + + attention_mask = torch.ones(4, 4, dtype=torch.float32) + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Global s_bar is far away (10.0), so batch S* (0.24) looks like a huge exploration outlier + config = {"mu_base": 10.0, "mu_exp": 0.1, "lambda_wm": 0.0, "enable": True, "clipping_type": "global"} + running_stats = {"s_bar": 10.0, "s_std": 1.0, "h_bar": 1.0, "initialized": True} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # Should be blocked because mu_exp=0.1 is very tight relative to global s_bar + self.assertTrue((batch.batch["advantages"] == 0).all()) + + def test_clipping_method_clip(self): + """Verify that 'clip' method scales advantages instead of zeroing them.""" + # Setup data such that we have a violation + # S* = 15.0, s_bar = 10.0, sigma = 1.0, mu_base = 1.0, h_factor = 1.0 + # diff = 5.0, threshold = 1.0 + # m_t should be 1.0 / 5.0 = 0.2 + + s_star = [[torch.tensor(15.0)]] + h_wm = [[torch.tensor(0.0)]] # lambda_wm=0, eta_wm=1 -> h_factor=1 + s_bar = 10.0 + sigma = 1.0 + + # 1. Test clip + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, mu_exp=1.0, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, clipping_method="clip") + self.assertAlmostEqual(mask[0][0], 0.2, places=5) + + # 2. Test mask (for comparison) + mask = compute_dynamic_mask(s_star, h_wm, mu_base=1.0, mu_exp=1.0, eta_wm=1.0, lambda_wm=0.0, + s_bar=s_bar, sigma=sigma, clipping_method="mask") + self.assertEqual(mask[0][0], 0.0) + + def test_clip_positive_only(self): + """Verify that clip_positive_only only affects positive advantages.""" + # Two turns to create variance so s_t != s_bar + response_mask = torch.tensor([[1, 0, 1, 0]], dtype=torch.float32) + attention_mask = torch.ones(1, 4, dtype=torch.float32) + # Turn 1 adv = 10, Turn 2 adv = -10 + advantages = torch.tensor([[10.0, 0.0, -10.0, 0.0]]) + # Make turn 1 very confident (p=0.9), turn 2 very uncertain (p=0.1) + old_log_probs = torch.tensor([[np.log(0.9), 0.0, np.log(0.1), 0.0]]) + entropys = torch.tensor([[0.1, 0.0, 2.0, 0.0]]) + + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Force a mask (m_t = 0.0) by setting very tight mu + config = { + "mu_base": 0.0001, + "mu_exp": 0.0001, + "lambda_wm": 0.0, + "enable": True, + "clipping_type": "batch", + "clip_positive_only": True + } + running_stats = {"initialized": False} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + # If masked, turn 1 (positive adv) should be 0.0, turn 2 (negative adv) should remain -10.0 + self.assertEqual(batch.batch["advantages"][0, 0].item(), 0.0) + self.assertEqual(batch.batch["advantages"][0, 2].item(), -10.0) + + def test_inverse_sft_mask(self): + """Verify that inverse_sft_mask correctly computes sft_weights on env tokens.""" + # Seq: [Action1, Env1, Action2, Env2] -> masks: response_mask=[1, 0, 1, 0], attention=[1, 1, 1, 1] + response_mask = torch.tensor([[1, 0, 1, 0]], dtype=torch.float32) + attention_mask = torch.ones(1, 4, dtype=torch.float32) + advantages = torch.tensor([[2.0, 2.0, 2.0, 2.0]]) + old_log_probs = torch.tensor([[np.log(0.5)] * 4]) + entropys = torch.tensor([[1.0] * 4]) + + batch = self._make_batch(advantages, response_mask, old_log_probs, attention_mask) + + # Force m_t = 0.0 for turn 0, m_t = 1.0 for turn 1 + # Turn 0 S* = 0.5 * (1.0 + log(0.5)) ~ 0.15 + # We can just let compute_dynamic_mask do its thing. + # Actually, let's just test that the weights are assigned correctly based on whatever mask is generated. + config = { + "mu_base": 0.001, # Force masking + "mu_exp": 0.001, + "lambda_wm": 0.0, + "enable": True, + "clipping_type": "batch", + "inverse_sft_mask": True + } + running_stats = {"initialized": False} + + _, metrics = apply_wmc_erc(batch, entropys, config, running_stats) + + self.assertIn("sft_weights", batch.batch) + sft_weights = batch.batch["sft_weights"] + + # Check shapes + self.assertEqual(sft_weights.shape, advantages.shape) + + # m_t is applied to advantages. Action1 (idx 0), Action2 (idx 2) + # Env1 (idx 1), Env2 (idx 3) + # If Action1 was masked (m_t=0), Env1 should have weight 1.0 + # If Action1 was not masked (m_t=1), Env1 should have weight 0.0 + + adv_action1 = batch.batch["advantages"][0, 0].item() + m_t_1 = adv_action1 / 2.0 # original advantage was 2.0 + weight_env1 = sft_weights[0, 1].item() + self.assertAlmostEqual(weight_env1, 1.0 - m_t_1) + + adv_action2 = batch.batch["advantages"][0, 2].item() + m_t_2 = adv_action2 / 2.0 + weight_env2 = sft_weights[0, 3].item() + self.assertAlmostEqual(weight_env2, 1.0 - m_t_2) + + # Ensure action tokens have 0 sft weight + self.assertEqual(sft_weights[0, 0].item(), 0.0) + self.assertEqual(sft_weights[0, 2].item(), 0.0) if __name__ == "__main__": - pytest.main([__file__, "-v"]) + unittest.main() diff --git a/progress.md b/progress.md new file mode 100644 index 00000000..0b6b1ee3 --- /dev/null +++ b/progress.md @@ -0,0 +1,86 @@ +# World Model Learning — Implementation Progress + +## Method 1: WMC-ERC (World Model-Conditioned Entropy Regularized Co-evolution) +**Status: Complete** +- File: `opentinker/backend_patch/verl/trainer/ppo/wmc_erc.py` +- Uses the LLM's prediction entropy at env token positions as a World Model uncertainty signal +- Dynamically gates policy gradient updates to prevent entropy collapse + +## Method 2: World Model SFT Loss +**Status: Complete** +- File: `opentinker/backend_patch/verl/trainer/ppo/world_model_loss.py` +- Auxiliary SFT loss on observation tokens (next-token prediction) +- Gated by `world_model_coeff` in actor config + +## Method 3: RWML — Reinforcement World Model Learning (arxiv:2602.05842) +**Status: Complete** + +### What it does +Per-turn RL reward based on text-level embedding similarity between model's predicted +next observation and actual environment observation, using an external embedding model. + +``` +d(pred, actual) = 1 - cos(E(pred), E(actual)) [external embedding model] +r^WM = 1.0 if d < tau_d else 0.0 [binary reward, paper default tau_d=0.2] +``` + +### How predictions are obtained +During `compute_log_prob`, logits at observation token positions are argmax-decoded +to get the model's predicted token IDs. These are decoded to text and compared with +actual observation tokens using an external sentence-transformers embedding model. + +### Files modified +- [x] `opentinker/backend_patch/verl/trainer/ppo/world_model_rl.py` — Core RWML module + - `EmbeddingSimilarityReward`: loads embedding model, computes cosine sim + binary reward + - `decode_per_turn_texts()`: extracts observation texts from token IDs per turn + - `compute_rwml_turn_rewards()`: batch reward computation with metrics +- [x] `verl/verl/workers/actor/dp_actor.py` — Predicted token ID extraction + - `_forward_micro_batch`: `return_predicted_ids` → argmax on logits + - `compute_log_prob`: returns `(log_probs, entropys, predicted_ids)` +- [x] `verl/verl/workers/fsdp_workers.py` — DataProto wrapping + - Includes `predicted_ids` in returned DataProto when RWML is enabled +- [x] `opentinker/server/generic_agent_loop.py` — Observation text storage + - Stores per-turn observation texts in `extra_fields["turn_observations"]` +- [x] `opentinker/server/http_training_server.py` — RWML integration + - Initializes embedding model at startup + - Computes RWML rewards from predicted_ids after compute_log_prob + - Adds RWML rewards to turn_scores before advantage computation +- [x] `opentinker/client/client_config/alfworld_wmc_erc_param.yaml` — Config + +### Configuration +```yaml +rwml: + enable: false + embedding_model: "Alibaba-NLP/gte-large-en-v1.5" + tau_d: 0.2 + coeff: 1.0 +``` + +### Data flow (separated from policy training) +``` +Rollout → turn_observations stored in extra_fields + ↓ +compute_log_prob (rwml_enabled=True) + → argmax on logits → predicted_ids + ↓ +RWML reward computation: + → decode predicted obs texts from predicted_ids + → decode actual obs texts from response tokens + → cos_sim(E(predicted), E(actual)) via embedding model + → binary reward: 1 if (1-sim) < tau_d + ↓ +RWML GRPO update (SEPARATE from policy): + → compute_grpo_per_step_advantage(turn_scores=rwml_rewards) + → update_actor(advantages=rwml_advantages) ← world model GRPO + ↓ +Policy GRPO update (normal): + → compute_advantage(turn_scores=task_rewards) + → update_actor(advantages=policy_advantages) ← policy GRPO +``` + +### Note on previous hidden-state approach +The initial WM-RL implementation (hidden-state cosine similarity as auxiliary loss) +was superseded by this RWML implementation which correctly follows the paper: +- Text-level (not hidden-state) embedding similarity +- External embedding model (not training model's own representations) +- Per-turn RL reward (not differentiable auxiliary loss) diff --git a/verl b/verl index 4bf4bd32..e2423e73 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 4bf4bd32d049be648867ade8c72ee3f5c27ebfcf +Subproject commit e2423e73eb5f09e1bdb091a605cd6385088502b3