diff --git a/examples/scripts/benchmark/benchmark_config.py b/examples/scripts/benchmark/benchmark_config.py new file mode 100644 index 0000000..a85fa7c --- /dev/null +++ b/examples/scripts/benchmark/benchmark_config.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +# --------------------------------------------------------------------------- +# Trie benchmarks (benchmark_fused_sample, benchmark_vtnk) +# --------------------------------------------------------------------------- + +@dataclass +class RunConfig: + """Parameters for a single trie benchmark run.""" + B: int + N: int + k: int + k_top: int + sparsity: float + + +@dataclass +class BenchmarkSuite: + """Collection of trie benchmark runs with shared timing / mode settings.""" + runs: list[RunConfig] + warmup: int = 25 + rep: int = 100 + diverse_nodes: bool = False + + +def trie_runs( + B_vals: list[int], + N_vals: list[int], + k: int, + k_top: int, + sparsity: float, +) -> list[RunConfig]: + """Cross-product of B_vals × N_vals, all sharing the same sparsity.""" + return [ + RunConfig(B=B, N=N, k=k, k_top=k_top, sparsity=sparsity) + for B in B_vals + for N in N_vals + ] + + +# --------------------------------------------------------------------------- +# Quantize benchmarks (benchmark_nn_quantize) +# --------------------------------------------------------------------------- + +@dataclass +class QuantizeRunConfig: + """Parameters for a single nearest-neighbour quantize benchmark run.""" + B: int + D: int + N: int # codebook size + + +@dataclass +class QuantizeBenchmarkSuite: + """Collection of quantize benchmark runs with shared timing settings.""" + runs: list[QuantizeRunConfig] + warmup: int = 25 + rep: int = 100 + faiss_warmup: int = 100 + + +# --------------------------------------------------------------------------- +# Update-speed benchmarks (benchmark_update_speed) +# --------------------------------------------------------------------------- + +@dataclass +class UpdateRunConfig: + """Parameters for a single trie-update benchmark run.""" + initial_n: int + update_n: int + + +@dataclass +class UpdateBenchmarkSuite: + """Collection of update benchmark runs with shared timing / vocab settings.""" + runs: list[UpdateRunConfig] + warmup: int = 5 + rep: int = 20 + vocab_size: int = 256 + seq_len: int = 4 + growing_rounds: int = 30 + + +# --------------------------------------------------------------------------- +# Default suites +# --------------------------------------------------------------------------- + +FUSED_SAMPLE_SUITE_N256 = BenchmarkSuite( + runs=trie_runs(B_vals=[256, 1024], N_vals=[256], k=1024, k_top=50, sparsity=0.9), +) + +FUSED_SAMPLE_SUITE_N150K = BenchmarkSuite( + runs=trie_runs(B_vals=[256, 1024], N_vals=[150_000], k=1024, k_top=50, sparsity=0.01), +) + +VTNK_SUITE = BenchmarkSuite( + runs=trie_runs( + B_vals=[256, 1024], + N_vals=[150_000], + k=512, + k_top=50, + sparsity=0.01, + ), +) + +QUANTIZE_SUITE = QuantizeBenchmarkSuite( + runs=[ + QuantizeRunConfig(B=B, D=D, N=N) + for N in [64, 128, 256, 512] + for B in [32, 256, 1024, 16384, 32768, 65536] + for D in [64, 128, 256] + ], +) + +UPDATE_SUITE = UpdateBenchmarkSuite( + runs=[ + UpdateRunConfig(initial_n=initial_n, update_n=update_n) + for initial_n in [10_000, 100_000, 100_000] + for update_n in [100, 1_000, 10_000] + ], +) diff --git a/examples/scripts/benchmark/benchmark_fused_sample.py b/examples/scripts/benchmark/benchmark_fused_sample.py index 60eec73..dc776e5 100644 --- a/examples/scripts/benchmark/benchmark_fused_sample.py +++ b/examples/scripts/benchmark/benchmark_fused_sample.py @@ -5,11 +5,12 @@ fused_linear_constrained_node_transition_topk (Triton, single kernel) vs torch.compile(sparse_linear_pytorch) + separate torch.topk -Grid: B (batch size) × N (vocab / logits size). K (hidden dim) fixed. +Grid: B (batch size) × N (vocab / logits size). K (hidden dim) fixed per run. """ import argparse import os +from typing import Any, cast os.environ["TRITON_PRINT_AUTOTUNING"] = "1" @@ -21,31 +22,49 @@ import pandas as pd from rectokens.schemas.compact_csr_trie import CompactCSRTrie +from rectokens.schemas.compact_ell_trie import CompactELLTrie from rectokens.schemas.state import ConstraintState -from rectokens.decoding.vntk import sparse_linear_pytorch, sparse_linear_compact_pytorch +from rectokens.decoding.vntk import ( + sparse_linear_pytorch, + sparse_linear_compact_pytorch, + sparse_linear_ell_pytorch, + sparse_linear_compact_ell_pytorch, +) from rectokens.ops.constrained_node_transition import ( + constrained_node_transition, fused_linear_constrained_node_transition_sampling, fused_linear_constrained_node_transition_topk, ) +from rectokens.kernels.constrained_node_transition_ell import ( + _ell_fused_linear_constrained_node_transition_sampling_op as ell_sampling_op, + _ell_fused_linear_constrained_node_transition_topk_op as ell_topk_op, +) +from benchmark_config import FUSED_SAMPLE_SUITE_N256, FUSED_SAMPLE_SUITE_N150K, BenchmarkSuite, RunConfig DEVICE = torch.device("cuda") -K = 1024 -K_TOP = 50 -WARMUP = 25 -REP = 100 ALL_ALGORITHMS = [ "fused_sample", + "ell_sample", "sparse_pytorch_sample", + "sparse_pytorch_sample_compact", + "sparse_pytorch_ell_sample", "fused_topk", + "ell_topk", "sparse_pytorch_topk", "sparse_pytorch_topk_compact", + "sparse_pytorch_ell_topk_compact", + "dense_topk", ] DEFAULT_ALGORITHMS = [ + "fused_topk", + #"ell_topk", + "sparse_pytorch_topk_compact", + #"sparse_pytorch_ell_topk_compact", "fused_sample", - "sparse_pytorch_sample", + "sparse_pytorch_sample_compact", + "dense_topk", ] -DEFAULT_SPARSITY = 0.01 def lex_sort(rows: list[list[int]]) -> torch.Tensor: @@ -85,139 +104,274 @@ def make_csr_diverse( dense_mask_by_layer=[v.to(DEVICE) for v in csr.dense_mask_by_layer], dense_states=csr.dense_states.to(DEVICE), ) - # Level-1 BFS node IDs are 1..num_nodes (root is 0, its children follow in BFS order) cur_node = (torch.arange(B, dtype=torch.long) % num_nodes + 1).to(DEVICE) return csr, cur_node -def run_bench(fn): - return testing.do_bench(fn, warmup=WARMUP, rep=REP) +def make_ell(csr: CompactCSRTrie) -> CompactELLTrie: + return CompactELLTrie.from_csr(csr) + + +def _dense_matmul_inner(a: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + return (a @ weight.T).float() -def benchmark_grid(B_vals, N_vals, algorithms, sparsity, k_top, diverse_nodes=False): +_dense_matmul_compiled = torch.compile(_dense_matmul_inner) + + +def run_bench(fn, suite: BenchmarkSuite) -> float: + return cast(float, testing.do_bench(fn, warmup=suite.warmup, rep=suite.rep)) + + +def benchmark_run(run: RunConfig, suite: BenchmarkSuite, algorithms: list[str]) -> dict: alg_set = set(algorithms) - records = [] - - for B in B_vals: - for N in N_vals: - max_branches = max(1, int(N * sparsity)) - k = min(k_top, max_branches) - print(f" B={B:6d} N={N:6d} max_branches={max_branches} k={k}") - - if diverse_nodes: - csr, cur_node = make_csr_diverse( - vocab_size=N, max_branches=max_branches, B=B - ) - step = 1 - else: - csr = make_csr(vocab_size=N, max_branches=max_branches) - cur_node = torch.zeros(B, dtype=torch.long, device=DEVICE) - step = 0 - - a = torch.randn(B, K, device=DEVICE, dtype=torch.bfloat16) - weight = torch.randn(N, K, device=DEVICE, dtype=torch.bfloat16) - - cs = ConstraintState(step=step, trie=csr, cur_node=cur_node) - - needs_sparse_full = alg_set & {"sparse_pytorch_sample", "sparse_pytorch_topk"} - needs_sparse_compact = "sparse_pytorch_topk_compact" in alg_set - if needs_sparse_full: - sparse_linear_pytorch_compiled = torch.compile(sparse_linear_pytorch) - if needs_sparse_compact: - - def _sparse_compact_topk(a, weight, cur_node, csr, step, k): - nn, vi, branch_logits = sparse_linear_compact_pytorch( - a, weight, cur_node, csr, step - ) - topk_logits, topk_branch_idxs = torch.topk(branch_logits, k, dim=-1) - topk_idxs = vi.gather(1, topk_branch_idxs) - return nn, vi, topk_logits, topk_idxs - - sparse_compact_topk_compiled = torch.compile(_sparse_compact_topk) - - if "sparse_pytorch_sample" in alg_set: - - def sparse_pytorch_with_sample(): - _, _, corrected_logits = sparse_linear_pytorch_compiled( - a, weight, cur_node, csr, step=step - ) - probs = F.softmax(corrected_logits, dim=-1) - return torch.multinomial(probs, num_samples=1).squeeze(-1) - - if "sparse_pytorch_topk" in alg_set: - - def sparse_pytorch_with_topk(): - _, _, corrected_logits = sparse_linear_pytorch_compiled( - a, weight, cur_node, csr, step=step - ) - return torch.topk(corrected_logits, k, dim=-1) - - if needs_sparse_compact: - - def sparse_pytorch_compact_with_topk(): - return sparse_compact_topk_compiled( - a, weight, cur_node, csr, step, k - ) - - # --- warmup / force compilation --- - with torch.no_grad(): - if "fused_sample" in alg_set: - fused_linear_constrained_node_transition_sampling(a, weight.T, cs) - if "sparse_pytorch_sample" in alg_set: - sparse_pytorch_with_sample() - if "fused_topk" in alg_set: - fused_linear_constrained_node_transition_topk(a, weight.T, cs, k=k) - if "sparse_pytorch_topk" in alg_set: - sparse_pytorch_with_topk() - if "sparse_pytorch_topk_compact" in alg_set: - sparse_pytorch_compact_with_topk() - record = {"B": B, "N": N} - - # --- benchmark --- - with torch.no_grad(): - if "fused_sample" in alg_set: - record["ms_fused_sample"] = run_bench( - lambda: fused_linear_constrained_node_transition_sampling( - a, weight.T, cs - ) - ) - if "sparse_pytorch_sample" in alg_set: - record["ms_sparse_pytorch_sample"] = run_bench( - sparse_pytorch_with_sample - ) - if "fused_topk" in alg_set: - record["ms_fused_topk"] = run_bench( - lambda: fused_linear_constrained_node_transition_topk( - a, weight.T, cs, k=k - ) - ) - if "sparse_pytorch_topk" in alg_set: - record["ms_sparse_pytorch_topk"] = run_bench( - sparse_pytorch_with_topk - ) - if "sparse_pytorch_topk_compact" in alg_set: - record["ms_sparse_pytorch_topk_compact"] = run_bench( - sparse_pytorch_compact_with_topk - ) - if "fused_sample" in alg_set and "sparse_pytorch_sample" in alg_set: - record["speedup_fused_vs_sparse_pytorch_sample"] = ( - record["ms_sparse_pytorch_sample"] / record["ms_fused_sample"] - ) - if "fused_topk" in alg_set and "sparse_pytorch_topk" in alg_set: - record["speedup_fused_topk_vs_sparse_pytorch_topk"] = ( - record["ms_sparse_pytorch_topk"] / record["ms_fused_topk"] - ) - if "fused_topk" in alg_set and "sparse_pytorch_topk_compact" in alg_set: - record["speedup_fused_topk_vs_sparse_pytorch_topk_compact"] = ( - record["ms_sparse_pytorch_topk_compact"] / record["ms_fused_topk"] - ) - if "sparse_pytorch_topk" in alg_set and "sparse_pytorch_topk_compact" in alg_set: - record["speedup_compact_vs_full_pytorch_topk"] = ( - record["ms_sparse_pytorch_topk"] / record["ms_sparse_pytorch_topk_compact"] - ) - records.append(record) - - return pd.DataFrame(records) + max_branches = max(1, int(run.N * run.sparsity)) + k = min(run.k_top, max_branches) + print(f" B={run.B:6d} N={run.N:6d} max_branches={max_branches} k={k}") + + if suite.diverse_nodes: + csr, cur_node = make_csr_diverse( + vocab_size=run.N, max_branches=max_branches, B=run.B + ) + step = 1 + else: + csr = make_csr(vocab_size=run.N, max_branches=max_branches) + cur_node = torch.zeros(run.B, dtype=torch.long, device=DEVICE) + step = 0 + + a = torch.randn(run.B, run.k, device=DEVICE, dtype=torch.bfloat16) + weight = torch.randn(run.N, run.k, device=DEVICE, dtype=torch.bfloat16) + cs = ConstraintState(step=step, trie=csr, cur_node=cur_node) + + needs_ell = alg_set & {"ell_sample", "ell_topk"} + if needs_ell: + ell = make_ell(csr) + _bias = a.new_empty(0) + + needs_sparse_full = alg_set & {"sparse_pytorch_sample", "sparse_pytorch_topk"} + needs_sparse_compact = "sparse_pytorch_topk_compact" in alg_set + needs_sparse_compact_sample = "sparse_pytorch_sample_compact" in alg_set + needs_sparse_ell_full = "sparse_pytorch_ell_sample" in alg_set + needs_sparse_ell_compact = "sparse_pytorch_ell_topk_compact" in alg_set + if needs_sparse_full: + sparse_linear_pytorch_compiled = torch.compile(sparse_linear_pytorch) + if needs_sparse_compact: + + def _sparse_compact_topk(a, weight, cur_node, csr, step, k): + nn, vi, branch_logits = sparse_linear_compact_pytorch( + a, weight, cur_node, csr, step + ) + topk_logits, topk_branch_idxs = torch.topk(branch_logits, k, dim=-1) + topk_idxs = vi.gather(1, topk_branch_idxs) + return nn, vi, topk_logits, topk_idxs + + sparse_compact_topk_compiled = torch.compile(_sparse_compact_topk) + + if needs_sparse_ell_full or needs_sparse_ell_compact: + if not needs_ell: + ell = make_ell(csr) + if needs_sparse_ell_full: + sparse_linear_ell_pytorch_compiled = torch.compile(sparse_linear_ell_pytorch) + if needs_sparse_ell_compact: + + def _sparse_ell_compact_topk(a, weight, cur_node, ell_trie, step, k): + nn, vi, branch_logits = sparse_linear_compact_ell_pytorch( + a, weight, cur_node, ell_trie, step + ) + topk_logits, topk_branch_idxs = torch.topk(branch_logits, k, dim=-1) + topk_idxs = vi.gather(1, topk_branch_idxs) + return nn, vi, topk_logits, topk_idxs + + sparse_ell_compact_topk_compiled = torch.compile(_sparse_ell_compact_topk) + + if "sparse_pytorch_sample" in alg_set: + + def sparse_pytorch_with_sample(): + _, _, corrected_logits = sparse_linear_pytorch_compiled( + a, weight, cur_node, csr, step=step + ) + probs = F.softmax(corrected_logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + if needs_sparse_compact_sample: + + def _sparse_compact_sample(a, weight, cur_node, csr, step): + nn, vi, branch_logits = sparse_linear_compact_pytorch( + a, weight, cur_node, csr, step + ) + probs = F.softmax(branch_logits, dim=-1) + branch_sample = torch.multinomial(probs, num_samples=1) + return nn, vi.gather(1, branch_sample).squeeze(-1) + + sparse_compact_sample_compiled = torch.compile(_sparse_compact_sample) + + if "sparse_pytorch_ell_sample" in alg_set: + + def sparse_pytorch_ell_with_sample(): + _, _, corrected_logits = sparse_linear_ell_pytorch_compiled( + a, weight, cur_node, ell, step=step + ) + probs = F.softmax(corrected_logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + if "sparse_pytorch_topk" in alg_set: + + def sparse_pytorch_with_topk(): + _, _, corrected_logits = sparse_linear_pytorch_compiled( + a, weight, cur_node, csr, step=step + ) + return torch.topk(corrected_logits, k, dim=-1) + + if needs_sparse_compact: + + def sparse_pytorch_compact_with_topk(): + return sparse_compact_topk_compiled(a, weight, cur_node, csr, step, k) + + if needs_sparse_ell_compact: + + def sparse_pytorch_ell_compact_with_topk(): + return sparse_ell_compact_topk_compiled(a, weight, cur_node, ell, step, k) + + # --- warmup / force compilation --- + with torch.no_grad(): + if "fused_sample" in alg_set: + fused_linear_constrained_node_transition_sampling(a, weight.T, cs) + if "ell_sample" in alg_set: + ell_sampling_op(a, weight.T, _bias, cur_node, ell.ell_cols_vals, ell.n_children, max_branches, False) + if "sparse_pytorch_sample" in alg_set: + sparse_pytorch_with_sample() + if needs_sparse_compact_sample: + sparse_compact_sample_compiled(a, weight, cur_node, csr, step) + if "sparse_pytorch_ell_sample" in alg_set: + sparse_pytorch_ell_with_sample() + if "fused_topk" in alg_set: + fused_linear_constrained_node_transition_topk(a, weight.T, cs, k=k) + if "ell_topk" in alg_set: + ell_topk_op(a, weight.T, _bias, cur_node, ell.ell_cols_vals, ell.n_children, max_branches, False, k) + if "sparse_pytorch_topk" in alg_set: + sparse_pytorch_with_topk() + if "sparse_pytorch_topk_compact" in alg_set: + sparse_pytorch_compact_with_topk() + if "sparse_pytorch_ell_topk_compact" in alg_set: + sparse_pytorch_ell_compact_with_topk() + if "dense_topk" in alg_set: + _, _, bl = constrained_node_transition(_dense_matmul_compiled(a, weight), cs) + torch.topk(bl, k, dim=-1) + + record: dict[str, Any] = {"B": run.B, "N": run.N} + + # --- benchmark --- + with torch.no_grad(): + if "fused_sample" in alg_set: + record["ms_fused_sample"] = run_bench( + lambda: fused_linear_constrained_node_transition_sampling(a, weight.T, cs), + suite, + ) + if "ell_sample" in alg_set: + record["ms_ell_sample"] = run_bench( + lambda: ell_sampling_op( + a, weight.T, _bias, cur_node, + ell.ell_cols_vals, ell.n_children, max_branches, False, + ), + suite, + ) + if "sparse_pytorch_sample" in alg_set: + record["ms_sparse_pytorch_sample"] = run_bench(sparse_pytorch_with_sample, suite) + if needs_sparse_compact_sample: + record["ms_sparse_pytorch_sample_compact"] = run_bench( + lambda: sparse_compact_sample_compiled(a, weight, cur_node, csr, step), + suite, + ) + if "sparse_pytorch_ell_sample" in alg_set: + record["ms_sparse_pytorch_ell_sample"] = run_bench(sparse_pytorch_ell_with_sample, suite) + if "fused_topk" in alg_set: + record["ms_fused_topk"] = run_bench( + lambda: fused_linear_constrained_node_transition_topk(a, weight.T, cs, k=k), + suite, + ) + if "ell_topk" in alg_set: + record["ms_ell_topk"] = run_bench( + lambda: ell_topk_op( + a, weight.T, _bias, cur_node, + ell.ell_cols_vals, ell.n_children, max_branches, False, k, + ), + suite, + ) + if "sparse_pytorch_topk" in alg_set: + record["ms_sparse_pytorch_topk"] = run_bench(sparse_pytorch_with_topk, suite) + if "sparse_pytorch_topk_compact" in alg_set: + record["ms_sparse_pytorch_topk_compact"] = run_bench(sparse_pytorch_compact_with_topk, suite) + if "sparse_pytorch_ell_topk_compact" in alg_set: + record["ms_sparse_pytorch_ell_topk_compact"] = run_bench(sparse_pytorch_ell_compact_with_topk, suite) + if "dense_topk" in alg_set: + record["ms_dense_topk"] = run_bench( + lambda: torch.topk( + constrained_node_transition(_dense_matmul_compiled(a, weight), cs)[2], k, dim=-1 + ), + suite, + ) + + if "fused_sample" in alg_set and "sparse_pytorch_sample" in alg_set: + record["speedup_fused_vs_sparse_pytorch_sample"] = ( + record["ms_sparse_pytorch_sample"] / record["ms_fused_sample"] + ) + if "fused_sample" in alg_set and "sparse_pytorch_sample_compact" in alg_set: + record["speedup_fused_vs_sparse_pytorch_sample_compact"] = ( + record["ms_sparse_pytorch_sample_compact"] / record["ms_fused_sample"] + ) + if "ell_sample" in alg_set and "fused_sample" in alg_set: + record["speedup_ell_vs_csr_sample"] = ( + record["ms_fused_sample"] / record["ms_ell_sample"] + ) + if "ell_sample" in alg_set and "sparse_pytorch_sample" in alg_set: + record["speedup_ell_vs_sparse_pytorch_sample"] = ( + record["ms_sparse_pytorch_sample"] / record["ms_ell_sample"] + ) + if "fused_topk" in alg_set and "sparse_pytorch_topk" in alg_set: + record["speedup_fused_topk_vs_sparse_pytorch_topk"] = ( + record["ms_sparse_pytorch_topk"] / record["ms_fused_topk"] + ) + if "ell_topk" in alg_set and "fused_topk" in alg_set: + record["speedup_ell_vs_csr_topk"] = ( + record["ms_fused_topk"] / record["ms_ell_topk"] + ) + if "ell_topk" in alg_set and "sparse_pytorch_topk" in alg_set: + record["speedup_ell_vs_sparse_pytorch_topk"] = ( + record["ms_sparse_pytorch_topk"] / record["ms_ell_topk"] + ) + if "fused_topk" in alg_set and "sparse_pytorch_topk_compact" in alg_set: + record["speedup_fused_topk_vs_sparse_pytorch_topk_compact"] = ( + record["ms_sparse_pytorch_topk_compact"] / record["ms_fused_topk"] + ) + if "sparse_pytorch_topk" in alg_set and "sparse_pytorch_topk_compact" in alg_set: + record["speedup_compact_vs_full_pytorch_topk"] = ( + record["ms_sparse_pytorch_topk"] / record["ms_sparse_pytorch_topk_compact"] + ) + if "sparse_pytorch_ell_sample" in alg_set and "sparse_pytorch_sample" in alg_set: + record["speedup_ell_pytorch_vs_csr_pytorch_sample"] = ( + record["ms_sparse_pytorch_sample"] / record["ms_sparse_pytorch_ell_sample"] + ) + if "sparse_pytorch_ell_sample" in alg_set and "ell_sample" in alg_set: + record["speedup_ell_triton_vs_ell_pytorch_sample"] = ( + record["ms_sparse_pytorch_ell_sample"] / record["ms_ell_sample"] + ) + if "sparse_pytorch_ell_topk_compact" in alg_set and "sparse_pytorch_topk_compact" in alg_set: + record["speedup_ell_pytorch_vs_csr_pytorch_topk_compact"] = ( + record["ms_sparse_pytorch_topk_compact"] / record["ms_sparse_pytorch_ell_topk_compact"] + ) + if "sparse_pytorch_ell_topk_compact" in alg_set and "ell_topk" in alg_set: + record["speedup_ell_triton_vs_ell_pytorch_topk_compact"] = ( + record["ms_sparse_pytorch_ell_topk_compact"] / record["ms_ell_topk"] + ) + if "fused_topk" in alg_set and "dense_topk" in alg_set: + record["speedup_fused_topk_vs_dense_topk"] = ( + record["ms_dense_topk"] / record["ms_fused_topk"] + ) + return record + + +def benchmark_grid(suite: BenchmarkSuite, algorithms: list[str]) -> pd.DataFrame: + return pd.DataFrame([benchmark_run(run, suite, algorithms) for run in suite.runs]) def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup"): @@ -251,61 +405,64 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" metavar="ALGO", help=f"Algorithms to benchmark. Choices: {ALL_ALGORITHMS} (default: {DEFAULT_ALGORITHMS})", ) - parser.add_argument( - "--sparsity", - type=float, - default=DEFAULT_SPARSITY, - help="Fraction of vocab used as max branches (default: %(default)s)", - ) - parser.add_argument( - "--topk", - type=int, - default=K_TOP, - help="k for top-k benchmarks (default: %(default)s)", - ) - parser.add_argument( - "--diverse-nodes", - action="store_true", - default=False, - help=( - "Place each batch element on a different trie node (2-level trie, step=1) " - "instead of all starting at the root. Tests realistic cache-miss pressure." - ), - ) args = parser.parse_args() assert torch.cuda.is_available(), "CUDA required" os.makedirs("out", exist_ok=True) - B_vals = [256, 1024, 4096] - N_vals = [256, 150000] + suites = [FUSED_SAMPLE_SUITE_N256, FUSED_SAMPLE_SUITE_N150K] - print( - f"Benchmarking K={K}, sparsity={args.sparsity}, topk={args.topk}, diverse_nodes={args.diverse_nodes}" - ) print(f"Algorithms: {args.algorithms}") - print(f"B_vals={B_vals}") - print(f"N_vals={N_vals}\n") - - df = benchmark_grid( - B_vals, - N_vals, - algorithms=args.algorithms, - sparsity=args.sparsity, - k_top=args.topk, - diverse_nodes=args.diverse_nodes, - ) + for suite in suites: + print(f" diverse_nodes={suite.diverse_nodes}") + for run in suite.runs: + print(f" B={run.B}, N={run.N}, k={run.k}, k_top={run.k_top}, sparsity={run.sparsity}") + print() + + df = pd.concat([benchmark_grid(suite, algorithms=args.algorithms) for suite in suites], ignore_index=True) csv_path = "out/bench_fused_sample.csv" df.to_csv(csv_path, index=False) print(f"\nSaved {csv_path}\n") print(df.to_string(index=False)) + if "speedup_ell_vs_csr_sample" in df.columns: + plot_heatmap( + df, + value_col="speedup_ell_vs_csr_sample", + title="ELL sample speedup vs CSR sample", + filename="out/heatmap_ell_vs_csr_sample.jpg", + cbar_label="Speedup (>1 = ELL faster)", + ) + if "speedup_ell_vs_sparse_pytorch_sample" in df.columns: + plot_heatmap( + df, + value_col="speedup_ell_vs_sparse_pytorch_sample", + title="ELL sample speedup vs compile(sparse_linear_pytorch)+multinomial", + filename="out/heatmap_ell_sample_vs_sparse_pytorch.jpg", + cbar_label="Speedup (>1 = ELL faster)", + ) + if "speedup_ell_vs_csr_topk" in df.columns: + plot_heatmap( + df, + value_col="speedup_ell_vs_csr_topk", + title="ELL top-k speedup vs CSR top-k", + filename="out/heatmap_ell_vs_csr_topk.jpg", + cbar_label="Speedup (>1 = ELL faster)", + ) + if "speedup_ell_vs_sparse_pytorch_topk" in df.columns: + plot_heatmap( + df, + value_col="speedup_ell_vs_sparse_pytorch_topk", + title="ELL top-k speedup vs compile(sparse_linear_pytorch)+topk", + filename="out/heatmap_ell_topk_vs_sparse_pytorch.jpg", + cbar_label="Speedup (>1 = ELL faster)", + ) if "speedup_fused_vs_sparse_pytorch_sample" in df.columns: plot_heatmap( df, value_col="speedup_fused_vs_sparse_pytorch_sample", - title=f"Fused sample speedup vs compile(sparse_linear_pytorch)+multinomial (K={K})", + title="Fused sample speedup vs compile(sparse_linear_pytorch)+multinomial", filename="out/heatmap_fused_sample_vs_sparse_pytorch.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -313,7 +470,7 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" plot_heatmap( df, value_col="speedup_fused_topk_vs_sparse_pytorch_topk", - title=f"Fused top-k speedup vs compile(sparse_linear_pytorch)+topk (K={K}, k={args.topk})", + title="Fused top-k speedup vs compile(sparse_linear_pytorch)+topk", filename="out/heatmap_fused_topk_vs_sparse_pytorch_topk.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -321,7 +478,7 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" plot_heatmap( df, value_col="speedup_fused_topk_vs_sparse_pytorch_topk_compact", - title=f"Fused top-k speedup vs compile(sparse_linear_compact_pytorch)+topk (K={K}, k={args.topk})", + title="Fused top-k speedup vs compile(sparse_linear_compact_pytorch)+topk", filename="out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -329,7 +486,15 @@ def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup" plot_heatmap( df, value_col="speedup_compact_vs_full_pytorch_topk", - title=f"Compact pytorch top-k speedup vs full (B,N) pytorch top-k (K={K}, k={args.topk})", + title="Compact pytorch top-k speedup vs full (B,N) pytorch top-k", filename="out/heatmap_compact_vs_full_pytorch_topk.jpg", cbar_label="Speedup (>1 = compact faster)", ) + if "speedup_fused_topk_vs_dense_topk" in df.columns: + plot_heatmap( + df, + value_col="speedup_fused_topk_vs_dense_topk", + title="Fused top-k speedup vs compile(dense_matmul)+topk", + filename="out/heatmap_fused_topk_vs_dense_topk.jpg", + cbar_label="Speedup (>1 = fused faster)", + ) diff --git a/examples/scripts/benchmark/benchmark_nn_quantize.py b/examples/scripts/benchmark/benchmark_nn_quantize.py index 57f568c..3550446 100644 --- a/examples/scripts/benchmark/benchmark_nn_quantize.py +++ b/examples/scripts/benchmark/benchmark_nn_quantize.py @@ -5,12 +5,14 @@ cdist_compiled – torch.compile(torch.cdist + argmin) faiss_search – FAISS-GPU flat L2, index pre-built (static codebook) -Grid: B (batch size) × D (embedding dim). N (codebook size) fixed per run. +Grid: B (batch size) × D (embedding dim). N (codebook size) fixed per run group. Heatmap axes: B (batch size) vs D (embedding dim). """ import argparse import os +from typing import Any, cast + import torch import triton.testing as testing import matplotlib.pyplot as plt @@ -19,12 +21,9 @@ from rectokens.kernels.nn_quantize import quantize_fwd, quantize_fwd_mm from rectokens.ops.faiss_quantize import make_gpu_index +from benchmark_config import QUANTIZE_SUITE, QuantizeBenchmarkSuite, QuantizeRunConfig DEVICE = torch.device("cuda") -N = 256 -WARMUP = 25 -FAISS_WARMUP = 100 -REP = 100 ALL_ALGORITHMS = ["quantize_fwd", "quantize_fwd_mm", "cdist_compiled", "faiss_search"] @@ -36,161 +35,134 @@ def cdist_nn(x: torch.Tensor, codebook: torch.Tensor) -> torch.Tensor: cdist_nn_compiled = torch.compile(cdist_nn) -def run_bench(fn, warmup=WARMUP): - return testing.do_bench(fn, warmup=warmup, rep=REP) +def run_bench(fn, suite: QuantizeBenchmarkSuite, faiss: bool = False) -> float: + warmup = suite.faiss_warmup if faiss else suite.warmup + return cast(float, testing.do_bench(fn, warmup=warmup, rep=suite.rep)) -def benchmark_grid(B_vals, D_vals, algorithms): +def benchmark_run(run: QuantizeRunConfig, suite: QuantizeBenchmarkSuite, algorithms: list[str]) -> dict[str, Any]: alg_set = set(algorithms) - records = [] - - for B in B_vals: - for D in D_vals: - print(f" B={B:6d} D={D:6d}") - - x = torch.randn(B, D, device=DEVICE) - codebook = torch.randn(N, D, device=DEVICE) - - if "faiss_search" in alg_set: - gpu_index = make_gpu_index(codebook) - - # warmup / force compilation / autotuning - with torch.no_grad(): - if "quantize_fwd" in alg_set: - quantize_fwd(x, codebook) - if "quantize_fwd_mm" in alg_set: - quantize_fwd_mm(x, codebook) - if "cdist_compiled" in alg_set: - cdist_nn_compiled(x, codebook) - if "faiss_search" in alg_set: - gpu_index.search(x.contiguous(), 1) - - record = {"B": B, "D": D, "BN": B * N} - - with torch.no_grad(): - if "quantize_fwd" in alg_set: - record["ms_quantize_fwd"] = run_bench( - lambda: quantize_fwd(x, codebook) - ) - if "quantize_fwd_mm" in alg_set: - record["ms_quantize_fwd_mm"] = run_bench( - lambda: quantize_fwd_mm(x, codebook) - ) - if "cdist_compiled" in alg_set: - record["ms_cdist_compiled"] = run_bench( - lambda: cdist_nn_compiled(x, codebook) - ) - if "faiss_search" in alg_set: - record["ms_faiss_search"] = run_bench( - lambda: gpu_index.search(x.contiguous(), 1), - warmup=FAISS_WARMUP, - ) - - if "quantize_fwd" in alg_set and "quantize_fwd_mm" in alg_set: - record["speedup_fwd_vs_mm"] = ( - record["ms_quantize_fwd_mm"] / record["ms_quantize_fwd"] - ) - if "quantize_fwd" in alg_set and "cdist_compiled" in alg_set: - record["speedup_fwd_vs_cdist"] = ( - record["ms_cdist_compiled"] / record["ms_quantize_fwd"] - ) - if "quantize_fwd_mm" in alg_set and "cdist_compiled" in alg_set: - record["speedup_mm_vs_cdist"] = ( - record["ms_cdist_compiled"] / record["ms_quantize_fwd_mm"] - ) - if "quantize_fwd" in alg_set and "faiss_search" in alg_set: - record["speedup_fwd_vs_faiss"] = ( - record["ms_faiss_search"] / record["ms_quantize_fwd"] - ) - if "quantize_fwd_mm" in alg_set and "faiss_search" in alg_set: - record["speedup_mm_vs_faiss"] = ( - record["ms_faiss_search"] / record["ms_quantize_fwd_mm"] - ) - - records.append(record) - - return pd.DataFrame(records) - - -def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup"): - pivot = df.pivot(index="B", columns="D", values=value_col) - pivot = pivot.sort_index() - plt.figure(figsize=(10, 6)) - sns.heatmap( - pivot, - annot=True, - fmt=fmt, - cmap="viridis", - cbar_kws={"label": cbar_label}, - ) - plt.title(title) - plt.ylabel("Batch size (B)") - plt.xlabel("Embedding dim (D)") - plt.tight_layout() - plt.savefig(filename, dpi=150, format="jpg") - plt.close() - print(f" Saved {filename}") - - -def run_for_N(n_val, B_vals, D_vals, algorithms): - global N - N = n_val + print(f" B={run.B:6d} D={run.D:6d} N={run.N}") + + x = torch.randn(run.B, run.D, device=DEVICE) + codebook = torch.randn(run.N, run.D, device=DEVICE) + + if "faiss_search" in alg_set: + gpu_index = make_gpu_index(codebook) + + with torch.no_grad(): + if "quantize_fwd" in alg_set: + quantize_fwd(x, codebook) + if "quantize_fwd_mm" in alg_set: + quantize_fwd_mm(x, codebook) + if "cdist_compiled" in alg_set: + cdist_nn_compiled(x, codebook) + if "faiss_search" in alg_set: + gpu_index.search(x.contiguous(), 1) + + record: dict[str, Any] = {"B": run.B, "D": run.D, "N": run.N, "BN": run.B * run.N} + + with torch.no_grad(): + if "quantize_fwd" in alg_set: + record["ms_quantize_fwd"] = run_bench(lambda: quantize_fwd(x, codebook), suite) + if "quantize_fwd_mm" in alg_set: + record["ms_quantize_fwd_mm"] = run_bench(lambda: quantize_fwd_mm(x, codebook), suite) + if "cdist_compiled" in alg_set: + record["ms_cdist_compiled"] = run_bench(lambda: cdist_nn_compiled(x, codebook), suite) + if "faiss_search" in alg_set: + record["ms_faiss_search"] = run_bench( + lambda: gpu_index.search(x.contiguous(), 1), suite, faiss=True + ) + + if "quantize_fwd" in alg_set and "quantize_fwd_mm" in alg_set: + record["speedup_fwd_vs_mm"] = record["ms_quantize_fwd_mm"] / record["ms_quantize_fwd"] + if "quantize_fwd" in alg_set and "cdist_compiled" in alg_set: + record["speedup_fwd_vs_cdist"] = record["ms_cdist_compiled"] / record["ms_quantize_fwd"] + if "quantize_fwd_mm" in alg_set and "cdist_compiled" in alg_set: + record["speedup_mm_vs_cdist"] = record["ms_cdist_compiled"] / record["ms_quantize_fwd_mm"] + if "quantize_fwd" in alg_set and "faiss_search" in alg_set: + record["speedup_fwd_vs_faiss"] = record["ms_faiss_search"] / record["ms_quantize_fwd"] + if "quantize_fwd_mm" in alg_set and "faiss_search" in alg_set: + record["speedup_mm_vs_faiss"] = record["ms_faiss_search"] / record["ms_quantize_fwd_mm"] + + return record + + +def benchmark_for_n(n: int, suite: QuantizeBenchmarkSuite, algorithms: list[str]) -> None: + runs = [r for r in suite.runs if r.N == n] print(f"\n{'=' * 50}") - print(f"Benchmarking N={N} (fixed)") - print(f"Algorithms: {algorithms}") - print(f"B_vals={B_vals}") - print(f"D_vals={D_vals}\n") + print(f"Benchmarking N={n} (codebook size fixed)") + print(f"Algorithms: {algorithms}\n") - df = benchmark_grid(B_vals, D_vals, algorithms=algorithms) - csv_path = f"out/bench_nn_quantize_N{N}.csv" + df = pd.DataFrame([benchmark_run(run, suite, algorithms) for run in runs]) + csv_path = f"out/bench_nn_quantize_N{n}.csv" df.to_csv(csv_path, index=False) print(f"\nSaved {csv_path}\n") - print(df.to_string(index=False)) if "speedup_fwd_vs_mm" in df.columns: plot_heatmap( df, value_col="speedup_fwd_vs_mm", - title=f"quantize_fwd speedup vs quantize_fwd_mm (N={N})", - filename=f"out/heatmap_fwd_vs_mm_N{N}.jpg", + title=f"quantize_fwd speedup vs quantize_fwd_mm (N={n})", + filename=f"out/heatmap_fwd_vs_mm_N{n}.jpg", cbar_label="Speedup (>1 = fwd faster)", ) if "speedup_fwd_vs_cdist" in df.columns: plot_heatmap( df, value_col="speedup_fwd_vs_cdist", - title=f"quantize_fwd speedup vs cdist_compiled (N={N})", - filename=f"out/heatmap_fwd_vs_cdist_N{N}.jpg", + title=f"quantize_fwd speedup vs cdist_compiled (N={n})", + filename=f"out/heatmap_fwd_vs_cdist_N{n}.jpg", cbar_label="Speedup (>1 = fwd faster)", ) if "speedup_mm_vs_cdist" in df.columns: plot_heatmap( df, value_col="speedup_mm_vs_cdist", - title=f"quantize_fwd_mm speedup vs cdist_compiled (N={N})", - filename=f"out/heatmap_mm_vs_cdist_N{N}.jpg", + title=f"quantize_fwd_mm speedup vs cdist_compiled (N={n})", + filename=f"out/heatmap_mm_vs_cdist_N{n}.jpg", cbar_label="Speedup (>1 = mm faster)", ) if "speedup_fwd_vs_faiss" in df.columns: plot_heatmap( df, value_col="speedup_fwd_vs_faiss", - title=f"quantize_fwd speedup vs faiss_search (N={N})", - filename=f"out/heatmap_fwd_vs_faiss_N{N}.jpg", + title=f"quantize_fwd speedup vs faiss_search (N={n})", + filename=f"out/heatmap_fwd_vs_faiss_N{n}.jpg", cbar_label="Speedup (>1 = fwd faster)", ) if "speedup_mm_vs_faiss" in df.columns: plot_heatmap( df, value_col="speedup_mm_vs_faiss", - title=f"quantize_fwd_mm speedup vs faiss_search (N={N})", - filename=f"out/heatmap_mm_vs_faiss_N{N}.jpg", + title=f"quantize_fwd_mm speedup vs faiss_search (N={n})", + filename=f"out/heatmap_mm_vs_faiss_N{n}.jpg", cbar_label="Speedup (>1 = mm faster)", ) +def plot_heatmap(df, value_col, title, filename, fmt=".2f", cbar_label="Speedup"): + pivot = df.pivot(index="B", columns="D", values=value_col) + pivot = pivot.sort_index() + plt.figure(figsize=(10, 6)) + sns.heatmap( + pivot, + annot=True, + fmt=fmt, + cmap="viridis", + cbar_kws={"label": cbar_label}, + ) + plt.title(title) + plt.ylabel("Batch size (B)") + plt.xlabel("Embedding dim (D)") + plt.tight_layout() + plt.savefig(filename, dpi=150, format="jpg") + plt.close() + print(f" Saved {filename}") + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Benchmark nearest-neighbor quantization algorithms." @@ -208,8 +180,7 @@ def run_for_N(n_val, B_vals, D_vals, algorithms): assert torch.cuda.is_available(), "CUDA required" os.makedirs("out", exist_ok=True) - B_vals = [32, 256, 1024, 4096, 16384, 32768, 65536] - D_vals = [64, 128, 256] - - for n_val in [64, 128, 256, 512]: - run_for_N(n_val, B_vals, D_vals, algorithms=args.algorithms) + suite = QUANTIZE_SUITE + n_vals = sorted({r.N for r in suite.runs}) + for n in n_vals: + benchmark_for_n(n, suite, algorithms=args.algorithms) diff --git a/examples/scripts/benchmark/benchmark_update_speed.py b/examples/scripts/benchmark/benchmark_update_speed.py new file mode 100644 index 0000000..5872832 --- /dev/null +++ b/examples/scripts/benchmark/benchmark_update_speed.py @@ -0,0 +1,344 @@ +""" +Benchmark: trie update speed — CSR full rebuild vs ELL incremental update. + +CSR must sort ALL sequences and call from_sorted_batch() from scratch on every +update. ELL uses MutableELLTrie which pre-allocates a buffer and inserts only +the new nodes/edges, touching zero existing data. + +Two benchmark modes +------------------- +single – Fixed initial catalog of N seqs. Time one update of M new seqs. + CSR: sort(initial + new) + from_sorted_batch. + ELL: MutableELLTrie.update(new_seqs) — pure traversal, no copy. + State is reset between timing reps by copy_() outside the hot path. + +growing – Simulate a catalog that grows over time. Each of K rounds adds M + new sequences. CSR rebuilds from all sequences every round (cost + grows). ELL calls update() every round (cost stays constant). + No state reset needed — rounds chain naturally. +""" + +import argparse +import os +import random +import time + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import torch + +from rectokens.schemas.compact_csr_trie import CompactCSRTrie +from rectokens.schemas.compact_ell_trie import CompactELLTrie, MutableELLTrie +from benchmark_config import UPDATE_SUITE, UpdateBenchmarkSuite, UpdateRunConfig + + +# --------------------------------------------------------------------------- +# Data helpers +# --------------------------------------------------------------------------- + +def lex_sort(seqs: list[list[int]]) -> torch.Tensor: + return torch.tensor(sorted(seqs), dtype=torch.long) + + +def generate_unique_seqs( + n: int, + L: int, + vocab_size: int, + exclude: set[tuple[int, ...]] | None = None, + seed: int = 0, +) -> list[list[int]]: + rng = random.Random(seed) + exclude = exclude or set() + seqs: set[tuple[int, ...]] = set() + while len(seqs) < n: + seq = tuple(rng.randint(0, vocab_size - 1) for _ in range(L)) + if seq not in exclude: + seqs.add(seq) + return [list(s) for s in seqs] + + +# --------------------------------------------------------------------------- +# "single" benchmark +# --------------------------------------------------------------------------- + +def bench_single(run: UpdateRunConfig, suite: UpdateBenchmarkSuite) -> dict: + """One update of run.update_n sequences into a trie pre-loaded with run.initial_n seqs.""" + initial_seqs = generate_unique_seqs(run.initial_n, suite.seq_len, suite.vocab_size, seed=0) + initial_set = {tuple(s) for s in initial_seqs} + new_seqs = generate_unique_seqs( + run.update_n, suite.seq_len, suite.vocab_size, exclude=initial_set, seed=1 + ) + + initial_tensor = lex_sort(initial_seqs) + csr_init = CompactCSRTrie.from_sorted_batch(initial_tensor, suite.vocab_size) + ell_init = CompactELLTrie.from_csr(csr_init) + + all_seqs = initial_seqs + new_seqs + + def csr_rebuild() -> None: + CompactCSRTrie.from_sorted_batch(lex_sort(all_seqs), suite.vocab_size) + + capacity = run.update_n * suite.seq_len + mutable = MutableELLTrie(ell_init, extra_capacity=capacity) + + snap_cv = mutable.ell_cv.clone() + snap_nch = mutable.n_ch.clone() + snap_nn = mutable.num_nodes + snap_mb = mutable.max_branches + + def ell_update() -> None: + mutable.update(new_seqs) + + def ell_reset() -> None: + if mutable.ell_cv.shape == snap_cv.shape: + mutable.ell_cv.copy_(snap_cv) + mutable.n_ch.copy_(snap_nch) + else: + mutable.ell_cv = snap_cv.clone() + mutable.n_ch = snap_nch.clone() + mutable.num_nodes = snap_nn + mutable.max_branches = snap_mb + + for _ in range(suite.warmup): + ell_reset() + ell_update() + ell_times: list[float] = [] + for _ in range(suite.rep): + ell_reset() + t0 = time.perf_counter() + ell_update() + ell_times.append((time.perf_counter() - t0) * 1e3) + + for _ in range(suite.warmup): + csr_rebuild() + csr_times: list[float] = [] + for _ in range(suite.rep): + t0 = time.perf_counter() + csr_rebuild() + csr_times.append((time.perf_counter() - t0) * 1e3) + + def median_lower(ts: list[float]) -> float: + ts = sorted(ts) + half = max(1, len(ts) // 2) + return sum(ts[:half]) / half + + ms_csr = median_lower(csr_times) + ms_ell = median_lower(ell_times) + return { + "initial_n": run.initial_n, + "update_n": run.update_n, + "ms_csr": ms_csr, + "ms_ell": ms_ell, + "speedup_ell_vs_csr": ms_csr / ms_ell, + } + + +# --------------------------------------------------------------------------- +# "growing" benchmark +# --------------------------------------------------------------------------- + +def bench_growing(run: UpdateRunConfig, suite: UpdateBenchmarkSuite) -> dict: + """Simulate a growing catalog. Each round adds run.update_n new sequences.""" + total_needed = run.initial_n + run.update_n * suite.growing_rounds + all_seqs = generate_unique_seqs(total_needed, suite.seq_len, suite.vocab_size, seed=42) + initial_seqs = all_seqs[:run.initial_n] + update_batches = [ + all_seqs[run.initial_n + i * run.update_n: run.initial_n + (i + 1) * run.update_n] + for i in range(suite.growing_rounds) + ] + + csr_times: list[float] = [] + current_seqs = list(initial_seqs) + for batch in update_batches: + current_seqs.extend(batch) + t0 = time.perf_counter() + CompactCSRTrie.from_sorted_batch(lex_sort(current_seqs), suite.vocab_size) + csr_times.append((time.perf_counter() - t0) * 1e3) + + csr_init = CompactCSRTrie.from_sorted_batch(lex_sort(initial_seqs), suite.vocab_size) + mutable = MutableELLTrie( + CompactELLTrie.from_csr(csr_init), + extra_capacity=run.update_n * suite.seq_len * suite.growing_rounds, + ) + ell_times: list[float] = [] + for batch in update_batches: + t0 = time.perf_counter() + mutable.update(batch) + ell_times.append((time.perf_counter() - t0) * 1e3) + + catalog_sizes = [run.initial_n + (i + 1) * run.update_n for i in range(suite.growing_rounds)] + return { + "catalog_sizes": catalog_sizes, + "csr_times": csr_times, + "ell_times": ell_times, + "speedups": [c / e for c, e in zip(csr_times, ell_times)], + } + + +# --------------------------------------------------------------------------- +# Plotting +# --------------------------------------------------------------------------- + +def plot_speedup_heatmap(df: pd.DataFrame, filename: str) -> None: + pivot = df.pivot(index="initial_n", columns="update_n", values="speedup_ell_vs_csr") + plt.figure(figsize=(9, 5)) + sns.heatmap( + pivot.sort_index(), + annot=True, + fmt=".1f", + cmap="viridis", + cbar_kws={"label": "Speedup (CSR / ELL, >1 = ELL faster)"}, + ) + plt.title("ELL incremental update speedup over CSR full rebuild\n(single-update mode)") + plt.ylabel("Initial catalogue size") + plt.xlabel("Update batch size") + plt.tight_layout() + plt.savefig(filename, dpi=150, format="jpg") + plt.close() + print(f" Saved {filename}") + + +def plot_abs_times(df: pd.DataFrame, initial_n: int, filename: str) -> None: + sub = df[df["initial_n"] == initial_n].copy() + x = list(range(len(sub))) + labels = [str(v) for v in sub["update_n"]] + w = 0.35 + + fig, ax = plt.subplots(figsize=(9, 5)) + bars_csr = ax.bar([i - w / 2 for i in x], sub["ms_csr"], w, label="CSR (full rebuild)") + bars_ell = ax.bar([i + w / 2 for i in x], sub["ms_ell"], w, label="ELL (incremental)") + ax.set_xticks(x) + ax.set_xticklabels(labels) + ax.set_xlabel("Update batch size") + ax.set_ylabel("Time (ms, log scale)") + ax.set_title( + f"Update latency: CSR vs ELL (initial catalogue = {initial_n:,} seqs,\n" + f"CSR time includes sort of all sequences; ELL times only new inserts)" + ) + ax.set_yscale("log") + ax.legend() + for bar in list(bars_csr) + list(bars_ell): + h = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, h * 1.1, + f"{h:.2f}", ha="center", va="bottom", fontsize=8, + ) + plt.tight_layout() + plt.savefig(filename, dpi=150, format="jpg") + plt.close() + print(f" Saved {filename}") + + +def plot_growing(result: dict, initial_n: int, update_n: int, filename: str) -> None: + sizes = result["catalog_sizes"] + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5)) + + ax1.plot(sizes, result["csr_times"], marker="o", label="CSR (full rebuild)") + ax1.plot(sizes, result["ell_times"], marker="s", label="ELL (incremental)") + ax1.set_xlabel("Total catalogue size") + ax1.set_ylabel("Update latency (ms)") + ax1.set_title( + f"Update latency vs catalogue size\n" + f"(initial={initial_n:,}, +{update_n} seqs/round)" + ) + ax1.legend() + + ax2.plot(sizes, result["speedups"], marker="^", color="C2") + ax2.axhline(1, linestyle="--", color="gray", linewidth=0.8) + ax2.set_xlabel("Total catalogue size") + ax2.set_ylabel("Speedup (CSR / ELL)") + ax2.set_title("Speedup over CSR\n(>1 = ELL faster)") + + plt.tight_layout() + plt.savefig(filename, dpi=150, format="jpg") + plt.close() + print(f" Saved {filename}") + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark CSR full rebuild vs ELL incremental update." + ) + parser.add_argument("--mode", choices=["single", "growing", "both"], default="both") + args = parser.parse_args() + + os.makedirs("out", exist_ok=True) + + suite = UPDATE_SUITE + print( + f"vocab_size={suite.vocab_size} seq_len={suite.seq_len} " + f"warmup={suite.warmup} rep={suite.rep}" + ) + + # ---- single mode ------------------------------------------------------- + if args.mode in ("single", "both"): + print("\n=== single-update mode ===") + print(f"{'initial_n':>10} {'update_n':>9} {'ms_csr':>9} {'ms_ell':>9} {'speedup':>9}") + print("-" * 55) + + records = [] + for run in suite.runs: + rec = bench_single(run, suite) + records.append(rec) + print( + f"{rec['initial_n']:>10} {rec['update_n']:>9} " + f"{rec['ms_csr']:>9.3f} {rec['ms_ell']:>9.3f} " + f"{rec['speedup_ell_vs_csr']:>8.1f}x" + ) + + df = pd.DataFrame(records) + df.to_csv("out/bench_update_single.csv", index=False) + print(f"\nSaved out/bench_update_single.csv") + plot_speedup_heatmap(df, "out/bench_update_speedup_heatmap.jpg") + largest = max(r.initial_n for r in suite.runs) + plot_abs_times(df, largest, f"out/bench_update_abs_times_n{largest}.jpg") + + # ---- growing mode ------------------------------------------------------- + if args.mode in ("growing", "both"): + print("\n=== growing-catalog mode ===") + initial_n = suite.runs[0].initial_n + growing_records = [] + + for run in suite.runs: + print(f" initial_n={run.initial_n} update_n={run.update_n} rounds={suite.growing_rounds}") + result = bench_growing(run, suite) + plot_growing( + result, + initial_n=run.initial_n, + update_n=run.update_n, + filename=f"out/bench_update_growing_m{run.update_n}.jpg", + ) + for size, tc, te, sp in zip( + result["catalog_sizes"], + result["csr_times"], + result["ell_times"], + result["speedups"], + ): + growing_records.append( + { + "update_n": run.update_n, + "catalog_size": size, + "ms_csr": tc, + "ms_ell": te, + "speedup": sp, + } + ) + + gdf = pd.DataFrame(growing_records) + gdf.to_csv("out/bench_update_growing.csv", index=False) + print(f" Saved out/bench_update_growing.csv") + + print("\n Final-round speedup (after all catalog growth):") + print(f" {'update_n':>9} {'catalog_size':>13} {'ms_csr':>9} {'ms_ell':>9} {'speedup':>9}") + unique_update_ns = sorted(gdf["update_n"].unique()) + for update_n in unique_update_ns: + row = gdf[gdf["update_n"] == update_n].iloc[-1] + print( + f" {int(row.update_n):>9} {int(row.catalog_size):>13} " + f"{row.ms_csr:>9.3f} {row.ms_ell:>9.3f} {row.speedup:>8.1f}x" + ) diff --git a/examples/scripts/benchmark/benchmark_vtnk.py b/examples/scripts/benchmark/benchmark_vtnk.py index 2e5aff3..3a346d7 100644 --- a/examples/scripts/benchmark/benchmark_vtnk.py +++ b/examples/scripts/benchmark/benchmark_vtnk.py @@ -4,12 +4,14 @@ vs torch.compile(nn.Linear) + vtnk_pytorch vs torch.compile(sparse_linear_pytorch) -Grid: B (batch size) × N (vocab / logits size). K (hidden dim) fixed. +Grid: B (batch size) × N (vocab / logits size). K (hidden dim) fixed per run. """ import argparse import os import time +from typing import Any, cast + import torch import torch.nn as nn import triton.testing as testing @@ -25,15 +27,12 @@ constrained_node_transition, fused_linear_constrained_node_transition, ) +from benchmark_config import VTNK_SUITE, BenchmarkSuite, RunConfig, trie_runs DEVICE = torch.device("cuda") -K = 512 -WARMUP = 25 -REP = 100 -ALL_ALGORITHMS = ["fused", "kernel", "pytorch", "sparse_pytorch", "dense_lookup", "trie_cpu"] -DEFAULT_ALGORITHMS = ["fused", "kernel", "pytorch", "sparse_pytorch", "dense_lookup"] -DEFAULT_SPARSITY = 0.01 +ALL_ALGORITHMS = ["fused", "kernel", "pytorch", "sparse_pytorch", "dense_lookup", "dense_topk", "trie_cpu"] +DEFAULT_ALGORITHMS = ["fused", "kernel", "pytorch", "sparse_pytorch", "dense_lookup", "dense_topk"] torch.set_float32_matmul_precision("high") @@ -71,17 +70,15 @@ def make_csr_diverse( dense_mask_by_layer=[v.to(DEVICE) for v in csr.dense_mask_by_layer], dense_states=csr.dense_states.to(DEVICE), ) - # Level-1 BFS node IDs are 1..num_nodes (root is 0, its children follow in BFS order) cur_node = (torch.arange(B, dtype=torch.long) % num_nodes + 1).to(DEVICE) return csr, cur_node def make_trie(max_branches: int) -> tuple[Trie, list[TrieNode]]: - """Build a Trie and return it along with a BFS-ordered node list for integer indexing.""" + """Build a Trie and return a BFS-ordered node list for integer indexing.""" trie = Trie() for i in range(max_branches): trie.insert([i]) - # BFS node list so that node index 0 == root (matching cur_node convention) nodes: list[TrieNode] = [] queue = [trie.root] while queue: @@ -114,35 +111,30 @@ def _dense_lookup_inner(a: torch.Tensor, weight: torch.Tensor, mask1d: torch.Ten _dense_lookup_compiled = torch.compile(_dense_lookup_inner) +def _dense_matmul_inner(a: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + return (a @ weight.T).float() + + +_dense_matmul_compiled = torch.compile(_dense_matmul_inner) + + def dense_lookup_pytorch( a: torch.Tensor, weight: torch.Tensor, cur_node: torch.Tensor, csr, step: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Full dense linear + dense mask lookup. - Computes the full logit vector (a @ weight.T) via torch.compile, then masks - invalid tokens using csr.dense_mask_by_layer[step]. Valid branch indices and - next-node IDs are looked up from csr.dense_states (no CSR scatter needed). - Only supports step=0 (all batch items at the root node); incompatible with --diverse-nodes because step=1 requires the previous-token path to index the 2-D dense_mask_by_layer, which is not stored in cur_node. - - Returns (next_node, valid_idxs, corrected_logits) matching vtnk_pytorch output: - next_node: (B, max_branches) int64 — child BFS IDs, -1 for padding - valid_idxs: (B, max_branches) int64 — valid token indices, -1 for padding - corrected_logits: (B, N) float32 — logits, -inf for invalid tokens """ device = a.device B = a.shape[0] - mask1d: torch.Tensor = csr.dense_mask_by_layer[step] # [N] - corrected_logits = _dense_lookup_compiled(a, weight, mask1d) # [B, N] + mask1d: torch.Tensor = csr.dense_mask_by_layer[step] + corrected_logits = _dense_lookup_compiled(a, weight, mask1d) - # Pack valid_idxs and next_node from dense_states — no CSR traversal needed. - # dense_states is 1-D for dense_lookup_layers=1 (the step=0 case): - # dense_states[tok] = BFS node ID of the child reached by tok from root. max_branches = csr.layer_max_branches[step] - valid_toks = mask1d.nonzero(as_tuple=True)[0] # [n_valid] + valid_toks = mask1d.nonzero(as_tuple=True)[0] n_valid = valid_toks.shape[0] padded_valid = valid_toks.new_full((max_branches,), -1) @@ -157,146 +149,150 @@ def dense_lookup_pytorch( return next_node, valid_idxs, corrected_logits -def run_bench(fn): - return testing.do_bench(fn, warmup=WARMUP, rep=REP) +def run_bench(fn, suite: BenchmarkSuite) -> float: + return cast(float, testing.do_bench(fn, warmup=suite.warmup, rep=suite.rep)) -def run_bench_cpu(fn, warmup: int = WARMUP, rep: int = REP) -> float: +def run_bench_cpu(fn, suite: BenchmarkSuite) -> float: """Wall-clock benchmark for CPU-only functions (ms).""" - for _ in range(warmup): + for _ in range(suite.warmup): fn() times = [] - for _ in range(rep): + for _ in range(suite.rep): t0 = time.perf_counter() fn() times.append((time.perf_counter() - t0) * 1e3) times.sort() - return float(sum(times[: max(1, rep // 2)]) / max(1, rep // 2)) + return float(sum(times[: max(1, suite.rep // 2)]) / max(1, suite.rep // 2)) -def benchmark_grid(B_vals, N_vals, algorithms, sparsity, diverse_nodes=False): +def benchmark_run(run: RunConfig, suite: BenchmarkSuite, algorithms: list[str], max_N: int, max_B: int) -> dict[str, Any]: alg_set = set(algorithms) - gpu_algos = alg_set & {"fused", "kernel", "pytorch", "sparse_pytorch", "dense_lookup"} - max_B, max_N = max(B_vals), max(N_vals) - records = [] - - for B in B_vals: - for N in N_vals: - max_branches = max(1, int(N * sparsity)) - print(f" B={B:6d} N={N:6d} max_branches={max_branches}") - - # CPU traversal is too slow to run at the highest N and B values, - # and incompatible with diverse_nodes (requires a deeper Trie object). - # dense_lookup requires step=0 (all batch items at root); diverse_nodes - # uses step=1 where the 2-D dense mask must be indexed by prev_token. - skip_cpu = B == max_B or N == max_N or diverse_nodes - skip_algos = {"trie_cpu"} if skip_cpu else set() - if diverse_nodes: - skip_algos.add("dense_lookup") - active_alg_set = alg_set - skip_algos - - if gpu_algos: - if diverse_nodes: - csr, cur_node = make_csr_diverse( - vocab_size=N, max_branches=max_branches, B=B - ) - step = 1 - else: - csr = make_csr(vocab_size=N, max_branches=max_branches) - cur_node = torch.zeros(B, dtype=torch.long, device=DEVICE) - step = 0 - if "trie_cpu" in active_alg_set: - _, trie_nodes = make_trie(max_branches=max_branches) - - a = torch.randn(B, K, device=DEVICE) - weight = torch.randn(N, K, device=DEVICE) - if not gpu_algos: - cur_node = torch.zeros(B, dtype=torch.long, device=DEVICE) - step = 0 - if "trie_cpu" in active_alg_set: - cur_node_cpu = cur_node.cpu() - - if "pytorch" in active_alg_set: - linear = torch.compile(nn.Linear(K, N, bias=False).to(DEVICE)) - with torch.no_grad(): - linear.weight.data.copy_(weight) - - if "sparse_pytorch" in active_alg_set: - sparse_linear_pytorch_compiled = torch.compile(sparse_linear_pytorch) - - if gpu_algos: - cs = ConstraintState(step=step, trie=csr, cur_node=cur_node) - - # --- warmup / force compilation --- - with torch.no_grad(): - if "fused" in active_alg_set: - fused_linear_constrained_node_transition(a, weight.T, cs) - if "kernel" in active_alg_set: - constrained_node_transition(a @ weight.T, cs) - if "pytorch" in active_alg_set: - vtnk_pytorch(linear(a), cur_node, csr, step=step) - if "sparse_pytorch" in active_alg_set: - sparse_linear_pytorch_compiled(a, weight, cur_node, csr, step=step) - if "dense_lookup" in active_alg_set: - dense_lookup_pytorch(a, weight, cur_node, csr, step=step) - - record = {"B": B, "N": N} - - # --- benchmark --- - with torch.no_grad(): - if "fused" in active_alg_set: - record["ms_fused"] = run_bench( - lambda: fused_linear_constrained_node_transition( - a, weight.T, cs - ) - ) - if "kernel" in active_alg_set: - record["ms_kernel"] = run_bench( - lambda: constrained_node_transition(a @ weight.T, cs) - ) - if "pytorch" in active_alg_set: - record["ms_pytorch"] = run_bench( - lambda: vtnk_pytorch(linear(a), cur_node, csr, step=step) - ) - if "sparse_pytorch" in active_alg_set: - record["ms_sparse_pytorch"] = run_bench( - lambda: sparse_linear_pytorch_compiled( - a, weight, cur_node, csr, step=step - ) - ) - if "dense_lookup" in active_alg_set: - record["ms_dense_lookup"] = run_bench( - lambda: dense_lookup_pytorch(a, weight, cur_node, csr, step=step) - ) - if "trie_cpu" in active_alg_set: - record["ms_trie_cpu"] = run_bench_cpu( - lambda: trie_cpu_traversal(cur_node_cpu, trie_nodes, N) - ) - - if "fused" in active_alg_set and "kernel" in active_alg_set: - record["speedup_fused_vs_kernel"] = ( - record["ms_kernel"] / record["ms_fused"] - ) - if "fused" in active_alg_set and "pytorch" in active_alg_set: - record["speedup_fused_vs_pytorch"] = ( - record["ms_pytorch"] / record["ms_fused"] - ) - if "fused" in active_alg_set and "sparse_pytorch" in active_alg_set: - record["speedup_fused_vs_sparse_pytorch"] = ( - record["ms_sparse_pytorch"] / record["ms_fused"] - ) - if "fused" in active_alg_set and "trie_cpu" in active_alg_set: - record["speedup_fused_vs_trie_cpu"] = ( - record["ms_trie_cpu"] / record["ms_fused"] - ) - if "fused" in active_alg_set and "dense_lookup" in active_alg_set: - record["speedup_fused_vs_dense_lookup"] = ( - record["ms_dense_lookup"] / record["ms_fused"] - ) - - records.append(record) - - return pd.DataFrame(records) + gpu_algos = alg_set & {"fused", "kernel", "pytorch", "sparse_pytorch", "dense_lookup", "dense_topk"} + max_branches = max(1, int(run.N * run.sparsity)) + print(f" B={run.B:6d} N={run.N:6d} max_branches={max_branches}") + + # trie_cpu is too slow at max B/N and incompatible with diverse_nodes. + # dense_lookup requires step=0; diverse_nodes uses step=1. + skip_cpu = run.B == max_B or run.N == max_N or suite.diverse_nodes + skip_algos = {"trie_cpu"} if skip_cpu else set() + if suite.diverse_nodes: + skip_algos.add("dense_lookup") + active_alg_set = alg_set - skip_algos + + if gpu_algos: + if suite.diverse_nodes: + csr, cur_node = make_csr_diverse( + vocab_size=run.N, max_branches=max_branches, B=run.B + ) + step = 1 + else: + csr = make_csr(vocab_size=run.N, max_branches=max_branches) + cur_node = torch.zeros(run.B, dtype=torch.long, device=DEVICE) + step = 0 + if "trie_cpu" in active_alg_set: + _, trie_nodes = make_trie(max_branches=max_branches) + + a = torch.randn(run.B, run.k, device=DEVICE) + weight = torch.randn(run.N, run.k, device=DEVICE) + if not gpu_algos: + cur_node = torch.zeros(run.B, dtype=torch.long, device=DEVICE) + step = 0 + if "trie_cpu" in active_alg_set: + cur_node_cpu = cur_node.cpu() + + if "pytorch" in active_alg_set: + linear = torch.compile(nn.Linear(run.k, run.N, bias=False).to(DEVICE)) + with torch.no_grad(): + linear.weight.data.copy_(weight) + + if "sparse_pytorch" in active_alg_set: + sparse_linear_pytorch_compiled = torch.compile(sparse_linear_pytorch) + + if gpu_algos: + cs = ConstraintState(step=step, trie=csr, cur_node=cur_node) + + # --- warmup / force compilation --- + with torch.no_grad(): + if "fused" in active_alg_set: + fused_linear_constrained_node_transition(a, weight.T, cs) + if "kernel" in active_alg_set: + constrained_node_transition(a @ weight.T, cs) + if "pytorch" in active_alg_set: + vtnk_pytorch(linear(a), cur_node, csr, step=step) + if "sparse_pytorch" in active_alg_set: + sparse_linear_pytorch_compiled(a, weight, cur_node, csr, step=step) + if "dense_lookup" in active_alg_set: + dense_lookup_pytorch(a, weight, cur_node, csr, step=step) + if "dense_topk" in active_alg_set: + _, _, bl = constrained_node_transition(_dense_matmul_compiled(a, weight), cs) + torch.topk(bl, run.k_top, dim=-1) + + record: dict[str, Any] = {"B": run.B, "N": run.N} + + # --- benchmark --- + with torch.no_grad(): + if "fused" in active_alg_set: + record["ms_fused"] = run_bench( + lambda: fused_linear_constrained_node_transition(a, weight.T, cs), + suite, + ) + if "kernel" in active_alg_set: + record["ms_kernel"] = run_bench( + lambda: constrained_node_transition(a @ weight.T, cs), + suite, + ) + if "pytorch" in active_alg_set: + record["ms_pytorch"] = run_bench( + lambda: vtnk_pytorch(linear(a), cur_node, csr, step=step), + suite, + ) + if "sparse_pytorch" in active_alg_set: + record["ms_sparse_pytorch"] = run_bench( + lambda: sparse_linear_pytorch_compiled(a, weight, cur_node, csr, step=step), + suite, + ) + if "dense_lookup" in active_alg_set: + record["ms_dense_lookup"] = run_bench( + lambda: dense_lookup_pytorch(a, weight, cur_node, csr, step=step), + suite, + ) + if "dense_topk" in active_alg_set: + record["ms_dense_topk"] = run_bench( + lambda: torch.topk( + constrained_node_transition(_dense_matmul_compiled(a, weight), cs)[2], run.k_top, dim=-1 + ), + suite, + ) + if "trie_cpu" in active_alg_set: + record["ms_trie_cpu"] = run_bench_cpu( + lambda: trie_cpu_traversal(cur_node_cpu, trie_nodes, run.N), + suite, + ) + + if "fused" in active_alg_set and "kernel" in active_alg_set: + record["speedup_fused_vs_kernel"] = record["ms_kernel"] / record["ms_fused"] + if "fused" in active_alg_set and "pytorch" in active_alg_set: + record["speedup_fused_vs_pytorch"] = record["ms_pytorch"] / record["ms_fused"] + if "fused" in active_alg_set and "sparse_pytorch" in active_alg_set: + record["speedup_fused_vs_sparse_pytorch"] = record["ms_sparse_pytorch"] / record["ms_fused"] + if "fused" in active_alg_set and "trie_cpu" in active_alg_set: + record["speedup_fused_vs_trie_cpu"] = record["ms_trie_cpu"] / record["ms_fused"] + if "fused" in active_alg_set and "dense_lookup" in active_alg_set: + record["speedup_fused_vs_dense_lookup"] = record["ms_dense_lookup"] / record["ms_fused"] + if "fused" in active_alg_set and "dense_topk" in active_alg_set: + record["speedup_fused_vs_dense_topk"] = record["ms_dense_topk"] / record["ms_fused"] + + return record + + +def benchmark_grid(suite: BenchmarkSuite, algorithms: list[str]) -> pd.DataFrame: + max_B = max(r.B for r in suite.runs) + max_N = max(r.N for r in suite.runs) + return pd.DataFrame([ + benchmark_run(run, suite, algorithms, max_N=max_N, max_B=max_B) + for run in suite.runs + ]) def plot_heatmap( @@ -332,44 +328,20 @@ def plot_heatmap( metavar="ALGO", help=f"Algorithms to benchmark. Choices: {ALL_ALGORITHMS} (default: {DEFAULT_ALGORITHMS})", ) - parser.add_argument( - "--sparsity", - type=float, - default=DEFAULT_SPARSITY, - help="Fraction of vocab used as max branches (default: %(default)s)", - ) - parser.add_argument( - "--diverse-nodes", - action="store_true", - default=False, - help=( - "Place each batch element on a different trie node (2-level trie, step=1) " - "instead of all starting at the root. Tests realistic cache-miss pressure. " - "Incompatible with trie_cpu (automatically skipped)." - ), - ) args = parser.parse_args() assert torch.cuda.is_available(), "CUDA required" os.makedirs("out", exist_ok=True) - B_vals = [256, 1024, 4096] - N_vals = [150000] + suite = VTNK_SUITE - print( - f"Benchmarking K={K}, sparsity={args.sparsity}, diverse_nodes={args.diverse_nodes}" - ) print(f"Algorithms: {args.algorithms}") - print(f"B_vals={B_vals}") - print(f"N_vals={N_vals}\n") - - df = benchmark_grid( - B_vals, - N_vals, - algorithms=args.algorithms, - sparsity=args.sparsity, - diverse_nodes=args.diverse_nodes, - ) + print(f"diverse_nodes={suite.diverse_nodes}") + for run in suite.runs: + print(f" B={run.B}, N={run.N}, k={run.k}, sparsity={run.sparsity}") + print() + + df = benchmark_grid(suite, algorithms=args.algorithms) csv_path = "out/bench_vtnk.csv" df.to_csv(csv_path, index=False) print(f"\nSaved {csv_path}\n") @@ -380,7 +352,7 @@ def plot_heatmap( plot_heatmap( df, value_col="speedup_fused_vs_kernel", - title=f"Fused speedup vs compiled_linear+constrained_kernel (K={K})", + title=f"Fused speedup vs compiled_linear+constrained_kernel (K={suite.runs[0].k})", filename="out/heatmap_fused_vs_kernel.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -388,7 +360,7 @@ def plot_heatmap( plot_heatmap( df, value_col="speedup_fused_vs_pytorch", - title=f"Fused speedup vs compiled_linear+vtnk_pytorch (K={K})", + title=f"Fused speedup vs compiled_linear+vtnk_pytorch (K={suite.runs[0].k})", filename="out/heatmap_fused_vs_pytorch.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -396,7 +368,7 @@ def plot_heatmap( plot_heatmap( df, value_col="speedup_fused_vs_sparse_pytorch", - title=f"Fused speedup vs sparse_linear_pytorch (K={K})", + title=f"Fused speedup vs sparse_linear_pytorch (K={suite.runs[0].k})", filename="out/heatmap_fused_vs_sparse_pytorch.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -404,7 +376,7 @@ def plot_heatmap( plot_heatmap( df, value_col="speedup_fused_vs_trie_cpu", - title=f"Fused speedup vs CPU trie traversal (K={K})", + title=f"Fused speedup vs CPU trie traversal (K={suite.runs[0].k})", filename="out/heatmap_fused_vs_trie_cpu.jpg", cbar_label="Speedup (>1 = fused faster)", ) @@ -412,7 +384,15 @@ def plot_heatmap( plot_heatmap( df, value_col="speedup_fused_vs_dense_lookup", - title=f"Fused speedup vs dense_lookup (K={K})", + title=f"Fused speedup vs dense_lookup (K={suite.runs[0].k})", filename="out/heatmap_fused_vs_dense_lookup.jpg", cbar_label="Speedup (>1 = fused faster)", ) + if "speedup_fused_vs_dense_topk" in df.columns: + plot_heatmap( + df, + value_col="speedup_fused_vs_dense_topk", + title=f"Fused speedup vs compile(dense_matmul)+topk (K={suite.runs[0].k}, k_top={suite.runs[0].k_top})", + filename="out/heatmap_fused_vs_dense_topk.jpg", + cbar_label="Speedup (>1 = fused faster)", + ) diff --git a/examples/scripts/benchmark/out/bench_fused_sample.csv b/examples/scripts/benchmark/out/bench_fused_sample.csv new file mode 100644 index 0000000..25d8836 --- /dev/null +++ b/examples/scripts/benchmark/out/bench_fused_sample.csv @@ -0,0 +1,5 @@ +B,N,ms_fused_sample,ms_sparse_pytorch_sample_compact,ms_fused_topk,ms_sparse_pytorch_topk_compact,speedup_fused_vs_sparse_pytorch_sample_compact,speedup_fused_topk_vs_sparse_pytorch_topk_compact +256,256,0.06371258458168041,0.17934747504583304,0.0755965249693912,0.08505579645309268,2.814945841914592,1.1251283903265594 +1024,256,0.18795623520717902,0.4799594249307495,0.19442370589132663,0.26744760695050973,2.553570113817737,1.375591550034541 +256,150000,0.28657493044932686,0.9350657916780728,0.29325902286697836,0.572092049577263,3.262901574160599,1.9508080057838926 +1024,150000,0.9918080000206828,3.537914663553238,0.9786599370149466,2.160640864758878,3.567136646890789,2.2077544845139396 diff --git a/examples/scripts/benchmark/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg b/examples/scripts/benchmark/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg new file mode 100644 index 0000000..5ba3354 Binary files /dev/null and b/examples/scripts/benchmark/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg differ diff --git a/out/bench_fused_sample.csv b/out/bench_fused_sample.csv index 729f224..9608b85 100644 --- a/out/bench_fused_sample.csv +++ b/out/bench_fused_sample.csv @@ -1,7 +1,5 @@ -B,N,ms_fused_sample,ms_sparse_pytorch_sample,speedup_fused_vs_sparse_pytorch_sample -256,256,0.052990657996810124,0.4847820407273818,9.148443500297056 -256,150000,0.2794327307338557,5.122409343719482,18.331457915709578 -1024,256,0.053247999399900436,0.5665876357392832,10.640543158891761 -1024,150000,0.9518530856479298,19.950080394744873,20.95920126283436 -4096,256,0.06617873177931803,0.6101244225221522,9.219342923595073 -4096,150000,4.0939519973028276,76.34329223632812,18.64782300491665 +B,N,ms_fused_sample,ms_sparse_pytorch_sample_compact,ms_fused_topk,ms_sparse_pytorch_topk_compact,ms_dense_topk,speedup_fused_vs_sparse_pytorch_sample_compact,speedup_fused_topk_vs_sparse_pytorch_topk_compact,speedup_fused_topk_vs_dense_topk +256,256,0.0639494627078132,0.17952672111383383,0.07594812758590865,0.0851272442993128,0.26681392294604606,2.807321805565376,1.120860342514978,3.513107319785333 +1024,256,0.1877503998577595,0.4797707233846802,0.19362583933053193,0.26734662324678704,0.3056407542819651,2.5553645890935868,1.3807383568801936,1.5785122240850116 +256,150000,0.28675306439399717,0.9354942076241792,0.2990090902579033,0.5744870197906923,39.74041557312012,3.262368650187519,1.9213028583685599,132.90704820660454 +1024,150000,0.9933454981073737,3.5444053411483765,0.9791015414091256,2.151728365872357,157.49119567871094,3.5681495994108294,2.1976559885459723,160.8527706452685 diff --git a/out/bench_vtnk.csv b/out/bench_vtnk.csv index c8e66c6..42b244f 100644 --- a/out/bench_vtnk.csv +++ b/out/bench_vtnk.csv @@ -1,4 +1,3 @@ -B,N,ms_fused,ms_kernel,ms_pytorch,ms_sparse_pytorch,ms_dense_lookup,speedup_fused_vs_kernel,speedup_fused_vs_pytorch,speedup_fused_vs_sparse_pytorch,speedup_fused_vs_dense_lookup -256,150000,0.3935734379403996,37.604352951049805,2.139191369752626,1.098073067267736,2.1511204371581205,95.54596252185185,5.4353042241549145,2.790008068161402,5.46561385955084 -1024,150000,1.507946029305458,148.959228515625,7.843095259232954,3.6962379268977954,7.717296004295349,98.78286465214795,5.201177699208101,2.451173884916981,5.1177534568991465 -4096,150000,7.3295360406239825,592.3890991210938,30.60221799214681,14.427306652069092,30.289578119913738,80.82218244617052,4.175191693244142,1.9683792496695136,4.132536896201019 +B,N,ms_fused,ms_kernel,ms_pytorch,ms_sparse_pytorch,ms_dense_lookup,ms_dense_topk,speedup_fused_vs_kernel,speedup_fused_vs_pytorch,speedup_fused_vs_sparse_pytorch,speedup_fused_vs_dense_lookup,speedup_fused_vs_dense_topk +256,150000,0.5463001217160907,37.63763236999512,2.1372595495647855,1.0868175733284873,2.1542763065647437,39.379966735839844,68.89552257789025,3.91224432249991,1.9894148474916518,3.9433934222777087,72.08485806690962 +1024,150000,2.107834861085222,148.9899444580078,7.86804368279197,3.6992222744485606,7.733674685160319,155.652099609375,70.68387908780488,3.7327609615210067,1.754986760463773,3.669013558860406,73.84454184861393 diff --git a/out/heatmap_ell_sample_vs_sparse_pytorch.jpg b/out/heatmap_ell_sample_vs_sparse_pytorch.jpg new file mode 100644 index 0000000..1f48fa0 Binary files /dev/null and b/out/heatmap_ell_sample_vs_sparse_pytorch.jpg differ diff --git a/out/heatmap_ell_vs_csr_sample.jpg b/out/heatmap_ell_vs_csr_sample.jpg new file mode 100644 index 0000000..88fe6de Binary files /dev/null and b/out/heatmap_ell_vs_csr_sample.jpg differ diff --git a/out/heatmap_ell_vs_csr_topk.jpg b/out/heatmap_ell_vs_csr_topk.jpg new file mode 100644 index 0000000..849240f Binary files /dev/null and b/out/heatmap_ell_vs_csr_topk.jpg differ diff --git a/out/heatmap_fused_sample_vs_sparse_pytorch.jpg b/out/heatmap_fused_sample_vs_sparse_pytorch.jpg index a5dacba..3218ed7 100644 Binary files a/out/heatmap_fused_sample_vs_sparse_pytorch.jpg and b/out/heatmap_fused_sample_vs_sparse_pytorch.jpg differ diff --git a/out/heatmap_fused_topk_vs_dense_topk.jpg b/out/heatmap_fused_topk_vs_dense_topk.jpg new file mode 100644 index 0000000..dc7c625 Binary files /dev/null and b/out/heatmap_fused_topk_vs_dense_topk.jpg differ diff --git a/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg b/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg index a8d2943..9511959 100644 Binary files a/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg and b/out/heatmap_fused_topk_vs_sparse_pytorch_topk_compact.jpg differ diff --git a/out/heatmap_fused_vs_dense_lookup.jpg b/out/heatmap_fused_vs_dense_lookup.jpg index 86a1ee0..865efcd 100644 Binary files a/out/heatmap_fused_vs_dense_lookup.jpg and b/out/heatmap_fused_vs_dense_lookup.jpg differ diff --git a/out/heatmap_fused_vs_dense_topk.jpg b/out/heatmap_fused_vs_dense_topk.jpg new file mode 100644 index 0000000..f9fcd90 Binary files /dev/null and b/out/heatmap_fused_vs_dense_topk.jpg differ diff --git a/out/heatmap_fused_vs_kernel.jpg b/out/heatmap_fused_vs_kernel.jpg index bca4cc6..1a5d484 100644 Binary files a/out/heatmap_fused_vs_kernel.jpg and b/out/heatmap_fused_vs_kernel.jpg differ diff --git a/out/heatmap_fused_vs_pytorch.jpg b/out/heatmap_fused_vs_pytorch.jpg index 0c04fb5..3a9c5a4 100644 Binary files a/out/heatmap_fused_vs_pytorch.jpg and b/out/heatmap_fused_vs_pytorch.jpg differ diff --git a/out/heatmap_fused_vs_sparse_pytorch.jpg b/out/heatmap_fused_vs_sparse_pytorch.jpg index 4e93637..42a5224 100644 Binary files a/out/heatmap_fused_vs_sparse_pytorch.jpg and b/out/heatmap_fused_vs_sparse_pytorch.jpg differ diff --git a/rectokens/decoding/vntk.py b/rectokens/decoding/vntk.py index a032d7f..de363d7 100644 --- a/rectokens/decoding/vntk.py +++ b/rectokens/decoding/vntk.py @@ -1,6 +1,66 @@ import torch +def _sparse_ell_branch_logits(a, weight, cur_node, ell_trie, step): + """Shared inner: ELL trie traversal + dot products for valid branches only. + + Returns (next_node, valid_idxs, branch_logits) where branch_logits has + shape (B, max_branches) with -inf padding — no (B, N) scatter. + weight shape: (N, K) — standard nn.Linear weight layout. + """ + max_branches = ell_trie.layer_max_branches[step] + + # Direct index by node ID — no row_ptr indirection (ELL advantage over CSR) + cols = ell_trie.ell_cols_vals[cur_node, 0, :max_branches] # (B, max_branches) + vals = ell_trie.ell_cols_vals[cur_node, 1, :max_branches] # (B, max_branches) + + valid_range = cols >= 0 # -1 sentinel marks padding + valid_idxs = torch.where(valid_range, cols, -1) + next_node = torch.where(valid_range, vals, -1) + + clamped_idxs = valid_idxs.clamp(min=0) # (B, max_branches) + valid_weights = weight[clamped_idxs] # (B, max_branches, K) + logits = (a.unsqueeze(1) * valid_weights).to(torch.float32).sum(dim=-1) # (B, max_branches) + branch_logits = torch.where(valid_range, logits, float("-inf")) + + return next_node, valid_idxs, branch_logits + + +def sparse_linear_ell_pytorch(a, weight, cur_node, ell_trie, step): + """ + PyTorch impl for ELL-format trie that only computes logits for valid tokens. + Equivalent to sparse_linear_pytorch but uses ELL format (no row_ptr lookup). + weight shape: (N, K) — standard nn.Linear weight layout. + """ + device = ell_trie.ell_cols_vals.device + B = a.shape[0] + N = weight.shape[0] + + next_node, valid_idxs, branch_logits = _sparse_ell_branch_logits( + a, weight, cur_node, ell_trie, step + ) + + corrected_logits = torch.full( + (B, N), float("-inf"), dtype=torch.float32, device=device + ) + b_idx = torch.arange(B, device=device).unsqueeze(-1).expand_as(valid_idxs) + valid = valid_idxs >= 0 + corrected_logits[b_idx[valid], valid_idxs[valid]] = branch_logits[valid] + + return next_node, valid_idxs, corrected_logits + + +def sparse_linear_compact_ell_pytorch(a, weight, cur_node, ell_trie, step): + """Like sparse_linear_ell_pytorch but skips the (B, N) scatter. + + Returns (next_node, valid_idxs, branch_logits) where branch_logits has + shape (B, max_branches). Avoids allocating the full vocab-sized logit + matrix — top-k can be applied directly on the compact buffer. + weight shape: (N, K) — standard nn.Linear weight layout. + """ + return _sparse_ell_branch_logits(a, weight, cur_node, ell_trie, step) + + def _sparse_branch_logits(a, weight, cur_node, trie, step): """Shared inner: trie traversal + dot products for valid branches only. diff --git a/rectokens/kernels/constrained_node_transition.py b/rectokens/kernels/constrained_node_transition.py index 819162c..0b51a8c 100644 --- a/rectokens/kernels/constrained_node_transition.py +++ b/rectokens/kernels/constrained_node_transition.py @@ -177,13 +177,15 @@ def _constrained_node_transition_kernel( _FUSED_AUTOTUNE_CONFIGS = [ - triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}), - triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}), - triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}), - triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}), - triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}), - triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}), - triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_stages=2), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_stages=3), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_stages=3), + triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_stages=2), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_stages=2), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_stages=2), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_stages=2), # triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 64}), # triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 64}), ] @@ -252,8 +254,11 @@ def _compute_branch_logits( ): """Load all branch data and compute per-branch dot-product logits. - Loads a_chunk once per K-tile and amortizes it across all BLOCK_BRANCHES branches, - reducing query-vector HBM traffic by BLOCK_BRANCHES× vs computing each branch separately. + Loads a_chunk once per K-tile (amortized across BLOCK_BRANCHES branches). + Within each K-tile, b_chunk loads are software-pipelined: the load for branch + i+1 is issued before the dot-product computation for branch i, overlapping + HBM gather latency with compute. num_stages on the outer k_tile loop drives + a_chunk prefetch via Triton's auto-pipeliner. Returns (branch_cols, branch_vals, branch_valid, logits) as [BLOCK_B, BLOCK_BRANCHES] tensors. Bias is folded into logits when HAS_BIAS=True. @@ -284,22 +289,46 @@ def _compute_branch_logits( k_mask = offs_K < K # Load a_chunk once per tile; reused across all BLOCK_BRANCHES branches. + # Triton's auto-pipeliner (num_stages) overlaps this load with the + # previous tile's inner branch loop. a_chunk = tl.load( a_ptr + offs_B[:, None] * a_stride_B + offs_K[None, :] * a_stride_K, mask=b_mask[:, None] & k_mask[None, :], other=0.0, ) # [BLOCK_B, BLOCK_K] + # Software pipeline prologue: issue b_chunk load for branch 0 before + # entering the inner loop so it is in flight during the first dot-product. + _, col_k_next, c_mask_next = _select_branch(0, branch_cols, branch_valid, BLOCK_BRANCHES) + b_chunk_next = tl.load( + b_ptr + offs_K[None, :] * b_stride_K + col_k_next[:, None] * b_stride_N, + mask=c_mask_next[:, None] & k_mask[None, :], + other=0.0, + ) + + # tl.static_range fully unrolls the loop at compile time, turning each + # iteration into a distinct sequence of IR operations. This lets us alias + # b_chunk_next → b_chunk_cur and immediately rebind b_chunk_next to the + # *next* branch's load, creating a double-buffer: load i+1 is in flight + # while dot(a_chunk, b_chunk_i) executes. for local_br in tl.static_range(BLOCK_BRANCHES): - br_sel, col_k, c_mask = _select_branch( - local_br, branch_cols, branch_valid, BLOCK_BRANCHES - ) - b_chunk = tl.load( - b_ptr + offs_K[None, :] * b_stride_K + col_k[:, None] * b_stride_N, - mask=c_mask[:, None] & k_mask[None, :], - other=0.0, - ) # [BLOCK_B, BLOCK_K] - dot = tl.sum(a_chunk * b_chunk, axis=1) # [BLOCK_B] + b_chunk_cur = b_chunk_next # use the prefetched data + br_sel = tl.arange(0, BLOCK_BRANCHES) == local_br + + # Prefetch next branch's b_chunk while the current dot-product executes. + # The Python-level if is a compile-time guard; the last iteration emits + # no prefetch load, avoiding an out-of-bounds address. + if local_br + 1 < BLOCK_BRANCHES: + _, col_k_next, c_mask_next = _select_branch( + local_br + 1, branch_cols, branch_valid, BLOCK_BRANCHES + ) + b_chunk_next = tl.load( + b_ptr + offs_K[None, :] * b_stride_K + col_k_next[:, None] * b_stride_N, + mask=c_mask_next[:, None] & k_mask[None, :], + other=0.0, + ) + + dot = tl.sum(a_chunk * b_chunk_cur, axis=1) # [BLOCK_B] logits = tl.where(br_sel[None, :], logits + dot[:, None], logits) if HAS_BIAS: diff --git a/rectokens/kernels/constrained_node_transition_ell.py b/rectokens/kernels/constrained_node_transition_ell.py new file mode 100644 index 0000000..43c1395 --- /dev/null +++ b/rectokens/kernels/constrained_node_transition_ell.py @@ -0,0 +1,500 @@ +import time +from typing import Optional + +import torch + +assert torch.cuda.is_available(), "CUDA is required to import ELL constrained node transition kernels." + +import triton +import triton.language as tl +from torch.library import triton_op, wrap_triton + +from rectokens.kernels.constrained_node_transition import ( + _select_branch, + _FUSED_MIN_BLOCK_BRANCHES, +) + + +# ───────────────────────────────────────────────────────────────────────────── +# ELL-specific autotune configs +# ───────────────────────────────────────────────────────────────────────────── + +# Separate from _FUSED_AUTOTUNE_CONFIGS: adds BLOCK_B=32 to prevent register +# spilling at large batch sizes, and explicit num_warps for occupancy tuning. +_ELL_AUTOTUNE_CONFIGS = [ + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4), + triton.Config({"BLOCK_B": 256, "BLOCK_K": 64, "BLOCK_BRANCHES": 4}, num_warps=4), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=4), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 128, "BLOCK_BRANCHES": 8}, num_warps=4), + triton.Config({"BLOCK_B": 64, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_warps=4), + triton.Config({"BLOCK_B": 128, "BLOCK_K": 64, "BLOCK_BRANCHES": 16}, num_warps=4), +] + + +# ───────────────────────────────────────────────────────────────────────────── +# Shared Triton device-function helpers +# ───────────────────────────────────────────────────────────────────────────── + + +@triton.jit +def _compute_ell_branch_logits( + offs_B, + offs_BR, + b_mask, + cur_node, + branch_cols, # [BLOCK_B, BLOCK_BRANCHES] — pre-loaded by caller + branch_valid, # [BLOCK_B, BLOCK_BRANCHES] — pre-computed by caller + a_ptr, + b_ptr, + bias_ptr, + ell_cols_vals_ptr, + a_stride_B, + a_stride_K, + b_stride_K, + b_stride_N, + ell_node_stride, # ell_cols_vals.stride(0) = 2 * max_branches + ell_cv_stride, # ell_cols_vals.stride(1) = max_branches (cols→vals gap within a node) + K: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_BRANCHES: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """Compute per-branch dot-product logits from ELL format. + + branch_cols and branch_valid are pre-loaded by the caller so validity can be + checked before this function is called, enabling an early exit for empty tiles. + + ell_node_base [BLOCK_B] avoids the former ell_row_offsets [BLOCK_B, BLOCK_BRANCHES] + intermediate, reducing register footprint by BLOCK_BRANCHES×. + """ + ell_node_base = cur_node.to(tl.int64) * ell_node_stride # [BLOCK_B] + + branch_vals = tl.load( + ell_cols_vals_ptr + ell_cv_stride + ell_node_base[:, None] + offs_BR[None, :].to(tl.int64), + mask=branch_valid, + other=-1, + ) # [BLOCK_B, BLOCK_BRANCHES] + + logits = tl.zeros((BLOCK_B, BLOCK_BRANCHES), dtype=tl.float32) + + for k_tile in range(0, tl.cdiv(K, BLOCK_K)): + offs_K = k_tile * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = offs_K < K + + a_chunk = tl.load( + a_ptr + offs_B[:, None] * a_stride_B + offs_K[None, :] * a_stride_K, + mask=b_mask[:, None] & k_mask[None, :], + other=0.0, + ) # [BLOCK_B, BLOCK_K] + + for local_br in tl.static_range(BLOCK_BRANCHES): + br_sel, col_k, c_mask = _select_branch( + local_br, branch_cols, branch_valid, BLOCK_BRANCHES + ) + b_chunk = tl.load( + b_ptr + offs_K[None, :] * b_stride_K + col_k[:, None] * b_stride_N, + mask=c_mask[:, None] & k_mask[None, :], + other=0.0, + ) # [BLOCK_B, BLOCK_K] + dot = tl.sum(a_chunk * b_chunk, axis=1) # [BLOCK_B] + logits = tl.where(br_sel[None, :], logits + dot[:, None], logits) + + if HAS_BIAS: + for local_br in tl.static_range(BLOCK_BRANCHES): + br_sel, col_k, c_mask = _select_branch( + local_br, branch_cols, branch_valid, BLOCK_BRANCHES + ) + bias_k = tl.load(bias_ptr + col_k, mask=c_mask, other=0.0) + logits = tl.where(br_sel[None, :], logits + bias_k[:, None], logits) + + return branch_vals, logits + + +@triton.jit +def _ell_fused_prologue( + cur_node_ptr, + ell_n_children_ptr, + a_ptr, + b_ptr, + bias_ptr, + ell_cols_vals_ptr, + next_node_ptr, + valid_idxs_ptr, + a_stride_B, + a_stride_K, + b_stride_K, + b_stride_N, + ell_node_stride, + ell_cv_stride, + next_node_stride_B, + next_node_stride_N, + valid_idxs_stride_B, + valid_idxs_stride_N, + max_branches: tl.constexpr, + B: tl.constexpr, + K: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_BRANCHES: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + """Shared ELL prologue: pid setup, branch-cols load, validity check, logit compute, stores. + + Uses n_children as a block-level gate: if no batch item has a child in + [pid_BR*BLOCK_BRANCHES, ...), the ell_cols_vals load is skipped entirely, avoiding + HBM traffic for padding blocks when max_branches is large. Within valid blocks, + per-lane validity is derived from the -1 sentinel in branch_cols. + + Returns any_valid so the calling kernel can skip its own work (gumbel sampling / + logit stores) when the block has no valid branches. + """ + pid_B = tl.program_id(axis=0) + pid_BR = tl.program_id(axis=1) + + offs_B = pid_B * BLOCK_B + tl.arange(0, BLOCK_B) + offs_BR = pid_BR * BLOCK_BRANCHES + tl.arange(0, BLOCK_BRANCHES) + b_mask = offs_B < B + + cur_node = tl.load(cur_node_ptr + offs_B, mask=b_mask, other=-1) + n_children = tl.load(ell_n_children_ptr + cur_node, mask=cur_node >= 0, other=0) + + # Block-level gate: skip ELL load if no batch item has a child in this branch slice. + # n_children is a compact (num_nodes,) int tensor that fits in L2; checking it here + # avoids reading the much larger ell_cols_vals for blocks that are entirely padding. + block_br_start = pid_BR * BLOCK_BRANCHES + n_in_range = tl.sum((n_children > block_br_start).to(tl.int32)) + if n_in_range > 0: + ell_node_base = cur_node.to(tl.int64) * ell_node_stride + load_mask = b_mask[:, None] & (offs_BR[None, :] < max_branches) + branch_cols = tl.load( + ell_cols_vals_ptr + ell_node_base[:, None] + offs_BR[None, :].to(tl.int64), + mask=load_mask, + other=-1, + ) # [BLOCK_B, BLOCK_BRANCHES] + branch_valid = b_mask[:, None] & (branch_cols >= 0) # -1 sentinel marks padding + + any_valid = tl.sum(branch_valid.to(tl.int32)) > 0 + if any_valid: + branch_vals, logits = _compute_ell_branch_logits( + offs_B, offs_BR, b_mask, cur_node, branch_cols, branch_valid, + a_ptr, b_ptr, bias_ptr, ell_cols_vals_ptr, + a_stride_B, a_stride_K, b_stride_K, b_stride_N, + ell_node_stride, ell_cv_stride, + K, BLOCK_B, BLOCK_K, BLOCK_BRANCHES, HAS_BIAS, + ) + store_mask = b_mask[:, None] & (offs_BR[None, :] < max_branches) + tl.store( + next_node_ptr + + offs_B[:, None] * next_node_stride_B + + offs_BR[None, :] * next_node_stride_N, + branch_vals, + mask=store_mask, + ) + tl.store( + valid_idxs_ptr + + offs_B[:, None] * valid_idxs_stride_B + + offs_BR[None, :] * valid_idxs_stride_N, + branch_cols, + mask=store_mask, + ) + else: + logits = tl.zeros([BLOCK_B, BLOCK_BRANCHES], dtype=tl.float32) + else: + branch_cols = tl.full([BLOCK_B, BLOCK_BRANCHES], -1, dtype=tl.int64) + branch_valid = branch_cols >= 0 # all False + logits = tl.zeros([BLOCK_B, BLOCK_BRANCHES], dtype=tl.float32) + any_valid = n_in_range > 0 # False + + return pid_BR, offs_B, offs_BR, b_mask, branch_cols, branch_valid, logits, any_valid + + +# ───────────────────────────────────────────────────────────────────────────── +# Fused sparse linear + Gumbel-max sampling (ELL) +# ───────────────────────────────────────────────────────────────────────────── + + +@triton_op("vtnk_ell::_fused_linear_constrained_node_transition_sampling_op", mutates_args={}) +def _ell_fused_linear_constrained_node_transition_sampling_op( + a: torch.Tensor, + b: torch.Tensor, + bias_val: torch.Tensor, + cur_node: torch.Tensor, + ell_cols_vals: torch.Tensor, + ell_n_children: torch.Tensor, + max_branches: int, + has_bias: bool, + rng_seed: Optional[int] = None, + temperature: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if rng_seed is None: + rng_seed = time.time_ns() & 0x7FFFFFFF + if temperature is None or temperature == 1.0: + temperature = torch.ones(1, dtype=torch.float32, device=a.device) + elif isinstance(temperature, float): + temperature = torch.tensor(temperature, dtype=torch.float32, device=a.device) + + B, K = a.shape + + assert cur_node.shape == (B,), f"Expected cur_node shape ({B},), got {cur_node.shape}" + + a = a.contiguous() + cur_node = cur_node.contiguous() + ell_cols_vals = ell_cols_vals.contiguous() + ell_n_children = ell_n_children.contiguous() + bias_val = bias_val.contiguous() + + next_node = cur_node.new_full((B, max_branches), -1) + valid_idxs = cur_node.new_full((B, max_branches), -1) + max_br_blocks = triton.cdiv(max_branches, _FUSED_MIN_BLOCK_BRANCHES) + gumbel_block_max = torch.full( + (B, max_br_blocks), float("-inf"), dtype=torch.float32, device=a.device + ) + block_sample_buf = torch.full( + (B, max_br_blocks), -1.0, dtype=torch.float32, device=a.device + ) + + grid = lambda meta: ( + triton.cdiv(B, meta["BLOCK_B"]), + triton.cdiv(max_branches, meta["BLOCK_BRANCHES"]), + ) + wrap_triton(_ell_fused_sampling_kernel)[grid]( + a_ptr=a, + b_ptr=b, + bias_ptr=bias_val, + cur_node_ptr=cur_node, + ell_cols_vals_ptr=ell_cols_vals, + ell_n_children_ptr=ell_n_children, + temperature_ptr=temperature, + a_stride_B=a.stride(0), + a_stride_K=a.stride(1), + b_stride_K=b.stride(0), + b_stride_N=b.stride(1), + ell_node_stride=ell_cols_vals.stride(0), + ell_cv_stride=ell_cols_vals.stride(1), + next_node_ptr=next_node, + valid_idxs_ptr=valid_idxs, + gumbel_block_max_ptr=gumbel_block_max, + block_sample_ptr=block_sample_buf, + next_node_stride_B=next_node.stride(0), + next_node_stride_N=next_node.stride(1), + valid_idxs_stride_B=valid_idxs.stride(0), + valid_idxs_stride_N=valid_idxs.stride(1), + max_br_blocks=max_br_blocks, + rng_seed=rng_seed, + B=B, + K=K, + max_branches=max_branches, + HAS_BIAS=has_bias, + ) + + argmax = torch.max(gumbel_block_max, dim=1).indices # [B] + sample = block_sample_buf.gather(1, argmax.unsqueeze(1)).squeeze(1) + return next_node, valid_idxs, sample + + +@triton.autotune( + configs=_ELL_AUTOTUNE_CONFIGS, + key=["B", "K", "max_branches"], + restore_value=["next_node_ptr", "valid_idxs_ptr", "gumbel_block_max_ptr", "block_sample_ptr"], +) +@triton.jit +def _ell_fused_sampling_kernel( + # Inputs + a_ptr, + b_ptr, + bias_ptr, + cur_node_ptr, + ell_cols_vals_ptr, + ell_n_children_ptr, + temperature_ptr, + a_stride_B, + a_stride_K, + b_stride_K, + b_stride_N, + ell_node_stride, + ell_cv_stride, + # Outputs + next_node_ptr, + valid_idxs_ptr, + gumbel_block_max_ptr, + block_sample_ptr, + next_node_stride_B, + next_node_stride_N, + valid_idxs_stride_B, + valid_idxs_stride_N, + max_br_blocks, + rng_seed, + # Constants + B: tl.constexpr, + K: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_BRANCHES: tl.constexpr, + max_branches: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + pid_BR, offs_B, offs_BR, b_mask, branch_cols, branch_valid, logits, any_valid = _ell_fused_prologue( + cur_node_ptr, ell_n_children_ptr, + a_ptr, b_ptr, bias_ptr, ell_cols_vals_ptr, + next_node_ptr, valid_idxs_ptr, + a_stride_B, a_stride_K, b_stride_K, b_stride_N, + ell_node_stride, ell_cv_stride, + next_node_stride_B, next_node_stride_N, + valid_idxs_stride_B, valid_idxs_stride_N, + max_branches, B, K, BLOCK_B, BLOCK_K, BLOCK_BRANCHES, HAS_BIAS, + ) + if any_valid == 0: + return + + temperature = tl.load(temperature_ptr) + u = tl.rand(seed=rng_seed, offset=offs_B[:, None] * max_branches + offs_BR[None, :]) + gumbel = -tl.log(-tl.log(u + 1e-10) + 1e-10) + g_vals = tl.where(branch_valid, logits / temperature + gumbel, float("-inf")) + block_max_gumbel = tl.max(g_vals, axis=1) # [BLOCK_B] + + winner_idx = tl.argmax(g_vals, axis=1) # [BLOCK_B] + br_sel = tl.arange(0, BLOCK_BRANCHES)[None, :] == winner_idx[:, None] + block_sample = tl.sum( + tl.where(br_sel & branch_valid, branch_cols.to(tl.float32), 0.0), axis=1 + ) + + tl.store( + gumbel_block_max_ptr + offs_B * max_br_blocks + pid_BR, + block_max_gumbel, + mask=b_mask, + ) + tl.store( + block_sample_ptr + offs_B * max_br_blocks + pid_BR, + block_sample, + mask=b_mask & (block_max_gumbel > float("-inf")), + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Fused sparse linear + top-K (ELL) +# ───────────────────────────────────────────────────────────────────────────── + + +@triton_op("vtnk_ell::_fused_linear_constrained_node_transition_topk_op", mutates_args={}) +def _ell_fused_linear_constrained_node_transition_topk_op( + a: torch.Tensor, + b: torch.Tensor, + bias_val: torch.Tensor, + cur_node: torch.Tensor, + ell_cols_vals: torch.Tensor, + ell_n_children: torch.Tensor, + max_branches: int, + has_bias: bool, + k: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, K = a.shape + + assert cur_node.shape == (B,), f"Expected cur_node shape ({B},), got {cur_node.shape}" + + a = a.contiguous() + cur_node = cur_node.contiguous() + ell_cols_vals = ell_cols_vals.contiguous() + ell_n_children = ell_n_children.contiguous() + bias_val = bias_val.contiguous() + + next_node = cur_node.new_full((B, max_branches), -1) + valid_idxs = cur_node.new_full((B, max_branches), -1) + branch_logits = torch.full( + (B, max_branches), float("-inf"), dtype=torch.float32, device=a.device + ) + + grid = lambda meta: ( + triton.cdiv(B, meta["BLOCK_B"]), + triton.cdiv(max_branches, meta["BLOCK_BRANCHES"]), + ) + wrap_triton(_ell_fused_compact_kernel)[grid]( + a_ptr=a, + b_ptr=b, + bias_ptr=bias_val, + cur_node_ptr=cur_node, + ell_cols_vals_ptr=ell_cols_vals, + ell_n_children_ptr=ell_n_children, + a_stride_B=a.stride(0), + a_stride_K=a.stride(1), + b_stride_K=b.stride(0), + b_stride_N=b.stride(1), + ell_node_stride=ell_cols_vals.stride(0), + ell_cv_stride=ell_cols_vals.stride(1), + next_node_ptr=next_node, + valid_idxs_ptr=valid_idxs, + branch_logits_ptr=branch_logits, + next_node_stride_B=next_node.stride(0), + next_node_stride_N=next_node.stride(1), + valid_idxs_stride_B=valid_idxs.stride(0), + valid_idxs_stride_N=valid_idxs.stride(1), + B=B, + K=K, + max_branches=max_branches, + HAS_BIAS=has_bias, + ) + + if k >= max_branches: + return next_node, valid_idxs, branch_logits, valid_idxs.clone() + topk_logits, topk_branch_idxs = torch.topk(branch_logits, k, dim=-1) + topk_idxs = valid_idxs.gather(1, topk_branch_idxs) + return next_node, valid_idxs, topk_logits, topk_idxs + + +@triton.autotune( + configs=_ELL_AUTOTUNE_CONFIGS, + key=["B", "K", "max_branches"], + restore_value=["next_node_ptr", "valid_idxs_ptr", "branch_logits_ptr"], +) +@triton.jit +def _ell_fused_compact_kernel( + # Inputs + a_ptr, + b_ptr, + bias_ptr, + cur_node_ptr, + ell_cols_vals_ptr, + ell_n_children_ptr, + a_stride_B, + a_stride_K, + b_stride_K, + b_stride_N, + ell_node_stride, + ell_cv_stride, + # Outputs + next_node_ptr, + valid_idxs_ptr, + branch_logits_ptr, + next_node_stride_B, + next_node_stride_N, + valid_idxs_stride_B, + valid_idxs_stride_N, + # Constants + B: tl.constexpr, + K: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_BRANCHES: tl.constexpr, + max_branches: tl.constexpr, + HAS_BIAS: tl.constexpr, +): + _, offs_B, offs_BR, b_mask, _, branch_valid, logits, any_valid = _ell_fused_prologue( + cur_node_ptr, ell_n_children_ptr, + a_ptr, b_ptr, bias_ptr, ell_cols_vals_ptr, + next_node_ptr, valid_idxs_ptr, + a_stride_B, a_stride_K, b_stride_K, b_stride_N, + ell_node_stride, ell_cv_stride, + next_node_stride_B, next_node_stride_N, + valid_idxs_stride_B, valid_idxs_stride_N, + max_branches, B, K, BLOCK_B, BLOCK_K, BLOCK_BRANCHES, HAS_BIAS, + ) + if any_valid == 0: + return + + store_mask = b_mask[:, None] & (offs_BR[None, :] < max_branches) + tl.store( + branch_logits_ptr + offs_B[:, None] * max_branches + offs_BR[None, :], + tl.where(branch_valid, logits, float("-inf")), + mask=store_mask, + ) diff --git a/rectokens/modules/sparse_linear.py b/rectokens/modules/sparse_linear.py index 62c0f8a..6be34ca 100644 --- a/rectokens/modules/sparse_linear.py +++ b/rectokens/modules/sparse_linear.py @@ -53,7 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self._ctx is None: return self.base_linear(x) - w, bias = self.base_linear.weight, self.base_linear.bias + w, bias = self.base_linear.weight.T, self.base_linear.bias if self._strategy == "sample": next_nodes, valid_idxs, sample = ( @@ -72,7 +72,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: sample, ) # Return zeros — decode_one_step uses self.sample directly. - return x.new_zeros(x.shape[0], w.shape[0]) + return x.new_zeros(x.shape[0], w.shape[1]) if self._strategy == "topk": next_nodes, valid_idxs, topk_logits, topk_idxs = ( diff --git a/rectokens/schemas/compact_ell_trie.py b/rectokens/schemas/compact_ell_trie.py new file mode 100644 index 0000000..ad80219 --- /dev/null +++ b/rectokens/schemas/compact_ell_trie.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from typing import NamedTuple +from typing import TYPE_CHECKING + +import torch +from torch import Tensor + +if TYPE_CHECKING: + from rectokens.schemas.compact_csr_trie import CompactCSRTrie + + +class CompactELLTrie(NamedTuple): + """A trie encoded in ELLPACK (ELL) format for GPU-accelerated constrained decoding. + + Compared to CompactCSRTrie, ELL eliminates the dependent row-pointer load by storing + adjacency data in a padded (num_nodes, 2, max_branches) matrix. Branch k of node i + sits at ell_cols_vals[i, :, k], so its address is pure arithmetic from cur_node — no + prior row_ptrs load is required. This shortens the dependent load chain from depth 3 + to depth 2 in Triton kernels. + + Layout (num_nodes, 2, max_branches): + ell_cols_vals[node, 0, branch] = col (token index), -1 = padding + ell_cols_vals[node, 1, branch] = val (child node id), -1 = padding + + Cols and vals for the same node are max_branches apart in memory (stride(1) = max_branches), + keeping both in the same cache region rather than being separated by node_count*max_branches. + + max_branches = max(layer_max_branches). The dense-phase fields are identical to CompactCSRTrie. + """ + + ell_cols_vals: Tensor # (num_nodes, 2, max_branches): [cols, vals], -1 = padding + n_children: Tensor # (num_nodes,): child count per node — used as a block-level gate + layer_max_branches: list[int] + dense_mask_by_layer: list[Tensor] + dense_states: Tensor + vocab_size: int + + @classmethod + def from_csr(cls, csr: CompactCSRTrie) -> CompactELLTrie: + row_ptrs = csr.row_ptrs + device = row_ptrs.device + num_nodes = len(row_ptrs) + nnz = csr.stacked_cols_vals.shape[1] - 1 # exclude sentinel + max_branches = max(csr.layer_max_branches) if csr.layer_max_branches else 0 + + row_ptrs_ext = torch.cat( + [row_ptrs, torch.tensor([nnz], dtype=row_ptrs.dtype, device=device)] + ) + n_children = row_ptrs_ext.diff() # (num_nodes,) + + ell = torch.full( + (num_nodes, 2, max_branches), -1, dtype=torch.long, device=device + ) + + if nnz > 0: + edge_node_id = torch.repeat_interleave( + torch.arange(num_nodes, device=device), n_children + ) # (nnz,) + edge_local_idx = ( + torch.arange(nnz, device=device) - row_ptrs[edge_node_id] + ) # (nnz,) + + ell[edge_node_id, 0, edge_local_idx] = csr.stacked_cols_vals[0, :nnz] + ell[edge_node_id, 1, edge_local_idx] = csr.stacked_cols_vals[1, :nnz] + + return cls( + ell_cols_vals=ell, + n_children=n_children, + layer_max_branches=csr.layer_max_branches, + dense_mask_by_layer=csr.dense_mask_by_layer, + dense_states=csr.dense_states, + vocab_size=csr.vocab_size, + ) + + def update(self, new_seqs: list[list[int]]) -> "CompactELLTrie": + """Add new sequences incrementally without full reconstruction. + + Traverses the existing trie for each new sequence, inserting only the + nodes and edges that do not already exist. Cost is O(M * L * B) where + M = len(new_seqs), L = sequence length, B = avg children per node — + versus O(N * L) for a CSR full rebuild over N total sequences. + + Note: layer_max_branches, dense_mask_by_layer, and dense_states are + preserved as-is and will be stale for nodes added by this call. + """ + if not new_seqs: + return self + + ell_cv = self.ell_cols_vals.clone() # (num_nodes, 2, max_branches) + n_ch = self.n_children.clone() + max_branches = ell_cv.shape[2] + num_nodes = ell_cv.shape[0] + + # Pre-allocate capacity for worst-case new nodes (one per token in new_seqs) + capacity = sum(len(s) for s in new_seqs) + ell_cv = torch.cat([ + ell_cv, + torch.full((capacity, 2, max_branches), -1, dtype=torch.long), + ]) + n_ch = torch.cat([n_ch, torch.zeros(capacity, dtype=torch.long)]) + + for seq in new_seqs: + node = 0 + for token in seq: + nc = int(n_ch[node]) + child = -1 + if nc > 0: + hit = (ell_cv[node, 0, :nc] == token).nonzero(as_tuple=False) + if len(hit): + child = int(ell_cv[node, 1, hit[0, 0]]) + + if child == -1: + if nc == max_branches: + new_mb = max_branches * 2 + expanded = torch.full( + (ell_cv.shape[0], 2, new_mb), -1, + dtype=torch.long, device=ell_cv.device, + ) + expanded[:, :, :max_branches] = ell_cv + ell_cv = expanded + max_branches = new_mb + child = num_nodes + num_nodes += 1 + ell_cv[node, 0, nc] = token + ell_cv[node, 1, nc] = child + n_ch[node] += 1 + + node = child + + return CompactELLTrie( + ell_cols_vals=ell_cv[:num_nodes].contiguous(), + n_children=n_ch[:num_nodes].contiguous(), + layer_max_branches=self.layer_max_branches, + dense_mask_by_layer=self.dense_mask_by_layer, + dense_states=self.dense_states, + vocab_size=self.vocab_size, + ) + + +class MutableELLTrie: + """ELL trie with pre-allocated capacity for zero-copy incremental updates. + + Unlike CompactELLTrie.update(), which clones the full tensor on every call, + MutableELLTrie pre-allocates a fixed capacity buffer and writes new nodes + directly into it. Each update() call is purely O(M * L * B) — proportional + to the new sequences, not the total trie size. + + Use this when you need repeated incremental updates without rebuilding. + """ + + def __init__( + self, + ell: CompactELLTrie, + extra_capacity: int, + ) -> None: + num_nodes = ell.ell_cols_vals.shape[0] + max_branches = ell.ell_cols_vals.shape[2] + device = ell.ell_cols_vals.device + + # Allocate buffer large enough for existing nodes + future inserts. + total = num_nodes + extra_capacity + self.ell_cv = torch.full((total, 2, max_branches), -1, dtype=torch.long, device=device) + self.ell_cv[:num_nodes] = ell.ell_cols_vals + + self.n_ch = torch.zeros(total, dtype=torch.long, device=device) + self.n_ch[:num_nodes] = ell.n_children + + self.num_nodes = num_nodes + self.max_branches = max_branches + self.vocab_size = ell.vocab_size + + def update(self, new_seqs: list[list[int]]) -> None: + """Insert new sequences in-place. No allocation is performed unless + a node needs more children than max_branches (rare; triggers realloc).""" + for seq in new_seqs: + node = 0 + for token in seq: + nc = int(self.n_ch[node]) + child = -1 + if nc > 0: + hit = (self.ell_cv[node, 0, :nc] == token).nonzero(as_tuple=False) + if len(hit): + child = int(self.ell_cv[node, 1, hit[0, 0]]) + + if child == -1: + if nc == self.max_branches: + new_mb = self.max_branches * 2 + expanded = torch.full( + (self.ell_cv.shape[0], 2, new_mb), -1, + dtype=torch.long, device=self.ell_cv.device, + ) + expanded[:, :, :self.max_branches] = self.ell_cv + self.ell_cv = expanded + self.max_branches = new_mb + child = self.num_nodes + self.num_nodes += 1 + self.ell_cv[node, 0, nc] = token + self.ell_cv[node, 1, nc] = child + self.n_ch[node] += 1 + + node = child diff --git a/tests/test_kernel_ell.py b/tests/test_kernel_ell.py new file mode 100644 index 0000000..6fa83fd --- /dev/null +++ b/tests/test_kernel_ell.py @@ -0,0 +1,509 @@ +"""Tests that the ELL-based fused kernels produce identical outputs to the CSR-based kernels.""" +from __future__ import annotations + +import unittest + +import torch +import torch.nn.functional as F + +if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA required") + +from rectokens.schemas.compact_csr_trie import CompactCSRTrie +from rectokens.schemas.compact_ell_trie import CompactELLTrie +from rectokens.kernels.constrained_node_transition import ( + _fused_linear_constrained_node_transition_sampling_op as csr_sampling_op, + _fused_linear_constrained_node_transition_topk_op as csr_topk_op, +) +from rectokens.kernels.constrained_node_transition_ell import ( + _ell_fused_linear_constrained_node_transition_sampling_op as ell_sampling_op, + _ell_fused_linear_constrained_node_transition_topk_op as ell_topk_op, +) +from rectokens.decoding.vntk import ( + sparse_linear_pytorch, + sparse_linear_compact_pytorch, + sparse_linear_ell_pytorch, + sparse_linear_compact_ell_pytorch, +) + + +DEVICE = torch.device("cuda") +VOCAB_SIZE = 8 + + +def lex_sort(rows: list[list[int]]) -> torch.Tensor: + return torch.tensor(sorted(rows), dtype=torch.long) + + +def to_device(csr: CompactCSRTrie) -> CompactCSRTrie: + return csr._replace( + row_ptrs=csr.row_ptrs.to(DEVICE), + stacked_cols_vals=csr.stacked_cols_vals.to(DEVICE), + dense_mask_by_layer=[v.to(DEVICE) for v in csr.dense_mask_by_layer], + dense_states=csr.dense_states.to(DEVICE), + ) + + +def ell_to_device(ell: CompactELLTrie) -> CompactELLTrie: + return ell._replace( + ell_cols_vals=ell.ell_cols_vals.to(DEVICE), + n_children=ell.n_children.to(DEVICE), + dense_mask_by_layer=[v.to(DEVICE) for v in ell.dense_mask_by_layer], + dense_states=ell.dense_states.to(DEVICE), + ) + + +def make_tries(seqs: list[list[int]], vocab_size: int) -> tuple[CompactCSRTrie, CompactELLTrie]: + csr = to_device(CompactCSRTrie.from_sorted_batch(lex_sort(seqs), vocab_size=vocab_size)) + ell = ell_to_device(CompactELLTrie.from_csr(csr)) + return csr, ell + + +class TestELLvCSRSampling(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + seqs_small = [[1, 2, 1], [3, 1, 2], [3, 1, 3]] + cls.csr_small, cls.ell_small = make_tries(seqs_small, VOCAB_SIZE) + + seqs_dense = [[i, j, k] for i in range(4) for j in range(4) for k in range(4)] + cls.csr_dense, cls.ell_dense = make_tries(seqs_dense, vocab_size=16) + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + + def _run_both( + self, + B: int, + K: int, + step: int, + cur_node_vals: list[int], + seed: int, + csr: CompactCSRTrie, + ell: CompactELLTrie, + ): + torch.manual_seed(0) + a = torch.randn(B, K, device=DEVICE) + b = torch.randn(K, csr.vocab_size, device=DEVICE) + cur_node = torch.tensor(cur_node_vals, device=DEVICE) + bias_val = a.new_empty(0) + max_branches = csr.layer_max_branches[step] + + csr_out = csr_sampling_op( + a, b, bias_val, cur_node, + csr.row_ptrs, csr.stacked_cols_vals, + max_branches, False, rng_seed=seed, + ) + ell_out = ell_sampling_op( + a, b, bias_val, cur_node, + ell.ell_cols_vals, ell.n_children, + max_branches, False, rng_seed=seed, + ) + return csr_out, ell_out + + def _assert_sampling_match( + self, B, K, step, cur_node_vals, seed, csr=None, ell=None + ): + if csr is None: + csr, ell = self.csr_small, self.ell_small + (csr_nn, csr_vi, csr_s), (ell_nn, ell_vi, ell_s) = self._run_both( + B, K, step, cur_node_vals, seed, csr, ell + ) + self.assertTrue(torch.equal(ell_nn, csr_nn), "next_node mismatch") + self.assertTrue(torch.equal(ell_vi, csr_vi), "valid_idxs mismatch") + self.assertTrue(torch.equal(ell_s, csr_s), "sample mismatch") + + # ------------------------------------------------------------------------- + # next_node / valid_idxs / sample agree with CSR + # ------------------------------------------------------------------------- + + def test_b1_step0(self) -> None: + self._assert_sampling_match(1, 16, 0, [0], seed=42) + + def test_b2_step0_same_node(self) -> None: + self._assert_sampling_match(2, 16, 0, [0, 0], seed=7) + + def test_b2_step1_diff_nodes(self) -> None: + self._assert_sampling_match(2, 16, 1, [1, 2], seed=99) + + def test_b3_step2(self) -> None: + self._assert_sampling_match(3, 16, 2, [3, 4, 3], seed=123) + + def test_b8_large_k(self) -> None: + self._assert_sampling_match(8, 128, 0, [0] * 8, seed=1) + + def test_b32_step0(self) -> None: + self._assert_sampling_match(32, 64, 0, [0] * 32, seed=5) + + def test_dense_trie_step0(self) -> None: + self._assert_sampling_match( + 8, 16, 0, [0] * 8, seed=42, csr=self.csr_dense, ell=self.ell_dense + ) + + def test_dense_trie_step1(self) -> None: + self._assert_sampling_match( + 4, 16, 1, [1, 2, 3, 4], seed=7, csr=self.csr_dense, ell=self.ell_dense + ) + + # ------------------------------------------------------------------------- + # Determinism: same seed → same result (ELL-internal consistency) + # ------------------------------------------------------------------------- + + def test_deterministic_b1(self) -> None: + _, (_, _, s1) = self._run_both(1, 16, 0, [0], 42, self.csr_small, self.ell_small) + _, (_, _, s2) = self._run_both(1, 16, 0, [0], 42, self.csr_small, self.ell_small) + self.assertTrue(torch.equal(s1, s2)) + + def test_deterministic_b4(self) -> None: + (_, _, s1), _ = self._run_both(4, 16, 0, [0, 0, 0, 0], 77, self.csr_small, self.ell_small) + (_, _, s2), _ = self._run_both(4, 16, 0, [0, 0, 0, 0], 77, self.csr_small, self.ell_small) + self.assertTrue(torch.equal(s1, s2)) + + # ------------------------------------------------------------------------- + # Sample must be a valid child of the current node + # ------------------------------------------------------------------------- + + def test_sample_is_valid_child_step0(self) -> None: + _, (_, vi, sample) = self._run_both( + 1, 16, 0, [0], 123, self.csr_small, self.ell_small + ) + valid = vi[0][vi[0] >= 0].tolist() + self.assertIn(int(sample[0].item()), valid) + + def test_sample_is_valid_child_step1(self) -> None: + _, (_, vi, sample) = self._run_both( + 2, 16, 1, [1, 2], 456, self.csr_small, self.ell_small + ) + for b in range(2): + valid = vi[b][vi[b] >= 0].tolist() + self.assertIn(int(sample[b].item()), valid) + + def test_sample_is_valid_child_step2(self) -> None: + _, (_, vi, sample) = self._run_both( + 3, 16, 2, [3, 4, 3], 789, self.csr_small, self.ell_small + ) + for b in range(3): + valid = vi[b][vi[b] >= 0].tolist() + self.assertIn(int(sample[b].item()), valid) + + # ------------------------------------------------------------------------- + # Bias path + # ------------------------------------------------------------------------- + + def test_with_bias(self) -> None: + B, K, step = 2, 16, 0 + cur_node_vals = [0, 0] + seed = 42 + torch.manual_seed(0) + a = torch.randn(B, K, device=DEVICE) + b = torch.randn(K, VOCAB_SIZE, device=DEVICE) + bias = torch.randn(VOCAB_SIZE, device=DEVICE) + cur_node = torch.tensor(cur_node_vals, device=DEVICE) + max_branches = self.csr_small.layer_max_branches[step] + + csr_nn, csr_vi, csr_s = csr_sampling_op( + a, b, bias, cur_node, + self.csr_small.row_ptrs, self.csr_small.stacked_cols_vals, + max_branches, True, rng_seed=seed, + ) + ell_nn, ell_vi, ell_s = ell_sampling_op( + a, b, bias, cur_node, + self.ell_small.ell_cols_vals, self.ell_small.n_children, + max_branches, True, rng_seed=seed, + ) + self.assertTrue(torch.equal(ell_nn, csr_nn)) + self.assertTrue(torch.equal(ell_vi, csr_vi)) + self.assertTrue(torch.equal(ell_s, csr_s)) + + +class TestELLvCSRTopK(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + seqs_small = [[1, 2, 1], [3, 1, 2], [3, 1, 3]] + cls.csr_small, cls.ell_small = make_tries(seqs_small, VOCAB_SIZE) + + seqs_dense = [[i, j, k] for i in range(4) for j in range(4) for k in range(4)] + cls.csr_dense, cls.ell_dense = make_tries(seqs_dense, vocab_size=16) + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + + def _run_both( + self, + B: int, + K: int, + step: int, + cur_node_vals: list[int], + k: int, + csr: CompactCSRTrie, + ell: CompactELLTrie, + ): + torch.manual_seed(42) + a = torch.randn(B, K, device=DEVICE) + b = torch.randn(K, csr.vocab_size, device=DEVICE) + cur_node = torch.tensor(cur_node_vals, device=DEVICE) + bias_val = a.new_empty(0) + max_branches = csr.layer_max_branches[step] + + csr_out = csr_topk_op( + a, b, bias_val, cur_node, + csr.row_ptrs, csr.stacked_cols_vals, + max_branches, False, k, + ) + ell_out = ell_topk_op( + a, b, bias_val, cur_node, + ell.ell_cols_vals, ell.n_children, + max_branches, False, k, + ) + return csr_out, ell_out + + def _assert_topk_match(self, B, K, step, cur_node_vals, k, csr=None, ell=None): + if csr is None: + csr, ell = self.csr_small, self.ell_small + (csr_nn, csr_vi, csr_tl, csr_ti), (ell_nn, ell_vi, ell_tl, ell_ti) = self._run_both( + B, K, step, cur_node_vals, k, csr, ell + ) + + self.assertTrue(torch.equal(ell_nn, csr_nn), "next_node mismatch") + self.assertTrue(torch.equal(ell_vi, csr_vi), "valid_idxs mismatch") + + # Sort along k-dim to handle tie-breaking, then compare. + self.assertTrue( + torch.allclose( + ell_tl.float().sort(dim=-1).values, + csr_tl.float().sort(dim=-1).values, + atol=1e-3, + equal_nan=True, + ), + f"topk logits mismatch\nELL: {ell_tl}\nCSR: {csr_tl}", + ) + self.assertTrue( + torch.equal(ell_ti.sort(dim=-1).values, csr_ti.sort(dim=-1).values), + f"topk idxs mismatch\nELL: {ell_ti}\nCSR: {csr_ti}", + ) + + # ------------------------------------------------------------------------- + # next_node / valid_idxs / topk_logits / topk_idxs agree with CSR + # ------------------------------------------------------------------------- + + def test_k1_b1_step0(self) -> None: + self._assert_topk_match(1, 16, 0, [0], k=1) + + def test_k1_b2_step1(self) -> None: + self._assert_topk_match(2, 16, 1, [1, 2], k=1) + + def test_k2_b1_step0(self) -> None: + self._assert_topk_match(1, 16, 0, [0], k=2) + + def test_k2_b2_step1(self) -> None: + self._assert_topk_match(2, 16, 1, [1, 2], k=2) + + def test_k1_b3_step2(self) -> None: + self._assert_topk_match(3, 16, 2, [3, 4, 3], k=1) + + def test_k2_b3_step2(self) -> None: + self._assert_topk_match(3, 16, 2, [3, 4, 3], k=2) + + def test_k2_b8_large_k(self) -> None: + self._assert_topk_match(8, 128, 0, [0] * 8, k=2) + + def test_k_exceeds_branches(self) -> None: + # When k >= max_branches the kernel returns all branches; both CSR and ELL do the same. + self._assert_topk_match(2, 16, 1, [1, 2], k=10) + + def test_dense_trie_step0(self) -> None: + self._assert_topk_match( + 8, 16, 0, [0] * 8, k=2, csr=self.csr_dense, ell=self.ell_dense + ) + + def test_dense_trie_step1_k3(self) -> None: + self._assert_topk_match( + 4, 16, 1, [1, 2, 3, 4], k=3, csr=self.csr_dense, ell=self.ell_dense + ) + + def test_b32_step0(self) -> None: + self._assert_topk_match(32, 64, 0, [0] * 32, k=2) + + # ------------------------------------------------------------------------- + # Bias path + # ------------------------------------------------------------------------- + + def test_with_bias(self) -> None: + B, K, step, k = 2, 16, 0, 1 + cur_node_vals = [0, 0] + torch.manual_seed(0) + a = torch.randn(B, K, device=DEVICE) + b = torch.randn(K, VOCAB_SIZE, device=DEVICE) + bias = torch.randn(VOCAB_SIZE, device=DEVICE) + cur_node = torch.tensor(cur_node_vals, device=DEVICE) + max_branches = self.csr_small.layer_max_branches[step] + + csr_nn, csr_vi, csr_tl, csr_ti = csr_topk_op( + a, b, bias, cur_node, + self.csr_small.row_ptrs, self.csr_small.stacked_cols_vals, + max_branches, True, k, + ) + ell_nn, ell_vi, ell_tl, ell_ti = ell_topk_op( + a, b, bias, cur_node, + self.ell_small.ell_cols_vals, self.ell_small.n_children, + max_branches, True, k, + ) + self.assertTrue(torch.equal(ell_nn, csr_nn)) + self.assertTrue(torch.equal(ell_vi, csr_vi)) + self.assertTrue( + torch.allclose(ell_tl.float(), csr_tl.float(), atol=1e-3, equal_nan=True) + ) + self.assertTrue(torch.equal(ell_ti, csr_ti)) + + +class TestELLPytorchvCSRPytorch(unittest.TestCase): + """ELL sparse_linear_pytorch outputs match CSR sparse_linear_pytorch outputs.""" + + @classmethod + def setUpClass(cls) -> None: + seqs_small = [[1, 2, 1], [3, 1, 2], [3, 1, 3]] + cls.csr_small, cls.ell_small = make_tries(seqs_small, VOCAB_SIZE) + + seqs_dense = [[i, j, k] for i in range(4) for j in range(4) for k in range(4)] + cls.csr_dense, cls.ell_dense = make_tries(seqs_dense, vocab_size=16) + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + + def _run_both(self, B, K, step, cur_node_vals, csr, ell): + torch.manual_seed(0) + a = torch.randn(B, K, device=DEVICE) + weight = torch.randn(csr.vocab_size, K, device=DEVICE) + cur_node = torch.tensor(cur_node_vals, device=DEVICE) + + csr_out = sparse_linear_pytorch(a, weight, cur_node, csr, step) + ell_out = sparse_linear_ell_pytorch(a, weight, cur_node, ell, step) + return csr_out, ell_out + + def _run_both_compact(self, B, K, step, cur_node_vals, csr, ell): + torch.manual_seed(0) + a = torch.randn(B, K, device=DEVICE) + weight = torch.randn(csr.vocab_size, K, device=DEVICE) + cur_node = torch.tensor(cur_node_vals, device=DEVICE) + + csr_out = sparse_linear_compact_pytorch(a, weight, cur_node, csr, step) + ell_out = sparse_linear_compact_ell_pytorch(a, weight, cur_node, ell, step) + return csr_out, ell_out + + def _assert_full_match(self, B, K, step, cur_node_vals, csr=None, ell=None): + if csr is None: + csr, ell = self.csr_small, self.ell_small + (csr_nn, csr_vi, csr_cl), (ell_nn, ell_vi, ell_cl) = self._run_both( + B, K, step, cur_node_vals, csr, ell + ) + self.assertTrue(torch.equal(ell_nn, csr_nn), "next_node mismatch") + self.assertTrue(torch.equal(ell_vi, csr_vi), "valid_idxs mismatch") + self.assertTrue( + torch.allclose(ell_cl.float(), csr_cl.float(), equal_nan=True), + "corrected_logits mismatch", + ) + + def _assert_compact_match(self, B, K, step, cur_node_vals, csr=None, ell=None): + if csr is None: + csr, ell = self.csr_small, self.ell_small + (csr_nn, csr_vi, csr_bl), (ell_nn, ell_vi, ell_bl) = self._run_both_compact( + B, K, step, cur_node_vals, csr, ell + ) + self.assertTrue(torch.equal(ell_nn, csr_nn), "next_node mismatch") + self.assertTrue(torch.equal(ell_vi, csr_vi), "valid_idxs mismatch") + self.assertTrue( + torch.allclose(ell_bl.float(), csr_bl.float(), equal_nan=True), + "branch_logits mismatch", + ) + + # ------------------------------------------------------------------------- + # sparse_linear_ell_pytorch matches sparse_linear_pytorch (full scatter) + # ------------------------------------------------------------------------- + + def test_full_b1_step0(self) -> None: + self._assert_full_match(1, 16, 0, [0]) + + def test_full_b2_step0_same_node(self) -> None: + self._assert_full_match(2, 16, 0, [0, 0]) + + def test_full_b2_step1_diff_nodes(self) -> None: + self._assert_full_match(2, 16, 1, [1, 2]) + + def test_full_b3_step2(self) -> None: + self._assert_full_match(3, 16, 2, [3, 4, 3]) + + def test_full_b8_large_k(self) -> None: + self._assert_full_match(8, 128, 0, [0] * 8) + + def test_full_dense_trie_step0(self) -> None: + self._assert_full_match( + 8, 16, 0, [0] * 8, csr=self.csr_dense, ell=self.ell_dense + ) + + def test_full_dense_trie_step1(self) -> None: + self._assert_full_match( + 4, 16, 1, [1, 2, 3, 4], csr=self.csr_dense, ell=self.ell_dense + ) + + # ------------------------------------------------------------------------- + # sparse_linear_compact_ell_pytorch matches sparse_linear_compact_pytorch + # ------------------------------------------------------------------------- + + def test_compact_b1_step0(self) -> None: + self._assert_compact_match(1, 16, 0, [0]) + + def test_compact_b2_step1_diff_nodes(self) -> None: + self._assert_compact_match(2, 16, 1, [1, 2]) + + def test_compact_b3_step2(self) -> None: + self._assert_compact_match(3, 16, 2, [3, 4, 3]) + + def test_compact_b8_large_k(self) -> None: + self._assert_compact_match(8, 128, 0, [0] * 8) + + def test_compact_dense_trie_step1(self) -> None: + self._assert_compact_match( + 4, 16, 1, [1, 2, 3, 4], csr=self.csr_dense, ell=self.ell_dense + ) + + # ------------------------------------------------------------------------- + # sample built on ELL full output matches sample built on CSR full output + # ------------------------------------------------------------------------- + + def test_sample_valid_child_step0(self) -> None: + torch.manual_seed(7) + a = torch.randn(2, 16, device=DEVICE) + weight = torch.randn(VOCAB_SIZE, 16, device=DEVICE) + cur_node = torch.tensor([0, 0], device=DEVICE) + _, vi, corrected_logits = sparse_linear_ell_pytorch( + a, weight, cur_node, self.ell_small, step=0 + ) + probs = F.softmax(corrected_logits, dim=-1) + sample = torch.multinomial(probs, num_samples=1).squeeze(-1) + for b in range(2): + valid = vi[b][vi[b] >= 0].tolist() + self.assertIn(int(sample[b].item()), valid) + + # ------------------------------------------------------------------------- + # top-k built on ELL compact output matches top-k built on CSR compact output + # ------------------------------------------------------------------------- + + def test_topk_compact_b2_k1_step1(self) -> None: + torch.manual_seed(42) + B, K, step, k = 2, 16, 1, 1 + a = torch.randn(B, K, device=DEVICE) + weight = torch.randn(VOCAB_SIZE, K, device=DEVICE) + cur_node = torch.tensor([1, 2], device=DEVICE) + + _, csr_vi, csr_bl = sparse_linear_compact_pytorch(a, weight, cur_node, self.csr_small, step) + csr_topk_l, csr_topk_bi = torch.topk(csr_bl, k, dim=-1) + csr_topk_i = csr_vi.gather(1, csr_topk_bi) + + _, ell_vi, ell_bl = sparse_linear_compact_ell_pytorch(a, weight, cur_node, self.ell_small, step) + ell_topk_l, ell_topk_bi = torch.topk(ell_bl, k, dim=-1) + ell_topk_i = ell_vi.gather(1, ell_topk_bi) + + self.assertTrue(torch.allclose(ell_topk_l, csr_topk_l, equal_nan=True)) + self.assertTrue(torch.equal(ell_topk_i, csr_topk_i))