Skip to content

Flash Attention Shared Memory constraint #10

@AntonOresten

Description

@AntonOresten

Hi! Thanks for making this.
I understand this package is still experimental, but wanted to document this constraint, and also ask if anyone knows whether e.g. head dim of 128 would be difficult to implement?

julia> x = CUDA.rand(64, 4096, 4, 4); Jjama3.NNop.flash_attention(x, x, x; causal=true);

julia> x = CUDA.rand(128, 4096, 4, 4); NNop.flash_attention(x, x, x; causal=true);
ERROR: Failed to find groupsize for Flash Attention that satisfies Shared Memory constraint.
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] #flash_attention_groupsize#8
   @ ~/.julia/packages/NNop/cRsoL/src/attention.jl:186 [inlined]
 [3] flash_attention_groupsize
   @ ~/.julia/packages/NNop/cRsoL/src/attention.jl:175 [inlined]
 [4] _flash_attention(q::CuArray{Float32, 4, CUDA.DeviceMemory}, k::CuArray{Float32, 4, CUDA.DeviceMemory}, v::CuArray{Float32, 4, CUDA.DeviceMemory}; causal::Bool)
   @ NNop ~/.julia/packages/NNop/cRsoL/src/attention.jl:127
 [5] _flash_attention
   @ ~/.julia/packages/NNop/cRsoL/src/attention.jl:113 [inlined]
 [6] #flash_attention#18
   @ ~/.julia/packages/NNop/cRsoL/src/attention_crc.jl:5 [inlined]
 [7] top-level scope
   @ REPL[121]:1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions