diff --git a/README.md b/README.md index 5fad55c5..8c0aceac 100755 --- a/README.md +++ b/README.md @@ -18,17 +18,17 @@ Choose an example below to get started. Each example includes step-by-step instructions for setup, training, and inference. -| Task | Description | Performance | -| ------------------------------------------------ | ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------- | -| **[LLM Single-Turn Math](docs/math_singleturn.md)** | Mathematical problem solving | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/bwkq1wl8?nw=nwuserzhusq20) | -| **[LLM Multi-Turn Math](docs/math_multiturn.md)** | Multi-turn mathematical problem solving with tool calling | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/f5pt6gcw?nw=nwuserzhusq20) | -| **[LLM Single-LoRA Single-Turn Math](docs/math_lora_singleturn.md)** | Math single-turn Trained With LoRA | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/cl1w5l07?nw=nwuserzhusq20) | -| **[VLM Single-Turn Math](docs/vlm_geo3k_singleturn.md)** | geometry 3k math problem solving | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/aidfc2y1?nw=nwuserzhusq20) | -| **[VLM Multi-Turn Math](docs/vlm_geo3k_multiturn.md)** | geometry 3k math problem solving with tool calling | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/r39htm2o?nw=nwuserzhusq20) | -| **[LLM Gomoku Agent](docs/gomoku_multiturn.md)** | A multi-turn gomoku agent | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/7a7ggkw3?nw=nwuserzhusq20) | -| **[LLM AlfWorld Agent](docs/alfworld_multiturn.md)** | A multi-turn alfworld agent | [wandb](https://wandb.ai/1125027232/opentinker-public/runs/3jrlolk7?nw=nwuser1125027232) | -| **[LLM Android World Agent](docs/android_world_multiturn.md)** | A multi-turn android world agent | | - +| Task | Description | Performance | +| -------------------------------------------------------------------- | --------------------------------------------------------- | ---------------------------------------------------------------------------------------- | +| **[LLM Single-Turn Math](docs/math_singleturn.md)** | Mathematical problem solving | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/bwkq1wl8?nw=nwuserzhusq20) | +| **[LLM Multi-Turn Math](docs/math_multiturn.md)** | Multi-turn mathematical problem solving with tool calling | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/f5pt6gcw?nw=nwuserzhusq20) | +| **[LLM Single-LoRA Single-Turn Math](docs/math_lora_singleturn.md)** | Math single-turn Trained With LoRA | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/cl1w5l07?nw=nwuserzhusq20) | +| **[VLM Single-Turn Math](docs/vlm_geo3k_singleturn.md)** | geometry 3k math problem solving | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/aidfc2y1?nw=nwuserzhusq20) | +| **[VLM Multi-Turn Math](docs/vlm_geo3k_multiturn.md)** | geometry 3k math problem solving with tool calling | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/r39htm2o?nw=nwuserzhusq20) | +| **[LLM Gomoku Agent](docs/gomoku_multiturn.md)** | A multi-turn gomoku agent | [wandb](https://wandb.ai/zsqzz/Open-Tinker/runs/7a7ggkw3?nw=nwuserzhusq20) | +| **[LLM AlfWorld Agent](docs/alfworld_multiturn.md)** | A multi-turn alfworld agent | [wandb](https://wandb.ai/1125027232/opentinker-public/runs/3jrlolk7?nw=nwuser1125027232) | +| **[LLM SciWorld Agent](docs/sciworld_multiturn.md)** | A multi-turn ScienceWorld agent | | +| **[LLM Android World Agent](docs/android_world_multiturn.md)** | A multi-turn android world agent | | ## 📦 Installation @@ -149,12 +149,12 @@ This 2×2 design space enables four distinct paradigms, each suited to different ``` @misc{zhu2026opentinkerseparatingconcernsagentic, - title={OpenTinker: Separating Concerns in Agentic Reinforcement Learning}, + title={OpenTinker: Separating Concerns in Agentic Reinforcement Learning}, author={Siqi Zhu and Jiaxuan You}, year={2026}, eprint={2601.07376}, archivePrefix={arXiv}, primaryClass={cs.AI}, - url={https://arxiv.org/abs/2601.07376}, + url={https://arxiv.org/abs/2601.07376}, } ``` diff --git a/docs/sciworld_multiturn.md b/docs/sciworld_multiturn.md new file mode 100644 index 00000000..cd5c5582 --- /dev/null +++ b/docs/sciworld_multiturn.md @@ -0,0 +1,117 @@ +# LLM Game Agent (ScienceWorld Multi-Turn) + +**Author:** Haofeiy + +This example demonstrates training a language model to complete science tasks in the ScienceWorld text-based environment. + +## Overview + +ScienceWorld is a text-based benchmark of grounded science tasks. Tasks include: + +- Boiling / freezing / melting substances +- Identifying and classifying living things +- Using instruments (thermometer, microscope, etc.) +- Combining materials to produce reactions +- Navigating rooms to find and manipulate objects + +OpenTinker support follows the same pattern as ALFWorld: + +- `SciWorldGame` wraps the benchmark as an `AbstractGame` +- `sciworld_server.py` exposes the environment over the generic FastAPI server +- `sciworld_rl.py` trains against that server through `GameEnvironment` + +## Prerequisites + +1. Complete the [Installation](../README.md#-installation) steps +2. Install ScienceWorld: `pip install scienceworld` +3. Ensure **Java** is available (`java -version`), since ScienceWorld launches a JVM-backed server +4. Get your IP address if client and scheduler run on different machines: `hostname -I` + +## Step 1: Start the Scheduler + +```bash +bash opentinker/scripts/launch_scheduler.sh --scheduler-port +``` + +## Step 2: Start the ScienceWorld Environment + +```bash +python -m opentinker.environment.sciworld.sciworld_server \ + --port \ + --max_steps 30 \ + --split train \ + --shards 8 \ + --threads-per-shard 256 +``` + +Optional task restriction: + +```bash +python -m opentinker.environment.sciworld.sciworld_server \ + --port \ + --split train \ + --task-name boil \ + --task-name find-animal +``` + +Useful server options: + +- `--split`: ScienceWorld split to sample variations from (`train`, `dev`, `test`) +- `--task-name`: Repeat to restrict the task pool +- `--task-id`: Alternative to task names if you prefer numeric task ids +- `--variation`: Repeat to restrict to explicit variation ids +- `--simplification-str`: Pass-through simplification string for `env.load()` +- `--thread-base`: Base ScienceWorld thread number for this server group +- `--threads-per-shard`: Reserved thread-number block per shard + +## Step 3: Run Training + +```bash +python opentinker/client/sciworld_rl.py \ + tokenizer_path=Qwen/Qwen2.5-3B-Instruct \ + batch_size=4 \ + num_steps=1000 \ + test_freq=10 \ + scheduler_url=http://: \ + interaction.config.env_port= \ + interaction.config.env_host= \ + interaction.config.split=train \ + interaction.config.local_thread_base=20000 +``` + +## Notes + +- Keep `num_workers=0` for the local prompt-generation dataloaders unless you + explicitly manage non-overlapping ScienceWorld thread bases per worker. +- Use the same `split`, `task_names`, `task_ids`, and `variation_indices` + settings on both the environment server and the client config so prompt + generation matches the remote environment. +- If you already run ScienceWorld-backed processes on the same machine, move + `--thread-base` and `interaction.config.local_thread_base` to disjoint ranges. + +## Reward Structure + +| Event | Reward | +| ---------------- | ------ | +| Task Success | +10.0 | +| Task Failure | -1.0 | +| Per Step Penalty | -0.01 | +| Invalid Action | -0.1 | + +## Example Actions + +The agent interacts with the environment using text commands: + +- `look around` - Observe the current room +- `open door to kitchen` - Navigate between rooms +- `pick up thermometer` - Pick up an object +- `use thermometer on water` - Use an instrument +- `pour water into beaker` - Combine or transfer materials +- `focus on substance in microscope` - Examine with instruments +- `inventory` - Check held items +- `wait` - Wait one step (e.g. for a reaction) + +## Configuration Reference + +See [`opentinker/client/client_config/sciworld_param.yaml`](../opentinker/client/client_config/sciworld_param.yaml) +for the full configuration. diff --git a/opentinker/client/client_config/sciworld_param.yaml b/opentinker/client/client_config/sciworld_param.yaml new file mode 100644 index 00000000..8c35fdca --- /dev/null +++ b/opentinker/client/client_config/sciworld_param.yaml @@ -0,0 +1,81 @@ +# ScienceWorld Training Configuration +# Use with: python sciworld_rl.py + +# Project settings +project_name: opentinker +experiment_name: sciworld_training + +# Logging +logger_backends: ["console", "wandb"] + +# Tracing (optional) +enable_tracing: true +weave_project: null + +# WandB (optional) +wandb_key: null + +# Model and tokenizer +tokenizer_path: null + +# Training parameters +batch_size: 4 +num_workers: 0 # Keep single-process; ScienceWorld uses per-instance JVM ports +# Training duration - set ONE of these (num_steps takes precedence if both set) +num_epochs: null # Number of epochs (null = use num_steps) +num_steps: 1000 # Total training steps (null = use num_epochs) +save_freq: 200 +test_freq: 50 # Validation frequency (every N steps) + +# Validation parameters +val_batch_size: 32 # Total validation samples (null = 50) + +# Generation parameters +temperature: 1 # Lower temperature for more focused responses +top_p: 1 +max_new_tokens: 4096 # TOTAL response budget for entire multi-turn trajectory (NOT per-turn!) +max_prompt_tokens: 2048 + +# Algorithm (must be agent_loop for multi-turn) +algorithm: "agent_loop" + +# RL Algorithm settings (passed to server via scheduler) +# adv_estimator options: +# - "grpo" : Standard GRPO (outcome-only advantage) +# - "grpo_per_step" : Per-step GRPO with return-based advantages (for multi-turn tasks) +# - "gae" : Generalized Advantage Estimation (for PPO, requires critic) +adv_estimator: "grpo" +# rollout_n: number of samples per prompt for GRPO/grpo_per_step +# For PPO (gae), rollout_n is typically 1 +rollout_n: 8 + +interaction: + name: sciworld + class_path: opentinker.environment.gym_environment_interaction.GymEnvironmentInteraction + config: + env_host: 0.0.0.0 + env_port: 8092 + env_endpoint: http://${interaction.config.env_host}:${interaction.config.env_port} + env_shards: 8 + max_steps: 30 + max_total_steps: 30 + observation_template: "{observation}" + split: train # train, dev, test + task_names: null # e.g. ["boil", "find-animal"] + task_ids: null + variation_indices: null + simplification_str: "" + jar_path: null + local_thread_base: 20000 + +multi_turn: + max_user_turns: ${interaction.config.max_total_steps} + max_assistant_turns: ${interaction.config.max_total_steps} + max_tokens_per_turn: 512 + weave_project: "opentinker/sciworld" + experiment_name: "sciworld_interaction" + +scheduler_url: "http://0.0.0.0:8780" +scheduler_api_key: null + +num_gpus: 8 diff --git a/opentinker/client/sciworld_eval.py b/opentinker/client/sciworld_eval.py new file mode 100644 index 00000000..e2d0c001 --- /dev/null +++ b/opentinker/client/sciworld_eval.py @@ -0,0 +1,731 @@ +#!/usr/bin/env python3 +"""Standalone ScienceWorld evaluation script. + +This script reuses OpenTinker's existing inference pipeline and ScienceWorld game +integration to run evaluation and export: +1. Per-sample JSONL records with full prompt/response and score (0/1) +2. Final aggregated scores per task name and overall +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import re +import time +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +from opentinker.environment.sciworld import SciWorldGame +from opentinker.environment.inference_pipeline import ( + InferencePipeline, + RemoteEnvironmentClient, +) + + +def build_fixed_order_samples( + args: argparse.Namespace, + split: str, +) -> list[dict[str, Any]]: + """Build fixed-order samples from the split's complete episode list.""" + game = SciWorldGame( + max_steps=args.max_steps, + split=split, + task_names=args.task_name, + task_ids=args.task_id, + simplification_str=args.simplification_str, + jar_path=args.jar_path, + thread_base=50000, + ) + + try: + task_pool = game._resolve_task_pool() + pairs: list[tuple[str, int]] = [] + for task_name in sorted(task_pool): + variations = game._resolve_variations_for_task(task_name) + for variation in sorted(variations): + pairs.append((task_name, variation)) + + total_available = len(pairs) + if total_available <= 0: + raise RuntimeError( + f"No available ScienceWorld episodes found for split={split}" + ) + + if args.max_samples is not None and args.max_samples > 0: + effective = min(args.max_samples, total_available) + else: + effective = total_available + + pairs = pairs[:effective] + system_prompt = game.get_system_prompt() + samples: list[dict[str, Any]] = [] + + for idx, (task_name, variation) in enumerate( + tqdm(pairs, desc="Build Fixed Episode List") + ): + user_prompt = game.get_user_message_with_state( + task_name=task_name, variation=variation + ) + samples.append( + { + "prompt": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "env_kwargs": { + "task_name": task_name, + "variation": variation, + }, + "episode_index": idx, + "task_name": task_name, + "variation": variation, + } + ) + + print( + f"Built fixed-order ScienceWorld episodes: " + f"{effective}/{total_available} (split={split})" + ) + return samples + finally: + game.close() + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Evaluate a model on ScienceWorld and export JSONL + summary scores." + ) + + # Model backend + parser.add_argument( + "--model-path", + type=str, + default=None, + help="Model path for offline vLLM mode (required if --vllm-server-url is not set).", + ) + parser.add_argument( + "--vllm-server-url", + type=str, + default=None, + help="Existing vLLM server URL for server mode (e.g., http://127.0.0.1:8000).", + ) + parser.add_argument( + "--tokenizer-path", + type=str, + default=None, + help="Tokenizer path (defaults to model path).", + ) + + # Environment/eval setup + parser.add_argument( + "--env-endpoint", + type=str, + default="http://127.0.0.1:8092", + help="ScienceWorld environment server endpoint for single-split mode.", + ) + parser.add_argument( + "--dev-env-endpoint", + type=str, + default=None, + help="ScienceWorld environment endpoint for dev split.", + ) + parser.add_argument( + "--test-env-endpoint", + type=str, + default=None, + help="ScienceWorld environment endpoint for test split.", + ) + parser.add_argument( + "--job-id", + type=str, + default="sciworld_eval", + help="Job id used for environment stats isolation.", + ) + parser.add_argument( + "--split", + type=str, + default="both", + choices=["train", "dev", "test", "both"], + help="ScienceWorld split. Use 'both' to run dev+test in one command.", + ) + parser.add_argument( + "--max-samples", + type=int, + default=-1, + help=( + "Number of evaluation samples to run. " + "Use <=0 to auto-run all available samples for the split." + ), + ) + parser.add_argument( + "--max-steps", + type=int, + default=30, + help="Maximum environment steps per episode.", + ) + parser.add_argument( + "--task-name", + action="append", + default=None, + help="Restrict to specific task names. Repeat for multiple tasks.", + ) + parser.add_argument( + "--task-id", + action="append", + type=int, + default=None, + help="Restrict to specific task ids. Repeat for multiple ids.", + ) + parser.add_argument( + "--simplification-str", + type=str, + default="", + help="ScienceWorld simplification string.", + ) + parser.add_argument( + "--jar-path", + type=str, + default=None, + help="Optional path to the ScienceWorld JAR.", + ) + + # Generation settings + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument( + "--max-tokens", + type=int, + default=8192, + help="Total generation budget across all assistant turns.", + ) + parser.add_argument( + "--max-tokens-per-turn", + type=int, + default=512, + help="Per-turn generation budget.", + ) + parser.add_argument("--max-user-turns", type=int, default=30) + parser.add_argument("--max-assistant-turns", type=int, default=30) + parser.add_argument("--max-context-length", type=int, default=30000) + parser.add_argument("--tensor-parallel-size", type=int, default=1) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + + # Output + parser.add_argument( + "--output-jsonl", + type=str, + default=None, + help="Output JSONL path for single-split mode.", + ) + parser.add_argument( + "--dev-output-jsonl", + type=str, + default="outputs/sciworld_eval_dev.jsonl", + help="Output JSONL path for dev split.", + ) + parser.add_argument( + "--test-output-jsonl", + type=str, + default="outputs/sciworld_eval_test.jsonl", + help="Output JSONL path for test split.", + ) + parser.add_argument( + "--summary-json", + type=str, + default=None, + help="Optional output path for summary JSON (default: .summary.json).", + ) + + parser.add_argument( + "--continue-on-error", + action="store_true", + help="Skip failed samples and continue evaluation.", + ) + + args = parser.parse_args() + if not args.model_path and not args.vllm_server_url: + parser.error("Either --model-path or --vllm-server-url must be provided.") + if args.split == "both": + if not args.dev_env_endpoint or not args.test_env_endpoint: + parser.error( + "--split both requires both --dev-env-endpoint and --test-env-endpoint." + ) + else: + if not args.output_jsonl: + parser.error( + "--output-jsonl is required for single-split mode (split != both)." + ) + return args + + +def extract_task_name( + sample: dict[str, Any], env_info_trace: list[dict[str, Any]] +) -> str: + for info in env_info_trace: + task_name = info.get("task_name") + if isinstance(task_name, str) and task_name.strip(): + return task_name.strip() + return sample.get( + "task_name", sample.get("env_kwargs", {}).get("task_name", "unknown") + ) + + +def extract_task_description( + sample: dict[str, Any], env_info_trace: list[dict[str, Any]] +) -> str: + for info in env_info_trace: + task = info.get("task") + if isinstance(task, str) and task.strip(): + return task.strip() + + user_prompt = "" + for msg in sample.get("prompt", []): + if msg.get("role") == "user": + user_prompt = msg.get("content", "") + break + + match = re.search(r"Task:\s*(.+?)(?:\n\n|$)", user_prompt, flags=re.DOTALL) + if match: + return match.group(1).strip() + + return user_prompt.strip() + + +def _coerce_optional_bool(value: Any) -> bool | None: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"true", "1", "yes", "y"}: + return True + if lowered in {"false", "0", "no", "n"}: + return False + if isinstance(value, (list, tuple)): + if not value: + return None + return _coerce_optional_bool(value[0]) + return None + + +def extract_episode_success( + env_info_trace: list[dict[str, Any]], +) -> tuple[bool | None, str]: + for info in reversed(env_info_trace): + if not isinstance(info, dict): + continue + if "success" in info: + success = _coerce_optional_bool(info.get("success")) + if success is not None: + return success, "success" + if "score" in info: + score = info.get("score") + if isinstance(score, (int, float)) and score >= 100.0: + return True, "score" + return None, "reward_fallback" + + +def compute_summary( + records: list[dict[str, Any]], elapsed_sec: float +) -> dict[str, Any]: + stats: dict[str, dict[str, int]] = {} + overall_total = 0 + overall_correct = 0 + errors = 0 + + for rec in records: + if rec.get("error"): + errors += 1 + continue + task_name = rec.get("task_name", "unknown") + score = int(rec.get("score", 0)) + + if task_name not in stats: + stats[task_name] = {"total": 0, "correct": 0} + + stats[task_name]["total"] += 1 + stats[task_name]["correct"] += score + overall_total += 1 + overall_correct += score + + per_task = {} + for name in sorted(stats.keys()): + total = stats[name]["total"] + correct = stats[name]["correct"] + acc = (correct / total) if total > 0 else None + per_task[name] = { + "correct": correct, + "total": total, + "score": acc, + } + + all_score = (overall_correct / overall_total) if overall_total > 0 else None + return { + "elapsed_sec": elapsed_sec, + "num_records": len(records), + "num_evaluated": overall_total, + "num_errors": errors, + "tasks": per_task, + "all": { + "correct": overall_correct, + "total": overall_total, + "score": all_score, + }, + } + + +def print_summary(summary: dict[str, Any]) -> None: + print("\n" + "=" * 64) + print("ScienceWorld Evaluation Summary") + print("=" * 64) + + for name in sorted(summary["tasks"].keys()): + item = summary["tasks"][name] + correct = item["correct"] + total = item["total"] + score = item["score"] + score_str = f"{score * 100:.2f}%" if score is not None else "N/A" + print(f"{name:>40}: {score_str:>8} ({correct}/{total})") + + all_item = summary["all"] + all_score = all_item["score"] + all_score_str = f"{all_score * 100:.2f}%" if all_score is not None else "N/A" + print("-" * 64) + print( + f"{'All':>40}: {all_score_str:>8} ({all_item['correct']}/{all_item['total']})" + ) + print("-" * 64) + print( + f"Samples={summary['num_records']}, Evaluated={summary['num_evaluated']}, " + f"Errors={summary['num_errors']}, Elapsed={summary['elapsed_sec']:.1f}s" + ) + print("=" * 64) + + +def build_combined_split_summary( + dev_summary: dict[str, Any], test_summary: dict[str, Any] +) -> dict[str, Any]: + all_task_names = sorted( + set(dev_summary["tasks"].keys()) | set(test_summary["tasks"].keys()) + ) + + empty = {"correct": 0, "total": 0, "score": None} + tasks: dict[str, Any] = {} + for name in all_task_names: + dev_item = dev_summary["tasks"].get(name, empty) + test_item = test_summary["tasks"].get(name, empty) + combined_correct = dev_item["correct"] + test_item["correct"] + combined_total = dev_item["total"] + test_item["total"] + combined_score = ( + combined_correct / combined_total if combined_total > 0 else None + ) + tasks[name] = { + "dev": dev_item, + "test": test_item, + "combined": { + "correct": combined_correct, + "total": combined_total, + "score": combined_score, + }, + } + + dev_all = dev_summary["all"] + test_all = test_summary["all"] + combined_all_correct = dev_all["correct"] + test_all["correct"] + combined_all_total = dev_all["total"] + test_all["total"] + combined_all_score = ( + combined_all_correct / combined_all_total if combined_all_total > 0 else None + ) + + return { + "tasks": tasks, + "overall": { + "dev": dev_all, + "test": test_all, + "combined": { + "correct": combined_all_correct, + "total": combined_all_total, + "score": combined_all_score, + }, + }, + } + + +def print_dual_split_summary(dual_summary: dict[str, Any]) -> None: + print("\n" + "=" * 100) + print("ScienceWorld Dev/Test Combined Summary") + print("=" * 100) + print(f"{'Task':>40} Dev Test Combined") + print("-" * 100) + + for name in sorted(dual_summary["tasks"].keys()): + item = dual_summary["tasks"][name] + dev = item["dev"] + test = item["test"] + combined = item["combined"] + dev_s = f"{dev['score'] * 100:.2f}%" if dev["score"] is not None else "N/A" + test_s = f"{test['score'] * 100:.2f}%" if test["score"] is not None else "N/A" + combined_s = ( + f"{combined['score'] * 100:.2f}%" + if combined["score"] is not None + else "N/A" + ) + print( + f"{name:>40} {dev_s:>8} ({dev['correct']:>3}/{dev['total']:<3}) " + f"{test_s:>8} ({test['correct']:>3}/{test['total']:<3}) " + f"{combined_s:>8} ({combined['correct']:>3}/{combined['total']:<3})" + ) + + print("-" * 100) + overall = dual_summary["overall"] + dev_all = overall["dev"] + test_all = overall["test"] + combined_all = overall["combined"] + dev_all_s = ( + f"{dev_all['score'] * 100:.2f}%" if dev_all["score"] is not None else "N/A" + ) + test_all_s = ( + f"{test_all['score'] * 100:.2f}%" if test_all["score"] is not None else "N/A" + ) + combined_all_s = ( + f"{combined_all['score'] * 100:.2f}%" + if combined_all["score"] is not None + else "N/A" + ) + print( + f"{'All':>40} {dev_all_s:>8} ({dev_all['correct']:>3}/{dev_all['total']:<3}) " + f"{test_all_s:>8} ({test_all['correct']:>3}/{test_all['total']:<3}) " + f"{combined_all_s:>8} ({combined_all['correct']:>3}/{combined_all['total']:<3})" + ) + print("=" * 100) + + +async def run_single_split_eval( + args: argparse.Namespace, + pipeline: InferencePipeline, + split: str, + env_endpoint: str, + output_jsonl_path: str, +) -> tuple[dict[str, Any], Path]: + samples = build_fixed_order_samples(args, split=split) + + pipeline.env_client = RemoteEnvironmentClient(env_endpoint, job_id=args.job_id) + healthy = await pipeline.env_client.health_check() + if not healthy: + raise RuntimeError(f"ScienceWorld server not available at {env_endpoint}") + print(f"Connected to ScienceWorld server at {env_endpoint} (split={split})") + + output_jsonl = Path(output_jsonl_path) + output_jsonl.parent.mkdir(parents=True, exist_ok=True) + output_jsonl.write_text("") + + start = time.time() + records: list[dict[str, Any]] = [] + + for idx, sample in enumerate(tqdm(samples, desc="ScienceWorld Eval")): + env_kwargs = sample.get("env_kwargs", {}) + task_name_from_sample = sample.get("task_name", "unknown") + + try: + result = await pipeline.run_single_inference( + messages=sample["prompt"], + env_kwargs=env_kwargs, + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + max_tokens_per_turn=args.max_tokens_per_turn, + ) + + env_info_trace = result.info.get("env_info", []) + task_name = extract_task_name(sample, env_info_trace) + task_description = extract_task_description(sample, env_info_trace) + episode_success, score_source = extract_episode_success(env_info_trace) + if episode_success is None: + episode_success = bool(result.reward > 0) + score_source = "reward_fallback" + score = int(episode_success) + + assistant_turns = [ + m.get("content", "") + for m in result.messages + if m.get("role") == "assistant" + ] + action_trace = [ + x.get("action_taken") for x in env_info_trace if x.get("action_taken") + ] + raw_reward_trace = [ + x.get("raw_reward") for x in env_info_trace if "raw_reward" in x + ] + success_trace = [x.get("success") for x in env_info_trace if "success" in x] + score_trace = [x.get("score") for x in env_info_trace if "score" in x] + + record = { + "index": idx, + "sample_id": result.sample_id, + "split": split, + "task_name": task_name, + "variation": env_kwargs.get("variation"), + "question": { + "system_prompt": sample["prompt"][0]["content"] + if sample.get("prompt") + else "", + "initial_user_prompt": sample["prompt"][1]["content"] + if len(sample.get("prompt", [])) > 1 + else "", + "full_prompt_text": result.prompt_text, + "task_description": task_description, + }, + "model_answer": { + "assistant_turns": assistant_turns, + "full_response_text": result.response_text, + "full_messages": result.messages, + }, + "ground_truth": { + "task_description": task_description, + "task_name": task_name, + "success_definition": "Environment success field (fallback: score >= 100, then reward > 0)", + }, + "score": score, + "success": bool(episode_success), + "score_source": score_source, + "final_reward": result.reward, + "done": result.done, + "num_turns": result.num_turns, + "env_kwargs": env_kwargs, + "episode_index": sample.get("episode_index"), + "action_trace": action_trace, + "raw_reward_trace": raw_reward_trace, + "success_trace": success_trace, + "score_trace": score_trace, + "env_info_trace": env_info_trace, + } + except Exception as e: + if not args.continue_on_error: + raise + record = { + "index": idx, + "split": split, + "task_name": task_name_from_sample, + "variation": env_kwargs.get("variation"), + "score": 0, + "error": str(e), + "env_kwargs": env_kwargs, + "episode_index": sample.get("episode_index"), + } + + records.append(record) + with output_jsonl.open("a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + + elapsed = time.time() - start + summary = compute_summary(records, elapsed_sec=elapsed) + summary.update( + { + "model_path": args.model_path, + "vllm_server_url": args.vllm_server_url, + "tokenizer_path": args.tokenizer_path, + "env_endpoint": env_endpoint, + "split": split, + "max_samples": len(samples), + "max_steps": args.max_steps, + "sample_mode": "fixed_episode_order", + "scoring_rule": "success -> score>=100 -> reward>0 fallback", + } + ) + + summary_path = output_jsonl.with_suffix(output_jsonl.suffix + ".summary.json") + summary_path.parent.mkdir(parents=True, exist_ok=True) + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2)) + + print_summary(summary) + print(f"Per-sample JSONL saved to: {output_jsonl}") + print(f"Summary JSON saved to: {summary_path}") + return summary, summary_path + + +async def run_eval(args: argparse.Namespace) -> None: + if args.split == "both": + split_runs = [ + ("dev", args.dev_env_endpoint, args.dev_output_jsonl), + ("test", args.test_env_endpoint, args.test_output_jsonl), + ] + else: + split_runs = [(args.split, args.env_endpoint, args.output_jsonl)] + + pipeline = InferencePipeline( + model_path=args.model_path, + tokenizer_path=args.tokenizer_path, + vllm_server_url=args.vllm_server_url, + env_endpoint=split_runs[0][1], + job_id=args.job_id, + max_user_turns=args.max_user_turns, + max_assistant_turns=args.max_assistant_turns, + max_context_length=args.max_context_length, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + + split_summaries: dict[str, dict[str, Any]] = {} + summary_paths: dict[str, str] = {} + for split_name, endpoint, output_jsonl in split_runs: + summary, summary_path = await run_single_split_eval( + args=args, + pipeline=pipeline, + split=split_name, + env_endpoint=endpoint, + output_jsonl_path=output_jsonl, + ) + split_summaries[split_name] = summary + summary_paths[split_name] = str(summary_path) + + if args.split == "both": + dual_summary = build_combined_split_summary( + dev_summary=split_summaries["dev"], + test_summary=split_summaries["test"], + ) + print_dual_split_summary(dual_summary) + + final_summary = { + "model_path": args.model_path, + "vllm_server_url": args.vllm_server_url, + "tokenizer_path": args.tokenizer_path, + "sample_mode": "fixed_episode_order", + "splits": { + "dev": split_summaries["dev"], + "test": split_summaries["test"], + }, + "combined": dual_summary, + "split_summary_paths": summary_paths, + } + summary_path = ( + Path(args.summary_json) + if args.summary_json + else Path("outputs/sciworld_eval_both.summary.json") + ) + summary_path.parent.mkdir(parents=True, exist_ok=True) + summary_path.write_text(json.dumps(final_summary, ensure_ascii=False, indent=2)) + print(f"Combined summary JSON saved to: {summary_path}") + else: + if args.summary_json: + single_summary_path = Path(args.summary_json) + single_summary_path.parent.mkdir(parents=True, exist_ok=True) + single_summary_path.write_text( + json.dumps(split_summaries[args.split], ensure_ascii=False, indent=2) + ) + print(f"Summary JSON saved to: {single_summary_path}") + + +def main() -> None: + args = parse_args() + asyncio.run(run_eval(args)) + + +if __name__ == "__main__": + main() diff --git a/opentinker/client/sciworld_rl.py b/opentinker/client/sciworld_rl.py new file mode 100644 index 00000000..594ebc97 --- /dev/null +++ b/opentinker/client/sciworld_rl.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +"""ScienceWorld RL Training Client. + +This script trains an LLM agent to complete science tasks in ScienceWorld. + +Usage: + # Start ScienceWorld server first (in another terminal): + python -m opentinker.environment.sciworld.sciworld_server --port 8092 + + # Run training: + python sciworld_rl.py scheduler_url=http://localhost:8780 num_gpus=2 +""" + +from omegaconf import OmegaConf +import hydra + +from utils.http_training_client import ServiceClient, SchedulerClient +from opentinker.environment.base_game_environment import GameEnvironment +from opentinker.environment.game_stats_client import GameStatsClient +from opentinker.environment.sciworld import SciWorldGame +from utils.utils import resolve_paths_in_config +from utils.scheduler_client_lifecycle import get_lifecycle_manager + + +@hydra.main(config_path="client_config", config_name="sciworld_param.yaml") +def main(args): + args = resolve_paths_in_config(args) + lifecycle = get_lifecycle_manager() + + enable_tracing = args.get("enable_tracing", False) + if enable_tracing: + try: + from opentinker.utils.rollout_trace_saver import init_weave_tracing + + weave_project = args.get("weave_project", "sciworld-training") + init_weave_tracing( + project_name=weave_project, + experiment_name=args.experiment_name, + token2text=True, + ) + except Exception as e: + print(f"\u26a0 Failed to initialize Weave tracing: {e}") + + print("=" * 60) + print("Training with ScienceWorld Environment") + print("=" * 60) + + scheduler_url = args.get("scheduler_url", "http://localhost:8780") + scheduler_api_key = args.get("scheduler_api_key", None) + + print(f"\nConnecting to scheduler at {scheduler_url}") + if scheduler_api_key: + print("\u2713 Using API key for authentication") + else: + print( + "\u26a0 No API key provided - authentication may fail if scheduler requires it" + ) + + scheduler_client = SchedulerClient( + scheduler_url=scheduler_url, api_key=scheduler_api_key + ) + + print("\nSubmitting training job to scheduler...") + job_result = scheduler_client.submit_job( + config=OmegaConf.to_container(args, resolve=True), + enable_agent_loop=True, + wandb_key=args.get("wandb_key"), + num_gpus=args.get("num_gpus"), + ) + + job_id = job_result["job_id"] + server_url = job_result["server_url"] + lifecycle.register_job(scheduler_client, job_id) + + print(f"\n\u2713 Job {job_id} allocated!") + print(f" Server URL: {server_url}") + print(f" GPUs: {job_result.get('gpu_ids')}") + print(f" Port: {job_result.get('port')}") + print("=" * 60) + + interaction_config = args.interaction.config + game_kwargs = { + "max_steps": interaction_config.get("max_total_steps", 30), + "split": interaction_config.get("split", "train"), + "task_names": interaction_config.get("task_names"), + "task_ids": interaction_config.get("task_ids"), + "variation_indices": interaction_config.get("variation_indices"), + "simplification_str": interaction_config.get("simplification_str", ""), + "jar_path": interaction_config.get("jar_path"), + "thread_base": interaction_config.get( + "local_thread_base", SciWorldGame.DEFAULT_LOCAL_THREAD_BASE + ), + } + + env_endpoint = interaction_config.env_endpoint + + print("\nSetting up GameEnvironment with SciWorldGame...") + print(f" Environment endpoint: {env_endpoint}") + print(f" Max steps: {game_kwargs['max_steps']}") + print(f" Split: {game_kwargs['split']}") + print(f" Job ID for stats: {job_id}") + + env = GameEnvironment( + game_class=SciWorldGame, + config=args, + game_kwargs=game_kwargs, + job_id=job_id, + ) + + print("\u2713 Environment created") + print(f" Interaction config path: {env.get_interaction_config_path()}") + + game_stats = GameStatsClient(env_endpoint, job_id=env.job_id) + if game_stats.health_check(): + print(f"\u2713 Connected to ScienceWorld server for metrics at {env_endpoint}") + game_stats.reset_all() + else: + print( + f"\u26a0 ScienceWorld server at {env_endpoint} not responding - metrics disabled" + ) + game_stats = None + + print(f"\nConnecting to allocated server at {server_url}") + client = ServiceClient( + server_url=server_url, + project_name=args.project_name, + experiment_name=args.experiment_name, + logger_backends=args.logger_backends, + ) + + client.set_config(args, env) + + num_steps = args.get("num_steps", None) + num_epochs = args.get("num_epochs", None) + + if num_steps: + print(f"\nStarting training for {num_steps} steps...") + elif num_epochs: + print(f"\nStarting training for {num_epochs} epochs...") + else: + print("\nStarting training (1 epoch default)...") + + print(f"Checkpoint save frequency: {args.save_freq}") + print(f"Validation frequency: {args.test_freq}") + print("=" * 60) + + try: + final_metrics = client.fit( + env=env, + num_epochs=num_epochs, + num_steps=num_steps, + save_freq=args.save_freq, + test_freq=args.test_freq, + verbose=True, + validate_before_training=True, + game_stats_client=game_stats, + ) + + print("\n" + "=" * 60) + print("Training completed!") + print(f"Final training metrics: {final_metrics}") + + # Display final cumulative game stats + if game_stats: + print("\n" + "-" * 40) + print("Final Game Statistics:") + cumulative = game_stats.get_all_stats() + if cumulative: + print(f" Total episodes: {cumulative.get('total_games', 0):.0f}") + print(f" Success rate: {cumulative.get('cumulative_win_rate', 0):.1%}") + print(f" Total successes: {cumulative.get('total_wins', 0):.0f}") + print(f" Total failures: {cumulative.get('total_losses', 0):.0f}") + print("=" * 60) + finally: + env.cleanup() + + +if __name__ == "__main__": + main() diff --git a/opentinker/environment/base_game.py b/opentinker/environment/base_game.py index 7d8e650d..9cec5640 100755 --- a/opentinker/environment/base_game.py +++ b/opentinker/environment/base_game.py @@ -209,6 +209,14 @@ def parse_action(self, raw_action: str) -> Optional[Any]: """ return raw_action + def close(self) -> None: + """Release any external resources held by this game instance. + + Override for environments that manage subprocesses, sockets, JVMs, or + other resources that should be cleaned up when an episode finishes. + """ + return None + class GameDataGenerator: """Data generator that works with any AbstractGame. @@ -273,3 +281,13 @@ def generate_sample(self, index: int) -> Dict[str, Any]: def get_interaction_name(self) -> str: """Return interaction name from game.""" return self._game.get_interaction_name() + + def close(self) -> None: + """Release resources held by the backing game instance.""" + self._game.close() + + def __del__(self): + try: + self.close() + except Exception: + pass diff --git a/opentinker/environment/base_game_environment.py b/opentinker/environment/base_game_environment.py index 149efd0e..e4547be3 100755 --- a/opentinker/environment/base_game_environment.py +++ b/opentinker/environment/base_game_environment.py @@ -153,6 +153,8 @@ def __init__( self.train_dataloader = None self.val_dataloader = None self._interaction_config_path = None + self._train_generator = None + self._val_generator = None self._setup_dataloader() self._setup_interaction_config() @@ -187,6 +189,7 @@ def _setup_dataloader(self): game_class=self.game_class, game_kwargs=self.game_kwargs, ) + self._train_generator = train_generator print(f"Creating training dataset (virtual_size={virtual_size})") train_dataset = DynamicGameDataset( @@ -216,6 +219,7 @@ def _setup_dataloader(self): game_kwargs=self.game_kwargs, seed=42, ) + self._val_generator = val_generator val_dataset = DynamicGameDataset( data_generator=val_generator, @@ -319,6 +323,17 @@ def cleanup(self): ): os.remove(self._interaction_config_path) print(f"Removed: {self._interaction_config_path}") + for resource in ( + getattr(self, "_train_generator", None), + getattr(self, "_val_generator", None), + getattr(self, "_game_instance", None), + ): + close_fn = getattr(resource, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception as exc: + print(f"[WARN] Failed to clean up {resource}: {exc}") def create_game_environment( diff --git a/opentinker/environment/base_game_server.py b/opentinker/environment/base_game_server.py index 668add10..77991858 100755 --- a/opentinker/environment/base_game_server.py +++ b/opentinker/environment/base_game_server.py @@ -31,8 +31,9 @@ run_game_server(MyGame, port=8081, board_size=9) """ +import asyncio import threading -from collections import defaultdict +from collections import defaultdict, deque from typing import Any, Dict, List, Optional, Type from fastapi import FastAPI, HTTPException @@ -42,6 +43,16 @@ from opentinker.environment.base_game import AbstractGame +def _close_game_instance(game: AbstractGame) -> None: + """Best-effort cleanup hook for game instances.""" + close_fn = getattr(game, "close", None) + if callable(close_fn): + try: + close_fn() + except Exception as exc: + print(f"[WARN] Failed to close {game.__class__.__name__}: {exc}") + + class BaseGameStats: """Thread-safe base statistics tracker for game server. @@ -358,6 +369,13 @@ def create_game_app( Parallelism is achieved via sharding (multiple server processes on consecutive ports). Each shard handles requests independently with its own in-memory instance registry. + Game instances are pooled: when an episode finishes, the instance is returned to a + pool for reuse rather than being destroyed. This avoids expensive re-initialization + (e.g., JVM startup for ScienceWorld). + + Blocking game operations (reset/step) are offloaded to a thread pool so the async + event loop is never blocked and concurrent requests can be processed in parallel. + Args: game_class: The AbstractGame subclass to use stats_class: Optional custom stats class (e.g., GomokuGameStats). @@ -373,6 +391,22 @@ def create_game_app( games: Dict[str, AbstractGame] = {} games_lock = threading.Lock() # Thread-safe access to games dict + # Pool of idle game instances available for reuse (avoids JVM/subprocess churn) + game_pool: deque[AbstractGame] = deque() + pool_lock = threading.Lock() + + def _get_or_create_game() -> AbstractGame: + """Get a game instance from the pool, or create a new one.""" + with pool_lock: + if game_pool: + return game_pool.popleft() + return game_class(**game_kwargs) + + def _return_to_pool(game: AbstractGame) -> None: + """Return a game instance to the pool for reuse.""" + with pool_lock: + game_pool.append(game) + # Use MultiJobGameStats for job isolation multi_stats = MultiJobGameStats(stats_class=stats_class or BaseGameStats) @@ -393,11 +427,11 @@ async def reset(request: ResetRequest): if instance_id in games: game = games[instance_id] else: - game = game_class(**game_kwargs) + game = _get_or_create_game() games[instance_id] = game - # Reset the game (this is the slow part) - observation = game.reset(**reset_kwargs) + # Reset the game in a thread (avoids blocking the event loop for JVM/subprocess envs) + observation = await asyncio.to_thread(game.reset, **reset_kwargs) # Track that this game has started (with job isolation) stats = multi_stats.get_job_stats(job_id) @@ -426,7 +460,8 @@ async def finalize(request: ResetRequest): instance_id = request.instance_id with games_lock: if instance_id in games: - del games[instance_id] + game = games.pop(instance_id) + _close_game_instance(game) return {"message": f"Instance {instance_id} removed"} return {"message": f"Instance {instance_id} not found", "status": "ignored"} @@ -444,7 +479,8 @@ async def step(request: StepRequest): ) game = games[instance_id] - result = game.step(action) + # Run blocking game.step() in a thread to avoid blocking the event loop + result = await asyncio.to_thread(game.step, action) # Record statistics with instance_id for per-game tracking (with job isolation) stats = multi_stats.get_job_stats(job_id) @@ -452,11 +488,13 @@ async def step(request: StepRequest): result.info, result.reward, result.done, instance_id, job_id ) - # Clean up finished games + # Return finished game instances to the pool for reuse instead of destroying them. + # This avoids expensive re-initialization (JVM startup, subprocess launch, etc.). if result.done: with games_lock: if instance_id in games: - del games[instance_id] + finished_game = games.pop(instance_id) + _return_to_pool(finished_game) return { "observation": result.observation, diff --git a/opentinker/environment/sciworld/__init__.py b/opentinker/environment/sciworld/__init__.py new file mode 100644 index 00000000..6af342f3 --- /dev/null +++ b/opentinker/environment/sciworld/__init__.py @@ -0,0 +1,16 @@ +"""ScienceWorld environment module for OpenTinker. + +This module provides ScienceWorld text-based environment integration +for LLM RL training. + +Usage: + from opentinker.environment.sciworld import SciWorldGame + + game = SciWorldGame() + obs = game.reset() + result = game.step("pick up thermometer") +""" + +from opentinker.environment.sciworld.sciworld_game import SciWorldGame + +__all__ = ["SciWorldGame"] diff --git a/opentinker/environment/sciworld/sciworld_game.py b/opentinker/environment/sciworld/sciworld_game.py new file mode 100644 index 00000000..a801ccbd --- /dev/null +++ b/opentinker/environment/sciworld/sciworld_game.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python3 +"""ScienceWorld game implementation for OpenTinker. + +This wrapper follows the same AbstractGame contract as ALFWorld and exposes +ScienceWorld through the generic FastAPI game server + GameEnvironment stack. +""" + +from __future__ import annotations + +import heapq +import logging +import random +import re +import threading +from typing import Any, Dict, Iterable, List, Optional + +from opentinker.environment.base_game import AbstractGame, StepResult + +try: + from scienceworld import ScienceWorldEnv + + SCIWORLD_AVAILABLE = True +except ImportError: + ScienceWorldEnv = None + SCIWORLD_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +class SciWorldGame(AbstractGame): + """ScienceWorld text environment wrapper.""" + + REWARD_SUCCESS = 10.0 + REWARD_FAILURE = -1.0 + REWARD_STEP = -0.01 + REWARD_INVALID_ACTION = -0.1 + + DEFAULT_MAX_STEPS = 30 + DEFAULT_MAX_ACTIONS = 50 + DEFAULT_LOCAL_THREAD_BASE = 20000 + + _thread_lock = threading.Lock() + _free_thread_offsets: list[int] = [] + _next_thread_offset = 0 + + def __init__( + self, + max_steps: int = DEFAULT_MAX_STEPS, + split: str = "train", + task_names: Optional[List[str]] = None, + task_ids: Optional[List[int]] = None, + variation_indices: Optional[List[int]] = None, + simplification_str: str = "", + jar_path: Optional[str] = None, + thread_base: int = 0, + max_action_candidates: int = DEFAULT_MAX_ACTIONS, + ): + if not SCIWORLD_AVAILABLE: + raise ImportError( + "scienceworld package not installed. Install with: pip install scienceworld" + ) + + self.max_steps = max_steps + self.split = (split or "train").strip().lower() + self.task_names = self._normalize_string_list(task_names) + self.task_ids = list(task_ids or []) + self.variation_indices = [int(v) for v in (variation_indices or [])] + self.simplification_str = simplification_str or "" + self.jar_path = jar_path + self.thread_base = int(thread_base or 0) + self.max_action_candidates = max(1, int(max_action_candidates or 1)) + + self._env: Optional[ScienceWorldEnv] = None + self._thread_num: Optional[int] = None + self._all_task_names: Optional[List[str]] = None + self._variation_cache: Dict[str, List[int]] = {} + + self._current_obs = "" + self._task_desc = "" + self._current_task_name = "" + self._current_variation = -1 + self._current_info: Dict[str, Any] = {} + self._admissible_actions: List[str] = [] + self._action_templates: List[str] = [] + self._objects: List[str] = [] + self._done = False + self._step_count = 0 + self._score = 0.0 + + @staticmethod + def _normalize_string_list(values: Optional[Iterable[str]]) -> List[str]: + if values is None: + return [] + if isinstance(values, str): + values = [v.strip() for v in values.split(",")] + return [str(v).strip() for v in values if str(v).strip()] + + @classmethod + def _allocate_thread_offset(cls) -> int: + with cls._thread_lock: + if cls._free_thread_offsets: + return heapq.heappop(cls._free_thread_offsets) + offset = cls._next_thread_offset + cls._next_thread_offset += 1 + return offset + + @classmethod + def _release_thread_offset(cls, offset: int) -> None: + with cls._thread_lock: + heapq.heappush(cls._free_thread_offsets, offset) + + def _ensure_env(self) -> None: + if self._env is not None: + return + + offset = self._allocate_thread_offset() + self._thread_num = self.thread_base + offset + logger.info( + "Initializing ScienceWorldEnv(threadNum=%s, max_steps=%s, split=%s)", + self._thread_num, + self.max_steps, + self.split, + ) + kwargs = {"envStepLimit": self.max_steps} + if self.jar_path: + kwargs["serverPath"] = self.jar_path + self._env = ScienceWorldEnv("", **kwargs) + + def close(self) -> None: + if self._env is not None: + shutdown_fn = getattr(self._env, "shutdown", None) + if callable(shutdown_fn): + try: + shutdown_fn() + except Exception as exc: + logger.warning("Failed to shutdown ScienceWorldEnv: %s", exc) + self._env = None + + if self._thread_num is not None: + offset = self._thread_num - self.thread_base + if offset >= 0: + self._release_thread_offset(offset) + self._thread_num = None + + def __del__(self): + try: + self.close() + except Exception: + pass + + def _get_all_task_names(self) -> List[str]: + self._ensure_env() + if self._all_task_names is None: + task_names = list(self._env.getTaskNames()) + self._all_task_names = [str(name) for name in task_names] + return self._all_task_names + + def _resolve_task_pool(self) -> List[str]: + task_pool = self._get_all_task_names() + if self.task_names: + allowed = set(self.task_names) + filtered = [name for name in task_pool if name in allowed] + if not filtered: + raise ValueError( + f"No ScienceWorld tasks matched task_names={self.task_names!r}" + ) + return filtered + + if self.task_ids: + resolved = [] + for task_id in self.task_ids: + if task_id < 0 or task_id >= len(task_pool): + raise ValueError( + f"ScienceWorld task_id out of range: {task_id} (num_tasks={len(task_pool)})" + ) + resolved.append(task_pool[task_id]) + return resolved + + return task_pool + + @staticmethod + def _extract_variations_from_value( + value: Any, task_name: str, split: str + ) -> Optional[List[int]]: + if value is None: + return None + if isinstance(value, dict): + for key in (task_name, split): + if key in value: + extracted = SciWorldGame._extract_variations_from_value( + value[key], task_name, split + ) + if extracted: + return extracted + return None + if isinstance(value, range): + return list(value) + if isinstance(value, (list, tuple, set)): + result = [] + for item in value: + try: + result.append(int(item)) + except Exception: + return None + return result + if isinstance(value, int): + return list(range(value)) + return None + + def _call_variation_method( + self, method_name: str, task_name: str + ) -> Optional[List[int]]: + method = getattr(self._env, method_name, None) + if not callable(method): + return None + + for args in ((task_name,), tuple()): + try: + value = method(*args) + except TypeError: + continue + except Exception: + return None + + variations = self._extract_variations_from_value( + value=value, + task_name=task_name, + split=self.split, + ) + if variations: + return variations + return None + + def _resolve_variations_for_task(self, task_name: str) -> List[int]: + if task_name in self._variation_cache: + return self._variation_cache[task_name] + + # 1. Get the ground-truth valid variations from the environment for this task/split + split_method_candidates = { + "train": [ + "getVariationsTrain", + "getTrainVariations", + "getTaskTrainVariations", + ], + "dev": [ + "getVariationsDev", + "getDevVariations", + "getTaskDevVariations", + ], + "test": [ + "getVariationsTest", + "getTestVariations", + "getTaskTestVariations", + ], + } + generic_candidates = [ + "getVariationsForTask", + "getTaskVariations", + "getVariations", + "getVariationIndices", + "getNumVariations", + ] + + env_variations = [] + for method_name in split_method_candidates.get(self.split, []): + env_variations = self._call_variation_method(method_name, task_name) or [] + if env_variations: + break + + if not env_variations: + for method_name in generic_candidates: + env_variations = ( + self._call_variation_method(method_name, task_name) or [] + ) + if env_variations: + break + + if not env_variations: + env_variations = [0] + + # 2. If user provided a specific set of indices, take the intersection + if self.variation_indices: + user_indices = set(self.variation_indices) + env_indices = set(env_variations) + variations = sorted(list(user_indices.intersection(env_indices))) + + # If intersection is empty, fall back to environment defaults to prevent crash + if not variations: + logger.warning( + f"Task {task_name!r} has no overlap with provided variation_indices. " + f"Falling back to environment's variations." + ) + variations = sorted(env_variations) + else: + variations = sorted(env_variations) + + self._variation_cache[task_name] = variations + return variations + + def _select_task_and_variation( + self, + task_name: Optional[str] = None, + variation: Optional[int] = None, + ) -> tuple[str, int]: + task_pool = self._resolve_task_pool() + selected_task = task_name or random.choice(task_pool) + if selected_task not in task_pool: + raise ValueError( + f"Task {selected_task!r} is not in the configured ScienceWorld task pool" + ) + + variations = self._resolve_variations_for_task(selected_task) + + # 1. Determine initial selection + if variation is not None: + v_int = int(variation) + if v_int in variations: + selected_variation = v_int + else: + # Fallback to a valid one from the list via modulo + selected_variation = variations[v_int % len(variations)] + else: + selected_variation = random.choice(variations) + + # 2. MANDATORY HARD LIMIT CHECK + # ScienceWorld often has a task-specific maximum variations. + # We try several ways to find the absolute limit to prevent the Java-side error. + limit = 1000000 # Default huge + + # Try to get the count from various scienceworld metadata sources + for method_name in ["getMaxVariations", "getNumVariations", "getVariationCount"]: + try: + m = getattr(self._env, method_name, None) + if callable(m): + # Try with task name, then without + for args in ((selected_task,), ()): + try: + res = m(*args) + if isinstance(res, int) and res > 0: + limit = min(limit, res) + break + except Exception: + continue + except Exception: + pass + + # If we have a list of variations from _resolve, its max is also a limit + if variations: + limit = min(limit, max(variations) + 1) + + # If our selection still exceeds the limit, force it down + if selected_variation >= limit: + old_v = selected_variation + selected_variation = old_v % limit + logger.warning( + f"Forced fix: variation {old_v} exceeds limit {limit} for task {selected_task!r}. " + f"New variation: {selected_variation}" + ) + + return selected_task, selected_variation + + def reset( + self, + task_name: Optional[str] = None, + variation: Optional[int] = None, + seed: Optional[int] = None, + **kwargs, + ) -> str: + del kwargs + self._ensure_env() + + if seed is not None: + random.seed(seed) + + selected_task, selected_variation = self._select_task_and_variation( + task_name=task_name, + variation=variation, + ) + + self._env.load(selected_task, selected_variation, self.simplification_str) + observation, info = self._env.reset() + + self._current_task_name = selected_task + self._current_variation = selected_variation + self._current_obs = str(observation) + self._current_info = info if isinstance(info, dict) else {} + self._task_desc = self._get_task_description() + self._admissible_actions = self._extract_valid_actions_from_info(self._current_info) + self._action_templates = self._extract_action_templates() + self._objects = self._extract_objects() + self._done = False + self._step_count = 0 + self._score = self._extract_score(self._current_info) + + return self._format_observation(self._current_obs) + + def step(self, action: str) -> StepResult: + if self._done: + return StepResult( + observation="Episode already finished.", + reward=0.0, + done=True, + info={"error": "episode_finished"}, + ) + + self._ensure_env() + self._step_count += 1 + parsed_action = self._parse_action(action) + + observation, reward, done, info = self._env.step(parsed_action) + info = info if isinstance(info, dict) else {} + + self._current_obs = str(observation) + self._current_info = info + self._admissible_actions = self._extract_valid_actions_from_info(self._current_info) + self._action_templates = self._extract_action_templates() + self._objects = self._extract_objects() + + # Enforce step limit + if self._step_count >= self.max_steps and not done: + done = True + observation = ( + f"TIMEOUT: Maximum steps ({self.max_steps}) reached.\n\n{observation}" + ) + self._current_obs = str(observation) + + score = self._extract_score(info, default=self._score) + score_delta = score - self._score + success = self._extract_success( + done=done, reward=reward, info=info, score=score + ) + valid_action = self._extract_valid_action( + info=info, observation=self._current_obs + ) + + # Use ScienceWorld's score delta as reward signal (0-100 scale → 0-1) + if score_delta > 0: + final_reward = score_delta / 100.0 + elif valid_action is False: + final_reward = self.REWARD_INVALID_ACTION + else: + final_reward = self.REWARD_STEP + + self._done = bool(done) + self._score = score + + return StepResult( + observation=self._format_observation(self._current_obs), + reward=final_reward, + done=self._done, + info={ + "action_taken": parsed_action, + "task": self._task_desc, + "task_name": self._current_task_name, + "variation": self._current_variation, + "score": score, + "raw_reward": float(reward) + if isinstance(reward, (int, float)) + else reward, + "success": success, + "valid_action": valid_action, + }, + ) + + @staticmethod + def _coerce_optional_bool(value: Any) -> Optional[bool]: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {"true", "1", "yes", "y"}: + return True + if lowered in {"false", "0", "no", "n"}: + return False + return None + + @staticmethod + def _normalize_reward(reward: Any) -> float: + if not isinstance(reward, (int, float)): + return 0.0 + reward = float(reward) + if abs(reward) > 10.0: + reward = reward / 10.0 + return max(min(reward, 10.0), -10.0) + + @staticmethod + def _extract_score(info: Dict[str, Any], default: float = 0.0) -> float: + if not isinstance(info, dict): + return float(default) + for key in ("score", "normalizedScore", "taskScore"): + value = info.get(key) + if isinstance(value, (int, float)): + return float(value) + return float(default) + + def _extract_success( + self, + done: bool, + reward: Any, + info: Dict[str, Any], + score: float, + ) -> bool: + for key in ("success", "taskCompleted", "completed", "isCompleted"): + value = self._coerce_optional_bool(info.get(key)) + if value is not None: + return value + if done and score >= 100.0: + return True + if done and isinstance(reward, (int, float)) and float(reward) >= 10.0: + return True + return False + + def _extract_valid_action( + self, info: Dict[str, Any], observation: str + ) -> Optional[bool]: + for key in ("valid", "action_valid", "validAction"): + value = info.get(key) + # ScienceWorld stores valid action *list* under "valid" — skip it + if isinstance(value, (list, dict)): + continue + coerced = self._coerce_optional_bool(value) + if coerced is not None: + return coerced + + lowered = observation.lower() + invalid_markers = [ + "not a valid action", + "invalid action", + "i don't understand", + "nothing happens", + "you can't do that", + "no known action matches that input", + "ambiguous request", + ] + if any(marker in lowered for marker in invalid_markers): + return False + return None + + def _call_optional_text_method(self, method_name: str) -> str: + self._ensure_env() + method = getattr(self._env, method_name, None) + if not callable(method): + return "" + try: + value = method() + except Exception: + return "" + return str(value).strip() if value is not None else "" + + def _get_task_description(self) -> str: + task_desc = self._call_optional_text_method("getTaskDescription") + if task_desc: + return task_desc + task_desc = self._call_optional_text_method("taskdescription") + if task_desc: + return task_desc + return f"Complete ScienceWorld task {self._current_task_name}." + + def _extract_valid_actions_from_info(self, info: Dict[str, Any]) -> List[str]: + """Extract valid action list from ScienceWorld's info['valid'].""" + valid = info.get("valid", []) + if not isinstance(valid, list): + return [] + actions = [] + seen = set() + for v in valid: + a = str(v).strip() + if a and a not in seen: + seen.add(a) + actions.append(a) + return actions + + def _extract_action_templates(self) -> List[str]: + """Get possible action templates from ScienceWorld (e.g. 'open OBJ').""" + self._ensure_env() + method = getattr(self._env, "getPossibleActions", None) + if not callable(method): + return [] + try: + templates = method() + except Exception: + return [] + if not isinstance(templates, (list, tuple)): + return [] + return [str(t).strip() for t in templates if str(t).strip()] + + def _extract_objects(self) -> List[str]: + """Get possible objects from ScienceWorld (e.g. 'door to kitchen').""" + self._ensure_env() + method = getattr(self._env, "getPossibleObjects", None) + if not callable(method): + return [] + try: + objects = method() + except Exception: + return [] + if not isinstance(objects, (list, tuple)): + return [] + return [str(o).strip() for o in objects if str(o).strip()] + + def _format_observation(self, observation: str) -> str: + parts = [ + "=== Current State ===", + observation.strip() or "(empty observation)", + ] + + score_line = f"Score: {self._score:.2f}" + if self._current_variation >= 0: + score_line += f" | Variation: {self._current_variation}" + parts.extend(["", score_line]) + + inventory = self._call_optional_text_method("inventory") + if inventory: + parts.extend(["", "=== Inventory ===", inventory]) + + look = self._call_optional_text_method("look") + if look and look != observation: + parts.extend(["", "=== Look ===", look]) + + if self._action_templates: + parts.extend(["", "=== Action Templates ==="]) + parts.extend(f"- {template}" for template in self._action_templates) + + if self._objects: + parts.extend(["", "=== Objects ==="]) + parts.extend(f"- {obj}" for obj in self._objects) + + return "\n".join(parts) + + def _parse_action(self, raw_action: str) -> str: + match = re.search( + r"\s*(.*?)\s*", raw_action, re.IGNORECASE | re.DOTALL + ) + if match: + return match.group(1).strip() + + lines = [ + line.strip() for line in raw_action.strip().splitlines() if line.strip() + ] + return lines[-1] if lines else "look around" + + def get_system_prompt(self) -> str: + return ( + "You are an AI assistant playing ScienceWorld, a text-based science environment.\n" + "Your goal is to complete the task by issuing grounded text actions.\n\n" + "IMPORTANT: You MUST respond in the following format:\n" + "1. Think briefly in \n" + "2. Output exactly one environment action in \n\n" + "Examples of actions:\n" + "- look around\n" + "- inventory\n" + "- examine beaker\n" + "- open door to art studio\n" + "- pour cup containing blue paint in glass cup\n\n" + "Example response:\n" + "I should inspect the room before manipulating objects.\n" + "look around" + ) + + def get_initial_user_message(self) -> str: + return ( + f"Task: {self._task_desc}\n\n" + "Interact with the environment and complete the task." + ) + + def get_state(self) -> Dict[str, Any]: + return { + "task": self._task_desc, + "task_name": self._current_task_name, + "variation": self._current_variation, + "score": self._score, + "step_count": self._step_count, + "max_steps": self.max_steps, + "done": self._done, + "candidate_actions": self._admissible_actions, + "action_templates": self._action_templates, + "objects": self._objects, + } + + def generate_initial_state(self) -> Dict[str, Any]: + task_name, variation = self._select_task_and_variation() + return { + "task_name": task_name, + "variation": variation, + "seed": random.randint(0, 1_000_000), + } + + def get_user_message_with_state( + self, + task_name: Optional[str] = None, + variation: Optional[int] = None, + **kwargs, + ) -> str: + self.reset(task_name=task_name, variation=variation, **kwargs) + return ( + f"Task: {self._task_desc}\n\n" + f"{self._format_observation(self._current_obs)}\n\n" + "What will you do next?" + ) + + def get_interaction_name(self) -> str: + return "sciworld" diff --git a/opentinker/environment/sciworld/sciworld_server.py b/opentinker/environment/sciworld/sciworld_server.py new file mode 100644 index 00000000..d9a9ca55 --- /dev/null +++ b/opentinker/environment/sciworld/sciworld_server.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +"""ScienceWorld environment server.""" + +import argparse +import os +import subprocess +import sys +import time + +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +def main(): + parser = argparse.ArgumentParser(description="ScienceWorld Game Server") + parser.add_argument("--host", default="0.0.0.0", help="Server host") + parser.add_argument("--port", type=int, default=8092, help="Server port") + parser.add_argument( + "--shards", + type=int, + default=8, + help="Number of independent server processes to launch on consecutive ports.", + ) + parser.add_argument( + "--max_steps", + type=int, + default=30, + help="Maximum steps per episode.", + ) + parser.add_argument( + "--split", + type=str, + default="train", + choices=["train", "dev", "test"], + help="ScienceWorld variation split to sample from.", + ) + parser.add_argument( + "--task-name", + action="append", + default=None, + help="Restrict to one or more ScienceWorld task names. Repeat for multiple tasks.", + ) + parser.add_argument( + "--task-id", + action="append", + type=int, + default=None, + help="Restrict to one or more ScienceWorld task ids. Repeat for multiple ids.", + ) + parser.add_argument( + "--variation", + action="append", + type=int, + default=None, + help="Restrict to one or more explicit variation indices.", + ) + parser.add_argument( + "--simplification-str", + type=str, + default="", + help="ScienceWorld simplification string passed directly to env.load().", + ) + parser.add_argument( + "--jar-path", + type=str, + default=None, + help="Optional path to the ScienceWorld JAR if not using the packaged default.", + ) + parser.add_argument( + "--thread-base", + type=int, + default=0, + help="Base ScienceWorld thread number for this server process.", + ) + parser.add_argument( + "--threads-per-shard", + type=int, + default=256, + help="Reserved ScienceWorld thread-number block size per shard.", + ) + args = parser.parse_args() + + from opentinker.environment.sciworld.sciworld_game import SciWorldGame + + print("\nScienceWorld Game Configuration:") + print(f" Max steps: {args.max_steps}") + print(f" Split: {args.split}") + print(f" Shards: {args.shards}") + print(f" Task names: {args.task_name or 'all'}") + print(f" Task ids: {args.task_id or 'none'}") + print(f" Variations: {args.variation or 'split default'}") + print(f" Simplification: {args.simplification_str or '(none)'}") + print(f" Thread base: {args.thread_base}") + print(f" Threads per shard: {args.threads_per_shard}") + print("\nReward structure:") + print(f" Success: +{SciWorldGame.REWARD_SUCCESS}") + print(f" Failure: {SciWorldGame.REWARD_FAILURE}") + print(f" Step penalty: {SciWorldGame.REWARD_STEP}") + print(f" Invalid action: {SciWorldGame.REWARD_INVALID_ACTION}") + + if args.shards and args.shards > 1: + print( + f"\nStarting sharded mode: {args.shards} shards on ports " + f"{args.port}..{args.port + args.shards - 1}" + ) + children: list[subprocess.Popen] = [] + try: + for i in range(args.shards): + port_i = args.port + i + thread_base_i = args.thread_base + i * args.threads_per_shard + cmd = [ + sys.executable, + os.path.abspath(__file__), + "--host", + args.host, + "--port", + str(port_i), + "--shards", + "1", + "--max_steps", + str(args.max_steps), + "--split", + args.split, + "--simplification-str", + args.simplification_str, + "--thread-base", + str(thread_base_i), + "--threads-per-shard", + str(args.threads_per_shard), + ] + if args.jar_path is not None: + cmd.extend(["--jar-path", args.jar_path]) + for task_name in args.task_name or []: + cmd.extend(["--task-name", task_name]) + for task_id in args.task_id or []: + cmd.extend(["--task-id", str(task_id)]) + for variation in args.variation or []: + cmd.extend(["--variation", str(variation)]) + + children.append(subprocess.Popen(cmd)) + time.sleep(0.2) + + print("Shards started. Press Ctrl+C to stop all shards.") + while True: + for proc in children: + rc = proc.poll() + if rc is not None: + raise RuntimeError( + f"Shard process exited early with code {rc}: pid={proc.pid}" + ) + time.sleep(1.0) + except KeyboardInterrupt: + pass + finally: + for proc in children: + try: + proc.terminate() + except Exception: + pass + for proc in children: + try: + proc.wait(timeout=5) + except Exception: + try: + proc.kill() + except Exception: + pass + return + + from opentinker.environment.base_game_server import run_game_server + + run_game_server( + game_class=SciWorldGame, + host=args.host, + port=args.port, + max_steps=args.max_steps, + split=args.split, + task_names=args.task_name, + task_ids=args.task_id, + variation_indices=args.variation, + simplification_str=args.simplification_str, + jar_path=args.jar_path, + thread_base=args.thread_base, + ) + + +if __name__ == "__main__": + main() diff --git a/opentinker/scheduler/job_scheduler.py b/opentinker/scheduler/job_scheduler.py index 2a1824d4..dbcd5a48 100755 --- a/opentinker/scheduler/job_scheduler.py +++ b/opentinker/scheduler/job_scheduler.py @@ -10,6 +10,7 @@ import logging import os import subprocess +import sys import threading import time import uuid @@ -238,12 +239,15 @@ def check_gpu_available(gpu_id: int) -> bool: - Scheduler restarts and loses allocation state - Jobs crash without properly releasing resources + NOTE: Currently disabled to allow shared GPU environments. + Args: gpu_id: GPU ID to check Returns: True if GPU is idle and available, False if occupied or check fails """ + return True # Skip GPU check for shared environments try: import subprocess @@ -284,7 +288,7 @@ def check_gpu_available(gpu_id: int) -> bool: return True # Fail open # Thresholds for considering a GPU "idle" - MAX_MEMORY_MB = 10 # Allow up to 100 MB (some baseline CUDA overhead) + MAX_MEMORY_MB = 40000 # Allow up to 40 GB (for shared GPU environments) MAX_UTILIZATION = 1000 # Allow up to 5% utilization if memory_used_mb > MAX_MEMORY_MB or utilization_percent > MAX_UTILIZATION: @@ -294,36 +298,6 @@ def check_gpu_available(gpu_id: int) -> bool: ) return False - # Check 2: Look for running processes on this GPU - pmon_result = subprocess.run( - ["nvidia-smi", "pmon", "-c", "1", "-s", "um"], - capture_output=True, - text=True, - timeout=5, - ) - - if pmon_result.returncode == 0: - # Parse pmon output to check for processes on this GPU - # Format: "# gpu pid type sm mem enc dec command" - # " 0 12345 C 50 500 0 0 python" - lines = pmon_result.stdout.strip().split("\n") - for line in lines: - if line.startswith("#") or not line.strip(): - continue - parts = line.split() - if len(parts) >= 2: - try: - gpu_idx = int(parts[0].strip()) - if gpu_idx == gpu_id and parts[1].strip() != "-": - # Found a process on this GPU - pid = parts[1].strip() - logger.warning( - f"GPU {gpu_id}: ⚠️ OCCUPIED - Process {pid} detected via pmon" - ) - return False - except (ValueError, IndexError): - continue - # All checks passed - GPU is idle logger.debug( f"GPU {gpu_id}: ✅ Available (Memory: {memory_used_mb} MB, Utilization: {utilization_percent}%)" @@ -1039,10 +1013,14 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: env["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, job.gpu_ids)) # Pass job_id to agent loop for per-client trace subdirectory isolation env["ROLLOUT_TRACE_JOB_ID"] = job.job_id + # Ensure the current Python env's bin dir is on PATH (for ninja, etc.) + python_bin_dir = os.path.dirname(sys.executable) + if python_bin_dir not in env.get("PATH", "").split(os.pathsep): + env["PATH"] = python_bin_dir + os.pathsep + env.get("PATH", "") # Build command line arguments from config cmd = [ - "python", + sys.executable, self.server_script_path, f"server.port={job.port}", f"job_id={job.job_id}", @@ -1085,12 +1063,18 @@ def _launch_server(self, job: JobInfo) -> subprocess.Popen: if kl_config: use_kl_in_reward = kl_config.get("use_kl_in_reward") if use_kl_in_reward is not None: - cmd.append(f"algorithm.use_kl_in_reward={str(use_kl_in_reward).lower()}") - logger.info(f"Job {job.job_id}: ✓ KL use_kl_in_reward={use_kl_in_reward}") + cmd.append( + f"algorithm.use_kl_in_reward={str(use_kl_in_reward).lower()}" + ) + logger.info( + f"Job {job.job_id}: ✓ KL use_kl_in_reward={use_kl_in_reward}" + ) use_kl_loss = kl_config.get("use_kl_loss") if use_kl_loss is not None: - cmd.append(f"actor_rollout_ref.actor.use_kl_loss={str(use_kl_loss).lower()}") + cmd.append( + f"actor_rollout_ref.actor.use_kl_loss={str(use_kl_loss).lower()}" + ) logger.info(f"Job {job.job_id}: ✓ KL use_kl_loss={use_kl_loss}") kl_loss_coef = kl_config.get("kl_loss_coef") diff --git a/opentinker/scheduler/scheduler_users.db b/opentinker/scheduler/scheduler_users.db new file mode 100644 index 00000000..10adcc2f Binary files /dev/null and b/opentinker/scheduler/scheduler_users.db differ diff --git a/opentinker/scripts/launch_scheduler.sh b/opentinker/scripts/launch_scheduler.sh index 1e1d6028..0777a180 100644 --- a/opentinker/scripts/launch_scheduler.sh +++ b/opentinker/scripts/launch_scheduler.sh @@ -9,9 +9,12 @@ export ROLLOUT_TRACE_DIR="${ROLLOUT_TRACE_DIR:-./traces}" mkdir -p "$ROLLOUT_TRACE_DIR" # export NVCC_EXECUTABLE=$CUDA_HOME/bin/nvcc -export TORCH_CUDA_ARCH_LIST="9.0" +export TORCH_CUDA_ARCH_LIST="8.6" export FLASHINFER_HOMOGENEOUS_MS=1 +# Workaround for vLLM cumem allocator CUDA errors during wake_up/sleep cycles +export VLLM_DISABLE_SLEEP_MODE=1 + # Default configuration AVAILABLE_GPUS="[0,1,2,3]" PORT_RANGE="null" # Set to null for auto-detection