zxdu20 commited on
Commit
e1494f2
1 Parent(s): cc96a22

Fix batch input

Browse files
Files changed (1) hide show
  1. tokenization_chatglm.py +3 -3
tokenization_chatglm.py CHANGED
@@ -177,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
177
 
178
  vocab_files_names = {"vocab_file": "ice_text.model"}
179
  max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
180
- model_input_names = ["input_ids"]
181
 
182
  def __init__(
183
  self,
@@ -397,7 +397,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
397
  needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
398
 
399
  # Initialize attention mask if not present.
400
- if needs_to_be_padded or return_attention_mask:
401
  context_length = required_input.index(bos_token_id)
402
  attention_mask = np.ones((1, seq_length, seq_length))
403
  attention_mask = np.tril(attention_mask)
@@ -405,7 +405,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
405
  attention_mask = np.bool_(attention_mask < 0.5)
406
  encoded_inputs["attention_mask"] = attention_mask
407
 
408
- if needs_to_be_padded or return_attention_mask:
409
  mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
410
  mask_position = required_input.index(mask_token)
411
  context_length = required_input.index(bos_token_id)
 
177
 
178
  vocab_files_names = {"vocab_file": "ice_text.model"}
179
  max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
180
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
181
 
182
  def __init__(
183
  self,
 
397
  needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
398
 
399
  # Initialize attention mask if not present.
400
+ if return_attention_mask:
401
  context_length = required_input.index(bos_token_id)
402
  attention_mask = np.ones((1, seq_length, seq_length))
403
  attention_mask = np.tril(attention_mask)
 
405
  attention_mask = np.bool_(attention_mask < 0.5)
406
  encoded_inputs["attention_mask"] = attention_mask
407
 
408
+ if return_attention_mask:
409
  mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
410
  mask_position = required_input.index(mask_token)
411
  context_length = required_input.index(bos_token_id)