chatglm2-6b微调过程中的attention_mask具体是如何实现的?哪一部分是双向,哪一部分是单向?

#97
by someone652314 - opened

ChatGLMPreTrainedModel类对应的函数

def get_masks(self, input_ids, past_key_values, padding_mask=None): # padding_mask=传入的attention_mask,维度=[batch_size, seq_len]
batch_size, seq_length = input_ids.shape
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
full_attention_mask.tril_()
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[0]
if past_length: # full_attention_mask维度 = [batch_size, seq_length, past_length+seq_length]
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
device=input_ids.device), full_attention_mask), dim=-1)
if padding_mask is not None:
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) # padding_mask.unsqueeze(1) = [batch_size, 1, seq_len]
if not past_length and padding_mask is not None:
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
full_attention_mask = (full_attention_mask < 0.5).bool()
full_attention_mask.unsqueeze_(1)
return full_attention_mask

从这个函数来看,是否只有prefix_encoder生成的prompt对应的attention是双向的,其余模型输入的文本都是单向的attention?

Sign up or log in to comment