Skip to content

Commit e4528e9

Browse files
authored
save processor automatically (#372)
1 parent 3ac377b commit e4528e9

File tree

5 files changed

+44
-28
lines changed

5 files changed

+44
-28
lines changed

auto_round/__main__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ def run_lmms():
5353
lmms_eval(args)
5454

5555
def switch():
56-
# if "--lmms" in sys.argv:
57-
# sys.argv.remove("--lmms")
58-
# run_lmms()
5956
if "--mllm" in sys.argv:
6057
sys.argv.remove("--mllm")
6158
run_mllm()

auto_round/autoround.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,9 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
12631263
self.model.save_pretrained(output_dir)
12641264
if self.tokenizer is not None:
12651265
self.tokenizer.save_pretrained(output_dir)
1266+
processor = kwargs.get("processor", None)
1267+
if processor is not None:
1268+
processor.save_pretrained(output_dir)
12661269
return
12671270

12681271
from auto_round.export import EXPORT_FORMAT

auto_round/mllm/autoround_mllm.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _only_text_test(model, tokenizer, device):
3838
tokenizer.padding_side = 'left'
3939
if tokenizer.pad_token is None:
4040
tokenizer.pad_token = tokenizer.eos_token
41-
if device != model.device.type:
41+
if device.split(':')[0] != model.device.type:
4242
model = model.to(device)
4343
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
4444
model(**inputs)
@@ -150,19 +150,20 @@ def __init__(
150150
self.to_quant_block_names = to_quant_block_names
151151
self.extra_data_dir = extra_data_dir
152152
self.quant_nontext_module = quant_nontext_module
153+
self.processor = processor
153154
self.image_processor = image_processor
154155
self.template = template if template is not None else model.config.model_type
155156
if not isinstance(dataset, torch.utils.data.DataLoader):
156157
self.template = get_template(
157158
self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
158-
159-
dataset = self.template.default_dataset if dataset is None else dataset
159+
dataset = self.template.default_dataset if dataset is None else dataset
160160

161161
from ..calib_dataset import CALIB_DATASETS
162162
from .mllm_dataset import MLLM_DATASET
163163
if isinstance(dataset, str):
164164
if quant_nontext_module or \
165-
(dataset in CALIB_DATASETS.keys() and not _only_text_test(model, tokenizer, device)):
165+
(dataset in CALIB_DATASETS.keys() and not \
166+
_only_text_test(model, tokenizer, device)):
166167
if quant_nontext_module:
167168
logger.warning(f"Text only dataset cannot be used for calibrating non-text modules,"
168169
"switching to liuhaotian/llava_conv_58k")
@@ -372,4 +373,20 @@ def calib(self, nsamples, bs):
372373
m = m.to("meta")
373374
# torch.cuda.empty_cache()
374375

376+
def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs):
377+
"""Save the quantized model to the specified output directory in the specified format.
378+
379+
Args:
380+
output_dir (str, optional): The directory to save the quantized model. Defaults to None.
381+
format (str, optional): The format in which to save the model. Defaults to "auto_round".
382+
inplace (bool, optional): Whether to modify the model in place. Defaults to True.
383+
**kwargs: Additional keyword arguments specific to the export format.
375384
385+
Returns:
386+
object: The compressed model object.
387+
"""
388+
if self.processor is not None and not hasattr(self.processor, "chat_template"):
389+
self.processor.chat_template = None
390+
compressed_model = super().save_quantized(
391+
output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs)
392+
return compressed_model

auto_round/mllm/template.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -118,24 +118,25 @@ def _register_template(
118118

119119
def load_template(path: str):
120120
"""Load template information from a json file."""
121-
data = json.load(open(path, "r"))
122-
if "model_type" not in data:
123-
data["model_type"] = "user_define"
124-
if "replace_tokens" in data and data["replace_tokens"] is not None:
125-
assert len(data["replace_tokens"]) % 2 == 0, \
126-
"the format of replace_tokens should be [old_tag1, replace_tag1, old_tag2, replace_tag2]"
127-
temp = []
128-
for i in range(0, len(data["replace_tokens"]), 2):
129-
temp.append((data["replace_tokens"][i], data["replace_tokens"][i + 1]))
130-
data["replace_tokens"] = temp
131-
if "processor" in data:
132-
assert data["processor"] in PROCESSORS.keys(), \
133-
"{} is not supported, current support: {}".format(data["processor"], ",".join(PROCESSORS.keys()))
134-
data["processor"] = PROCESSORS[data["processor"]]
135-
template = _register_template(
136-
**data
137-
)
138-
return template
121+
with open(path, "r") as file:
122+
data = json.load(file)
123+
if "model_type" not in data:
124+
data["model_type"] = "user_define"
125+
if "replace_tokens" in data and data["replace_tokens"] is not None:
126+
assert len(data["replace_tokens"]) % 2 == 0, \
127+
"the format of replace_tokens should be [old_tag1, replace_tag1, old_tag2, replace_tag2]"
128+
temp = []
129+
for i in range(0, len(data["replace_tokens"]), 2):
130+
temp.append((data["replace_tokens"][i], data["replace_tokens"][i + 1]))
131+
data["replace_tokens"] = temp
132+
if "processor" in data:
133+
assert data["processor"] in PROCESSORS.keys(), \
134+
"{} is not supported, current support: {}".format(data["processor"], ",".join(PROCESSORS.keys()))
135+
data["processor"] = PROCESSORS[data["processor"]]
136+
template = _register_template(
137+
**data
138+
)
139+
return template
139140

140141

141142
def _load_preset_template():

auto_round/script/mllm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -418,13 +418,11 @@ def tune(args):
418418
inplace = False if len(format_list) > 1 else True
419419
for format_ in format_list:
420420
eval_folder = f'{export_dir}-{format_}'
421-
if processor is not None and not hasattr(processor, "chat_template"):
422-
processor.chat_template = None
423421
safe_serialization = True
424422
if "phi3_v" in model_type:
425423
safe_serialization = False
426424
autoround.save_quantized(
427-
eval_folder, format=format_, inplace=inplace, processor=processor, safe_serialization=safe_serialization)
425+
eval_folder, format=format_, inplace=inplace, safe_serialization=safe_serialization)
428426

429427

430428
def eval(args):

0 commit comments

Comments
 (0)