|
1 | 1 | import torch |
2 | 2 | import math |
3 | | -from .utils import mse_loss |
4 | 3 | from sparsebit.quantization.observers import Observer as BaseObserver |
5 | 4 | from sparsebit.quantization.observers import register_observer |
6 | 5 | from sparsebit.quantization.quantizers.quant_tensor import STE |
@@ -56,6 +55,8 @@ def __init__(self, config, qdesc): |
56 | 55 | 8: 11.16, |
57 | 56 | } |
58 | 57 | self.gaus_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5) |
| 58 | + self.distribution = config.OBSERVER.ACIQ.DISTRIBUTION.lower() |
| 59 | + assert self.distribution in ["gaus", "laplace"] |
59 | 60 |
|
60 | 61 | def calc_laplace_minmax(self, data, is_half_range): |
61 | 62 | if self.is_perchannel: |
@@ -115,53 +116,13 @@ def calc_minmax(self): |
115 | 116 | data = self.get_calibration_data(c_first=True) |
116 | 117 | is_half_range = data.min() >= 0 |
117 | 118 |
|
118 | | - laplace_min_val, laplace_max_val = self.calc_laplace_minmax(data, is_half_range) |
119 | | - scale_laplace, zero_point_laplace = self.calc_qparams_with_minmax( |
120 | | - laplace_min_val, laplace_max_val |
121 | | - ) |
122 | | - mse_laplace = mse_loss( |
123 | | - STE.apply( |
124 | | - data, scale_laplace, zero_point_laplace, self.qdesc, self.backend |
125 | | - ), |
126 | | - data, |
127 | | - self.is_perchannel, |
128 | | - ) |
129 | | - |
130 | | - gaus_min_val, gaus_max_val = self.calc_gaus_minmax( |
131 | | - data, batch_size, is_half_range |
132 | | - ) |
133 | | - scale_gaus, zero_point_gaus = self.calc_qparams_with_minmax( |
134 | | - gaus_min_val, gaus_max_val |
135 | | - ) |
136 | | - |
137 | | - mse_gaus = mse_loss( |
138 | | - STE.apply(data, scale_gaus, zero_point_gaus, self.qdesc, self.backend), |
139 | | - data, |
140 | | - self.is_perchannel, |
141 | | - ) |
142 | | - |
143 | | - naive_min_val, naive_max_val = self.calc_naive_minmax(data) |
144 | | - scale_minmax, zero_point_minmax = self.calc_qparams_with_minmax( |
145 | | - naive_min_val, naive_max_val |
146 | | - ) |
147 | | - mse_minmax = mse_loss( |
148 | | - STE.apply(data, scale_minmax, zero_point_minmax, self.qdesc, self.backend), |
149 | | - data, |
150 | | - self.is_perchannel, |
151 | | - ) |
152 | | - |
153 | | - mse_gaus_laplace = torch.minimum(mse_gaus, mse_laplace) |
154 | | - self.min_val = torch.where( |
155 | | - mse_gaus < mse_laplace, gaus_min_val, laplace_min_val |
156 | | - ) |
157 | | - self.min_val = torch.where( |
158 | | - mse_minmax < mse_gaus_laplace, naive_min_val, self.min_val |
159 | | - ).to(self.device) |
160 | | - self.max_val = torch.where( |
161 | | - mse_gaus < mse_laplace, gaus_max_val, laplace_max_val |
162 | | - ) |
163 | | - self.max_val = torch.where( |
164 | | - mse_minmax < mse_gaus_laplace, naive_max_val, self.max_val |
165 | | - ).to(self.device) |
| 119 | + if self.distribution == "laplace": |
| 120 | + min_val, max_val = self.calc_laplace_minmax(data, is_half_range) |
| 121 | + else: |
| 122 | + min_val, max_val = self.calc_gaus_minmax( |
| 123 | + data, batch_size, is_half_range |
| 124 | + ) |
| 125 | + self.min_val = min_val.to(self.device) |
| 126 | + self.max_val = max_val.to(self.device) |
166 | 127 |
|
167 | 128 | return self.min_val, self.max_val |
0 commit comments