Skip to content

Commit 8033dd0

Browse files
authored
support packing immediately in new quantization api to save ram usage (#466)
1 parent 34336d7 commit 8033dd0

File tree

15 files changed

+801
-375
lines changed

15 files changed

+801
-375
lines changed

auto_round/auto_quantizer.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def merge_quantization_configs(
178178
if isinstance(quantization_config_from_args, (AutoRoundConfig)):
179179
logger.info(f"Loading quantized model in auto_round format.")
180180
tmp_backend = quantization_config["quant_method"]
181+
if "auto-round" not in tmp_backend and "gptq" not in tmp_backend and "awq" not in tmp_backend:
182+
logger.error("could not convert to auto_round format, currently only supports `gptq`,`awq` or "
183+
"`auto-round` format")
184+
exit(-1)
181185
target_backend = quantization_config["backend"] if "backend" in quantization_config else "auto"
182186
if loading_attr_dict is not None and "backend" in loading_attr_dict:
183187
target_backend = loading_attr_dict["backend"]
@@ -470,6 +474,22 @@ def convert_model(self, model: nn.Module):
470474
extra_config = {}
471475
if hasattr(quantization_config, "extra_config"):
472476
extra_config = quantization_config.extra_config
477+
if hasattr(quantization_config, "modules_in_block_to_quantize"):##gptq format
478+
modules_in_block_to_quantize_tmp = quantization_config.modules_in_block_to_quantize
479+
modules_in_block_to_quantize = [item for sublist in modules_in_block_to_quantize_tmp for item in sublist]
480+
for layer_name in layer_names:
481+
quantized = False
482+
for qname in modules_in_block_to_quantize:
483+
if qname in layer_name:
484+
quantized=True
485+
break
486+
if not quantized:
487+
extra_config[layer_name]={"bits":16}
488+
if hasattr(quantization_config, "modules_to_not_convert"):
489+
for layer_name in quantization_config.modules_to_not_convert:
490+
extra_config[layer_name]={"bits":16}
491+
492+
473493

474494
layer_names += extra_config.keys()
475495
layer_names = list(set(layer_names))

0 commit comments

Comments
 (0)