Skip to content

Commit 22ce956

Browse files
authored
refine mllm API and add help info (#334)
1 parent 8021793 commit 22ce956

File tree

8 files changed

+93
-61
lines changed

8 files changed

+93
-61
lines changed

auto_round/__main__.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,30 @@ def run_fast():
3333

3434

3535
def run_mllm():
36-
from auto_round.script.mllm import setup_parser, tune, eval
37-
args = setup_parser()
38-
if args.eval:
36+
if "--eval" in sys.argv:
37+
from auto_round.script.mllm import setup_lmeval_parser, eval
38+
sys.argv.remove("--eval")
39+
args = setup_lmeval_parser()
3940
eval(args)
41+
elif "--lmms" in sys.argv:
42+
sys.argv.remove("--lmms")
43+
run_lmms()
4044
else:
45+
from auto_round.script.mllm import setup_parser, tune
46+
args = setup_parser()
4147
tune(args)
4248

4349
def run_lmms():
44-
from transformers.utils.versions import require_version
45-
require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`")
4650
# from auto_round.script.lmms_eval import setup_lmms_args, eval
4751
from auto_round.script.mllm import setup_lmms_parser, lmms_eval
4852
args = setup_lmms_parser()
4953
lmms_eval(args)
5054

5155
def switch():
52-
if "--lmms" in sys.argv:
53-
sys.argv.remove("--lmms")
54-
run_lmms()
55-
elif "--mllm" in sys.argv:
56+
# if "--lmms" in sys.argv:
57+
# sys.argv.remove("--lmms")
58+
# run_lmms()
59+
if "--mllm" in sys.argv:
5660
sys.argv.remove("--mllm")
5761
run_mllm()
5862
else:

auto_round/mllm/autoround_mllm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def __init__(
9696
self,
9797
model,
9898
tokenizer,
99-
image_processor=None,
99+
processor = None,
100+
image_processor = None,
100101
bits: int = 4,
101102
group_size: int = 128,
102103
sym: bool = False,
@@ -143,8 +144,8 @@ def __init__(
143144
self.image_processor = image_processor
144145
self.template = template if template is not None else model.config.model_type
145146
self.template = get_template(
146-
self.template, model=model, tokenizer=tokenizer, image_processor=image_processor)
147-
147+
self.template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
148+
148149
dataset = self.template.default_dataset if dataset is None else dataset
149150
from ..calib_dataset import CALIB_DATASETS
150151
if truncation is None:

auto_round/mllm/eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def lmms_eval(
350350
apply_chat_template=False
351351
):
352352
from auto_round import AutoRoundConfig
353-
353+
from transformers.utils.versions import require_version
354+
require_version("lmms_eval", "lmms_eval need to be installed, `pip install lmms_eval`")
354355
if isinstance(tasks, str):
355356
tasks = tasks.replace(' ', '').split(',')
356357

auto_round/mllm/processor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ def register(processor):
3232
class BasicProcessor:
3333
def __init__(self):
3434
pass
35-
36-
def post_init(self, model, tokenizer, image_processor=None, **kwargs):
35+
36+
def post_init(self, model, tokenizer, processor=None, image_processor=None, **kwargs):
3737
self.model = model
3838
self.tokenizer = tokenizer
39+
self.processor = processor
3940
if image_processor is not None:
4041
self.image_processor = image_processor
4142
else:
@@ -76,7 +77,7 @@ def get_input(
7677
if truncation is True and truncation_strategy == "text":
7778
text = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length])
7879

79-
ret = self.tokenizer.processor(
80+
ret = self.processor(
8081
text=text,
8182
images=images,
8283
return_tensors=return_tensors,

auto_round/mllm/template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _load_preset_template():
147147
_load_preset_template()
148148

149149

150-
def get_template(template_or_path: str, model=None, tokenizer=None, image_processor=None):
150+
def get_template(template_or_path: str, model=None, tokenizer=None, processor=None, image_processor=None):
151151
"""Get template by template name or from a json file.
152152
153153
Args:
@@ -166,6 +166,6 @@ def get_template(template_or_path: str, model=None, tokenizer=None, image_proces
166166
logger.warning(f"Unable to recognize {template_or_path}, using default template instead.")
167167
template = TEMPLATES["default"]
168168

169-
template.processor.post_init(model=model, tokenizer=tokenizer, image_processor=image_processor)
169+
template.processor.post_init(model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
170170

171171
return template

auto_round/script/mllm.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -160,37 +160,7 @@ def __init__(self, *args, **kwargs):
160160
self.add_argument("--to_quant_block_names", default=None, type=str,
161161
help="Names of quantitative blocks, please use commas to separate them.")
162162

163-
## ======================= VLM eval=======================
164-
self.add_argument("--tasks", type=str,
165-
default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE",
166-
help="eval tasks for VLMEvalKit.")
167-
# Args that only apply to Video Dataset
168-
self.add_argument("--nframe", type=int, default=8,
169-
help="the number of frames to sample from a video,"
170-
" only applicable to the evaluation of video benchmarks.")
171-
self.add_argument("--pack", action='store_true',
172-
help="a video may associate with multiple questions, if pack==True,"
173-
" will ask all questions for a video in a single")
174-
self.add_argument("--use-subtitle", action='store_true')
175-
self.add_argument("--fps", type=float, default=-1)
176-
# Work Dir
177-
# Infer + Eval or Infer Only
178-
self.add_argument("--mode", type=str, default='all', choices=['all', 'infer'],
179-
help="when mode set to 'all', will perform both inference and evaluation;"
180-
" when set to 'infer' will only perform the inference.")
181-
self.add_argument('--eval_data_dir', type=str, default=None,
182-
help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData')
183-
# API Kwargs, Apply to API VLMs and Judge API LLMs
184-
self.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
185-
# Explicitly Set the Judge Model
186-
self.add_argument('--judge', type=str, default=None)
187-
# Logging Utils
188-
self.add_argument('--verbose', action='store_true')
189-
# Configuration for Resume
190-
# Ignore: will not rerun failed VLM inference
191-
self.add_argument('--ignore', action='store_true', help='ignore failed indices. ')
192-
# Rerun: will remove all evaluation temp files
193-
self.add_argument('--rerun', action='store_true')
163+
194164

195165

196166
def setup_parser():
@@ -215,6 +185,50 @@ def setup_parser():
215185
return args
216186

217187

188+
def setup_lmeval_parser():
189+
parser = argparse.ArgumentParser()
190+
parser.add_argument("--model", "--model_name", "--model_name_or_path",
191+
help="model name or path")
192+
parser.add_argument("--tasks", type=str,
193+
default="MMBench_DEV_EN_V11,ScienceQA_VAL,TextVQA_VAL,POPE",
194+
help="eval tasks for VLMEvalKit.")
195+
# Args that only apply to Video Dataset
196+
parser.add_argument("--nframe", type=int, default=8,
197+
help="the number of frames to sample from a video,"
198+
" only applicable to the evaluation of video benchmarks.")
199+
parser.add_argument("--pack", action='store_true',
200+
help="a video may associate with multiple questions, if pack==True,"
201+
" will ask all questions for a video in a single")
202+
parser.add_argument("--fps", type=float, default=-1,
203+
help="set the fps for a video.")
204+
# Work Dir
205+
# Infer + Eval or Infer Only
206+
parser.add_argument("--mode", type=str, default='all', choices=['all', 'infer'],
207+
help="when mode set to 'all', will perform both inference and evaluation;"
208+
" when set to 'infer' will only perform the inference.")
209+
parser.add_argument('--eval_data_dir', type=str, default=None,
210+
help='path for VLMEvalKit to store the eval data. Default will store in ~/LMUData')
211+
# API Kwargs, Apply to API VLMs and Judge API LLMs
212+
parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
213+
# Explicitly Set the Judge Model
214+
parser.add_argument('--judge', type=str, default=None,
215+
help="whether is a judge model.")
216+
# Logging Utils
217+
parser.add_argument('--verbose', action='store_true',
218+
help="whether to display verbose information.")
219+
# Configuration for Resume
220+
# Ignore: will not rerun failed VLM inference
221+
parser.add_argument('--ignore', action='store_true',
222+
help='ignore failed indices. ')
223+
# Rerun: will remove all evaluation temp files
224+
parser.add_argument('--rerun', action='store_true',
225+
help="if true, will remove all evaluation temp files and rerun.")
226+
parser.add_argument("--output_dir", default="./eval_result", type=str,
227+
help="the directory to save quantized model")
228+
args = parser.parse_args()
229+
return args
230+
231+
218232
def tune(args):
219233
if args.format is None:
220234
args.format = "auto_round"
@@ -265,14 +279,14 @@ def tune(args):
265279
processor, image_processor = None, None
266280
if "llava" in model_name:
267281
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
268-
tokenizer, model, image_processor, _ = load_pretrained_model(model_name, model_base=None, model_name=model_name,
269-
torch_dtype=torch_dtype)
282+
tokenizer, model, image_processor, _ = load_pretrained_model(
283+
model_name, model_base=None, model_name=model_name,
284+
torch_dtype=torch_dtype)
270285
model_type = "llava"
271286
else:
272287
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
273288
tokenizer = AutoTokenizer.from_pretrained(model_name)
274289
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
275-
tokenizer.processor = processor
276290
model_type = config.model_type
277291
if "qwen2_vl" in model_type:
278292
from transformers import Qwen2VLForConditionalGeneration
@@ -361,7 +375,7 @@ def tune(args):
361375
if "--truncation" not in sys.argv:
362376
args.truncation = None
363377

364-
autoround = round(model, tokenizer, image_processor=image_processor, dataset=args.dataset,
378+
autoround = round(model, tokenizer, processor=processor, image_processor=image_processor, dataset=args.dataset,
365379
extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size,
366380
sym=not args.asym, batch_size=args.batch_size, seqlen=seqlen, nblocks=args.nblocks,
367381
iters=args.iters, lr=args.lr, minmax_lr=args.minmax_lr, amp=not args.disable_amp,
@@ -406,7 +420,6 @@ def eval(args):
406420
data_store_dir=args.eval_data_dir,
407421
dataset=args.tasks,
408422
pack=args.pack,
409-
use_subtitle=args.use_subtitle,
410423
fps=args.fps,
411424
nframe=args.nframe,
412425
rerun=args.rerun,
@@ -426,8 +439,8 @@ def setup_lmms_parser():
426439
default="pope,textvqa_val,scienceqa,mmbench_en",
427440
help="To get full list of tasks, use the command lmms-eval --tasks list",
428441
)
429-
parser.add_argument("--output_dir", default="./tmp_autoround", type=str,
430-
help="the directory to save quantized model")
442+
parser.add_argument("--output_dir", default="./eval_result", type=str,
443+
help="the directory to save quantized model")
431444
parser.add_argument(
432445
"--num_fewshot",
433446
type=int,

test/test_basic_usage.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,24 @@ def test_auto_round_cmd(self):
3232

3333

3434
# test mllm script
35+
# test auto_round_mllm help
3536
res = os.system(
3637
f"cd .. && {python_path} -m auto_round --mllm -h")
3738
if res > 0 or res == -1:
3839
assert False, "cmd line test fail, please have a check"
3940

41+
# test auto_round_mllm --eval help
42+
res = os.system(
43+
f"cd .. && {python_path} -m auto_round --mllm --eval -h")
44+
if res > 0 or res == -1:
45+
assert False, "cmd line test fail, please have a check"
46+
47+
# test auto_round_mllm --lmms help
48+
res = os.system(
49+
f"cd .. && {python_path} -m auto_round --mllm --lmms -h")
50+
if res > 0 or res == -1:
51+
assert False, "cmd line test fail, please have a check"
52+
4053
res = os.system(
4154
f"cd .. && {python_path} -m auto_round --mllm --iter 2 --nsamples 10 --format auto_round --output_dir ./saved")
4255
if res > 0 or res == -1:

test/test_mllm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ def tearDownClass(self):
4242
def test_tune(self):
4343
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
4444
processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
45-
tokenizer.processor = processor
4645
model = Qwen2VLForConditionalGeneration.from_pretrained(
4746
self.model_name, trust_remote_code=True, device_map="auto")
4847
bits, group_size = 4, 128
4948
autoround = AutoRoundMLLM(
50-
model, tokenizer, bits=bits, group_size=group_size,
49+
model, tokenizer, processor=processor,
50+
bits=bits, group_size=group_size,
5151
nsamples=1,
5252
batch_size=1, iters=2, dataset=self.dataset,seqlen=256)
5353
autoround.quantize()
@@ -57,12 +57,12 @@ def test_tune(self):
5757
def test_quant_vision(self): ## bug need to fix
5858
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
5959
processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
60-
tokenizer.processor = processor
6160
model = Qwen2VLForConditionalGeneration.from_pretrained(
6261
self.model_name, trust_remote_code=True, device_map="auto")
6362
bits, group_size = 4, 128
6463
autoround = AutoRoundMLLM(
65-
model, tokenizer, bits=bits, group_size=group_size,
64+
model, tokenizer, processor=processor,
65+
bits=bits, group_size=group_size,
6666
nsamples=5,
6767
batch_size=3, iters=2, dataset=self.dataset, quant_nontext_module=False,seqlen=256)
6868
autoround.quantize()
@@ -72,7 +72,6 @@ def test_quant_block_names(self):
7272
from auto_round.utils import get_multimodal_block_names,find_matching_blocks
7373
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
7474
processor = AutoProcessor.from_pretrained(self.model_name, trust_remote_code=True)
75-
tokenizer.processor = processor
7675
model = Qwen2VLForConditionalGeneration.from_pretrained(
7776
self.model_name, trust_remote_code=True, device_map="auto")
7877
to_quant_block_names = 'visual.*12,layers.0,model.layers.*9'

0 commit comments

Comments
 (0)