BigMaoGoGoGo commited on
Commit
ac792e3
1 Parent(s): 8fe54fa

fix gpu cache

Browse files
Files changed (1) hide show
  1. gptq_quantization.py +39 -47
gptq_quantization.py CHANGED
@@ -138,38 +138,34 @@ class GPTQLayerWrapper:
138
 
139
  if is_transformer_conv1d(self.layer):
140
  Q = Q.t()
141
- self.layer.weight = nn.Parameter(Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype), requires_grad=False)
142
-
 
 
143
  del self.H
144
- if torch.cuda.is_available():
145
- torch.cuda.empty_cache()
146
-
147
- def release_gpu_memory(self):
148
- if hasattr(self, "H"):
149
- del self.H
150
 
151
 
152
  class GPTQBlockWrapper:
153
- def __init__(self, module_name: str, module: nn.Module, weight_bit_width=8):
154
  self.layer_wrappers = {}
155
  self.hook_handles = []
156
- # module order in the whole network
157
  self.order = 0
158
- self.module_name = module_name
159
 
160
  def get_hook(layer_name):
161
  def record_hook(_, x):
162
  self.layer_wrappers[layer_name].record_h(x[0])
163
  return record_hook
164
 
165
- for layer_name, layer in module.named_modules():
166
  if isinstance(layer, tuple(QUANT_LAYERS)):
167
- full_layer_name = f"{module_name}.{layer_name}" if layer_name else f"{module_name}"
168
  self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
169
  handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
170
  self.hook_handles.append(handle)
171
 
172
- def quant_module(self):
173
  for _, wrapper in self.layer_wrappers.items():
174
  wrapper.quant_weight()
175
 
@@ -190,10 +186,6 @@ class GPTQBlockWrapper:
190
  for n, l in self.layer_wrappers.items():
191
  l.is_record = False
192
 
193
- def release_gpu_memory(self):
194
- for _, wrapper in self.layer_wrappers.items():
195
- wrapper.release_gpu_memory()
196
-
197
 
198
  class GPTQuantizer:
199
  def __init__(self, block_type: Optional[List[type]] = None):
@@ -207,19 +199,13 @@ class GPTQuantizer:
207
  child_prefix = f"{prefix}.{name}" if prefix else name
208
  if isinstance(child, tuple(self.block_type)):
209
  self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
210
- LOGGER.debug(f"Calibrate module {child_prefix} as a whole block in GPTQ")
211
  else:
212
  wrap_block(child, child_prefix)
213
 
214
  wrap_block(model)
215
  return model
216
 
217
- def quantize(self, model: nn.Module):
218
- for _, module_wrapper in self.gptq_block_wrappers.items():
219
- module_wrapper.quant_module()
220
-
221
- return model
222
-
223
  @property
224
  def calibration_iters(self):
225
  return len(self.gptq_block_wrappers)
@@ -230,56 +216,59 @@ class GPTQuantizer:
230
  record_handles = []
231
  orders = {}
232
  try:
233
- def get_record_order_hook(module_name):
234
  def record_hook(*args, **kwargs):
235
  nonlocal counter
236
- if module_name not in orders:
237
- orders[module_name] = counter
238
  counter += 1
239
  return record_hook
240
 
241
- for module_name, module_wrapper in self.gptq_block_wrappers.items():
242
  # disable the record
243
- for _, layer_wrapper in module_wrapper.layer_wrappers.items():
244
  layer_wrapper.is_record = False
245
 
246
- one_layer_wrapper_in_module = list(module_wrapper.layer_wrappers.values())[0]
247
- handles = one_layer_wrapper_in_module.layer.register_forward_pre_hook(get_record_order_hook(module_name))
248
  record_handles.append(handles)
249
  yield
250
  except Exception as e:
251
  logging.warning(e)
252
  finally:
253
- for module_name, order in orders.items():
254
- self.gptq_block_wrappers[module_name].set_order(order)
255
 
256
  for h in record_handles:
257
  h.remove()
258
 
259
- for module_name, module_wrapper in self.gptq_block_wrappers.items():
260
  # disable the record
261
- for _, layer_wrapper in module_wrapper.layer_wrappers.items():
262
  layer_wrapper.is_record = True
263
 
264
 
265
  @contextlib.contextmanager
266
  def start_calib_iter(self, i):
267
  assert i < len(self.gptq_block_wrappers)
268
- target_module_wrapper = None
269
  try:
270
- for _, module_wrapper in self.gptq_block_wrappers.items():
271
- if module_wrapper.get_order() == i:
272
- module_wrapper.enable()
273
- target_module_wrapper = module_wrapper
274
  else:
275
- module_wrapper.disable()
276
  yield
277
  finally:
278
- target_module_wrapper.quant_module()
279
 
280
- def release_gpu_memory(self):
281
- for block_name, block_wrapper in self.gptq_block_wrappers.items():
282
- block_wrapper.release_gpu_memory()
 
 
 
283
 
284
  torch.cuda.empty_cache()
285
 
@@ -301,10 +290,12 @@ def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
301
  calib_model = quantizer.wrap_model(model, weight_bit_width)
302
  with quantizer.record_order():
303
  calib_model.chat(tokenizer, calib_data[0], history=[])
 
304
  logging.info("Start doing calibration using GPTQ ")
305
  for i in range(quantizer.calibration_iters):
306
  logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
307
  # todo: should add early return to speed up the calibration
 
308
  with quantizer.start_calib_iter(i):
309
  for prompt in calib_data:
310
  model.chat(tokenizer, prompt, history=[])
@@ -328,5 +319,6 @@ def gptq_quantize(model, tokenizer, weight_bit_width, calib_data):
328
  )
329
  parent.add_module(name_in_parent, quantized_layer)
330
 
331
- torch.cuda.empty_cache()
 
332
  return
 
138
 
139
  if is_transformer_conv1d(self.layer):
140
  Q = Q.t()
141
+ shape = self.layer.weight.shape
142
+ dtype = self.layer.weight.data.dtype
143
+ del self.layer.weight
144
+ setattr(self.layer, "weight", nn.Parameter(Q.reshape(shape).to(dtype), requires_grad=False))
145
  del self.H
 
 
 
 
 
 
146
 
147
 
148
  class GPTQBlockWrapper:
149
+ def __init__(self, block_name: str, block: nn.Module, weight_bit_width=8):
150
  self.layer_wrappers = {}
151
  self.hook_handles = []
152
+ # block order in the whole network
153
  self.order = 0
154
+ self.block_name = block_name
155
 
156
  def get_hook(layer_name):
157
  def record_hook(_, x):
158
  self.layer_wrappers[layer_name].record_h(x[0])
159
  return record_hook
160
 
161
+ for layer_name, layer in block.named_modules():
162
  if isinstance(layer, tuple(QUANT_LAYERS)):
163
+ full_layer_name = f"{block_name}.{layer_name}" if layer_name else f"{block_name}"
164
  self.layer_wrappers[full_layer_name] = GPTQLayerWrapper(full_layer_name, layer, weight_bit_width)
165
  handle = layer.register_forward_pre_hook(get_hook(full_layer_name))
166
  self.hook_handles.append(handle)
167
 
168
+ def quant_block(self):
169
  for _, wrapper in self.layer_wrappers.items():
170
  wrapper.quant_weight()
171
 
 
186
  for n, l in self.layer_wrappers.items():
187
  l.is_record = False
188
 
 
 
 
 
189
 
190
  class GPTQuantizer:
191
  def __init__(self, block_type: Optional[List[type]] = None):
 
199
  child_prefix = f"{prefix}.{name}" if prefix else name
200
  if isinstance(child, tuple(self.block_type)):
201
  self.gptq_block_wrappers[name] = GPTQBlockWrapper(child_prefix, child, weight_bit_width)
202
+ LOGGER.debug(f"Calibrate block {child_prefix} as a whole block in GPTQ")
203
  else:
204
  wrap_block(child, child_prefix)
205
 
206
  wrap_block(model)
207
  return model
208
 
 
 
 
 
 
 
209
  @property
210
  def calibration_iters(self):
211
  return len(self.gptq_block_wrappers)
 
216
  record_handles = []
217
  orders = {}
218
  try:
219
+ def get_record_order_hook(block_name):
220
  def record_hook(*args, **kwargs):
221
  nonlocal counter
222
+ if block_name not in orders:
223
+ orders[block_name] = counter
224
  counter += 1
225
  return record_hook
226
 
227
+ for block_name, block_wrapper in self.gptq_block_wrappers.items():
228
  # disable the record
229
+ for _, layer_wrapper in block_wrapper.layer_wrappers.items():
230
  layer_wrapper.is_record = False
231
 
232
+ one_layer_wrapper_in_block = list(block_wrapper.layer_wrappers.values())[0]
233
+ handles = one_layer_wrapper_in_block.layer.register_forward_pre_hook(get_record_order_hook(block_name))
234
  record_handles.append(handles)
235
  yield
236
  except Exception as e:
237
  logging.warning(e)
238
  finally:
239
+ for block_name, order in orders.items():
240
+ self.gptq_block_wrappers[block_name].set_order(order)
241
 
242
  for h in record_handles:
243
  h.remove()
244
 
245
+ for _, block_wrapper in self.gptq_block_wrappers.items():
246
  # disable the record
247
+ for _, layer_wrapper in block_wrapper.layer_wrappers.items():
248
  layer_wrapper.is_record = True
249
 
250
 
251
  @contextlib.contextmanager
252
  def start_calib_iter(self, i):
253
  assert i < len(self.gptq_block_wrappers)
254
+ target_block_wrapper = None
255
  try:
256
+ for _, block_wrapper in self.gptq_block_wrappers.items():
257
+ if block_wrapper.get_order() == i:
258
+ block_wrapper.enable()
259
+ target_block_wrapper = block_wrapper
260
  else:
261
+ block_wrapper.disable()
262
  yield
263
  finally:
264
+ target_block_wrapper.quant_block()
265
 
266
+ def release_reference(self):
267
+ # delete reference so that `torch.cuda.empty_cache()` can
268
+ # release all the gpu memory cache used during calibration
269
+ for _, block_wrapper in self.gptq_block_wrappers.items():
270
+ for _, layer_wrapper in block_wrapper.layer_wrappers.items():
271
+ del layer_wrapper.layer
272
 
273
  torch.cuda.empty_cache()
274
 
 
290
  calib_model = quantizer.wrap_model(model, weight_bit_width)
291
  with quantizer.record_order():
292
  calib_model.chat(tokenizer, calib_data[0], history=[])
293
+
294
  logging.info("Start doing calibration using GPTQ ")
295
  for i in range(quantizer.calibration_iters):
296
  logging.info(f"Process: {i + 1}/{quantizer.calibration_iters}")
297
  # todo: should add early return to speed up the calibration
298
+ # todo: add cpu offload to reduce the gpu memory requirements.
299
  with quantizer.start_calib_iter(i):
300
  for prompt in calib_data:
301
  model.chat(tokenizer, prompt, history=[])
 
319
  )
320
  parent.add_module(name_in_parent, quantized_layer)
321
 
322
+ # release the memory caache during calibration
323
+ quantizer.release_reference()
324
  return