Trajectory-balance normalization experiments#14
Open
josephdviviano wants to merge 34 commits into
Open
Conversation
- Fix Modified TB and VarGrad losses: normalize score by trajectory length *before* squaring (per-step average error) instead of after (was effectively O(T) bias toward long trajectories) - Fix traj_len computation: count actions via ~actions.is_dummy instead of counting non-sink states (semantically correct, same values) - Add replay buffer support: sample batch_size fresh trajectories, keep (1-frac) on-policy, replace the rest with buffer samples, compute single loss with recalculated log-probs - Add optimizer selection (adamw/adam/sgd), beta2 tuning, cosine LR schedule, loss clamping, algorithm filtering (--algos) - Add --output-dir flag for per-job result isolation - Add SLURM scripts: run_single.sh (one job), launch_sweep.sh (192-job grid over algos x envs x lr x beta2 x grad_clip) - Add aggregate_results.py to merge per-job CSVs, rank HP configs, and plot best-per-algo comparisons - Move tb.plan.md into .claude/, add IDEAS_TO_TRY.md for deferred stabilization ideas Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- _store_all_states_tensor() -> _enumerate_all_states_tensor() (upstream rename) - check_action_validity -> debug param in DiscreteEnv.__init__ - trajectories.conditioning -> trajectories.states.conditions[0] - EPS_REWARD_CMP 1e-12 -> 1e-6 (float32 vs float64 precision mismatch) - Disable validate_modes in _build_env (quick-check heuristics have false negatives) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Keep fixed loss formulations (per-step normalization) and API compat fixes; discard duplicate old class definitions from remote. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Use conda shell hook instead of source activate (args leak fix) - Use SLURM_SUBMIT_DIR for reliable repo path resolution - Reduce cpus-per-task from 4 to 1 - Fix _enumerate_all_states_tensor -> _store_all_states_tensor - Add missing _calculate_log_partition() call - Fix DiscreteEnv debug -> check_action_validity kwarg - Enable performance_mode in set_seed to skip deterministic algos Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- run_single.sh: Save/restore $@ around conda init to prevent arg leakage
into conda activate (fixed 90 jobs); set CUBLAS_WORKSPACE_CONFIG for
deterministic mode; switch partition from main to long
- tb_normalize.py: Fix torchgfn 2.3.1 API — trajectories.states.conditions
→ trajectories.conditioning (fixed all 48 ModifiedTBGFlowNet jobs)
- checkpoint.py: Add fsync + retry with exponential backoff to
_write_json_atomic for shared filesystem resilience (fixed 23 jobs)
- launch_sweep.sh: Fix job name collision — Modified{TB,LogPart} both
truncated to "Modifi"; now uses unique short tags (TB/ModTB/LogPV/ModLP)
- relaunch_failed.sh: Script to relaunch exactly the 79 incomplete combos
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Replace upstream gfn.utils.training.validate with local version that uses .sum() instead of .mean() for proper L1 distance between probability distributions (was off by factor of n_terminating_states) - Sample fresh from current policy instead of reusing biased training states (visited_terminating_states[-n:]) - Bump validation_samples from 20k to 100k for reliable L1 estimates across the 331k-state HyperGrid - Add results notebook for sweep analysis Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Expanded LR/schedule grid (4 cosine + 2 linear), added smoke_test.sh for local pre-flight validation, hardened aggregate_results.py against corrupt configs, and removed obsolete relaunch script and old results. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- checkpoint.py: strip all resume machinery (auto-discover, partial CSV, hydrate completed); RunState is now a simple path container; add write_completed_checkpoint() as one-shot completion marker; include algo_names + timestamp in build_effective_config to prevent hash collisions across algorithms - tb_normalize.py: collect all records in memory and write CSV once at end instead of incremental append; remove --resume-from CLI arg; fix set_seed() API (remove stale performance_mode kwarg); switch from cosine_schedule bool to lr_schedule enum (cosine/linear/none) with lr_end_factor; change normalization to divide after squaring to avoid 1/T² gradient attenuation; add --lr-logz-multiplier CLI option - aggregate_results.py: filter glob with timestamp regex to skip aggregated_results.csv; add _deduplicate() to detect and remove duplicate config rows; fix n_mode_states_found → n_modes_found - launch_sweep.sh: add lr_logz_multiplier sweep dimension - Remove tests/test_checkpointing.py (no more checkpointing) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- aggregate_results.py: add lr_logz, lr_logz_multiplier, lr_schedule to augmented HP columns; drop stale cosine_schedule - tb_normalize.py: persist lr_logz_multiplier in CONFIG so it appears in the config snapshot - results.ipynb: updated analysis notebook Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix n_modes_found → n_mode_states_found in aggregate_results.py - Update results.ipynb with latest analysis - Add .gitignore entries for artifacts (.DS_Store, sweep_results, etc.) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rastructure - Log unnormalized (base-equation) loss for modified algorithms alongside the normalized loss, enabling cross-algorithm loss comparison on the same scale after training. - Add JSD metric computation with chunked sampling to avoid OOM. - Add high-quality final validation (10M samples) for reliable comparison. - Add Optuna HP sweep infrastructure (search + confirmation phases). - Add parallelized SLURM launch scripts for confirmation runs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
original,cosine,bitwise_xor,multiplicative_coprime), Optuna HP search (50 trials per algo×env), and confirmation runs (top-3 configs × 5 seeds × 4000 iterations).code/tolir/, addsconftest.py,environment.yaml, and a test suite for LIR gradient correctness, convergence, and pruning.What changed
lir/gflownet/:tb_normalize.py(training loop, 4 algorithms, CLI),hypergrid.py(ModifiedHyperGrid with 4 reward functions, mode analysis, GF(2) feasibility checks),checkpoint.py(atomic run bookkeeping).experiments/:optuna_sweep.py(two-phase Optuna driver), SLURM launch scripts,gflownet_results.ipynb(all figures and tables for the paper).lir/test/: Gradient sign verification, fixed-point convergence, attention-mask pruning tests.lir/lir__simpler.py: Core LIR training loop (moved fromcode/, expanded).WRITEUP.md: Synthesized experiment notes with key findings on gradient clipping and logZ learning rate behavior.Key results
multiplicative_coprime(20k modes), but similar onoriginal(256 modes).Test plan
pytestpasses (core LIR tests + checkpoint tests)python -m lir.gflownet.tb_normalize --envs original --algos TBGFlowNet --n_iterations 10 --n-seeds 1runs without errorexperiments/gflownet_results.ipynbrenders against existingoptuna_results/🤖 Generated with Claude Code