zxdu20 commited on
Commit
01717dd
1 Parent(s): 507dfe3

Update tokenizer

Browse files
Files changed (1) hide show
  1. tokenization_glm.py +270 -55
tokenization_glm.py CHANGED
@@ -1,66 +1,41 @@
1
  import os
2
- from typing import Optional, Tuple, List
3
  from shutil import copyfile
4
-
5
  import torch
6
- from transformers import PreTrainedTokenizer
 
7
  from transformers.utils import logging
8
  from transformers.tokenization_utils_base import BatchEncoding
 
 
9
  import sentencepiece as spm
10
 
11
  logger = logging.get_logger(__name__)
12
- VOCAB_FILES_NAMES = {"vocab_file": "cog-pretrain.model"}
13
-
14
-
15
- class GLMChineseTokenizer(PreTrainedTokenizer):
16
- vocab_files_names = VOCAB_FILES_NAMES
17
-
18
- def __init__(self, vocab_file, **kwargs):
19
- super().__init__(**kwargs)
20
-
21
- self.sp_model = spm.SentencePieceProcessor()
22
- self.sp_model.Load(vocab_file)
23
 
24
- @property
25
- def vocab_size(self):
26
- return len(self.sp_model)
27
-
28
- def get_vocab(self):
29
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
30
- vocab.update(self.added_tokens_encoder)
31
- return vocab
32
 
33
- def _tokenize(self, text, **kwargs):
34
- return self.sp_model.encode(text, out_type=str)
35
-
36
- def _convert_token_to_id(self, token):
37
- """Converts a token (str) in an id using the vocab."""
38
- return self.sp_model.PieceToId(token)
39
-
40
- def _convert_id_to_token(self, index):
41
- """Converts an index (integer) in a token (str) using the vocab."""
42
- return self.sp_model.IdToPiece(index)
43
 
44
- def convert_tokens_to_string(self, tokens):
45
- return self.sp_model.decode(tokens)
46
 
47
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
48
- if not os.path.isdir(save_directory):
49
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
50
- return
51
- out_vocab_file = os.path.join(
52
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
53
- )
54
 
55
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
56
- copyfile(self.vocab_file, out_vocab_file)
57
- elif not os.path.isfile(self.vocab_file):
58
- with open(out_vocab_file, "wb") as fi:
59
- content_spiece_model = self.sp_model.serialized_model_proto()
60
- fi.write(content_spiece_model)
 
 
61
 
62
- return (out_vocab_file,)
63
 
 
64
  @property
65
  def sop_token(self) -> Optional[str]:
66
  return "<|startofpiece|>"
@@ -68,7 +43,7 @@ class GLMChineseTokenizer(PreTrainedTokenizer):
68
  @property
69
  def sop_token_id(self) -> Optional[int]:
70
  """
71
- `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling. Returns `None` if the token has not been set.
72
  """
73
  return self.convert_tokens_to_ids(self.sop_token)
74
 
@@ -79,7 +54,7 @@ class GLMChineseTokenizer(PreTrainedTokenizer):
79
  @property
80
  def eop_token_id(self) -> Optional[int]:
81
  """
82
- `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling. Returns `None` if the token has not been set.
83
  """
84
  return self.convert_tokens_to_ids(self.eop_token)
85
 
@@ -91,12 +66,113 @@ class GLMChineseTokenizer(PreTrainedTokenizer):
91
  def smask_token_id(self) -> int:
92
  return self.convert_tokens_to_ids("[sMASK]")
93
 
94
- def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512):
95
- mask_ids = [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  input_ids = model_input.input_ids
97
  batch_size, seq_length = input_ids.shape[:2]
98
  position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
99
  position_ids, block_position_ids = [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  for i in range(batch_size):
101
  mask_positions = []
102
  for mask_id in mask_ids:
@@ -117,11 +193,86 @@ class GLMChineseTokenizer(PreTrainedTokenizer):
117
  dim=0).unsqueeze(0).expand(batch_size, -1, -1)
118
  attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
119
  attention_mask = attention_mask.unsqueeze(1)
120
- input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
121
- return BatchEncoding(
122
- {"input_ids": input_ids, "position_ids": position_ids, "generation_attention_mask": attention_mask}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  )
124
 
 
 
 
 
 
 
 
 
 
125
  def build_inputs_with_special_tokens(
126
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
127
  ) -> List[int]:
@@ -145,3 +296,67 @@ class GLMChineseTokenizer(PreTrainedTokenizer):
145
  cls = [self.cls_token_id]
146
  eos = [self.eos_token_id]
147
  return cls + token_ids_0 + eos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Optional, Tuple, List, Union
3
  from shutil import copyfile
 
4
  import torch
5
+
6
+ from transformers import PreTrainedTokenizer, RobertaTokenizer, GPT2Tokenizer, BertTokenizer
7
  from transformers.utils import logging
8
  from transformers.tokenization_utils_base import BatchEncoding
9
+ from transformers.models.auto.tokenization_auto import get_tokenizer_config
10
+ from transformers.utils.generic import _is_torch_device
11
  import sentencepiece as spm
12
 
13
  logger = logging.get_logger(__name__)
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
15
 
16
+ class GLMBatchEncoding(BatchEncoding):
17
+ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
18
+ """
19
+ Send all values to device by calling `v.to(device)` (PyTorch only).
 
 
 
 
 
 
20
 
21
+ Args:
22
+ device (`str` or `torch.device`): The device to put the tensors on.
23
 
24
+ Returns:
25
+ [`BatchEncoding`]: The same instance after modification.
26
+ """
 
 
 
 
27
 
28
+ # This check catches things like APEX blindly calling "to" on all inputs to a module
29
+ # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
30
+ # into a HalfTensor
31
+ if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
32
+ self.data = {k: v.to(device=device) if torch.is_tensor(v) else v for k, v in self.data.items()}
33
+ else:
34
+ logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
35
+ return self
36
 
 
37
 
38
+ class GLMTokenizerMixin:
39
  @property
40
  def sop_token(self) -> Optional[str]:
41
  return "<|startofpiece|>"
 
43
  @property
44
  def sop_token_id(self) -> Optional[int]:
45
  """
46
+ `Optional[int]`: Id of the start token in the vocabulary, used when training a model with autoregressive blank filling.
47
  """
48
  return self.convert_tokens_to_ids(self.sop_token)
49
 
 
54
  @property
55
  def eop_token_id(self) -> Optional[int]:
56
  """
57
+ `Optional[int]`: Id of the end token in the vocabulary, used when training a model with autoregressive blank filling.
58
  """
59
  return self.convert_tokens_to_ids(self.eop_token)
60
 
 
66
  def smask_token_id(self) -> int:
67
  return self.convert_tokens_to_ids("[sMASK]")
68
 
69
+ @property
70
+ def mask_token_ids(self):
71
+ return [self.mask_token_id, self.smask_token_id, self.gmask_token_id]
72
+
73
+ def _build_input_for_multiple_choice(self, context, choices):
74
+ context_id = context["input_ids"]
75
+ if torch.is_tensor(context_id):
76
+ context_id = context_id.tolist()
77
+
78
+ division = len(context_id)
79
+ mask_position = context_id.index(self.mask_token_id)
80
+
81
+ token = torch.tensor(context_id, dtype=torch.long)
82
+ attention_mask = [context["attention_mask"].expand(division, -1)]
83
+ position_id = torch.arange(division, dtype=torch.long)
84
+ block_position_id = torch.zeros(division, dtype=torch.long)
85
+
86
+ choice_ids, choice_indices = [], []
87
+
88
+ for choice_str in choices:
89
+ choice = torch.tensor(self(choice_str, add_special_tokens=False, padding=False)['input_ids'],
90
+ dtype=torch.long)
91
+ choice_ids.append(choice)
92
+ choice_indices.append(torch.arange(len(token), len(token) + len(choice), dtype=torch.long))
93
+ attention_mask.append(torch.tril(torch.ones((len(choice), len(choice)), dtype=torch.long)))
94
+
95
+ token = torch.cat((token, torch.tensor([self.sop_token_id], dtype=torch.long), choice[:-1]))
96
+ position_id = torch.cat((position_id, torch.tensor([mask_position] * len(choice), dtype=torch.long)))
97
+ block_position_id = torch.cat((block_position_id, torch.arange(1, 1 + len(choice), dtype=torch.long)))
98
+
99
+ attention_mask = torch.block_diag(*attention_mask)
100
+ attention_mask[division:, :division] = context["attention_mask"].unsqueeze(0)
101
+
102
+ return {
103
+ "input_ids": token,
104
+ "position_ids": torch.stack((position_id, block_position_id)),
105
+ "attention_mask": attention_mask,
106
+ "choice_ids": choice_ids,
107
+ "choice_indices": choice_indices
108
+ }
109
+
110
+ def _pad_batch(self, tokens, position_ids, attention_mask, max_seq_length):
111
+ pad_length = max_seq_length - len(tokens)
112
+ attention_mask = torch.nn.functional.pad(
113
+ attention_mask,
114
+ (0, pad_length, 0, pad_length),
115
+ mode="constant",
116
+ value=0,
117
+ )
118
+ tokens = torch.cat((tokens, torch.zeros(pad_length, dtype=torch.long)))
119
+ position_ids = torch.cat((position_ids, position_ids[..., -1:].expand(-1, pad_length)), dim=-1)
120
+ return tokens, position_ids, attention_mask
121
+
122
+ def _collate(self, samples):
123
+ TILE = 1
124
+ length_to_pad = (max(map(lambda spl: len(spl["input_ids"]), samples)) + TILE - 1) // TILE * TILE
125
+
126
+ token_batch, position_id_batch, attention_mask_batch = [], [], []
127
+ choices_batch, choice_target_ids_batch = [], []
128
+
129
+ for sample in samples:
130
+ token, position_id, attention_mask = self._pad_batch(
131
+ sample["input_ids"], sample["position_ids"], sample["attention_mask"], length_to_pad
132
+ )
133
+ token_batch.append(token)
134
+ position_id_batch.append(position_id)
135
+ attention_mask_batch.append(attention_mask)
136
+ choices_batch.append(sample["choice_ids"])
137
+ choice_target_ids_batch.append(sample["choice_indices"])
138
+ return {
139
+ "input_ids": torch.stack(token_batch),
140
+ "position_ids": torch.stack(position_id_batch),
141
+ "attention_mask": torch.stack(attention_mask_batch).unsqueeze(1),
142
+ "choice_ids": choices_batch,
143
+ "choice_indices": choice_target_ids_batch,
144
+ }
145
+
146
+ def build_inputs_for_multiple_choice(self, model_input: BatchEncoding, choices, max_length=None):
147
+ samples = [{key: value[i] for key, value in model_input.items()} for i in range(len(model_input["input_ids"]))]
148
+ samples = [self._build_input_for_multiple_choice(sample, choice) for sample, choice in
149
+ zip(samples, choices)]
150
+ inputs = self._collate(samples)
151
+ return GLMBatchEncoding(inputs)
152
+
153
+ def build_inputs_for_generation(self, model_input: BatchEncoding, max_gen_length=512, targets=None, padding=False):
154
+ mask_ids = self.mask_token_ids
155
  input_ids = model_input.input_ids
156
  batch_size, seq_length = input_ids.shape[:2]
157
  position_id, block_position_id = list(range(seq_length)), [0 for _ in range(seq_length)]
158
  position_ids, block_position_ids = [], []
159
+ labels = None
160
+ if targets is not None:
161
+ is_batched = isinstance(targets, (list, tuple))
162
+ targets = self(targets, add_special_tokens=False, padding=False).input_ids
163
+ if not is_batched:
164
+ targets = [targets]
165
+ assert len(targets) == len(input_ids)
166
+ targets = [(target + [self.eop_token_id])[:max_gen_length] for target in targets]
167
+ if not padding:
168
+ max_gen_length = max(map(len, targets))
169
+ targets = [[self.sop_token_id] + target for target in targets]
170
+ labels = [target[1:] for target in targets]
171
+ targets = [target + [self.pad_token_id] * (max_gen_length + 1 - len(target)) for target in targets]
172
+ labels = [label + [-100] * (max_gen_length - len(label)) for label in labels]
173
+ targets = torch.tensor(targets, dtype=input_ids.dtype, device=input_ids.device)
174
+ labels = torch.tensor(labels, dtype=input_ids.dtype, device=input_ids.device)
175
+ labels = torch.cat((input_ids.new_full((batch_size, seq_length), -100), labels), dim=1)
176
  for i in range(batch_size):
177
  mask_positions = []
178
  for mask_id in mask_ids:
 
193
  dim=0).unsqueeze(0).expand(batch_size, -1, -1)
194
  attention_mask = torch.cat((attention_mask, generation_attention_mask), dim=2)
195
  attention_mask = attention_mask.unsqueeze(1)
196
+ if targets is None:
197
+ input_ids = torch.cat((input_ids, input_ids.new_full((batch_size, 1), self.sop_token_id)), dim=-1)
198
+ else:
199
+ input_ids = torch.cat((input_ids, targets[:, :-1]), dim=1)
200
+ batch = {"input_ids": input_ids, "position_ids": position_ids}
201
+ if labels is None:
202
+ batch["generation_attention_mask"] = attention_mask
203
+ else:
204
+ batch["attention_mask"] = attention_mask
205
+ batch["labels"] = labels
206
+ return BatchEncoding(batch)
207
+
208
+
209
+ class GLMRobertaTokenizer(RobertaTokenizer, GLMTokenizerMixin):
210
+ model_input_names = ["input_ids", "position_ids", "attention_mask"]
211
+ truncation_side: str = "left"
212
+
213
+ @property
214
+ def gmask_token_id(self) -> int:
215
+ raise NotImplementedError("The model doesn't support gMASK")
216
+
217
+ @property
218
+ def smask_token_id(self) -> int:
219
+ raise NotImplementedError("The model doesn't support sMASK")
220
+
221
+ @property
222
+ def mask_token_ids(self):
223
+ return [self.mask_token_id]
224
+
225
+
226
+ class GLMChineseTokenizer(PreTrainedTokenizer, GLMTokenizerMixin):
227
+ vocab_files_names = {"vocab_file": "cog-pretrain.model"}
228
+ truncation_side: str = "left"
229
+
230
+ def __init__(self, vocab_file, **kwargs):
231
+ super().__init__(**kwargs)
232
+ self.vocab_file = vocab_file
233
+ self.sp_model = spm.SentencePieceProcessor()
234
+ self.sp_model.Load(vocab_file)
235
+
236
+ @property
237
+ def vocab_size(self):
238
+ return len(self.sp_model)
239
+
240
+ def get_vocab(self):
241
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
242
+ vocab.update(self.added_tokens_encoder)
243
+ return vocab
244
+
245
+ def _tokenize(self, text, **kwargs):
246
+ return self.sp_model.encode(text, out_type=str)
247
+
248
+ def _convert_token_to_id(self, token):
249
+ """Converts a token (str) in an id using the vocab."""
250
+ return self.sp_model.PieceToId(token)
251
+
252
+ def _convert_id_to_token(self, index):
253
+ """Converts an index (integer) in a token (str) using the vocab."""
254
+ return self.sp_model.IdToPiece(index)
255
+
256
+ def convert_tokens_to_string(self, tokens):
257
+ return self.sp_model.decode(tokens)
258
+
259
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
260
+ if not os.path.isdir(save_directory):
261
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
262
+ return
263
+ out_vocab_file = os.path.join(
264
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
265
  )
266
 
267
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
268
+ copyfile(self.vocab_file, out_vocab_file)
269
+ elif not os.path.isfile(self.vocab_file):
270
+ with open(out_vocab_file, "wb") as fi:
271
+ content_spiece_model = self.sp_model.serialized_model_proto()
272
+ fi.write(content_spiece_model)
273
+
274
+ return (out_vocab_file,)
275
+
276
  def build_inputs_with_special_tokens(
277
  self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
278
  ) -> List[int]:
 
296
  cls = [self.cls_token_id]
297
  eos = [self.eos_token_id]
298
  return cls + token_ids_0 + eos
299
+
300
+
301
+ class GLMGPT2Tokenizer(GPT2Tokenizer, GLMTokenizerMixin):
302
+ model_input_names = ["input_ids", "position_ids", "attention_mask"]
303
+ truncation_side: str = "left"
304
+
305
+ def build_inputs_with_special_tokens(
306
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
307
+ ) -> List[int]:
308
+ """
309
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
310
+ adding special tokens. A BERT sequence has the following format:
311
+
312
+ - single sequence: ``[CLS] X [SEP]``
313
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
314
+
315
+ Args:
316
+ token_ids_0 (:obj:`List[int]`):
317
+ List of IDs to which the special tokens will be added.
318
+ token_ids_1 (:obj:`List[int]`, `optional`):
319
+ Optional second list of IDs for sequence pairs.
320
+
321
+ Returns:
322
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
323
+ """
324
+ assert token_ids_1 is None
325
+ cls = [self.cls_token_id]
326
+ eos = [self.eos_token_id]
327
+ return cls + token_ids_0 + eos
328
+
329
+
330
+ class GLMBertTokenizer(BertTokenizer, GLMTokenizerMixin):
331
+ model_input_names = ["input_ids", "position_ids", "attention_mask"]
332
+ truncation_side: str = "left"
333
+
334
+ @property
335
+ def gmask_token_id(self) -> int:
336
+ raise NotImplementedError("The model doesn't support gMASK")
337
+
338
+ @property
339
+ def smask_token_id(self) -> int:
340
+ raise NotImplementedError("The model doesn't support sMASK")
341
+
342
+ @property
343
+ def mask_token_ids(self):
344
+ return [self.mask_token_id]
345
+
346
+
347
+ class GLMTokenizer:
348
+ @classmethod
349
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
350
+ tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
351
+ config_tokenizer_class = tokenizer_config.get("tokenizer_class")
352
+ if config_tokenizer_class == "GLMRobertaTokenizer":
353
+ tokenizer_class = GLMRobertaTokenizer
354
+ elif config_tokenizer_class == "GLMChineseTokenizer":
355
+ tokenizer_class = GLMChineseTokenizer
356
+ elif config_tokenizer_class == "GLMGPT2Tokenizer":
357
+ tokenizer_class = GLMGPT2Tokenizer
358
+ elif config_tokenizer_class == "GLMBertTokenizer":
359
+ tokenizer_class = GLMBertTokenizer
360
+ else:
361
+ raise NotImplementedError("Not implemented tokenizer type:", config_tokenizer_class)
362
+ return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)