Skip to content

feat: Generic Rotation #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions fms_mo/modules/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# Local
from fms_mo.quant.quantizers import Qbypass, Qdynamic, get_activation_quantizer
from fms_mo.quant.rotation import RotQuantWrapper


class QBmm(nn.Module):
Expand Down Expand Up @@ -131,8 +132,10 @@ def __init__(
)

self.calib_iterator = [] # To simplify update of clipvals in forward()
self.quantize_m1 = Qbypass()
self.quantize_calib_m1 = Qbypass()
quant_m1_def = Qbypass() if "rot_" not in self.qm1_mode else RotQuantWrapper()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to do the if condition as RotQuantWrapper() if "rot_" in self.qm1_mode else Qbypass()?

quant_m2_def = Qbypass() if "rot_" not in self.qm2_mode else RotQuantWrapper()
self.quantize_m1 = quant_m1_def
self.quantize_calib_m1 = quant_m1_def
if self.num_bits_m1 not in [32, 16]:
self.quantize_m1 = get_activation_quantizer(
self.qm1_mode if (not m1_bounded or "fp8" in qm1_mode) else "minmax",
Expand All @@ -155,8 +158,8 @@ def __init__(
symmetric=self.symmetric,
)

self.quantize_m2 = Qbypass()
self.quantize_calib_m2 = Qbypass()
self.quantize_m2 = quant_m2_def
self.quantize_calib_m2 = quant_m2_def
if self.num_bits_m2 not in [32, 16]:
self.quantize_m2 = get_activation_quantizer(
self.qm2_mode if (not m2_bounded or "fp8" in qm2_mode) else "minmax",
Expand Down
11 changes: 7 additions & 4 deletions fms_mo/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_weight_quantizer,
mask_fc_kij,
)
from fms_mo.quant.rotation import RotQuantWrapper
from fms_mo.utils.import_utils import available_packages

if available_packages["triton"]:
Expand Down Expand Up @@ -158,8 +159,10 @@ def __init__(

self.calib_iterator = []
# To simplify update of clipvals in forward()
self.quantize_feature = Qbypass()
self.quantize_calib_feature = Qbypass()
quantA_default = Qbypass() if "rot_" not in self.qa_mode else RotQuantWrapper()
quantW_default = Qbypass() if "rot_" not in self.qw_mode else RotQuantWrapper()
self.quantize_feature = quantA_default
self.quantize_calib_feature = quantA_default
if self.num_bits_feature not in [32, 16]:
self.quantize_feature = get_activation_quantizer(
self.qa_mode,
Expand Down Expand Up @@ -187,8 +190,8 @@ def __init__(
quantizer2sync=self.quantize_feature,
)

self.quantize_weight = Qbypass()
self.quantize_calib_weight = Qbypass()
self.quantize_weight = quantW_default
self.quantize_calib_weight = quantW_default
if self.num_bits_weight not in [32, 16]:
self.quantize_weight = get_weight_quantizer(
self.qw_mode,
Expand Down
8 changes: 7 additions & 1 deletion fms_mo/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
base_params = {}
if hasattr(module, "__constants__"):
base_params = {k: getattr(module, k) for k in module.__constants__}
base_params["bias"] = module.bias is not None
base_params["bias"] = getattr(module, "bias", None) is not None
base_params["device"] = next(module.parameters()).device # usually cuda

module_output = module
Expand Down Expand Up @@ -480,6 +480,12 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
setattr(module_output, k, v)
module_output._all_weights = module._all_weights

# For nn.Embedding
elif isinstance(module, nn.Embedding):
# simplest case, only support rotation for now, no quantization
Qemb = mapping.get(nn.Embedding, nn.Embedding)
module_output = Qemb(module)

return module_output


Expand Down
97 changes: 80 additions & 17 deletions fms_mo/quant/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
import torch.nn as nn # pylint: disable=consider-using-from-import
import torch.nn.functional as F

# Local
from fms_mo.quant.rotation import RotQuantWrapper

logger = logging.getLogger(__name__)


Expand All @@ -66,8 +69,16 @@ def get_activation_quantizer(
- pact/pact+/pactsym
- sawb/sawb+
- max

If qa_mode has "rot_" prefix or "_rot" suffix, wrap it with RotQuantizer(), remember to set up
R_left, R_right tensors later.
"""

use_rot = False
if "rot_" in qa_mode or "_rot" in qa_mode:
use_rot = True
qa_mode = qa_mode.replace("rot_", "").replace("_rot", "")

if not use_swcap:
QPACTLUT = {
"pact_uni": PACT,
Expand Down Expand Up @@ -123,23 +134,27 @@ def get_activation_quantizer(
)
elif qa_mode == "dorefa":
act_quantizer = dorefa_quantize_activation
elif (
qa_mode == "max"
): # NOTE Need to be careful using this for activation, particular to 1 sided.
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "minmax":
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
elif "max" in qa_mode:
# NOTE Need to be careful using this for activation, particular to 1 sided.
if "min" in qa_mode:
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True)
elif "pertoken" in qa_mode or "perToken" in qa_mode:
act_quantizer = QMaxDynamic(nbits, dim=-1)
elif "per_channel" in qa_mode or "perCh" in qa_mode:
act_quantizer = QMaxDynamic(nbits, dim=-2)
elif "sym" in qa_mode:
act_quantizer = Qmax(
nbits,
align_zero=True,
minmax=False,
extend_act_range=extend_act_range,
)
else:
act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False)
elif qa_mode == "fix":
act_quantizer = QFixSymmetric(
nbits, init_clip_val=clip_val, align_zero=align_zero
)
elif qa_mode == "maxsym":
act_quantizer = Qmax(
nbits,
align_zero=True,
minmax=False,
extend_act_range=extend_act_range,
)
elif qa_mode == "pactsym":
act_quantizer = PACT2Sym(
nbits,
Expand Down Expand Up @@ -179,8 +194,6 @@ def get_activation_quantizer(
perToken=perToken,
emulate=True,
)
elif qa_mode == "pertokenmax":
act_quantizer = PerTokenMax(nbits)
else:
raise ValueError(f"unrecognized activation quantization mode {qa_mode}")
else: # swcap-compatible activation quantizers
Expand Down Expand Up @@ -220,6 +233,9 @@ def get_activation_quantizer(
f"activation quantization mode {qa_mode} is incompatible with swcap"
)

if use_rot:
act_quantizer = RotQuantWrapper(act_quantizer)

return act_quantizer


Expand All @@ -245,7 +261,15 @@ def get_weight_quantizer(
SWCAP quantizers:
- sawb/sawb+
- max
If qa_mode has "rot_" prefix or "_rot" suffix, wrap it with RotQuantizer(), remember to set up
R_left, R_right tensors later.
"""

use_rot = False
if "rot_" in qw_mode or "_rot" in qw_mode:
use_rot = True
qw_mode = qw_mode.replace("rot_", "").replace("_rot", "")

weight_quantizer = None
if not use_swcap:
cggrad = "cgpact" in qw_mode
Expand Down Expand Up @@ -367,6 +391,9 @@ def get_weight_quantizer(
f"activation quantized mode {qw_mode} is incompatible with swcap"
)

if use_rot:
weight_quantizer = RotQuantWrapper(weight_quantizer)

return weight_quantizer


Expand Down Expand Up @@ -3470,7 +3497,7 @@ def __init__(self, num_bits):
"""
For per-token activation quantization using abs().max() as scale,
Zero is aligned so that the levels are symmetric around zero (lossing one level)
Since the token length is un-known before running, the quatnization is dynamic, meaning
Since the token length is un-known before running, the quantization is dynamic, meaning
no trainable quantization scales and the scales are computed at run time.
"""
super().__init__()
Expand All @@ -3487,6 +3514,42 @@ def __repr__(self):
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"


class QMaxDynamic(nn.Module):
def __init__(self, num_bits, dim=-1):
"""
For per-token or per-channel quantization using abs().max() as scale, usually for activation
and could be used for Qbmm M2 as well.
(reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token
dim = -2 -> per-channel
Zero is aligned so that the levels are symmetric around zero (lossing one level)
Since the token length is un-known before running, the quantizater can only calculate the
scales at the run times dynamically, meaning no trainable quantization scales is allowed.
(unless input seq length is always the same, not just padded to a fixed length.)
"""
super().__init__()
self.num_bits = num_bits
self.levels = 2 ** (self.num_bits - 1) - 1
if isinstance(dim, str):
if "perCh" in dim or "per_channel" in dim:
dim = -2
elif "perToken" in dim or "per_token" in dim or "per_Token" in dim:
dim = -1
elif dim in [-1, -2]:
self.reduce_dim = dim
else:
raise ValueError(
f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}"
)

def forward(self, input_tensor):
amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0]
scales = amax_dim.clamp(min=1e-5).div(self.levels)
return input_tensor.div(scales).round().mul(scales)

def __repr__(self):
return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)"


class Qdynamic(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -4560,7 +4623,7 @@ def forward(self, x_orig):

class Qbypass(nn.Module):
"""
no quantization at all, straight-thru
No quantization at all, output the input_tensor directly.
in place of lambda function when using nbits=32 and 16.
to avoid issue when pickle (ie torch.save) of lambda
(seems to be a problem only for DDP)
Expand Down
Loading
Loading