diff --git a/Justfile b/Justfile index f8f6009..f42808d 100644 --- a/Justfile +++ b/Justfile @@ -70,34 +70,69 @@ install_training: @command -v uv > /dev/null || (echo "uv not found. Please install from https://docs.astral.sh/uv/" && exit 1) cd training && uv sync -# Generate selfplay training data (traditional MCTS, high quality for bootstrapping) -selfplay GAMES="100" PLAYOUTS="20000": install_cargo - mkdir -p training/artifacts - cargo run --release -p selfplay -- {{GAMES}} {{PLAYOUTS}} > training/artifacts/training_data.jsonl - @echo "Generated training data with auxiliary targets (ownership, score_diff)" - -# Generate selfplay training data using neural network guidance (faster, more games) -selfplay_nn GAMES="200" PLAYOUTS="800": install_cargo - mkdir -p training/artifacts - cargo run --release -p selfplay -- {{GAMES}} {{PLAYOUTS}} --nn >> training/artifacts/training_data.jsonl - @echo "Generated NN-guided training data with auxiliary targets" +# Generate one durable policy-v2 selfplay run with the production-equivalent uniform-prior baseline +selfplay GAMES="100" PLAYOUTS="20000" MAX_SAMPLES="100000": install_cargo install_training + cd training && uv run generate_selfplay.py --games {{GAMES}} --playouts {{PLAYOUTS}} --max-samples {{MAX_SAMPLES}} + +# Generate one durable policy-v2 experimental selfplay run using neural network root priors +selfplay_nn GAMES="200" PLAYOUTS="800" MAX_SAMPLES="100000" PRIOR_WEIGHT="0.01": install_cargo install_training + cd training && uv run generate_selfplay.py --games {{GAMES}} --playouts {{PLAYOUTS}} --nn --prior-weight {{PRIOR_WEIGHT}} --max-samples {{MAX_SAMPLES}} + +# Rebuild the active training_data.jsonl replay view from uniform teacher runs only +replay MAX_SAMPLES="100000": install_training + cd training && uv run replay.py build --max-samples {{MAX_SAMPLES}} + +# Rebuild replay with explicitly selected teachers for experiments +replay_experimental TEACHERS="uniform,nn_root" MAX_SAMPLES="100000": install_training + cd training && uv run replay.py build --max-samples {{MAX_SAMPLES}} --include-teachers {{TEACHERS}} # Train the model on existing data train_only EPOCHS="20": install_training cd training && uv run train.py --epochs {{EPOCHS}} -# Run one iteration: selfplay + training (traditional MCTS for bootstrapping) +# Evaluate trained policy priors against the uniform-prior production baseline +evaluate_model PAIRS="200" PLAYOUTS="400" MIN_SCORE="0.53": install_cargo + cargo run --release --bin nn_vs_mcts -- {{PAIRS}} {{PLAYOUTS}} --summary-only --min-score {{MIN_SCORE}} + +# Evaluate trained policy priors without enforcing a promotion gate +evaluate_model_report PAIRS="100" PLAYOUTS="400": install_cargo + cargo run --release --bin nn_vs_mcts -- {{PAIRS}} {{PLAYOUTS}} + +# Sweep root prior blend weights without enforcing a promotion gate +evaluate_prior_sweep PAIRS="200" PLAYOUTS="400" WEIGHTS="0.00,0.02,0.05,0.08,0.10,0.15,0.25": install_cargo install_training + cd training && uv run prior_sweep.py --pairs {{PAIRS}} --playouts {{PLAYOUTS}} --weights {{WEIGHTS}} + +# Sanity-check evaluator by comparing uniform priors against uniform priors +evaluate_uniform PAIRS="20" PLAYOUTS="200": install_cargo + cargo run --release --bin nn_vs_mcts -- {{PAIRS}} {{PLAYOUTS}} --uniform-vs-uniform + +# Promote the current ONNX model to the browser-served artifact path after evaluation passes +promote_model PAIRS="200" PLAYOUTS="400" MIN_SCORE="0.53": install_cargo + cargo run --release --bin nn_vs_mcts -- {{PAIRS}} {{PLAYOUTS}} --summary-only --min-score {{MIN_SCORE}} + mkdir -p www/public/models + cp training/artifacts/model.onnx www/public/models/htmf-policy.onnx + @echo "Promoted training/artifacts/model.onnx to www/public/models/htmf-policy.onnx" + +# Run one iteration: selfplay + training + evaluation train GAMES="100" PLAYOUTS="20000" EPOCHS="20": install_cargo install_training @echo "Running selfplay with {{GAMES}} games, {{PLAYOUTS}} playouts..." just selfplay {{GAMES}} {{PLAYOUTS}} @echo "Training for {{EPOCHS}} epochs..." just train_only {{EPOCHS}} - @echo "Training iteration complete!" + @echo "Evaluating trained model for promotion..." + just promote_model + @echo "Training iteration complete and model passed the promotion gate." -# Run iterative training loop (AlphaZero-style with NN-guided selfplay) -iterate ITERATIONS="20" GAMES="200" PLAYOUTS="1000" EPOCHS="20": install_cargo install_training - cd training && uv run iterate.py --iterations {{ITERATIONS}} --games {{GAMES}} --playouts {{PLAYOUTS}} --epochs {{EPOCHS}} +# Run iterative training loop with promotion gate +iterate ITERATIONS="20" GAMES="200" PLAYOUTS="1000" EPOCHS="20" EVAL_PAIRS="100" EVAL_PLAYOUTS="400": install_cargo install_training + cd training && uv run iterate.py --iterations {{ITERATIONS}} --games {{GAMES}} --playouts {{PLAYOUTS}} --epochs {{EPOCHS}} --eval-pairs {{EVAL_PAIRS}} --eval-playouts {{EVAL_PLAYOUTS}} # Create blank models for debugging blank_models: install_training cd training && uv run create_blank_models.py + +# Tiny end-to-end ML smoke test +ml_smoke: install_cargo install_training + cd training && uv run generate_selfplay.py --games 1 --playouts 20 --max-samples 1000 + cd training && uv run train.py --epochs 1 --num-filters 8 --num-blocks 1 --batch-size 16 + cargo run --release --bin nn_vs_mcts -- 1 20 diff --git a/bots/src/bin/debug_modes.rs b/bots/src/bin/debug_modes.rs index fc18e04..1319b44 100644 --- a/bots/src/bin/debug_modes.rs +++ b/bots/src/bin/debug_modes.rs @@ -1,11 +1,11 @@ -//! Debug benchmark: Compare MCTS modes +//! Debug benchmark: Compare policy-prior sources //! //! This tests: -//! 1. Pure MCTS (UCB1) vs PUCT with uniform priors - PUCT should be similar or better +//! 1. Production baseline vs itself //! 2. PUCT with NN priors vs PUCT with uniform priors - shows if NN is helping //! -//! The PUCT mode uses random rollouts for evaluation (same as Pure MCTS), -//! but uses the PUCT selection formula which outperforms UCB1. +//! All modes use the same PUCT search and random rollout evaluation. The only +//! thing that changes is the policy prior source. use std::sync::Arc; @@ -135,12 +135,11 @@ fn main() { println!("Playouts per move: {}", PLAYOUTS); println!("Games per comparison: {}", NUM_GAMES); - // Test 1: Pure MCTS (UCB1) vs PUCT with uniform priors - // PUCT should be similar or better than UCB1 + // Test 1: Baseline implementation paths should be equivalent. run_comparison( - "Test 1: Pure MCTS (UCB1) vs PUCT (uniform priors)", - "Pure", - "PUCT", + "Test 1: Production baseline vs explicit uniform priors", + "Baseline", + "Uniform", |game, player| MCTSBot::new(game, player), |game, player| MCTSBot::with_neural_net(game, player, None), ); @@ -161,12 +160,12 @@ fn main() { |game, player| MCTSBot::with_neural_net(game, player, None), ); - // Test 3: PUCT with NN priors vs Pure MCTS + // Test 3: PUCT with NN priors vs production baseline let nn_clone2 = nn.clone(); run_comparison( - "Test 3: PUCT (NN priors) vs Pure MCTS (UCB1)", + "Test 3: PUCT (NN priors) vs production baseline", "NN", - "Pure", + "Baseline", move |game, player| MCTSBot::with_neural_net(game, player, Some(nn_clone2.clone())), |game, player| MCTSBot::new(game, player), ); @@ -179,7 +178,7 @@ fn main() { println!("\n\nInterpretation Guide:"); println!("====================="); - println!("- PUCT should be similar or better than Pure MCTS"); + println!("- Baseline vs Uniform should look roughly even"); println!("- If NN >> Uniform: Neural network policy priors are helping"); println!("- If Uniform >> NN: Neural network priors are hurting (bad training)"); } diff --git a/bots/src/bin/nn_vs_mcts.rs b/bots/src/bin/nn_vs_mcts.rs index e38d423..4a44b04 100644 --- a/bots/src/bin/nn_vs_mcts.rs +++ b/bots/src/bin/nn_vs_mcts.rs @@ -1,11 +1,11 @@ -//! Benchmark: Neural Network guided MCTS vs Traditional MCTS +//! Evaluate trained policy priors against the production uniform-prior baseline. //! -//! The NN bot uses PUCT selection with NN policy priors and random rollouts. -//! This provides a strong baseline that can be incrementally improved through training. -//! -//! Usage: nn_vs_mcts [num_games] [num_playouts] [--uniform] -//! --uniform: Use uniform priors instead of trained NN (for baseline comparison) +//! Usage: nn_vs_mcts [num_pairs] [num_playouts] [--uniform-vs-uniform] [--min-score SCORE] +//! num_pairs: each seed is played twice with colors swapped +//! --uniform-vs-uniform: sanity-check the evaluator against itself +//! --min-score: return a failing exit code if model score is below this value +use std::process::ExitCode; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -15,183 +15,258 @@ use htmf_bots::{MCTSBot, NeuralNet}; use rand::prelude::*; use rayon::prelude::*; -fn main() { +struct GameResult { + game_num: usize, + model_player: usize, + model_score: usize, + baseline_player: usize, + baseline_score: usize, + move_count: usize, + result: usize, // 0=model_win, 1=baseline_win, 2=draw + model_time: Duration, + baseline_time: Duration, +} + +fn play_game( + game_num: usize, + seed: u64, + model_player: usize, + num_playouts: usize, + nn: Option>, +) -> GameResult { + let baseline_player = 1 - model_player; + + let mut rng = StdRng::seed_from_u64(seed); + let mut game = GameState::new_two_player(&mut rng); + + let mut model_bot = MCTSBot::with_neural_net(game.clone(), Player { id: model_player }, nn); + let mut baseline_bot = MCTSBot::new( + game.clone(), + Player { + id: baseline_player, + }, + ); + + let mut move_count = 0; + let mut model_time = Duration::ZERO; + let mut baseline_time = Duration::ZERO; + + while let Some(p) = game.active_player() { + let action = if p.id == model_player { + let start = Instant::now(); + for _ in 0..num_playouts { + model_bot.playout(); + } + let action = model_bot.take_action(); + model_time += start.elapsed(); + action + } else { + let start = Instant::now(); + for _ in 0..num_playouts { + baseline_bot.playout(); + } + let action = baseline_bot.take_action(); + baseline_time += start.elapsed(); + action + }; + + game.apply_action(&action).unwrap(); + model_bot.update(&game); + baseline_bot.update(&game); + move_count += 1; + } + + let scores = game.get_scores(); + let model_score = scores[model_player]; + let baseline_score = scores[baseline_player]; + let result = if model_score > baseline_score { + 0 + } else if baseline_score > model_score { + 1 + } else { + 2 + }; + + GameResult { + game_num, + model_player, + model_score, + baseline_player, + baseline_score, + move_count, + result, + model_time, + baseline_time, + } +} + +fn main() -> ExitCode { let args: Vec = std::env::args().collect(); + let mut uniform_vs_uniform = false; + let mut min_score: Option = None; + let mut numeric_args = Vec::new(); - // Check for --uniform flag - let use_uniform = args.iter().any(|a| a == "--uniform"); - let numeric_args: Vec<&String> = args - .iter() - .skip(1) - .filter(|a| !a.starts_with("--")) - .collect(); + let mut i = 1; + while i < args.len() { + match args[i].as_str() { + "--uniform-vs-uniform" => { + uniform_vs_uniform = true; + i += 1; + } + "--min-score" => { + let Some(value) = args.get(i + 1) else { + eprintln!("--min-score requires a numeric value"); + return ExitCode::from(2); + }; + match value.parse::() { + Ok(value) => min_score = Some(value), + Err(_) => { + eprintln!("Invalid --min-score value: {value}"); + return ExitCode::from(2); + } + } + i += 2; + } + value if value.starts_with("--") => { + eprintln!("Unknown argument: {value}"); + return ExitCode::from(2); + } + value => { + numeric_args.push(value.to_owned()); + i += 1; + } + } + } - let num_games: usize = numeric_args + let num_pairs: usize = numeric_args .first() - .and_then(|s| s.parse().ok()) - .unwrap_or(40); + .and_then(|s| s.parse::().ok()) + .unwrap_or(100); let num_playouts: usize = numeric_args .get(1) - .and_then(|s| s.parse().ok()) + .and_then(|s| s.parse::().ok()) .unwrap_or(400); - // Load neural network unless uniform mode - let nn: Option> = if use_uniform { - eprintln!("Using uniform priors (PUCT without trained NN)"); + let nn: Option> = if uniform_vs_uniform { + eprintln!("Running uniform-prior baseline against itself"); None } else { - eprintln!("Loading neural network..."); + eprintln!("Loading neural network from training/artifacts/model.onnx..."); match NeuralNet::load("training/artifacts/model.onnx") { - Ok(model) => { - eprintln!("Neural network loaded successfully"); - Some(Arc::new(model)) - } + Ok(model) => Some(Arc::new(model)), Err(e) => { - eprintln!("Failed to load neural network: {:?}", e); - eprintln!("Falling back to uniform priors"); - None + eprintln!("Failed to load neural network: {e:?}"); + return ExitCode::from(1); } } }; - let mode_name = if nn.is_some() { - "NN-guided PUCT" + let mode_name = if uniform_vs_uniform { + "Uniform priors" } else { - "PUCT (uniform priors)" + "Trained priors" }; - println!("{} vs Traditional MCTS", mode_name); + println!("{} vs Uniform-prior baseline", mode_name); println!("========================================"); - println!("Games: {}", num_games); + println!("Pairs: {}", num_pairs); + println!("Games: {}", num_pairs * 2); println!("Playouts per move: {}", num_playouts); println!(); - // Result type including timing - struct GameResult { - game_num: usize, - nn_player: usize, - nn_score: usize, - mcts_player: usize, - mcts_score: usize, - move_count: usize, - result: usize, // 0=nn_win, 1=mcts_win, 2=draw - nn_time: Duration, - mcts_time: Duration, - } - - // Run games in parallel - let results: Vec = (0..num_games) + let results: Vec = (0..num_pairs) .into_par_iter() - .map(|game_num| { - // Alternate who goes first - let nn_player = game_num % 2; - let mcts_player = 1 - nn_player; - - let mut rng = StdRng::seed_from_u64(game_num as u64); - let mut game = GameState::new_two_player(&mut rng); - - let mut nn_bot = MCTSBot::with_neural_net( - game.clone(), - Player { id: nn_player }, - nn.clone(), - ); - let mut mcts_bot = MCTSBot::new(game.clone(), Player { id: mcts_player }); - - let mut move_count = 0; - let mut nn_time = Duration::ZERO; - let mut mcts_time = Duration::ZERO; - - while let Some(p) = game.active_player() { - let action = if p.id == nn_player { - let start = Instant::now(); - for _ in 0..num_playouts { - nn_bot.playout(); - } - let result = nn_bot.take_action(); - nn_time += start.elapsed(); - result - } else { - let start = Instant::now(); - for _ in 0..num_playouts { - mcts_bot.playout(); - } - let result = mcts_bot.take_action(); - mcts_time += start.elapsed(); - result - }; - - game.apply_action(&action).unwrap(); - nn_bot.update(&game); - mcts_bot.update(&game); - move_count += 1; - } - - let scores = game.get_scores(); - let nn_score = scores[nn_player]; - let mcts_score = scores[mcts_player]; - - let result = if nn_score > mcts_score { - 0 // NN wins - } else if mcts_score > nn_score { - 1 // MCTS wins - } else { - 2 // Draw - }; - - GameResult { - game_num, - nn_player, - nn_score, - mcts_player, - mcts_score, - move_count, - result, - nn_time, - mcts_time, - } + .flat_map_iter(|pair| { + let seed = pair as u64; + [ + play_game(pair * 2, seed, 0, num_playouts, nn.clone()), + play_game(pair * 2 + 1, seed, 1, num_playouts, nn.clone()), + ] }) .collect(); - // Print results in order and tally - let mut nn_wins = 0; - let mut mcts_wins = 0; - let mut draws = 0; - let mut total_nn_time = Duration::ZERO; - let mut total_mcts_time = Duration::ZERO; - let mut sorted_results = results; sorted_results.sort_by_key(|r| r.game_num); + let mut model_wins = 0; + let mut baseline_wins = 0; + let mut draws = 0; + let mut total_model_time = Duration::ZERO; + let mut total_baseline_time = Duration::ZERO; + for r in &sorted_results { let result_str = match r.result { - 0 => { nn_wins += 1; "NN wins" } - 1 => { mcts_wins += 1; "MCTS wins" } - _ => { draws += 1; "Draw" } + 0 => { + model_wins += 1; + "Model wins" + } + 1 => { + baseline_wins += 1; + "Baseline wins" + } + _ => { + draws += 1; + "Draw" + } }; - total_nn_time += r.nn_time; - total_mcts_time += r.mcts_time; + total_model_time += r.model_time; + total_baseline_time += r.baseline_time; println!( - "Game {:2}: NN(P{})={:2} vs MCTS(P{})={:2} in {:2} moves - {}", + "Game {:3}: Model(P{})={:2} vs Baseline(P{})={:2} in {:2} moves - {}", r.game_num + 1, - r.nn_player, - r.nn_score, - r.mcts_player, - r.mcts_score, + r.model_player, + r.model_score, + r.baseline_player, + r.baseline_score, r.move_count, result_str ); } + let total_games = sorted_results.len(); + let decisive_games = model_wins + baseline_wins; + let score = (model_wins as f64 + 0.5 * draws as f64) / total_games as f64; + println!(); println!("Results:"); - println!(" NN wins: {} ({:.1}%)", nn_wins, 100.0 * nn_wins as f64 / num_games as f64); - println!(" MCTS wins: {} ({:.1}%)", mcts_wins, 100.0 * mcts_wins as f64 / num_games as f64); - println!(" Draws: {} ({:.1}%)", draws, 100.0 * draws as f64 / num_games as f64); + println!( + " Model wins: {} ({:.1}%)", + model_wins, + 100.0 * model_wins as f64 / total_games as f64 + ); + println!( + " Baseline wins: {} ({:.1}%)", + baseline_wins, + 100.0 * baseline_wins as f64 / total_games as f64 + ); + println!( + " Draws: {} ({:.1}%)", + draws, + 100.0 * draws as f64 / total_games as f64 + ); + println!(" Score: {:.3}", score); + println!(" Decisive: {}", decisive_games); + if let Some(min_score) = min_score { + let gate = if score >= min_score { "PASS" } else { "FAIL" }; + println!(" Gate: {} (min score {:.3})", gate, min_score); + } println!(); println!("Thinking time:"); - println!(" NN total: {:.2}s", total_nn_time.as_secs_f64()); - println!(" MCTS total: {:.2}s", total_mcts_time.as_secs_f64()); - println!(" NN/MCTS ratio: {:.2}x", total_nn_time.as_secs_f64() / total_mcts_time.as_secs_f64()); + println!(" Model total: {:.2}s", total_model_time.as_secs_f64()); + println!( + " Baseline total: {:.2}s", + total_baseline_time.as_secs_f64() + ); + println!( + " Model/Baseline ratio: {:.2}x", + total_model_time.as_secs_f64() / total_baseline_time.as_secs_f64() + ); + + if min_score.is_some_and(|min_score| score < min_score) { + ExitCode::from(1) + } else { + ExitCode::SUCCESS + } } diff --git a/bots/src/lib.rs b/bots/src/lib.rs index 1f13a24..a9c1620 100644 --- a/bots/src/lib.rs +++ b/bots/src/lib.rs @@ -1,9 +1,13 @@ pub mod mctsbot; pub mod minimaxbot; pub mod neuralnet; +pub mod policy; pub mod randombot; -pub use mctsbot::{MCTSBot, MCTSMode}; +pub use mctsbot::{MCTSBot, MCTSMode, PriorProvider, UniformPriorProvider}; pub use minimaxbot::MinimaxBot; pub use neuralnet::NeuralNet; +pub use policy::{ + move_to_direction_distance, move_to_policy_index, MOVEMENT_POLICY_SIZE, POLICY_VERSION, +}; pub use randombot::RandomBot; diff --git a/bots/src/mctsbot.rs b/bots/src/mctsbot.rs index e60b590..9437372 100644 --- a/bots/src/mctsbot.rs +++ b/bots/src/mctsbot.rs @@ -1,22 +1,83 @@ use crate::neuralnet::NeuralNet; -use htmf::board::Board; -use htmf::hex::Cube; +use crate::policy::{move_to_policy_index, policy_size}; use rand::prelude::*; use std::cell::OnceCell; +use std::fmt; use std::sync::Arc; const NUM_PLAYERS: usize = 2; -const NUM_DIRECTIONS: usize = 6; -const MAX_DISTANCE: usize = 7; - /// Exploration constant for PUCT formula (AlphaZero uses ~1.0-2.0) const C_PUCT: f32 = 1.0; +/// Keep the production uniform prior as an anchor while the learned policy is weak. +const DEFAULT_MODEL_PRIOR_WEIGHT: f32 = 0.05; + +fn model_prior_weight() -> f32 { + std::env::var("HTMF_MODEL_PRIOR_WEIGHT") + .ok() + .and_then(|value| value.parse::().ok()) + .filter(|value| (0.0..=1.0).contains(value)) + .unwrap_or(DEFAULT_MODEL_PRIOR_WEIGHT) +} + +#[derive(Debug)] +pub struct PriorError { + message: String, +} + +impl PriorError { + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +impl fmt::Display for PriorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl std::error::Error for PriorError {} + +pub trait PriorProvider: Send + Sync { + fn policy_logits( + &self, + game: &htmf::game::GameState, + current_player: usize, + ) -> Result, PriorError>; +} + +#[derive(Debug, Default)] +pub struct UniformPriorProvider; + +impl PriorProvider for UniformPriorProvider { + fn policy_logits( + &self, + game: &htmf::game::GameState, + _current_player: usize, + ) -> Result, PriorError> { + Ok(vec![0.0; policy_size(!game.finished_drafting())]) + } +} + +impl PriorProvider for NeuralNet { + fn policy_logits( + &self, + game: &htmf::game::GameState, + current_player: usize, + ) -> Result, PriorError> { + self.predict(game, current_player) + .map(|output| output.policy_logits) + .map_err(|err| PriorError::new(format!("ONNX prior inference failed: {err}"))) + } +} /** * Games are connected to each other via Moves. */ -#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum Move { Place(u8), Move((u8, u8)), @@ -116,7 +177,7 @@ impl TreeNode { &mut self, game: &Game, policy_logits: &[f32], - current_player: usize, + _current_player: usize, node_count: &mut usize, ) { if self.children.get().is_some() { @@ -126,31 +187,29 @@ impl TreeNode { let is_drafting = !game.state.finished_drafting(); let moves: Vec = game.available_moves().collect(); - // Get sorted penguin list for consistent indexing (only needed for movement) - let mut penguins: Vec = game.state.board.penguins[current_player] - .into_iter() - .collect(); - penguins.sort(); - // Convert logits to probabilities with softmax over legal moves only let mut max_logit = f32::NEG_INFINITY; for m in &moves { - let idx = move_to_policy_index(m, is_drafting, &penguins); - max_logit = max_logit.max(policy_logits[idx]); + let idx = move_to_policy_index(m, is_drafting); + max_logit = max_logit.max(policy_logits.get(idx).copied().unwrap_or(0.0)); } let mut sum_exp = 0.0f32; let mut priors: Vec = Vec::with_capacity(moves.len()); for m in &moves { - let idx = move_to_policy_index(m, is_drafting, &penguins); - let exp_val = (policy_logits[idx] - max_logit).exp(); + let idx = move_to_policy_index(m, is_drafting); + let exp_val = (policy_logits.get(idx).copied().unwrap_or(0.0) - max_logit).exp(); priors.push(exp_val); sum_exp += exp_val; } - // Normalize + // Normalize and blend with uniform so weak models cannot completely + // dominate the production baseline search. + let model_prior_weight = model_prior_weight(); + let uniform_prior = 1.0 / priors.len() as f32; for p in &mut priors { - *p /= sum_exp; + *p = model_prior_weight * (*p / sum_exp) + + (1.0 - model_prior_weight) * uniform_prior; } let children: Vec<_> = moves @@ -182,89 +241,6 @@ impl TreeNode { } } -/// Convert a move from (src, dst) to (direction, distance) -/// Direction is 0-5 based on Cube::neighbors() order -/// Distance is 1-7 (number of cells traveled) -fn move_to_direction_distance(src: u8, dst: u8) -> Option<(usize, usize)> { - let src_hex = Board::index_to_evenr(src); - let dst_hex = Board::index_to_evenr(dst); - let src_cube = Cube::from_evenr(&src_hex); - let dst_cube = Cube::from_evenr(&dst_hex); - - // Calculate the delta in cube coordinates - let dx = dst_cube.x - src_cube.x; - let dy = dst_cube.y - src_cube.y; - let dz = dst_cube.z - src_cube.z; - - // Determine direction based on which axis is constant (the other two change) - // Direction 0: (+x, -y, 0z) East - // Direction 1: (+x, 0y, -z) Northeast - // Direction 2: (0x, +y, -z) Northwest - // Direction 3: (-x, +y, 0z) West - // Direction 4: (-x, 0y, +z) Southwest - // Direction 5: (0x, -y, +z) Southeast - - let direction = if dz == 0 { - // z constant: East (0) or West (3) - if dx > 0 { - 0 - } else { - 3 - } - } else if dy == 0 { - // y constant: Northeast (1) or Southwest (4) - if dx > 0 { - 1 - } else { - 4 - } - } else if dx == 0 { - // x constant: Northwest (2) or Southeast (5) - if dy > 0 { - 2 - } else { - 5 - } - } else { - // Not a valid hex line move - return None; - }; - - // Distance is the absolute delta on any non-zero axis - let distance = dx.abs().max(dy.abs()).max(dz.abs()) as usize; - - if distance == 0 || distance > MAX_DISTANCE { - return None; - } - - Some((direction, distance)) -} - -/// Convert a move to its index in the policy output -/// For movement phase, this uses the compressed format: penguin_idx * 42 + direction * 7 + (distance - 1) -fn move_to_policy_index(m: &Move, is_drafting: bool, penguins: &[u8]) -> usize { - match m { - Move::Place(dst) => { - debug_assert!(is_drafting); - *dst as usize - } - Move::Move((src, dst)) => { - debug_assert!(!is_drafting); - // Find penguin index - let penguin_idx = penguins.iter().position(|&p| p == *src).unwrap_or(0); - // Get direction and distance - if let Some((direction, distance)) = move_to_direction_distance(*src, *dst) { - penguin_idx * (NUM_DIRECTIONS * MAX_DISTANCE) - + direction * MAX_DISTANCE - + (distance - 1) - } else { - // Fallback - should not happen with valid moves - 0 - } - } - } -} - #[derive(Default, Debug)] pub struct RewardsVisits { rewards: f32, @@ -399,7 +375,7 @@ fn get_reward(game: &htmf::game::GameState, p: usize) -> f32 { fn playout_puct( root: &mut TreeNode, root_game: &Game, - nn: &Option>, + prior_provider: &Arc, node_count: &mut usize, ) -> (Vec, Game) { let rng = &mut rand::rng(); @@ -418,18 +394,11 @@ fn playout_puct( // At a leaf - expand with priors (from NN if available, otherwise uniform) if !game.state.game_over() { let current_player = game.current_player().id; - if let Some(nn) = nn { - let output = nn - .predict(&game.state, current_player) - .expect("NN prediction failed"); - expand_node.expand_with_priors( - &game, - &output.policy_logits, - current_player, - node_count, - ); - } else { - expand_node.expand_with_uniform_priors(&game, node_count); + match prior_provider.policy_logits(&game.state, current_player) { + Ok(policy_logits) => { + expand_node.expand_with_priors(&game, &policy_logits, current_player, node_count); + } + Err(_) => expand_node.expand_with_uniform_priors(&game, node_count), } // Select one child to expand into (using PUCT) @@ -477,12 +446,10 @@ pub struct UpdateStats { /// Mode of operation for MCTS #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum MCTSMode { - /// Pure MCTS with UCB1 selection and random rollouts + /// Production baseline: PUCT with uniform priors and random rollouts. Pure, /// PUCT selection with random rollouts. - /// Uses neural network policy priors if available, otherwise uniform priors. - /// This mode outperforms Pure MCTS and provides a foundation for incremental - /// improvement through training - each generation can improve on the last. + /// Uses neural network policy priors when a model is provided. NeuralNet, } @@ -491,32 +458,13 @@ pub struct MCTSBot { pub root_game: Game, pub me: htmf::board::Player, pub num_nodes: usize, - /// Optional neural network for guided search - nn: Option>, - /// Mode of operation - mode: MCTSMode, + /// Source of policy priors used by PUCT. + prior_provider: Arc, } impl MCTSBot { pub fn new(game: htmf::game::GameState, me: htmf::board::Player) -> Self { - let root_game = Game { - state: game.clone(), - }; - let mut bot = MCTSBot { - root: TreeNode::new(), - root_game, - me, - num_nodes: 1, - nn: None, - mode: MCTSMode::Pure, - }; - - if game.active_player().is_some() { - bot.root - .expand_with_uniform_priors(&bot.root_game.clone(), &mut bot.num_nodes); - } - - bot + Self::with_prior_provider(game, me, Arc::new(UniformPriorProvider), MCTSMode::Pure) } /// Create an MCTS bot that uses PUCT selection with random rollouts. @@ -525,12 +473,24 @@ impl MCTSBot { /// (guiding which moves to explore first). Otherwise, uniform priors are used. /// /// This mode uses random rollouts for leaf evaluation (not NN value prediction), - /// which provides a strong baseline that outperforms pure UCB1-based MCTS. - /// Training can then incrementally improve the policy priors. + /// matching the uniform-prior production baseline. Training can then + /// incrementally improve the policy priors. pub fn with_neural_net( game: htmf::game::GameState, me: htmf::board::Player, nn: Option>, + ) -> Self { + let prior_provider: Arc = nn + .map(|nn| nn as Arc) + .unwrap_or_else(|| Arc::new(UniformPriorProvider)); + Self::with_prior_provider(game, me, prior_provider, MCTSMode::NeuralNet) + } + + pub fn with_prior_provider( + game: htmf::game::GameState, + me: htmf::board::Player, + prior_provider: Arc, + _mode: MCTSMode, ) -> Self { let mut bot = MCTSBot { root: TreeNode::new(), @@ -539,11 +499,10 @@ impl MCTSBot { }, me, num_nodes: 1, - nn, - mode: MCTSMode::NeuralNet, + prior_provider, }; - // Initialize root with priors (from NN if available, otherwise uniform) + // Initialize root with priors. The uniform provider is the production baseline. if let Some(p) = game.active_player() { bot.expand_node_with_priors(&bot.root_game.clone(), p.id); } @@ -553,19 +512,22 @@ impl MCTSBot { /// Expand a node with priors (from NN if available, otherwise uniform) fn expand_node_with_priors(&mut self, game: &Game, current_player: usize) { - if let Some(nn) = &self.nn { - let output = nn - .predict(&game.state, current_player) - .expect("NN prediction failed"); - self.root.expand_with_priors( - game, - &output.policy_logits, - current_player, - &mut self.num_nodes, - ); - } else { - self.root - .expand_with_uniform_priors(game, &mut self.num_nodes); + match self + .prior_provider + .policy_logits(&game.state, current_player) + { + Ok(policy_logits) => { + self.root.expand_with_priors( + game, + &policy_logits, + current_player, + &mut self.num_nodes, + ); + } + Err(_) => { + self.root + .expand_with_uniform_priors(game, &mut self.num_nodes); + } } } @@ -593,20 +555,8 @@ impl MCTSBot { // New state not found in tree - start fresh self.root = dummy; - // Initialize based on mode - match self.mode { - MCTSMode::NeuralNet => { - if let Some(p) = game_state.active_player() { - self.expand_node_with_priors(&new_game, p.id); - } - } - MCTSMode::Pure => { - // Pure mode also uses PUCT, so pre-expand with uniform priors - if game_state.active_player().is_some() { - self.root - .expand_with_uniform_priors(&new_game, &mut self.num_nodes); - } - } + if let Some(p) = game_state.active_player() { + self.expand_node_with_priors(&new_game, p.id); } } @@ -618,7 +568,7 @@ impl MCTSBot { let (path, game) = playout_puct( &mut self.root, &self.root_game, - &self.nn, + &self.prior_provider, &mut self.num_nodes, ); backprop(&mut self.root, &self.root_game, path, game); @@ -703,6 +653,42 @@ impl MCTSBot { } } + pub fn update_root_priors_from_logits(&mut self, policy_logits: &[f32]) { + if self.root_game.state.active_player().is_none() { + return; + } + let Some(children) = self.root.children.get_mut() else { + return; + }; + + let is_drafting = !self.root_game.state.finished_drafting(); + let mut max_logit = f32::NEG_INFINITY; + for (m, _) in children.iter() { + let idx = move_to_policy_index(m, is_drafting); + max_logit = max_logit.max(policy_logits.get(idx).copied().unwrap_or(0.0)); + } + + let mut sum_exp = 0.0f32; + let mut priors = Vec::with_capacity(children.len()); + for (m, _) in children.iter() { + let idx = move_to_policy_index(m, is_drafting); + let exp_val = (policy_logits.get(idx).copied().unwrap_or(0.0) - max_logit).exp(); + priors.push(exp_val); + sum_exp += exp_val; + } + + if sum_exp <= 0.0 || !sum_exp.is_finite() { + return; + } + + let model_prior_weight = model_prior_weight(); + let uniform_prior = 1.0 / priors.len() as f32; + for ((_, child), prior) in children.iter_mut().zip(priors) { + child.prior = model_prior_weight * (prior / sum_exp) + + (1.0 - model_prior_weight) * uniform_prior; + } + } + pub fn tree_size(&self) -> usize { self.num_nodes } @@ -798,6 +784,31 @@ fn test_tree_size_optimization() { assert_eq!(bot.tree_size(), bot.calculate_tree_size()); } +#[test] +fn test_uniform_prior_paths_match_at_root() { + use htmf::board::Player; + use htmf::game::GameState; + + let game = GameState::new_two_player::(&mut SeedableRng::seed_from_u64(7)); + let baseline = MCTSBot::new(game.clone(), Player { id: 0 }); + let explicit_uniform = MCTSBot::with_neural_net(game, Player { id: 0 }, None); + + let baseline_children = baseline.root.children.get().unwrap(); + let explicit_children = explicit_uniform.root.children.get().unwrap(); + + assert_eq!(baseline_children.len(), explicit_children.len()); + for ((baseline_move, baseline_child), (uniform_move, uniform_child)) in + baseline_children.iter().zip(explicit_children) + { + assert_eq!(baseline_move, uniform_move); + assert_eq!(baseline_child.prior, uniform_child.prior); + assert_eq!( + baseline_child.rewards_visits.get(), + uniform_child.rewards_visits.get() + ); + } +} + #[test] fn test_memory_usage() { use htmf::board::Player; @@ -829,8 +840,7 @@ fn test_neural_network_guided_game() { // Load neural network let nn = Arc::new( - NeuralNet::load("../training/artifacts/model.onnx") - .expect("Failed to load neural network"), + NeuralNet::load("../training/artifacts/model.onnx").expect("Failed to load neural network"), ); let mut game = GameState::new_two_player::(&mut SeedableRng::seed_from_u64(42)); diff --git a/bots/src/neuralnet.rs b/bots/src/neuralnet.rs index 170d4ae..3e65d45 100644 --- a/bots/src/neuralnet.rs +++ b/bots/src/neuralnet.rs @@ -1,3 +1,4 @@ +use crate::policy::policy_size; use htmf::NUM_CELLS; use tract_onnx::prelude::*; @@ -12,7 +13,7 @@ pub struct NeuralNet { /// Output from neural network inference pub struct NeuralNetOutput { - /// Policy logits (60 for drafting, 3600 for movement) + /// Policy logits (60 for drafting, 2520 for movement) pub policy_logits: Vec, /// Value estimate (win probability for current player) pub value: f32, @@ -23,7 +24,7 @@ impl NeuralNet { /// /// The model should have: /// - Input: features (1, 480) - /// - Outputs: drafting_policy (1, 60), movement_policy (1, 168), value (1, 1) + /// - Outputs: drafting_policy (1, 60), movement_policy (1, 2520), value (1, 1) pub fn load(model_path: &str) -> TractResult { let model = tract_onnx::onnx() .model_for_path(model_path)? @@ -52,8 +53,8 @@ impl NeuralNet { let features = extract_features(game, current_player); let is_drafting = !game.finished_drafting(); - let input: Tensor = tract_ndarray::Array2::from_shape_vec((1, NUM_FEATURES), features)? - .into(); + let input: Tensor = + tract_ndarray::Array2::from_shape_vec((1, NUM_FEATURES), features)?.into(); let outputs = self.model.run(tvec!(input.into()))?; @@ -65,6 +66,19 @@ impl NeuralNet { .iter() .copied() .collect(); + let expected_policy_size = policy_size(is_drafting); + if policy_logits.len() != expected_policy_size { + return Err(TractError::msg(format!( + "model output {} has {} logits, expected {} for policy encoding v2", + if is_drafting { + "drafting_policy" + } else { + "movement_policy" + }, + policy_logits.len(), + expected_policy_size + ))); + } // Convert tanh output [-1, 1] to probability [0, 1] let raw_value: f32 = outputs[2].to_array_view::()?[[0, 0]]; diff --git a/bots/src/policy.rs b/bots/src/policy.rs new file mode 100644 index 0000000..4c4322f --- /dev/null +++ b/bots/src/policy.rs @@ -0,0 +1,119 @@ +use htmf::NUM_CELLS; + +use crate::mctsbot::Move; + +pub const NUM_DIRECTIONS: usize = 6; +pub const MAX_DISTANCE: usize = 7; +pub const POLICY_VERSION: u8 = 2; +pub const MOVEMENT_POLICY_SIZE: usize = NUM_CELLS * NUM_DIRECTIONS * MAX_DISTANCE; + +/// Convert a move from (src, dst) to (direction, distance). +/// +/// Direction is 0-5 using the same cube-axis convention everywhere policy +/// targets and priors are encoded. +pub fn move_to_direction_distance(src: u8, dst: u8) -> Option<(usize, usize)> { + let src_hex = htmf::board::Board::index_to_evenr(src); + let dst_hex = htmf::board::Board::index_to_evenr(dst); + let src_cube = htmf::hex::Cube::from_evenr(&src_hex); + let dst_cube = htmf::hex::Cube::from_evenr(&dst_hex); + + let dx = dst_cube.x - src_cube.x; + let dy = dst_cube.y - src_cube.y; + let dz = dst_cube.z - src_cube.z; + + let direction = if dz == 0 { + if dx > 0 { + 0 + } else { + 3 + } + } else if dy == 0 { + if dx > 0 { + 1 + } else { + 4 + } + } else if dx == 0 { + if dy > 0 { + 2 + } else { + 5 + } + } else { + return None; + }; + + let distance = dx.abs().max(dy.abs()).max(dz.abs()) as usize; + if distance == 0 || distance > MAX_DISTANCE { + return None; + } + + Some((direction, distance)) +} + +/// Convert a game move to the model policy output index. +/// +/// Drafting uses one logit per board cell. Movement uses the absolute +/// `src_cell * 42 + direction * 7 + (distance - 1)` encoding. +pub fn move_to_policy_index(m: &Move, is_drafting: bool) -> usize { + match m { + Move::Place(dst) => { + debug_assert!(is_drafting); + *dst as usize + } + Move::Move((src, dst)) => { + debug_assert!(!is_drafting); + if let Some((direction, distance)) = move_to_direction_distance(*src, *dst) { + *src as usize * (NUM_DIRECTIONS * MAX_DISTANCE) + + direction * MAX_DISTANCE + + (distance - 1) + } else { + 0 + } + } + } +} + +pub fn policy_size(is_drafting: bool) -> usize { + if is_drafting { + NUM_CELLS + } else { + MOVEMENT_POLICY_SIZE + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn placement_policy_index_is_cell_index() { + assert_eq!(move_to_policy_index(&Move::Place(17), true), 17); + } + + #[test] + fn movement_policy_index_uses_source_cell() { + let (direction, distance) = move_to_direction_distance(11, 12).unwrap(); + let expected = 11 * (NUM_DIRECTIONS * MAX_DISTANCE) + + direction * MAX_DISTANCE + + (distance - 1); + + assert_eq!( + move_to_policy_index(&Move::Move((11, 12)), false), + expected + ); + } + + #[test] + fn movement_policy_index_stays_in_range_for_all_board_moves() { + let mut rng = rand::rng(); + let board = htmf::board::Board::new(&mut rng); + + for src in 0..NUM_CELLS as u8 { + for dst in board.moves(src) { + let idx = move_to_policy_index(&Move::Move((src, dst)), false); + assert!(idx < MOVEMENT_POLICY_SIZE); + } + } + } +} diff --git a/bun.lock b/bun.lock index 17e81bd..34c212c 100644 --- a/bun.lock +++ b/bun.lock @@ -9,13 +9,14 @@ }, "wasm/pkg": { "name": "htmf-wasm", - "version": "0.0.0", + "version": "0.1.0", }, "www": { "name": "htmf", "version": "0.5.0", "dependencies": { "htmf-wasm": "*", + "onnxruntime-web": "^1.22.0", "react": "19.2.4", "react-dom": "19.2.4", }, @@ -39,7 +40,7 @@ "prettier": "3.8.1", "typescript": "6.0.2", "typescript-eslint": "8.58.0", - "vite": "8.0.3", + "vite": "8.0.5", "vitest": "4.1.2", "vitest-browser-react": "2.2.0", }, @@ -134,6 +135,26 @@ "@polka/url": ["@polka/url@1.0.0-next.29", "", {}, "sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww=="], + "@protobufjs/aspromise": ["@protobufjs/aspromise@1.1.2", "", {}, "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ=="], + + "@protobufjs/base64": ["@protobufjs/base64@1.1.2", "", {}, "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg=="], + + "@protobufjs/codegen": ["@protobufjs/codegen@2.0.5", "", {}, "sha512-zgXFLzW3Ap33e6d0Wlj4MGIm6Ce8O89n/apUaGNB/jx+hw+ruWEp7EwGUshdLKVRCxZW12fp9r40E1mQrf/34g=="], + + "@protobufjs/eventemitter": ["@protobufjs/eventemitter@1.1.0", "", {}, "sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q=="], + + "@protobufjs/fetch": ["@protobufjs/fetch@1.1.0", "", { "dependencies": { "@protobufjs/aspromise": "^1.1.1", "@protobufjs/inquire": "^1.1.0" } }, "sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ=="], + + "@protobufjs/float": ["@protobufjs/float@1.0.2", "", {}, "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ=="], + + "@protobufjs/inquire": ["@protobufjs/inquire@1.1.1", "", {}, "sha512-mnzgDV26ueAvk7rsbt9L7bE0SuAoqyuys/sMMrmVcN5x9VsxpcG3rqAUSgDyLp0UZlmNfIbQ4fHfCtreVBk8Ew=="], + + "@protobufjs/path": ["@protobufjs/path@1.1.2", "", {}, "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA=="], + + "@protobufjs/pool": ["@protobufjs/pool@1.1.0", "", {}, "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw=="], + + "@protobufjs/utf8": ["@protobufjs/utf8@1.1.1", "", {}, "sha512-oOAWABowe8EAbMyWKM0tYDKi8Yaox52D+HWZhAIJqQXbqe0xI/GV7FhLWqlEKreMkfDjshR5FKgi3mnle0h6Eg=="], + "@rolldown/binding-android-arm64": ["@rolldown/binding-android-arm64@1.0.0-rc.12", "", { "os": "android", "cpu": "arm64" }, "sha512-pv1y2Fv0JybcykuiiD3qBOBdz6RteYojRFY1d+b95WVuzx211CRh+ytI/+9iVyWQ6koTh5dawe4S/yRfOFjgaA=="], "@rolldown/binding-darwin-arm64": ["@rolldown/binding-darwin-arm64@1.0.0-rc.12", "", { "os": "darwin", "cpu": "arm64" }, "sha512-cFYr6zTG/3PXXF3pUO+umXxt1wkRK/0AYT8lDwuqvRC+LuKYWSAQAQZjCWDQpAH172ZV6ieYrNnFzVVcnSflAg=="], @@ -350,6 +371,8 @@ "flat-cache": ["flat-cache@4.0.1", "", { "dependencies": { "flatted": "^3.2.9", "keyv": "^4.5.4" } }, "sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw=="], + "flatbuffers": ["flatbuffers@25.9.23", "", {}, "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ=="], + "flatted": ["flatted@3.3.3", "", {}, "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg=="], "fs-extra": ["fs-extra@11.3.2", "", { "dependencies": { "graceful-fs": "^4.2.0", "jsonfile": "^6.0.1", "universalify": "^2.0.0" } }, "sha512-Xr9F6z6up6Ws+NjzMCZc6WXg2YFRlrLP9NQDO3VQrWrfiojdhS56TzueT88ze0uBdCTwEIhQ3ptnmKeWGFAe0A=="], @@ -368,6 +391,8 @@ "graceful-fs": ["graceful-fs@4.2.11", "", {}, "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ=="], + "guid-typescript": ["guid-typescript@1.0.9", "", {}, "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ=="], + "has-flag": ["has-flag@4.0.0", "", {}, "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ=="], "hermes-estree": ["hermes-estree@0.25.1", "", {}, "sha512-0wUoCcLp+5Ev5pDW2OriHC2MJCbwLwuRx+gAqMTOkGKJJiBCLjtrvy4PWUGn6MIVefecRpzoOZ/UV6iGdOr+Cw=="], @@ -440,6 +465,8 @@ "lodash.merge": ["lodash.merge@4.6.2", "", {}, "sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ=="], + "long": ["long@5.3.2", "", {}, "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA=="], + "lru-cache": ["lru-cache@5.1.1", "", { "dependencies": { "yallist": "^3.0.2" } }, "sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w=="], "magic-string": ["magic-string@0.30.21", "", { "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.5" } }, "sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ=="], @@ -464,6 +491,10 @@ "obug": ["obug@2.1.1", "", {}, "sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ=="], + "onnxruntime-common": ["onnxruntime-common@1.24.3", "", {}, "sha512-GeuPZO6U/LBJXvwdaqHbuUmoXiEdeCjWi/EG7Y1HNnDwJYuk6WUbNXpF6luSUY8yASul3cmUlLGrCCL1ZgVXqA=="], + + "onnxruntime-web": ["onnxruntime-web@1.24.3", "", { "dependencies": { "flatbuffers": "^25.1.24", "guid-typescript": "^1.0.9", "long": "^5.2.3", "onnxruntime-common": "1.24.3", "platform": "^1.3.6", "protobufjs": "^7.2.4" } }, "sha512-41dDq7fxtTm0XzGE7N0d6m8FcOY8EWtUA65GkOixJPB/G7DGzBmiDAnVVXHznRw9bgUZpb+4/1lQK/PNxGpbrQ=="], + "optionator": ["optionator@0.9.4", "", { "dependencies": { "deep-is": "^0.1.3", "fast-levenshtein": "^2.0.6", "levn": "^0.4.1", "prelude-ls": "^1.2.1", "type-check": "^0.4.0", "word-wrap": "^1.2.5" } }, "sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g=="], "p-limit": ["p-limit@3.1.0", "", { "dependencies": { "yocto-queue": "^0.1.0" } }, "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ=="], @@ -488,6 +519,8 @@ "pkg-dir": ["pkg-dir@4.2.0", "", { "dependencies": { "find-up": "^4.0.0" } }, "sha512-HRDzbaKjC+AOWVXxAU/x54COGeIv9eb+6CkDSQoNTt4XyWoIJvuPsXizxu/Fr23EiekbtZwmh1IcIG/l/a10GQ=="], + "platform": ["platform@1.3.6", "", {}, "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg=="], + "playwright": ["playwright@1.59.1", "", { "dependencies": { "playwright-core": "1.59.1" }, "optionalDependencies": { "fsevents": "2.3.2" }, "bin": { "playwright": "cli.js" } }, "sha512-C8oWjPR3F81yljW9o5OxcWzfh6avkVwDD2VYdwIGqTkl+OGFISgypqzfu7dOe4QNLL2aqcWBmI3PMtLIK233lw=="], "playwright-core": ["playwright-core@1.59.1", "", { "bin": { "playwright-core": "cli.js" } }, "sha512-HBV/RJg81z5BiiZ9yPzIiClYV/QMsDCKUyogwH9p3MCP6IYjUFu/MActgYAvK0oWyV9NlwM3GLBjADyWgydVyg=="], @@ -500,6 +533,8 @@ "prettier": ["prettier@3.8.1", "", { "bin": { "prettier": "bin/prettier.cjs" } }, "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg=="], + "protobufjs": ["protobufjs@7.5.5", "", { "dependencies": { "@protobufjs/aspromise": "^1.1.2", "@protobufjs/base64": "^1.1.2", "@protobufjs/codegen": "^2.0.4", "@protobufjs/eventemitter": "^1.1.0", "@protobufjs/fetch": "^1.1.0", "@protobufjs/float": "^1.0.2", "@protobufjs/inquire": "^1.1.0", "@protobufjs/path": "^1.1.2", "@protobufjs/pool": "^1.1.0", "@protobufjs/utf8": "^1.1.0", "@types/node": ">=13.7.0", "long": "^5.0.0" } }, "sha512-3wY1AxV+VBNW8Yypfd1yQY9pXnqTAN+KwQxL8iYm3/BjKYMNg4i0owhEe26PWDOMaIrzeeF98Lqd5NGz4omiIg=="], + "punycode": ["punycode@2.3.1", "", {}, "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg=="], "queue-microtask": ["queue-microtask@1.2.3", "", {}, "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A=="], @@ -574,7 +609,7 @@ "uri-js": ["uri-js@4.4.1", "", { "dependencies": { "punycode": "^2.1.0" } }, "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg=="], - "vite": ["vite@8.0.3", "", { "dependencies": { "lightningcss": "^1.32.0", "picomatch": "^4.0.4", "postcss": "^8.5.8", "rolldown": "1.0.0-rc.12", "tinyglobby": "^0.2.15" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^20.19.0 || >=22.12.0", "@vitejs/devtools": "^0.1.0", "esbuild": "^0.27.0", "jiti": ">=1.21.0", "less": "^4.0.0", "sass": "^1.70.0", "sass-embedded": "^1.70.0", "stylus": ">=0.54.8", "sugarss": "^5.0.0", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "@vitejs/devtools", "esbuild", "jiti", "less", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-B9ifbFudT1TFhfltfaIPgjo9Z3mDynBTJSUYxTjOQruf/zHH+ezCQKcoqO+h7a9Pw9Nm/OtlXAiGT1axBgwqrQ=="], + "vite": ["vite@8.0.5", "", { "dependencies": { "lightningcss": "^1.32.0", "picomatch": "^4.0.4", "postcss": "^8.5.8", "rolldown": "1.0.0-rc.12", "tinyglobby": "^0.2.15" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^20.19.0 || >=22.12.0", "@vitejs/devtools": "^0.1.0", "esbuild": "^0.27.0 || ^0.28.0", "jiti": ">=1.21.0", "less": "^4.0.0", "sass": "^1.70.0", "sass-embedded": "^1.70.0", "stylus": ">=0.54.8", "sugarss": "^5.0.0", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "@vitejs/devtools", "esbuild", "jiti", "less", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-nmu43Qvq9UopTRfMx2jOYW5l16pb3iDC1JH6yMuPkpVbzK0k+L7dfsEDH4jRgYFmsg0sTAqkojoZgzLMlwHsCQ=="], "vitest": ["vitest@4.1.2", "", { "dependencies": { "@vitest/expect": "4.1.2", "@vitest/mocker": "4.1.2", "@vitest/pretty-format": "4.1.2", "@vitest/runner": "4.1.2", "@vitest/snapshot": "4.1.2", "@vitest/spy": "4.1.2", "@vitest/utils": "4.1.2", "es-module-lexer": "^2.0.0", "expect-type": "^1.3.0", "magic-string": "^0.30.21", "obug": "^2.1.1", "pathe": "^2.0.3", "picomatch": "^4.0.3", "std-env": "^4.0.0-rc.1", "tinybench": "^2.9.0", "tinyexec": "^1.0.2", "tinyglobby": "^0.2.15", "tinyrainbow": "^3.1.0", "vite": "^6.0.0 || ^7.0.0 || ^8.0.0", "why-is-node-running": "^2.3.0" }, "peerDependencies": { "@edge-runtime/vm": "*", "@opentelemetry/api": "^1.9.0", "@types/node": "^20.0.0 || ^22.0.0 || >=24.0.0", "@vitest/browser-playwright": "4.1.2", "@vitest/browser-preview": "4.1.2", "@vitest/browser-webdriverio": "4.1.2", "@vitest/ui": "4.1.2", "happy-dom": "*", "jsdom": "*" }, "optionalPeers": ["@edge-runtime/vm", "@opentelemetry/api", "@types/node", "@vitest/browser-playwright", "@vitest/browser-preview", "@vitest/browser-webdriverio", "@vitest/ui", "happy-dom", "jsdom"], "bin": { "vitest": "vitest.mjs" } }, "sha512-xjR1dMTVHlFLh98JE3i/f/WePqJsah4A0FK9cc8Ehp9Udk0AZk6ccpIZhh1qJ/yxVWRZ+Q54ocnD8TXmkhspGg=="], @@ -636,6 +671,8 @@ "vitest/picomatch": ["picomatch@4.0.3", "", {}, "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q=="], + "vitest/vite": ["vite@8.0.3", "", { "dependencies": { "lightningcss": "^1.32.0", "picomatch": "^4.0.4", "postcss": "^8.5.8", "rolldown": "1.0.0-rc.12", "tinyglobby": "^0.2.15" }, "optionalDependencies": { "fsevents": "~2.3.3" }, "peerDependencies": { "@types/node": "^20.19.0 || >=22.12.0", "@vitejs/devtools": "^0.1.0", "esbuild": "^0.27.0", "jiti": ">=1.21.0", "less": "^4.0.0", "sass": "^1.70.0", "sass-embedded": "^1.70.0", "stylus": ">=0.54.8", "sugarss": "^5.0.0", "terser": "^5.16.0", "tsx": "^4.8.1", "yaml": "^2.4.2" }, "optionalPeers": ["@types/node", "@vitejs/devtools", "esbuild", "jiti", "less", "sass", "sass-embedded", "stylus", "sugarss", "terser", "tsx", "yaml"], "bin": { "vite": "bin/vite.js" } }, "sha512-B9ifbFudT1TFhfltfaIPgjo9Z3mDynBTJSUYxTjOQruf/zHH+ezCQKcoqO+h7a9Pw9Nm/OtlXAiGT1axBgwqrQ=="], + "@eslint/eslintrc/espree/acorn": ["acorn@8.15.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg=="], "@eslint/eslintrc/espree/eslint-visitor-keys": ["eslint-visitor-keys@4.2.1", "", {}, "sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ=="], @@ -670,6 +707,10 @@ "pkg-dir/find-up/locate-path": ["locate-path@5.0.0", "", { "dependencies": { "p-locate": "^4.1.0" } }, "sha512-t7hw9pI+WvuwNJXwk5zVHpyhIqzg2qTlklJOf0mVxGSbe3Fp2VieZcduNYjaLDoy6p9uGpQEGWG87WpMKlNq8g=="], + "vitest/vite/fsevents": ["fsevents@2.3.3", "", { "os": "darwin" }, "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw=="], + + "vitest/vite/picomatch": ["picomatch@4.0.4", "", {}, "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A=="], + "@eslint/eslintrc/minimatch/brace-expansion/balanced-match": ["balanced-match@1.0.2", "", {}, "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw=="], "eslint-plugin-react-hooks/eslint/@eslint/config-array/@eslint/object-schema": ["@eslint/object-schema@2.1.7", "", {}, "sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA=="], diff --git a/engine/benches/board_bench.rs b/engine/benches/board_bench.rs index b963a6a..3223340 100644 --- a/engine/benches/board_bench.rs +++ b/engine/benches/board_bench.rs @@ -18,9 +18,9 @@ fn benchmark_board_moves(c: &mut Criterion) { } } - c.bench_function("board_moves", |b| b.iter(|| { - black_box(board.moves(black_box(src_idx))) - })); + c.bench_function("board_moves", |b| { + b.iter(|| black_box(board.moves(black_box(src_idx)))) + }); } criterion_group!(benches, benchmark_board_moves); diff --git a/engine/src/board.rs b/engine/src/board.rs index c9e36ff..1ac8468 100644 --- a/engine/src/board.rs +++ b/engine/src/board.rs @@ -214,14 +214,14 @@ impl Board { // For each of the 6 directions for dir in 0..6 { let ray = RAY_MASKS[cell_idx as usize][dir]; - + // If we intersect with occupied cells: let blockers = ray & occupied; if blockers != 0 { // There is a blocker. // If the direction is increasing index (forward), the first blocker is the one with the smallest index (trailing_zeros). // If the direction is decreasing index (backward), the first blocker is the one with the largest index (leading_zeros). - + // We need to know which directions are "positive" and "negative" in terms of bit index. // 0: East (+1) -> Positive // 1: NorthEast (depends on row parity, but generally -row_len) -> Negative @@ -229,7 +229,7 @@ impl Board { // 3: West (-1) -> Negative // 4: SouthWest -> Positive // 5: SouthEast -> Positive - + if dir == 0 || dir == 4 || dir == 5 { // Positive direction: we want everything strictly less than the smallest blocker let first_blocker_idx = blockers.trailing_zeros(); @@ -244,7 +244,7 @@ impl Board { // !((1 << (idx + 1)) - 1) // Or simpler: !((1 << (first_blocker_idx + 1)).wrapping_sub(1)) // Be careful with overflow if idx is 63. - + let mask = if first_blocker_idx == 63 { 0 } else { @@ -257,7 +257,7 @@ impl Board { moves |= ray; } } - + CellSet { data: moves } } @@ -710,8 +710,12 @@ mod tests { let mut board = Board::new::(&mut SeedableRng::seed_from_u64(0)); // Overwrite claimed sets with random data - board.claimed[0] = CellSet { data: claimed_mask_0 }; - board.claimed[1] = CellSet { data: claimed_mask_1 }; + board.claimed[0] = CellSet { + data: claimed_mask_0, + }; + board.claimed[1] = CellSet { + data: claimed_mask_1, + }; // Ensure src is not claimed (moves() might assume it? No, but usually we move from our own penguin) // But moves() signature is just (cell_idx). diff --git a/scripts/train.ts b/scripts/train.ts deleted file mode 100644 index 76391d8..0000000 --- a/scripts/train.ts +++ /dev/null @@ -1,253 +0,0 @@ -#!/usr/bin/env bun -/** - * HTMF Training Script - * - * Runs one iteration of selfplay and training, then evaluates the new model - * against the baseline (uniform prior). If the new model is at least as strong, - * it's promoted; otherwise rolled back. - * - * Key insight: Only use NN-guided selfplay once the model has "graduated" - * (beaten the uniform prior). Until then, use traditional MCTS selfplay - * to generate high-quality training data. - */ - -import { spawn } from "bun"; -import { existsSync, mkdirSync, copyFileSync, unlinkSync, renameSync, writeFileSync } from "fs"; -import { join } from "path"; - -// Configuration -const SELFPLAY_GAMES = 50; -const SELFPLAY_PLAYOUTS = 1000; -const TRAIN_EPOCHS = 10; -const EVAL_GAMES = 20; -const EVAL_PLAYOUTS = 200; - -// Paths (relative to project root) -const ARTIFACTS_DIR = "training/artifacts"; -const MODEL_DRAFTING = join(ARTIFACTS_DIR, "model_drafting.onnx"); -const MODEL_MOVEMENT = join(ARTIFACTS_DIR, "model_movement.onnx"); -const MODEL_FINAL = join(ARTIFACTS_DIR, "model_final.pt"); -const PREV_DRAFTING = join(ARTIFACTS_DIR, "prev_model_drafting.onnx"); -const PREV_MOVEMENT = join(ARTIFACTS_DIR, "prev_model_movement.onnx"); -const PREV_FINAL = join(ARTIFACTS_DIR, "prev_model_final.pt"); -const BLANK_DRAFTING = join(ARTIFACTS_DIR, "blank_model_drafting.onnx"); -const BLANK_MOVEMENT = join(ARTIFACTS_DIR, "blank_model_movement.onnx"); -const TRAINING_DATA = join(ARTIFACTS_DIR, "training_data.jsonl"); -// Marker file indicating the model has graduated (beaten the baseline) -const GRADUATED_MARKER = join(ARTIFACTS_DIR, ".graduated"); - -async function run(cmd: string[], options?: { cwd?: string; stdout?: "pipe" | "inherit" }): Promise<{ exitCode: number; stdout?: string }> { - const proc = spawn({ - cmd, - cwd: options?.cwd, - stdout: options?.stdout ?? "inherit", - stderr: "inherit", - }); - - const exitCode = await proc.exited; - - if (options?.stdout === "pipe") { - const stdout = await new Response(proc.stdout).text(); - return { exitCode, stdout }; - } - - return { exitCode }; -} - -async function runWithOutput(cmd: string[], cwd?: string): Promise { - const proc = spawn({ - cmd, - cwd, - stdout: "pipe", - stderr: "pipe", - }); - - const [stdout, stderr] = await Promise.all([ - new Response(proc.stdout).text(), - new Response(proc.stderr).text(), - ]); - - await proc.exited; - return stdout + stderr; -} - -function countLines(filePath: string): number { - const file = Bun.file(filePath); - const text = file.size > 0 ? require("fs").readFileSync(filePath, "utf-8") : ""; - return text.split("\n").filter((line: string) => line.trim()).length; -} - -async function mergeTrainingData(selfplayFile: string, maxSamples = 100000): Promise { - const newData = await Bun.file(selfplayFile).text(); - const newLines = newData.split("\n").filter(line => line.trim()); - - let allLines: string[]; - - if (existsSync(TRAINING_DATA)) { - const existingData = await Bun.file(TRAINING_DATA).text(); - const existingLines = existingData.split("\n").filter(line => line.trim()); - allLines = [...existingLines, ...newLines]; - } else { - allLines = newLines; - } - - // Keep only the most recent samples - if (allLines.length > maxSamples) { - allLines = allLines.slice(-maxSamples); - } - - await Bun.write(TRAINING_DATA, allLines.join("\n") + "\n"); - return allLines.length; -} - -function parseEvalResults(output: string): { nnWins: number; mctsWins: number; draws: number } { - const nnMatch = output.match(/NN wins:\s*(\d+)/); - const mctsMatch = output.match(/MCTS wins:\s*(\d+)/); - const drawsMatch = output.match(/Draws:\s*(\d+)/); - - return { - nnWins: nnMatch ? parseInt(nnMatch[1], 10) : 0, - mctsWins: mctsMatch ? parseInt(mctsMatch[1], 10) : 0, - draws: drawsMatch ? parseInt(drawsMatch[1], 10) : 0, - }; -} - -async function main(): Promise { - console.log("=============================================="); - console.log("HTMF Training Iteration"); - console.log("=============================================="); - - // Ensure artifacts directory exists - if (!existsSync(ARTIFACTS_DIR)) { - mkdirSync(ARTIFACTS_DIR, { recursive: true }); - } - - // Determine if this is the first training run (no model exists) - const isFirstRun = !existsSync(MODEL_DRAFTING) || !existsSync(MODEL_MOVEMENT); - - if (isFirstRun) { - console.log("First training run detected - will compare against uniform prior"); - - // Ensure blank models exist - if (!existsSync(BLANK_DRAFTING) || !existsSync(BLANK_MOVEMENT)) { - console.log("Creating blank models for uniform prior baseline..."); - const result = await run(["uv", "run", "create_blank_models.py"], { cwd: "training" }); - if (result.exitCode !== 0) { - console.error("Failed to create blank models"); - return 1; - } - } - } else { - // Save current model as "previous" (for comparison after training) - console.log("\nBacking up current model..."); - copyFileSync(MODEL_DRAFTING, PREV_DRAFTING); - copyFileSync(MODEL_MOVEMENT, PREV_MOVEMENT); - copyFileSync(MODEL_FINAL, PREV_FINAL); - } - - // Check if model has graduated (beaten baseline before) - const hasGraduated = existsSync(GRADUATED_MARKER); - - // Step 1: Generate selfplay data - // Only use NN-guided selfplay if the model has graduated (beaten baseline) - // Otherwise, use traditional MCTS to generate high-quality training data - const useNn = hasGraduated; - const selfplayMode = useNn ? "NN-guided MCTS" : "traditional MCTS (uniform prior)"; - console.log(`\nStep 1: Generating ${SELFPLAY_GAMES} selfplay games (${SELFPLAY_PLAYOUTS} playouts/move)...`); - console.log(`Using ${selfplayMode} for selfplay`); - - const timestamp = new Date().toISOString().replace(/[-:T.]/g, "").slice(0, 15); - const selfplayFile = join(ARTIFACTS_DIR, `selfplay_${timestamp}.jsonl`); - - const selfplayArgs = ["cargo", "run", "--release", "-p", "selfplay", "--", String(SELFPLAY_GAMES), String(SELFPLAY_PLAYOUTS)]; - if (useNn) { - selfplayArgs.push("--nn"); - } - - const selfplayProc = spawn({ - cmd: selfplayArgs, - stdout: "pipe", - stderr: "inherit", - }); - - const selfplayOutput = await new Response(selfplayProc.stdout).text(); - await Bun.write(selfplayFile, selfplayOutput); - - const selfplayExit = await selfplayProc.exited; - if (selfplayExit !== 0) { - console.error("Selfplay failed"); - return 1; - } - - const numSamples = countLines(selfplayFile); - console.log(`Generated ${numSamples} training samples`); - - // Merge with existing training data - const totalSamples = await mergeTrainingData(selfplayFile); - console.log(`Total training samples: ${totalSamples}`); - - // Clean up selfplay file - unlinkSync(selfplayFile); - - // Step 2: Train the model - console.log(`\nStep 2: Training for ${TRAIN_EPOCHS} epochs...`); - const trainResult = await run(["uv", "run", "train.py", "--epochs", String(TRAIN_EPOCHS)], { cwd: "training" }); - if (trainResult.exitCode !== 0) { - console.error("Training failed"); - return 1; - } - - // Step 3: Evaluate new model vs baseline - console.log(`\nStep 3: Evaluating new model (${EVAL_GAMES} games, ${EVAL_PLAYOUTS} playouts/move)...`); - - const evalOutput = await runWithOutput([ - "cargo", "run", "--release", "--bin", "nn_vs_mcts", "--", - String(EVAL_GAMES), String(EVAL_PLAYOUTS) - ]); - - console.log(evalOutput); - - const { nnWins, mctsWins, draws } = parseEvalResults(evalOutput); - - console.log("\n=============================================="); - console.log(`Results: NN=${nnWins}, MCTS=${mctsWins}, Draws=${draws}`); - - // Determine if new model is at least as strong - if (nnWins >= mctsWins) { - console.log("New model is at least as strong as baseline!"); - console.log("Model promoted successfully."); - - // Mark model as graduated (can now use NN-guided selfplay) - if (!hasGraduated) { - writeFileSync(GRADUATED_MARKER, new Date().toISOString()); - console.log("Model has graduated! Future selfplay will use NN-guided MCTS."); - } - - // Clean up previous model backup - if (existsSync(PREV_DRAFTING)) unlinkSync(PREV_DRAFTING); - if (existsSync(PREV_MOVEMENT)) unlinkSync(PREV_MOVEMENT); - if (existsSync(PREV_FINAL)) unlinkSync(PREV_FINAL); - - console.log("=============================================="); - return 0; - } else { - console.log("New model is weaker than baseline."); - - if (isFirstRun) { - console.log("First model not strong enough yet - keeping it for next iteration."); - console.log("(More training data may help)"); - } else { - console.log("Rolling back to previous model..."); - renameSync(PREV_DRAFTING, MODEL_DRAFTING); - renameSync(PREV_MOVEMENT, MODEL_MOVEMENT); - renameSync(PREV_FINAL, MODEL_FINAL); - console.log("Previous model restored."); - } - - console.log("=============================================="); - return 1; - } -} - -const exitCode = await main(); -process.exit(exitCode); diff --git a/selfplay/src/main.rs b/selfplay/src/main.rs index b5637c5..ce47876 100644 --- a/selfplay/src/main.rs +++ b/selfplay/src/main.rs @@ -7,28 +7,42 @@ extern crate htmf; extern crate htmf_bots; use std::sync::{mpsc, Arc}; +use std::time::{SystemTime, UNIX_EPOCH}; use rayon::prelude::*; use serde::Serialize; use htmf::board::*; use htmf::game::*; -use htmf::hex::Cube; use htmf::NUM_CELLS; use htmf_bots::mctsbot::*; +use htmf_bots::policy::{move_to_policy_index, MOVEMENT_POLICY_SIZE, POLICY_VERSION}; use htmf_bots::NeuralNet; -// Compressed movement policy: 4 penguins × 6 directions × 7 max distances = 168 values -const NUM_PENGUINS: usize = 4; -const NUM_DIRECTIONS: usize = 6; -const MAX_DISTANCE: usize = 7; -pub const MOVEMENT_POLICY_SIZE: usize = NUM_PENGUINS * NUM_DIRECTIONS * MAX_DISTANCE; // 168 - const NUM_PLAYERS: usize = 2; +#[derive(Debug, Clone, Serialize)] +pub struct TeacherMetadata { + pub teacher: String, + pub teacher_playouts: usize, + pub model_prior_weight: f32, + pub run_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub teacher_model: Option, +} + /// A training sample containing the game state, MCTS policy, and eventual outcome #[derive(Debug, Clone, Serialize)] pub struct TrainingSample { + /// Policy encoding version. Version 2 uses absolute movement source-cell encoding. + pub policy_version: u8, + /// Teacher/search configuration that generated this policy target. + pub teacher: String, + pub teacher_playouts: usize, + pub model_prior_weight: f32, + pub run_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub teacher_model: Option, /// Board features as a flat array /// Layout: 8 channels x 60 cells = 480 values /// Channels: @@ -43,13 +57,15 @@ pub struct TrainingSample { pub features: Vec, /// MCTS visit distribution over all possible moves (policy target) /// For placement: 60 values (one per cell) - /// For movement: 168 values (4 penguins × 6 directions × 7 distances) - /// Index = penguin_idx * 42 + direction * 7 + (distance - 1) + /// For movement: 2520 values (60 source cells × 6 directions × 7 distances) + /// Index = source_cell * 42 + direction * 7 + (distance - 1) pub policy: Vec, /// Game outcome from current player's perspective: 1.0 = win, 0.5 = draw, 0.0 = loss pub value: f32, /// Current player (0 or 1) pub player: usize, + /// Zero-based move number when this position was searched. + pub turn: usize, /// Whether this is drafting phase pub is_drafting: bool, /// Ownership prediction target: which player owns each cell at game end @@ -69,7 +85,11 @@ fn main() { // Check for --nn flag let use_nn = args.iter().any(|a| a == "--nn"); - let numeric_args: Vec<&String> = args.iter().skip(1).filter(|a| !a.starts_with("--")).collect(); + let numeric_args: Vec<&String> = args + .iter() + .skip(1) + .filter(|a| !a.starts_with("--")) + .collect(); let ntrials: usize = numeric_args .first() @@ -90,7 +110,7 @@ fn main() { } Err(e) => { eprintln!("Failed to load neural network: {:?}", e); - eprintln!("Falling back to traditional MCTS"); + eprintln!("Falling back to uniform priors"); None } } @@ -98,7 +118,13 @@ fn main() { None }; - let mode = if nn.is_some() { "NN-guided MCTS" } else { "Traditional MCTS" }; + let metadata = Arc::new(build_teacher_metadata(nplayouts, nn.is_some(), use_nn)); + + let mode = if nn.is_some() { + "NN-guided priors" + } else { + "uniform-prior baseline" + }; eprintln!( "Running {} games with {} playouts per move ({})", ntrials, nplayouts, mode @@ -124,7 +150,10 @@ fn main() { let results: Vec<(usize, usize)> = (0..ntrials) .into_par_iter() - .map_with(nn.clone(), |nn, _| play_game(nplayouts, nn.clone())) + .map_with((nn.clone(), metadata.clone()), |state, _| { + let (nn, metadata) = state; + play_game(nplayouts, nn.clone(), metadata.as_ref()) + }) .map_with(mpsc::Sender::clone(&logger_tx), |logger_tx, result| { let _ = logger_tx.send(Ok(result.samples)); (result.winner, result.num_moves) @@ -147,6 +176,53 @@ fn main() { eprintln!("Total training samples: {}", total_samples); } +fn build_teacher_metadata( + nplayouts: usize, + nn_loaded: bool, + nn_requested: bool, +) -> TeacherMetadata { + let requested_teacher = std::env::var("HTMF_SELFPLAY_TEACHER").unwrap_or_else(|_| { + if nn_requested { + "nn_root".to_owned() + } else { + "uniform".to_owned() + } + }); + let teacher = if nn_loaded && requested_teacher == "nn_root" { + "nn_root".to_owned() + } else { + "uniform".to_owned() + }; + let model_prior_weight = if teacher == "nn_root" { + std::env::var("HTMF_MODEL_PRIOR_WEIGHT") + .ok() + .and_then(|value| value.parse::().ok()) + .unwrap_or(0.05) + } else { + 0.0 + }; + let run_id = std::env::var("HTMF_SELFPLAY_RUN_ID").unwrap_or_else(|_| { + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or_default(); + format!("manual_{teacher}_{timestamp}") + }); + let teacher_model = if teacher == "nn_root" { + std::env::var("HTMF_TEACHER_MODEL").ok() + } else { + None + }; + + TeacherMetadata { + teacher, + teacher_playouts: nplayouts, + model_prior_weight, + run_id, + teacher_model, + } +} + struct GameResult { winner: usize, samples: Vec, @@ -203,54 +279,8 @@ fn extract_features(game: &GameState, current_player: usize) -> Vec { features } -/// Convert a move from (src, dst) to (direction, distance) -/// Direction is 0-5 based on Cube::neighbors() order -/// Distance is 1-7 (number of cells traveled) -fn move_to_direction_distance(src: u8, dst: u8) -> Option<(usize, usize)> { - let src_hex = Board::index_to_evenr(src); - let dst_hex = Board::index_to_evenr(dst); - let src_cube = Cube::from_evenr(&src_hex); - let dst_cube = Cube::from_evenr(&dst_hex); - - // Calculate the delta in cube coordinates - let dx = dst_cube.x - src_cube.x; - let dy = dst_cube.y - src_cube.y; - let dz = dst_cube.z - src_cube.z; - - // Determine direction based on which axis is constant (the other two change) - // Direction 0: (+x, -y, 0z) East - // Direction 1: (+x, 0y, -z) Northeast - // Direction 2: (0x, +y, -z) Northwest - // Direction 3: (-x, +y, 0z) West - // Direction 4: (-x, 0y, +z) Southwest - // Direction 5: (0x, -y, +z) Southeast - - let direction = if dz == 0 { - // z constant: East (0) or West (3) - if dx > 0 { 0 } else { 3 } - } else if dy == 0 { - // y constant: Northeast (1) or Southwest (4) - if dx > 0 { 1 } else { 4 } - } else if dx == 0 { - // x constant: Northwest (2) or Southeast (5) - if dy > 0 { 2 } else { 5 } - } else { - // Not a valid hex line move - return None; - }; - - // Distance is the absolute delta on any non-zero axis - let distance = dx.abs().max(dy.abs()).max(dz.abs()) as usize; - - if distance == 0 || distance > MAX_DISTANCE { - return None; - } - - Some((direction, distance)) -} - /// Extract MCTS policy from tree node visit counts -fn extract_policy(mcts: &MCTSBot, game: &GameState, current_player: usize) -> Vec { +fn extract_policy(mcts: &MCTSBot, game: &GameState, _current_player: usize) -> Vec { let is_drafting = !game.finished_drafting(); if is_drafting { @@ -277,30 +307,17 @@ fn extract_policy(mcts: &MCTSBot, game: &GameState, current_player: usize) -> Ve policy } else { - // Compressed movement policy: 4 penguins × 6 directions × 7 distances = 168 values + // Absolute movement policy: 60 source cells × 6 directions × 7 distances = 2520 values let mut policy = vec![0.0f32; MOVEMENT_POLICY_SIZE]; let mut total_visits = 0u32; - // Get current player's penguins in sorted order for consistent indexing - let mut penguins: Vec = game.board.penguins[current_player].into_iter().collect(); - penguins.sort(); - if let Some(children) = mcts.root.children.get() { for (mov, child) in children { if let Move::Move((src, dst)) = mov { - // Find which penguin index this is - let penguin_idx = penguins.iter().position(|&p| p == *src); - if let Some(penguin_idx) = penguin_idx { - if let Some((direction, distance)) = move_to_direction_distance(*src, *dst) { - let (_, visits) = child.rewards_visits.get(); - // Index = penguin_idx * 42 + direction * 7 + (distance - 1) - let idx = penguin_idx * (NUM_DIRECTIONS * MAX_DISTANCE) - + direction * MAX_DISTANCE - + (distance - 1); - policy[idx] = visits as f32; - total_visits += visits; - } - } + let idx = move_to_policy_index(&Move::Move((*src, *dst)), false); + let (_, visits) = child.rewards_visits.get(); + policy[idx] = visits as f32; + total_visits += visits; } } } @@ -329,13 +346,17 @@ fn get_temperature(move_num: usize) -> f32 { } } -fn play_game(nplayouts: usize, nn: Option>) -> GameResult { +fn play_game( + nplayouts: usize, + nn: Option>, + metadata: &TeacherMetadata, +) -> GameResult { let mut game = GameState::new_two_player(&mut rand::rng()); let mut bots: Vec = (0..NUM_PLAYERS) .map(|i| { - // Always use PUCT mode (with_neural_net) - it outperforms pure UCB1 - // Pass the NN if available for policy priors, otherwise uses uniform priors - MCTSBot::with_neural_net(game.clone(), Player { id: i }, nn.clone()) + // Always use the production PUCT path. If an NN is present, apply it + // only to the root before each search, matching browser serving. + MCTSBot::new(game.clone(), Player { id: i }) }) .collect(); @@ -344,12 +365,20 @@ fn play_game(nplayouts: usize, nn: Option>) -> GameResult { features: Vec, policy: Vec, player: usize, + turn: usize, is_drafting: bool, } let mut pending_samples: Vec = vec![]; let mut num_moves = 0; while let Some(p) = game.active_player() { + if let Some(nn) = &nn { + match nn.predict(&game, p.id) { + Ok(output) => bots[p.id].update_root_priors_from_logits(&output.policy_logits), + Err(err) => eprintln!("Model prior inference failed; using uniform root: {err}"), + } + } + // Run MCTS playouts for _ in 0..nplayouts { bots[p.id].playout(); @@ -364,6 +393,7 @@ fn play_game(nplayouts: usize, nn: Option>) -> GameResult { features, policy, player: p.id, + turn: num_moves, is_drafting, }); @@ -433,10 +463,17 @@ fn play_game(nplayouts: usize, nn: Option>) -> GameResult { .map(|s| { let score_diff_bin = (score_diffs[s.player] + 92) as u8; TrainingSample { + policy_version: POLICY_VERSION, + teacher: metadata.teacher.clone(), + teacher_playouts: metadata.teacher_playouts, + model_prior_weight: metadata.model_prior_weight, + run_id: metadata.run_id.clone(), + teacher_model: metadata.teacher_model.clone(), features: s.features, policy: s.policy, value: values[s.player], player: s.player, + turn: s.turn, is_drafting: s.is_drafting, ownership: ownership.clone(), score_diff: score_diff_bin, @@ -454,7 +491,14 @@ fn play_game(nplayouts: usize, nn: Option>) -> GameResult { #[test] fn test_selfplay_generates_samples() { let nplayouts = 50; - let result = play_game(nplayouts, None); + let metadata = TeacherMetadata { + teacher: "uniform".to_owned(), + teacher_playouts: nplayouts, + model_prior_weight: 0.0, + run_id: "test_uniform".to_owned(), + teacher_model: None, + }; + let result = play_game(nplayouts, None, &metadata); // Should have samples for each move in the game assert!( @@ -470,18 +514,29 @@ fn test_selfplay_generates_samples() { // Check first sample has correct feature dimensions let first = &result.samples[0]; + assert_eq!( + first.policy_version, POLICY_VERSION, + "Samples should declare the current policy encoding version" + ); + assert_eq!(first.teacher, "uniform"); + assert_eq!(first.teacher_playouts, nplayouts); + assert_eq!(first.model_prior_weight, 0.0); + assert_eq!(first.run_id, "test_uniform"); + assert!(first.teacher_model.is_none()); assert_eq!( first.features.len(), 8 * NUM_CELLS, "Features should be 8 channels x 60 cells" ); + assert_eq!(first.turn, 0, "First sample should be turn 0"); // First 8 moves are drafting - for sample in result.samples.iter().take(8) { + for (turn, sample) in result.samples.iter().take(8).enumerate() { assert!( sample.is_drafting, "First 8 samples should be drafting phase" ); + assert_eq!(sample.turn, turn, "Sample turn should match move number"); assert_eq!( sample.policy.len(), NUM_CELLS, @@ -489,7 +544,7 @@ fn test_selfplay_generates_samples() { ); } - // After drafting, policy should be movement (168 values - compressed) + // After drafting, policy should be movement (2520 values - absolute source-cell encoding) if result.samples.len() > 8 { let movement_sample = &result.samples[8]; assert!( @@ -499,8 +554,9 @@ fn test_selfplay_generates_samples() { assert_eq!( movement_sample.policy.len(), MOVEMENT_POLICY_SIZE, - "Movement policy should be 168 values (4 penguins × 6 directions × 7 distances)" + "Movement policy should be 2520 values (60 cells × 6 directions × 7 distances)" ); + assert_eq!(movement_sample.turn, 8, "First movement sample should be turn 8"); } // Values should be valid (0.0, 0.5, or 1.0) diff --git a/training/artifacts/model.onnx b/training/artifacts/model.onnx index 182e992..47f8808 100644 Binary files a/training/artifacts/model.onnx and b/training/artifacts/model.onnx differ diff --git a/training/artifacts/model_final.pt b/training/artifacts/model_final.pt index 06a472b..7ebfb20 100644 Binary files a/training/artifacts/model_final.pt and b/training/artifacts/model_final.pt differ diff --git a/training/create_blank_models.py b/training/create_blank_models.py index b0e1831..9f3e75d 100644 --- a/training/create_blank_models.py +++ b/training/create_blank_models.py @@ -6,27 +6,14 @@ - Policy: All zeros → softmax gives 1/n for each legal move (uniform) - Value: Always 0 (tanh) → converted to 0.5 in Rust (neutral) -With these models, the AlphaZero bot should explore uniformly but still -use the PUCT formula. Comparing this to pure MCTS helps isolate bugs. +With these models, the bot should explore uniformly through the same PUCT +path used by the production uniform-prior baseline. -IMPORTANT: Even with these "blank" models, AlphaZero will NOT behave -identically to pure MCTS because: +IMPORTANT: Even with these "blank" models, only the policy output is used by +search; the neutral value head is intentionally ignored by the production bot. -1. PUCT vs UCB1 selection: - - MCTS UCB1: unvisited nodes get INFINITY (always explore first) - - AlphaZero PUCT: unvisited nodes get Q=0.5 + exploration term (finite) - -2. Leaf evaluation: - - MCTS: Random rollout to game end → actual win/loss/draw - - Blank model: Always returns value=0.5 → Q-values stay at 0.5 - -3. This means with blank models: - - All nodes will have Q ≈ 0.5 (no learning of which moves are good) - - Selection driven purely by exploration term: C_PUCT * P * sqrt(N) / (1+n) - - With uniform priors, exploration favors less-visited nodes - -The UniformPriorRollout mode in Rust is a better debugging tool because -it uses random rollouts (like MCTS) but with PUCT selection. +This is useful for verifying that ONNX inference and the native +UniformPriorProvider produce the same legal-move priors. """ from pathlib import Path @@ -38,7 +25,10 @@ NUM_CELLS = 60 NUM_CHANNELS = 8 NUM_FEATURES = NUM_CHANNELS * NUM_CELLS # 480 -MOVEMENT_POLICY_SIZE = 4 * 6 * 7 # 168 (penguins × directions × distances) +NUM_DIRECTIONS = 6 +MAX_DISTANCE = 7 +POLICY_VERSION = 2 +MOVEMENT_POLICY_SIZE = NUM_CELLS * NUM_DIRECTIONS * MAX_DISTANCE # 2520 ARTIFACTS_DIR = Path("./artifacts") @@ -127,6 +117,9 @@ def main(): print("NOTE: This blank model outputs uniform policy (all zeros -> 1/n after softmax)") print("and neutral value (0 tanh -> 0.5 probability) for both drafting and movement.") print() + print(f"Policy encoding version: {POLICY_VERSION}") + print("Movement policy: source_cell * 42 + direction * 7 + distance_minus_one") + print() print("The PUCT mode uses random rollouts for leaf evaluation, so the value output") print("from this model is NOT used. Only the policy priors are used to guide") print("which moves to explore first.") diff --git a/training/iterate.py b/training/iterate.py index 2566773..cfef8c2 100644 --- a/training/iterate.py +++ b/training/iterate.py @@ -20,12 +20,15 @@ from datetime import datetime from pathlib import Path +import replay + # Paths ARTIFACTS_DIR = Path("./artifacts") TRAINING_DATA = ARTIFACTS_DIR / "training_data.jsonl" MODEL_FINAL = ARTIFACTS_DIR / "model_final.pt" ONNX_MODEL = ARTIFACTS_DIR / "model.onnx" ITERATIONS_DIR = ARTIFACTS_DIR / "iterations" +BROWSER_MODEL = Path("../www/public/models/htmf-policy.onnx") def run_command(cmd: list[str], cwd: str | None = None) -> subprocess.CompletedProcess: @@ -39,8 +42,12 @@ def run_command(cmd: list[str], cwd: str | None = None) -> subprocess.CompletedP def generate_selfplay_data(num_games: int, num_playouts: int, use_nn: bool) -> Path: """Generate self-play training data.""" + replay.SELFPLAY_DIR.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_file = ARTIFACTS_DIR / f"selfplay_{timestamp}.jsonl" + mode = "nn" if use_nn else "uniform" + output_file = replay.SELFPLAY_DIR / ( + f"selfplay_{timestamp}_{mode}_g{num_games}_p{num_playouts}.jsonl" + ) cmd = [ "cargo", @@ -72,84 +79,89 @@ def generate_selfplay_data(num_games: int, num_playouts: int, use_nn: bool) -> P def merge_training_data(new_data: Path, max_samples: int = 100_000): - """Merge new data into the main training file, keeping recent samples.""" - # Read existing data - existing_samples = [] - if TRAINING_DATA.exists(): - with open(TRAINING_DATA) as f: - existing_samples = f.readlines() - - # Read new data + """Rebuild the active training file from durable policy-v2 selfplay runs.""" with open(new_data) as f: - new_samples = f.readlines() - - # Combine, keeping most recent samples - all_samples = existing_samples + new_samples - if len(all_samples) > max_samples: - # Keep the most recent samples (at the end) - all_samples = all_samples[-max_samples:] - - # Write back - with open(TRAINING_DATA, "w") as f: - f.writelines(all_samples) - - print(f"Training data: {len(all_samples)} samples (added {len(new_samples)} new)") + new_samples = sum(1 for _ in f) + total_samples = replay.build_replay(max_samples=max_samples) + print(f"Training data: {total_samples} samples (added {new_samples} new)") -def train_model(epochs: int, learning_rate: float = 0.001, num_filters: int = 64, num_blocks: int = 4) -> bool: +def train_model( + epochs: int, + learning_rate: float = 0.001, + num_filters: int | None = None, + num_blocks: int | None = None, +) -> bool: """Train the model and return True if it improved.""" print(f"\nTraining for {epochs} epochs...") - result = run_command( - [ - "uv", - "run", - "train.py", - "--epochs", - str(epochs), - "--lr", - str(learning_rate), - "--num-filters", - str(num_filters), - "--num-blocks", - str(num_blocks), - ], - cwd=str(Path(__file__).parent), - ) + cmd = [ + "uv", + "run", + "train.py", + "--epochs", + str(epochs), + "--lr", + str(learning_rate), + ] + if num_filters is not None: + cmd.extend(["--num-filters", str(num_filters)]) + if num_blocks is not None: + cmd.extend(["--num-blocks", str(num_blocks)]) + + result = run_command(cmd, cwd=str(Path(__file__).parent)) return result.returncode == 0 def evaluate_models( - num_games: int = 20, num_playouts: int = 100 -) -> tuple[int, int, int]: + num_pairs: int = 100, num_playouts: int = 400, uniform_vs_uniform: bool = False +) -> tuple[int, int, int, float]: """ - Evaluate new model vs old model. - Returns (new_wins, old_wins, draws) + Evaluate the model against the production uniform-prior baseline. + Returns (model_wins, baseline_wins, draws, score) """ - print(f"\nEvaluating new model ({num_games} games, {num_playouts} playouts)...") + print(f"\nEvaluating model ({num_pairs} pairs, {num_playouts} playouts)...") + cmd = [ + "cargo", + "run", + "--release", + "--bin", + "nn_vs_mcts", + "--", + str(num_pairs), + str(num_playouts), + ] + if uniform_vs_uniform: + cmd.append("--uniform-vs-uniform") result = subprocess.run( - ["cargo", "run", "--release", "--bin", "nn_vs_mcts", "--"], + cmd, capture_output=True, text=True, cwd=str(Path(__file__).parent.parent), ) + if result.returncode != 0: + print(result.stdout) + print(result.stderr) + raise RuntimeError(f"evaluation failed with exit code {result.returncode}") # Parse results from output - # Looking for lines like "NN wins: X (Y%)" - nn_wins = 0 - mcts_wins = 0 + model_wins = 0 + baseline_wins = 0 draws = 0 + score = 0.0 for line in result.stdout.split("\n") + result.stderr.split("\n"): - if "NN wins:" in line: - nn_wins = int(line.split(":")[1].split("(")[0].strip()) - elif "MCTS wins:" in line: - mcts_wins = int(line.split(":")[1].split("(")[0].strip()) + if "Model wins:" in line: + model_wins = int(line.split(":")[1].split("(")[0].strip()) + elif "Baseline wins:" in line: + baseline_wins = int(line.split(":")[1].split("(")[0].strip()) elif "Draws:" in line: draws = int(line.split(":")[1].split("(")[0].strip()) + elif "Score:" in line: + score = float(line.split(":")[1].strip()) - print(f"Results: NN={nn_wins}, MCTS={mcts_wins}, Draws={draws}") - return nn_wins, mcts_wins, draws + print(f"Results: model={model_wins}, baseline={baseline_wins}, draws={draws}, score={score:.3f}") + return model_wins, baseline_wins, draws, score def save_iteration(iteration: int): @@ -164,6 +176,36 @@ def save_iteration(iteration: int): print(f"Saved iteration {iteration} to {iter_dir}") +def promote_browser_model(): + """Copy the current ONNX model to the browser-served artifact path.""" + BROWSER_MODEL.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(ONNX_MODEL, BROWSER_MODEL) + print(f"Promoted browser model -> {BROWSER_MODEL}") + + +def backup_current_model() -> Path | None: + """Save the current training artifacts so a failed iteration can roll back.""" + existing = [p for p in [MODEL_FINAL, ONNX_MODEL] if p.exists()] + if not existing: + return None + + backup_dir = ARTIFACTS_DIR / "rollback" + if backup_dir.exists(): + shutil.rmtree(backup_dir) + backup_dir.mkdir(parents=True) + for src in existing: + shutil.copy(src, backup_dir / src.name) + return backup_dir + + +def restore_model_backup(backup_dir: Path | None): + if backup_dir is None: + return + for src in backup_dir.iterdir(): + shutil.copy(src, ARTIFACTS_DIR / src.name) + print(f"Restored previous model from {backup_dir}") + + def main(): parser = argparse.ArgumentParser(description="Iterative HTMF training") parser.add_argument( @@ -179,13 +221,31 @@ def main(): "--epochs", type=int, default=20, help="Training epochs per iteration" ) parser.add_argument( - "--num-filters", type=int, default=64, help="Number of convolutional filters" + "--eval-pairs", type=int, default=100, help="Paired evaluation seeds per iteration" ) parser.add_argument( - "--num-blocks", type=int, default=4, help="Number of residual blocks" + "--eval-playouts", type=int, default=400, help="Evaluation playouts per move" ) parser.add_argument( - "--bootstrap", action="store_true", help="Start with traditional MCTS data" + "--promotion-score", + type=float, + default=0.53, + help="Minimum model score vs uniform baseline required for promotion", + ) + parser.add_argument( + "--num-filters", + type=int, + default=None, + help="Override number of convolutional filters; defaults to checkpoint metadata", + ) + parser.add_argument( + "--num-blocks", + type=int, + default=None, + help="Override number of residual blocks; defaults to checkpoint metadata", + ) + parser.add_argument( + "--bootstrap", action="store_true", help="Start with uniform-prior baseline data" ) parser.add_argument( "--fresh", action="store_true", help="Start fresh (delete existing model)" @@ -202,6 +262,9 @@ def main(): if f.exists(): f.unlink() print(f" Removed {f}") + if replay.SELFPLAY_DIR.exists(): + shutil.rmtree(replay.SELFPLAY_DIR) + print(f" Removed {replay.SELFPLAY_DIR}") print("=" * 60) print("HTMF Iterative Training") @@ -210,16 +273,22 @@ def main(): print(f"Games per iteration: {args.games}") print(f"Playouts per move: {args.playouts}") print(f"Epochs per iteration: {args.epochs}") + print(f"Eval pairs: {args.eval_pairs}") + print(f"Eval playouts: {args.eval_playouts}") + print(f"Promotion score: {args.promotion_score:.3f}") print("=" * 60) # Check if we need to bootstrap if args.bootstrap or not MODEL_FINAL.exists(): - print("\nBootstrapping with traditional MCTS self-play...") + print("\nBootstrapping with uniform-prior baseline self-play...") # Generate high-quality data with many playouts new_data = generate_selfplay_data(args.games * 2, args.playouts, use_nn=False) merge_training_data(new_data) train_model(args.epochs * 2, num_filters=args.num_filters, num_blocks=args.num_blocks) # Train longer for initial model save_iteration(0) + _, _, _, bootstrap_score = evaluate_models(args.eval_pairs, args.eval_playouts) + if bootstrap_score >= args.promotion_score: + promote_browser_model() for iteration in range(1, args.iterations + 1): print(f"\n{'=' * 60}") @@ -230,20 +299,30 @@ def main(): new_data = generate_selfplay_data(args.games, args.playouts, use_nn=True) merge_training_data(new_data) + backup_dir = backup_current_model() + # Train on all data train_model(args.epochs, num_filters=args.num_filters, num_blocks=args.num_blocks) # Save this iteration save_iteration(iteration) - # Evaluate against pure MCTS - nn_wins, mcts_wins, draws = evaluate_models() - win_rate = nn_wins / max(1, nn_wins + mcts_wins + draws) * 100 - - print(f"\nIteration {iteration} complete: NN win rate = {win_rate:.1f}%") - - if win_rate >= 55: - print("Model is now competitive with pure MCTS!") + # Evaluate against the production uniform-prior baseline + model_wins, baseline_wins, draws, score = evaluate_models( + args.eval_pairs, args.eval_playouts + ) + + print( + f"\nIteration {iteration} complete: model={model_wins}, baseline={baseline_wins}, " + f"draws={draws}, score={score:.3f}" + ) + + if score >= args.promotion_score: + print("Model passed promotion gate.") + promote_browser_model() + else: + print("Model did not pass promotion gate; rolling back training artifact.") + restore_model_backup(backup_dir) print("\n" + "=" * 60) print("Training complete!") diff --git a/training/train.py b/training/train.py index 0aa8626..cafce4c 100644 --- a/training/train.py +++ b/training/train.py @@ -5,7 +5,7 @@ The network has two heads: - Policy head: probability distribution over moves - Drafting: 60 values (one per cell) - - Movement: 168 values (4 penguins x 6 directions x 7 distances) + - Movement: 2520 values (60 source cells x 6 directions x 7 distances) - Value head: predicted win probability for current player Usage: @@ -14,6 +14,7 @@ import argparse import json +from collections import Counter from pathlib import Path import numpy as np @@ -26,10 +27,11 @@ NUM_CELLS = 60 NUM_CHANNELS = 8 NUM_FEATURES = NUM_CHANNELS * NUM_CELLS # 480 -NUM_PENGUINS = 4 NUM_DIRECTIONS = 6 MAX_DISTANCE = 7 -MOVEMENT_POLICY_SIZE = NUM_PENGUINS * NUM_DIRECTIONS * MAX_DISTANCE # 168 +POLICY_VERSION = 2 +MOVEMENT_POLICY_SIZE = NUM_CELLS * NUM_DIRECTIONS * MAX_DISTANCE # 2520 +VALID_TEACHERS = {"uniform", "nn_root"} # Grid dimensions for Conv2D (8 rows, alternating 7/8 columns -> embed in 8x8) NUM_ROWS = 8 @@ -69,9 +71,6 @@ def _build_valid_mask(): TRAINING_DATA = ARTIFACTS_DIR / "training_data.jsonl" MODEL_CHECKPOINT = ARTIFACTS_DIR / "model_final.pt" ONNX_MODEL = ARTIFACTS_DIR / "model.onnx" -# Legacy paths for backward compatibility -ONNX_DRAFTING = ARTIFACTS_DIR / "model_drafting.onnx" -ONNX_MOVEMENT = ARTIFACTS_DIR / "model_movement.onnx" class HTMFDataset(Dataset): @@ -80,10 +79,63 @@ class HTMFDataset(Dataset): def __init__(self, data_path: Path): self.drafting_samples: list[dict] = [] self.movement_samples: list[dict] = [] + self.teacher_counts: Counter[str] = Counter() with open(data_path) as f: - for line in f: + for line_num, line in enumerate(f, start=1): sample = json.loads(line) + policy_version = sample.get("policy_version") + if policy_version != POLICY_VERSION: + raise ValueError( + f"{data_path}:{line_num}: expected policy_version={POLICY_VERSION}, " + f"got {policy_version!r}. Regenerate selfplay data; v1 movement " + "policies are not compatible with the absolute source-cell encoding." + ) + + if "turn" not in sample: + raise ValueError( + f"{data_path}:{line_num}: missing turn metadata. Regenerate " + "policy-v2 selfplay data before training." + ) + + teacher = sample.get("teacher") + if teacher not in VALID_TEACHERS: + raise ValueError( + f"{data_path}:{line_num}: expected teacher in " + f"{sorted(VALID_TEACHERS)}, got {teacher!r}. Rebuild replay " + "so legacy samples are normalized." + ) + + if not isinstance(sample.get("teacher_playouts"), int): + raise ValueError( + f"{data_path}:{line_num}: teacher_playouts must be an integer" + ) + + if not isinstance(sample.get("model_prior_weight"), int | float): + raise ValueError( + f"{data_path}:{line_num}: model_prior_weight must be numeric" + ) + + if not isinstance(sample.get("run_id"), str) or not sample["run_id"]: + raise ValueError( + f"{data_path}:{line_num}: run_id must be a non-empty string" + ) + + if "teacher_model" in sample and not isinstance(sample["teacher_model"], str): + raise ValueError( + f"{data_path}:{line_num}: teacher_model must be a string when present" + ) + + expected_policy_len = ( + NUM_CELLS if sample["is_drafting"] else MOVEMENT_POLICY_SIZE + ) + if len(sample["policy"]) != expected_policy_len: + raise ValueError( + f"{data_path}:{line_num}: expected policy length " + f"{expected_policy_len}, got {len(sample['policy'])}" + ) + + self.teacher_counts[teacher] += 1 if sample["is_drafting"]: self.drafting_samples.append(sample) else: @@ -91,6 +143,13 @@ def __init__(self, data_path: Path): print(f"Loaded {len(self.drafting_samples)} drafting samples") print(f"Loaded {len(self.movement_samples)} movement samples") + print( + "Teacher counts: " + + ", ".join( + f"{teacher}={self.teacher_counts.get(teacher, 0)}" + for teacher in sorted(VALID_TEACHERS) + ) + ) def __len__(self): return len(self.drafting_samples) + len(self.movement_samples) @@ -388,7 +447,7 @@ class HTMFNet(nn.Module): - Input: 8 channels x 8x8 grid (60 valid cells embedded in 64) - Shared trunk: Initial conv + residual blocks (shared between all heads) - Drafting policy head: outputs 60 cell probabilities - - Movement policy head: outputs 168 move probabilities (4 penguins × 6 dirs × 7 dists) + - Movement policy head: outputs 2520 move probabilities (60 source cells × 6 dirs × 7 dists) - Value head: predicts win probability - Ownership head: predicts per-cell ownership at game end (60 cells × 3 classes) - Score difference head: predicts final score difference distribution (185 bins) @@ -468,6 +527,8 @@ def train_model( device: torch.device, is_drafting: bool, batch_size: int = 256, + value_weight: float = 0.25, + aux_weight: float = 0.05, ) -> tuple[float, float, float, float]: """Train the model for one epoch. @@ -557,13 +618,13 @@ def train_model( # Combine PDF and CDF losses (equal weighting as in KataGo) score_diff_loss = pdf_loss + cdf_loss - # Combined loss with auxiliary targets - # Weight auxiliary losses lower to avoid overwhelming main objectives + # Policy priors are the only v1 search signal. Value and auxiliary heads + # stay as regularizers, weighted low enough not to drown out policy loss. loss = ( policy_loss - + value_loss - + 0.5 * ownership_loss - + 0.5 * score_diff_loss + + value_weight * value_loss + + aux_weight * ownership_loss + + aux_weight * score_diff_loss ) loss.backward() optimizer.step() @@ -589,7 +650,7 @@ def export_to_onnx(model: HTMFNet, path: Path): - Input: features (batch, 480) - Outputs: - drafting_policy (batch, 60) - - movement_policy (batch, 168) + - movement_policy (batch, 2520) - value (batch, 1) - ownership (batch, 60, 3) - per-cell ownership prediction - score_diff (batch, 185) - score difference distribution @@ -626,6 +687,23 @@ def forward(self, x): ) +def compatible_state_dict( + model: HTMFNet, state_dict: dict[str, torch.Tensor] +) -> tuple[dict[str, torch.Tensor], list[str]]: + """Return only checkpoint tensors whose names and shapes match this model.""" + current_state = model.state_dict() + compatible = {} + skipped = [] + + for name, tensor in state_dict.items(): + if name in current_state and current_state[name].shape == tensor.shape: + compatible[name] = tensor + else: + skipped.append(name) + + return compatible, skipped + + def main(): parser = argparse.ArgumentParser(description="Train HTMF neural network") parser.add_argument( @@ -634,10 +712,33 @@ def main(): parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") parser.add_argument("--batch-size", type=int, default=256, help="Batch size") parser.add_argument( - "--num-filters", type=int, default=64, help="Number of conv filters" + "--num-filters", + type=int, + default=None, + help="Number of conv filters; defaults to checkpoint metadata or 64 for a fresh model", + ) + parser.add_argument( + "--num-blocks", + type=int, + default=None, + help="Number of residual blocks; defaults to checkpoint metadata or 4 for a fresh model", ) parser.add_argument( - "--num-blocks", type=int, default=4, help="Number of residual blocks" + "--value-weight", + type=float, + default=0.25, + help="Loss weight for the value head; policy remains the primary search signal", + ) + parser.add_argument( + "--aux-weight", + type=float, + default=0.05, + help="Loss weight for ownership and score-difference auxiliary heads", + ) + parser.add_argument( + "--fresh", + action="store_true", + help="Train from random initialization instead of loading model_final.pt", ) args = parser.parse_args() @@ -683,25 +784,44 @@ def main(): if drafting_ownerships is not None or movement_ownerships is not None: print("Auxiliary targets detected: ownership, score_diff") - # Create single shared model - model = HTMFNet(num_filters=args.num_filters, num_blocks=args.num_blocks).to(device) - - # Load existing weights if available - if MODEL_CHECKPOINT.exists(): - print(f"Loading existing model from {MODEL_CHECKPOINT}...") + checkpoint = None + state_dict = None + if args.fresh: + print("Fresh training requested; not loading existing checkpoint.") + elif MODEL_CHECKPOINT.exists(): + print(f"Loading existing model metadata from {MODEL_CHECKPOINT}...") checkpoint = torch.load( MODEL_CHECKPOINT, map_location=device, weights_only=True ) - # Handle both old format (direct state_dict) and new format (dict with "model" key) - if "model" in checkpoint: + if isinstance(checkpoint, dict) and "model" in checkpoint: state_dict = checkpoint["model"] + if args.num_filters is None and "num_filters" in checkpoint: + args.num_filters = int(checkpoint["num_filters"]) + if args.num_blocks is None and "num_blocks" in checkpoint: + args.num_blocks = int(checkpoint["num_blocks"]) else: - # Old format: checkpoint is the state_dict directly state_dict = checkpoint - # Load weights, allowing for missing keys (e.g., new auxiliary heads) - model.load_state_dict(state_dict, strict=False) - print("Loaded existing model weights (auxiliary heads will be randomly initialized if not present)") + if args.num_filters is None: + args.num_filters = 64 + if args.num_blocks is None: + args.num_blocks = 4 + + # Create single shared model + model = HTMFNet(num_filters=args.num_filters, num_blocks=args.num_blocks).to(device) + + # Load existing weights if available + if state_dict is not None: + compatible, skipped = compatible_state_dict(model, state_dict) + missing, unexpected = model.load_state_dict(compatible, strict=False) + print( + f"Loaded {len(compatible)} tensors from existing checkpoint " + f"({len(skipped)} shape/name mismatches skipped)" + ) + if missing: + print(f"Initialized {len(missing)} tensors from scratch") + if unexpected: + print(f"Ignored {len(unexpected)} unexpected tensors") # Single optimizer for the entire model optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -717,6 +837,8 @@ def main(): print(f"Learning rate: {args.lr} (with cosine annealing schedule)") print(f"Warmup epochs: {warmup_epochs}") print(f"Batch size: {args.batch_size}") + print(f"Value loss weight: {args.value_weight}") + print(f"Auxiliary loss weight: {args.aux_weight}") print() for epoch in range(1, args.epochs + 1): @@ -732,6 +854,8 @@ def main(): device, is_drafting=True, batch_size=args.batch_size, + value_weight=args.value_weight, + aux_weight=args.aux_weight, ) # Train on movement data @@ -746,6 +870,8 @@ def main(): device, is_drafting=False, batch_size=args.batch_size, + value_weight=args.value_weight, + aux_weight=args.aux_weight, ) # Print detailed loss analysis @@ -757,19 +883,19 @@ def main(): print(f"\nDRAFTING PHASE:") print(f" Raw losses: P={d_policy_loss:.4f} V={d_value_loss:.4f} O={d_ownership_loss:.4f} S={d_score_diff_loss:.4f}") if d_ownership_loss > 0: - print(f" Weighted (0.5): O={0.5*d_ownership_loss:.4f} S={0.5*d_score_diff_loss:.4f}") - total_weighted = d_policy_loss + d_value_loss + 0.5*d_ownership_loss + 0.5*d_score_diff_loss + print(f" Weighted: V={args.value_weight*d_value_loss:.4f} O={args.aux_weight*d_ownership_loss:.4f} S={args.aux_weight*d_score_diff_loss:.4f}") + total_weighted = d_policy_loss + args.value_weight*d_value_loss + args.aux_weight*d_ownership_loss + args.aux_weight*d_score_diff_loss print(f" Total weighted loss: {total_weighted:.4f}") - print(f" Contribution %: P={100*d_policy_loss/total_weighted:.1f}% V={100*d_value_loss/total_weighted:.1f}% O={100*0.5*d_ownership_loss/total_weighted:.1f}% S={100*0.5*d_score_diff_loss/total_weighted:.1f}%") + print(f" Contribution %: P={100*d_policy_loss/total_weighted:.1f}% V={100*args.value_weight*d_value_loss/total_weighted:.1f}% O={100*args.aux_weight*d_ownership_loss/total_weighted:.1f}% S={100*args.aux_weight*d_score_diff_loss/total_weighted:.1f}%") # Movement losses print(f"\nMOVEMENT PHASE:") print(f" Raw losses: P={m_policy_loss:.4f} V={m_value_loss:.4f} O={m_ownership_loss:.4f} S={m_score_diff_loss:.4f}") if m_ownership_loss > 0: - print(f" Weighted (0.5): O={0.5*m_ownership_loss:.4f} S={0.5*m_score_diff_loss:.4f}") - total_weighted = m_policy_loss + m_value_loss + 0.5*m_ownership_loss + 0.5*m_score_diff_loss + print(f" Weighted: V={args.value_weight*m_value_loss:.4f} O={args.aux_weight*m_ownership_loss:.4f} S={args.aux_weight*m_score_diff_loss:.4f}") + total_weighted = m_policy_loss + args.value_weight*m_value_loss + args.aux_weight*m_ownership_loss + args.aux_weight*m_score_diff_loss print(f" Total weighted loss: {total_weighted:.4f}") - print(f" Contribution %: P={100*m_policy_loss/total_weighted:.1f}% V={100*m_value_loss/total_weighted:.1f}% O={100*0.5*m_ownership_loss/total_weighted:.1f}% S={100*0.5*m_score_diff_loss/total_weighted:.1f}%") + print(f" Contribution %: P={100*m_policy_loss/total_weighted:.1f}% V={100*args.value_weight*m_value_loss/total_weighted:.1f}% O={100*args.aux_weight*m_ownership_loss/total_weighted:.1f}% S={100*args.aux_weight*m_score_diff_loss/total_weighted:.1f}%") # Update learning rate (warmup for first N epochs, then cosine annealing) current_lr = optimizer.param_groups[0]['lr'] @@ -791,6 +917,7 @@ def main(): "model": model.state_dict(), "num_filters": args.num_filters, "num_blocks": args.num_blocks, + "policy_version": POLICY_VERSION, }, MODEL_CHECKPOINT, ) diff --git a/wasm/src/lib.rs b/wasm/src/lib.rs index 11ab2bc..e8749e4 100644 --- a/wasm/src/lib.rs +++ b/wasm/src/lib.rs @@ -2,6 +2,7 @@ use wasm_bindgen::prelude::*; use htmf::board::Player; use htmf::game::GameState; +use htmf::NUM_CELLS; use htmf_bots::MCTSBot; mod utils; @@ -99,6 +100,17 @@ impl Game { .collect() } + pub fn features_for_active_player(&self) -> Vec { + let Some(current_player) = self.bot.root_game.state.active_player().map(|p| p.id) else { + return vec![0.0; 8 * NUM_CELLS]; + }; + extract_features(&self.bot.root_game.state, current_player) + } + + pub fn apply_policy_logits(&mut self, logits: Vec) { + self.bot.update_root_priors_from_logits(&logits); + } + pub fn place_penguin(&mut self, dst: u8) -> Result<(), JsValue> { let mut new_game = self.bot.root_game.state.clone(); match new_game.place_penguin(dst) { @@ -212,3 +224,46 @@ impl Default for Game { Self::new() } } + +fn extract_features(game: &GameState, current_player: usize) -> Vec { + let opponent = 1 - current_player; + let is_drafting = !game.finished_drafting(); + + let mut features = vec![0.0f32; 8 * NUM_CELLS]; + + for cell in 0..NUM_CELLS as u8 { + if game.board.fish[0].contains(cell) { + features[cell as usize] = 1.0; + } + if game.board.fish[1].contains(cell) { + features[NUM_CELLS + cell as usize] = 1.0; + } + if game.board.fish[2].contains(cell) { + features[2 * NUM_CELLS + cell as usize] = 1.0; + } + } + + for cell in game.board.penguins[current_player].into_iter() { + features[3 * NUM_CELLS + cell as usize] = 1.0; + } + + for cell in game.board.penguins[opponent].into_iter() { + features[4 * NUM_CELLS + cell as usize] = 1.0; + } + + for cell in game.board.claimed[current_player].into_iter() { + features[5 * NUM_CELLS + cell as usize] = 1.0; + } + + for cell in game.board.claimed[opponent].into_iter() { + features[6 * NUM_CELLS + cell as usize] = 1.0; + } + + if is_drafting { + for i in 0..NUM_CELLS { + features[7 * NUM_CELLS + i] = 1.0; + } + } + + features +} diff --git a/www/eslint.config.mjs b/www/eslint.config.mjs index 30d08c3..1bde10c 100644 --- a/www/eslint.config.mjs +++ b/www/eslint.config.mjs @@ -15,7 +15,7 @@ const __dirname = path.dirname(__filename); * @type {ReturnType} */ const config = defineConfig([ - globalIgnores(["dist/**/*"]), + globalIgnores(["dist", "dist/**/*", "**/dist/**/*"]), { languageOptions: { globals: { diff --git a/www/package.json b/www/package.json index 06e38de..919697e 100644 --- a/www/package.json +++ b/www/package.json @@ -31,6 +31,7 @@ "homepage": "https://jthemphill.github.io/htmf", "dependencies": { "htmf-wasm": "*", + "onnxruntime-web": "^1.22.0", "react": "19.2.4", "react-dom": "19.2.4" }, diff --git a/www/public/models/htmf-policy.json b/www/public/models/htmf-policy.json new file mode 100644 index 0000000..b3bf072 --- /dev/null +++ b/www/public/models/htmf-policy.json @@ -0,0 +1,29 @@ +{ + "model": "htmf-policy.onnx", + "format": "onnx", + "featureVersion": 1, + "policyVersion": 2, + "features": { + "shape": [1, 480], + "channels": [ + "one_fish", + "two_fish", + "three_fish", + "current_player_penguins", + "opponent_penguins", + "current_player_claimed", + "opponent_claimed", + "is_drafting" + ] + }, + "policy": { + "draftingSize": 60, + "movementSize": 2520, + "movementEncoding": "source_cell * 42 + direction * 7 + distance_minus_one" + }, + "serving": { + "runtime": "onnxruntime-web", + "executionProviders": ["webgpu", "wasm"], + "modelPriorWeight": 0.05 + } +} diff --git a/www/src/webworker/Bot.ts b/www/src/webworker/Bot.ts index 441c38a..be7e50c 100644 --- a/www/src/webworker/Bot.ts +++ b/www/src/webworker/Bot.ts @@ -15,6 +15,7 @@ import { type WorkerRequest, type WorkerResponse, } from "../browser/WorkerProtocol"; +import PolicyModel from "./PolicyModel"; function getGameState(game: wasm.Game): GameState { const fish = []; @@ -65,6 +66,8 @@ class Bot { forcedMove = false; ponderStartTime?: number; totalCompletedPonderTimeMs = 0; + policyModel: PolicyModel = new PolicyModel(); + applyingPolicyPriors = false; constructor( wasmInternals: wasm.InitOutput, @@ -72,6 +75,7 @@ class Bot { ) { this.wasmInternals = wasmInternals; this.postMessage = postMessage; + void this.policyModel.load(); this.ponder(); this.postGameState({}); } @@ -88,6 +92,7 @@ class Bot { } this.ponderStartTime = performance.now(); this.ponderer = self.setInterval(() => { + void this.refreshPolicyPriors(); const activePlayer = this.game.active_player(); if (activePlayer === BOT_PLAYER) { // We need to make a move soon @@ -145,6 +150,19 @@ class Bot { this.game.playout(); } + async refreshPolicyPriors(): Promise { + if (this.applyingPolicyPriors) { + return; + } + + this.applyingPolicyPriors = true; + try { + await this.policyModel.applyRootPriors(this.game); + } finally { + this.applyingPolicyPriors = false; + } + } + getState(): GameState { return getGameState(this.game); } diff --git a/www/src/webworker/PolicyModel.ts b/www/src/webworker/PolicyModel.ts new file mode 100644 index 0000000..30e435e --- /dev/null +++ b/www/src/webworker/PolicyModel.ts @@ -0,0 +1,62 @@ +import * as ort from "onnxruntime-web"; +import type * as wasm from "htmf-wasm"; + +const MODEL_URL = "/models/htmf-policy.onnx"; +const NUM_FEATURES = 8 * 60; + +type Backend = "webgpu" | "wasm"; + +export default class PolicyModel { + private session?: ort.InferenceSession; + private backend?: Backend; + private loadPromise?: Promise; + + load(): Promise { + this.loadPromise ??= this.loadWithFallback(); + return this.loadPromise; + } + + async applyRootPriors(game: wasm.Game): Promise { + await this.load(); + if (this.session === undefined || game.active_player() === undefined) { + return; + } + + const features = game.features_for_active_player(); + if (features.length !== NUM_FEATURES) { + return; + } + + const input = new ort.Tensor("float32", Float32Array.from(features), [ + 1, + NUM_FEATURES, + ]); + const outputs = await this.session.run({ features: input }); + const outputName = game.is_drafting() ? "drafting_policy" : "movement_policy"; + const output = outputs[outputName]; + if (output === undefined || !(output.data instanceof Float32Array)) { + return; + } + + game.apply_policy_logits(output.data); + } + + getBackend(): Backend | undefined { + return this.backend; + } + + private async loadWithFallback(): Promise { + for (const backend of ["webgpu", "wasm"] satisfies Backend[]) { + try { + this.session = await ort.InferenceSession.create(MODEL_URL, { + executionProviders: [backend], + graphOptimizationLevel: "all", + }); + this.backend = backend; + return; + } catch (err) { + void err; + } + } + } +}