1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Quantization utils.""" 16 17import numpy as np 18from mindspore._checkparam import Validator 19from ... import nn 20 21__all__ = ["load_nonquant_param_into_quant_net", "query_quant_layers"] 22 23 24def cal_quantization_params(input_min, 25 input_max, 26 quant_min, 27 quant_max, 28 data_type, 29 symmetric=False): 30 r""" 31 Calculate quantization params for scale and zero point. 32 33 Args: 34 input_min (numpy.ndarray): The dimension of channel or 1. 35 input_max (numpy.ndarray): The dimension of channel or 1. 36 quant_min (int): The minimum quantization integer. 37 quant_max (int): The maximum quantization integer. 38 data_type (numpy type) : Can be numpy int8, numpy uint8. 39 symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. 40 41 Returns: 42 scale (numpy.ndarray): quantization param. 43 zero point (numpy.ndarray): quantization param. 44 """ 45 input_max = np.maximum(0.0, input_max) 46 input_min = np.minimum(0.0, input_min) 47 48 if input_min.shape != input_max.shape: 49 raise ValueError("input min shape should equal to input max.") 50 if len(input_min.shape) > 1: 51 raise ValueError("input min and max shape should be one dim.") 52 if (input_min > input_max).all(): 53 raise ValueError("input_min min should less than input max.") 54 if (input_max == input_min).all(): 55 return np.ones(input_min.shape), np.zeros(input_min.shape) 56 57 # calculate scale 58 if symmetric: 59 input_max = np.maximum(-input_min, input_max) 60 input_min = -input_max 61 scale = (input_max - input_min) / (quant_max - quant_min) 62 63 # calculate zero point 64 if data_type == np.int8 and symmetric: 65 zp = np.zeros(input_min.shape) 66 else: 67 zp_double = quant_min - input_min / scale 68 zp = np.floor(zp_double + 0.5) 69 70 return scale, zp 71 72 73def get_quant_min_max(data_type, num_bits=8, narrow_range=False): 74 """Calculate quantization params for minimum/maximum quantization integer""" 75 if data_type == np.int8: 76 quant_min = 0 - 2 ** (num_bits - 1) 77 quant_max = 2 ** (num_bits - 1) - 1 78 elif data_type == np.uint8: 79 quant_min = 0 80 quant_max = 2 ** num_bits - 1 81 else: 82 raise ValueError("Unsupported datatype({})".format(data_type)) 83 if narrow_range: 84 quant_min = quant_min + 1 85 return quant_min, quant_max 86 87 88def weight2int(data, scale, zero_point, quant_min, quant_max): 89 r""" 90 Calculate int8/uint8 weight from fp32. the formula is defined as: 91 92 .. math:: 93 int8/uint8 = round(float/scale) + offset 94 95 Args: 96 data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. 97 scale (numpy.ndarray): The dimension of channel or 1. 98 zero_point (numpy.ndarray): The dimension of channel or 1. 99 quant_min (int): The minimum quantization integer. 100 quant_max (int): The maximum quantization integer. 101 102 Returns: 103 weight (numpy.ndarray): The dimension of channel or 1. 104 """ 105 if scale.shape != zero_point.shape: 106 raise ValueError("`scale` and `zero_point` should have the same shape.") 107 if scale.shape[0] < 0: 108 raise ValueError("`scale` and `zero_point` shape should greater than zero.") 109 if len(scale.shape) >= 1 and scale.shape[0] > 1: 110 # for perchannel 111 if scale.shape[0] == data.shape[0]: 112 # `Conv2d` or `Dense` op weight 113 shape_list = [-1] + [1] * len(data.shape[1:]) 114 scale = scale.reshape(shape_list) 115 zero_point = zero_point.reshape(shape_list) 116 elif scale.shape[0] == data.shape[1]: 117 # `DepthwiseConv2d` op weight 118 shape_list = [1, -1] + [1] * len(data.shape[2:]) 119 scale = scale.reshape(shape_list) 120 zero_point = zero_point.reshape(shape_list) 121 else: 122 raise ValueError("Unsupported weight shape({})".format(data.shape)) 123 124 weight_int = np.round((data / scale) + zero_point) 125 weight_int[weight_int > quant_max] = quant_max 126 weight_int[weight_int < quant_min] = quant_min 127 return weight_int 128 129 130def scale_zp_max_min_from_fake_quant_cell(cell, data_type): 131 """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`.""" 132 minq = cell.minq.data.asnumpy() 133 maxq = cell.maxq.data.asnumpy() 134 # make sure maxq > 0 and minq <= 0 135 if cell.mode == 'LEARNED_SCALE': 136 maxq = np.abs(maxq) 137 minq = -np.abs(minq) 138 quant_min, quant_max = get_quant_min_max(data_type, num_bits=cell.num_bits, narrow_range=cell.narrow_range) 139 symmetric = cell.symmetric and not cell.neg_trunc 140 scale, zp = cal_quantization_params( 141 minq, maxq, 142 quant_min, quant_max, data_type, 143 symmetric=symmetric) 144 return scale, zp, maxq, minq 145 146 147def fold_batchnorm(weight, cell_quant): 148 r""" 149 Fold the batchnorm in `Conv2dBnFoldQuant` to weight. 150 151 Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. 152 153 Args: 154 weight (numpy.ndarray): Weight of `cell_quant`. 155 cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`. 156 157 Returns: 158 weight (numpy.ndarray): Folded weight. 159 bias (numpy.ndarray): Folded bias. 160 """ 161 variance = cell_quant.moving_variance.data.asnumpy() 162 mean = cell_quant.moving_mean.data.asnumpy() 163 gamma = cell_quant.gamma.data.asnumpy() 164 beta = cell_quant.beta.data.asnumpy() 165 epsilon = cell_quant.eps 166 sigma = np.sqrt(variance + epsilon) 167 168 if gamma.shape[0] == weight.shape[0]: 169 # `Conv2d` or `Dense` op weight 170 shape_list = [-1] + [1] * len(weight.shape[1:]) 171 _gamma = gamma.reshape(shape_list) 172 _sigma = sigma.reshape(shape_list) 173 elif gamma.shape[0] == weight.shape[1]: 174 # `DepthwiseConv2d` op weight 175 shape_list = [1, -1] + [1] * len(weight.shape[2:]) 176 _gamma = gamma.reshape(shape_list) 177 _sigma = sigma.reshape(shape_list) 178 else: 179 raise ValueError("Unsupported weight shape({})".format(weight.shape)) 180 181 weight = weight * _gamma / _sigma 182 bias = beta - gamma * mean / sigma 183 return weight, bias 184 185 186def without_fold_batchnorm(weight, cell_quant): 187 r""" 188 Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight. 189 190 Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. 191 192 Args: 193 weight (numpy.ndarray): Weight of `cell_quant`. 194 cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`. 195 196 Returns: 197 weight (numpy.ndarray): whihout folded weight. 198 bias (numpy.ndarray): without folded bias. 199 """ 200 variance = cell_quant.batchnorm.moving_variance.data.asnumpy() 201 mean = cell_quant.batchnorm.moving_mean.data.asnumpy() 202 gamma = cell_quant.batchnorm.gamma.data.asnumpy() 203 beta = cell_quant.batchnorm.beta.data.asnumpy() 204 epsilon = cell_quant.batchnorm.eps 205 sigma = np.sqrt(variance + epsilon) 206 207 if gamma.shape[0] == weight.shape[0]: 208 # `Conv2d` or `Dense` op weight 209 shape_list = [-1] + [1] * len(weight.shape[1:]) 210 _gamma = gamma.reshape(shape_list) 211 _sigma = sigma.reshape(shape_list) 212 elif gamma.shape[0] == weight.shape[1]: 213 # `DepthwiseConv2d` op weight 214 shape_list = [1, -1] + [1] * len(weight.shape[2:]) 215 _gamma = gamma.reshape(shape_list) 216 _sigma = sigma.reshape(shape_list) 217 else: 218 raise ValueError("Unsupported weight shape({})".format(weight.shape)) 219 220 weight = weight * _gamma / _sigma 221 bias = beta - gamma * mean / sigma 222 return weight, bias 223 224 225def compute_kl_threshold(data, bitwidth): 226 r""" 227 Using KL-J Distance to calculate the clip threshold. 228 229 Args: 230 - **data** (NumpyArray) - Data observed to calculate the threshold for quantization, 231 - **bitwidth** (QuantDtype) - The datatype of quantization. 232 Outputs: 233 Tensor with Shape 1. Threshold to calculate the data. 234 """ 235 data_max = np.abs(data).max() 236 if data_max < 1e-5: 237 return 1e-5 238 hist, bin_edges = np.histogram(np.abs(data), bins='sqrt', range=(0, data_max), density=True) 239 # For the sake of high efficiency, we limit the maximum number of bins to 1024 in `sqrt` mode, If it exceeds the 240 # largest size, turn to use the default bins config. 241 largest_bin_size = 1024 242 if hist.shape[0] > largest_bin_size: 243 hist, bin_edges = np.histogram(np.abs(data), range=(0, data_max), density=True) 244 hist = hist / np.sum(hist) 245 cumsum = np.cumsum(hist) 246 bit_pow_range = pow(2, int(bitwidth.num_bits) - 1) 247 threshold = [] 248 scaling_factor = [] 249 kl = [] 250 if bit_pow_range + 1 > len(bin_edges) - 1: 251 th_layer_out = bin_edges[-1] 252 return float(th_layer_out) 253 for i in range(bit_pow_range + 1, len(bin_edges), 1): 254 threshold_tmp = (i + 0.5) * (bin_edges[1] - bin_edges[0]) 255 threshold = np.concatenate((threshold, [threshold_tmp])) 256 scaling_factor_tmp = threshold_tmp / (bit_pow_range - 1) 257 scaling_factor = np.concatenate((scaling_factor, [scaling_factor_tmp])) 258 # forward interpolation 259 cumsum_tmp = np.copy(cumsum) 260 cumsum_tmp[(i - 1):] = 1 261 fwd_x = np.linspace(0.0, 1.0, bit_pow_range) 262 fwd_xp = np.linspace(0.0, 1.0, i) 263 fwd_fp = cumsum_tmp[:i] 264 forward_interp = np.interp(fwd_x, fwd_xp, fwd_fp) 265 # backward interpolation 266 bwd_x = np.linspace(0.0, 1.0, i) 267 bwd_xp = np.linspace(0.0, 1.0, bit_pow_range) 268 bwd_fp = forward_interp 269 backward_interp = np.interp(bwd_x, bwd_xp, bwd_fp) 270 cumsum_tmp[:i] = backward_interp 271 kl_tmp = np.sum((cumsum - cumsum_tmp) * np.log2(cumsum / cumsum_tmp)) # Kullback-Leibler-J 272 kl = np.concatenate((kl, [kl_tmp])) 273 th_layer_out = threshold[np.argmin(kl)] 274 threshold = float(th_layer_out) 275 if threshold < 1e-5: 276 threshold = 1e-5 277 return threshold 278 279 280def query_quant_layers(network): 281 r""" 282 Query the network's quantization strategy of each quantized layer and print it to the screen, note that all the 283 quantization layers are queried before graph compile optimization in the graph mode, thus, some redundant quantized 284 layers, which not exist in practical execution, may appear. 285 286 Args: 287 network (Cell): input network 288 """ 289 network = Validator.check_isinstance("network", network, nn.Cell) 290 tplt = "{0:60}\t{1:10}" 291 for cell_and_name in network.cells_and_names(): 292 cell_name = cell_and_name[0] 293 cell = cell_and_name[1] 294 if isinstance(cell, nn.FakeQuantWithMinMaxObserver): 295 print(tplt.format(cell_name, cell.quant_dtype)) 296 297 298def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): 299 r""" 300 Load fp32 model parameters into quantization model. 301 302 Args: 303 quant_model(Cell): Quantization model. 304 params_dict(dict): Parameter dict that stores fp32 parameters. 305 quant_new_params(list): Parameters that exist in quantization network but not in non-quantization 306 network. Default: None. 307 308 Raises: 309 TypeError: If `quant_new_params` is not None and is not list. 310 ValueError: If there are parameters in the `quant_model` that are neither in `params_dict` 311 nor in `quant_new_params`. 312 """ 313 if quant_new_params is not None and not isinstance(quant_new_params, list): 314 raise TypeError("quant_new_params must be list or None.") 315 iterable_dict = { 316 'minq': iter(list(filter(lambda item: item[0].endswith('minq'), params_dict.items()))), 317 'maxq': iter(list(filter(lambda item: item[0].endswith('maxq'), params_dict.items()))), 318 'quant_max': iter(list(filter(lambda item: item[0].endswith('quant_max'), params_dict.items()))) 319 } 320 for param in params_dict.items(): 321 key_name = param[0].split(".")[-1] 322 if key_name not in iterable_dict: 323 iterable_dict[key_name] = iter(list(filter(lambda item, value=key_name: item[0].endswith(value), 324 params_dict.items()))) 325 326 for name, param in quant_model.parameters_and_names(): 327 key_name = name.split(".")[-1] 328 if key_name not in iterable_dict.keys(): 329 if key_name not in quant_new_params: 330 raise ValueError(f"Can't find match parameter in ckpt, param name = {name}") 331 continue 332 value_param = next(iterable_dict[key_name], None) 333 if value_param: 334 param.set_data(value_param[1].data) 335 print(f'init model param {name} with checkpoint param {value_param[0]}') 336 337 338 # Perform KL_init when learned scale quantization is executed. 339 for cell_and_name in quant_model.cells_and_names(): 340 cell = cell_and_name[1] 341 if isinstance(cell, (nn.Conv2dBnFoldQuantOneConv, nn.Conv2dBnFoldQuant, nn.Conv2dBnWithoutFoldQuant, 342 nn.Conv2dQuant, nn.DenseQuant)) and cell.fake_quant_weight.mode == "LEARNED_SCALE": 343 subcell_weight_para = cell.weight.data.asnumpy() 344 if hasattr(cell, 'gamma'): 345 scale_factor = (cell.gamma.data.asnumpy() / 346 np.sqrt(cell.moving_variance.data.asnumpy() + 1e-5)) 347 subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) 348 349 if cell.fake_quant_weight.per_channel: 350 max_init = [compute_kl_threshold(weight_para_each, cell.fake_quant_weight.quant_dtype) 351 for weight_para_each in subcell_weight_para] 352 min_init = [-x for x in max_init] 353 else: 354 max_init = [compute_kl_threshold(subcell_weight_para, cell.fake_quant_weight.quant_dtype)] 355 min_init = [-x for x in max_init] 356 357 cell.fake_quant_weight.reset(quant_dtype=cell.fake_quant_weight.quant_dtype, 358 min_init=min_init, max_init=max_init) 359