在byte_infer_perf中,以byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py为例:
1、LlamaFlashAttention2中的causal功能由is_causal判断。
但是LlamaFlashAttention2同样使用了父类LlamaAttention中的attention_mask,attention_mask的用法为:
1)attention_mask在_flash_attention_forward下的解释为
“
attention_mask (torch.Tensor):
The padding mask - corresponds to a tensor of size **(batch_size, seq_len)** where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
”
2)attention_mask传入self._upad_input后,用于对Q进行padding,具体见下
“
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
”
2、LlamaAttention中的causal有参数causal_mask,虽然命名为causal_mask,但其定义为causal_mask = attention_mask[:, :, :, : key_states.shape[-2]],而由LlamaFlashAttention2可知,attention_mask的shape为**(batch_size, seq_len)**,且用于变长序列的padding。因此,我们推测causal_mask并未实现因果掩码功能。
3、同样的,在LlamaSdpaAttention中,causal_mask = ~attention_mask,causal_mask同样未实现因果掩码功能。
综上所述,我们存在以下两个问题:
首先,在LlamaAttention和LlamaSdpaAttention中,attention_mask实现的是变长序列的padding功能,还是因果掩码causal功能?
其次,请问LlamaAttention和LlamaSdpaAttention是否支持因果掩码?
在byte_infer_perf中,以byte_infer_perf/llm_perf/backends/GPU/model_impl/modeling_llama3.py为例:
1、LlamaFlashAttention2中的causal功能由is_causal判断。
但是LlamaFlashAttention2同样使用了父类LlamaAttention中的attention_mask,attention_mask的用法为:
1)attention_mask在_flash_attention_forward下的解释为
“
attention_mask (
torch.Tensor):The padding mask - corresponds to a tensor of size
**(batch_size, seq_len)**where 0 stands for theposition of padding tokens and 1 for the position of non-padding tokens.
”
2)attention_mask传入self._upad_input后,用于对Q进行padding,具体见下
“
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
”
2、LlamaAttention中的causal有参数causal_mask,虽然命名为causal_mask,但其定义为causal_mask = attention_mask[:, :, :, : key_states.shape[-2]],而由LlamaFlashAttention2可知,attention_mask的shape为**(batch_size, seq_len)**,且用于变长序列的padding。因此,我们推测causal_mask并未实现因果掩码功能。
3、同样的,在LlamaSdpaAttention中,causal_mask = ~attention_mask,causal_mask同样未实现因果掩码功能。
综上所述,我们存在以下两个问题:
首先,在LlamaAttention和LlamaSdpaAttention中,attention_mask实现的是变长序列的padding功能,还是因果掩码causal功能?
其次,请问LlamaAttention和LlamaSdpaAttention是否支持因果掩码?