Skip to content

Commit 42d9ffb

Browse files
committed
fix: corrected minor code errors from PR
Signed-off-by: omobayode.fagbohungbe <[email protected]>
1 parent 3e7f29c commit 42d9ffb

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

fms_mo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
# Local
2121
from fms_mo.prep import qmodel_prep
22-
from fms_mo.utils.qconfig_utils import qconfig_init, qconfig_load
22+
from fms_mo.utils.qconfig_utils import qconfig_init
2323

2424
VERSION_FALLBACK = "0.0.0"
2525

fms_mo/dq.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import torch
3636

3737
# Local
38-
from fms_mo import qconfig_init, qmodel_prep, qconfig_load
38+
from fms_mo import qconfig_init, qmodel_prep
3939
from fms_mo.fx.utils import model_size_Wb
4040
from fms_mo.quant.ptq import (
4141
calibration_llm_1GPU,
@@ -214,7 +214,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
214214
q_file = open('qcfg_llama.json', "r", encoding="utf-8")
215215
saved_qcfg = json.load(q_file)
216216
qcfg.update(saved_qcfg)
217-
print(qcfg)
218217

219218
qmodel_prep(
220219
model,
@@ -252,7 +251,6 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
252251
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
253252
tokenizer.save_pretrained(opt_args.output_dir)
254253
else:
255-
pass
256254
from accelerate import load_checkpoint_and_dispatch
257255
model = load_checkpoint_and_dispatch( model, checkpoint=opt_args.output_dir, device_map=None, no_split_module_classes=['Block'])
258256

fms_mo/prep.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,18 @@ def has_quantized_module(model):
535535
"""Check if model is already quantized - do not want to quantize twice if so"""
536536
return any(isinstance(m, quantized_modules) for m in model.modules())
537537

538-
def swap_qbmm(model, qcfg):
538+
def swap_qbmm(model: nn.Module, qcfg: dict):
539+
"""Go through all model.named_modules(), try to create an equivalent Qbmm layer to replace each of
540+
the existing linear Bmm layers.
541+
542+
Args:
543+
model (nn.Module): input model to be "prepared"
544+
qcfg (dict): quant config
545+
546+
Returns: updated model is returned with the Qbmm added
547+
548+
"""
549+
539550
from fms_mo.modules import QBmm
540551

541552
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
@@ -650,7 +661,6 @@ def qmodel_prep(
650661
if mode:
651662

652663
if qcfg.get("QBmm"):
653-
pass
654664
swap_qbmm(model,qcfg)
655665

656666
model = q_any_net_5(model, qcfg, verbose = False)

0 commit comments

Comments
 (0)