From 4c5ae6723617168eaf42933e66c3725a17bbbde5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=93=AE02-ISP=E4=BA=8C=E7=BB=84?= Date: Thu, 17 Nov 2022 15:47:03 +0800 Subject: [PATCH 1/4] add regularizer --- .../cifar10/basecase/main.py | 15 ++++-- .../cifar10/basecase/qconfig_lsq_dampen.yaml | 14 ++++++ .../cifar10/basecase/qconfig_pact.yaml | 3 ++ sparsebit/quantization/quant_config.py | 4 ++ .../quantization/regularizers/__init__.py | 15 ++++++ sparsebit/quantization/regularizers/base.py | 6 +++ sparsebit/quantization/regularizers/dampen.py | 48 +++++++++++++++++++ sparsebit/quantization/regularizers/pact.py | 20 ++++++++ 8 files changed, 120 insertions(+), 5 deletions(-) create mode 100644 examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml create mode 100644 sparsebit/quantization/regularizers/__init__.py create mode 100644 sparsebit/quantization/regularizers/base.py create mode 100644 sparsebit/quantization/regularizers/dampen.py create mode 100644 sparsebit/quantization/regularizers/pact.py diff --git a/examples/quantization_aware_training/cifar10/basecase/main.py b/examples/quantization_aware_training/cifar10/basecase/main.py index a57b659..05a80eb 100644 --- a/examples/quantization_aware_training/cifar10/basecase/main.py +++ b/examples/quantization_aware_training/cifar10/basecase/main.py @@ -5,6 +5,7 @@ import time import warnings from enum import Enum +import math import torch import torch.nn as nn @@ -21,6 +22,7 @@ from model import resnet20 from sparsebit.quantization import QuantModel, parse_qconfig +from sparsebit.quantization.regularizers import build_regularizer parser = argparse.ArgumentParser(description="PyTorch Cifar Training") @@ -147,8 +149,6 @@ def main(): qconfig = parse_qconfig(args.config) - is_pact = qconfig.A.QUANTIZER.TYPE == "pact" - qmodel = QuantModel(model, qconfig).cuda() # 将model转化为量化模型,以支持后续QAT的各种量化操作 # set head and tail of model is 8bit @@ -181,6 +181,11 @@ def main(): optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1 ) + if qconfig.REGULARIZER.ENABLE: + regularizer = build_regularizer(qconfig) + else: + regularizer = None + best_acc1 = 0 for epoch in range(args.start_epoch, args.epochs): # train for one epoch @@ -190,7 +195,7 @@ def main(): criterion, optimizer, epoch, - is_pact, + regularizer, args.regularizer_lambda, args.print_freq, ) @@ -247,7 +252,7 @@ def train( criterion, optimizer, epoch, - is_pact, + regularizer, regularizer_lambda, print_freq, ): @@ -278,7 +283,7 @@ def train( # compute output output = model(images) ce_loss = criterion(output, target) - regular_loss = get_regularizer_loss(model, is_pact, scale=regularizer_lambda) + regular_loss = get_regularizer_loss(model, regularizer, regularizer_lambda) loss = ce_loss + regular_loss # measure accuracy and record loss diff --git a/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml b/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml new file mode 100644 index 0000000..93c6127 --- /dev/null +++ b/examples/quantization_aware_training/cifar10/basecase/qconfig_lsq_dampen.yaml @@ -0,0 +1,14 @@ +BACKEND: virtual +W: + QSCHEME: per-channel-symmetric + QUANTIZER: + TYPE: lsq + BIT: 4 +A: + QSCHEME: per-tensor-affine + QUANTIZER: + TYPE: lsq + BIT: 4 +REGULARIZER: + ENABLE: True + TYPE: dampen diff --git a/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml b/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml index a191746..72b370e 100644 --- a/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml +++ b/examples/quantization_aware_training/cifar10/basecase/qconfig_pact.yaml @@ -9,3 +9,6 @@ A: QUANTIZER: TYPE: pact BIT: 4 +REGULARIZER: + ENABLE: True + TYPE: pact diff --git a/sparsebit/quantization/quant_config.py b/sparsebit/quantization/quant_config.py index b46532d..61204a8 100644 --- a/sparsebit/quantization/quant_config.py +++ b/sparsebit/quantization/quant_config.py @@ -47,6 +47,10 @@ _C.A.QADD.ENABLE_QUANT = False _C.A.SPECIFIC = [] +_C.REGULARIZER = CN() +_C.REGULARIZER.ENABLE = False +_C.REGULARIZER.TYPE = "" + def parse_qconfig(cfg_file): qconfig = _parse_config(cfg_file, default_cfg=_C) diff --git a/sparsebit/quantization/regularizers/__init__.py b/sparsebit/quantization/regularizers/__init__.py new file mode 100644 index 0000000..2fda594 --- /dev/null +++ b/sparsebit/quantization/regularizers/__init__.py @@ -0,0 +1,15 @@ +REGULARIZERS_MAP = {} + + +def register_regularizer(regularizer): + REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer + return regularizer + + +from .base import Regularizer +from . import dampen, pact + + +def build_regularizer(config): + regularizer = REGULARIZERS_MAP[config.REGULARIZER.TYPE.lower()](config) + return regularizer diff --git a/sparsebit/quantization/regularizers/base.py b/sparsebit/quantization/regularizers/base.py new file mode 100644 index 0000000..45e3cb9 --- /dev/null +++ b/sparsebit/quantization/regularizers/base.py @@ -0,0 +1,6 @@ +class Regularizer(object): + def __init__(self, config): + self.config = config + + def __call__(self): + pass diff --git a/sparsebit/quantization/regularizers/dampen.py b/sparsebit/quantization/regularizers/dampen.py new file mode 100644 index 0000000..f1f7191 --- /dev/null +++ b/sparsebit/quantization/regularizers/dampen.py @@ -0,0 +1,48 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Dampen" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + + def _get_loss(self, x, quantizer): + + x_q = quantizer(x) + + qmin, qmax = quantizer.qdesc.qrange + + scale, zero_point = quantizer._qparams_preprocess(x) + + scale = scale.detach() + zero_point = zero_point.detach() + + min_val = (qmin - zero_point) * scale + + max_val = (qmax - zero_point) * scale + + x_c = torch.min(torch.max(x, min_val), max_val) + + loss = (x_q - x_c) ** 2 + + loss = loss.sum() + + return loss + + def __call__(self, model): + loss = 0.0 + for n, m in model.named_modules(): + if ( + hasattr(m, "weight") + and hasattr(m, "weight_quantizer") + and m.weight_quantizer + and m.weight_quantizer.is_enable + ): + loss += self._get_loss(m.weight, m.weight_quantizer) + return loss diff --git a/sparsebit/quantization/regularizers/pact.py b/sparsebit/quantization/regularizers/pact.py new file mode 100644 index 0000000..18c4cf0 --- /dev/null +++ b/sparsebit/quantization/regularizers/pact.py @@ -0,0 +1,20 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Pact" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + + def __call__(self, model): + loss = 0.0 + for n, p in model.named_parameters(): + if "alpha" in n: + loss += (p ** 2).sum() + return loss From ab4100b9f287ce51249d9d16a28dd4dc1fd367b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=93=AE02-ISP=E4=BA=8C=E7=BB=84?= Date: Thu, 15 Dec 2022 15:47:04 +0800 Subject: [PATCH 2/4] update --- .../cifar10/basecase/main.py | 15 +++------------ sparsebit/quantization/regularizers/dampen.py | 13 ++++++------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/examples/quantization_aware_training/cifar10/basecase/main.py b/examples/quantization_aware_training/cifar10/basecase/main.py index 05a80eb..63339e9 100644 --- a/examples/quantization_aware_training/cifar10/basecase/main.py +++ b/examples/quantization_aware_training/cifar10/basecase/main.py @@ -230,18 +230,9 @@ def main(): ) -# PACT算法中对 alpha 增加 L2-regularization -def get_pact_regularizer_loss(model): - loss = 0 - for n, p in model.named_parameters(): - if "alpha" in n: - loss += (p**2).sum() - return loss - - -def get_regularizer_loss(model, is_pact, scale=0): - if is_pact: - return get_pact_regularizer_loss(model) * scale +def get_regularizer_loss(model, regularizer, _lambda): + if regularizer is not None: + return regularizer(model) * _lambda else: return torch.tensor(0.0).cuda() diff --git a/sparsebit/quantization/regularizers/dampen.py b/sparsebit/quantization/regularizers/dampen.py index f1f7191..2cc1bf9 100644 --- a/sparsebit/quantization/regularizers/dampen.py +++ b/sparsebit/quantization/regularizers/dampen.py @@ -2,6 +2,7 @@ from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer from sparsebit.quantization.regularizers import register_regularizer +from sparsebit.quantization.quant_tensor import fake_qrange_factory @register_regularizer @@ -16,16 +17,14 @@ def _get_loss(self, x, quantizer): x_q = quantizer(x) - qmin, qmax = quantizer.qdesc.qrange - scale, zero_point = quantizer._qparams_preprocess(x) - scale = scale.detach() - zero_point = zero_point.detach() - - min_val = (qmin - zero_point) * scale + min_val, max_val = fake_qrange_factory[quantizer.backend]( + scale, zero_point, quantizer.qdesc + ) - max_val = (qmax - zero_point) * scale + min_val = min_val.detach() + max_val = max_val.detach() x_c = torch.min(torch.max(x, min_val), max_val) From 03ab277aacb53a8322282bf93b884f1243c51b6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=93=AE02-ISP=E4=BA=8C=E7=BB=84?= Date: Thu, 15 Dec 2022 16:03:43 +0800 Subject: [PATCH 3/4] update --- sparsebit/quantization/regularizers/dampen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsebit/quantization/regularizers/dampen.py b/sparsebit/quantization/regularizers/dampen.py index 2cc1bf9..c121d0f 100644 --- a/sparsebit/quantization/regularizers/dampen.py +++ b/sparsebit/quantization/regularizers/dampen.py @@ -2,7 +2,7 @@ from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer from sparsebit.quantization.regularizers import register_regularizer -from sparsebit.quantization.quant_tensor import fake_qrange_factory +from sparsebit.quantization.quantizers.quant_tensor import fake_qrange_factory @register_regularizer From da297e88c761cdaa4078accbbb3c596ec9714c62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E9=93=AE02-ISP=E4=BA=8C=E7=BB=84?= Date: Thu, 15 Dec 2022 16:35:43 +0800 Subject: [PATCH 4/4] black --- sparsebit/quantization/regularizers/pact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparsebit/quantization/regularizers/pact.py b/sparsebit/quantization/regularizers/pact.py index 18c4cf0..89b7cba 100644 --- a/sparsebit/quantization/regularizers/pact.py +++ b/sparsebit/quantization/regularizers/pact.py @@ -16,5 +16,5 @@ def __call__(self, model): loss = 0.0 for n, p in model.named_parameters(): if "alpha" in n: - loss += (p ** 2).sum() + loss += (p**2).sum() return loss