zxdu20 commited on
Commit
23ad39b
1 Parent(s): fdb7a60

Fix decode method for torch tensor

Browse files
Files changed (1) hide show
  1. tokenization_chatglm.py +6 -15
tokenization_chatglm.py CHANGED
@@ -253,29 +253,20 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
253
 
254
  return seq
255
 
256
- def decode(
257
  self,
258
- token_ids: Union[List[int], List[List[int]]],
259
  skip_special_tokens: bool = False,
260
  clean_up_tokenization_spaces: bool = True,
261
- spaces_between_special_tokens: bool = True,
262
  **kwargs
263
  ) -> str:
264
- if not isinstance(token_ids, list):
265
  token_ids = [token_ids]
266
  if len(token_ids) == 0:
267
  return ""
268
- if isinstance(token_ids[0], list):
269
- tokens = []
270
- for single_token_ids in token_ids:
271
- if self.pad_token_id in single_token_ids: # remove pad
272
- single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids))
273
- tokens.append(self.sp_tokenizer.decode(single_token_ids))
274
- return (tokens)
275
- else:
276
- if self.pad_token_id in token_ids: # remove pad
277
- token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
278
- return self.sp_tokenizer.decode(token_ids)
279
 
280
  def _convert_token_to_id(self, token):
281
  """ Converts a token (str) in an id using the vocab. """
 
253
 
254
  return seq
255
 
256
+ def _decode(
257
  self,
258
+ token_ids: Union[int, List[int]],
259
  skip_special_tokens: bool = False,
260
  clean_up_tokenization_spaces: bool = True,
 
261
  **kwargs
262
  ) -> str:
263
+ if isinstance(token_ids, int):
264
  token_ids = [token_ids]
265
  if len(token_ids) == 0:
266
  return ""
267
+ if self.pad_token_id in token_ids: # remove pad
268
+ token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
269
+ return self.sp_tokenizer.decode(token_ids)
 
 
 
 
 
 
 
 
270
 
271
  def _convert_token_to_id(self, token):
272
  """ Converts a token (str) in an id using the vocab. """