zRzRzRzRzRzRzR commited on
Commit
f308259
1 Parent(s): 37fe000

add set_input_embeddings(self, value):

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +3 -0
modeling_chatglm.py CHANGED
@@ -769,6 +769,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
769
  def get_input_embeddings(self):
770
  return self.embedding.word_embeddings
771
 
 
 
 
772
  def get_prompt(self, batch_size, device, dtype=torch.half):
773
  prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
774
  past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
 
769
  def get_input_embeddings(self):
770
  return self.embedding.word_embeddings
771
 
772
+ def set_input_embeddings(self, value):
773
+ self.embedding.word_embeddings = value
774
+
775
  def get_prompt(self, batch_size, device, dtype=torch.half):
776
  prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
777
  past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)