diff --git a/tests/ops/test_gqa.py b/tests/ops/test_gqa.py index 69275c1..1d0208b 100644 --- a/tests/ops/test_gqa.py +++ b/tests/ops/test_gqa.py @@ -1,5 +1,6 @@ import argparse +import pytest import torch from benchmarks import GroupQueryAttentionBwdBenchmark, GroupQueryAttentionFwdBenchmark @@ -7,14 +8,14 @@ from top.utils import str2dtype -def test_gqa_fwd(batch: int, - seq_len: int, - heads: int, - heads_kv: int, - dim: int, - causal: bool, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize("batch, seq_len, heads, heads_kv, dim, causal, dtype, tune", [ + (1, 1024, 8, 4, 64, False, torch.float16, False), + (2, 2048, 16, 8, 128, True, torch.float16, False), + (1, 1024, 8, 4, 64, False, torch.bfloat16, True), + (2, 2048, 16, 8, 128, True, torch.bfloat16, True), +]) +def test_gqa_fwd(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, causal: bool, + dtype: torch.dtype, tune: bool) -> None: op = GroupQueryAttentionFwdOp(batch, heads, heads_kv, seq_len, dim, causal, dtype, tune=tune) benchmark = GroupQueryAttentionFwdBenchmark(batch, heads, heads_kv, seq_len, dim, causal, dtype) @@ -24,14 +25,14 @@ def test_gqa_fwd(batch: int, benchmark.profile(op, *inputs) -def test_gqa_bwd(batch: int, - seq_len: int, - heads: int, - heads_kv: int, - dim: int, - causal: bool, - dtype: torch.dtype, - tune: bool = False) -> None: +@pytest.mark.parametrize("batch, seq_len, heads, heads_kv, dim, causal, dtype, tune", [ + (1, 512, 8, 4, 64, False, torch.float16, False), + (2, 1024, 16, 8, 128, True, torch.float16, False), + (1, 512, 8, 4, 64, False, torch.bfloat16, True), + (2, 1024, 16, 8, 128, True, torch.bfloat16, True), +]) +def test_gqa_bwd(batch: int, seq_len: int, heads: int, heads_kv: int, dim: int, causal: bool, + dtype: torch.dtype, tune: bool) -> None: op = GroupQueryAttentionBwdOp(batch, heads, heads_kv, seq_len, dim, causal, dtype, tune=tune) benchmark = GroupQueryAttentionBwdBenchmark(batch, heads, heads_kv, seq_len, dim, causal, dtype) diff --git a/tests/ops/test_mha.py b/tests/ops/test_mha.py index 03f12a7..358e49a 100644 --- a/tests/ops/test_mha.py +++ b/tests/ops/test_mha.py @@ -1,28 +1,43 @@ import argparse +import pytest +import torch from benchmarks import MultiHeadAttentionBwdBenchmark, MultiHeadAttentionFwdBenchmark from top.ops import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp from top.utils import str2dtype -def test_mha_fwd(batch, seq_len, heads, dim, causal, dtype, tune=False): +@pytest.mark.parametrize("batch, seq_len, heads, dim, causal, dtype, tune", [ + (1, 1024, 8, 64, False, torch.float16, False), + (2, 2048, 16, 128, True, torch.float16, False), + (1, 1024, 8, 64, False, torch.float16, True), + (2, 2048, 16, 128, True, torch.bfloat16, True), +]) +def test_mha_fwd(batch, seq_len, heads, dim, causal, dtype, tune): op = MultiHeadAttentionFwdOp(batch, heads, seq_len, dim, causal, dtype, tune=tune) benchmark = MultiHeadAttentionFwdBenchmark(batch, heads, seq_len, dim, causal, dtype) inputs = benchmark.gen_inputs() - print("Forward Results:") + print( + f"Forward Results for batch={batch}, seq_len={seq_len}, heads={heads}, dim={dim}, causal={causal}, dtype={dtype}, tune={tune}:" + ) benchmark.check(op, *inputs) - benchmark.profile(op, *inputs) -def test_mha_bwd(batch, seq_len, heads, dim, causal, dtype, tune=False): +@pytest.mark.parametrize("batch, seq_len, heads, dim, causal, dtype, tune", [ + (1, 1024, 8, 64, False, torch.float16, False), + (2, 2048, 16, 128, True, torch.float16, False), + (1, 1024, 8, 64, False, torch.float16, True), +]) +def test_mha_bwd(batch, seq_len, heads, dim, causal, dtype, tune): op = MultiHeadAttentionBwdOp(batch, heads, seq_len, dim, causal, dtype, tune=tune) benchmark = MultiHeadAttentionBwdBenchmark(batch, heads, seq_len, dim, causal, dtype) inputs = benchmark.gen_inputs() - print("Backward Results:") + print( + f"Backward Results for batch={batch}, seq_len={seq_len}, heads={heads}, dim={dim}, causal={causal}, dtype={dtype}, tune={tune}:" + ) benchmark.check(op, *inputs) - benchmark.profile(op, *inputs) if __name__ == "__main__":