ltcs15 commited on
Commit
c54139a
1 Parent(s): 35ca523

add custom onnx export config

Browse files

currently only support for batch=1 and use_past=False;
and need to fix several model structure to support export into ONNX

Files changed (1) hide show
  1. configuration_chatglm.py +310 -0
configuration_chatglm.py CHANGED
@@ -1,8 +1,15 @@
1
  """ ChatGLM model configuration """
 
 
 
 
2
 
3
  from transformers.configuration_utils import PretrainedConfig
4
  from transformers.utils import logging
5
 
 
 
 
6
  logger = logging.get_logger(__name__)
7
 
8
 
@@ -101,3 +108,306 @@ class ChatGLMConfig(PretrainedConfig):
101
  eos_token_id=eos_token_id,
102
  **kwargs
103
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """ ChatGLM model configuration """
2
+ import torch
3
+
4
+ from collections import OrderedDict
5
+ from typing import List, Mapping, Optional, Any
6
 
7
  from transformers.configuration_utils import PretrainedConfig
8
  from transformers.utils import logging
9
 
10
+ from transformers.onnx import OnnxConfigWithPast, PatchingSpec
11
+ from transformers import PreTrainedTokenizer, TensorType, is_torch_available
12
+
13
  logger = logging.get_logger(__name__)
14
 
15
 
 
108
  eos_token_id=eos_token_id,
109
  **kwargs
110
  )
111
+
112
+
113
+ class ChatGLMOnnxConfig(OnnxConfigWithPast):
114
+ r"""
115
+ This class is the custom configuration of a ChatGLMModel needed in exporting model to ONNX.
116
+ Currently this need to pre-fix several model struct in modeling_chatglm.py
117
+
118
+ Also there is still a TODO list of current ChatGLMOnnxConfig:
119
+ 1. add support for batch_size > 1
120
+ 2. add support for use_past
121
+
122
+ in modeling_chatglm.py and its attention_fn function,we need to change several view into
123
+ torch tensor action since reshape param may get frozen into constant in onnx model.
124
+ here is the code:
125
+ ```python
126
+ >>> def attention_fn(
127
+ >>> self,
128
+ >>> query_layer,
129
+ >>> key_layer,
130
+ >>> value_layer,
131
+ >>> attention_mask,
132
+ >>> hidden_size_per_partition,
133
+ >>> layer_id,
134
+ >>> layer_past=None,
135
+ >>> scaling_attention_score=True,
136
+ >>> use_cache=False,
137
+ >>> ):
138
+ >>> if layer_past is not None:
139
+ >>> past_key, past_value = layer_past[0], layer_past[1]
140
+ >>> key_layer = torch.cat((past_key, key_layer), dim=0)
141
+ >>> value_layer = torch.cat((past_value, value_layer), dim=0)
142
+ >>>
143
+ >>> # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
144
+ >>> seq_len, b, nh, hidden_size = key_layer.shape
145
+ >>>
146
+ >>> if use_cache:
147
+ >>> present = (key_layer, value_layer)
148
+ >>> else:
149
+ >>> present = None
150
+ >>>
151
+ >>> query_key_layer_scaling_coeff = float(layer_id + 1)
152
+ >>> if scaling_attention_score:
153
+ >>> query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
154
+ >>>
155
+ >>> # ===================================
156
+ >>> # Raw attention scores. [b, np, s, s]
157
+ >>> # ===================================
158
+ >>>
159
+ >>> # [b, np, sq, sk]
160
+ >>> # # output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
161
+ >>>
162
+ >>> # [sq, b, np, hn] -> [sq, b * np, hn]
163
+ >>> # query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
164
+ >>> query_layer = query_layer.flatten(start_dim=1, end_dim=2)
165
+ >>>
166
+ >>> # [sk, b, np, hn] -> [sk, b * np, hn]
167
+ >>> # key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
168
+ >>> key_layer = key_layer.flatten(start_dim=1, end_dim=2)
169
+ >>>
170
+ >>> matmul_result = torch.zeros(
171
+ >>> 1, 1, 1,
172
+ >>> dtype=query_layer.dtype,
173
+ >>> device=query_layer.device,
174
+ >>> )
175
+ >>>
176
+ >>> matmul_result = torch.baddbmm(
177
+ >>> matmul_result,
178
+ >>> query_layer.transpose(0, 1), # [b * np, sq, hn]
179
+ >>> key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
180
+ >>> beta=0.0,
181
+ >>> alpha=1.0,
182
+ >>> )
183
+ >>>
184
+ >>> # [b * np, sq, sk] -> [b, np, sq, sk]
185
+ >>> # attention_scores = matmul_result.view(*output_size)
186
+ >>> attention_scores = matmul_result.unsqueeze(0)
187
+ >>>
188
+ >>> if self.scale_mask_softmax:
189
+ >>> self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
190
+ >>> attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
191
+ >>> else:
192
+ >>> # if not (attention_mask == 0).all():
193
+ >>> # # if auto-regressive, skip
194
+ >>> attention_scores.masked_fill_(attention_mask, -10000.0)
195
+ >>> dtype = attention_scores.dtype
196
+ >>> attention_scores = attention_scores.float()
197
+ >>> attention_scores = attention_scores * query_key_layer_scaling_coeff
198
+ >>>
199
+ >>> attention_probs = F.softmax(attention_scores, dim=-1)
200
+ >>>
201
+ >>> attention_probs = attention_probs.type(dtype)
202
+ >>>
203
+ >>> # =========================
204
+ >>> # Context layer. [sq, b, hp]
205
+ >>> # =========================
206
+ >>>
207
+ >>> # value_layer -> context layer.
208
+ >>> # [sk, b, np, hn] --> [b, np, sq, hn]
209
+ >>>
210
+ >>> # context layer shape: [b, np, sq, hn]
211
+ >>> # output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
212
+ >>>
213
+ >>> # change view [sk, b * np, hn]
214
+ >>> # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
215
+ >>> value_layer = value_layer.flatten(start_dim=1, end_dim=2)
216
+ >>>
217
+ >>> # change view [b * np, sq, sk]
218
+ >>> # attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
219
+ >>> attention_probs = attention_probs.flatten(start_dim=0, end_dim=1)
220
+ >>>
221
+ >>> # matmul: [b * np, sq, hn]
222
+ >>> context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
223
+ >>>
224
+ >>> # change view [b, np, sq, hn]
225
+ >>> # context_layer = context_layer.reshape(b, np, sq, hidden_size)
226
+ >>> context_layer = context_layer.unsqueeze(0)
227
+ >>>
228
+ >>> # [b, np, sq, hn] --> [sq, b, np, hn]
229
+ >>> context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
230
+ >>>
231
+ >>> # [sq, b, np, hn] --> [sq, b, hp]
232
+ >>> # new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
233
+ >>> # context_layer = context_layer.view(*new_context_layer_shape)
234
+ >>> context_layer = context_layer.flatten(start_dim=2)
235
+ >>>
236
+ >>> outputs = (context_layer, present, attention_probs)
237
+ >>>
238
+ >>> return outputs
239
+ '''
240
+ mainly aviod using view with dynamic size
241
+
242
+ after change the modeling_chatglm.py, you can simply use following code to export and test the onnx model
243
+ ```python
244
+ >>> from pathlib import Path
245
+ >>> from transformers import AutoTokenizer, AutoModel
246
+ >>> from transformers.onnx import export, validate_model_outputs
247
+ >>>
248
+ >>> # load model
249
+ >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
250
+ >>> pt_model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
251
+ >>> pt_model = pt_model.float() # only tested in CPU for now
252
+ >>> pt_model.eval()
253
+ >>> # define path for saving onnx model
254
+ >>> onnx_path = Path(f"model/chatglm-6b.onnx")
255
+ >>> onnx_path.parent.mkdir(exist_ok=True)
256
+ >>> # convert model to onnx
257
+ >>> onnx_config_chatglm = ChatGLMOnnxConfig(pt_model.config, task="causal-lm")
258
+ >>> onnx_inputs, onnx_outputs = export(tokenizer, pt_model,
259
+ >>> onnx_config_chatglm, onnx_config_chatglm.default_onnx_opset,
260
+ >>> onnx_path)
261
+ >>> # test onnx model
262
+ >>> validate_model_outputs(onnx_config_chatglm, tokenizer, pt_model, onnx_path, onnx_outputs, atol=1e-4)
263
+ ```
264
+ """
265
+ # TODO support dynamic batch size
266
+ default_fixed_batch = 1
267
+
268
+ def __init__(
269
+ self,
270
+ config: PretrainedConfig,
271
+ task: str = "default",
272
+ patching_specs: List[PatchingSpec] = None,
273
+ use_past: bool = False,
274
+ ):
275
+ super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
276
+
277
+ @property
278
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
279
+ common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
280
+ if self.use_past:
281
+ # TODO support use_past
282
+ # self.fill_with_past_key_values_(common_inputs, direction="inputs")
283
+ # common_inputs["attention_mask"] = \
284
+ # {0: "batch", 1: "past_sequence + sequence", 2: "past_sequence + sequence"}
285
+ raise NotImplementedError('position_ids do not support past_key_values yet.')
286
+ else:
287
+ # remind the order
288
+ common_inputs["position_ids"] = {0: "batch", 2: "sequence"}
289
+ common_inputs["attention_mask"] = {0: "batch", 2: "sequence", 3: "sequence"}
290
+
291
+ return common_inputs
292
+
293
+ @property
294
+ def num_layers(self) -> int:
295
+ return self._config.n_layer
296
+
297
+ @property
298
+ def num_attention_heads(self) -> int:
299
+ return self._config.n_head
300
+
301
+ def get_masks(self, input_ids, device=None):
302
+ """
303
+ reference from modeling_chatglm.get_masks
304
+ """
305
+ batch_size, seq_length = input_ids.shape
306
+ context_lengths = [seq.tolist().index(self._config.bos_token_id) for seq in input_ids]
307
+ if device:
308
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
309
+ else:
310
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=input_ids.device)
311
+ attention_mask.tril_()
312
+ for i, context_length in enumerate(context_lengths):
313
+ attention_mask[i, :, :context_length] = 1
314
+ attention_mask.unsqueeze_(1)
315
+ attention_mask = (attention_mask < 0.5).bool()
316
+
317
+ # print("attention_mask", attention_mask.shape)
318
+ return attention_mask
319
+
320
+ def get_position_ids(self, input_ids, mask_positions, device=None, use_gmasks=None):
321
+ batch_size, seq_length = input_ids.shape
322
+ if device is None:
323
+ device = input_ids.device
324
+ if use_gmasks is None:
325
+ use_gmasks = [False] * batch_size
326
+ context_lengths = [seq.tolist().index(self._config.bos_token_id) for seq in input_ids]
327
+ if self._config.position_encoding_2d:
328
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
329
+ for i, context_length in enumerate(context_lengths):
330
+ position_ids[i, context_length:] = mask_positions[i]
331
+ block_position_ids = [torch.cat((
332
+ torch.zeros(context_length, dtype=torch.long, device=device),
333
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
334
+ )) for context_length in context_lengths]
335
+ block_position_ids = torch.stack(block_position_ids, dim=0)
336
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
337
+ else:
338
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
339
+ for i, context_length in enumerate(context_lengths):
340
+ if not use_gmasks[i]:
341
+ position_ids[context_length:] = mask_positions[i]
342
+
343
+ # print("position_ids", position_ids.shape)
344
+ return position_ids
345
+
346
+ def generate_dummy_inputs(
347
+ self,
348
+ tokenizer: PreTrainedTokenizer,
349
+ batch_size: int = default_fixed_batch,
350
+ seq_length: int = -1,
351
+ is_pair: bool = False,
352
+ framework: Optional[TensorType] = None,
353
+ ) -> Mapping[str, Any]:
354
+ common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
355
+ tokenizer, batch_size=self.default_fixed_batch, seq_length=seq_length, is_pair=is_pair, framework=framework
356
+ )
357
+ # check if the mode is using fixed batch size
358
+ if batch_size != self.default_fixed_batch:
359
+ logger.warning('batch size is not fixed, force change into fixed batch size: %d.'
360
+ % self.default_fixed_batch)
361
+
362
+ # We need to order the input in the way they appears in the forward()
363
+ ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
364
+
365
+ # Need to add the past_keys
366
+ if self.use_past:
367
+ if not is_torch_available():
368
+ raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
369
+ else:
370
+ # TODO support use_past
371
+ # import torch
372
+ #
373
+ # batch, seqlen = common_inputs["input_ids"].shape
374
+ # # Not using the same length for past_key_values
375
+ # past_key_values_length = seqlen + 2
376
+ # past_shape = (
377
+ # batch,
378
+ # self.num_attention_heads,
379
+ # past_key_values_length,
380
+ # self._config.hidden_size // self.num_attention_heads,
381
+ # )
382
+ # ordered_inputs["past_key_values"] = [
383
+ # (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
384
+ # ]
385
+ raise NotImplementedError('position_ids do not support past_key_values yet.')
386
+
387
+ # Need to add the attention_mask manually
388
+ # 1. add attention_mask
389
+ ordered_inputs["attention_mask"] = self.get_masks(common_inputs["input_ids"])
390
+ # 2. add position_ids
391
+ MASK, gMASK = self._config.mask_token_id, self._config.gmask_token_id
392
+ seqs = common_inputs["input_ids"].tolist()
393
+ mask_positions, use_gmasks = [], []
394
+ for seq in seqs:
395
+ mask_token = gMASK if gMASK in seq else MASK
396
+ use_gmask = mask_token == gMASK
397
+ mask_positions.append(seq.index(mask_token))
398
+ use_gmasks.append(use_gmask)
399
+ ordered_inputs["position_ids"] = self.get_position_ids(common_inputs["input_ids"],
400
+ mask_positions, use_gmasks=use_gmasks)
401
+
402
+ if self.use_past:
403
+ # mask_dtype = ordered_inputs["attention_mask"].dtype
404
+ # ordered_inputs["attention_mask"] = torch.cat(
405
+ # [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
406
+ # )
407
+ raise NotImplementedError('position_ids do not support past_key_values yet.')
408
+
409
+ return ordered_inputs
410
+
411
+ @property
412
+ def default_onnx_opset(self) -> int:
413
+ return 13