diff --git a/fms_mo/modules/bmm.py b/fms_mo/modules/bmm.py index aa16def..b50cc2a 100644 --- a/fms_mo/modules/bmm.py +++ b/fms_mo/modules/bmm.py @@ -20,6 +20,7 @@ # Local from fms_mo.quant.quantizers import Qbypass, Qdynamic, get_activation_quantizer +from fms_mo.quant.rotation import RotQuantWrapper class QBmm(nn.Module): @@ -131,8 +132,10 @@ def __init__( ) self.calib_iterator = [] # To simplify update of clipvals in forward() - self.quantize_m1 = Qbypass() - self.quantize_calib_m1 = Qbypass() + quant_m1_def = Qbypass() if "rot_" not in self.qm1_mode else RotQuantWrapper() + quant_m2_def = Qbypass() if "rot_" not in self.qm2_mode else RotQuantWrapper() + self.quantize_m1 = quant_m1_def + self.quantize_calib_m1 = quant_m1_def if self.num_bits_m1 not in [32, 16]: self.quantize_m1 = get_activation_quantizer( self.qm1_mode if (not m1_bounded or "fp8" in qm1_mode) else "minmax", @@ -155,8 +158,8 @@ def __init__( symmetric=self.symmetric, ) - self.quantize_m2 = Qbypass() - self.quantize_calib_m2 = Qbypass() + self.quantize_m2 = quant_m2_def + self.quantize_calib_m2 = quant_m2_def if self.num_bits_m2 not in [32, 16]: self.quantize_m2 = get_activation_quantizer( self.qm2_mode if (not m2_bounded or "fp8" in qm2_mode) else "minmax", diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 26b383c..13b89a6 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -36,6 +36,7 @@ get_weight_quantizer, mask_fc_kij, ) +from fms_mo.quant.rotation import RotQuantWrapper from fms_mo.utils.import_utils import available_packages if available_packages["triton"]: @@ -158,8 +159,10 @@ def __init__( self.calib_iterator = [] # To simplify update of clipvals in forward() - self.quantize_feature = Qbypass() - self.quantize_calib_feature = Qbypass() + quantA_default = Qbypass() if "rot_" not in self.qa_mode else RotQuantWrapper() + quantW_default = Qbypass() if "rot_" not in self.qw_mode else RotQuantWrapper() + self.quantize_feature = quantA_default + self.quantize_calib_feature = quantA_default if self.num_bits_feature not in [32, 16]: self.quantize_feature = get_activation_quantizer( self.qa_mode, @@ -187,8 +190,8 @@ def __init__( quantizer2sync=self.quantize_feature, ) - self.quantize_weight = Qbypass() - self.quantize_calib_weight = Qbypass() + self.quantize_weight = quantW_default + self.quantize_calib_weight = quantW_default if self.num_bits_weight not in [32, 16]: self.quantize_weight = get_weight_quantizer( self.qw_mode, diff --git a/fms_mo/prep.py b/fms_mo/prep.py index 0e8501f..25fc688 100644 --- a/fms_mo/prep.py +++ b/fms_mo/prep.py @@ -215,7 +215,7 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): base_params = {} if hasattr(module, "__constants__"): base_params = {k: getattr(module, k) for k in module.__constants__} - base_params["bias"] = module.bias is not None + base_params["bias"] = getattr(module, "bias", None) is not None base_params["device"] = next(module.parameters()).device # usually cuda module_output = module @@ -480,6 +480,12 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False): setattr(module_output, k, v) module_output._all_weights = module._all_weights + # For nn.Embedding + elif isinstance(module, nn.Embedding): + # simplest case, only support rotation for now, no quantization + Qemb = mapping.get(nn.Embedding, nn.Embedding) + module_output = Qemb(module) + return module_output diff --git a/fms_mo/quant/quantizers.py b/fms_mo/quant/quantizers.py index c97dbfa..1c3ee5e 100644 --- a/fms_mo/quant/quantizers.py +++ b/fms_mo/quant/quantizers.py @@ -40,6 +40,9 @@ import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn.functional as F +# Local +from fms_mo.quant.rotation import RotQuantWrapper + logger = logging.getLogger(__name__) @@ -66,8 +69,16 @@ def get_activation_quantizer( - pact/pact+/pactsym - sawb/sawb+ - max + + If qa_mode has "rot_" prefix or "_rot" suffix, wrap it with RotQuantizer(), remember to set up + R_left, R_right tensors later. """ + use_rot = False + if "rot_" in qa_mode or "_rot" in qa_mode: + use_rot = True + qa_mode = qa_mode.replace("rot_", "").replace("_rot", "") + if not use_swcap: QPACTLUT = { "pact_uni": PACT, @@ -123,23 +134,27 @@ def get_activation_quantizer( ) elif qa_mode == "dorefa": act_quantizer = dorefa_quantize_activation - elif ( - qa_mode == "max" - ): # NOTE Need to be careful using this for activation, particular to 1 sided. - act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) - elif qa_mode == "minmax": - act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) + elif "max" in qa_mode: + # NOTE Need to be careful using this for activation, particular to 1 sided. + if "min" in qa_mode: + act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=True) + elif "pertoken" in qa_mode or "perToken" in qa_mode: + act_quantizer = QMaxDynamic(nbits, dim=-1) + elif "per_channel" in qa_mode or "perCh" in qa_mode: + act_quantizer = QMaxDynamic(nbits, dim=-2) + elif "sym" in qa_mode: + act_quantizer = Qmax( + nbits, + align_zero=True, + minmax=False, + extend_act_range=extend_act_range, + ) + else: + act_quantizer = Qmax(nbits, align_zero=align_zero, minmax=False) elif qa_mode == "fix": act_quantizer = QFixSymmetric( nbits, init_clip_val=clip_val, align_zero=align_zero ) - elif qa_mode == "maxsym": - act_quantizer = Qmax( - nbits, - align_zero=True, - minmax=False, - extend_act_range=extend_act_range, - ) elif qa_mode == "pactsym": act_quantizer = PACT2Sym( nbits, @@ -179,8 +194,6 @@ def get_activation_quantizer( perToken=perToken, emulate=True, ) - elif qa_mode == "pertokenmax": - act_quantizer = PerTokenMax(nbits) else: raise ValueError(f"unrecognized activation quantization mode {qa_mode}") else: # swcap-compatible activation quantizers @@ -220,6 +233,9 @@ def get_activation_quantizer( f"activation quantization mode {qa_mode} is incompatible with swcap" ) + if use_rot: + act_quantizer = RotQuantWrapper(act_quantizer) + return act_quantizer @@ -245,7 +261,15 @@ def get_weight_quantizer( SWCAP quantizers: - sawb/sawb+ - max + If qa_mode has "rot_" prefix or "_rot" suffix, wrap it with RotQuantizer(), remember to set up + R_left, R_right tensors later. """ + + use_rot = False + if "rot_" in qw_mode or "_rot" in qw_mode: + use_rot = True + qw_mode = qw_mode.replace("rot_", "").replace("_rot", "") + weight_quantizer = None if not use_swcap: cggrad = "cgpact" in qw_mode @@ -367,6 +391,9 @@ def get_weight_quantizer( f"activation quantized mode {qw_mode} is incompatible with swcap" ) + if use_rot: + weight_quantizer = RotQuantWrapper(weight_quantizer) + return weight_quantizer @@ -3470,7 +3497,7 @@ def __init__(self, num_bits): """ For per-token activation quantization using abs().max() as scale, Zero is aligned so that the levels are symmetric around zero (lossing one level) - Since the token length is un-known before running, the quatnization is dynamic, meaning + Since the token length is un-known before running, the quantization is dynamic, meaning no trainable quantization scales and the scales are computed at run time. """ super().__init__() @@ -3487,6 +3514,42 @@ def __repr__(self): return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" +class QMaxDynamic(nn.Module): + def __init__(self, num_bits, dim=-1): + """ + For per-token or per-channel quantization using abs().max() as scale, usually for activation + and could be used for Qbmm M2 as well. + (reduce) dim = -1 -> abs() will output a column vector (if input is 2D) => per token + dim = -2 -> per-channel + Zero is aligned so that the levels are symmetric around zero (lossing one level) + Since the token length is un-known before running, the quantizater can only calculate the + scales at the run times dynamically, meaning no trainable quantization scales is allowed. + (unless input seq length is always the same, not just padded to a fixed length.) + """ + super().__init__() + self.num_bits = num_bits + self.levels = 2 ** (self.num_bits - 1) - 1 + if isinstance(dim, str): + if "perCh" in dim or "per_channel" in dim: + dim = -2 + elif "perToken" in dim or "per_token" in dim or "per_Token" in dim: + dim = -1 + elif dim in [-1, -2]: + self.reduce_dim = dim + else: + raise ValueError( + f"Reduce dim can only be [-1, -2] or ['perCh', 'perToken'] but found {dim}" + ) + + def forward(self, input_tensor): + amax_dim = input_tensor.abs().max(dim=self.reduce_dim, keepdim=True)[0] + scales = amax_dim.clamp(min=1e-5).div(self.levels) + return input_tensor.div(scales).round().mul(scales) + + def __repr__(self): + return f"{self.__class__.__name__}(num_bits={self.num_bits}, quantizer=)" + + class Qdynamic(nn.Module): def __init__( self, @@ -4560,7 +4623,7 @@ def forward(self, x_orig): class Qbypass(nn.Module): """ - no quantization at all, straight-thru + No quantization at all, output the input_tensor directly. in place of lambda function when using nbits=32 and 16. to avoid issue when pickle (ie torch.save) of lambda (seems to be a problem only for DDP) diff --git a/fms_mo/quant/rotation.py b/fms_mo/quant/rotation.py new file mode 100644 index 0000000..70cf284 --- /dev/null +++ b/fms_mo/quant/rotation.py @@ -0,0 +1,195 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Util functions related to Hadamard rotation.""" + +# Third Party +import torch + +# Local +from fms_mo.utils.hadamard_util import matmul_hadU, matmul_hadU_cuda + + +class RotQuantWrapper(torch.nn.Module): + """Add a wrapper to fms-mo quantizers. Objects of this class could have two rotation tensors, + and basic formula is: + + quantizer(Rot_left @ input_tensor @ Rot_right) + + But Rot_xxx could be optional, depending on whether it's for weights or activations. + + For weights, two possible use cases in SpinQuant are: + (A^-1 W) and (A^-1 W B). + Since linear.weight is already W^T and should stay as (rotated W)^T , these two cases will be + (A^-1 W)^T = W^T (A^-1)^T = W^T A, as A is a Hadamard matrix + (A^-1 W B)^T = B^T W^T A + ** Furthermore, depending on R1 is A (v_proj) or B (o_ and down_proj), computation could be + slightly different + if R1 is A (R_left): + calc W^T A first -> (W^T A)^T -> reshape -> *B -> .t() then ready for linear + else R1 is B (R_right): + calc B^T W^T first -> reshape -> *A -> ready for linear + + For activation (online rotation), it will always be (input_tensor @ R_right) + + then return F.linear(qx, qw, bias) + + NOTE + 0. If online_full_had == False and self.R_left is None => do nothing, apply quantizer ONLY. + 1. Make sure self.R is pointing to a nn.Parameter() if training on R is needed. + 2. Because R is a ptr to a nn.Param tensor, it CANNOT store a "transposed" copy, hence the use + of self.transpose flags if needed. + """ + + def __init__(self, quantizer=None, *args, **kwargs): + self.online_full_had = kwargs.pop("online_full_had", None) + self.compute_dtype = kwargs.pop("compute_dtype", torch.float64) + super().__init__(*args, **kwargs) + self.quantizer = quantizer + self.R_left = None + self.R_right = None + self.K_left = None + self.K_right = None + self.R1_is_left = True # see dosstring above + self.transpose_right = False # this flag is for online rotation only + # if K_xxx == 1, use exact hadamard matrix. (R_xxx won't be needed). but if K > 1, R will + # be one of the 12 special had matrix. (they are stored in a binary file) + + def forward(self, inp): + org_dtype = inp.dtype + + if self.R_left is not None: + # Case 1: Weight rotation + # as Activation rotation will only have R_right. If R_left exists for A => + # should have absorbed R_left for A into prev layer's W. + # Hence, R_left is not None can only mean weight rotation, not online => + # could be either 1) R_left only or 2) both R_left and R_right. + + in_feat, out_feat = inp.shape[-1], inp.shape[0] # input is W^T (out, in) + if self.R1_is_left: + # for q, k, v, up, gate, calc W^T A first. see details in docstring + inp = inp.to(self.compute_dtype) @ self.R_left.to(self.compute_dtype) + + if self.R_right is not None: + had_dim = self.R_right.shape[0] + inp = inp.t() # (W^T A) ^T = A^T W, shape is (in, out) + inp = inp.reshape(-1, out_feat // had_dim, had_dim) + inp = inp.to(self.compute_dtype) @ self.R_right.to( + self.compute_dtype + ) + inp = inp.reshape((in_feat, out_feat)).t() + + else: + assert self.R_right is not None, "R1_is_right but R_right is None." + + # for o, down, calc B^T W^T first, where R1 is B + inp = self.R_right.t().to(self.compute_dtype) @ inp.to( + self.compute_dtype + ) + had_dim = self.R_left.shape[0] + inp = inp.t() # this will be W, not W^T, i.e. (in, out) + w_shape = inp.shape + inp = inp.reshape(-1, in_feat // had_dim, had_dim) + inp = inp.to(self.compute_dtype) @ self.R_left.to(self.compute_dtype) + inp = inp.reshape((out_feat, in_feat)) + + elif self.R_right is not None or self.K_right == 1: + # Case 2: rotation for activation. should always be (inp @ R_right) + if self.online_full_had: + # Case 2-1: online, no training to R. When R_right is None (K==1), use exact size + if self.compute_dtype in [torch.float, torch.float64]: + # follow SpinQuant paper, use no higher than fp32 for online had + inp = inp.float() + + # matmul_hadU_cuda already include 1/sqrt(shape[-1]) + if self.transpose_right and self.R_right is not None: + inp = matmul_hadU_cuda(inp, self.R_right.t(), self.K_right) + else: + inp = matmul_hadU_cuda(inp, self.R_right, self.K_right) + # inp = matmul_hadU(inp) + else: + # Case 2-2: offline (such as last R before lm_head) + if self.transpose_right: + inp = inp.to(self.compute_dtype) @ self.R_right.t().to( + self.compute_dtype + ) + else: + inp = inp.to(self.compute_dtype) @ self.R_right.to( + self.compute_dtype + ) + + # Case 3: both R_left and R_right are None and K!=1=> No Rotation, apply quantizer if exist. + + inp = inp.to(org_dtype) + + if self.quantizer: + # with torch.no_grad(): + inp = self.quantizer(inp) + + return inp + + def __repr__(self): + """Simplified repr for RotQuantizer. Shows name and nbits.""" + repr_str = "Only(" + if self.quantizer is not None: + repr_str = f"{self.quantizer.__class__.__name__}(" + + if self.R_left is not None or self.online_full_had: + # will do W or A rotation + repr_str = ( + "Rot" + + repr_str + + f"{'' if self.R_left is None else 'Rl'},{'' if self.R_right is None else 'Rr'})" + ) + + return repr_str + + +class EmbeddingRotWrapper(torch.nn.Module): + """Simply add a Rotation after input embeddings. original code looks like + + input_embeds = self.embed_tokens(input_ids) + + This wrapper will be: + + input_embeds = self.embed_tokens(input_ids) + dtype = input_embeds.dtype + if self.R: + input_embeds = input_embeds @ self.R).to(dtype) + return input_embeds + + Also need to make sure self.R is pointing to a nn.Parameter() if training on R is needed. + """ + + def __init__(self, emb, *args, **kwargs): + super().__init__(*args, **kwargs) + self.emb = emb + self.R = None + self.compute_dtype = torch.float64 + + def forward(self, inp_ids): + inp_embeds = self.emb(inp_ids) + org_dtype = inp_embeds.dtype + if self.R is not None: + inp_embeds = ( + inp_embeds.to(self.compute_dtype) @ self.R.to(self.compute_dtype) + ).to(org_dtype) + return inp_embeds + + def __repr__(self): + """Simplified repr for RotEmb.""" + repr_str = f"Rot{str(self.emb)}" + if self.R is not None: + repr_str.replace(")", ", Rr)") + return repr_str diff --git a/fms_mo/utils/hadamard_util.py b/fms_mo/utils/hadamard_util.py new file mode 100644 index 0000000..9f92b2e --- /dev/null +++ b/fms_mo/utils/hadamard_util.py @@ -0,0 +1,183 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is based on QuaRot(https://github.com/spcl/QuaRot/tree/main/quarot). +# Licensed under Apache License 2.0. +# Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py +# and https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py +""" +Change original "text tensor implementation" into binaries for better efficiency. Only has 12 +sizes available in the safetensors file. [12, 20, 28, 36, 40, 44, 52, 60, 108, 140, 156, 172] +""" + +# Standard +from pathlib import Path + +# Third Party +from fast_hadamard_transform import hadamard_transform # pylint: disable=import-error +from safetensors import safe_open +import torch + +# TODO make sure it's a persistent cache so we don't need to load from file everytime +cwd = Path(__file__).parent +hadKs = {} +with safe_open(cwd / "hadk.safetensors", framework="pt", device="cuda") as f: + for K_str in f.keys(): # K is a str + hadKs[K_str] = f.get_tensor(K_str) + + +class HadamardTransform(torch.autograd.Function): + """The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))""" + + # TODO seems redundant, insdie hadamard_transform(), backward is already handled...? + @staticmethod + def forward(_ctx, u): + return hadamard_transform(u) + + @staticmethod + def backward(_ctx, grad): + return hadamard_transform(grad) + + +def get_hadK(n, transpose=False): + """Simplify the implementation and use binary tensors instead of text implementation.""" + for K in [172, 156, 140, 108, 60, 52, 44, 40, 36, 28, 20, 12]: + if n % K == 0 and is_pow2(n // K): + hadK = hadKs[str(K)] + if transpose: + hadK = hadK.T + break + + if hadK is None: + if is_pow2(n): + K = 1 + else: + raise RuntimeError( + f"{n} is not power of 2 or does not have a special size Hadamard available." + ) + + return hadK, K + + +def matmul_hadU(X, transpose=False): + """Borrowed from SpinQuant.""" + n = X.shape[-1] + hadK, K = get_hadK(n, transpose) + input_ = X.clone().view(-1, n, 1) + output = input_.clone() + while input_.shape[1] > K: + input_ = input_.view(input_.shape[0], input_.shape[1] // 2, 2, input_.shape[2]) + output = output.view(input_.shape) + output[:, :, 0, :] = input_[:, :, 0, :] + input_[:, :, 1, :] + output[:, :, 1, :] = input_[:, :, 0, :] - input_[:, :, 1, :] + output = output.view(input_.shape[0], input_.shape[1], -1) + (input_, output) = (output, input_) + del output + + if K > 1: + # Do not explicitly repeat - OOM + # input_ = torch.bmm( + # hadK.repeat(len(input_), 1, 1).to(input_.device).to(input_.dtype), input_) + # Use bcast instead + input_ = hadK.view(1, K, K).to(input_) @ input_ + + return input_.view(X.shape) / torch.tensor(n).sqrt() + + +def matmul_hadUt(X): + """Borrowed from SpinQuant.""" + return matmul_hadU(X, transpose=True) + + +def random_hadamard_matrix(size, device): + """Borrowed from SpinQuant.""" + # See https://cornell-relaxml.github.io/quip-sharp/ + # Section "Randomized Hadamard Transformation" + Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return matmul_hadU(Q).to(device) + + +def hadamard_matrix(size, device): + """Borrowed from SpinQuant.""" + Q = torch.eye(size) + return matmul_hadU(Q).to(device) + + +def matmul_hadU_cuda(X, hadK, K): + """Borrowed from SpinQuant.""" + n = X.shape[-1] + if K == 1: + return HadamardTransform.apply(X.contiguous()) / torch.tensor(n).sqrt() + # if transpose: + # hadK = hadK.T.contiguous() + input_ = X.view(-1, K, n // K) + input_ = HadamardTransform.apply(input_.contiguous()) / torch.tensor(n).sqrt() + input_ = hadK.to(input_.device).to(input_.dtype) @ input_ + return input_.reshape(X.shape) + + +# def matmul_hadUt_cuda(X, hadK, K): +# """Borrowed from SpinQuant.""" +# return matmul_hadU_cuda(X, hadK, K, transpose=True) + + +def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): + """Borrowed from SpinQuant.""" + assert isinstance(module, torch.nn.Linear) + in_features, out_features = module.in_features, module.out_features + + if had_dim != -1: + assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" + + W_ = module.weight.data + dtype = W_.dtype + dev = W_.device + init_shape = W_.shape + W_ = W_.float().cuda() + + if had_dim == -1: + if output: + had_K, K = get_hadK(out_features) + W_ = matmul_hadU_cuda(W_.t(), had_K, K).t() + if not output: + had_K, K = get_hadK(in_features) + W_ = matmul_hadU_cuda(W_, had_K, K) + else: + hadK = hadamard_matrix(had_dim, "cuda").to(torch.float64) + if R2 is not None: + hadK = R2.to(torch.float64) + if output: + W_ = W_.t() + transposed_shape = W_.shape + temp = W_.reshape(-1, transposed_shape[-1] // had_dim, had_dim) + temp = temp.to(torch.float64) @ hadK + W_ = temp.reshape(transposed_shape).t() + else: + init_shape = W_.shape + temp = W_.reshape(-1, init_shape[-1] // had_dim, had_dim) + temp = temp.to(torch.float64) @ hadK + W_ = temp.reshape(init_shape) + module.weight.data = W_.to(device=dev, dtype=dtype) + + +def is_pow2(n): + """Borrowed from SpinQuant.""" + return (n & (n - 1) == 0) and (n > 0) + + +# hadamard matrices for had12, had36.pal2, had52,will, +# # had60.pal, had108.pal, had140.pal, had156.will, had172.will: +# http://www.neilsloane.com/hadamard/index.html diff --git a/fms_mo/utils/hadk.safetensors b/fms_mo/utils/hadk.safetensors new file mode 100644 index 0000000..399a8fa Binary files /dev/null and b/fms_mo/utils/hadk.safetensors differ diff --git a/fms_mo/utils/qconfig_utils.py b/fms_mo/utils/qconfig_utils.py index caafec1..ef8f973 100644 --- a/fms_mo/utils/qconfig_utils.py +++ b/fms_mo/utils/qconfig_utils.py @@ -568,18 +568,9 @@ def check_config(config, model_dtype=None): ) # Set allowed qa_modes, qw_modes, bmm_modes - qa_mode_settings = [ - "pact", - "pact+", - "pactsym", - "pactsym+", - "max", - "minmax", - "maxsym", - "pertokenmax", - "lsq+", - "fix", - "brecq", + shared_modes = [ + "max_perToken", + "max_perCh", # fp8_e4m3 "fp8_e4m3_sat", "fp8_e4m3_scale", @@ -594,6 +585,23 @@ def check_config(config, model_dtype=None): "fp8_e5m2_scale_perCh", "fp8_e5m2_sat_perToken", "fp8_e5m2_scale_perToken", + # others + "_only", # was "rot_only" + "no_quant", # could be used for those nbits = 16 or 32 + ] + + qa_mode_settings = [ + "pact", + "pact+", + "pactsym", + "pactsym+", + "max", + "minmax", + "maxsym", + "pertokenmax", + "lsq+", + "fix", + "brecq", ] qw_mode_settings = [ "sawb", @@ -616,20 +624,6 @@ def check_config(config, model_dtype=None): "brecq", "adaround", "pertokenmax", - # fp8_e4m3 - "fp8_e4m3_sat", - "fp8_e4m3_scale", - "fp8_e4m3_sat_perCh", - "fp8_e4m3_scale_perCh", - "fp8_e4m3_sat_perToken", - "fp8_e4m3_scale_perToken", - # fp8_e5m2 - "fp8_e5m2_sat", - "fp8_e5m2_scale", - "fp8_e5m2_sat_perCh", - "fp8_e5m2_scale_perCh", - "fp8_e5m2_sat_perToken", - "fp8_e5m2_scale_perToken", ] bmm_mode_settings = [ "pact", @@ -639,10 +633,6 @@ def check_config(config, model_dtype=None): "max", "minmax", "pertokenmax", - "fp8_e4m3_sat", - "fp8_e4m3_scale_perToken", - "fp8_e5m2_sat", - "fp8_e5m2_scale_perToken", ] # Get strings in config for qa_modes, qw_modes, bmm_modes @@ -663,24 +653,24 @@ def check_config(config, model_dtype=None): # Check each for correct ranges for qa_mode_str in qa_modes_str: - qa_mode = config.get(qa_mode_str, "pact+") - if not qa_mode in qa_mode_settings: + qa_mode = config.get(qa_mode_str, "pact+").replace("rot_", "") + if not (qa_mode in qa_mode_settings or qa_mode in shared_modes): raise ValueError( f"{qa_mode_str} = {qa_mode} is not set to one of the following: " f"{qa_mode_settings}" ) for qw_mode_str in qw_modes_str: - qw_mode = config.get(qw_mode_str, "sawb+") - if not qw_mode in qw_mode_settings: + qw_mode = config.get(qw_mode_str, "sawb+").replace("rot_", "") + if not (qw_mode in qw_mode_settings or qw_mode in shared_modes): raise ValueError( f"{qw_mode_str} = {qw_mode} is not set to one of the following: " f"{qw_mode_settings}" ) for bmm_mode_str in bmm_modes_str: - bmm_mode = config.get(bmm_mode_str, "pactsym+") - if not bmm_mode in bmm_mode_settings: + bmm_mode = config.get(bmm_mode_str, "pactsym+").replace("rot_", "") + if not (bmm_mode in bmm_mode_settings or bmm_mode in shared_modes): raise ValueError( f"{bmm_mode_str} = {bmm_mode} is not set to one of the following: " f"{bmm_mode_settings}" diff --git a/fms_mo/utils/utils.py b/fms_mo/utils/utils.py index 38e2a1d..e333316 100644 --- a/fms_mo/utils/utils.py +++ b/fms_mo/utils/utils.py @@ -23,6 +23,7 @@ # Standard from contextlib import ExitStack, contextmanager +from functools import partial from typing import Any, Callable, Dict, List, Tuple, Union from unittest import mock import logging @@ -71,7 +72,12 @@ def move_to(obj, device): return obj -def mockbmm(mat1, mat2, default_to_torch=False): +def mockbmm( + mat1, + mat2, + default_to_torch=False, + target_line_num=[0], +): """ This function is used to mock the behavior of the bmm function in PyTorch. It is used to work around the fact that the bmm function in PyTorch is not @@ -87,7 +93,9 @@ def mockbmm(mat1, mat2, default_to_torch=False): cf = sys._getframe() qbmm_mod = None qbmm_lineno = cf.f_back.f_lineno - while cf.f_back and qbmm_mod is None: + if qbmm_lineno not in target_line_num: + default_to_torch = True + while (not default_to_torch) and cf.f_back and qbmm_mod is None: # First frame is QBmm's forward itself, can start searching from previous stack cf = cf.f_back if ( @@ -102,13 +110,21 @@ def mockbmm(mat1, mat2, default_to_torch=False): return qbmm_mod(mat1, mat2) -def mockmatmul(mat1, mat2, default_to_torch=False): +def mockmatmul( + mat1, + mat2, + default_to_torch=False, + target_line_num=[0], +): """ Patches torch.matmul() with QBmm( torch.bmm() ) Args: mat1 (torch.Tensor): The first matrix to be multiplied. mat2 (torch.Tensor): The second matrix to be multiplied. + target_bmm_lineno: Only patch matmul/bmm on the line number previously found by qmodel_prep. + i.e., matmuls/bmms other than self-attn will not be patched. + => need to make sure qmodel_prep only found 2. Returns: torch.Tensor: The result of the mock matrix multiplication. @@ -124,7 +140,9 @@ def mockmatmul(mat1, mat2, default_to_torch=False): cf = sys._getframe() qbmm_mod = None qbmm_lineno = cf.f_back.f_lineno - while cf.f_back and qbmm_mod is None: + if qbmm_lineno not in target_line_num: + default_to_torch = True + while (not default_to_torch) and cf.f_back and qbmm_mod is None: cf = cf.f_back if ( "forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name @@ -134,16 +152,17 @@ def mockmatmul(mat1, mat2, default_to_torch=False): qbmm_mod = getattr(mod_calling_bmm_function, f"QBmm{qbmm_lineno}", None) del cf - # Didn't find the corresponding QBmm, default the call to torch.bmm + # Didn't find the corresponding QBmm, default the call to torch.bmm, which only accepts 3D if qbmm_mod is None and default_to_torch: - org_batch_header = mat1.shape[:2] - # Need to double check m1/m2 are 3d, otherwise reshape - if len(mat1.shape) > 3: + # Need to reshape if inputs are 2d or 4d + if len(mat1.shape) == len(mat2.shape) and len(mat2.shape) in [2, 4]: + tar_shape = [mat1.shape[-2], mat1.shape[-1]] + if len(mat1.shape) == 4: + tar_shape = mat1.shape[:2] + tar_shape mat1 = mat1.reshape([-1, mat1.shape[-2], mat1.shape[-1]]) - if len(mat2.shape) > 3: mat2 = mat2.reshape([-1, mat2.shape[-2], mat2.shape[-1]]) output = torch.bmm(mat1, mat2) - output = output.reshape([*org_batch_header, *output.shape[1:]]) + output = output.reshape(tar_shape) return output return qbmm_mod(mat1, mat2) @@ -158,13 +177,14 @@ def patch_torch_bmm(qcfg): if qcfg is not None: # could be 'torch.bmm', 'torch.matmul', or None ops_to_patch = qcfg.get("which2patch_contextmanager", None) + tar_ln = list(qcfg["bmm_prep"]["layers_with_bmm"].values())[0] # if qcfg["bmm_prep"]["bmm_only_in_self_attn"] is False, may need to enable default_to_torch # in mock functions, e.g. partial(mockmatmul, default_to_torch=True) # This is in case a model uses extra matmuls, and QBmmXXX is not found or attached properly. new_target = ( - mockbmm + partial(mockbmm, target_line_num=tar_ln) if ops_to_patch == "torch.bmm" - else mockmatmul + else partial(mockmatmul, target_line_num=tar_ln) if ops_to_patch == "torch.matmul" else None )