Skip to content

byte_infer_perf中,以byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py为例,模型是否支持因果掩码(即causal)功能? #172

@zlk0306

Description

@zlk0306

在byte_infer_perf中,以byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py为例:

1、LlamaFlashAttention2中的causal功能由is_causal判断。
但是LlamaFlashAttention2同样使用了父类LlamaAttention中的attention_mask,attention_mask的用法为:
1)attention_mask在_flash_attention_forward下的解释为

attention_mask (torch.Tensor):
The padding mask - corresponds to a tensor of size **(batch_size, seq_len)** where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.

2)attention_mask传入self._upad_input后,用于对Q进行padding,具体见下

# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

2、LlamaAttention中的causal有参数causal_mask,虽然命名为causal_mask,但其定义为causal_mask = attention_mask[:, :, :, : key_states.shape[-2]],而由LlamaFlashAttention2可知,attention_mask的shape为**(batch_size, seq_len)**,且用于变长序列的padding。因此,我们推测causal_mask并未实现因果掩码功能。

3、同样的,在LlamaSdpaAttention中,causal_mask = ~attention_mask,causal_mask同样未实现因果掩码功能。

综上所述,我们存在以下两个问题:
首先,在LlamaAttention和LlamaSdpaAttention中,attention_mask实现的是变长序列的padding功能,还是因果掩码causal功能?
其次,请问LlamaAttention和LlamaSdpaAttention是否支持因果掩码?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions