diff --git a/.gitignore b/.gitignore index 89b1a0a5..86a66978 100755 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,4 @@ logs log outputs .history +**/traces/ diff --git a/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py b/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py index 9f355ee6..a35399b2 100755 --- a/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py +++ b/opentinker/backend_patch/verl/trainer/ppo/ray_trainer.py @@ -1216,6 +1216,311 @@ def _balance_batch( ) metrics.update(global_balance_stats) + def _compute_wm_beta_token( + self, batch: DataProto, wm_config: dict + ) -> tuple[DataProto, dict]: + """Compute per-token entropy coefficient based on WM uncertainty. + + This computes beta_token at the trainer level (not actor level) to ensure: + 1. Stable weights across PPO epochs (using lagged θ^- from old_log_probs) + 2. Proper coordination with observation_mask and turn_ids + + Formula: + For each turn t: + u_t = mean(-log_prob on obs tokens of turn t) # WM uncertainty + z_t = (u_t - mean(u)) / std(u) # z-score + β_t = β_0 + β_1 * sigmoid(γ * z_t) # dynamic entropy coeff + + beta_token[action_tokens_of_turn_t] = β_t + + Args: + batch: DataProto with old_log_probs, observation_mask, turn_ids, response_mask + wm_config: WM dynamic entropy configuration dict + + Returns: + Updated batch with beta_token field, and metrics dict + """ + import torch + import numpy as np + + metrics = {} + + # Get config - handle both dict and config object access patterns + base_entropy_coeff = self.config.actor_rollout_ref.actor.entropy_coeff + + # Helper function to get value from config (dict or object) + def get_config_value(cfg, key, default): + if isinstance(cfg, dict): + return cfg.get(key, default) + elif hasattr(cfg, key): + return getattr(cfg, key, default) + elif hasattr(cfg, "get"): + return cfg.get(key, default) + return default + + # New per-turn budget design: + # beta_t = per_turn_budget * (1 + fluctuation * tanh(gamma * (u - baseline) / baseline)) + # This gives each turn a base budget, adjusted up/down by WM uncertainty + per_turn_budget = get_config_value(wm_config, "per_turn_budget", 0.002) + gamma = get_config_value(wm_config, "gamma", 1.0) # Sensitivity + fluctuation = get_config_value(wm_config, "fluctuation", 0.5) # ±50% range + + # Sparse mode: only add entropy to turns with uncertainty significantly above baseline + sparse_mode = get_config_value(wm_config, "sparse_mode", False) + sparsity_threshold = get_config_value( + wm_config, "sparsity_threshold", 0.3 + ) # 30% above baseline + + print(f"[WM Dynamic Entropy] DEBUG: wm_config = {wm_config}") + print( + f"[WM Dynamic Entropy] DEBUG: per_turn_budget={per_turn_budget}, fluctuation={fluctuation}, sparse_mode={sparse_mode}, sparsity_threshold={sparsity_threshold}" + ) + + # Legacy config support (fallback to old beta_0/beta_1 if per_turn_budget not set) + beta_0 = get_config_value(wm_config, "beta_0", None) + beta_1 = get_config_value(wm_config, "beta_1", None) + use_legacy_mode = ( + beta_0 is not None and beta_1 is not None and per_turn_budget == 0.002 + ) + + if beta_0 is None: + beta_0 = base_entropy_coeff + if beta_1 is None: + beta_1 = base_entropy_coeff + + # Get data from batch + old_log_probs = batch.batch["old_log_probs"] # (B, L) - lagged policy log_probs + response_mask = batch.batch["response_mask"] # (B, L) + batch_size, response_length = response_mask.shape + device = response_mask.device + + # Get observation_mask and turn_ids from non_tensor_batch + observation_mask_raw = batch.non_tensor_batch.get("observation_mask", None) + turn_ids_raw = batch.non_tensor_batch.get("turn_ids", None) + + # DEBUG: Print available keys in non_tensor_batch + print( + f"[WM Dynamic Entropy] DEBUG: non_tensor_batch keys = {list(batch.non_tensor_batch.keys())}" + ) + print( + f"[WM Dynamic Entropy] DEBUG: observation_mask_raw is None = {observation_mask_raw is None}" + ) + print( + f"[WM Dynamic Entropy] DEBUG: turn_ids_raw is None = {turn_ids_raw is None}" + ) + + if observation_mask_raw is None or turn_ids_raw is None: + print( + "[WM Dynamic Entropy] WARNING: observation_mask or turn_ids not found, using per_turn_budget as neutral" + ) + # Create uniform beta_token with per_turn_budget (neutral value) + beta_token = ( + torch.ones(batch_size, response_length, device=device) * per_turn_budget + ) + batch.batch["beta_token"] = beta_token + metrics["wm_entropy/beta_token_mean"] = per_turn_budget + metrics["wm_entropy/per_turn_budget"] = per_turn_budget + return batch, metrics + + # Build tensors from raw data + obs_mask_tensor = torch.zeros( + batch_size, response_length, device=device, dtype=torch.bool + ) + turn_ids_tensor = torch.zeros( + batch_size, response_length, device=device, dtype=torch.long + ) + + for b in range(batch_size): + # observation_mask + obs_mask = observation_mask_raw[b] + if obs_mask is not None and len(obs_mask) > 0: + # Verify format: should be 0/1 vector, not index list + obs_arr = np.array(obs_mask[:response_length]) + if len(obs_arr) > 0 and obs_arr.max() > 1: + print( + f"[WM Dynamic Entropy] WARNING: observation_mask looks like index list (max={obs_arr.max()}), skipping sample {b}" + ) + continue + mask_tensor = torch.tensor(obs_arr, device=device, dtype=torch.float32) + if len(mask_tensor) < response_length: + mask_tensor = torch.cat( + [ + mask_tensor, + torch.zeros( + response_length - len(mask_tensor), device=device + ), + ] + ) + obs_mask_tensor[b] = mask_tensor.bool() + + # turn_ids + t_ids = turn_ids_raw[b] + if t_ids is not None and len(t_ids) > 0: + ids_tensor = torch.tensor( + t_ids[:response_length], device=device, dtype=torch.long + ) + if len(ids_tensor) < response_length: + ids_tensor = torch.cat( + [ + ids_tensor, + torch.zeros( + response_length - len(ids_tensor), + device=device, + dtype=torch.long, + ), + ] + ) + turn_ids_tensor[b] = ids_tensor + + # Initialize beta_token with per_turn_budget (neutral value) + # This is used for tokens in turns that have no obs (e.g., first turn) + # where we can't compute WM uncertainty + default_beta = per_turn_budget # Use per_turn_budget as neutral + beta_token = ( + torch.ones(batch_size, response_length, device=device) * default_beta + ) + + # DEBUG: Print sample info + if observation_mask_raw is not None and len(observation_mask_raw) > 0: + print( + f"[WM Dynamic Entropy] DEBUG: First sample observation_mask length = {len(observation_mask_raw[0]) if observation_mask_raw[0] is not None else 'None'}" + ) + print( + f"[WM Dynamic Entropy] DEBUG: First sample turn_ids length = {len(turn_ids_raw[0]) if turn_ids_raw[0] is not None else 'None'}" + ) + + # Track statistics + all_uncertainties = [] + all_betas = [] + valid_samples = 0 + + for b in range(batch_size): + sample_obs_mask = obs_mask_tensor[b] + sample_turn_ids = turn_ids_tensor[b] + sample_log_prob = old_log_probs[b] + + # Compute per-turn WM uncertainty + max_turns = ( + sample_turn_ids.max().item() + 1 if sample_turn_ids.numel() > 0 else 1 + ) + turn_uncertainties = [] + + for t in range(max_turns): + # Find observation tokens for this turn + turn_obs_mask = (sample_turn_ids == t) & sample_obs_mask + if turn_obs_mask.sum() > 0: + # WM uncertainty = mean negative log_prob on obs tokens + # Note: We use old_log_probs which has proper values for obs tokens + # because the full sequence is passed through the model + obs_log_probs = sample_log_prob[turn_obs_mask] + # Filter out zero/invalid log_probs (shouldn't happen with proper data) + valid_log_probs = obs_log_probs[obs_log_probs != 0] + if len(valid_log_probs) > 0: + uncertainty = -valid_log_probs.mean().item() + turn_uncertainties.append((t, uncertainty)) + all_uncertainties.append(uncertainty) + + if len(turn_uncertainties) >= 1: + valid_samples += 1 + + # Compute DYNAMIC baseline = mean of all turn uncertainties in this sample + sample_uncertainties = [u for _, u in turn_uncertainties] + sample_baseline = np.mean(sample_uncertainties) + + for t, u in turn_uncertainties: + # Compute normalized difference from baseline + if sample_baseline > 1e-6: + normalized_diff = (u - sample_baseline) / sample_baseline + else: + normalized_diff = 0.0 + + if sparse_mode: + # Sparse mode: only add entropy to turns with uncertainty significantly above baseline + # If normalized_diff > sparsity_threshold, add full per_turn_budget + # Otherwise, add nothing (beta_t = 0) + if normalized_diff > sparsity_threshold: + beta_t = per_turn_budget + else: + beta_t = 0.0 + else: + # Dense mode: Per-turn budget with dynamic baseline adjustment + # beta_t = per_turn_budget * (1 + fluctuation * tanh(gamma * (u - sample_baseline) / sample_baseline)) + # + # When u = sample_baseline: multiplier = 1, beta_t = per_turn_budget + # When u > sample_baseline: multiplier > 1, beta_t > per_turn_budget (more exploration) + # When u < sample_baseline: multiplier < 1, beta_t < per_turn_budget (less exploration) + + # Use tanh to smoothly bound the multiplier to [1-fluctuation, 1+fluctuation] + multiplier = 1.0 + fluctuation * np.tanh( + gamma * normalized_diff + ) + beta_t = per_turn_budget * multiplier + + # Ensure beta_t is non-negative + beta_t = max(0.0, beta_t) + + all_betas.append(beta_t) + + # Apply to action tokens of this turn (not obs tokens) + turn_action_mask = ( + (sample_turn_ids == t) + & (~sample_obs_mask) + & response_mask[b].bool() + ) + beta_token[b, turn_action_mask] = beta_t + + # Add beta_token to batch + batch.batch["beta_token"] = beta_token + + # Compute metrics + metrics["wm_entropy/per_turn_budget"] = per_turn_budget + metrics["wm_entropy/fluctuation"] = fluctuation + metrics["wm_entropy/sparse_mode"] = 1.0 if sparse_mode else 0.0 + metrics["wm_entropy/sparsity_threshold"] = ( + sparsity_threshold if sparse_mode else 0.0 + ) + + # Count how many turns got entropy (useful for sparse mode) + turns_with_entropy = sum(1 for b in all_betas if b > 0) + total_turns = len(all_betas) + metrics["wm_entropy/turns_with_entropy"] = turns_with_entropy + metrics["wm_entropy/entropy_sparsity"] = turns_with_entropy / max( + total_turns, 1 + ) + + if all_uncertainties: + # uncertainty_mean is the dynamic baseline (average across all turns in batch) + batch_mean_uncertainty = np.mean(all_uncertainties) + metrics["wm_entropy/uncertainty_mean"] = batch_mean_uncertainty + metrics["wm_entropy/uncertainty_std"] = ( + np.std(all_uncertainties) if len(all_uncertainties) > 1 else 0.0 + ) + metrics["wm_entropy/uncertainty_min"] = np.min(all_uncertainties) + metrics["wm_entropy/uncertainty_max"] = np.max(all_uncertainties) + if all_betas: + metrics["wm_entropy/beta_mean"] = np.mean(all_betas) + metrics["wm_entropy/beta_std"] = ( + np.std(all_betas) if len(all_betas) > 1 else 0.0 + ) + metrics["wm_entropy/beta_min"] = np.min(all_betas) + metrics["wm_entropy/beta_max"] = np.max(all_betas) + # Beta relative to per_turn_budget + metrics["wm_entropy/beta_mean_ratio"] = ( + np.mean(all_betas) / per_turn_budget if per_turn_budget > 0 else 1.0 + ) + metrics["wm_entropy/valid_samples"] = valid_samples + metrics["wm_entropy/total_samples"] = batch_size + + # Log overall beta_token stats + action_mask = response_mask.bool() & (~obs_mask_tensor) + if action_mask.sum() > 0: + metrics["wm_entropy/beta_token_mean"] = ( + beta_token[action_mask].mean().item() + ) + metrics["wm_entropy/beta_token_std"] = beta_token[action_mask].std().item() + + return batch, metrics + def compute_rollout_importance_weights_and_add_to_batch( self, batch: DataProto ) -> tuple[DataProto, dict]: @@ -1470,7 +1775,7 @@ def fit(self): batch, self.reward_fn ) - # recompute old_log_probs + # recompute old_log_probs (behavior policy log probabilities) with marked_timer("old_log_prob", timing_raw, color="blue"): # ===== DEBUG LOGGING START ===== print("=" * 80) @@ -1502,8 +1807,86 @@ def fit(self): print("=" * 80) # ===== DEBUG LOGGING END ===== - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) - entropys = old_log_prob.batch["entropys"] + # Check if turn-level temperature is enabled + # When enabled, rollout_log_probs already contains log μ_T (behavior policy log prob) + # which should be used directly as old_log_probs for correct IS correction + turn_temp_config = ( + self.config.actor_rollout_ref.rollout.multi_turn.get( + "turn_level_temperature", {} + ) + ) + turn_temp_enabled = turn_temp_config.get("enabled", False) + has_rollout_logprobs = "rollout_log_probs" in batch.batch + use_rollout_logprobs_as_old = ( + turn_temp_enabled and has_rollout_logprobs + ) + + # CRITICAL: Warn if turn-level temp is enabled but rollout_log_probs not available + if turn_temp_enabled and not has_rollout_logprobs: + print( + "[TurnLevelTemp] WARNING: turn_level_temperature is enabled but rollout_log_probs " + "is not in batch! This means calculate_log_probs=False in rollout config. " + "Set actor_rollout_ref.rollout.calculate_log_probs=True for correct IS correction." + ) + + # Log turn-level temperature statistics if available + has_token_temps = "token_temperatures" in batch.non_tensor_batch + if has_token_temps: + token_temps_raw = batch.non_tensor_batch.get( + "token_temperatures", None + ) + if token_temps_raw is not None: + all_temps = [] + for temps in token_temps_raw: + if temps is not None and len(temps) > 0: + all_temps.extend(temps) + if all_temps: + import numpy as np + + metrics["turn_temp/mean"] = np.mean(all_temps) + metrics["turn_temp/std"] = ( + np.std(all_temps) if len(all_temps) > 1 else 0.0 + ) + metrics["turn_temp/min"] = np.min(all_temps) + metrics["turn_temp/max"] = np.max(all_temps) + # Count tokens with T != 1.0 (adjusted temperatures) + adjusted_count = sum( + 1 for t in all_temps if abs(t - 1.0) > 0.01 + ) + metrics["turn_temp/adjusted_ratio"] = ( + adjusted_count / len(all_temps) + ) + print( + f"[TurnLevelTemp] Batch stats: mean={metrics['turn_temp/mean']:.3f}, " + f"range=[{metrics['turn_temp/min']:.3f}, {metrics['turn_temp/max']:.3f}], " + f"adjusted_ratio={metrics['turn_temp/adjusted_ratio']:.2%}" + ) + + # Also log turn-level uncertainties if available + if "turn_uncertainties" in batch.non_tensor_batch: + turn_u_raw = batch.non_tensor_batch.get( + "turn_uncertainties", None + ) + if turn_u_raw is not None: + all_u = [] + for u_list in turn_u_raw: + if u_list is not None and len(u_list) > 0: + all_u.extend(u_list) + if all_u: + import numpy as np + + metrics["turn_temp/uncertainty_mean"] = np.mean( + all_u + ) + metrics["turn_temp/uncertainty_std"] = ( + np.std(all_u) if len(all_u) > 1 else 0.0 + ) + + # Always compute log_prob with T=1 to get entropy for entropy bonus + # (entropy should be computed on the learned policy, not the behavior policy) + actor_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + + entropys = actor_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] loss_agg_mode = ( self.config.actor_rollout_ref.actor.loss_agg_mode @@ -1532,8 +1915,43 @@ def fit(self): # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) + actor_log_prob.batch.pop("entropys") + + # CRITICAL FIX: Use correct behavior policy log_prob for IS correction + # When turn-level temperature is enabled: + # - rollout_log_probs = log μ_T = log softmax(z/T) (from sglang/vLLM during sampling) + # - This is the TRUE behavior policy log_prob and should be used as old_log_probs + # - ratio = π_θ_new / μ_T = exp(log π_θ_new - log μ_T) + # When turn-level temperature is NOT enabled (T=1 everywhere): + # - Use actor_log_prob["old_log_probs"] as before (both are log π) + if use_rollout_logprobs_as_old: + print( + "[TurnLevelTemp] Using rollout_log_probs as old_log_probs for correct IS correction" + ) + # rollout_log_probs is already log μ_T from the behavior policy + batch.batch["old_log_probs"] = batch.batch[ + "rollout_log_probs" + ] + + # Log the difference between actor (T=1) and rollout (T≠1) log_probs + actor_lp = actor_log_prob.batch["old_log_probs"] + rollout_lp = batch.batch["rollout_log_probs"] + diff = (actor_lp - rollout_lp).abs() + masked_diff = diff * response_masks + if response_masks.sum() > 0: + mean_diff = masked_diff.sum() / response_masks.sum() + max_diff = masked_diff.max() + metrics["turn_temp/logprob_diff_mean"] = ( + mean_diff.item() + ) + metrics["turn_temp/logprob_diff_max"] = max_diff.item() + print( + f"[TurnLevelTemp] log_prob diff (actor vs rollout): " + f"mean={mean_diff.item():.4f}, max={max_diff.item():.4f}" + ) + else: + # Standard case: no temperature variation, actor log_prob = behavior log_prob + batch = batch.union(actor_log_prob) if "rollout_log_probs" in batch.batch.keys(): # TODO: we may want to add diff of probs too. @@ -1618,6 +2036,52 @@ def fit(self): config=self.config.algorithm, ) + # ===================================================================== + # WM-Guided Dynamic Entropy: Compute beta_token at trainer level + # This ensures stable weights across PPO epochs (using lagged θ^-) + # ===================================================================== + actor_config = self.config.actor_rollout_ref.actor + # Handle both dict and dataclass/OmegaConf access patterns + if hasattr(actor_config, "get"): + wm_dynamic_entropy_config = actor_config.get( + "wm_dynamic_entropy", {} + ) + elif hasattr(actor_config, "wm_dynamic_entropy"): + wm_dynamic_entropy_config = actor_config.wm_dynamic_entropy + # Convert to dict if it's a config object + if hasattr(wm_dynamic_entropy_config, "__dict__"): + wm_dynamic_entropy_config = dict( + wm_dynamic_entropy_config + ) + elif hasattr(wm_dynamic_entropy_config, "items"): + wm_dynamic_entropy_config = dict( + wm_dynamic_entropy_config + ) + else: + wm_dynamic_entropy_config = {} + + # Check if enabled + is_enabled = False + if isinstance(wm_dynamic_entropy_config, dict): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + elif hasattr(wm_dynamic_entropy_config, "enabled"): + is_enabled = wm_dynamic_entropy_config.enabled + elif hasattr(wm_dynamic_entropy_config, "get"): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + + print( + f"[WM Dynamic Entropy] DEBUG: is_enabled = {is_enabled}, config type = {type(wm_dynamic_entropy_config)}" + ) + + if is_enabled: + with marked_timer( + "wm_beta_token", timing_raw, color="magenta" + ): + batch, wm_metrics = self._compute_wm_beta_token( + batch, wm_dynamic_entropy_config + ) + metrics.update(wm_metrics) + # update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): diff --git a/opentinker/backend_patch/verl/workers/config/__init__.py b/opentinker/backend_patch/verl/workers/config/__init__.py new file mode 100644 index 00000000..1aa65308 --- /dev/null +++ b/opentinker/backend_patch/verl/workers/config/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .rollout import ( + SamplingConfig, + TurnLevelTemperatureConfig, + MultiTurnConfig, + CustomAsyncServerConfig, + AgentLoopConfig, + TraceConfig, + ServerConfig, + RolloutConfig, +) + +__all__ = [ + "SamplingConfig", + "TurnLevelTemperatureConfig", + "MultiTurnConfig", + "CustomAsyncServerConfig", + "AgentLoopConfig", + "TraceConfig", + "ServerConfig", + "RolloutConfig", +] diff --git a/opentinker/backend_patch/verl/workers/config/rollout.py b/opentinker/backend_patch/verl/workers/config/rollout.py index 54434aa2..8ed230f0 100755 --- a/opentinker/backend_patch/verl/workers/config/rollout.py +++ b/opentinker/backend_patch/verl/workers/config/rollout.py @@ -22,6 +22,7 @@ __all__ = [ "SamplingConfig", + "TurnLevelTemperatureConfig", "MultiTurnConfig", "CustomAsyncServerConfig", "AgentLoopConfig", @@ -40,6 +41,21 @@ class SamplingConfig(BaseConfig): n: int = 1 +@dataclass +class TurnLevelTemperatureConfig(BaseConfig): + """Configuration for turn-level temperature (WM-guided exploration during rollout).""" + + enabled: bool = False + base_temperature: float = 1.0 + kappa: float = 0.3 # Temperature adjustment range + min_temperature: float = 0.5 + max_temperature: float = 1.5 + ema_decay: float = 0.9 + enable_is_correction: bool = True + # Use accurate uncertainty (T=1 re-generation) vs heuristic (scaled log μ_T) + use_accurate_uncertainty: bool = True + + @dataclass class MultiTurnConfig(BaseConfig): _mutable_fields = { @@ -48,6 +64,7 @@ class MultiTurnConfig(BaseConfig): "max_tokens_per_turn", "weave_project", "experiment_name", + "turn_level_temperature", } enable: bool = False @@ -72,6 +89,11 @@ class MultiTurnConfig(BaseConfig): weave_project: Optional[str] = None experiment_name: Optional[str] = None + # Turn-level temperature for WM-guided exploration + turn_level_temperature: TurnLevelTemperatureConfig = field( + default_factory=TurnLevelTemperatureConfig + ) + @dataclass class CustomAsyncServerConfig(BaseConfig): diff --git a/opentinker/client/alfworld_inference.py b/opentinker/client/alfworld_inference.py new file mode 100644 index 00000000..4b870921 --- /dev/null +++ b/opentinker/client/alfworld_inference.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +"""ALFWorld Inference Script. + +This script runs inference/evaluation on trained ALFWorld models. + +Usage: + # Start ALFWorld environment server first (in another terminal): + python -m opentinker.environment.alfworld.alfworld_server --port 8091 --split eval_in_distribution + + # Run inference with scheduler: + python alfworld_inference.py \ + model_path=/path/to/checkpoint \ + scheduler_url=http://localhost:8089 \ + data_path=/path/to/eval_data.jsonl +""" + +import hydra + +from utils.http_training_client import InferenceSchedulerClient +from utils.scheduler_client_lifecycle import get_lifecycle_manager +from opentinker.environment.inference_pipeline import run_inference +from opentinker.environment.alfworld import ALFWorldGame +from opentinker.environment.game_stats_client import GameStatsClient + + +@hydra.main( + config_path="client_config", + config_name="alfworld_inference_config.yaml", + version_base=None, +) +def main(args): + """Run ALFWorld inference with scheduler-managed vLLM server.""" + lifecycle = get_lifecycle_manager() + + print("=" * 60) + print("ALFWorld Inference with Scheduler") + print("=" * 60) + + if not args.model_path: + raise ValueError("model_path is required") + + # 1. Submit inference job to scheduler + scheduler_client = InferenceSchedulerClient( + scheduler_url=args.get("scheduler_url", "http://localhost:8089"), + api_key=args.get("scheduler_api_key"), + ) + + print(f"\nModel: {args.model_path}") + print(f"Scheduler: {args.scheduler_url}") + print(f"Environment: {args.env_endpoint}") + print(f"Split: {args.split}") + + print("\nSubmitting inference job to scheduler...") + job_result = scheduler_client.submit_inference_job( + model_path=args.model_path, + tokenizer_path=args.get("tokenizer_path"), + tensor_parallel_size=args.get("tensor_parallel_size", 1), + num_gpus=args.get("num_gpus"), + gpu_memory_utilization=args.get("gpu_memory_utilization", 0.9), + max_model_len=args.get("max_model_len"), + trust_remote_code=args.get("trust_remote_code", True), + ) + + job_id = job_result["job_id"] + vllm_server_url = job_result["vllm_server_url"] + + # Register job for lifecycle cleanup + lifecycle.register_job(scheduler_client, job_id) + + print(f"✓ Inference job {job_id} started at {vllm_server_url}") + + # 2. Setup GameStatsClient for per-step metrics (with job_id isolation) + game_stats = GameStatsClient(args.env_endpoint, job_id=job_id) + if game_stats.health_check(): + print(f"✓ Connected to ALFWorld server at {args.env_endpoint}") + game_stats.reset_all() # Reset stats for this job before inference + else: + print( + f"⚠ ALFWorld server not available at {args.env_endpoint}, continuing without stats" + ) + game_stats = None + + # 3. Run inference using the remote vLLM server + data_path = args.get("data_path") + if data_path: + print(f"Running inference on {data_path}...") + else: + print(f"Running inference on ALFWorld {args.split} split...") + + results = run_inference( + model_path=None, # Not needed when using vllm_server_url + vllm_server_url=vllm_server_url, + tokenizer_path=args.get("tokenizer_path") or args.model_path, + data_path=data_path, + game_class=ALFWorldGame, + env_endpoint=args.env_endpoint, + job_id=job_id, # Pass job_id for stats isolation + output_path=args.get("output_path"), + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_new_tokens, + max_samples=args.get("max_samples"), + max_user_turns=args.multi_turn.max_user_turns, + max_assistant_turns=args.multi_turn.max_assistant_turns, + ) + + # 4. Log game stats after inference + print("\n" + "=" * 60) + print("Inference Results") + print("=" * 60) + + if game_stats: + stats = game_stats.get_all_stats() + print(f"\nALFWorld Evaluation Stats (job_id={job_id}):") + print(f" Total episodes: {stats.get('total_games', 0)}") + print(f" Successes: {stats.get('total_wins', 0)}") + print(f" Failures: {stats.get('total_losses', 0)}") + success_rate = stats.get("cumulative_win_rate", 0) + print(f" Success rate: {success_rate:.1%}") + print(f" Mean reward: {stats.get('mean_final_reward', 0):.4f}") + print(f" Mean steps: {stats.get('mean_steps', 0):.2f}") + + if results: + print(f"\nProcessed {len(results)} samples") + + if args.get("output_path"): + print(f"Results saved to: {args.output_path}") + + print(f"\n{'='*60}") + print("Inference completed! vLLM server will be automatically cleaned up.") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/opentinker/client/client_config/alfworld_inference_config.yaml b/opentinker/client/client_config/alfworld_inference_config.yaml new file mode 100644 index 00000000..64c39cf8 --- /dev/null +++ b/opentinker/client/client_config/alfworld_inference_config.yaml @@ -0,0 +1,37 @@ +# ALFWorld Inference Configuration +# Use with: python alfworld_inference.py + +# Model settings +model_path: null # Path to trained checkpoint (HuggingFace format) - REQUIRED +tokenizer_path: null # Tokenizer path (defaults to model_path if null) + +# GPU settings +tensor_parallel_size: 1 # Number of GPUs for tensor parallelism +num_gpus: 1 # Number of GPUs to request from scheduler +gpu_memory_utilization: 0.9 +max_model_len: null # Max model context length (null = auto) +trust_remote_code: true + +# Generation parameters (greedy by default for inference) +temperature: 0.0 # 0.0 = greedy decoding for deterministic evaluation +top_p: 1.0 +max_new_tokens: 4096 # Max tokens for full multi-turn trajectory + +# Data settings +data_path: null # Input data file (parquet/jsonl), null = use ALFWorld split +output_path: null # Output results file (jsonl) +max_samples: null # Limit samples (null = all) + +# Environment settings +env_endpoint: http://0.0.0.0:8091 +split: eval_in_distribution # train, eval_in_distribution, eval_out_of_distribution + +# Multi-turn settings for ALFWorld +multi_turn: + max_user_turns: 50 # Max environment interactions + max_assistant_turns: 50 + max_tokens_per_turn: 256 # Per-turn response limit + +# Scheduler settings +scheduler_url: http://0.0.0.0:8089 +scheduler_api_key: null diff --git a/opentinker/client/client_config/alfworld_param.yaml b/opentinker/client/client_config/alfworld_param.yaml index 822f9997..17d6d431 100644 --- a/opentinker/client/client_config/alfworld_param.yaml +++ b/opentinker/client/client_config/alfworld_param.yaml @@ -8,8 +8,8 @@ experiment_name: alfworld_training # Logging logger_backends: ["console", "wandb"] -# Tracing (optional) -enable_tracing: true +# Tracing (optional) - DISABLED to prevent disk space issues +enable_tracing: false weave_project: null # WandB (optional) @@ -24,8 +24,8 @@ num_workers: 4 # Training duration - set ONE of these (num_steps takes precedence if both set) num_epochs: null # Number of epochs (null = use num_steps) num_steps: 1000 # Total training steps (null = use num_epochs) -save_freq: 20000 -test_freq: 10 # Validation frequency (every N steps) +save_freq: 10000 +test_freq: 100 # Validation frequency (every N steps) # Validation parameters val_batch_size: 50 # Total validation samples (null = 50) @@ -75,9 +75,52 @@ multi_turn: weave_project: "zsqzz/alfworld-env-test" experiment_name: "alfworld_interaction" + # Turn-level Temperature (WM-guided exploration) + # Adjusts sampling temperature per turn based on previous observation's WM uncertainty + # High uncertainty -> higher temperature (more exploration) + # Low uncertainty -> lower temperature (more exploitation) + # + # IMPORTANT: Requires actor_rollout_ref.rollout.calculate_log_probs=True + # to ensure rollout_log_probs is saved for correct IS correction. + turn_level_temperature: + enabled: true # Set to true to enable dynamic temperature per turn + base_temperature: 1.0 # Base temperature when uncertainty is average + kappa: 0.5 # Sensitivity: T = T_base + kappa * tanh(normalized_uncertainty) + min_temperature: 0.5 # Minimum allowed temperature + max_temperature: 2.0 # Maximum allowed temperature + ema_decay: 0.9 # EMA decay for uncertainty normalization + # Uncertainty computation method: + # - true: Use accurate method (T=1 re-generation to get log π), SLOW - extra model call + # - false: Use heuristic (direct log μ_T), fast and works with vLLM + # NOTE: vLLM doesn't support the extra generate call, so use false with vLLM + use_accurate_uncertainty: false + # Scheduler settings scheduler_url: "http://0.0.0.0:8780" scheduler_api_key: otk_98b8db24ccd64c92e1fdd9a232e209fa # GPU settings num_gpus: 4 + +# Actor settings (passed to server) +actor: + # World model loss: predict environment observations as auxiliary task + # 训练模型预测环境观察,提供 WM 不确定性信号 + use_world_model_loss: true + world_model_loss_coef: 0.01 # 用小系数避免干扰 policy + + # Turn-wise Dynamic Entropy Coefficient (WM-guided) + # 根据每个 turn 的 WM uncertainty 调整 entropy bonus + # 高 uncertainty turn -> β > per_turn_budget -> 更多探索 + # 低 uncertainty turn -> β < per_turn_budget -> 更稳定执行 + wm_dynamic_entropy: + enabled: false + # Per-turn budget design: + # β_t = per_turn_budget * (1 + fluctuation * tanh(γ * (u - baseline) / baseline)) + # - per_turn_budget: 每个 turn 的基础 entropy budget + # - baseline: 动态计算 = 该 sample 内所有 turn 的 uncertainty 均值 + # - fluctuation: 浮动范围 (0.5 = ±50%) + # - gamma: 敏感度 (越大对 uncertainty 差异越敏感) + per_turn_budget: 0.002 # 每个 turn 的基础 budget + fluctuation: 0.5 # ±50% 浮动 (beta 范围: [0.001, 0.003]) + gamma: 1.0 # 敏感度 diff --git a/opentinker/client/client_config/llm_user_param.yaml b/opentinker/client/client_config/llm_user_param.yaml new file mode 100644 index 00000000..b05a9937 --- /dev/null +++ b/opentinker/client/client_config/llm_user_param.yaml @@ -0,0 +1,57 @@ +# LLM User Simulator Training Configuration +# Train a conversational agent with LLM-based user simulation + +# Project settings +project_name: opentinker +experiment_name: llm_user_training + +# Logging +logger_backends: ["console", "wandb"] +enable_tracing: false +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 8 +num_workers: 4 +num_steps: 1000 +save_freq: 500 +test_freq: 10 +val_batch_size: 20 + +# Generation parameters +temperature: 0.8 +top_p: 0.95 +max_new_tokens: 4096 +max_prompt_tokens: 2048 + +# Algorithm +algorithm: "agent_loop" +adv_estimator: "grpo" +rollout_n: 4 + +# Interaction configuration +interaction: + name: llm_user_simulator + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8100 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 8 + max_steps: 10 + observation_template: "{observation}" + +multi_turn: + max_user_turns: 10 + max_assistant_turns: 10 + max_tokens_per_turn: 512 + +# Scheduler settings +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +# GPU settings +num_gpus: 4 diff --git a/opentinker/client/llm_user_rl.py b/opentinker/client/llm_user_rl.py new file mode 100644 index 00000000..7bf1800f --- /dev/null +++ b/opentinker/client/llm_user_rl.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +"""LLM User Simulator RL Training Client. + +Train a conversational agent with LLM-based user simulation. + +Usage: + # Start the LLM user simulator server first: + python -m opentinker.environment.llm_user_simulator.llm_user_server --port 8100 --shards 8 + + # Run training: + python llm_user_rl.py scheduler_url=http://localhost:8780 num_gpus=4 +""" + +from omegaconf import OmegaConf +import hydra + +from utils.http_training_client import ServiceClient, SchedulerClient +from opentinker.environment.base_game_environment import GameEnvironment +from opentinker.environment.llm_user_simulator import LLMUserGame +from opentinker.environment.game_stats_client import GameStatsClient +from utils.utils import resolve_paths_in_config +from utils.scheduler_client_lifecycle import get_lifecycle_manager + + +@hydra.main(config_path="client_config", config_name="llm_user_param.yaml") +def main(args): + args = resolve_paths_in_config(args) + lifecycle = get_lifecycle_manager() + + print("=" * 60) + print("Training with LLM User Simulator") + print("=" * 60) + + # Connect to scheduler + scheduler_url = args.get("scheduler_url", "http://localhost:8780") + scheduler_api_key = args.get("scheduler_api_key", None) + + print(f"\nConnecting to scheduler at {scheduler_url}") + scheduler_client = SchedulerClient( + scheduler_url=scheduler_url, api_key=scheduler_api_key + ) + + # Submit job + print("\nSubmitting training job...") + job_result = scheduler_client.submit_job( + config=OmegaConf.to_container(args, resolve=True), + enable_agent_loop=True, + wandb_key=args.get("wandb_key"), + num_gpus=args.get("num_gpus"), + ) + + job_id = job_result["job_id"] + server_url = job_result["server_url"] + lifecycle.register_job(scheduler_client, job_id) + + print(f"\n✓ Job {job_id} allocated!") + print(f" Server URL: {server_url}") + print("=" * 60) + + # Setup GameEnvironment + interaction_config = args.interaction.config + game_kwargs = { + "max_turns": interaction_config.get("max_steps", 10), + } + + env = GameEnvironment( + game_class=LLMUserGame, + config=args, + game_kwargs=game_kwargs, + job_id=job_id, + ) + + # Setup stats client + env_endpoint = interaction_config.env_endpoint + game_stats = GameStatsClient(env_endpoint, job_id=env.job_id) + if game_stats.health_check(): + print(f"✓ Connected to LLM user simulator at {env_endpoint}") + game_stats.reset_all() + else: + print(f"⚠ Server at {env_endpoint} not responding") + game_stats = None + + # Connect to training server + print(f"\nConnecting to server at {server_url}") + client = ServiceClient( + server_url=server_url, + project_name=args.project_name, + experiment_name=args.experiment_name, + logger_backends=args.logger_backends, + ) + + client.set_config(args, env) + + # Train + num_steps = args.get("num_steps", 1000) + print(f"\nStarting training for {num_steps} steps...") + print("=" * 60) + + try: + final_metrics = client.fit( + env=env, + num_steps=num_steps, + save_freq=args.save_freq, + test_freq=args.test_freq, + verbose=True, + validate_before_training=True, + game_stats_client=game_stats, + ) + + print("\n" + "=" * 60) + print("Training completed!") + print(f"Final metrics: {final_metrics}") + print("=" * 60) + + finally: + env.cleanup() + + +if __name__ == "__main__": + main() diff --git a/opentinker/environment/alfworld/alfworld_server.py b/opentinker/environment/alfworld/alfworld_server.py index 69d59f12..2d2197a7 100644 --- a/opentinker/environment/alfworld/alfworld_server.py +++ b/opentinker/environment/alfworld/alfworld_server.py @@ -49,7 +49,7 @@ def main(): parser.add_argument( "--split", type=str, - default="train", + default="eval_in_distribution", choices=["train", "eval_in_distribution", "eval_out_of_distribution"], help="Dataset split to use", ) diff --git a/opentinker/environment/llm_user_simulator/__init__.py b/opentinker/environment/llm_user_simulator/__init__.py new file mode 100644 index 00000000..0e07e666 --- /dev/null +++ b/opentinker/environment/llm_user_simulator/__init__.py @@ -0,0 +1,9 @@ +"""LLM User Simulator Environment. + +This module provides an environment where an LLM simulates a user, +enabling training of conversational agents. +""" + +from opentinker.environment.llm_user_simulator.llm_user_game import LLMUserGame + +__all__ = ["LLMUserGame"] diff --git a/opentinker/environment/llm_user_simulator/llm_user_game.py b/opentinker/environment/llm_user_simulator/llm_user_game.py new file mode 100644 index 00000000..00f0b15f --- /dev/null +++ b/opentinker/environment/llm_user_simulator/llm_user_game.py @@ -0,0 +1,590 @@ +#!/usr/bin/env python3 +"""LLM User Simulator Game Implementation. + +This module provides an environment where an LLM acts as a user simulator, +enabling training of conversational agents through self-play or cross-play. + +Example: + from llm_user_game import LLMUserGame + + game = LLMUserGame( + simulator_model="gpt-4o-mini", + task_prompt="You are a customer trying to book a flight.", + ) + obs = game.reset() + result = game.step("Hello! How can I help you today?") +""" + +import os +import random +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from opentinker.environment.base_game import AbstractGame, StepResult + +# Try to import LLM clients +try: + from openai import OpenAI + + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + +try: + import anthropic + + ANTHROPIC_AVAILABLE = True +except ImportError: + ANTHROPIC_AVAILABLE = False + + +@dataclass +class ConversationTurn: + """A single turn in the conversation.""" + + role: str # "agent" or "user" + content: str + + +class LLMUserGame(AbstractGame): + """LLM-based user simulator environment. + + The agent (being trained) plays the role of an assistant/agent, + while an LLM simulates the user providing requests and feedback. + + Attributes: + simulator_model: Model name for the user simulator (e.g., "gpt-4o-mini") + task_prompt: System prompt defining the user's persona and task + max_turns: Maximum conversation turns before episode ends + success_keywords: Keywords that indicate task success + """ + + # Reward constants + REWARD_SUCCESS = 10.0 + REWARD_FAILURE = -1.0 + REWARD_STEP = -0.01 + REWARD_USER_SATISFIED = 5.0 + + # Default max turns + DEFAULT_MAX_TURNS = 10 + + # LLM Judge prompt + JUDGE_PROMPT = """You are an expert evaluator assessing conversational AI quality. + +Evaluate the following conversation between an AI assistant and a user. + +## Evaluation Criteria (score 1-10 each): + +1. **Helpfulness**: Did the assistant understand and address the user's needs? +2. **Clarity**: Were the responses clear and easy to understand? +3. **Problem Resolution**: Was the user's issue/request successfully resolved? +4. **Professionalism**: Was the tone appropriate and professional? +5. **Efficiency**: Did the assistant resolve the issue without unnecessary back-and-forth? + +## Conversation: +{conversation} + +## Your Evaluation: +Provide scores and a brief explanation in this exact JSON format: +```json +{ + "helpfulness": <1-10>, + "clarity": <1-10>, + "problem_resolution": <1-10>, + "professionalism": <1-10>, + "efficiency": <1-10>, + "overall_score": <1-10>, + "success": , + "explanation": "" +} +``` +""" + + # Default task prompts for various scenarios + TASK_PROMPTS = { + "customer_service": """You are a customer contacting customer service. +You have a specific problem that needs to be resolved. +Be realistic - ask clarifying questions, express frustration if not helped properly. +When your issue is resolved satisfactorily, say "Thank you, that resolves my issue." +If the agent is unhelpful after several attempts, say "This is not helpful, goodbye." +""", + "booking_assistant": """You are a user trying to book a reservation. +You have specific preferences (date, time, number of people, etc.). +Ask questions about availability and options. +When you successfully complete a booking, say "Great, the booking is confirmed." +If unable to book, say "I'll try elsewhere, thanks." +""", + "tech_support": """You are a user with a technical problem. +Describe your issue and provide details when asked. +If the solution works, say "That fixed it, thank you!" +If multiple attempts fail, say "This still doesn't work." +""", + "information_seeking": """You are a user looking for specific information. +Ask questions to get the information you need. +When you get satisfactory answers, say "That's exactly what I needed, thanks!" +If answers are unclear or wrong, say "That's not what I was looking for." +""", + } + + def __init__( + self, + simulator_model: str = "gpt-4o-mini", + simulator_api_key: Optional[str] = None, + simulator_base_url: Optional[str] = None, + task_prompt: Optional[str] = None, + task_type: str = "customer_service", + max_turns: int = DEFAULT_MAX_TURNS, + success_keywords: Optional[List[str]] = None, + failure_keywords: Optional[List[str]] = None, + temperature: float = 0.7, + seed: Optional[int] = None, + use_llm_judge: bool = True, + judge_model: Optional[str] = None, + ): + """Initialize LLM User Simulator. + + Args: + simulator_model: Model name for user simulation + simulator_api_key: API key (defaults to env var) + simulator_base_url: Custom API base URL (for local models) + task_prompt: Custom system prompt for user persona + task_type: Predefined task type if task_prompt not provided + max_turns: Maximum conversation turns + success_keywords: Phrases indicating success (fallback if LLM judge disabled) + failure_keywords: Phrases indicating failure (fallback if LLM judge disabled) + temperature: Sampling temperature for user LLM + seed: Random seed for reproducibility + use_llm_judge: Use LLM-as-a-Judge for evaluation (recommended) + judge_model: Model for judging (defaults to simulator_model) + """ + self.simulator_model = simulator_model + self.max_turns = max_turns + self.temperature = temperature + + # Set API key + self.api_key = simulator_api_key or os.environ.get("OPENAI_API_KEY") + self.base_url = simulator_base_url + + # Initialize LLM client + self._init_llm_client() + + # Set task prompt + if task_prompt: + self.task_prompt = task_prompt + else: + self.task_prompt = self.TASK_PROMPTS.get( + task_type, self.TASK_PROMPTS["customer_service"] + ) + + # LLM-as-a-Judge settings + self.use_llm_judge = use_llm_judge + self.judge_model = judge_model or simulator_model + + # Success/failure detection (fallback when LLM judge is disabled) + self.success_keywords = success_keywords or [ + "thank you", + "that resolves", + "that fixed it", + "exactly what I needed", + "booking is confirmed", + "issue is resolved", + "problem solved", + ] + self.failure_keywords = failure_keywords or [ + "not helpful", + "goodbye", + "doesn't work", + "not what I was looking for", + "try elsewhere", + "give up", + "frustrated", + ] + + # Judge evaluation result (populated at end of episode) + self._judge_result: Optional[Dict[str, Any]] = None + + # Game state + self._conversation: List[ConversationTurn] = [] + self._turn_count = 0 + self._done = False + self._success = False + self._current_task = "" + + if seed is not None: + random.seed(seed) + + def _init_llm_client(self): + """Initialize the LLM client for user simulation.""" + if not OPENAI_AVAILABLE: + raise ImportError( + "openai package not installed. Install with: pip install openai" + ) + + client_kwargs = {"api_key": self.api_key} + if self.base_url: + client_kwargs["base_url"] = self.base_url + + self._client = OpenAI(**client_kwargs) + + def _generate_user_response(self, agent_message: str) -> str: + """Generate user response using the simulator LLM.""" + # Build conversation history for the user LLM + messages = [ + { + "role": "system", + "content": self.task_prompt + "\n\n" + self._current_task, + } + ] + + # Add conversation history + for turn in self._conversation: + if turn.role == "agent": + # Agent messages appear as "assistant" to the user simulator + messages.append({"role": "user", "content": turn.content}) + else: + # User's own previous messages + messages.append({"role": "assistant", "content": turn.content}) + + # Add the latest agent message + messages.append({"role": "user", "content": agent_message}) + + # Generate user response + response = self._client.chat.completions.create( + model=self.simulator_model, + messages=messages, + temperature=self.temperature, + max_tokens=500, + ) + + return response.choices[0].message.content + + def _generate_initial_user_message(self) -> str: + """Generate the initial user message to start conversation.""" + messages = [ + { + "role": "system", + "content": self.task_prompt + "\n\n" + self._current_task, + }, + { + "role": "user", + "content": "Start the conversation by stating your request or problem.", + }, + ] + + response = self._client.chat.completions.create( + model=self.simulator_model, + messages=messages, + temperature=self.temperature, + max_tokens=300, + ) + + return response.choices[0].message.content + + def _check_success(self, text: str) -> bool: + """Check if conversation indicates success.""" + text_lower = text.lower() + return any(kw.lower() in text_lower for kw in self.success_keywords) + + def _check_failure(self, text: str) -> bool: + """Check if conversation indicates failure.""" + text_lower = text.lower() + return any(kw.lower() in text_lower for kw in self.failure_keywords) + + def _evaluate_with_llm_judge(self) -> Dict[str, Any]: + """Evaluate the conversation using LLM-as-a-Judge. + + Returns: + Dictionary with scores and evaluation details. + """ + import json + + # Format conversation for judge + conv_text = "" + for turn in self._conversation: + role_label = "Assistant" if turn.role == "agent" else "User" + conv_text += f"{role_label}: {turn.content}\n\n" + + prompt = self.JUDGE_PROMPT.format(conversation=conv_text) + + try: + response = self._client.chat.completions.create( + model=self.judge_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, # Low temperature for consistent evaluation + max_tokens=500, + ) + + result_text = response.choices[0].message.content + + # Extract JSON from response + json_match = re.search(r"```json\s*(.*?)\s*```", result_text, re.DOTALL) + if json_match: + result = json.loads(json_match.group(1)) + else: + # Try to parse the whole response as JSON + result = json.loads(result_text) + + # Validate and normalize scores + for key in [ + "helpfulness", + "clarity", + "problem_resolution", + "professionalism", + "efficiency", + "overall_score", + ]: + if key in result: + result[key] = max(1, min(10, int(result[key]))) + + return result + + except Exception as e: + # Fallback to keyword-based evaluation + print(f"[LLM Judge] Evaluation failed: {e}, using fallback") + last_user_msg = self._conversation[-1].content if self._conversation else "" + success = self._check_success(last_user_msg) + return { + "helpfulness": 7 if success else 3, + "clarity": 5, + "problem_resolution": 8 if success else 2, + "professionalism": 5, + "efficiency": 5, + "overall_score": 7 if success else 3, + "success": success, + "explanation": "Fallback evaluation (LLM judge failed)", + } + + def _calculate_reward_from_judge(self, judge_result: Dict[str, Any]) -> float: + """Calculate reward from LLM judge evaluation. + + Maps the overall_score (1-10) to reward range. + """ + overall = judge_result.get("overall_score", 5) + success = judge_result.get("success", False) + + if success: + # Success: reward based on quality (5-10 range) + # Map overall_score 1-10 to reward 5-10 + return 5.0 + (overall / 10.0) * 5.0 + else: + # Failure: negative reward based on how bad + # Map overall_score 1-10 to reward -5 to 0 + return (overall / 10.0) * 5.0 - 5.0 + + def reset( + self, task_prompt: Optional[str] = None, seed: Optional[int] = None, **kwargs + ) -> str: + """Reset the game to start a new conversation. + + Args: + task_prompt: Override task prompt for this episode + seed: Random seed + **kwargs: Additional arguments + + Returns: + Initial observation (user's first message) + """ + if seed is not None: + random.seed(seed) + + # Update task prompt if provided + if task_prompt: + self._current_task = task_prompt + else: + # Generate a random specific task + self._current_task = self._generate_random_task() + + # Reset state + self._conversation = [] + self._turn_count = 0 + self._done = False + self._success = False + self._judge_result = None + + # Generate initial user message + initial_message = self._generate_initial_user_message() + self._conversation.append( + ConversationTurn(role="user", content=initial_message) + ) + + return self._format_observation(initial_message) + + def _generate_random_task(self) -> str: + """Generate a random specific task for variety.""" + tasks = [ + "Your flight was cancelled and you need to rebook.", + "You received a defective product and want a refund.", + "You need to change your hotel reservation dates.", + "Your internet connection is not working.", + "You want to upgrade your subscription plan.", + "You're looking for recommendations for a restaurant.", + "You need help resetting your password.", + "You want to cancel your order.", + ] + return random.choice(tasks) + + def _format_observation(self, message: str) -> str: + """Format observation for the agent.""" + obs = f"=== User Message ===\n{message}\n" + obs += f"\n=== Conversation Turn: {self._turn_count + 1}/{self.max_turns} ===" + return obs + + def step(self, action: str) -> StepResult: + """Execute agent action and get user response. + + Args: + action: Agent's response to the user + + Returns: + StepResult with user's response, reward, done flag, and info + """ + if self._done: + return StepResult( + observation="Conversation has ended.", + reward=0.0, + done=True, + info={"error": "conversation_ended"}, + ) + + self._turn_count += 1 + + # Parse agent's action + parsed_action = self._parse_action(action) + + # Add agent message to conversation + self._conversation.append(ConversationTurn(role="agent", content=parsed_action)) + + # Generate user response + user_response = self._generate_user_response(parsed_action) + self._conversation.append(ConversationTurn(role="user", content=user_response)) + + # Check for episode end conditions + episode_ended = False + end_reason = "" + + if self._check_success(user_response): + episode_ended = True + end_reason = "success_keyword" + elif self._check_failure(user_response): + episode_ended = True + end_reason = "failure_keyword" + elif self._turn_count >= self.max_turns: + episode_ended = True + end_reason = "timeout" + + # Calculate reward + if episode_ended: + self._done = True + + if self.use_llm_judge: + # Use LLM-as-a-Judge for evaluation + self._judge_result = self._evaluate_with_llm_judge() + reward = self._calculate_reward_from_judge(self._judge_result) + self._success = self._judge_result.get("success", False) + + # Add evaluation summary to response + judge_summary = ( + f"\n\n=== LLM Judge Evaluation ===\n" + f"Overall Score: {self._judge_result.get('overall_score', 'N/A')}/10\n" + f"Success: {self._success}\n" + f"Explanation: {self._judge_result.get('explanation', 'N/A')}" + ) + user_response = f"{user_response}{judge_summary}" + else: + # Fallback to keyword-based evaluation + if end_reason == "success_keyword": + self._success = True + reward = self.REWARD_SUCCESS + else: + self._success = False + reward = self.REWARD_FAILURE + + # Add end reason prefix + if end_reason == "timeout": + user_response = f"TIMEOUT: Maximum turns reached.\n\n{user_response}" + elif self._success: + user_response = f"SUCCESS: {user_response}" + else: + user_response = f"FAILURE: {user_response}" + else: + reward = self.REWARD_STEP + + # Build info dict + info = { + "turn": self._turn_count, + "success": self._success, + "agent_message": parsed_action, + "user_message": user_response, + } + + # Add judge evaluation if available + if self._judge_result: + info["judge_evaluation"] = self._judge_result + + return StepResult( + observation=self._format_observation(user_response), + reward=reward, + done=self._done, + info=info, + ) + + def _parse_action(self, raw_action: str) -> str: + """Parse action from LLM output.""" + # Try to extract from tags + match = re.search( + r"\s*(.*?)\s*", raw_action, re.IGNORECASE | re.DOTALL + ) + if match: + return match.group(1).strip() + + # Otherwise use the whole output + return raw_action.strip() + + def get_system_prompt(self) -> str: + """Return the system prompt for the agent.""" + return ( + "You are a helpful assistant engaging in a conversation with a user.\n" + "Your goal is to understand the user's needs and help them effectively.\n\n" + "IMPORTANT: Respond naturally and helpfully. Be concise but thorough.\n" + "If you need more information, ask clarifying questions.\n" + "If you can help, provide clear solutions or information.\n\n" + "Wrap your response in tags.\n\n" + "Example:\n" + "I understand you're having trouble with your order. " + "Could you please provide your order number so I can look into this?" + ) + + def get_initial_user_message(self) -> str: + """Return context for the agent.""" + return "You are helping a user. Respond to their message." + + def get_state(self) -> Dict[str, Any]: + """Return current game state.""" + state = { + "turn_count": self._turn_count, + "max_turns": self.max_turns, + "done": self._done, + "success": self._success, + "conversation_length": len(self._conversation), + "use_llm_judge": self.use_llm_judge, + } + if self._judge_result: + state["judge_evaluation"] = self._judge_result + return state + + def generate_initial_state(self) -> Dict[str, Any]: + """Generate random initial state for training.""" + return { + "seed": random.randint(0, 1000000), + } + + def get_user_message_with_state(self, **kwargs) -> str: + """Generate user message with state for prompt.""" + self.reset(**kwargs) + initial_obs = self._format_observation(self._conversation[0].content) + return f"{initial_obs}\n\nRespond to the user." + + def get_interaction_name(self) -> str: + """Return interaction name.""" + return "llm_user_simulator" diff --git a/opentinker/environment/llm_user_simulator/llm_user_server.py b/opentinker/environment/llm_user_simulator/llm_user_server.py new file mode 100644 index 00000000..0fdb6874 --- /dev/null +++ b/opentinker/environment/llm_user_simulator/llm_user_server.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +"""LLM User Simulator Server. + +This script starts an LLM user simulator server. + +Usage: + python llm_user_server.py --port 8100 --shards 8 + + # With custom model: + python llm_user_server.py --port 8100 --simulator_model gpt-4o-mini +""" + +import argparse +import os +import subprocess +import sys +import time + + +def main(): + parser = argparse.ArgumentParser(description="LLM User Simulator Server") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8100, help="Server port") + parser.add_argument( + "--shards", + type=int, + default=4, + help="Number of independent server processes on consecutive ports.", + ) + parser.add_argument( + "--simulator_model", + type=str, + default="gpt-4o-mini", + help="Model name for user simulation (e.g., gpt-4o-mini, gpt-4o)", + ) + parser.add_argument( + "--simulator_base_url", + type=str, + default=None, + help="Custom API base URL (for local models like vLLM)", + ) + parser.add_argument( + "--task_type", + type=str, + default="customer_service", + choices=[ + "customer_service", + "booking_assistant", + "tech_support", + "information_seeking", + ], + help="Type of user simulation task", + ) + parser.add_argument( + "--max_turns", + type=int, + default=10, + help="Maximum conversation turns per episode", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.7, + help="Sampling temperature for user simulator", + ) + parser.add_argument( + "--use_llm_judge", + action="store_true", + default=True, + help="Use LLM-as-a-Judge for evaluation (default: True)", + ) + parser.add_argument( + "--no_llm_judge", + action="store_true", + help="Disable LLM-as-a-Judge, use keyword matching instead", + ) + parser.add_argument( + "--judge_model", + type=str, + default=None, + help="Model for LLM judge (defaults to simulator_model)", + ) + args = parser.parse_args() + + # Handle --no_llm_judge flag + if args.no_llm_judge: + args.use_llm_judge = False + + from opentinker.environment.llm_user_simulator.llm_user_game import LLMUserGame + + print("\nLLM User Simulator Configuration:") + print(f" Simulator model: {args.simulator_model}") + print(f" Base URL: {args.simulator_base_url or 'default (OpenAI)'}") + print(f" Task type: {args.task_type}") + print(f" Max turns: {args.max_turns}") + print(f" Temperature: {args.temperature}") + print(f" Shards: {args.shards}") + print(f" LLM-as-a-Judge: {'enabled' if args.use_llm_judge else 'disabled'}") + if args.use_llm_judge: + print(f" Judge model: {args.judge_model or args.simulator_model}") + print("\nReward structure:") + if args.use_llm_judge: + print(" Using LLM Judge scoring (1-10 scale mapped to rewards)") + else: + print(f" Success: +{LLMUserGame.REWARD_SUCCESS}") + print(f" Failure: {LLMUserGame.REWARD_FAILURE}") + print(f" Step penalty: {LLMUserGame.REWARD_STEP}") + + # Sharded mode + if args.shards and args.shards > 1: + print( + f"\nStarting sharded mode: {args.shards} shards on ports {args.port}..{args.port + args.shards - 1}" + ) + + children: list[subprocess.Popen] = [] + try: + for i in range(args.shards): + port_i = args.port + i + cmd = [ + sys.executable, + os.path.abspath(__file__), + "--host", + args.host, + "--port", + str(port_i), + "--shards", + "1", + "--simulator_model", + args.simulator_model, + "--task_type", + args.task_type, + "--max_turns", + str(args.max_turns), + "--temperature", + str(args.temperature), + ] + if args.simulator_base_url: + cmd.extend(["--simulator_base_url", args.simulator_base_url]) + if not args.use_llm_judge: + cmd.append("--no_llm_judge") + if args.judge_model: + cmd.extend(["--judge_model", args.judge_model]) + + children.append(subprocess.Popen(cmd)) + time.sleep(0.1) + + print("Shards started. Press Ctrl+C to stop all shards.") + while True: + for p in children: + rc = p.poll() + if rc is not None: + raise RuntimeError( + f"Shard exited early: pid={p.pid}, code={rc}" + ) + time.sleep(1.0) + except KeyboardInterrupt: + pass + finally: + for p in children: + try: + p.terminate() + except Exception: + pass + for p in children: + try: + p.wait(timeout=5) + except Exception: + try: + p.kill() + except Exception: + pass + return + + # Single shard mode + from opentinker.environment.base_game_server import run_game_server + + run_game_server( + game_class=LLMUserGame, + host=args.host, + port=args.port, + stats_class=None, + simulator_model=args.simulator_model, + simulator_base_url=args.simulator_base_url, + task_type=args.task_type, + max_turns=args.max_turns, + temperature=args.temperature, + use_llm_judge=args.use_llm_judge, + judge_model=args.judge_model, + ) + + +if __name__ == "__main__": + main() diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 36800531..7bcbe469 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -1111,6 +1111,40 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: f"Job {job.job_id}: ✓ LoRA enabled: rank={lora_rank}, alpha={lora_alpha}, target_modules={target_modules}" ) + # Forward world model loss settings if specified + actor_config = job.config.get("actor", {}) + if actor_config.get("use_world_model_loss"): + cmd.append("actor_rollout_ref.actor.use_world_model_loss=true") + wm_coef = actor_config.get("world_model_loss_coef", 0.1) + cmd.append(f"actor_rollout_ref.actor.world_model_loss_coef={wm_coef}") + logger.info(f"Job {job.job_id}: ✓ World Model Loss enabled: coef={wm_coef}") + + # Forward WM active sampling settings if specified + if actor_config.get("wm_active_sampling"): + cmd.append("actor_rollout_ref.actor.wm_active_sampling=true") + wm_active_coef = actor_config.get("wm_active_sampling_coef", 0.5) + cmd.append( + f"actor_rollout_ref.actor.wm_active_sampling_coef={wm_active_coef}" + ) + logger.info( + f"Job {job.job_id}: ✓ WM Active Sampling enabled: coef={wm_active_coef}" + ) + + # Forward WM dynamic entropy settings if specified + # Use + prefix to add new config keys that may not exist in the base schema + wm_dynamic_entropy = actor_config.get("wm_dynamic_entropy", {}) + if wm_dynamic_entropy.get("enabled"): + cmd.append("+actor_rollout_ref.actor.wm_dynamic_entropy.enabled=true") + beta_0 = wm_dynamic_entropy.get("beta_0", 0.001) + beta_1 = wm_dynamic_entropy.get("beta_1", 0.01) + gamma = wm_dynamic_entropy.get("gamma", 1.0) + cmd.append(f"+actor_rollout_ref.actor.wm_dynamic_entropy.beta_0={beta_0}") + cmd.append(f"+actor_rollout_ref.actor.wm_dynamic_entropy.beta_1={beta_1}") + cmd.append(f"+actor_rollout_ref.actor.wm_dynamic_entropy.gamma={gamma}") + logger.info( + f"Job {job.job_id}: ✓ WM Dynamic Entropy enabled: beta_0={beta_0}, beta_1={beta_1}, gamma={gamma}" + ) + logger.info(f"Job {job.job_id}: Launching server with command: {' '.join(cmd)}") # Create log files for stdout and stderr with human-readable timestamp diff --git a/opentinker/scripts/launch_scheduler.sh b/opentinker/scripts/launch_scheduler.sh index 581e7cfd..5dff4df5 100755 --- a/opentinker/scripts/launch_scheduler.sh +++ b/opentinker/scripts/launch_scheduler.sh @@ -6,11 +6,19 @@ export CUDA_HOME=$HOME/local/cuda-12.8 export PATH=$CUDA_HOME/bin:$PATH export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH -export ROLLOUT_TRACE_DIR="/home/haofeiy2/OpenTinker/traces" +# DISABLED - causes disk space issues +# export ROLLOUT_TRACE_DIR="/home/haofeiy2/OpenTinker/traces" export NVCC_EXECUTABLE=$CUDA_HOME/bin/nvcc export TORCH_CUDA_ARCH_LIST="9.0" export FLASHINFER_HOMOGENEOUS_MS=1 +# Disable sleep mode to avoid cumem allocator CUDA errors (V1 required for async engine) +export VLLM_DISABLE_SLEEP_MODE=1 + +# Limit Ray object store to prevent disk space issues +# Default 200GB is too large and causes spilling to disk +export RAY_object_store_memory=30000000000 # 50GB max + # Default configuration AVAILABLE_GPUS="[0,1,2,3,4,5,6,7,8,9]" PORT_RANGE="null" # Set to null for auto-detection diff --git a/opentinker/scripts/run_alfworld.sh b/opentinker/scripts/run_alfworld.sh new file mode 100755 index 00000000..7f4168cc --- /dev/null +++ b/opentinker/scripts/run_alfworld.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# ALFWorld Training & Inference Script +# +# This script runs ALFWorld RL training or inference with OpenTinker. +# You need to run these steps in SEPARATE terminals. +# +# For Training (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env +# Terminal 3: bash run_alfworld.sh client +# +# For Inference/Evaluation (3 terminals): +# Terminal 1: bash run_alfworld.sh scheduler +# Terminal 2: bash run_alfworld.sh env-eval +# Terminal 3: bash run_alfworld.sh inference model_path=/path/to/checkpoint + +# ============================================================================= +# Configuration +# ============================================================================= +SCHEDULER_PORT=8089 +ENV_PORT=8091 +GPUS='[0,1,2,3,4,5,6,7,8,9]' +NUM_GPUS=4 + +# Fix vLLM v1 cumem allocator issue (V1 is required for async engine) +# Disable sleep mode to avoid cumem allocator CUDA errors +export VLLM_DISABLE_SLEEP_MODE=1 + +# Activate conda environment +source ~/anaconda3/etc/profile.d/conda.sh +conda activate opentinker + +# Change to OpenTinker directory +cd /home/haofeiy2/OpenTinker + +# ============================================================================= +# Step Selection +# ============================================================================= +case "$1" in + scheduler|1) + echo "========================================" + echo "Step 1: Starting Scheduler on port $SCHEDULER_PORT" + echo "========================================" + bash opentinker/scripts/launch_scheduler.sh \ + --scheduler-port $SCHEDULER_PORT \ + --gpus "$GPUS" + ;; + + env|2) + echo "========================================" + echo "Step 2: Starting ALFWorld Environment Server on ports $ENV_PORT-$((ENV_PORT+7)) (8 shards)" + echo "========================================" + python opentinker/environment/alfworld/alfworld_server.py \ + --port $ENV_PORT \ + --shards 8 \ + --split train \ + --max_steps 50 + ;; + + env-eval) + echo "========================================" + echo "Step 2 (Eval): Starting ALFWorld Environment Server for Evaluation" + echo "========================================" + python opentinker/environment/alfworld/alfworld_server.py \ + --port $ENV_PORT \ + --shards 1 \ + --split eval_in_distribution \ + --max_steps 50 + ;; + + client|3) + echo "========================================" + echo "Step 3: Running ALFWorld RL Client" + echo "========================================" + python opentinker/client/alfworld_rl.py \ + tokenizer_path=Qwen/Qwen2.5-3B-Instruct \ + batch_size=16 \ + val_batch_size=32 \ + num_epochs=5 \ + save_freq=1000 \ + test_freq=100 \ + num_gpus=$NUM_GPUS \ + scheduler_url=http://0.0.0.0:$SCHEDULER_PORT \ + interaction.config.env_port=$ENV_PORT \ + interaction.config.env_host=0.0.0.0 \ + interaction.config.env_shards=8 + ;; + + inference|4) + echo "========================================" + echo "Step 3 (Inference): Running ALFWorld Evaluation" + echo "========================================" + # Pass remaining arguments (e.g., model_path=/path/to/checkpoint) + shift # Remove 'inference' from args + python opentinker/client/alfworld_inference.py \ + num_gpus=$NUM_GPUS \ + scheduler_url=http://0.0.0.0:$SCHEDULER_PORT \ + env_endpoint=http://0.0.0.0:$ENV_PORT \ + split=eval_in_distribution \ + "$@" + ;; + + *) + echo "ALFWorld Training & Inference Script" + echo "" + echo "Usage: $0 {scheduler|env|env-eval|client|inference}" + echo " $0 {1|2|3|4}" + echo "" + echo "=== For Training (3 terminals) ===" + echo " Terminal 1: $0 scheduler # Start scheduler (port $SCHEDULER_PORT)" + echo " Terminal 2: $0 env # Start environment server (train split)" + echo " Terminal 3: $0 client # Start RL training client" + echo "" + echo "=== For Inference/Evaluation (3 terminals) ===" + echo " Terminal 1: $0 scheduler # Start scheduler (port $SCHEDULER_PORT)" + echo " Terminal 2: $0 env-eval # Start environment server (eval split)" + echo " Terminal 3: $0 inference model_path=/path/to/checkpoint" + echo "" + echo "Inference options:" + echo " model_path=... # Path to trained checkpoint (REQUIRED)" + echo " max_samples=N # Limit evaluation samples" + echo " output_path=... # Save results to file" + echo " split=... # eval_in_distribution (default) or eval_out_of_distribution" + echo "" + echo "Configuration:" + echo " SCHEDULER_PORT=$SCHEDULER_PORT" + echo " ENV_PORT=$ENV_PORT" + echo " GPUS=$GPUS" + echo " NUM_GPUS=$NUM_GPUS" + ;; +esac diff --git a/opentinker/server/config/actor/actor.yaml b/opentinker/server/config/actor/actor.yaml index 43f576aa..9f92a1ba 100755 --- a/opentinker/server/config/actor/actor.yaml +++ b/opentinker/server/config/actor/actor.yaml @@ -92,6 +92,35 @@ ppo_epochs: 1 # Shuffle training data across PPO epochs shuffle: false +# World model loss: auxiliary SFT loss for predicting environment observations +# This helps the model learn a world model of the environment in multi-turn agentic tasks +use_world_model_loss: false + +# Coefficient for world model loss +world_model_loss_coef: 0.1 + +# WM Active Sampling: use WM uncertainty to weight policy gradient +# High uncertainty samples get higher advantage -> agent learns more from uncertain states +# Low uncertainty (redundant) samples get lower weight -> less gradient +# Uses OLD log_prob (lagged theta^-) for stable weights across PPO epochs +# Weights computed at MINI-BATCH level for consistency across gradient accumulation +# Auto-detects observation_mask format (bool mask vs index list) +wm_active_sampling: false + +# Alpha coefficient for WM active sampling: weight = exp(alpha * z_score) +# Higher alpha = stronger emphasis on uncertainty differences +wm_active_sampling_coef: 0.5 + +# Minimum weight for WM active sampling (allows down-weighting redundant samples) +wm_active_wmin: 0.5 + +# Maximum weight for WM active sampling (caps up-weighting uncertain samples) +wm_active_wmax: 2.0 + +# Only apply WM weights to positive advantages (avoid penalizing exploration) +# Useful when reward is sparse and early exploration often fails +wm_active_positive_only: false + # checkpoint configs checkpoint: # Target dataclass for this configuration diff --git a/opentinker/server/config/ppo_trainer.yaml b/opentinker/server/config/ppo_trainer.yaml index 53d1c53b..ec651062 100755 --- a/opentinker/server/config/ppo_trainer.yaml +++ b/opentinker/server/config/ppo_trainer.yaml @@ -221,7 +221,7 @@ trainer: del_local_ckpt_after_load: False # Default local directory for saving checkpoints - default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} + default_local_dir: /mnt/disk3_from_server2/haofeiy2/opentinker_checkpoints/${trainer.project_name}/${trainer.experiment_name} # Maximum number of actor checkpoints to keep max_actor_ckpt_to_keep: null diff --git a/opentinker/server/config/rollout/rollout.yaml b/opentinker/server/config/rollout/rollout.yaml index 33cff32e..9a53e557 100755 --- a/opentinker/server/config/rollout/rollout.yaml +++ b/opentinker/server/config/rollout/rollout.yaml @@ -77,6 +77,10 @@ enable_prefix_caching: True # safetensors (for huge model, and set use_shm=True); dummy: randomly init model weight load_format: dummy +# Whether to compute and store log_probs during rollout (required for turn-level temperature IS correction) +# Set to True if using turn_level_temperature.enabled=True +calculate_log_probs: true + # [Will be deprecated, use log_prob_micro_batch_size_per_gpu] The batch size for one forward pass in the computation of log_prob. Global batch size. log_prob_micro_batch_size: null @@ -189,9 +193,19 @@ multi_turn: # Number of repeat rollouts for each interaction num_repeat_rollouts: null -# support logging rollout prob for debugging purpose -# "Truncated importance sampling" requires rollout log probs, set to True when turning on Truncated importance sampling -calculate_log_probs: False + # Turn-level temperature (WM-guided exploration) + # Adjusts sampling temperature per turn based on previous observation's WM uncertainty + turn_level_temperature: + _target_: opentinker.backend_patch.verl.workers.config.rollout.TurnLevelTemperatureConfig + enabled: false + base_temperature: 1.0 + kappa: 0.3 + min_temperature: 0.5 + max_temperature: 1.5 + ema_decay: 0.9 + enable_is_correction: true + # Use accurate uncertainty (T=1 re-generation) vs heuristic (scaled log μ_T) + use_accurate_uncertainty: true # [Experimental] agent loop based rollout configs agent: diff --git a/opentinker/server/generic_agent_loop.py b/opentinker/server/generic_agent_loop.py index 3930bbc4..49534d1d 100755 --- a/opentinker/server/generic_agent_loop.py +++ b/opentinker/server/generic_agent_loop.py @@ -117,6 +117,16 @@ def __init__( self.response_mask: list[int] = [] self.response_logprobs: list[float] = [] + # Observation mask for world model loss + # observation_mask=1 for environment observation tokens (used for world model SFT loss) + # observation_mask=0 for LLM-generated action tokens + self.observation_mask: list[int] = [] + + # Turn index for each token (used for turn-wise dynamic entropy coefficient) + # turn_ids[i] = which turn token i belongs to (0-indexed) + # This allows computing per-turn WM uncertainty and applying different entropy weights + self.turn_ids: list[int] = [] + # Turn tracking self.user_turns = 0 self.assistant_turns = 0 @@ -124,6 +134,18 @@ def __init__( # Reward tracking (for turn-level rewards, accumulated for final reward) self.turn_scores: list[float] = [] + # Turn-level temperature tracking (for WM-guided exploration) + # Stores uncertainty from each turn's observation tokens + self.turn_uncertainties: list[float] = [] + # Temperature used for each turn's generation (for logging) + self.turn_temperatures: list[float] = [] + + # Token-level temperature for IS correction + # token_temperatures[i] = temperature used when sampling token i + # This is needed for proper importance sampling correction: + # ratio = π_θ(a|s) / μ_T(a|s) where μ_T is the behavior policy with temperature T + self.token_temperatures: list[float] = [] + # Extra fields for additional data self.extra_fields: dict[str, Any] = {} @@ -179,6 +201,43 @@ def init_class(cls, config, tokenizer, processor, **kwargs): "max_tokens_per_turn", None ) + # Turn-level temperature configuration (WM-guided exploration) + # T_t = T_base + kappa * normalize(u_{t-1}) + # where u_{t-1} is the uncertainty from the previous turn's observation + cls.turn_level_temperature = config.actor_rollout_ref.rollout.multi_turn.get( + "turn_level_temperature", {} + ) + cls.turn_temp_enabled = cls.turn_level_temperature.get("enabled", False) + cls.turn_temp_base = cls.turn_level_temperature.get("base_temperature", 1.0) + cls.turn_temp_kappa = cls.turn_level_temperature.get("kappa", 0.5) + cls.turn_temp_min = cls.turn_level_temperature.get("min_temperature", 0.5) + cls.turn_temp_max = cls.turn_level_temperature.get("max_temperature", 2.0) + # EMA for uncertainty normalization + cls.turn_temp_ema_decay = cls.turn_level_temperature.get("ema_decay", 0.9) + cls._uncertainty_ema_mean = None # Will be initialized on first observation + cls._uncertainty_ema_std = None + # IS correction: when True, token_temperatures are recorded and used in PPO loss + # to correct ratio = exp((log π_θ - log π_old) / T) + cls.enable_is_correction = cls.turn_level_temperature.get( + "enable_is_correction", True + ) + + # CRITICAL: Validate configuration + # If turn_level_temperature is enabled but IS correction is disabled, PPO assumptions are violated! + # The old_log_prob would be computed with T=1, but sampling used T≠1 + if cls.turn_temp_enabled and not cls.enable_is_correction: + import warnings + + warnings.warn( + "\n⚠️ CONFIGURATION WARNING ⚠️\n" + "turn_level_temperature.enabled=True but enable_is_correction=False!\n" + "This violates PPO's on-policy assumption and will cause ENTROPY COLLAPSE!\n" + "Either:\n" + " 1. Set enable_is_correction=True (recommended), or\n" + " 2. Set turn_level_temperature.enabled=False\n", + RuntimeWarning, + ) + # Pre-compute system prompt tokens for later stripping cls.system_prompt = tokenizer.apply_chat_template( [{}], @@ -453,6 +512,26 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu # Ensure env_info exists for all samples (even if empty) for consistent DataProto.concat output.extra_fields["env_info"] = agent_data.extra_fields.get("env_info", []) output.extra_fields["turn_scores"] = agent_data.turn_scores + # Add observation_mask for world model loss (marks environment feedback tokens) + output.extra_fields["observation_mask"] = agent_data.observation_mask[ + : self.response_length + ] + # Add turn_ids for turn-wise dynamic entropy coefficient + # turn_ids[i] = which turn token i belongs to (0-indexed) + output.extra_fields["turn_ids"] = agent_data.turn_ids[: self.response_length] + # Add turn-level temperature info for logging + output.extra_fields["turn_temperatures"] = agent_data.turn_temperatures + output.extra_fields["turn_uncertainties"] = agent_data.turn_uncertainties + # Add token-level temperatures for IS correction in PPO (only if enabled) + # token_temperatures[i] = temperature used when sampling token i + if ( + self.turn_temp_enabled + and self.enable_is_correction + and agent_data.token_temperatures + ): + output.extra_fields["token_temperatures"] = agent_data.token_temperatures[ + : self.response_length + ] # Add any other extra fields (except the ones we already set) for key, value in agent_data.extra_fields.items(): if key not in output.extra_fields: @@ -503,6 +582,7 @@ async def _handle_generating_state( """Handle the generating state: generate LLM response. The generated tokens are marked with mask=1 (included in loss computation). + Turn IDs are recorded for turn-wise dynamic entropy coefficient. """ import time @@ -530,10 +610,69 @@ async def _handle_generating_state( agent_data.response_ids = [eos_token_id] agent_data.prompt_ids.append(eos_token_id) agent_data.response_mask.append(1) + agent_data.observation_mask.append(0) # EOS is LLM-generated + agent_data.turn_ids.append( + agent_data.assistant_turns + ) # Current turn + if self.turn_temp_enabled and self.enable_is_correction: + agent_data.token_temperatures.append(1.0) # Default temperature return GenericAgentState.TERMINATED + # Current turn index (0-indexed, based on assistant turns) + current_turn = agent_data.assistant_turns + + # Compute turn-level temperature based on previous turn's uncertainty + actual_sampling_params = sampling_params.copy() + turn_temperature = sampling_params.get("temperature", self.turn_temp_base) + print( - f"[GenericAgentLoop DEBUG] _handle_generating_state START: request_id={agent_data.request_id}, prompt_len={len(agent_data.prompt_ids)}" + f"[TurnLevelTemp DEBUG] Turn {agent_data.assistant_turns}: " + f"turn_temp_enabled={self.turn_temp_enabled}, " + f"num_uncertainties={len(agent_data.turn_uncertainties)}" + ) + + if self.turn_temp_enabled and len(agent_data.turn_uncertainties) > 0: + # Use the most recent uncertainty (from previous turn's observation) + prev_uncertainty = agent_data.turn_uncertainties[-1] + + # Normalize uncertainty using EMA statistics + if ( + self._uncertainty_ema_mean is not None + and self._uncertainty_ema_std is not None + ): + if self._uncertainty_ema_std > 1e-6: + normalized_u = ( + prev_uncertainty - self._uncertainty_ema_mean + ) / self._uncertainty_ema_std + else: + normalized_u = 0.0 + else: + # First observation, use raw uncertainty (will be normalized later) + normalized_u = 0.0 + + # T_t = T_base + kappa * tanh(normalized_u) + # tanh bounds the effect to [-1, 1], so temperature is in [T_base - kappa, T_base + kappa] + import math + + temp_adjustment = self.turn_temp_kappa * math.tanh(normalized_u) + turn_temperature = self.turn_temp_base + temp_adjustment + + # Clamp to min/max + turn_temperature = max( + self.turn_temp_min, min(self.turn_temp_max, turn_temperature) + ) + + actual_sampling_params["temperature"] = turn_temperature + print( + f"[TurnLevelTemp] Turn {current_turn}: prev_uncertainty={prev_uncertainty:.3f}, " + f"normalized={normalized_u:.3f}, temperature={turn_temperature:.3f}" + ) + + # Record the temperature used for this turn + agent_data.turn_temperatures.append(turn_temperature) + + print( + f"[GenericAgentLoop DEBUG] _handle_generating_state START: request_id={agent_data.request_id}, prompt_len={len(agent_data.prompt_ids)}, turn={current_turn}" ) start_time = time.time() with simple_timer("generate_sequences", agent_data.metrics): @@ -541,10 +680,11 @@ async def _handle_generating_state( f"[GenericAgentLoop DEBUG] Calling server_manager.generate() with image_data={agent_data.image_data is not None}..." ) # CRITICAL: Pass image_data to vLLM for VL model inference + # Use actual_sampling_params which may have adjusted temperature output = await self.server_manager.generate( request_id=agent_data.request_id, prompt_ids=agent_data.prompt_ids, - sampling_params=sampling_params, + sampling_params=actual_sampling_params, image_data=agent_data.image_data, ) elapsed = time.time() - start_time @@ -573,6 +713,20 @@ async def _handle_generating_state( agent_data.response_mask += [1] * len( agent_data.response_ids ) # mask=1 for LLM tokens + agent_data.observation_mask += [0] * len( + agent_data.response_ids + ) # observation_mask=0 for LLM-generated actions + + # Record turn ID for each token (used for turn-wise dynamic entropy coefficient) + # current_turn was captured BEFORE incrementing assistant_turns + agent_data.turn_ids += [current_turn] * len(agent_data.response_ids) + + # Record token-level temperature for IS correction (only if enabled) + # Each LLM-generated token was sampled with turn_temperature + if self.turn_temp_enabled and self.enable_is_correction: + agent_data.token_temperatures += [turn_temperature] * len( + agent_data.response_ids + ) if response_log_probs: agent_data.response_logprobs += response_log_probs @@ -672,19 +826,242 @@ async def _handle_interacting_state( return GenericAgentState.TERMINATED # Update prompt_ids and response_mask - # mask=0 for environment observation tokens (not included in loss) + # mask=0 for environment observation tokens (not included in policy loss) agent_data.prompt_ids += response_ids agent_data.response_mask += [0] * len(response_ids) + # observation_mask=1 for environment observation tokens (used for world model SFT loss) + agent_data.observation_mask += [1] * len(response_ids) + # turn_ids: observation belongs to the previous turn (action that caused this observation) + # Use (assistant_turns - 1) since assistant_turns was already incremented in _handle_generating_state + obs_turn = max(0, agent_data.assistant_turns - 1) + agent_data.turn_ids += [obs_turn] * len(response_ids) if agent_data.response_logprobs: # Pad logprobs with 0.0 for observation tokens agent_data.response_logprobs += [0.0] * len(response_ids) + # Pad token_temperatures with 1.0 for observation tokens (only if IS correction enabled) + # These tokens have response_mask=0 so they don't affect IS correction + if self.turn_temp_enabled and self.enable_is_correction: + agent_data.token_temperatures += [1.0] * len(response_ids) + + # Compute turn-level uncertainty for the next turn's temperature adjustment + if self.turn_temp_enabled: + # Get the action token IDs from the last generation + # Find the last contiguous block of response_mask=1 tokens + last_action_start = -1 + for i in range(len(agent_data.response_mask) - 1, -1, -1): + if agent_data.response_mask[i] == 1: + last_action_start = i + elif last_action_start != -1: + break + + if last_action_start != -1: + action_token_ids = agent_data.prompt_ids[last_action_start:] + else: + action_token_ids = ( + agent_data.response_ids + if hasattr(agent_data, "response_ids") + else [] + ) + + # Use accurate method (T=1 re-generation) or heuristic based on config + use_accurate = self.turn_level_temperature.get( + "use_accurate_uncertainty", True + ) + + if use_accurate: + turn_uncertainty = await self._compute_action_uncertainty_accurate( + agent_data, action_token_ids + ) + else: + turn_uncertainty = self._compute_action_uncertainty_heuristic( + agent_data + ) + + agent_data.turn_uncertainties.append(turn_uncertainty) + print( + f"[TurnLevelTemp DEBUG] After observation: turn_uncertainties={agent_data.turn_uncertainties}, " + f"EMA_mean={self._uncertainty_ema_mean}, EMA_std={self._uncertainty_ema_std}" + ) + + # Update EMA statistics for normalization + if self._uncertainty_ema_mean is None: + GenericAgentLoop._uncertainty_ema_mean = turn_uncertainty + GenericAgentLoop._uncertainty_ema_std = 1.0 # Initial std + else: + # EMA update + decay = self.turn_temp_ema_decay + GenericAgentLoop._uncertainty_ema_mean = ( + decay * self._uncertainty_ema_mean + (1 - decay) * turn_uncertainty + ) + # Approximate std update using running variance + diff = turn_uncertainty - self._uncertainty_ema_mean + GenericAgentLoop._uncertainty_ema_std = ( + decay * self._uncertainty_ema_std + (1 - decay) * abs(diff) + ) + if should_terminate: return GenericAgentState.TERMINATED else: return GenericAgentState.GENERATING + async def _compute_action_uncertainty_accurate( + self, agent_data: GenericAgentData, action_token_ids: list[int] + ) -> float: + """Compute ACCURATE uncertainty by re-generating with T=1. + + This method calls the model with T=1 to get log π (true model distribution), + which is independent of the sampling temperature used during rollout. + + Uncertainty = -mean(log π) = cross-entropy / perplexity + + This is more accurate than using log μ_T from rollout, but requires an + extra generation call per turn. + + Args: + agent_data: Current agent state + action_token_ids: Token IDs of the action (response) to evaluate + + Returns: + float: Uncertainty score based on log π (T=1) + """ + import numpy as np + + if not action_token_ids: + print("[TurnLevelTemp] No action tokens, using default uncertainty 1.0") + return 1.0 + + try: + # Build prompt: everything before the action + prompt_before_action = agent_data.prompt_ids[: -len(action_token_ids)] + + # We'll generate the same action tokens with T=1 to get log π + # Use a trick: set the prompt to include the action, generate 1 token + # and look at the returned log_probs + + # Actually, a simpler approach: generate with T=1 and logprobs=True + # The log_probs returned will be log π (since T=1) + sampling_params = { + "temperature": 1.0, # T=1 to get log π + "max_new_tokens": len(action_token_ids), # Generate same length + "logprobs": True, + "top_p": 1.0, # No top-p filtering + "top_k": -1, # No top-k filtering + } + + output = await self.server_manager.generate( + request_id=f"{agent_data.request_id}_uncertainty_t1", + prompt_ids=prompt_before_action, + sampling_params=sampling_params, + image_data=agent_data.image_data, + ) + + # Get log_probs from the T=1 generation + if output.log_probs and len(output.log_probs) > 0: + valid_logprobs = [ + lp for lp in output.log_probs if lp is not None and lp != 0.0 + ] + if valid_logprobs: + mean_logprob_pi = np.mean(valid_logprobs) + uncertainty = -mean_logprob_pi + uncertainty = max(0.1, min(10.0, uncertainty)) + print( + f"[TurnLevelTemp] Accurate uncertainty (T=1): {uncertainty:.3f} " + f"(mean_log_π={mean_logprob_pi:.3f}, num_tokens={len(valid_logprobs)})" + ) + return uncertainty + + print( + "[TurnLevelTemp] No log_probs from T=1 generation, falling back to heuristic" + ) + return self._compute_action_uncertainty_heuristic(agent_data) + + except Exception as e: + print(f"[TurnLevelTemp] Failed to compute accurate uncertainty: {e}") + return self._compute_action_uncertainty_heuristic(agent_data) + + def _compute_action_uncertainty_heuristic( + self, agent_data: GenericAgentData + ) -> float: + """Compute uncertainty using log μ_T with temperature compensation. + + Uses the log_probs from rollout (log μ_T) as a proxy for uncertainty, + with compensation for the temperature used during sampling. + + Problem: log μ_T is affected by the sampling temperature T: + - High T → log μ_T more negative (flatter distribution) + - Low T → log μ_T closer to 0 (sharper distribution) + + This creates a positive feedback loop if not compensated: + - High uncertainty → high T → log μ_T more negative → higher uncertainty → ... + - Low uncertainty → low T → log μ_T closer to 0 → lower uncertainty → ... + + Compensation: We divide by T to approximate temperature-invariant uncertainty. + While log μ_T = log softmax(z/T) is not exactly log π / T, dividing by T + provides a reasonable approximation that breaks the feedback loop. + + uncertainty ≈ -mean(log μ_T) / T + + This way: + - If T was high and log μ_T was -0.3, uncertainty ≈ 0.3/1.5 = 0.2 + - If T was low and log μ_T was -0.05, uncertainty ≈ 0.05/0.7 = 0.07 + + Both reflect similar underlying model confidence. + + Args: + agent_data: Current agent state with response_logprobs + + Returns: + float: Uncertainty score, clipped to [0.1, 3.0] + """ + import numpy as np + + if not agent_data.response_logprobs or len(agent_data.response_logprobs) == 0: + print("[TurnLevelTemp] No response_logprobs, using default uncertainty 1.0") + return 1.0 + + # Find action log_probs (response_mask == 1) + action_logprobs = [] + for mask, logprob in zip( + agent_data.response_mask[-len(agent_data.response_logprobs) :], + agent_data.response_logprobs, + ): + if mask == 1 and logprob != 0.0 and logprob is not None: + action_logprobs.append(logprob) + + if not action_logprobs: + action_logprobs = [ + lp + for lp in agent_data.response_logprobs + if lp != 0.0 and lp is not None + ] + + if not action_logprobs: + print("[TurnLevelTemp] No valid log_probs, using default uncertainty 1.0") + return 1.0 + + # Get temperature used for this turn + current_turn_T = ( + agent_data.turn_temperatures[-1] if agent_data.turn_temperatures else 1.0 + ) + + mean_logprob_mu_T = np.mean(action_logprobs) + + # Temperature-compensated uncertainty: divide by T to approximate log π + # This breaks the positive feedback loop where T affects log μ_T which affects T + raw_uncertainty = -mean_logprob_mu_T / current_turn_T + + # Clip to reasonable range [0.1, 3.0] + # Using tighter range to prevent extreme temperature swings + uncertainty = max(0.1, min(3.0, raw_uncertainty)) + + print( + f"[TurnLevelTemp] Heuristic uncertainty: {uncertainty:.3f} " + f"(raw={raw_uncertainty:.3f}, mean_log_μT={mean_logprob_mu_T:.3f}, T={current_turn_T:.2f})" + ) + return uncertainty + async def _save_debug_images(self, image_data: list, request_id: str): """Save debug images to disk when SAVE_DEBUG_IMAGES env var is set. diff --git a/opentinker/server/http_training_server.py b/opentinker/server/http_training_server.py index 809be9a0..40f05dc2 100755 --- a/opentinker/server/http_training_server.py +++ b/opentinker/server/http_training_server.py @@ -975,7 +975,7 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: batch, self.reward_fn ) - # 6. Compute old_log_probs + # 6. Compute old_log_probs (behavior policy log probabilities) with marked_timer("old_log_prob", timing_raw, color="blue"): # ===== DEBUG LOGGING START ===== logger.info("=" * 80) @@ -1007,11 +1007,76 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: logger.info("=" * 80) # ===== DEBUG LOGGING END ===== - old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + # Check if turn-level temperature is enabled + # When enabled, rollout_log_probs already contains log μ_T (behavior policy log prob) + # which should be used directly as old_log_probs for correct IS correction + turn_temp_config = self.config.actor_rollout_ref.rollout.multi_turn.get( + "turn_level_temperature", {} + ) + turn_temp_enabled = turn_temp_config.get("enabled", False) + has_rollout_logprobs = "rollout_log_probs" in batch.batch + use_rollout_logprobs_as_old = turn_temp_enabled and has_rollout_logprobs + + # CRITICAL: Warn if turn-level temp is enabled but rollout_log_probs not available + if turn_temp_enabled and not has_rollout_logprobs: + logger.warning( + "[TurnLevelTemp] WARNING: turn_level_temperature is enabled but rollout_log_probs " + "is not in batch! This means calculate_log_probs=False in rollout config. " + "Set actor_rollout_ref.rollout.calculate_log_probs=True for correct IS correction." + ) + + # Log turn-level temperature statistics if available + has_token_temps = "token_temperatures" in batch.non_tensor_batch + if has_token_temps: + token_temps_raw = batch.non_tensor_batch.get( + "token_temperatures", None + ) + if token_temps_raw is not None: + all_temps = [] + for temps in token_temps_raw: + if temps is not None and len(temps) > 0: + all_temps.extend(temps) + if all_temps: + metrics["turn_temp/mean"] = np.mean(all_temps) + metrics["turn_temp/std"] = ( + np.std(all_temps) if len(all_temps) > 1 else 0.0 + ) + metrics["turn_temp/min"] = np.min(all_temps) + metrics["turn_temp/max"] = np.max(all_temps) + adjusted_count = sum( + 1 for t in all_temps if abs(t - 1.0) > 0.01 + ) + metrics["turn_temp/adjusted_ratio"] = adjusted_count / len( + all_temps + ) + logger.info( + f"[TurnLevelTemp] Batch stats: mean={metrics['turn_temp/mean']:.3f}, " + f"range=[{metrics['turn_temp/min']:.3f}, {metrics['turn_temp/max']:.3f}], " + f"adjusted_ratio={metrics['turn_temp/adjusted_ratio']:.2%}" + ) + + # Also log turn-level uncertainties if available + if "turn_uncertainties" in batch.non_tensor_batch: + turn_u_raw = batch.non_tensor_batch.get("turn_uncertainties", None) + if turn_u_raw is not None: + all_u = [] + for u_list in turn_u_raw: + if u_list is not None and len(u_list) > 0: + all_u.extend(u_list) + if all_u: + metrics["turn_temp/uncertainty_mean"] = np.mean(all_u) + metrics["turn_temp/uncertainty_std"] = ( + np.std(all_u) if len(all_u) > 1 else 0.0 + ) + + # Always compute log_prob with T=1 to get entropy for entropy bonus + # (entropy should be computed on the learned policy, not the behavior policy) + actor_log_prob = self.actor_rollout_wg.compute_log_prob(batch) logger.info( - f"DEBUG: old_log_prob keys: {list(old_log_prob.batch.keys())}" + f"DEBUG: actor_log_prob keys: {list(actor_log_prob.batch.keys())}" ) - entropys = old_log_prob.batch["entropys"] + + entropys = actor_log_prob.batch["entropys"] response_masks = batch.batch["response_mask"] from verl.trainer.ppo.core_algos import agg_loss @@ -1039,10 +1104,42 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: # ===== DEBUG LOGGING END ===== metrics.update(old_log_prob_metrics) - old_log_prob.batch.pop("entropys") - batch = batch.union(old_log_prob) + actor_log_prob.batch.pop("entropys") + + # CRITICAL FIX: Use correct behavior policy log_prob for IS correction + # When turn-level temperature is enabled: + # - rollout_log_probs = log μ_T = log softmax(z/T) (from sglang/vLLM during sampling) + # - This is the TRUE behavior policy log_prob and should be used as old_log_probs + # - ratio = π_θ_new / μ_T = exp(log π_θ_new - log μ_T) + # When turn-level temperature is NOT enabled (T=1 everywhere): + # - Use actor_log_prob["old_log_probs"] as before (both are log π) + if use_rollout_logprobs_as_old: + logger.info( + "[TurnLevelTemp] Using rollout_log_probs as old_log_probs for correct IS correction" + ) + # rollout_log_probs is already log μ_T from the behavior policy + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + + # Log the difference between actor (T=1) and rollout (T≠1) log_probs + actor_lp = actor_log_prob.batch["old_log_probs"] + rollout_lp = batch.batch["rollout_log_probs"] + diff = (actor_lp - rollout_lp).abs() + masked_diff = diff * response_masks + if response_masks.sum() > 0: + mean_diff = masked_diff.sum() / response_masks.sum() + max_diff = masked_diff.max() + metrics["turn_temp/logprob_diff_mean"] = mean_diff.item() + metrics["turn_temp/logprob_diff_max"] = max_diff.item() + logger.info( + f"[TurnLevelTemp] log_prob diff (actor vs rollout): " + f"mean={mean_diff.item():.4f}, max={max_diff.item():.4f}" + ) + else: + # Standard case: no temperature variation, actor log_prob = behavior log_prob + batch = batch.union(actor_log_prob) + logger.info( - f"DEBUG: batch keys after old_log_prob union: {list(batch.batch.keys())}" + f"DEBUG: batch keys after old_log_prob processing: {list(batch.batch.keys())}" ) # Calculate debug metrics for rollout vs actor log probs mismatch @@ -1120,6 +1217,41 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: config=self.config.algorithm, ) + # ===================================================================== + # WM-Guided Dynamic Entropy: Compute beta_token at trainer level + # This ensures stable weights across PPO epochs (using lagged θ^-) + # ===================================================================== + actor_config = self.config.actor_rollout_ref.actor + # Handle both dict and dataclass/OmegaConf access patterns + if hasattr(actor_config, "get"): + wm_dynamic_entropy_config = actor_config.get( + "wm_dynamic_entropy", {} + ) + elif hasattr(actor_config, "wm_dynamic_entropy"): + wm_dynamic_entropy_config = actor_config.wm_dynamic_entropy + else: + wm_dynamic_entropy_config = {} + + # Check if enabled + is_enabled = False + if isinstance(wm_dynamic_entropy_config, dict): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + elif hasattr(wm_dynamic_entropy_config, "enabled"): + is_enabled = wm_dynamic_entropy_config.enabled + elif hasattr(wm_dynamic_entropy_config, "get"): + is_enabled = wm_dynamic_entropy_config.get("enabled", False) + + print( + f"[WM Dynamic Entropy] DEBUG: is_enabled = {is_enabled}, config type = {type(wm_dynamic_entropy_config)}" + ) + + if is_enabled: + with marked_timer("wm_beta_token", timing_raw, color="magenta"): + batch, wm_metrics = self.trainer._compute_wm_beta_token( + batch, wm_dynamic_entropy_config + ) + metrics.update(wm_metrics) + # 10. Update critic if self.use_critic: with marked_timer("update_critic", timing_raw, color="pink"): @@ -1135,6 +1267,10 @@ def train_step(self, batch: DataProto) -> Dict[str, Any]: batch.meta_info["multi_turn"] = ( self.config.actor_rollout_ref.rollout.multi_turn.enable ) + # Required by dp_actor.update_policy + batch.meta_info["temperature"] = ( + self.config.actor_rollout_ref.rollout.temperature + ) actor_output = self.actor_rollout_wg.update_actor(batch) actor_output_metrics = reduce_metrics( actor_output.meta_info["metrics"] diff --git a/opentinker/server/launch_http_server.py b/opentinker/server/launch_http_server.py index 6c2561de..928e7671 100755 --- a/opentinker/server/launch_http_server.py +++ b/opentinker/server/launch_http_server.py @@ -22,6 +22,8 @@ def main(cfg): os.environ["RAY_DEDUP_LOGS"] = "0" os.environ["HYDRA_FULL_ERROR"] = "1" + # Disable sleep mode to avoid cumem allocator issues (CUDA Error: invalid argument) + os.environ["VLLM_DISABLE_SLEEP_MODE"] = "1" from omegaconf import open_dict import logging @@ -123,7 +125,9 @@ def main(cfg): cfg.trainer.save_freq = 500 cfg.trainer.test_freq = 500 cfg.trainer.total_epochs = 15 - cfg.trainer.default_local_dir = "/workspace/verl/verl/ckpts" + cfg.trainer.default_local_dir = os.path.expanduser( + "/mnt/disk1_from_server2/haofeiy2/opentinker_checkpoints" + ) # --------------------------------------------------------- # Agent Loop Configuration @@ -138,8 +142,11 @@ def main(cfg): logger.info("Agent Loop Mode Enabled") logger.info("=" * 60) + # Async engine requires V1, so force it. VLLM_DISABLE_SLEEP_MODE=1 handles cumem issues. os.environ["VLLM_USE_V1"] = "1" - logger.info("Set VLLM_USE_V1=1 for async rollout") + logger.info( + "VLLM_USE_V1=1 for async rollout (sleep mode disabled to avoid cumem issues)" + ) # Increase Ray's memory threshold to avoid premature OOM kills # Default is 0.95 (95%), we increase to 0.98 (98%)