From 4f09aaf31b51a1e11871c66d78f368400e02eedd Mon Sep 17 00:00:00 2001 From: zhangyangjie Date: Mon, 27 Apr 2026 10:57:18 +0800 Subject: [PATCH 1/2] add gdn and kda test --- README_linear_attention.md | 130 ++++ benchmark_linear_attention_registry.py | 177 +++++ benchmark_linear_attention_run.py | 290 ++++++++ linear_attention_conftest.py | 184 +++++ test_gated_delta.py | 974 +++++++++++++++++++++++++ test_kda.py | 915 +++++++++++++++++++++++ 6 files changed, 2670 insertions(+) create mode 100644 README_linear_attention.md create mode 100644 benchmark_linear_attention_registry.py create mode 100644 benchmark_linear_attention_run.py create mode 100644 linear_attention_conftest.py create mode 100644 test_gated_delta.py create mode 100644 test_kda.py diff --git a/README_linear_attention.md b/README_linear_attention.md new file mode 100644 index 0000000..9f010ea --- /dev/null +++ b/README_linear_attention.md @@ -0,0 +1,130 @@ +# linear_attn + +Triton-based GDN (Gated Delta Networks) and KDA (Kimi Delta Attention) operators with chunk-wise and fused-recurrent execution modes. + +## Dependencies + +- PaddlePaddle-GPU (GPU required) +- triton +- pytest (for tests) +- einops (for GDN tests) + +## Environment Setup + +```bash +# Install the flash_mask package (includes linear_attn) +cd /path/to/flash-attention/flashmask +pip install -e . --no-build-isolation +``` + +## Running Tests + +Test files are located at `test_flashmask/`: + +| File | Description | +|------|-------------| +| `test_gated_delta.py` | GDN operator correctness tests | +| `test_kda.py` | KDA operator correctness tests | + +Each test compares the Triton-optimized implementation against a naive Python reference, checking both forward output and backward gradients. + +```bash +cd /path/to/test_flashmask + +# Run all tests +pytest test_gated_delta.py test_kda.py -v + +# Run GDN tests only +pytest test_gated_delta.py -v + +# Run KDA tests only +pytest test_kda.py -v + +# Run a single test function +pytest test_gated_delta.py::test_fused_recurrent -v +pytest test_kda.py::test_chunk -v + +# Filter by parametrized id +pytest test_gated_delta.py -k "test_fused_recurrent and B1-T63" -v +``` + +### Optional Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `SKIP_TEST_CHUNK_VARLEN=1` | unset | Skip varlen (variable-length sequence) tests | +| `FLA_BENCHMARK=1` | `0` | Disable driver probing overhead | + +```bash +SKIP_TEST_CHUNK_VARLEN=1 pytest test_gated_delta.py test_kda.py -v +``` + +## Running Benchmarks + +The benchmark framework is located at `test_flashmask/` and supports 4 operators: + +| Operator | Description | Modes | +|----------|-------------|-------| +| `chunk_gdn` | GDN chunk-level | fwd / fwdbwd | +| `chunk_kda` | KDA chunk-level | fwd / fwdbwd | +| `recurrent_gdn` | GDN fused recurrent | fwd only | +| `recurrent_kda` | KDA fused recurrent | fwd only | + +```bash +cd /path/to/test_flashmask + +# List registered operators +python benchmark_linear_attention_run.py --list + +# Run all benchmarks +python benchmark_linear_attention_run.py --op all + +# Run specific operators +python benchmark_linear_attention_run.py --op chunk_gdn +python benchmark_linear_attention_run.py --op chunk_kda recurrent_kda + +# Forward only +python benchmark_linear_attention_run.py --op chunk_gdn --modes fwd + +# Custom shapes +python benchmark_linear_attention_run.py --op chunk_gdn \ + --custom-shapes '{"smoke":{"B":1,"T":64,"H":2,"D":32}}' + +# Save results as JSON +python benchmark_linear_attention_run.py --op all --json results.json +``` + +### Default Shape Configs + +| Config Name | B | T | H | D | +|-------------|---|---|---|---| +| B1_T8192_H96_D128 | 1 | 8192 | 96 | 128 | +| B2_T16384_H16_D128 | 2 | 16384 | 16 | 128 | +| B4_T2048_H16_D128 | 4 | 2048 | 16 | 128 | +| B4_T4096_H64_D128 | 4 | 4096 | 64 | 128 | +| B8_T2048_H32_D256 | 8 | 2048 | 32 | 256 | +| B8_T1024_H8_D64 | 8 | 1024 | 8 | 64 | + +### Benchmark Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `FLA_BENCH_OP_WARMUP_ITERS` | `5` | Number of warmup iterations | +| `FLA_BENCH_WARMUP_MS` | `100` | `do_bench` warmup time (ms) | +| `FLA_BENCH_REP_MS` | `500` | `do_bench` repeat measurement time (ms) | + +## Known Limitations + +- Context Parallel (CP) is NOT supported. +- `fused_recurrent_gdn` / `fused_recurrent_kda` are forward-only. Use `chunk_gdn` / `chunk_kda` for training workloads that require gradients. + +## Known Issue: GDN Backward Precision on Hopper GPUs with Triton >= 3.4.0 + +The upstream fla-org/flash-linear-attention project has identified a backward precision issue in the gated `chunk_bwd_dqkwg` kernel when running on Hopper-class GPUs (H20, H100, GB200, etc.) with Triton >= 3.4.0 ([upstream PR #827](https://github.com/fla-org/flash-linear-attention/pull/827)). The upstream fix introduces a TileLang-based kernel as an alternative backend. + +**Current status in this fork:** +- On NVIDIA H800 (Hopper) with the current Triton version, this issue has **not** been observed in practice. +- If you plan to deploy on other Hopper GPUs (H20, GB200, etc.), or upgrade Triton to >= 3.4.0, you may encounter this backward precision regression. +- The TileLang backend has **not** been integrated into this Paddle port yet. + +**Action needed:** When targeting Hopper GPUs other than H800 or upgrading Triton, consider integrating the TileLang backend from the upstream fix (`pip install tilelang` + dispatch logic in `fla/ops/common/chunk_o.py`). diff --git a/benchmark_linear_attention_registry.py b/benchmark_linear_attention_registry.py new file mode 100644 index 0000000..ca3b2c7 --- /dev/null +++ b/benchmark_linear_attention_registry.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import logging +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import paddle +import paddle.nn.functional as F + +logger = logging.getLogger(__name__) + + +def shape_BTHD(B, T, H, D, **kw): + return (B, T, H, D) + + +def shape_BTH(B, T, H, D, **kw): + return (B, T, H) + + +logsigmoid = F.log_sigmoid + + +def sigmoid_transform(t): + return t.sigmoid() + + +@dataclass +class TensorSpec: + shape_fn: Callable + requires_grad: bool = True + dtype: Any = 'default' + transform: Callable | None = None + + +@dataclass +class OpConfig: + name: str + import_path: str + inputs: dict[str, TensorSpec] + func_name: str | None = None + extra_kwargs: dict[str, Any] = field(default_factory=dict) + output_is_tuple: bool = True + skip_backward: bool = False + category: str = '' + + +_REGISTRY: dict[str, OpConfig] = {} + + +def register_op(config: OpConfig) -> None: + _REGISTRY[config.name] = config + + +SHAPE_CONFIGS = { + 'B1_T8192_H96_D128': {'B': 1, 'T': 8192, 'H': 96, 'D': 128}, + 'B2_T16384_H16_D128': {'B': 2, 'T': 16384, 'H': 16, 'D': 128}, + 'B4_T2048_H16_D128': {'B': 4, 'T': 2048, 'H': 16, 'D': 128}, + 'B4_T4096_H64_D128': {'B': 4, 'T': 4096, 'H': 64, 'D': 128}, + 'B8_T2048_H32_D256': {'B': 8, 'T': 2048, 'H': 32, 'D': 256}, + 'B8_T1024_H8_D64': {'B': 8, 'T': 1024, 'H': 8, 'D': 64}, +} + + +def get_op(name: str) -> OpConfig: + if name not in _REGISTRY: + raise KeyError(f"Op '{name}' not registered. Available: {sorted(_REGISTRY)}") + return _REGISTRY[name] + + +def list_ops() -> list[str]: + return sorted(_REGISTRY.keys()) + + +def _resolve_dtype(dtype): + if dtype == 'default': + return paddle.bfloat16 + if dtype == 'float32': + return paddle.float32 + if dtype == 'int64': + return paddle.int64 + return dtype + + +def _set_device(device: str | None): + if device is None: + return + current = paddle.get_device() + if current != device: + paddle.device.set_device(device) + + +def generate_inputs( + config: OpConfig, + B: int, + T: int, + H: int, + D: int, + dtype=paddle.bfloat16, + device: str | None = None, +) -> dict[str, paddle.Tensor]: + _set_device(device) + inputs: dict[str, paddle.Tensor] = {} + for param_name, spec in config.inputs.items(): + shape = spec.shape_fn(B, T, H, D) + tensor_dtype = dtype if spec.dtype == 'default' else _resolve_dtype(spec.dtype) + if tensor_dtype == paddle.int64: + tensor = paddle.randint(0, 10, shape=shape, dtype=tensor_dtype) + else: + tensor = paddle.randn(shape, dtype=tensor_dtype) + if spec.transform is not None: + tensor = spec.transform(tensor) + if spec.requires_grad and paddle.is_floating_point(tensor): + tensor.stop_gradient = False + inputs[param_name] = tensor + return inputs + + +_simple_qkv = { + 'q': TensorSpec(shape_BTHD), + 'k': TensorSpec(shape_BTHD), + 'v': TensorSpec(shape_BTHD), +} + +register_op(OpConfig( + name='chunk_gdn', + import_path='flash_mask.linear_attn.ops.gated_delta_rule', + func_name='chunk_gated_delta_rule', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTH, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True}, + category='gate_beta', +)) + +register_op(OpConfig( + name='chunk_kda', + import_path='flash_mask.linear_attn.ops.kda', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTHD, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True, 'safe_gate': True, 'lower_bound': -5}, + category='gate_beta', +)) + +register_op(OpConfig( + name='recurrent_gdn', + import_path='flash_mask.linear_attn.ops.gated_delta_rule', + func_name='fused_recurrent_gated_delta_rule', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTH, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True}, + skip_backward=True, + category='gate_beta', +)) + +register_op(OpConfig( + name='recurrent_kda', + import_path='flash_mask.linear_attn.ops.kda', + func_name='fused_recurrent_kda', + inputs={ + **_simple_qkv, + 'g': TensorSpec(shape_BTHD, transform=logsigmoid), + 'beta': TensorSpec(shape_BTH, transform=sigmoid_transform), + }, + extra_kwargs={'use_qk_l2norm_in_kernel': True, 'safe_gate': True, 'lower_bound': -5}, + skip_backward=True, + category='gate_beta', +)) diff --git a/benchmark_linear_attention_run.py b/benchmark_linear_attention_run.py new file mode 100644 index 0000000..530b1a0 --- /dev/null +++ b/benchmark_linear_attention_run.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import argparse +import importlib +import json +import logging +import os +import platform +import socket +import sys +from contextlib import contextmanager + +import paddle + +# Must call paddle.enable_compat(scope={"triton"}) BEFORE import triton. +# In a pure-Paddle env this registers the Paddle triton driver so that +# triton can discover it during initialization. In a mixed torch+paddle +# env this is also safe — both drivers are registered and the +# swap_driver_guard / activate_paddle_driver mechanism handles switching. +paddle.enable_compat(scope={"triton"}) + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from benchmark_linear_attention_registry import SHAPE_CONFIGS, OpConfig, generate_inputs, get_op, list_ops # noqa: E402 + +logger = logging.getLogger(__name__) + + +@contextmanager +def _activate_paddle_driver(): + try: + from flash_mask.linear_attn.triton_utils import paddle_driver + from triton.runtime.driver import driver + except Exception: + yield + return + + if paddle_driver is None: + yield + return + + driver.set_active(paddle_driver) + try: + yield + finally: + driver.reset_active() + + +def _import_op(config: OpConfig): + mod = importlib.import_module(config.import_path) + attr = config.func_name or config.name + fn = getattr(mod, attr, None) + if fn is None: + raise ImportError( + f"Cannot find '{attr}' in module '{config.import_path}'. " + f"Available: {[x for x in dir(mod) if not x.startswith('_')] }" + ) + return fn + + +def _get_machine_info() -> dict: + info = { + 'hostname': socket.gethostname(), + 'platform': platform.platform(), + 'paddle_version': paddle.__version__, + } + try: + import triton + info['triton_version'] = triton.__version__ + except Exception: + info['triton_version'] = 'N/A' + + if paddle.device.is_compiled_with_cuda(): + info['gpu_name'] = paddle.device.cuda.get_device_name() + info['gpu_count'] = paddle.device.cuda.device_count() + else: + info['gpu_name'] = 'N/A' + info['gpu_count'] = 0 + return info + + +def _warmup_iters() -> int: + return max(1, int(os.environ.get('FLA_BENCH_OP_WARMUP_ITERS', '5'))) + + +def _do_bench_kw(): + warmup_ms = int(os.environ.get('FLA_BENCH_WARMUP_MS', '100')) + rep_ms = int(os.environ.get('FLA_BENCH_REP_MS', '500')) + return {'warmup': max(1, warmup_ms), 'rep': max(1, rep_ms)} + + +def _synchronize(): + if paddle.device.is_compiled_with_cuda(): + paddle.device.synchronize() + + +def _clear_gradients(inputs: dict[str, paddle.Tensor]): + for tensor in inputs.values(): + if isinstance(tensor, paddle.Tensor) and not tensor.stop_gradient: + tensor.clear_gradient(False) + + +def _backward(tensor: paddle.Tensor, grad: paddle.Tensor): + paddle.autograd.backward([tensor], [grad]) + + +def _warmup_autotune(fn, n: int | None = None): + if n is None: + n = _warmup_iters() + with _activate_paddle_driver(): + for _ in range(n): + fn() + _synchronize() + + +def benchmark_op( + op_name: str, + shapes: dict[str, dict[str, int]], + modes: list[str] | None = None, +) -> list[dict]: + import triton + + if modes is None: + modes = ['fwd', 'fwdbwd'] + + config = get_op(op_name) + op_fn = _import_op(config) + if config.skip_backward and 'fwdbwd' in modes: + modes = [mode for mode in modes if mode != 'fwdbwd'] + + dtype = paddle.bfloat16 + device = 'gpu' + + print(f"\n [{op_name}] Warming up {len(shapes)} shape(s)...") + failed_shapes = set() + for shape_name, shape_dict in shapes.items(): + B, T, H, D = shape_dict['B'], shape_dict['T'], shape_dict['H'], shape_dict['D'] + try: + inputs = generate_inputs(config, B, T, H, D, dtype=dtype, device=device) + with _activate_paddle_driver(): + out = op_fn(**inputs, **config.extra_kwargs) + out_tensor = out[0] if config.output_is_tuple else out + do = paddle.randn(out_tensor.shape, dtype=out_tensor.dtype) + + def _fwd_fn(inputs=inputs): + return op_fn(**inputs, **config.extra_kwargs) + + def _fwdbwd_fn(inputs=inputs, do=do): + _clear_gradients(inputs) + result = op_fn(**inputs, **config.extra_kwargs) + tensor = result[0] if config.output_is_tuple else result + _backward(tensor, do) + + warmup_fn = _fwdbwd_fn if 'fwdbwd' in modes else _fwd_fn + _warmup_autotune(warmup_fn) + except Exception as error: + logger.warning(f"Warmup failed for {op_name} @ {shape_name}: {error}") + failed_shapes.add(shape_name) + + valid_shapes = {name: cfg for name, cfg in shapes.items() if name not in failed_shapes} + print(f" [{op_name}] Warmup done.") + + results = [] + for shape_name, shape_dict in valid_shapes.items(): + B, T, H, D = shape_dict['B'], shape_dict['T'], shape_dict['H'], shape_dict['D'] + try: + inputs = generate_inputs(config, B, T, H, D, dtype=dtype, device=device) + with _activate_paddle_driver(): + out = op_fn(**inputs, **config.extra_kwargs) + out_tensor = out[0] if config.output_is_tuple else out + do = paddle.randn(out_tensor.shape, dtype=out_tensor.dtype) + except Exception as error: + logger.warning(f"Input generation failed for {op_name} @ {shape_name}: {error}") + continue + + for mode in modes: + if mode == 'fwd': + def fn(inputs=inputs): + return op_fn(**inputs, **config.extra_kwargs) + else: + def fn(inputs=inputs, do=do): + _clear_gradients(inputs) + result = op_fn(**inputs, **config.extra_kwargs) + tensor = result[0] if config.output_is_tuple else result + _backward(tensor, do) + + try: + with _activate_paddle_driver(): + ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8], **_do_bench_kw()) + except Exception as error: + logger.warning(f"Bench failed for {op_name} {mode} @ {shape_name}: {error}") + continue + + results.append({ + 'op': op_name, + 'mode': mode, + 'B': B, + 'T': T, + 'H': H, + 'D': D, + 'median_ms': ms[0], + 'p20_ms': ms[1], + 'p80_ms': ms[2], + }) + + return results + + +def print_results_table(results: list[dict], machine_info: dict | None = None): + if not results: + print("\n No results to display.") + return + + width = 92 + print(f"\n{'=' * width}") + if machine_info: + gpu = machine_info.get('gpu_name', 'N/A') + paddle_version = machine_info.get('paddle_version', 'N/A') + print(f" Machine: {gpu} | Paddle {paddle_version}") + print(f"{'=' * width}") + print(f" {'op':':<18s} {'mode':':<7s} {'B':>4s} {'T':>6s} {'H':>4s} {'D':>4s} {'median(ms)':>12s} {'p20(ms)':>12s} {'p80(ms)':>12s}") + print(f" {'-' * 18} {'-' * 7} {'-' * 4} {'-' * 6} {'-' * 4} {'-' * 4} {'-' * 12} {'-' * 12} {'-' * 12}") + for result in results: + print( + f" {result['op']:<18s} {result['mode']:<7s} {result['B']:>4d} {result['T']:>6d} " + f"{result['H']:>4d} {result['D']:>4d} {result['median_ms']:>12.3f} " + f"{result['p20_ms']:>12.3f} {result['p80_ms']:>12.3f}" + ) + print(f"{'=' * width}") + + +def main(argv: list[str] | None = None): + parser = argparse.ArgumentParser(description='Paddle benchmark runner for flash-linear-attention ops') + parser.add_argument('--op', nargs='+', default=None, help='Op name(s) to benchmark, or "all"') + parser.add_argument( + '--custom-shapes', + default=None, + help='JSON string to override default shapes, e.g. \'{"my": {"B":1,"T":2048,"H":16,"D":128}}\'', + ) + parser.add_argument( + '--modes', + nargs='+', + default=['fwd', 'fwdbwd'], + choices=['fwd', 'fwdbwd'], + help='Benchmark modes (default: fwd fwdbwd)', + ) + parser.add_argument('--json', dest='json_file', default=None, help='Output file path for JSON results') + parser.add_argument('--list', action='store_true', help='List all registered ops and exit') + args = parser.parse_args(argv) + + if args.list: + ops = list_ops() + print(f"Registered ops ({len(ops)}):") + for name in ops: + cfg = get_op(name) + print(f" {name:18s} [{cfg.category}] {cfg.import_path}") + return [] + + if args.op is None: + parser.error('--op is required unless --list') + + op_names = list_ops() if args.op == ['all'] else args.op + shape_configs = json.loads(args.custom_shapes) if args.custom_shapes else SHAPE_CONFIGS + + machine_info = _get_machine_info() + print(f"Machine: {machine_info.get('gpu_name', 'N/A')} | Paddle {machine_info.get('paddle_version', 'N/A')}") + print(f"Shapes: {len(shape_configs)} configs") + print(f"Ops: {op_names}") + + all_results = [] + for op_name in op_names: + try: + all_results.extend(benchmark_op(op_name, shape_configs, modes=args.modes)) + except Exception as error: + logger.error(f"Failed to benchmark {op_name}: {error}") + + mode_order = {'fwd': 0, 'fwdbwd': 1} + all_results.sort(key=lambda result: (mode_order.get(result['mode'], 9), result['B'], result['T'], result['H'], result['D'], result['op'])) + print_results_table(all_results, machine_info) + + if args.json_file: + output = {'machine_info': machine_info, 'results': all_results} + with open(args.json_file, 'w') as handle: + json.dump(output, handle, indent=2) + print(f"\nResults saved to {args.json_file}") + + return all_results + + +if __name__ == '__main__': + main() diff --git a/linear_attention_conftest.py b/linear_attention_conftest.py new file mode 100644 index 0000000..9aa4a6e --- /dev/null +++ b/linear_attention_conftest.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Test fixtures for Paddle migration tests + +import logging +import os +import warnings +import importlib + +import paddle +import pytest + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# assert_close: matches torch fla.utils.assert_close semantics +# --------------------------------------------------------------------------- + +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().pow(2).mean().sqrt().item() + base = x.detach().flatten().pow(2).mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + error_rate = get_err_ratio(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {error_rate:.6f}" + logger.info(msg) + if abs_atol <= err_atol: + return + assert not paddle.isnan(ref).any(), f"{prefix}: NaN detected in ref" + assert not paddle.isnan(tri).any(), f"{prefix}: NaN detected in tri" + if warning: + if error_rate > ratio: + warnings.warn(msg) + else: + assert error_rate < ratio, msg + + +class FrameworkTracker: + """Track the framework backend of tensors and triton driver at runtime. + + Driver detection modes: + - probe (default in tests): reads the result captured *during* kernel + execution by the swap_driver_guard probe hook. This is accurate even + in mixed torch+paddle environments. + - snapshot: inspects the triton active driver at call time (outside + kernel execution). Fast but shows the *default* driver, which is + torch when torch is installed. + + Set ``FLA_BENCHMARK=1`` to disable all probing overhead. In that mode, + ``detect_triton_driver()`` falls back to snapshot and the report omits + the Triton Driver line entirely. + """ + + def __init__(self): + self._benchmark = os.environ.get("FLA_BENCHMARK", "0") == "1" + + # ------ tensor framework ------ + + @staticmethod + def detect_tensor_framework(tensor) -> str: + """Detect which framework a tensor belongs to.""" + module = type(tensor).__module__ + if 'paddle' in module: + return 'paddle' + elif 'torch' in module: + return 'torch' + return f'unknown({module})' + + # ------ triton driver ------ + + def detect_triton_driver(self) -> str: + """Return the triton driver detected during the last kernel launch. + + Uses the probe result captured inside swap_driver_guard when probing + is enabled; falls back to a snapshot of the current active driver + otherwise. + """ + if self._benchmark: + return self._detect_triton_driver_snapshot() + from flash_mask.linear_attn.triton_utils import get_driver_probe_result + result = get_driver_probe_result() + if result == "not_probed": + return self._detect_triton_driver_snapshot() + return result + + @staticmethod + def _detect_triton_driver_snapshot() -> str: + """Inspect the current triton active driver (outside kernel execution).""" + try: + from triton.runtime.driver import driver + from flash_mask.linear_attn.triton_utils import _detect_driver_framework + return _detect_driver_framework(driver.active) + except Exception as e: + return f'error({e})' + + # ------ autograd ------ + + @staticmethod + def detect_autograd_framework(tensor) -> str: + """Detect the autograd backend of a tensor.""" + import paddle + if isinstance(tensor, paddle.Tensor): + if not tensor.stop_gradient: + return 'paddle' + return 'paddle (no_grad)' + try: + import torch + if isinstance(tensor, torch.Tensor): + if tensor.requires_grad: + return 'torch' + return 'torch (no_grad)' + except ImportError: + pass + return 'unknown' + + # ------ report ------ + + def report(self, tensors: dict, label: str = ""): + """Generate a framework detection report.""" + lines = [] + if label: + lines.append(f"\n{'='*60}") + lines.append(f" Framework Detection Report: {label}") + lines.append(f"{'='*60}") + + if not self._benchmark: + lines.append(f" Triton Driver: {self.detect_triton_driver()}") + lines.append(f" {'─'*56}") + + for name, tensor in tensors.items(): + fw = self.detect_tensor_framework(tensor) + ag = self.detect_autograd_framework(tensor) + lines.append(f" {name:20s} | framework: {fw:8s} | autograd: {ag}") + + lines.append(f"{'='*60}\n") + return '\n'.join(lines) + + +@pytest.fixture(autouse=True) +def _driver_probe_lifecycle(): + """Enable driver probing before each test, disable after.""" + if os.environ.get("FLA_BENCHMARK", "0") == "1": + yield + return + from flash_mask.linear_attn.triton_utils import enable_driver_probe, disable_driver_probe + enable_driver_probe() + yield + disable_driver_probe() + + +@pytest.fixture(autouse=True) +def _linear_attn_cache_isolation(): + modules = [] + for name in ( + 'flash_mask.linear_attn.ops.common.chunk_o', + 'flash_mask.linear_attn.ops.kda.wy_fast', + ): + try: + modules.append(importlib.import_module(name)) + except Exception: + continue + for mod in modules: + for attr in ('_const_tiling', '_chunk_o_launch_meta', '_wy_tiling', '_wy_launch_meta'): + fn = getattr(mod, attr, None) + if fn is not None and hasattr(fn, 'cache_clear'): + fn.cache_clear() + yield + for mod in modules: + for attr in ('_const_tiling', '_chunk_o_launch_meta', '_wy_tiling', '_wy_launch_meta'): + fn = getattr(mod, attr, None) + if fn is not None and hasattr(fn, 'cache_clear'): + fn.cache_clear() + + +@pytest.fixture +def framework_tracker(): + return FrameworkTracker() diff --git a/test_gated_delta.py b/test_gated_delta.py new file mode 100644 index 0000000..199c4cf --- /dev/null +++ b/test_gated_delta.py @@ -0,0 +1,974 @@ +# -*- coding: utf-8 -*- +# Tests for Gated Delta Rule operators on PaddlePaddle +# Migrated from flash-attention/flashmask/tests/linear_attn/test_gated_delta.py + +import os + +import paddle +import paddle.nn.functional as F +import pytest +from einops import repeat + +from flash_mask.linear_attn.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from flash_mask.linear_attn.ops.gated_delta_rule.gate import fused_gdn_gate, naive_gdn_gate +from flash_mask.linear_attn.ops.gated_delta_rule.naive import naive_recurrent_gated_delta_rule + +from linear_attention_conftest import ( + assert_close, + _driver_probe_lifecycle, + _linear_attn_cache_isolation, +) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HV', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HV{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1, 1, paddle.float32), + (2, 500, 4, 4, 60, 1, 1, paddle.float32), + (2, 1000, 2, 8, 128, 1, 0.1, paddle.float32), + (3, 1024, 2, 2, 128, 0.1, 1, paddle.float32), + (4, 1024, 3, 3, 128, 1, 10, paddle.float32), + (4, 2048, 4, 4, 64, 0.1, 1, paddle.float32), + (2, 1024, 4, 4, 128, 1, 0.1, paddle.float16), + (2, 1024, 4, 8, 128, 1, 10, paddle.float16), + ] + ], +) +def test_fused_recurrent( + B: int, + T: int, + H: int, + HV: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.randn([B, T, H, D], dtype=paddle.float32) + k = paddle.randn([B, T, H, D], dtype=paddle.float32) + v = paddle.randn([B, T, HV, D], dtype=dtype) + beta = paddle.rand([B, T, HV], dtype=dtype).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, HV], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0 = paddle.randn([B, HV, D, D], dtype=paddle.float32) + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(repeat(q.clone(), 'b t h d -> b t (h g) d', g=HV // H), p=2, axis=-1).cast(dtype), + k=F.normalize(repeat(k.clone(), 'b t h d -> b t (h g) d', g=HV // H), p=2, axis=-1).cast(dtype), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + tri, tri_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0.clone(), + use_qk_l2norm_in_kernel=True, + output_final_state=True, + ) + assert_close('o', ref, tri, 0.002) + assert_close('ht', ref_ht, tri_ht, 0.002) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (2, 75, 4, 64, 1, 0.01, 0, False, paddle.float16), + (2, 500, 3, 60, 1, 1, 0, False, paddle.float16), + (2, 1000, 3, 64, 0.1, 1, 0.5, False, paddle.float16), + (3, 1024, 4, 100, 1, 0.1, 0, False, paddle.float16), + (4, 1024, 4, 128, 0.1, 1, 0, False, paddle.float16), + (4, 1024, 4, 128, 0.1, 1, 0, True, paddle.float16), + (2, 1500, 4, 128, 0.1, 10, 0, False, paddle.float16), + (4, 2048, 8, 64, 0.1, 1, 0, False, paddle.float16), + ] + ], +) +def test_chunk( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, H], dtype=paddle.float32)) + g = g / gate_logit_normalizer + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + for t in [q, k, v, beta, g, h0]: + t.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0.clear_gradient() + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + output_final_state=True, + initial_state=h0.clone(), + ) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + + +@pytest.mark.parametrize( + ('B', 'T', 'Hq', 'H', 'D', 'scale', 'gate_logit_normalizer', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-Hq{}-H{}-D{}-scale{}-gate_logit_normalizer{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (2, 256, 2, 4, 64, 1, 1, False, paddle.float16), + (2, 512, 1, 4, 64, 0.1, 1, False, paddle.float16), + (2, 512, 2, 8, 64, 1, 0.1, True, paddle.float16), + (2, 1024, 4, 8, 128, 0.1, 1, False, paddle.float16), + ] + ], +) +def test_chunk_gqa( + B: int, + T: int, + Hq: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + dtype, +): + paddle.seed(42) + assert H % Hq == 0 + G = H // Hq + + q = paddle.rand([B, T, Hq, D], dtype=dtype) + k = paddle.rand([B, T, Hq, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, H], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + for t in [q, k, v, beta, g, h0]: + t.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0.clear_gradient() + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(repeat(q.clone(), 'b t h d -> b t (h g) d', g=G), p=2, axis=-1), + k=F.normalize(repeat(k.clone(), 'b t h d -> b t (h g) d', g=G), p=2, axis=-1), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + output_final_state=True, + initial_state=h0.clone(), + ) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 64, 1, 1, paddle.float16), + (2, 500, 3, 60, 1, 1, paddle.float16), + (3, 1024, 4, 128, 0.1, 1, paddle.float16), + (4, 2048, 8, 64, 0.1, 1, paddle.float16), + ] + ], +) +def test_chunk_transpose_state( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=dtype).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, H], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0_kv = paddle.randn([B, H, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + for t in [q, k, v, beta, g, h0_kv, h0_vk]: + t.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_vk.clone(), + output_final_state=True, + transpose_state_layout=True, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht_vk = paddle.randn([B, H, D, D], dtype=paddle.float32) + dht_kv = dht_vk.transpose([0, 1, 3, 2]).contiguous() + ((tri * do).sum() + (tri_ht * dht_vk).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0_vk.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0_vk.clear_gradient() + + ref, ref_ht = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_kv.clone(), + output_final_state=True, + transpose_state_layout=False, + ) + ((ref * do).sum() + (ref_ht * dht_kv).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0_kv.grad.clone() + ) + + assert_close('o', ref, tri, 1e-4) + assert_close('ht', ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + assert_close('dq', ref_dq, tri_dq, 1e-4) + assert_close('dk', ref_dk, tri_dk, 1e-4) + assert_close('dv', ref_dv, tri_dv, 1e-4) + assert_close('db', ref_dbeta, tri_dbeta, 1e-4) + assert_close('dg', ref_dg, tri_dg, 1e-4) + assert_close('dh0', ref_dh0, tri_dh0.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'HV', 'D', 'scale', 'gate_logit_normalizer', 'dtype'), + [ + pytest.param(*test, id="B{}-T{}-H{}-HV{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test)) + for test in [ + (1, 63, 1, 1, 64, 1, 1, paddle.float32), + (2, 500, 4, 4, 60, 1, 1, paddle.float32), + (2, 1000, 2, 8, 128, 1, 0.1, paddle.float32), + (3, 1024, 2, 2, 128, 0.1, 1, paddle.float32), + (4, 2048, 4, 4, 64, 0.1, 1, paddle.float32), + ] + ], +) +def test_fused_recurrent_transpose_state( + B: int, + T: int, + H: int, + HV: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.randn([B, T, H, D], dtype=paddle.float32) + k = paddle.randn([B, T, H, D], dtype=paddle.float32) + v = paddle.randn([B, T, HV, D], dtype=dtype) + beta = paddle.rand([B, T, HV], dtype=dtype).sigmoid() + g = F.log_sigmoid(paddle.rand([B, T, HV], dtype=paddle.float32)) + g = g / gate_logit_normalizer + h0_kv = paddle.randn([B, HV, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + + ref, ref_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0_kv.clone(), + use_qk_l2norm_in_kernel=True, + output_final_state=True, + transpose_state_layout=False, + ) + tri, tri_ht = fused_recurrent_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + scale=scale, + initial_state=h0_vk.clone(), + use_qk_l2norm_in_kernel=True, + output_final_state=True, + transpose_state_layout=True, + ) + assert_close('o', ref, tri, 1e-4) + assert_close('ht', ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, 0, [0, 15], paddle.float16), + (4, 64, 0, [0, 256, 500, 1000], paddle.float16), + (4, 64, 0.5, [0, 256, 500, 1000], paddle.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set', +) +def test_chunk_varlen( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, +): + paddle.seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.rand([1, T, H], dtype=dtype)) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + beta = paddle.rand([1, T, H], dtype=paddle.float32).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=dtype) + + for t in [q, k, v, beta, g, h0]: + t.stop_gradient = False + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.rand(h0.shape, dtype=h0.dtype) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g.clear_gradient() + h0.clear_gradient() + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_gated_delta_rule( + q=q[:, s:e], + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=g[:, s:e], + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dg, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g.grad.clone(), h0.grad.clone() + ) + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.007) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.007) + assert_close('db', ref_dbeta, tri_dbeta, 0.015) + assert_close('dg', ref_dg, tri_dg, 0.015) + assert_close('dh0', ref_dh0, tri_dh0, 0.007) + + +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, 0, [0, 8192], paddle.float16), + (4, 60, 0, [0, 15], paddle.float16), + (4, 64, 0, [0, 256, 500, 1000], paddle.float16), + (4, 64, 0.5, [0, 256, 500, 1000], paddle.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set', +) +def test_chunk_varlen_prefill( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, +): + paddle.seed(42) + with paddle.no_grad(): + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.rand([1, T, H], dtype=dtype)) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + beta = paddle.rand([1, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=dtype) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + ) + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_gated_delta_rule( + q=q[:, s:e], + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=g[:, s:e], + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ('B', 'T', 'H', 'D', 'scale', 'has_dt_bias', 'use_qk_l2norm_in_kernel', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-has_dt_bias{}-use_qk_l2norm{}-{}".format(*test), + ) + for test in [ + (2, 75, 4, 64, 1, True, True, paddle.float16), + (2, 500, 3, 60, 1, False, False, paddle.float16), + (2, 1000, 3, 64, 0.1, True, False, paddle.float16), + (3, 1024, 4, 100, 1, True, True, paddle.float16), + (4, 1024, 4, 128, 0.1, False, True, paddle.float16), + (4, 2048, 8, 64, 0.1, True, False, paddle.float16), + ] + ], +) +def test_chunk_gate_in_kernel( + B: int, + T: int, + H: int, + D: int, + scale: float, + has_dt_bias: bool, + use_qk_l2norm_in_kernel: bool, + dtype, +): + """Test use_gate_in_kernel=True path: fused gate activation + chunk cumsum inside kernel.""" + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g_raw = paddle.randn([B, T, H], dtype=paddle.float32) + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H], dtype=paddle.float32) if has_dt_bias else None + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, beta, g_raw, h0]: + t.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + + # === Triton path: use_gate_in_kernel=True === + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone() if use_qk_l2norm_in_kernel else F.normalize(q.clone(), p=2, axis=-1), + k=k.clone() if use_qk_l2norm_in_kernel else F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g_raw.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=True, + A_log=A_log.clone(), + dt_bias=dt_bias.clone() if dt_bias is not None else None, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g_raw.grad.clone(), h0.grad.clone() + ) + tri_dA_log = A_log.grad.clone() + tri_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g_raw.clear_gradient() + h0.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + # === Reference path: manually compute gate, then use_gate_in_kernel=False === + g_ref = naive_gdn_gate(g_raw, A_log, dt_bias) + ref, ref_ht = chunk_gated_delta_rule( + q=q.clone() if use_qk_l2norm_in_kernel else F.normalize(q.clone(), p=2, axis=-1), + k=k.clone() if use_qk_l2norm_in_kernel else F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g_ref, + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), h0.grad.clone() + ) + ref_dg = g_raw.grad.clone() + ref_dA_log = A_log.grad.clone() + ref_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + assert_close('dA_log', ref_dA_log, tri_dA_log, 0.02) + if dt_bias is not None: + assert_close('ddt_bias', ref_ddt_bias, tri_ddt_bias, 0.02) + + +@pytest.mark.parametrize( + ('B', 'T', 'Hq', 'H', 'D', 'scale', 'has_dt_bias', 'dtype'), + [ + pytest.param( + *test, + id="B{}-T{}-Hq{}-H{}-D{}-scale{}-has_dt_bias{}-{}".format(*test), + ) + for test in [ + (2, 256, 2, 4, 64, 1, True, paddle.float16), + (2, 512, 1, 4, 64, 0.1, False, paddle.float16), + (2, 512, 2, 8, 64, 1, True, paddle.float16), + (2, 1024, 4, 8, 128, 0.1, True, paddle.float16), + ] + ], +) +def test_chunk_gate_in_kernel_gqa( + B: int, + T: int, + Hq: int, + H: int, + D: int, + scale: float, + has_dt_bias: bool, + dtype, +): + """Test use_gate_in_kernel=True with grouped value attention (HV > H).""" + paddle.seed(42) + assert H % Hq == 0 + + q = paddle.rand([B, T, Hq, D], dtype=dtype) + k = paddle.rand([B, T, Hq, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=paddle.float32).sigmoid() + g_raw = paddle.randn([B, T, H], dtype=paddle.float32) + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H], dtype=paddle.float32) if has_dt_bias else None + h0 = paddle.zeros([B, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, beta, g_raw, h0]: + t.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_raw.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + A_log=A_log.clone(), + dt_bias=dt_bias.clone() if dt_bias is not None else None, + ) + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g_raw.grad.clone(), h0.grad.clone() + ) + tri_dA_log = A_log.grad.clone() + tri_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g_raw.clear_gradient() + h0.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + g_ref = naive_gdn_gate(g_raw, A_log, dt_bias) + ref, ref_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_ref, + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), h0.grad.clone() + ) + ref_dg = g_raw.grad.clone() + ref_dA_log = A_log.grad.clone() + ref_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + assert_close('dA_log', ref_dA_log, tri_dA_log, 0.02) + if dt_bias is not None: + assert_close('ddt_bias', ref_ddt_bias, tri_ddt_bias, 0.02) + + +@pytest.mark.parametrize( + ('H', 'D', 'has_dt_bias', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-has_dt_bias{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, True, [0, 15], paddle.float16), + (4, 64, False, [0, 256, 500, 1000], paddle.float16), + (4, 64, True, [0, 256, 500, 1000], paddle.float16), + (4, 100, True, [0, 15, 100, 300, 1200, 2000], paddle.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test because SKIP_TEST_CHUNK_VARLEN is set', +) +def test_chunk_gate_in_kernel_varlen( + H: int, + D: int, + has_dt_bias: bool, + cu_seqlens: list, + dtype, +): + """Test use_gate_in_kernel=True with variable-length sequences.""" + paddle.seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = paddle.randn([1, T, H, D], dtype=dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + beta = paddle.rand([1, T, H], dtype=paddle.float32).sigmoid() + g_raw = paddle.randn([1, T, H], dtype=paddle.float32) + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H], dtype=paddle.float32) if has_dt_bias else None + h0 = paddle.randn([N, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, beta, g_raw, h0]: + t.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.rand(h0.shape, dtype=h0.dtype) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_raw.clone(), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + use_qk_l2norm_in_kernel=True, + use_gate_in_kernel=True, + A_log=A_log.clone(), + dt_bias=dt_bias.clone() if dt_bias is not None else None, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dbeta, tri_dg, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), g_raw.grad.clone(), h0.grad.clone() + ) + tri_dA_log = A_log.grad.clone() + tri_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + beta.clear_gradient() + g_raw.clear_gradient() + h0.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + g_ref = naive_gdn_gate(g_raw, A_log, dt_bias) + ref, ref_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + g=g_ref, + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + use_qk_l2norm_in_kernel=True, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dbeta, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + beta.grad.clone(), h0.grad.clone() + ) + ref_dg = g_raw.grad.clone() + ref_dA_log = A_log.grad.clone() + ref_ddt_bias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) + assert_close('dq', ref_dq, tri_dq, 0.008) + assert_close('dk', ref_dk, tri_dk, 0.008) + assert_close('dv', ref_dv, tri_dv, 0.008) + assert_close('db', ref_dbeta, tri_dbeta, 0.02) + assert_close('dg', ref_dg, tri_dg, 0.02) + assert_close('dh0', ref_dh0, tri_dh0, 0.008) + assert_close('dA_log', ref_dA_log, tri_dA_log, 0.02) + if dt_bias is not None: + assert_close('ddt_bias', ref_ddt_bias, tri_ddt_bias, 0.02) + + +@pytest.mark.parametrize( + ('B', 'T', 'HV', 'HAS_BIAS'), + [ + pytest.param(*test, id="B{}-T{}-HV{}-bias{}".format(*test)) + for test in [ + (1, 32, 2, False), + (2, 64, 4, True), + (4, 128, 8, True), + (4, 128, 16, False), + ] + ], +) +def test_gate( + B: int, + T: int, + HV: int, + HAS_BIAS: bool, +): + paddle.seed(42) + g = paddle.randn([B, T, HV], dtype=paddle.float32) + A_log = paddle.log(paddle.uniform([HV], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([HV], dtype=paddle.float32) if HAS_BIAS else None + g.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + do = paddle.randn([B, T, HV], dtype=paddle.float32) + + ref = naive_gdn_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + ) + tri = fused_gdn_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + ) + (ref * do).sum().backward(retain_graph=True) + + ref_dg = g.grad.clone() + ref_dA = A_log.grad.clone() + ref_dbias = dt_bias.grad.clone() if dt_bias is not None else None + g.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + (tri * do).sum().backward(retain_graph=True) + tri_dg = g.grad.clone() + tri_dA = A_log.grad.clone() + tri_dbias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close("o", ref, tri, 1e-4) + assert_close("dg", ref_dg, tri_dg, 1e-4) + assert_close("dA", ref_dA, tri_dA, 1e-4) + if HAS_BIAS: + assert_close("dbias", ref_dbias, tri_dbias, 1e-4) diff --git a/test_kda.py b/test_kda.py new file mode 100644 index 0000000..c3883d5 --- /dev/null +++ b/test_kda.py @@ -0,0 +1,915 @@ +# -*- coding: utf-8 -*- +# Tests for KDA (Kimi Delta Attention) operators on PaddlePaddle +# Migrated from flash-attention/flashmask/tests/linear_attn/test_kda.py + +import paddle +import paddle.nn.functional as F +import pytest + +from flash_mask.linear_attn.ops.kda import chunk_kda, fused_recurrent_kda +from flash_mask.linear_attn.ops.kda.fused_recurrent import fused_recurrent_kda_fwd +from flash_mask.linear_attn.ops.kda.gate import fused_kda_gate, naive_kda_gate, naive_kda_lowerbound_gate +from flash_mask.linear_attn.ops.kda.naive import naive_chunk_kda, naive_recurrent_kda + +from linear_attention_conftest import ( + assert_close, + _driver_probe_lifecycle, + _linear_attn_cache_isolation, +) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test), + ) + for test in [ + (1, 64, 1, 64, 1, 1, paddle.float32), + (2, 512, 3, 60, 1, 1, paddle.float32), + (4, 1024, 4, 128, 0.1, 1, paddle.float32), + (4, 1024, 4, 128, 1, 10, paddle.float32), + ] + ], +) +def test_naive_chunk( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([B, H, D, D], dtype=paddle.float32) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + + tri, tri_ht = naive_chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "use_qk_l2norm_in_kernel", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-use_qk_l2norm_in_kernel{}-{}".format(*test), + ) + for test in [ + (1, 64, 1, 64, 1, 1, False, paddle.float32), + (2, 512, 3, 60, 1, 1, False, paddle.float32), + (3, 1000, 4, 100, 0.1, 1, True, paddle.float32), + (4, 1024, 4, 128, 0.1, 1, False, paddle.float32), + ] + ], +) +def test_fused_recurrent( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([B, H, D, D], dtype=paddle.float32) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + + tri, tri_ht = fused_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test), + ) + for test in [ + (1, 64, 1, 64, 1, 1, paddle.float32), + (2, 512, 3, 60, 1, 1, paddle.float32), + (4, 1024, 4, 128, 0.1, 1, paddle.float32), + (4, 1024, 4, 128, 1, 10, paddle.float32), + ] + ], +) +def test_fused_recurrent_transpose_state( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0_kv = paddle.randn([B, H, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + + ref, ref_ht = fused_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_kv.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=False, + ) + tri, tri_ht = fused_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_vk.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=True, + ) + assert_close("o", ref, tri, 1e-4) + assert_close("ht", ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ("B", "H", "D", "scale", "gate_logit_normalizer", "use_qk_l2norm_in_kernel", "use_gate_in_kernel", "safe_gate", "dtype"), + [ + pytest.param( + *test, + id="B{}-H{}-D{}-scale{}-norm{}-qk_l2{}-gate{}-safe_gate{}-dtype{}".format(*test), + ) + for test in [ + (16, 16, 128, 0.1, 1.0, True, False, False, paddle.bfloat16), + (32, 8, 64, 1.0, 1.0, False, False, False, paddle.float16), + (16, 16, 128, 0.1, 1.0, True, True, False, paddle.bfloat16), + (32, 8, 64, 1.0, 1.0, False, True, False, paddle.float16), + (7, 32, 128, 0.5, 0.5, True, True, True, paddle.bfloat16), + ] + ], +) +def test_fused_recurrent_vllm_decode( + B: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + safe_gate: bool, + dtype, +): + """Test vLLM-style decoding with continuous batching and paged state storage.""" + paddle.seed(42) + + # Setup cache pool and inputs + max_cache_slots = B * 3 + state_pool = paddle.randn([max_cache_slots, H, D, D], dtype=paddle.float32) + state_indices = paddle.randperm(max_cache_slots)[:B].cast(paddle.int32) + + # Fill unaccessed slots with a huge value to detect out-of-bound access + HUGE_VALUE = 1e30 + mask = paddle.ones([max_cache_slots], dtype='bool') + mask[state_indices.cast(paddle.int64)] = False + state_pool[mask] = HUGE_VALUE + + T = 1 + total_tokens = B * T + + q = paddle.rand([1, total_tokens, H, D], dtype=dtype) + k = paddle.rand([1, total_tokens, H, D], dtype=dtype) + v = paddle.rand([1, total_tokens, H, D], dtype=dtype) + g = paddle.randn([1, total_tokens, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + + if use_gate_in_kernel: + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)).squeeze() + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + lower_bound = -5.0 if safe_gate else None + naive_kda_gate_fn = naive_kda_lowerbound_gate if safe_gate else naive_kda_gate + else: + g = F.log_sigmoid(g) / gate_logit_normalizer + A_log = None + dt_bias = None + lower_bound = None + naive_kda_gate_fn = None + + beta = paddle.randn([1, total_tokens, H], dtype=dtype).sigmoid() + + cu_seqlens = paddle.arange(0, total_tokens + 1, step=T, dtype=paddle.int32) + ref_state_pool = state_pool.clone() + tri_state_pool = state_pool.clone() + + # Reference implementation (loop over batch) + ref_outputs = [] + for i in range(B): + start, end = i, i + 1 + slot_idx = state_indices[i].item() + + q_i = q[:, start:end].clone() + k_i = k[:, start:end].clone() + v_i = v[:, start:end].clone() + g_i = g[:, start:end].clone() + beta_i = beta[:, start:end].clone() + + h_init = ref_state_pool[slot_idx].clone().unsqueeze(0) + ref_o_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q_i, p=2, axis=-1), + k=F.normalize(k_i, p=2, axis=-1), + v=v_i, + g=(naive_kda_gate_fn(g_i, A_log, dt_bias) if use_gate_in_kernel else g_i), + beta=beta_i, + scale=scale, + initial_state=h_init, + output_final_state=True + ) + ref_outputs.append(ref_o_i) + ref_state_pool[slot_idx] = ref_ht_i.squeeze(0) + + ref_out = paddle.concat(ref_outputs, axis=1) + + # Triton kernel + q_in = q.clone() + k_in = k.clone() + if not use_qk_l2norm_in_kernel: + q_in = F.normalize(q_in, p=2, axis=-1) + k_in = F.normalize(k_in, p=2, axis=-1) + + tri_out, _ = fused_recurrent_kda_fwd( + q=q_in, + k=k_in, + v=v, + g=g, + beta=beta, + A_log=A_log, + dt_bias=dt_bias, + initial_state=tri_state_pool, + scale=scale, + output_final_state=False, + inplace_final_state=True, + cu_seqlens=cu_seqlens, + ssm_state_indices=state_indices, + num_accepted_tokens=None, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + lower_bound=lower_bound, + ) + + # Verify results + assert_close("o", ref_out, tri_out, 0.005) + assert_close("ht", ref_state_pool[state_indices.cast(paddle.int64)], + tri_state_pool[state_indices.cast(paddle.int64)], 0.005) + + mask = paddle.ones([max_cache_slots], dtype='bool') + mask[state_indices.cast(paddle.int64)] = False + assert_close("Untouched ht", ref_state_pool[mask], tri_state_pool[mask], 0.0) + + +@pytest.mark.parametrize( + ( + "B", "T", "H", "D", "scale", "gate_logit_normalizer", + "mask_p", "use_qk_l2norm_in_kernel", "use_gate_in_kernel", + "dtype", "safe_gate", "disable_recompute", + ), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-qk_l2norm{}-gate{}-dtype{}-safe_gate{}-disable_recompute{}".format( + *test), + ) + for test in [ + (1, 63, 1, 64, 1, 1, 0, False, False, paddle.float16, True, False), + (2, 500, 3, 60, 1, 1, 0, False, False, paddle.float16, True, True), + (2, 1000, 3, 64, 0.1, 1, 0.5, False, False, paddle.float16, False, True), + (3, 1024, 4, 100, 1, 0.1, 0, False, False, paddle.float16, False, False), + (4, 1024, 4, 128, 0.1, 1, 0, False, False, paddle.float16, True, True), + (4, 1024, 4, 128, 0.1, 1, 0, True, False, paddle.float16, True, False), + (2, 1500, 4, 128, 0.1, 10, 0, False, True, paddle.float16, False, True), + (4, 2048, 8, 64, 0.1, 1, 0, False, True, paddle.float16, True, True), + ] + ], +) +def test_chunk( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + dtype, + safe_gate: bool, + disable_recompute: bool, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = paddle.randn([B, T, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = paddle.randn([H], dtype=paddle.float32) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + else: + g = F.log_sigmoid(g) / gate_logit_normalizer + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + if safe_gate: + lower_bound = -5.0 + if not use_gate_in_kernel: + g = g.clip(-5, 0) + naive_kda_gate_fn = naive_kda_lowerbound_gate + else: + lower_bound = None + naive_kda_gate_fn = naive_kda_gate + + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([B, H, D, D], dtype=paddle.float32) + + if use_gate_in_kernel: + A_log.stop_gradient = False + dt_bias.stop_gradient = False + for t in [q, k, v, g, beta, h0]: + t.stop_gradient = False + + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.randn(h0.shape, dtype=h0.dtype) + + ref, ref_ht = naive_recurrent_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=(naive_kda_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + ) + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + if use_gate_in_kernel: + ref_dA = A_log.grad.clone() + A_log.clear_gradient() + ref_dbias = dt_bias.grad.clone() + dt_bias.clear_gradient() + ref_dq, ref_dk, ref_dv, ref_dg, ref_db, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + g.clear_gradient() + beta.clear_gradient() + h0.clear_gradient() + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, axis=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + scale=scale, + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + disable_recompute=disable_recompute, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + if use_gate_in_kernel: + tri_dA = A_log.grad.clone() + A_log.clear_gradient() + tri_dbias = dt_bias.grad.clone() + dt_bias.clear_gradient() + tri_dq, tri_dk, tri_dv, tri_dg, tri_db, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("dq", ref_dq, tri_dq, 0.008) + assert_close("dk", ref_dk, tri_dk, 0.008) + assert_close("dv", ref_dv, tri_dv, 0.008) + assert_close("dg", ref_dg, tri_dg, 0.02) + assert_close("db", ref_db, tri_db, 0.02) + if use_gate_in_kernel: + assert_close("dA", ref_dA, tri_dA, 0.003, warning=True) + # Paddle migration shows slightly larger numerical drift on dt_bias grad than Torch. + # Keep a slightly looser tolerance here to avoid rejecting acceptable backend differences. + assert_close("dbias", ref_dbias, tri_dbias, 0.01) + assert_close("dh0", ref_dh0, tri_dh0, 0.008) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "scale", "gate_logit_normalizer", "dtype"), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-{}".format(*test), + ) + for test in [ + (1, 63, 1, 64, 1, 1, paddle.float16), + (2, 500, 3, 60, 1, 1, paddle.float16), + (3, 1024, 4, 128, 0.1, 1, paddle.float16), + (4, 2048, 8, 64, 0.1, 1, paddle.float16), + ] + ], +) +def test_chunk_transpose_state( + B: int, + T: int, + H: int, + D: int, + scale: float, + gate_logit_normalizer: float, + dtype, +): + paddle.seed(42) + q = paddle.rand([B, T, H, D], dtype=dtype) + k = paddle.rand([B, T, H, D], dtype=dtype) + v = paddle.rand([B, T, H, D], dtype=dtype) + g = F.log_sigmoid(paddle.randn([B, T, H, D], dtype=paddle.float32)) / gate_logit_normalizer + beta = paddle.randn([B, T, H], dtype=dtype).sigmoid() + h0_kv = paddle.randn([B, H, D, D], dtype=paddle.float32) + h0_vk = h0_kv.transpose([0, 1, 3, 2]).contiguous() + + for t in [q, k, v, g, beta, h0_kv, h0_vk]: + t.stop_gradient = False + + do = paddle.randn(v.shape, dtype=v.dtype) + dht_vk = paddle.randn([B, H, D, D], dtype=paddle.float32) + dht_kv = dht_vk.transpose([0, 1, 3, 2]).contiguous() + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_vk.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=True, + ) + ((tri * do).sum() + (tri_ht * dht_vk).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dg, tri_db, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0_vk.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + g.clear_gradient() + beta.clear_gradient() + h0_vk.clear_gradient() + + ref, ref_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=F.normalize(k.clone(), p=2, axis=-1), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + scale=scale, + initial_state=h0_kv.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=False, + transpose_state_layout=False, + ) + ((ref * do).sum() + (ref_ht * dht_kv).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dg, ref_db, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0_kv.grad.clone() + ) + + assert_close("o", ref, tri, 1e-4) + assert_close("ht", ref_ht, tri_ht.transpose([0, 1, 3, 2]), 1e-4) + assert_close("dq", ref_dq, tri_dq, 1e-4) + assert_close("dk", ref_dk, tri_dk, 1e-4) + assert_close("dv", ref_dv, tri_dv, 1e-4) + assert_close("dg", ref_dg, tri_dg, 1e-4) + assert_close("db", ref_db, tri_db, 1e-4) + assert_close("dh0", ref_dh0, tri_dh0.transpose([0, 1, 3, 2]), 1e-4) + + +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "use_gate_in_kernel", "safe_gate", "disable_recompute"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-gate{}-safe_gate{}-disable_recompute{}".format(*test)) + for test in [ + (4, 60, 0.1, [0, 15], paddle.float16, True, False, False), + (4, 64, 0.9, [0, 256, 500, 1000], paddle.float16, True, False, False), + (4, 128, 0.5, [0, 256, 500, 1000], paddle.float16, False, False, False), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16, True, False, False), + (4, 256, 0, [0, 100, 300, 1200, 3000, 4096], paddle.float16, False, True, True), + ] + ], +) +def test_chunk_varlen( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, + use_gate_in_kernel: bool, + safe_gate: bool, + disable_recompute: bool, +): + paddle.seed(42) + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + cu_seqlens_cpu = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = paddle.randn([1, T, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + else: + g = F.log_sigmoid(g) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + mask = (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + g = g * mask + (1 - mask) * (-1000) + if safe_gate: + assert use_gate_in_kernel is False + g = g.clip(-5, 0) + + beta = paddle.rand([1, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=paddle.float32) + + for t in [q, k, v, g, beta, h0]: + t.stop_gradient = False + if use_gate_in_kernel: + A_log.stop_gradient = False + dt_bias.stop_gradient = False + do = paddle.randn(v.shape, dtype=v.dtype) + dht = paddle.rand(h0.shape, dtype=h0.dtype) + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + cu_seqlens_cpu=cu_seqlens_cpu, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + disable_recompute=disable_recompute, + ) + ((tri * do).sum() + (tri_ht * dht).sum()).backward(retain_graph=True) + tri_dq, tri_dk, tri_dv, tri_dg, tri_db, tri_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + q.clear_gradient() + k.clear_gradient() + v.clear_gradient() + g.clear_gradient() + beta.clear_gradient() + h0.clear_gradient() + if use_gate_in_kernel: + tri_dA = A_log.grad.clone() + A_log.clear_gradient() + tri_dbias = dt_bias.grad.clone() + dt_bias.clear_gradient() + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q[:, s:e], p=2, axis=-1), + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=(naive_kda_gate(g[:, s:e].cast(paddle.float32), A_log.cast(paddle.float32), + dt_bias.cast(paddle.float32)) if use_gate_in_kernel else g[:, s:e]), + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + ((ref * do).sum() + (ref_ht * dht).sum()).backward(retain_graph=True) + ref_dq, ref_dk, ref_dv, ref_dg, ref_db, ref_dh0 = ( + q.grad.clone(), k.grad.clone(), v.grad.clone(), + g.grad.clone(), beta.grad.clone(), h0.grad.clone() + ) + if use_gate_in_kernel: + ref_dA = A_log.grad.clone() + ref_dbias = dt_bias.grad.clone() + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("dq", ref_dq, tri_dq, 0.007) + assert_close("dk", ref_dk, tri_dk, 0.008) + assert_close("dv", ref_dv, tri_dv, 0.007) + assert_close("dg", ref_dg, tri_dg, 0.015) + assert_close("db", ref_db, tri_db, 0.015) + assert_close("dh0", ref_dh0, tri_dh0, 0.007) + if use_gate_in_kernel: + assert_close("dA", ref_dA, tri_dA, 0.008, warning=True) + assert_close("dbias", ref_dbias, tri_dbias, 0.005) + + +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "use_gate_in_kernel", "safe_gate", "disable_recompute"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-gate{}-safe_gate{}-disable_recompute{}".format(*test)) + for test in [ + (4, 60, 0.1, [0, 8192], paddle.float16, True, False, False), + (4, 64, 0.9, [0, 256, 500, 1000], paddle.float16, True, False, False), + (4, 128, 0.5, [0, 256, 500, 1000], paddle.float16, False, False, False), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], paddle.float16, True, False, False), + (4, 256, 0, [0, 100, 300, 1200, 3000, 4096], paddle.float16, False, True, True), + ] + ], +) +def test_chunk_varlen_prefill( + H: int, + D: int, + mask_p: float, + cu_seqlens: list, + dtype, + use_gate_in_kernel: bool, + safe_gate: bool, + disable_recompute: bool, +): + paddle.seed(42) + with paddle.no_grad(): + cu_seqlens_t = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + cu_seqlens_cpu = paddle.to_tensor(cu_seqlens, dtype=paddle.int64) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = paddle.randn([1, T, H, D], dtype=dtype) + k = F.normalize(paddle.randn([1, T, H, D], dtype=paddle.float32), p=2, axis=-1).cast(dtype) + v = paddle.randn([1, T, H, D], dtype=dtype) + g = paddle.randn([1, T, H, D], dtype=paddle.float32 if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) + else: + g = F.log_sigmoid(g) + g = g * (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + mask = (paddle.rand(g.shape, dtype=g.dtype) > mask_p).cast(g.dtype) + g = g * mask + (1 - mask) * (-1000) + if safe_gate: + assert use_gate_in_kernel is False + g = g.clip(-5, 0) + + beta = paddle.rand([1, T, H], dtype=dtype).sigmoid() + h0 = paddle.randn([N, H, D, D], dtype=paddle.float32) + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, axis=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens_t, + cu_seqlens_cpu=cu_seqlens_cpu, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + disable_recompute=disable_recompute, + ) + + ref_list = [] + ref_ht_list = [] + for i in range(N): + s, e = cu_seqlens[i], cu_seqlens[i + 1] + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q[:, s:e], p=2, axis=-1), + k=k[:, s:e], + v=v[:, s:e], + beta=beta[:, s:e], + g=(naive_kda_gate(g[:, s:e].cast(paddle.float32), A_log.cast(paddle.float32), + dt_bias.cast(paddle.float32)) if use_gate_in_kernel else g[:, s:e]), + initial_state=h0[i], + output_final_state=True, + ) + ref_list.append(ref_i) + ref_ht_list.append(ref_ht_i) + ref = paddle.concat(ref_list, axis=1) + ref_ht = paddle.concat(ref_ht_list, axis=0) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("B", "T", "H", "D", "HAS_BIAS", "LOWER_BOUND"), + [ + pytest.param(*test, id="B{}-T{}-H{}-D{}-bias{}-lowerbound{}".format(*test)) + for test in [ + (1, 2, 2, 12, False, -5.0), + (1, 32, 2, 16, False, -5.0), + (2, 64, 4, 32, False, -5.0), + (4, 128, 8, 64, False, -5.0), + (4, 128, 8, 128, False, None), + (1, 2, 2, 12, True, None), + (1, 32, 2, 16, True, None), + (2, 64, 4, 32, True, None), + (4, 128, 8, 64, True, None), + (4, 128, 8, 128, True, None), + ] + ], +) +def test_gate( + B: int, + T: int, + H: int, + D: int, + HAS_BIAS: bool, + LOWER_BOUND, +): + paddle.seed(42) + g = paddle.randn([B, T, H, D], dtype=paddle.float32) * 10 + A_log = paddle.log(paddle.uniform([1, 1, H, 1], dtype=paddle.float32, min=1, max=16)) + dt_bias = paddle.randn([H * D], dtype=paddle.float32) if HAS_BIAS else None + g.stop_gradient = False + A_log.stop_gradient = False + if dt_bias is not None: + dt_bias.stop_gradient = False + do = paddle.randn([B, T, H, D], dtype=paddle.float32) + + if LOWER_BOUND is not None: + ref = naive_kda_lowerbound_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, LOWER_BOUND + ) + else: + ref = naive_kda_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + ) + tri = fused_kda_gate( + g.clone(), A_log.clone(), dt_bias.clone() if dt_bias is not None else None, + lower_bound=LOWER_BOUND + ) + (ref * do).sum().backward(retain_graph=True) + + ref_dg = g.grad.clone() + ref_dA = A_log.grad.clone() + ref_dbias = dt_bias.grad.clone() if dt_bias is not None else None + g.clear_gradient() + A_log.clear_gradient() + if dt_bias is not None: + dt_bias.clear_gradient() + + ((tri * do).sum()).backward(retain_graph=True) + tri_dg = g.grad.clone() + tri_dA = A_log.grad.clone() + tri_dbias = dt_bias.grad.clone() if dt_bias is not None else None + + assert_close("o", ref, tri, 1e-4) + assert_close("dg", ref_dg, tri_dg, 1e-4) + assert_close("dA", ref_dA, tri_dA, 1e-4) + if HAS_BIAS: + assert_close("dbias", ref_dbias, tri_dbias, 1e-4) + + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +def test_chunk_return_intermediate_states(dtype): + """Test that return_intermediate_states=True works in inference mode and returns h with correct shape.""" + paddle.seed(42) + B, T, H, D = 2, 1024, 4, 128 + chunk_size = 64 + + with paddle.no_grad(): + q = paddle.randn([B, T, H, D], dtype=dtype) + k = paddle.randn([B, T, H, D], dtype=dtype) + v = paddle.randn([B, T, H, D], dtype=dtype) + g = paddle.randn([B, T, H, D], dtype=dtype) + beta = paddle.rand([B, T, H], dtype=dtype) + + # Test equal-length sequences + o, final_state, h = chunk_kda( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=None, + output_final_state=True, + return_intermediate_states=True, + disable_recompute=False, + ) + + # Verify shapes + assert list(o.shape) == [B, T, H, D], f"Output shape mismatch: {o.shape}" + assert list(final_state.shape) == [B, H, D, D], f"Final state shape mismatch: {final_state.shape}" + + expected_nt = (T + chunk_size - 1) // chunk_size + assert list(h.shape) == [B, expected_nt, H, D, D], f"h shape mismatch: {h.shape}" + assert h.dtype == dtype, f"h dtype should be {dtype}, got: {h.dtype}" + + # Test variable-length sequences + total_tokens = 1024 + N = 2 + seq_len = total_tokens // N + cu_seqlens = paddle.to_tensor([0, seq_len, total_tokens], dtype=paddle.int64) + + q_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + k_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + v_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + g_varlen = paddle.randn([1, total_tokens, H, D], dtype=dtype) + beta_varlen = paddle.rand([1, total_tokens, H], dtype=dtype) + + o_varlen, final_state_varlen, h_varlen = chunk_kda( + q=q_varlen, + k=k_varlen, + v=v_varlen, + g=g_varlen, + beta=beta_varlen, + initial_state=None, + output_final_state=True, + cu_seqlens=cu_seqlens, + return_intermediate_states=True, + disable_recompute=False, + ) + + assert list(o_varlen.shape) == [1, total_tokens, H, D], f"Varlen output shape mismatch: {o_varlen.shape}" + assert list(final_state_varlen.shape) == [N, H, D, D], f"Varlen final state shape mismatch: {final_state_varlen.shape}" + assert h_varlen.shape[0] == 1, f"Varlen h batch dim should be 1, got: {h_varlen.shape[0]}" + assert list(h_varlen.shape[2:]) == [H, D, D], f"Varlen h dims mismatch: {h_varlen.shape[2:]}" + assert h_varlen.dtype == dtype, f"Varlen h dtype should be {dtype}, got: {h_varlen.dtype}" From 14e9e0fce630a2372bbb6fce336b9f5085c7e96b Mon Sep 17 00:00:00 2001 From: zhangyangjie Date: Mon, 27 Apr 2026 15:56:07 +0800 Subject: [PATCH 2/2] update linear attention benchmark script --- benchmark_linear_attention_run.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark_linear_attention_run.py b/benchmark_linear_attention_run.py index 530b1a0..5c45a35 100644 --- a/benchmark_linear_attention_run.py +++ b/benchmark_linear_attention_run.py @@ -217,8 +217,9 @@ def print_results_table(results: list[dict], machine_info: dict | None = None): paddle_version = machine_info.get('paddle_version', 'N/A') print(f" Machine: {gpu} | Paddle {paddle_version}") print(f"{'=' * width}") - print(f" {'op':':<18s} {'mode':':<7s} {'B':>4s} {'T':>6s} {'H':>4s} {'D':>4s} {'median(ms)':>12s} {'p20(ms)':>12s} {'p80(ms)':>12s}") + print(f" {'op':<18s} {'mode':<7s} {'B':>4s} {'T':>6s} {'H':>4s} {'D':>4s} {'median(ms)':>12s} {'p20(ms)':>12s} {'p80(ms)':>12s}") print(f" {'-' * 18} {'-' * 7} {'-' * 4} {'-' * 6} {'-' * 4} {'-' * 4} {'-' * 12} {'-' * 12} {'-' * 12}") + for result in results: print( f" {result['op']:<18s} {result['mode']:<7s} {result['B']:>4d} {result['T']:>6d} "