Problem Description
Used the instructions in the ROCm documentation to attempt to inference GPT-OSS 20B with CK FlashAttentionV2
Initially, in the ROCm6.4.3 container, I tried with the "eager mode" execution for attention, and it worked. The output was coherent, and the throughput was about ~30 tokens per second. However, when I try to inference with FlashAttentionV2, it hangs indefinitely, no output is generated
Then with the ROCm7.0 release, I tried the same code snippet with FlashAttentionV2 with the same prompt, and I am getting the following output with a lower throughput of ~21 tokens per second:
today is from I is from I have I and I with I and I and I and I and I and I and I and I and I and I and I and I and I and
I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I
and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I
and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and i and i and i and i
and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i
and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i
and i and i and i and i and i and i and i and i and i and
OR
today is i is i is i i i is i of i i j i of i j i of i j of i j i of i j of i j i of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j as i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i
The issue: unusable & degenerate repetitive output and FlashAttentionV2 resulting in lower throughput than eager mode execution. Please fix
Operating System
Ubuntu 22.04.5 LTS (Jammy Jellyfish)
CPU
AMD Ryzen 9 7950X 16-Core Processor
GPU
MI300X
ROCm Version
7.0.51831-a3e329ad8
ROCm Component
composable_kernel
Steps to Reproduce
Code snippet to reproduce:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
tok = AutoTokenizer.from_pretrained(model_path,
use_fast = True,
trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path)
gpt_oss = AutoModelForCausalLM.from_pretrained(model_path,
device_map=device,
config = config,
offload_folder=None,
low_cpu_mem_usage=True,
torch_dtype="auto",
attn_implementation="flash_attention_2",
trust_remote_code=True,
local_files_only=True).eval()
prompt = "today is the day of the Lord, and we are the"
inputs = tok(prompt, return_tensors="pt").to(gpt_oss.device)
gen = gpt_oss.generate(**inputs, max_new_tokens=256)
print(tok.decode(gen[0], skip_special_tokens=True))
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
Python version 3.10.12 --> this is what is shipped with the ROCm 7.0 container, the specific tag is "rocm7.0_ubuntu_22.04_vllm_0.10.1_instinct_20250915"
Versioning of relevant libraries:
transformers -> /opt/venv/lib/python3.10/site-packages/transformers/__init__.py 4.56.1
accelerate -> /opt/venv/lib/python3.10/site-packages/accelerate/__init__.py 1.10.1
safetensors -> /opt/venv/lib/python3.10/site-packages/safetensors/__init__.py 0.6.2
kernels -> /opt/venv/lib/python3.10/site-packages/kernels/__init__.py 0.10.1
triton -> /opt/venv/lib/python3.10/site-packages/triton/__init__.py 3.4.0+rocm7.0.2.git56765e8c
flash_attn -> /opt/venv/lib/python3.10/site-packages/flash_attn/__init__.py 2.8.0.post2
Problem Description
Used the instructions in the ROCm documentation to attempt to inference GPT-OSS 20B with CK FlashAttentionV2
Initially, in the ROCm6.4.3 container, I tried with the "eager mode" execution for attention, and it worked. The output was coherent, and the throughput was about ~30 tokens per second. However, when I try to inference with FlashAttentionV2, it hangs indefinitely, no output is generated
Then with the ROCm7.0 release, I tried the same code snippet with FlashAttentionV2 with the same prompt, and I am getting the following output with a lower throughput of ~21 tokens per second:
OR
today is i is i is i i i is i of i i j i of i j i of i j of i j i of i j of i j i of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j as i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of iThe issue: unusable & degenerate repetitive output and FlashAttentionV2 resulting in lower throughput than eager mode execution. Please fix
Operating System
Ubuntu 22.04.5 LTS (Jammy Jellyfish)
CPU
AMD Ryzen 9 7950X 16-Core Processor
GPU
MI300X
ROCm Version
7.0.51831-a3e329ad8
ROCm Component
composable_kernel
Steps to Reproduce
Code snippet to reproduce:
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
Python version 3.10.12 --> this is what is shipped with the ROCm 7.0 container, the specific tag is "rocm7.0_ubuntu_22.04_vllm_0.10.1_instinct_20250915"
Versioning of relevant libraries: