Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions README_linear_attention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# linear_attn

Triton-based GDN (Gated Delta Networks) and KDA (Kimi Delta Attention) operators with chunk-wise and fused-recurrent execution modes.

## Dependencies

- PaddlePaddle-GPU (GPU required)
- triton
- pytest (for tests)
- einops (for GDN tests)

## Environment Setup

```bash
# Install the flash_mask package (includes linear_attn)
cd /path/to/flash-attention/flashmask
pip install -e . --no-build-isolation
```

## Running Tests

Test files are located at `test_flashmask/`:

| File | Description |
|------|-------------|
| `test_gated_delta.py` | GDN operator correctness tests |
| `test_kda.py` | KDA operator correctness tests |

Each test compares the Triton-optimized implementation against a naive Python reference, checking both forward output and backward gradients.

```bash
cd /path/to/test_flashmask

# Run all tests
pytest test_gated_delta.py test_kda.py -v

# Run GDN tests only
pytest test_gated_delta.py -v

# Run KDA tests only
pytest test_kda.py -v

# Run a single test function
pytest test_gated_delta.py::test_fused_recurrent -v
pytest test_kda.py::test_chunk -v

# Filter by parametrized id
pytest test_gated_delta.py -k "test_fused_recurrent and B1-T63" -v
```

### Optional Environment Variables

| Variable | Default | Description |
|----------|---------|-------------|
| `SKIP_TEST_CHUNK_VARLEN=1` | unset | Skip varlen (variable-length sequence) tests |
| `FLA_BENCHMARK=1` | `0` | Disable driver probing overhead |

```bash
SKIP_TEST_CHUNK_VARLEN=1 pytest test_gated_delta.py test_kda.py -v
```

## Running Benchmarks

The benchmark framework is located at `test_flashmask/` and supports 4 operators:

| Operator | Description | Modes |
|----------|-------------|-------|
| `chunk_gdn` | GDN chunk-level | fwd / fwdbwd |
| `chunk_kda` | KDA chunk-level | fwd / fwdbwd |
| `recurrent_gdn` | GDN fused recurrent | fwd only |
| `recurrent_kda` | KDA fused recurrent | fwd only |

```bash
cd /path/to/test_flashmask

# List registered operators
python benchmark_linear_attention_run.py --list

# Run all benchmarks
python benchmark_linear_attention_run.py --op all

# Run specific operators
python benchmark_linear_attention_run.py --op chunk_gdn
python benchmark_linear_attention_run.py --op chunk_kda recurrent_kda

# Forward only
python benchmark_linear_attention_run.py --op chunk_gdn --modes fwd

# Custom shapes
python benchmark_linear_attention_run.py --op chunk_gdn \
--custom-shapes '{"smoke":{"B":1,"T":64,"H":2,"D":32}}'

# Save results as JSON
python benchmark_linear_attention_run.py --op all --json results.json
```

### Default Shape Configs

| Config Name | B | T | H | D |
|-------------|---|---|---|---|
| B1_T8192_H96_D128 | 1 | 8192 | 96 | 128 |
| B2_T16384_H16_D128 | 2 | 16384 | 16 | 128 |
| B4_T2048_H16_D128 | 4 | 2048 | 16 | 128 |
| B4_T4096_H64_D128 | 4 | 4096 | 64 | 128 |
| B8_T2048_H32_D256 | 8 | 2048 | 32 | 256 |
| B8_T1024_H8_D64 | 8 | 1024 | 8 | 64 |

### Benchmark Environment Variables

| Variable | Default | Description |
|----------|---------|-------------|
| `FLA_BENCH_OP_WARMUP_ITERS` | `5` | Number of warmup iterations |
| `FLA_BENCH_WARMUP_MS` | `100` | `do_bench` warmup time (ms) |
| `FLA_BENCH_REP_MS` | `500` | `do_bench` repeat measurement time (ms) |

## Known Limitations

- Context Parallel (CP) is NOT supported.
- `fused_recurrent_gdn` / `fused_recurrent_kda` are forward-only. Use `chunk_gdn` / `chunk_kda` for training workloads that require gradients.

## Known Issue: GDN Backward Precision on Hopper GPUs with Triton >= 3.4.0

The upstream fla-org/flash-linear-attention project has identified a backward precision issue in the gated `chunk_bwd_dqkwg` kernel when running on Hopper-class GPUs (H20, H100, GB200, etc.) with Triton >= 3.4.0 ([upstream PR #827](https://github.com/fla-org/flash-linear-attention/pull/827)). The upstream fix introduces a TileLang-based kernel as an alternative backend.

**Current status in this fork:**
- On NVIDIA H800 (Hopper) with the current Triton version, this issue has **not** been observed in practice.
- If you plan to deploy on other Hopper GPUs (H20, GB200, etc.), or upgrade Triton to >= 3.4.0, you may encounter this backward precision regression.
- The TileLang backend has **not** been integrated into this Paddle port yet.

**Action needed:** When targeting Hopper GPUs other than H800 or upgrading Triton, consider integrating the TileLang backend from the upstream fix (`pip install tilelang` + dispatch logic in `fla/ops/common/chunk_o.py`).
177 changes: 177 additions & 0 deletions benchmark_linear_attention_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any

import paddle
import paddle.nn.functional as F

logger = logging.getLogger(__name__)


def shape_BTHD(B, T, H, D, **kw):
return (B, T, H, D)


def shape_BTH(B, T, H, D, **kw):
return (B, T, H)


logsigmoid = F.log_sigmoid


def sigmoid_transform(t):
return t.sigmoid()


@dataclass
class TensorSpec:
shape_fn: Callable
requires_grad: bool = True
dtype: Any = 'default'
transform: Callable | None = None


@dataclass
class OpConfig:
name: str
import_path: str
inputs: dict[str, TensorSpec]
func_name: str | None = None
extra_kwargs: dict[str, Any] = field(default_factory=dict)
output_is_tuple: bool = True
skip_backward: bool = False
category: str = ''


_REGISTRY: dict[str, OpConfig] = {}


def register_op(config: OpConfig) -> None:
_REGISTRY[config.name] = config


SHAPE_CONFIGS = {
'B1_T8192_H96_D128': {'B': 1, 'T': 8192, 'H': 96, 'D': 128},
'B2_T16384_H16_D128': {'B': 2, 'T': 16384, 'H': 16, 'D': 128},
'B4_T2048_H16_D128': {'B': 4, 'T': 2048, 'H': 16, 'D': 128},
'B4_T4096_H64_D128': {'B': 4, 'T': 4096, 'H': 64, 'D': 128},
'B8_T2048_H32_D256': {'B': 8, 'T': 2048, 'H': 32, 'D': 256},
'B8_T1024_H8_D64': {'B': 8, 'T': 1024, 'H': 8, 'D': 64},
}


def get_op(name: str) -> OpConfig:
if name not in _REGISTRY:
raise KeyError(f"Op '{name}' not registered. Available: {sorted(_REGISTRY)}")
return _REGISTRY[name]


def list_ops() -> list[str]:
return sorted(_REGISTRY.keys())


def _resolve_dtype(dtype):
if dtype == 'default':
return paddle.bfloat16
if dtype == 'float32':
return paddle.float32
if dtype == 'int64':
return paddle.int64
return dtype


def _set_device(device: str | None):
if device is None:
return
current = paddle.get_device()
if current != device:
paddle.device.set_device(device)


def generate_inputs(
config: OpConfig,
B: int,
T: int,
H: int,
D: int,
dtype=paddle.bfloat16,
device: str | None = None,
) -> dict[str, paddle.Tensor]:
_set_device(device)
inputs: dict[str, paddle.Tensor] = {}
for param_name, spec in config.inputs.items():
shape = spec.shape_fn(B, T, H, D)
tensor_dtype = dtype if spec.dtype == 'default' else _resolve_dtype(spec.dtype)
if tensor_dtype == paddle.int64:
tensor = paddle.randint(0, 10, shape=shape, dtype=tensor_dtype)
else:
tensor = paddle.randn(shape, dtype=tensor_dtype)
if spec.transform is not None:
tensor = spec.transform(tensor)
if spec.requires_grad and paddle.is_floating_point(tensor):
tensor.stop_gradient = False
inputs[param_name] = tensor
return inputs


_simple_qkv = {
'q': TensorSpec(shape_BTHD),
'k': TensorSpec(shape_BTHD),
'v': TensorSpec(shape_BTHD),
}

register_op(OpConfig(
name='chunk_gdn',
import_path='flash_mask.linear_attn.ops.gated_delta_rule',
func_name='chunk_gated_delta_rule',
inputs={
**_simple_qkv,
'g': TensorSpec(shape_BTH, transform=logsigmoid),
'beta': TensorSpec(shape_BTH, transform=sigmoid_transform),
},
extra_kwargs={'use_qk_l2norm_in_kernel': True},
category='gate_beta',
))

register_op(OpConfig(
name='chunk_kda',
import_path='flash_mask.linear_attn.ops.kda',
inputs={
**_simple_qkv,
'g': TensorSpec(shape_BTHD, transform=logsigmoid),
'beta': TensorSpec(shape_BTH, transform=sigmoid_transform),
},
extra_kwargs={'use_qk_l2norm_in_kernel': True, 'safe_gate': True, 'lower_bound': -5},
category='gate_beta',
))

register_op(OpConfig(
name='recurrent_gdn',
import_path='flash_mask.linear_attn.ops.gated_delta_rule',
func_name='fused_recurrent_gated_delta_rule',
inputs={
**_simple_qkv,
'g': TensorSpec(shape_BTH, transform=logsigmoid),
'beta': TensorSpec(shape_BTH, transform=sigmoid_transform),
},
extra_kwargs={'use_qk_l2norm_in_kernel': True},
skip_backward=True,
category='gate_beta',
))

register_op(OpConfig(
name='recurrent_kda',
import_path='flash_mask.linear_attn.ops.kda',
func_name='fused_recurrent_kda',
inputs={
**_simple_qkv,
'g': TensorSpec(shape_BTHD, transform=logsigmoid),
'beta': TensorSpec(shape_BTH, transform=sigmoid_transform),
},
extra_kwargs={'use_qk_l2norm_in_kernel': True, 'safe_gate': True, 'lower_bound': -5},
skip_backward=True,
category='gate_beta',
))
Loading