Skip to content

[Issue]: Inferencing GPT-OSS 20B with FlashAttentionV2 in ROCm7.0 on MI300X results in unusable & degenerate repetitive output #158

@geozhai

Description

@geozhai

Problem Description

Used the instructions in the ROCm documentation to attempt to inference GPT-OSS 20B with CK FlashAttentionV2

Initially, in the ROCm6.4.3 container, I tried with the "eager mode" execution for attention, and it worked. The output was coherent, and the throughput was about ~30 tokens per second. However, when I try to inference with FlashAttentionV2, it hangs indefinitely, no output is generated

Then with the ROCm7.0 release, I tried the same code snippet with FlashAttentionV2 with the same prompt, and I am getting the following output with a lower throughput of ~21 tokens per second:

today is from I is from I have I and I with I and I and I and I and I and I and I and I and I and I and I and I and I and
I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I 
and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I 
and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and I and i and i and i and i 
and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i 
and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i and i 
and i and i and i and i and i and i and i and i and i and

OR

today is i is i is i i i is i of i i j i of i j i of i j of i j i of i j of i j i of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j as i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i j of i

The issue: unusable & degenerate repetitive output and FlashAttentionV2 resulting in lower throughput than eager mode execution. Please fix

Operating System

Ubuntu 22.04.5 LTS (Jammy Jellyfish)

CPU

AMD Ryzen 9 7950X 16-Core Processor

GPU

MI300X

ROCm Version

7.0.51831-a3e329ad8

ROCm Component

composable_kernel

Steps to Reproduce

Code snippet to reproduce:

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

tok = AutoTokenizer.from_pretrained(model_path, 
                                    use_fast = True, 
                                    trust_remote_code=True)

config = AutoConfig.from_pretrained(model_path)

gpt_oss = AutoModelForCausalLM.from_pretrained(model_path, 
                                                  device_map=device,
                                                  config = config, 
                                                  offload_folder=None,
                                                  low_cpu_mem_usage=True,
                                                  torch_dtype="auto",
                                                  attn_implementation="flash_attention_2",
                                                  trust_remote_code=True, 
                                                  local_files_only=True).eval()

prompt = "today is the day of the Lord, and we are the"
inputs = tok(prompt, return_tensors="pt").to(gpt_oss.device)
gen = gpt_oss.generate(**inputs, max_new_tokens=256)
print(tok.decode(gen[0], skip_special_tokens=True))

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

Python version 3.10.12 --> this is what is shipped with the ROCm 7.0 container, the specific tag is "rocm7.0_ubuntu_22.04_vllm_0.10.1_instinct_20250915"

Versioning of relevant libraries:

transformers -> /opt/venv/lib/python3.10/site-packages/transformers/__init__.py 4.56.1
accelerate -> /opt/venv/lib/python3.10/site-packages/accelerate/__init__.py 1.10.1
safetensors -> /opt/venv/lib/python3.10/site-packages/safetensors/__init__.py 0.6.2
kernels -> /opt/venv/lib/python3.10/site-packages/kernels/__init__.py 0.10.1
triton -> /opt/venv/lib/python3.10/site-packages/triton/__init__.py 3.4.0+rocm7.0.2.git56765e8c
flash_attn -> /opt/venv/lib/python3.10/site-packages/flash_attn/__init__.py 2.8.0.post2

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions