diff --git a/.gitignore b/.gitignore index 7a535ad..ae44491 100644 --- a/.gitignore +++ b/.gitignore @@ -35,5 +35,6 @@ models/ # Temp *.tmp *.lock +!uv.lock worktrees/ dev/ diff --git a/README.md b/README.md index 24ae498..de37e04 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,9 @@ LocalPilot is different: it has a **visual web browsing agent** ([MolmoWeb-4B](h You run it It does this, autonomously ───────── ────────────────────────── uv run python experiments/run_enhanced_v3.py ─> 1. Reads train.py + past results - 2. Qwen3.5-9B plans what to search + (or run_enhanced_v4.py for V4 WIP) 2. Qwen3.5-9B plans what to search 3. MolmoWeb-4B browses arXiv visually - 4. Devstral-24B proposes a HP change, citing why + 4. Proposes a HP change, citing why 5. Edits train.py, trains locally 6. Keeps if val_bpb improves, reverts if not 7. Loops — gets smarter each iteration @@ -37,7 +37,7 @@ LocalPilot is different: it has a **visual web browsing agent** ([MolmoWeb-4B](h The key innovation is **step 3**: MolmoWeb-4B is a visual web agent that takes screenshots of web pages and interacts with them like a human would. It navigates to arXiv papers, scrolls through figures and tables, and extracts specific techniques — not just keyword matches from abstracts. -All models run locally (Qwen3.5-9B for orchestration, MolmoWeb-4B for browsing, Devstral-24B for code). No API keys, no cloud bills. +All models run locally (Qwen3.5-9B for orchestration, MolmoWeb-4B for browsing). No API keys, no cloud bills. ## Results @@ -158,19 +158,22 @@ The runner scripts expect a local `.venv` created by `uv`. If `uv sync` (Step 1) python -c "import sys; print(sys.executable)" ``` -> **Optional:** A Dockerfile is included for running training inside a Linux container (useful for Flash Attention 3 which requires Linux CUDA). Build with `docker build -t autoresearch-train .` if needed. +> **Optional:** A Dockerfile is included for running training inside a Linux container (CUDA 13.0, useful for Flash Attention 3 on Hopper GPUs). Build with `docker build -t autoresearch-train .` if needed. Note: FA3 is Hopper-only (SM 9.0); Blackwell GPUs use FlexAttention until FA4 stabilizes. ### Step 6: Run it ```bash -# Run the autonomous research agent (reads papers, proposes experiments) +# V3 (stable) — paper-grounded search with visual browsing uv run python experiments/run_enhanced_v3.py -# Or run the random baseline for comparison (no LLMs needed) +# V4 (WIP) — tiered pipeline with agent-grade resilience +uv run python experiments/run_enhanced_v4.py + +# Random baseline for comparison (no LLMs needed) uv run python experiments/run_baseline_v2.py ``` -The enhanced runner will pre-flight check that all models exist and print download commands if anything is missing. +The enhanced runners will pre-flight check that all models exist and print download commands if anything is missing. ### Troubleshooting @@ -184,7 +187,9 @@ The enhanced runner will pre-flight check that all models exist and print downlo ## How the research pipeline works -### V3 (current) +### V3 (stable — best validated result: 1.1507 BPB) + +V3 is the stable, validated pipeline. It achieved **1.1507 val_bpb** in 64 experiments using SDPA attention (cuDNN backend) with batch_size=32. Qwen3.5-9B orchestrates the loop — it decides what to search, MolmoWeb-4B browses arXiv visually, and Devstral-24B writes the code patch: @@ -236,6 +241,17 @@ V4 adds Semantic Scholar + arXiv API search with batch relevance scoring, plus a This solves the rate-limiting problem — raw MolmoWeb browsing triggered CDN bans (~1500 HTTP requests per session). Tiered research cuts web requests by ~90%. +### Attention backend status + +| Backend | GPU | Status | +|---|---|---| +| **SDPA (cuDNN)** | All NVIDIA | Working — V3 baseline (1.1507 BPB) | +| **FlexAttention** | SM 8.0+ (PyTorch 2.5+) | Working — sliding window + GQA via compiled BlockMask | +| **Flash Attention 3** | Hopper (SM 9.0) only | Working via `kernels` — not available on Blackwell | +| **Flash Attention 4** | Blackwell (SM 12.0) | Beta (`flash-attn-4==4.0.0b7`) — JIT compilation crashes, waiting for stable release | + +`train.py` uses a 3-tier fallback: FA3 → FlexAttention → SDPA. On Blackwell GPUs, FlexAttention is the current best path until FA4 stabilizes. + ## Adapting to your own project LocalPilot isn't locked to karpathy's train.py. To use it on your own training script: diff --git a/experiments/run_enhanced_v4.py b/experiments/run_enhanced_v4.py index 7bb2116..b3f2874 100644 --- a/experiments/run_enhanced_v4.py +++ b/experiments/run_enhanced_v4.py @@ -671,11 +671,13 @@ def orchestrator_plan_search(hp_lines, results_history, tried_descs, best_bpb): Current best val_bpb: {best_bpb:.4f} (lower is better). Starting was 1.379. GPU: NVIDIA RTX 5090, 24 GB VRAM. Training budget: 5 min/experiment. -Current hyperparameters in train.py: + {hp_numbered} + -Experiment history (most recent last): + {history} + Your task: Decide what research direction to explore NEXT on arXiv/Scholar. - Look at what has been tried and what worked (KEEP) vs failed (DISCARD/CRASH) @@ -683,12 +685,22 @@ def orchestrator_plan_search(hp_lines, results_history, tried_descs, best_bpb): - Focus on optimizer settings, learning rates, schedules, and architecture ratios - We can only change existing hyperparameters (no new code, no new layers) + +SEARCH: optimal learning rate warmup schedule small transformer pretraining +WHY: We haven't explored warmup ratio yet and papers suggest 1-5% warmup helps convergence + +SEARCH: Adam beta2 0.95 vs 0.999 language model training stability +WHY: Lower beta2 showed promise in recent experiments, need evidence for optimal value + + Output EXACTLY two lines: SEARCH: WHY: """ response = call_llm(prompt, max_tokens=200, temperature=0.7, stop=["\n\n\n"]) - m = re.search(r"SEARCH:\s*(.+)", response) + # Strip thinking tags if present + cleaned = _strip_thinking(response) + m = re.search(r"SEARCH:\s*(.+)", cleaned) if m: query = m.group(1).strip() print(f" [Search query: {query}]") @@ -930,11 +942,64 @@ def validate_proposal(param_name, proposed_value, results_history, tried_descs, # Stage 4: Orchestrator proposes PARAM + VALUE (V4: open values) # --------------------------------------------------------------------------- +def _strip_thinking(response): + """Strip Qwen's ... wrapper from response. + When enable_thinking=True, Qwen wraps reasoning in think tags. + The actual answer is after the closing tag. + """ + # Remove all ... blocks (may span multiple lines) + cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() + return cleaned if cleaned else response + + +def _parse_proposal(response): + """Parse PARAM/VALUE/REASON from LLM response with multiple fallback strategies. + + Adapted from Claude Code's structured output validation — try strict format + first, then progressively looser patterns to maximize extraction rate. + Returns (param_name, raw_value, reason) or (None, None, None). + """ + # Strip thinking tags first + text = _strip_thinking(response) + + # Strategy 1: Exact format — PARAM: X / VALUE: Y / REASON: Z + param_m = re.search(r"PARAM:\s*(\w+)", text) + value_m = re.search(r"VALUE:\s*(.+)", text) + reason_m = re.search(r"REASON:\s*(.+)", text) + if param_m and value_m: + return (param_m.group(1).strip(), + value_m.group(1).strip(), + reason_m.group(1).strip() if reason_m else "no reason given") + + # Strategy 2: Markdown-style — **PARAM**: X or `PARAM`: X + param_m = re.search(r"\*?\*?PARAM\*?\*?:\s*`?(\w+)`?", text) + value_m = re.search(r"\*?\*?VALUE\*?\*?:\s*`?(.+?)`?\s*$", text, re.M) + if param_m and value_m: + reason_m = re.search(r"\*?\*?REASON\*?\*?:\s*(.+)", text) + return (param_m.group(1).strip(), + value_m.group(1).strip(), + reason_m.group(1).strip() if reason_m else "no reason given") + + # Strategy 3: Line-by-line — look for any line with "param_name = value" pattern + # This catches cases where the LLM outputs the actual Python assignment + for line in text.splitlines(): + m = re.match(r"(\w+)\s*=\s*(.+)", line.strip()) + if m and m.group(1) in ALL_PARAM_NAMES: + return (m.group(1), m.group(2).strip(), "inferred from direct assignment") + + return (None, None, None) + + def orchestrator_propose(hp_lines, paper_ideas, results_history, tried_descs, best_bpb, n=PATCHES_PER_SESSION): """ Qwen3.5-9B reviews findings + history, proposes PARAM + exact VALUE. Values are clamped to safe bounds after proposal. + + Prompt design adapted from Claude Code's structured output patterns: + - Few-shot examples showing exact expected format + - XML-style section markers for clear context boundaries + - Explicit constraint listing to reduce hallucinated params """ start_llama_server(QWEN_MODEL, ctx_size=8192, label="Qwen3.5-9B orchestrator", enable_thinking=True) @@ -955,6 +1020,7 @@ def orchestrator_propose(hp_lines, paper_ideas, results_history, proposals = [] hp_dict = get_current_hp_dict(hp_lines) + rejection_counts = {} # track why proposals fail for diagnostics for attempt in range(n * 3): if len(proposals) >= n: @@ -966,56 +1032,83 @@ def orchestrator_propose(hp_lines, paper_ideas, results_history, if not available: available = [p for p in ALL_PARAM_NAMES if p not in batch_params] - avail_str = "\n".join(f" - {name}" for name in available) + avail_str = ", ".join(available) prompt = f"""You are an ML research orchestrator tuning a GPT language model. -Current best val_bpb: {best_bpb:.4f} (lower is better). Starting was 1.379. + +Current best val_bpb: {best_bpb:.4f} (lower is better). Starting was 1.379. GPU: NVIDIA RTX 5090, 24 GB VRAM. Training budget: 5 min/experiment. -This is a small GPT (~124M params). Focus on optimizer, LR, and schedule tuning. +Model: small GPT (~124M params, 8 layers, 512 embed dim). + -Current hyperparameters: + {hp_numbered} + + {param_summary} + + {bounds_desc} + -Parameters you can choose from (others are on cooldown): + {avail_str} + -Research findings: -{paper_ideas[:2000]} + +{paper_ideas[:2000] if paper_ideas else "No research findings available. Use your knowledge of transformer training best practices."} + -Pick ONE parameter and propose an EXACT new value. -- You can set ANY value within the bounds (not just +/- steps) -- If research suggests a specific value, propose it directly +Pick ONE parameter from and propose an EXACT new value within . +Rules: +- Pick from ONLY (others are on cooldown) +- Set ANY value within bounds (not just small steps) - Prefer UNEXPLORED parameters -- For ADAM_BETAS, format as (beta1, beta2) - -Output EXACTLY 3 lines: -PARAM: -VALUE: -REASON: -""" +- For ADAM_BETAS use format (beta1, beta2) +- For WINDOW_PATTERN use quotes like "SSSL" + + +Example 1: +PARAM: WEIGHT_DECAY +VALUE: 0.15 +REASON: Research shows higher weight decay (0.1-0.2) improves generalization for small transformers + +Example 2: +PARAM: ADAM_BETAS +VALUE: (0.9, 0.95) +REASON: Lower beta2 can help with training stability in small models per Wortsman et al. + +Example 3: +PARAM: DEPTH +VALUE: 12 +REASON: Deeper networks with same param count often achieve lower loss per Kaplan scaling laws + + +Output EXACTLY 3 lines (no other text): +PARAM: +VALUE: +REASON: """ try: - response = call_llm(prompt, max_tokens=200, temperature=0.9, + response = call_llm(prompt, max_tokens=300, temperature=0.7, stop=["\n\n\n"]) - param_m = re.search(r"PARAM:\s*(\w+)", response) - value_m = re.search(r"VALUE:\s*(.+)", response) - reason_m = re.search(r"REASON:\s*(.+)", response) + param_name, raw_value, reason = _parse_proposal(response) - if not all([param_m, value_m]): + if param_name is None: + rejection_counts["parse_fail"] = rejection_counts.get("parse_fail", 0) + 1 + print(f" [attempt {attempt+1}: parse fail — " + f"response: {_strip_thinking(response)[:80]}]") continue - param_name = param_m.group(1).strip() - raw_value = value_m.group(1).strip() - reason = reason_m.group(1).strip() if reason_m else "no reason given" - if param_name not in available: + rejection_counts["not_available"] = rejection_counts.get("not_available", 0) + 1 + print(f" [attempt {attempt+1}: {param_name} not in available params]") continue if param_name in batch_params: + rejection_counts["already_in_batch"] = rejection_counts.get("already_in_batch", 0) + 1 continue # Clamp to safe bounds @@ -1024,33 +1117,41 @@ def orchestrator_propose(hp_lines, paper_ideas, results_history, # Produce edit edit = make_edit(param_name, safe_value, hp_lines) if edit is None: + rejection_counts["no_edit"] = rejection_counts.get("no_edit", 0) + 1 + print(f" [attempt {attempt+1}: {param_name}={safe_value} " + f"same as current or not found in train.py]") continue # OOM pre-flight check test_hp = dict(hp_dict) test_hp[param_name] = safe_value if would_oom(test_hp): + rejection_counts["oom"] = rejection_counts.get("oom", 0) + 1 print(f" [BLOCKED OOM: {param_name}={safe_value}]") continue edit["reason"] = reason # --- Post-proposal validation (adapted from Claude Code stop hooks) --- - # Before accepting, do a quick sanity check: is this change - # contradicted by our own experiment history? rejection = validate_proposal(param_name, safe_value, results_history, tried_descs, best_bpb) if rejection: + rejection_counts["validation"] = rejection_counts.get("validation", 0) + 1 print(f" [REJECTED: {param_name}={safe_value} — {rejection}]") continue proposals.append(edit) - print(f" proposal {len(proposals)}: {edit['desc']} | {reason[:50]}") + print(f" proposal {len(proposals)}: {edit['desc']} | {reason[:60]}") except Exception as e: + rejection_counts["error"] = rejection_counts.get("error", 0) + 1 print(f" [orchestrator propose error: {e}]") time.sleep(1) + # Diagnostic summary + if rejection_counts: + print(f" [Proposal diagnostics: {rejection_counts}]") + return proposals diff --git a/pyproject.toml b/pyproject.toml index b81df43..4de1241 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "pandas>=2.3.3", "pyarrow>=21.0.0", "beautifulsoup4>=4.12.0", + "jinja2>=3.0.0", "Pillow>=10.0.0", "playwright>=1.40.0", "requests>=2.32.0", diff --git a/tests/test_v4_agent_patterns.py b/tests/test_v4_agent_patterns.py index 4a70ac0..19f6609 100644 --- a/tests/test_v4_agent_patterns.py +++ b/tests/test_v4_agent_patterns.py @@ -215,6 +215,44 @@ def _categorize_error(error): return "unknown" +# --- _strip_thinking + _parse_proposal (copied from v4) --- +def _strip_thinking(response): + """Strip Qwen's ... wrapper from response.""" + cleaned = re.sub(r'.*?', '', response, flags=re.DOTALL).strip() + return cleaned if cleaned else response + + +def _parse_proposal(response): + """Parse PARAM/VALUE/REASON from LLM response with multiple fallback strategies.""" + text = _strip_thinking(response) + + # Strategy 1: Exact format + param_m = re.search(r"PARAM:\s*(\w+)", text) + value_m = re.search(r"VALUE:\s*(.+)", text) + reason_m = re.search(r"REASON:\s*(.+)", text) + if param_m and value_m: + return (param_m.group(1).strip(), + value_m.group(1).strip(), + reason_m.group(1).strip() if reason_m else "no reason given") + + # Strategy 2: Markdown-style + param_m = re.search(r"\*?\*?PARAM\*?\*?:\s*`?(\w+)`?", text) + value_m = re.search(r"\*?\*?VALUE\*?\*?:\s*`?(.+?)`?\s*$", text, re.M) + if param_m and value_m: + reason_m = re.search(r"\*?\*?REASON\*?\*?:\s*(.+)", text) + return (param_m.group(1).strip(), + value_m.group(1).strip(), + reason_m.group(1).strip() if reason_m else "no reason given") + + # Strategy 3: Direct assignment + for line in text.splitlines(): + m = re.match(r"(\w+)\s*=\s*(.+)", line.strip()) + if m and m.group(1) in ALL_PARAM_NAMES: + return (m.group(1), m.group(2).strip(), "inferred from direct assignment") + + return (None, None, None) + + # --- should_stop (copied from v4) --- CONSEC_DISCARD_LIMIT = 15 NO_KEEP_WINDOW = 20 @@ -906,6 +944,139 @@ def test_full_pipeline_reject_duplicate(): assert "duplicate" in rejection +# --------------------------------------------------------------------------- +# 10. _strip_thinking tests +# --------------------------------------------------------------------------- + +def test_strip_thinking_basic(): + """Strip simple think tags.""" + response = "Let me analyze this carefully...\nPARAM: DEPTH\nVALUE: 6" + assert _strip_thinking(response) == "PARAM: DEPTH\nVALUE: 6" + + +def test_strip_thinking_multiline(): + """Strip multiline think tags.""" + response = "\nI need to consider\nmultiple factors\nhere\n\nPARAM: WEIGHT_DECAY\nVALUE: 0.15" + cleaned = _strip_thinking(response) + assert "" not in cleaned + assert "PARAM: WEIGHT_DECAY" in cleaned + + +def test_strip_thinking_no_tags(): + """Pass through response without think tags.""" + response = "PARAM: DEPTH\nVALUE: 6\nREASON: deeper model" + assert _strip_thinking(response) == response + + +def test_strip_thinking_empty_content(): + """When thinking is the entire response, return original.""" + response = "only thinking here" + assert _strip_thinking(response) == response # returns original when cleaned is empty? No, cleaned is empty, so returns response + + +def test_strip_thinking_multiple_blocks(): + """Strip multiple think blocks.""" + response = "first thoughtPARAM: DEPTH\nsecond thoughtVALUE: 8" + cleaned = _strip_thinking(response) + assert "" not in cleaned + assert "PARAM: DEPTH" in cleaned + assert "VALUE: 8" in cleaned + + +# --------------------------------------------------------------------------- +# 11. _parse_proposal tests +# --------------------------------------------------------------------------- + +def test_parse_exact_format(): + """Parse standard PARAM/VALUE/REASON format.""" + response = "PARAM: WEIGHT_DECAY\nVALUE: 0.15\nREASON: higher weight decay helps" + param, value, reason = _parse_proposal(response) + assert param == "WEIGHT_DECAY" + assert value == "0.15" + assert "higher weight decay" in reason + + +def test_parse_with_thinking(): + """Parse format wrapped in thinking tags.""" + response = "Let me think about this...\nI should try weight decay.\n\nPARAM: WEIGHT_DECAY\nVALUE: 0.2\nREASON: regularization" + param, value, reason = _parse_proposal(response) + assert param == "WEIGHT_DECAY" + assert value == "0.2" + + +def test_parse_markdown_format(): + """Parse markdown-style bold formatting.""" + response = "**PARAM**: `DEPTH`\n**VALUE**: `8`\n**REASON**: deeper model improves loss" + param, value, reason = _parse_proposal(response) + assert param == "DEPTH" + assert value == "8" + + +def test_parse_direct_assignment(): + """Parse Python-style assignment as fallback.""" + response = "I think we should set:\nWEIGHT_DECAY = 0.2" + param, value, reason = _parse_proposal(response) + assert param == "WEIGHT_DECAY" + assert value == "0.2" + + +def test_parse_no_value(): + """Return None tuple when no parseable format found.""" + response = "I'm not sure what to suggest." + param, value, reason = _parse_proposal(response) + assert param is None + assert value is None + assert reason is None + + +def test_parse_adam_betas(): + """Parse ADAM_BETAS tuple format.""" + response = "PARAM: ADAM_BETAS\nVALUE: (0.9, 0.95)\nREASON: lower beta2 for stability" + param, value, reason = _parse_proposal(response) + assert param == "ADAM_BETAS" + assert "(0.9, 0.95)" in value + + +def test_parse_window_pattern(): + """Parse WINDOW_PATTERN with quotes.""" + response = 'PARAM: WINDOW_PATTERN\nVALUE: "LLLL"\nREASON: full attention for small model' + param, value, reason = _parse_proposal(response) + assert param == "WINDOW_PATTERN" + assert "LLLL" in value + + +def test_parse_missing_reason(): + """Parse successfully even without REASON line.""" + response = "PARAM: DEPTH\nVALUE: 10" + param, value, reason = _parse_proposal(response) + assert param == "DEPTH" + assert value == "10" + assert reason == "no reason given" + + +def test_parse_extra_whitespace(): + """Handle extra whitespace around values.""" + response = "PARAM: MATRIX_LR \nVALUE: 0.06 \nREASON: midrange learning rate " + param, value, reason = _parse_proposal(response) + assert param == "MATRIX_LR" + assert value == "0.06" + + +def test_parse_thinking_then_markdown(): + """Parse thinking + markdown combo (common Qwen output).""" + response = "\nThis model is small, so deeper depth might help.\nBut we should also consider width.\n\n**PARAM**: ASPECT_RATIO\n**VALUE**: 80\n**REASON**: wider model trades depth for width" + param, value, reason = _parse_proposal(response) + assert param == "ASPECT_RATIO" + assert value == "80" + + +def test_parse_direct_assignment_not_param(): + """Direct assignment strategy only matches known params.""" + response = "I think:\nFOO_BAR = 42\nSOMETHING_ELSE = 99" + param, value, reason = _parse_proposal(response) + assert param is None # neither FOO_BAR nor SOMETHING_ELSE is in ALL_PARAM_NAMES + + # =========================================================================== # Run all tests # =========================================================================== @@ -981,6 +1152,26 @@ def test_full_pipeline_reject_duplicate(): # Integration ("pipeline: full proposal flow", test_full_proposal_pipeline), ("pipeline: reject duplicate", test_full_pipeline_reject_duplicate), + + # Strip thinking + ("thinking: basic strip", test_strip_thinking_basic), + ("thinking: multiline strip", test_strip_thinking_multiline), + ("thinking: no tags passthrough", test_strip_thinking_no_tags), + ("thinking: empty content", test_strip_thinking_empty_content), + ("thinking: multiple blocks", test_strip_thinking_multiple_blocks), + + # Parse proposal + ("parse: exact format", test_parse_exact_format), + ("parse: with thinking tags", test_parse_with_thinking), + ("parse: markdown format", test_parse_markdown_format), + ("parse: direct assignment", test_parse_direct_assignment), + ("parse: no value returns None", test_parse_no_value), + ("parse: ADAM_BETAS tuple", test_parse_adam_betas), + ("parse: WINDOW_PATTERN", test_parse_window_pattern), + ("parse: missing reason", test_parse_missing_reason), + ("parse: extra whitespace", test_parse_extra_whitespace), + ("parse: thinking + markdown", test_parse_thinking_then_markdown), + ("parse: direct assignment non-param", test_parse_direct_assignment_not_param), ] for name, fn in tests: diff --git a/train.py b/train.py index d64257b..e44904d 100644 --- a/train.py +++ b/train.py @@ -19,9 +19,41 @@ try: from kernels import get_kernel - fa3 = get_kernel('varunneal/flash-attention-3').flash_attn_interface -except (ImportError, FileNotFoundError, OSError): - fa3 = None # fallback to PyTorch SDPA (uses more VRAM) + cap = torch.cuda.get_device_capability() + # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface +except Exception: + fa3 = None + +# FlexAttention (PyTorch 2.5+): efficient sliding window + GQA on any GPU +try: + from torch.nn.attention.flex_attention import flex_attention, create_block_mask + _flex_attention = torch.compile(flex_attention) + _block_mask_cache = {} # (window_size, seq_len) -> BlockMask + + def _get_block_mask(window_size, seq_len, device): + """Get or create a cached BlockMask for the given window size.""" + key = (window_size, seq_len, device) + if key not in _block_mask_cache: + if window_size <= 0 or window_size >= seq_len: + # Full causal attention + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + _block_mask_cache[key] = create_block_mask( + causal_mask, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, device=device) + else: + # Sliding window causal attention + _ws = window_size # capture for closure + def sliding_window_causal(b, h, q_idx, kv_idx): + return (q_idx >= kv_idx) & (q_idx - kv_idx < _ws) + _block_mask_cache[key] = create_block_mask( + sliding_window_causal, B=None, H=None, Q_LEN=seq_len, KV_LEN=seq_len, device=device) + return _block_mask_cache[key] + + has_flex_attention = True +except ImportError: + has_flex_attention = False from localpilot.constants import MAX_SEQ_LEN, TIME_BUDGET from prepare import Tokenizer, make_dataloader, evaluate_bpb @@ -93,16 +125,27 @@ def forward(self, x, ve, cos_sin, window_size): if fa3 is not None: y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size) - else: - # Fallback: PyTorch SDPA (no sliding window, full causal attention) + elif has_flex_attention: + # FlexAttention: efficient sliding window + GQA on any GPU (PyTorch 2.5+) q = q.transpose(1, 2) # (B, n_head, T, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) + ws = window_size[0] if isinstance(window_size, tuple) else window_size + block_mask = _get_block_mask(ws, T, q.device) + use_gqa = k.size(1) != q.size(1) + y = _flex_attention(q, k, v, block_mask=block_mask, + enable_gqa=use_gqa) + y = y.transpose(1, 2) # (B, T, n_head, head_dim) + else: + # Last resort: PyTorch SDPA (no sliding window, full causal attention) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) if k.size(1) != q.size(1): k = k.repeat_interleave(q.size(1) // k.size(1), dim=1) v = v.repeat_interleave(q.size(1) // v.size(1), dim=1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) - y = y.transpose(1, 2) # (B, T, n_head, head_dim) + y = y.transpose(1, 2) y = y.contiguous().view(B, T, -1) y = self.c_proj(y) return y @@ -460,7 +503,7 @@ def step(self): # Model size DEPTH = 8 # number of transformer layers -DEVICE_BATCH_SIZE = 64 if fa3 is None else 128 # halved for SDPA (O(n^2) vs O(n) memory) +DEVICE_BATCH_SIZE = 128 if fa3 is not None else 64 # FlexAttention/SDPA: 64 (grad_accum=2 for 2^18 batch) # --------------------------------------------------------------------------- # Setup: tokenizer, model, optimizer, dataloader