diff --git a/.gitignore b/.gitignore
index 7a535ad..ae44491 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,5 +35,6 @@ models/
# Temp
*.tmp
*.lock
+!uv.lock
worktrees/
dev/
diff --git a/README.md b/README.md
index 24ae498..de37e04 100644
--- a/README.md
+++ b/README.md
@@ -27,9 +27,9 @@ LocalPilot is different: it has a **visual web browsing agent** ([MolmoWeb-4B](h
You run it It does this, autonomously
───────── ──────────────────────────
uv run python experiments/run_enhanced_v3.py ─> 1. Reads train.py + past results
- 2. Qwen3.5-9B plans what to search
+ (or run_enhanced_v4.py for V4 WIP) 2. Qwen3.5-9B plans what to search
3. MolmoWeb-4B browses arXiv visually
- 4. Devstral-24B proposes a HP change, citing why
+ 4. Proposes a HP change, citing why
5. Edits train.py, trains locally
6. Keeps if val_bpb improves, reverts if not
7. Loops — gets smarter each iteration
@@ -37,7 +37,7 @@ LocalPilot is different: it has a **visual web browsing agent** ([MolmoWeb-4B](h
The key innovation is **step 3**: MolmoWeb-4B is a visual web agent that takes screenshots of web pages and interacts with them like a human would. It navigates to arXiv papers, scrolls through figures and tables, and extracts specific techniques — not just keyword matches from abstracts.
-All models run locally (Qwen3.5-9B for orchestration, MolmoWeb-4B for browsing, Devstral-24B for code). No API keys, no cloud bills.
+All models run locally (Qwen3.5-9B for orchestration, MolmoWeb-4B for browsing). No API keys, no cloud bills.
## Results
@@ -158,19 +158,22 @@ The runner scripts expect a local `.venv` created by `uv`. If `uv sync` (Step 1)
python -c "import sys; print(sys.executable)"
```
-> **Optional:** A Dockerfile is included for running training inside a Linux container (useful for Flash Attention 3 which requires Linux CUDA). Build with `docker build -t autoresearch-train .` if needed.
+> **Optional:** A Dockerfile is included for running training inside a Linux container (CUDA 13.0, useful for Flash Attention 3 on Hopper GPUs). Build with `docker build -t autoresearch-train .` if needed. Note: FA3 is Hopper-only (SM 9.0); Blackwell GPUs use FlexAttention until FA4 stabilizes.
### Step 6: Run it
```bash
-# Run the autonomous research agent (reads papers, proposes experiments)
+# V3 (stable) — paper-grounded search with visual browsing
uv run python experiments/run_enhanced_v3.py
-# Or run the random baseline for comparison (no LLMs needed)
+# V4 (WIP) — tiered pipeline with agent-grade resilience
+uv run python experiments/run_enhanced_v4.py
+
+# Random baseline for comparison (no LLMs needed)
uv run python experiments/run_baseline_v2.py
```
-The enhanced runner will pre-flight check that all models exist and print download commands if anything is missing.
+The enhanced runners will pre-flight check that all models exist and print download commands if anything is missing.
### Troubleshooting
@@ -184,7 +187,9 @@ The enhanced runner will pre-flight check that all models exist and print downlo
## How the research pipeline works
-### V3 (current)
+### V3 (stable — best validated result: 1.1507 BPB)
+
+V3 is the stable, validated pipeline. It achieved **1.1507 val_bpb** in 64 experiments using SDPA attention (cuDNN backend) with batch_size=32.
Qwen3.5-9B orchestrates the loop — it decides what to search, MolmoWeb-4B browses arXiv visually, and Devstral-24B writes the code patch:
@@ -236,6 +241,17 @@ V4 adds Semantic Scholar + arXiv API search with batch relevance scoring, plus a
This solves the rate-limiting problem — raw MolmoWeb browsing triggered CDN bans (~1500 HTTP requests per session). Tiered research cuts web requests by ~90%.
+### Attention backend status
+
+| Backend | GPU | Status |
+|---|---|---|
+| **SDPA (cuDNN)** | All NVIDIA | Working — V3 baseline (1.1507 BPB) |
+| **FlexAttention** | SM 8.0+ (PyTorch 2.5+) | Working — sliding window + GQA via compiled BlockMask |
+| **Flash Attention 3** | Hopper (SM 9.0) only | Working via `kernels` — not available on Blackwell |
+| **Flash Attention 4** | Blackwell (SM 12.0) | Beta (`flash-attn-4==4.0.0b7`) — JIT compilation crashes, waiting for stable release |
+
+`train.py` uses a 3-tier fallback: FA3 → FlexAttention → SDPA. On Blackwell GPUs, FlexAttention is the current best path until FA4 stabilizes.
+
## Adapting to your own project
LocalPilot isn't locked to karpathy's train.py. To use it on your own training script:
diff --git a/experiments/run_enhanced_v4.py b/experiments/run_enhanced_v4.py
index 7bb2116..b3f2874 100644
--- a/experiments/run_enhanced_v4.py
+++ b/experiments/run_enhanced_v4.py
@@ -671,11 +671,13 @@ def orchestrator_plan_search(hp_lines, results_history, tried_descs, best_bpb):
Current best val_bpb: {best_bpb:.4f} (lower is better). Starting was 1.379.
GPU: NVIDIA RTX 5090, 24 GB VRAM. Training budget: 5 min/experiment.
-Current hyperparameters in train.py:
+
{hp_numbered}
+
-Experiment history (most recent last):
+
{history}
+
Your task: Decide what research direction to explore NEXT on arXiv/Scholar.
- Look at what has been tried and what worked (KEEP) vs failed (DISCARD/CRASH)
@@ -683,12 +685,22 @@ def orchestrator_plan_search(hp_lines, results_history, tried_descs, best_bpb):
- Focus on optimizer settings, learning rates, schedules, and architecture ratios
- We can only change existing hyperparameters (no new code, no new layers)
+
+SEARCH: optimal learning rate warmup schedule small transformer pretraining
+WHY: We haven't explored warmup ratio yet and papers suggest 1-5% warmup helps convergence
+
+SEARCH: Adam beta2 0.95 vs 0.999 language model training stability
+WHY: Lower beta2 showed promise in recent experiments, need evidence for optimal value
+
+
Output EXACTLY two lines:
SEARCH:
WHY: """
response = call_llm(prompt, max_tokens=200, temperature=0.7, stop=["\n\n\n"])
- m = re.search(r"SEARCH:\s*(.+)", response)
+ # Strip thinking tags if present
+ cleaned = _strip_thinking(response)
+ m = re.search(r"SEARCH:\s*(.+)", cleaned)
if m:
query = m.group(1).strip()
print(f" [Search query: {query}]")
@@ -930,11 +942,64 @@ def validate_proposal(param_name, proposed_value, results_history, tried_descs,
# Stage 4: Orchestrator proposes PARAM + VALUE (V4: open values)
# ---------------------------------------------------------------------------
+def _strip_thinking(response):
+ """Strip Qwen's ... wrapper from response.
+ When enable_thinking=True, Qwen wraps reasoning in think tags.
+ The actual answer is after the closing tag.
+ """
+ # Remove all ... blocks (may span multiple lines)
+ cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL).strip()
+ return cleaned if cleaned else response
+
+
+def _parse_proposal(response):
+ """Parse PARAM/VALUE/REASON from LLM response with multiple fallback strategies.
+
+ Adapted from Claude Code's structured output validation — try strict format
+ first, then progressively looser patterns to maximize extraction rate.
+ Returns (param_name, raw_value, reason) or (None, None, None).
+ """
+ # Strip thinking tags first
+ text = _strip_thinking(response)
+
+ # Strategy 1: Exact format — PARAM: X / VALUE: Y / REASON: Z
+ param_m = re.search(r"PARAM:\s*(\w+)", text)
+ value_m = re.search(r"VALUE:\s*(.+)", text)
+ reason_m = re.search(r"REASON:\s*(.+)", text)
+ if param_m and value_m:
+ return (param_m.group(1).strip(),
+ value_m.group(1).strip(),
+ reason_m.group(1).strip() if reason_m else "no reason given")
+
+ # Strategy 2: Markdown-style — **PARAM**: X or `PARAM`: X
+ param_m = re.search(r"\*?\*?PARAM\*?\*?:\s*`?(\w+)`?", text)
+ value_m = re.search(r"\*?\*?VALUE\*?\*?:\s*`?(.+?)`?\s*$", text, re.M)
+ if param_m and value_m:
+ reason_m = re.search(r"\*?\*?REASON\*?\*?:\s*(.+)", text)
+ return (param_m.group(1).strip(),
+ value_m.group(1).strip(),
+ reason_m.group(1).strip() if reason_m else "no reason given")
+
+ # Strategy 3: Line-by-line — look for any line with "param_name = value" pattern
+ # This catches cases where the LLM outputs the actual Python assignment
+ for line in text.splitlines():
+ m = re.match(r"(\w+)\s*=\s*(.+)", line.strip())
+ if m and m.group(1) in ALL_PARAM_NAMES:
+ return (m.group(1), m.group(2).strip(), "inferred from direct assignment")
+
+ return (None, None, None)
+
+
def orchestrator_propose(hp_lines, paper_ideas, results_history,
tried_descs, best_bpb, n=PATCHES_PER_SESSION):
"""
Qwen3.5-9B reviews findings + history, proposes PARAM + exact VALUE.
Values are clamped to safe bounds after proposal.
+
+ Prompt design adapted from Claude Code's structured output patterns:
+ - Few-shot examples showing exact expected format
+ - XML-style section markers for clear context boundaries
+ - Explicit constraint listing to reduce hallucinated params
"""
start_llama_server(QWEN_MODEL, ctx_size=8192, label="Qwen3.5-9B orchestrator",
enable_thinking=True)
@@ -955,6 +1020,7 @@ def orchestrator_propose(hp_lines, paper_ideas, results_history,
proposals = []
hp_dict = get_current_hp_dict(hp_lines)
+ rejection_counts = {} # track why proposals fail for diagnostics
for attempt in range(n * 3):
if len(proposals) >= n:
@@ -966,56 +1032,83 @@ def orchestrator_propose(hp_lines, paper_ideas, results_history,
if not available:
available = [p for p in ALL_PARAM_NAMES if p not in batch_params]
- avail_str = "\n".join(f" - {name}" for name in available)
+ avail_str = ", ".join(available)
prompt = f"""You are an ML research orchestrator tuning a GPT language model.
-Current best val_bpb: {best_bpb:.4f} (lower is better). Starting was 1.379.
+
+Current best val_bpb: {best_bpb:.4f} (lower is better). Starting was 1.379.
GPU: NVIDIA RTX 5090, 24 GB VRAM. Training budget: 5 min/experiment.
-This is a small GPT (~124M params). Focus on optimizer, LR, and schedule tuning.
+Model: small GPT (~124M params, 8 layers, 512 embed dim).
+
-Current hyperparameters:
+
{hp_numbered}
+
+
{param_summary}
+
+
{bounds_desc}
+
-Parameters you can choose from (others are on cooldown):
+
{avail_str}
+
-Research findings:
-{paper_ideas[:2000]}
+
+{paper_ideas[:2000] if paper_ideas else "No research findings available. Use your knowledge of transformer training best practices."}
+
-Pick ONE parameter and propose an EXACT new value.
-- You can set ANY value within the bounds (not just +/- steps)
-- If research suggests a specific value, propose it directly
+Pick ONE parameter from and propose an EXACT new value within .
+Rules:
+- Pick from ONLY (others are on cooldown)
+- Set ANY value within bounds (not just small steps)
- Prefer UNEXPLORED parameters
-- For ADAM_BETAS, format as (beta1, beta2)
-
-Output EXACTLY 3 lines:
-PARAM:
-VALUE:
-REASON:
-"""
+- For ADAM_BETAS use format (beta1, beta2)
+- For WINDOW_PATTERN use quotes like "SSSL"
+
+
+Example 1:
+PARAM: WEIGHT_DECAY
+VALUE: 0.15
+REASON: Research shows higher weight decay (0.1-0.2) improves generalization for small transformers
+
+Example 2:
+PARAM: ADAM_BETAS
+VALUE: (0.9, 0.95)
+REASON: Lower beta2 can help with training stability in small models per Wortsman et al.
+
+Example 3:
+PARAM: DEPTH
+VALUE: 12
+REASON: Deeper networks with same param count often achieve lower loss per Kaplan scaling laws
+
+
+Output EXACTLY 3 lines (no other text):
+PARAM:
+VALUE:
+REASON: """
try:
- response = call_llm(prompt, max_tokens=200, temperature=0.9,
+ response = call_llm(prompt, max_tokens=300, temperature=0.7,
stop=["\n\n\n"])
- param_m = re.search(r"PARAM:\s*(\w+)", response)
- value_m = re.search(r"VALUE:\s*(.+)", response)
- reason_m = re.search(r"REASON:\s*(.+)", response)
+ param_name, raw_value, reason = _parse_proposal(response)
- if not all([param_m, value_m]):
+ if param_name is None:
+ rejection_counts["parse_fail"] = rejection_counts.get("parse_fail", 0) + 1
+ print(f" [attempt {attempt+1}: parse fail — "
+ f"response: {_strip_thinking(response)[:80]}]")
continue
- param_name = param_m.group(1).strip()
- raw_value = value_m.group(1).strip()
- reason = reason_m.group(1).strip() if reason_m else "no reason given"
-
if param_name not in available:
+ rejection_counts["not_available"] = rejection_counts.get("not_available", 0) + 1
+ print(f" [attempt {attempt+1}: {param_name} not in available params]")
continue
if param_name in batch_params:
+ rejection_counts["already_in_batch"] = rejection_counts.get("already_in_batch", 0) + 1
continue
# Clamp to safe bounds
@@ -1024,33 +1117,41 @@ def orchestrator_propose(hp_lines, paper_ideas, results_history,
# Produce edit
edit = make_edit(param_name, safe_value, hp_lines)
if edit is None:
+ rejection_counts["no_edit"] = rejection_counts.get("no_edit", 0) + 1
+ print(f" [attempt {attempt+1}: {param_name}={safe_value} "
+ f"same as current or not found in train.py]")
continue
# OOM pre-flight check
test_hp = dict(hp_dict)
test_hp[param_name] = safe_value
if would_oom(test_hp):
+ rejection_counts["oom"] = rejection_counts.get("oom", 0) + 1
print(f" [BLOCKED OOM: {param_name}={safe_value}]")
continue
edit["reason"] = reason
# --- Post-proposal validation (adapted from Claude Code stop hooks) ---
- # Before accepting, do a quick sanity check: is this change
- # contradicted by our own experiment history?
rejection = validate_proposal(param_name, safe_value, results_history,
tried_descs, best_bpb)
if rejection:
+ rejection_counts["validation"] = rejection_counts.get("validation", 0) + 1
print(f" [REJECTED: {param_name}={safe_value} — {rejection}]")
continue
proposals.append(edit)
- print(f" proposal {len(proposals)}: {edit['desc']} | {reason[:50]}")
+ print(f" proposal {len(proposals)}: {edit['desc']} | {reason[:60]}")
except Exception as e:
+ rejection_counts["error"] = rejection_counts.get("error", 0) + 1
print(f" [orchestrator propose error: {e}]")
time.sleep(1)
+ # Diagnostic summary
+ if rejection_counts:
+ print(f" [Proposal diagnostics: {rejection_counts}]")
+
return proposals
diff --git a/pyproject.toml b/pyproject.toml
index b81df43..4de1241 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,6 +11,7 @@ dependencies = [
"pandas>=2.3.3",
"pyarrow>=21.0.0",
"beautifulsoup4>=4.12.0",
+ "jinja2>=3.0.0",
"Pillow>=10.0.0",
"playwright>=1.40.0",
"requests>=2.32.0",
diff --git a/tests/test_v4_agent_patterns.py b/tests/test_v4_agent_patterns.py
index 4a70ac0..19f6609 100644
--- a/tests/test_v4_agent_patterns.py
+++ b/tests/test_v4_agent_patterns.py
@@ -215,6 +215,44 @@ def _categorize_error(error):
return "unknown"
+# --- _strip_thinking + _parse_proposal (copied from v4) ---
+def _strip_thinking(response):
+ """Strip Qwen's ... wrapper from response."""
+ cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL).strip()
+ return cleaned if cleaned else response
+
+
+def _parse_proposal(response):
+ """Parse PARAM/VALUE/REASON from LLM response with multiple fallback strategies."""
+ text = _strip_thinking(response)
+
+ # Strategy 1: Exact format
+ param_m = re.search(r"PARAM:\s*(\w+)", text)
+ value_m = re.search(r"VALUE:\s*(.+)", text)
+ reason_m = re.search(r"REASON:\s*(.+)", text)
+ if param_m and value_m:
+ return (param_m.group(1).strip(),
+ value_m.group(1).strip(),
+ reason_m.group(1).strip() if reason_m else "no reason given")
+
+ # Strategy 2: Markdown-style
+ param_m = re.search(r"\*?\*?PARAM\*?\*?:\s*`?(\w+)`?", text)
+ value_m = re.search(r"\*?\*?VALUE\*?\*?:\s*`?(.+?)`?\s*$", text, re.M)
+ if param_m and value_m:
+ reason_m = re.search(r"\*?\*?REASON\*?\*?:\s*(.+)", text)
+ return (param_m.group(1).strip(),
+ value_m.group(1).strip(),
+ reason_m.group(1).strip() if reason_m else "no reason given")
+
+ # Strategy 3: Direct assignment
+ for line in text.splitlines():
+ m = re.match(r"(\w+)\s*=\s*(.+)", line.strip())
+ if m and m.group(1) in ALL_PARAM_NAMES:
+ return (m.group(1), m.group(2).strip(), "inferred from direct assignment")
+
+ return (None, None, None)
+
+
# --- should_stop (copied from v4) ---
CONSEC_DISCARD_LIMIT = 15
NO_KEEP_WINDOW = 20
@@ -906,6 +944,139 @@ def test_full_pipeline_reject_duplicate():
assert "duplicate" in rejection
+# ---------------------------------------------------------------------------
+# 10. _strip_thinking tests
+# ---------------------------------------------------------------------------
+
+def test_strip_thinking_basic():
+ """Strip simple think tags."""
+ response = "Let me analyze this carefully...\nPARAM: DEPTH\nVALUE: 6"
+ assert _strip_thinking(response) == "PARAM: DEPTH\nVALUE: 6"
+
+
+def test_strip_thinking_multiline():
+ """Strip multiline think tags."""
+ response = "\nI need to consider\nmultiple factors\nhere\n\nPARAM: WEIGHT_DECAY\nVALUE: 0.15"
+ cleaned = _strip_thinking(response)
+ assert "" not in cleaned
+ assert "PARAM: WEIGHT_DECAY" in cleaned
+
+
+def test_strip_thinking_no_tags():
+ """Pass through response without think tags."""
+ response = "PARAM: DEPTH\nVALUE: 6\nREASON: deeper model"
+ assert _strip_thinking(response) == response
+
+
+def test_strip_thinking_empty_content():
+ """When thinking is the entire response, return original."""
+ response = "only thinking here"
+ assert _strip_thinking(response) == response # returns original when cleaned is empty? No, cleaned is empty, so returns response
+
+
+def test_strip_thinking_multiple_blocks():
+ """Strip multiple think blocks."""
+ response = "first thoughtPARAM: DEPTH\nsecond thoughtVALUE: 8"
+ cleaned = _strip_thinking(response)
+ assert "" not in cleaned
+ assert "PARAM: DEPTH" in cleaned
+ assert "VALUE: 8" in cleaned
+
+
+# ---------------------------------------------------------------------------
+# 11. _parse_proposal tests
+# ---------------------------------------------------------------------------
+
+def test_parse_exact_format():
+ """Parse standard PARAM/VALUE/REASON format."""
+ response = "PARAM: WEIGHT_DECAY\nVALUE: 0.15\nREASON: higher weight decay helps"
+ param, value, reason = _parse_proposal(response)
+ assert param == "WEIGHT_DECAY"
+ assert value == "0.15"
+ assert "higher weight decay" in reason
+
+
+def test_parse_with_thinking():
+ """Parse format wrapped in thinking tags."""
+ response = "Let me think about this...\nI should try weight decay.\n\nPARAM: WEIGHT_DECAY\nVALUE: 0.2\nREASON: regularization"
+ param, value, reason = _parse_proposal(response)
+ assert param == "WEIGHT_DECAY"
+ assert value == "0.2"
+
+
+def test_parse_markdown_format():
+ """Parse markdown-style bold formatting."""
+ response = "**PARAM**: `DEPTH`\n**VALUE**: `8`\n**REASON**: deeper model improves loss"
+ param, value, reason = _parse_proposal(response)
+ assert param == "DEPTH"
+ assert value == "8"
+
+
+def test_parse_direct_assignment():
+ """Parse Python-style assignment as fallback."""
+ response = "I think we should set:\nWEIGHT_DECAY = 0.2"
+ param, value, reason = _parse_proposal(response)
+ assert param == "WEIGHT_DECAY"
+ assert value == "0.2"
+
+
+def test_parse_no_value():
+ """Return None tuple when no parseable format found."""
+ response = "I'm not sure what to suggest."
+ param, value, reason = _parse_proposal(response)
+ assert param is None
+ assert value is None
+ assert reason is None
+
+
+def test_parse_adam_betas():
+ """Parse ADAM_BETAS tuple format."""
+ response = "PARAM: ADAM_BETAS\nVALUE: (0.9, 0.95)\nREASON: lower beta2 for stability"
+ param, value, reason = _parse_proposal(response)
+ assert param == "ADAM_BETAS"
+ assert "(0.9, 0.95)" in value
+
+
+def test_parse_window_pattern():
+ """Parse WINDOW_PATTERN with quotes."""
+ response = 'PARAM: WINDOW_PATTERN\nVALUE: "LLLL"\nREASON: full attention for small model'
+ param, value, reason = _parse_proposal(response)
+ assert param == "WINDOW_PATTERN"
+ assert "LLLL" in value
+
+
+def test_parse_missing_reason():
+ """Parse successfully even without REASON line."""
+ response = "PARAM: DEPTH\nVALUE: 10"
+ param, value, reason = _parse_proposal(response)
+ assert param == "DEPTH"
+ assert value == "10"
+ assert reason == "no reason given"
+
+
+def test_parse_extra_whitespace():
+ """Handle extra whitespace around values."""
+ response = "PARAM: MATRIX_LR \nVALUE: 0.06 \nREASON: midrange learning rate "
+ param, value, reason = _parse_proposal(response)
+ assert param == "MATRIX_LR"
+ assert value == "0.06"
+
+
+def test_parse_thinking_then_markdown():
+ """Parse thinking + markdown combo (common Qwen output)."""
+ response = "\nThis model is small, so deeper depth might help.\nBut we should also consider width.\n\n**PARAM**: ASPECT_RATIO\n**VALUE**: 80\n**REASON**: wider model trades depth for width"
+ param, value, reason = _parse_proposal(response)
+ assert param == "ASPECT_RATIO"
+ assert value == "80"
+
+
+def test_parse_direct_assignment_not_param():
+ """Direct assignment strategy only matches known params."""
+ response = "I think:\nFOO_BAR = 42\nSOMETHING_ELSE = 99"
+ param, value, reason = _parse_proposal(response)
+ assert param is None # neither FOO_BAR nor SOMETHING_ELSE is in ALL_PARAM_NAMES
+
+
# ===========================================================================
# Run all tests
# ===========================================================================
@@ -981,6 +1152,26 @@ def test_full_pipeline_reject_duplicate():
# Integration
("pipeline: full proposal flow", test_full_proposal_pipeline),
("pipeline: reject duplicate", test_full_pipeline_reject_duplicate),
+
+ # Strip thinking
+ ("thinking: basic strip", test_strip_thinking_basic),
+ ("thinking: multiline strip", test_strip_thinking_multiline),
+ ("thinking: no tags passthrough", test_strip_thinking_no_tags),
+ ("thinking: empty content", test_strip_thinking_empty_content),
+ ("thinking: multiple blocks", test_strip_thinking_multiple_blocks),
+
+ # Parse proposal
+ ("parse: exact format", test_parse_exact_format),
+ ("parse: with thinking tags", test_parse_with_thinking),
+ ("parse: markdown format", test_parse_markdown_format),
+ ("parse: direct assignment", test_parse_direct_assignment),
+ ("parse: no value returns None", test_parse_no_value),
+ ("parse: ADAM_BETAS tuple", test_parse_adam_betas),
+ ("parse: WINDOW_PATTERN", test_parse_window_pattern),
+ ("parse: missing reason", test_parse_missing_reason),
+ ("parse: extra whitespace", test_parse_extra_whitespace),
+ ("parse: thinking + markdown", test_parse_thinking_then_markdown),
+ ("parse: direct assignment non-param", test_parse_direct_assignment_not_param),
]
for name, fn in tests:
diff --git a/train.py b/train.py
index d64257b..e44904d 100644
--- a/train.py
+++ b/train.py
@@ -19,9 +19,41 @@
try:
from kernels import get_kernel
- fa3 = get_kernel('varunneal/flash-attention-3').flash_attn_interface
-except (ImportError, FileNotFoundError, OSError):
- fa3 = None # fallback to PyTorch SDPA (uses more VRAM)
+ cap = torch.cuda.get_device_capability()
+ # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
+ repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
+ fa3 = get_kernel(repo).flash_attn_interface
+except Exception:
+ fa3 = None
+
+# FlexAttention (PyTorch 2.5+): efficient sliding window + GQA on any GPU
+try:
+ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
+ _flex_attention = torch.compile(flex_attention)
+ _block_mask_cache = {} # (window_size, seq_len) -> BlockMask
+
+ def _get_block_mask(window_size, seq_len, device):
+ """Get or create a cached BlockMask for the given window size."""
+ key = (window_size, seq_len, device)
+ if key not in _block_mask_cache:
+ if window_size <= 0 or window_size >= seq_len:
+ # Full causal attention
+ def causal_mask(b, h, q_idx, kv_idx):
+ return q_idx >= kv_idx
+ _block_mask_cache[key] = create_block_mask(
+ causal_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, device=device)
+ else:
+ # Sliding window causal attention
+ _ws = window_size # capture for closure
+ def sliding_window_causal(b, h, q_idx, kv_idx):
+ return (q_idx >= kv_idx) & (q_idx - kv_idx < _ws)
+ _block_mask_cache[key] = create_block_mask(
+ sliding_window_causal, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, device=device)
+ return _block_mask_cache[key]
+
+ has_flex_attention = True
+except ImportError:
+ has_flex_attention = False
from localpilot.constants import MAX_SEQ_LEN, TIME_BUDGET
from prepare import Tokenizer, make_dataloader, evaluate_bpb
@@ -93,16 +125,27 @@ def forward(self, x, ve, cos_sin, window_size):
if fa3 is not None:
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
- else:
- # Fallback: PyTorch SDPA (no sliding window, full causal attention)
+ elif has_flex_attention:
+ # FlexAttention: efficient sliding window + GQA on any GPU (PyTorch 2.5+)
q = q.transpose(1, 2) # (B, n_head, T, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
+ ws = window_size[0] if isinstance(window_size, tuple) else window_size
+ block_mask = _get_block_mask(ws, T, q.device)
+ use_gqa = k.size(1) != q.size(1)
+ y = _flex_attention(q, k, v, block_mask=block_mask,
+ enable_gqa=use_gqa)
+ y = y.transpose(1, 2) # (B, T, n_head, head_dim)
+ else:
+ # Last resort: PyTorch SDPA (no sliding window, full causal attention)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
if k.size(1) != q.size(1):
k = k.repeat_interleave(q.size(1) // k.size(1), dim=1)
v = v.repeat_interleave(q.size(1) // v.size(1), dim=1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
- y = y.transpose(1, 2) # (B, T, n_head, head_dim)
+ y = y.transpose(1, 2)
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
@@ -460,7 +503,7 @@ def step(self):
# Model size
DEPTH = 8 # number of transformer layers
-DEVICE_BATCH_SIZE = 64 if fa3 is None else 128 # halved for SDPA (O(n^2) vs O(n) memory)
+DEVICE_BATCH_SIZE = 128 if fa3 is not None else 64 # FlexAttention/SDPA: 64 (grad_accum=2 for 2^18 batch)
# ---------------------------------------------------------------------------
# Setup: tokenizer, model, optimizer, dataloader