zxdu20 commited on
Commit
c949d03
1 Parent(s): 0cfae21

Use dynamic dtype for prompts

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +7 -5
modeling_chatglm.py CHANGED
@@ -804,9 +804,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
804
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
805
  self.word_embeddings = new_embeddings
806
 
807
- def get_prompt(self, batch_size, device):
808
  prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
809
- past_key_values = self.prefix_encoder(prefix_tokens).half()
810
  past_key_values = past_key_values.view(
811
  batch_size,
812
  self.pre_seq_len,
@@ -896,9 +896,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
896
  else:
897
  raise ValueError("You have to specify either input_ids or inputs_embeds")
898
 
 
 
 
899
  if past_key_values is None:
900
  if self.pre_seq_len is not None:
901
- past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
 
902
  else:
903
  past_key_values = tuple([None] * len(self.layers))
904
 
@@ -927,8 +931,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
927
  gmask=use_gmask
928
  )
929
 
930
- if inputs_embeds is None:
931
- inputs_embeds = self.word_embeddings(input_ids)
932
 
933
  # [seq_len, batch, hidden_size]
934
  hidden_states = inputs_embeds.transpose(0, 1)
 
804
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
805
  self.word_embeddings = new_embeddings
806
 
807
+ def get_prompt(self, batch_size, device, dtype=torch.half):
808
  prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
809
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
810
  past_key_values = past_key_values.view(
811
  batch_size,
812
  self.pre_seq_len,
 
896
  else:
897
  raise ValueError("You have to specify either input_ids or inputs_embeds")
898
 
899
+ if inputs_embeds is None:
900
+ inputs_embeds = self.word_embeddings(input_ids)
901
+
902
  if past_key_values is None:
903
  if self.pre_seq_len is not None:
904
+ past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
905
+ dtype=inputs_embeds.dtype)
906
  else:
907
  past_key_values = tuple([None] * len(self.layers))
908
 
 
931
  gmask=use_gmask
932
  )
933
 
 
 
934
 
935
  # [seq_len, batch, hidden_size]
936
  hidden_states = inputs_embeds.transpose(0, 1)