1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Quantization aware training 17 18User can use quantization aware to train a model. MindSpore supports quantization aware training, 19which models quantization errors in both the forward and backward passes using fake-quantization 20operations. Note that the entire computation is carried out in floating point. At the end of quantization 21aware training, MindSpore provides conversion functions to convert the trained model into lower precision. 22""" 23 24import re 25import mindspore.context as context 26import numpy as np 27from ... import nn, ops 28from ..._checkparam import Validator, Rel 29from ...nn.layer import quant 30from ...ops import functional as F 31from ..common import QuantDtype 32from .quantizer import Quantizer, OptimizeOption 33from .quant_utils import compute_kl_threshold 34 35 36__all__ = ["QuantizationAwareTraining", "create_quant_config"] 37 38 39def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver), 40 quant_delay=(0, 0), 41 quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), 42 per_channel=(False, False), 43 symmetric=(False, False), 44 narrow_range=(False, False), 45 mode="DEFAULT"): 46 r""" 47 Config the observer type of weights and data flow with quant parameters. 48 49 Args: 50 quant_observer (Union[Observer, list, tuple]): The types of observer for quantization. The first element 51 applies to weights and the second applies to data flow. Currently, only 52 :class:`FakeQuantWithMinMaxObserver` supported. 53 Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver). 54 quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized 55 during train and eval. The first element represents weights and the second element represents data flow. 56 Default: (0, 0). 57 quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first 58 element represents weights and the second element represents data flow. 59 Default: (QuantDtype.INT8, QuantDtype.INT8). 60 per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` 61 then base on per channel, otherwise base on per layer. The first element represents weights 62 and the second element represents data flow, and the second element must be `False` now. 63 Default: (False, False). 64 symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then 65 base on symmetric, otherwise base on asymmetric. The first element represents weights and the second 66 element represents data flow. Default: (False, False). 67 narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. 68 The first element represents weights and the second element represents data flow. 69 Default: (False, False). 70 mode (str): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported. 71 Default: ("DEFAULT"). 72 73 Returns: 74 QuantConfig, contains the observer type of weight and activation. 75 76 Raises: 77 ValueError: If the second element of `per_channel` is not `False`. 78 """ 79 if per_channel[-1]: 80 raise ValueError("Arg 'per_channel' second element must be 'False'.") 81 weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], 82 per_channel=per_channel[0], symmetric=symmetric[0], 83 narrow_range=narrow_range[0], mode=mode) 84 act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], 85 per_channel=per_channel[-1], symmetric=symmetric[-1], 86 narrow_range=narrow_range[-1], mode=mode) 87 return quant.QuantConfig(weight=weight_observer, activation=act_observer) 88 89 90class _AddFakeQuantInput(nn.Cell): 91 """ 92 Add FakeQuant OP at input of the network. Only support one input case. 93 """ 94 95 def __init__(self, network, quant_delay=0): 96 super(_AddFakeQuantInput, self).__init__(auto_prefix=False) 97 self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, 98 quant_delay=quant_delay, ema=True) 99 self.fake_quant_input.update_parameters_name('fake_quant_input.') 100 self.network = network 101 102 def construct(self, data): 103 data = self.fake_quant_input(data) 104 output = self.network(data) 105 return output 106 107 108class _AddFakeQuantAfterSubCell(nn.Cell): 109 """ 110 Add FakeQuant OP after of the sub Cell. 111 """ 112 113 def __init__(self, subcell, **kwargs): 114 super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) 115 self.subcell = subcell 116 self.mode = "DEFAULT" 117 self.max_init = 6 118 self.min_init = -6 119 120 if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]: 121 self.mode = "LEARNED_SCALE" 122 self.max_init = 16 123 self.min_init = -16 124 125 self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init, 126 max_init=self.max_init, 127 ema=True, 128 quant_dtype=kwargs["quant_dtype"], 129 quant_delay=kwargs["quant_delay"], 130 per_channel=kwargs["per_channel"], 131 symmetric=kwargs["symmetric"], 132 narrow_range=kwargs["narrow_range"], 133 mode=self.mode) 134 135 def construct(self, *data): 136 output = self.subcell(*data) 137 output = self.fake_quant_act(output) 138 return output 139 140 141class QuantizationAwareTraining(Quantizer): 142 r""" 143 Quantizer for quantization aware training. 144 145 Args: 146 bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default: True. 147 freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance. 148 Default: 1e7. 149 quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized 150 during train and eval. The first element represents weights and the second element represents data flow. 151 Default: (0, 0). 152 quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first 153 element represents weights and the second element represents data flow. It is necessary to consider the 154 precision support of hardware devices in the practical quantization infer scenario. 155 Default: (QuantDtype.INT8, QuantDtype.INT8). 156 per_channel (Union[bool, list, tuple]): Quantization granularity based on layer or on channel. If `True` 157 then base on per channel, otherwise base on per layer. The first element represents weights and the 158 second element represents data flow, and the second element must be `False` now. Default: (False, False). 159 symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then 160 base on symmetric, otherwise base on asymmetric. The first element represents weights and the second 161 element represents data flow. Default: (False, False). 162 narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not. 163 The first element represents weights and the second element represents data flow. 164 Default: (False, False). 165 optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently 166 only support `QAT` and `LEARNED_SCALE` (Note that, if both `QAT` and `LEARNED_SCALE` are configured, 167 `LEARNED_SCALE` has a higher priority. `LEARNED_SCALE` currently only work under some constraints, which 168 includes: freeze_bn=0, quant_delay=0, symmetric=True, narrow_range=True, More specifically, for operators 169 such as Relu and Relu6, which only have positive values, we add a negative truncation to optimize this 170 scenario, and narrow_range will automatically match to False). Default: OptimizeOption.QAT. 171 one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default: True. 172 173 Raises: 174 TypeError: If the element of `quant_delay` or `freeze_bn` is not int. 175 TypeError: If `bn_fold`, `one_conv_fold` or the element of `per_channel`, `symmetric`, `narrow_range` 176 is not bool. 177 TypeError: If the element of `quant_dtype` is not `QuantDtype`. 178 ValueError: If the length of `quant_delay`, `quant_dtype`, `per_channel`, `symmetric` or `narrow_range` is 179 not less than 2. 180 ValueError: If the `optimize_option` is `LEARNED_SCALE` and `freeze_bn` is not equal to 0. 181 ValueError: If the `optimize_option` is `LEARNED_SCALE` and `symmetric` is not (True, True). 182 ValueError: If the `optimize_option` is `LEARNED_SCALE` and `narrow_range` is not (True, True). 183 ValueError: If the `optimize_option` is `LEARNED_SCALE` and `quant_delay` is not (0, 0). 184 185 Examples: 186 >>> from mindspore.compression.quant import QuantizationAwareTraining 187 >>> class LeNet5(nn.Cell): 188 ... def __init__(self, num_class=10, channel=1): 189 ... super(LeNet5, self).__init__() 190 ... self.type = "fusion" 191 ... self.num_class = num_class 192 ... 193 ... # change `nn.Conv2d` to `nn.Conv2dBnAct` 194 ... self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') 195 ... self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') 196 ... # change `nn.Dense` to `nn.DenseBnAct` 197 ... self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') 198 ... self.fc2 = nn.DenseBnAct(120, 84, activation='relu') 199 ... self.fc3 = nn.DenseBnAct(84, self.num_class) 200 ... 201 ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 202 ... self.flatten = nn.Flatten() 203 ... 204 ... def construct(self, x): 205 ... x = self.conv1(x) 206 ... x = self.max_pool2d(x) 207 ... x = self.conv2(x) 208 ... x = self.max_pool2d(x) 209 ... x = self.flatten(x) 210 ... x = self.fc1(x) 211 ... x = self.fc2(x) 212 ... x = self.fc3(x) 213 ... return x 214 ... 215 >>> net = LeNet5() 216 >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False]) 217 >>> net_qat = quantizer.quantize(net) 218 """ 219 __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv", "ReduceMean"] 220 221 def __init__(self, 222 bn_fold=True, 223 freeze_bn=10000000, 224 quant_delay=(0, 0), 225 quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), 226 per_channel=(False, False), 227 symmetric=(False, False), 228 narrow_range=(False, False), 229 optimize_option=OptimizeOption.QAT, 230 one_conv_fold=True): 231 """Init for QuantizationAwareTraining quantizer""" 232 super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) 233 234 def convert2list(name, value): 235 if not isinstance(value, list) and not isinstance(value, tuple): 236 value = [value] 237 elif len(value) > 2: 238 raise ValueError("input `{}` len should less then 2".format(name)) 239 return value 240 241 quant_delay = convert2list("quant delay", quant_delay) 242 quant_dtype = convert2list("quant dtype", quant_dtype) 243 per_channel = convert2list("per channel", per_channel) 244 symmetric = convert2list("symmetric", symmetric) 245 narrow_range = convert2list("narrow range", narrow_range) 246 247 self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay") 248 self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay") 249 self.bn_fold = Validator.check_bool(bn_fold, "bn fold") 250 self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn") 251 self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype) 252 self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype) 253 self.weight_channel = Validator.check_bool(per_channel[0], "per channel") 254 self.act_channel = Validator.check_bool(per_channel[-1], "per channel") 255 self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric") 256 self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") 257 self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") 258 self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") 259 self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold") 260 self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv, 261 nn.DenseBnAct: self._convert_dense} 262 self.mode = "DEFAULT" 263 if OptimizeOption.LEARNED_SCALE in self.optimize_option: 264 self.mode = "LEARNED_SCALE" 265 if not self.weight_symmetric or not self.act_symmetric: 266 raise ValueError("OptimizeOption.LEARNED_SCALE currently only support " 267 "symmetric=(True, True) for quant") 268 if not self.weight_range or not self.act_range: 269 raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) " 270 "for quant") 271 if self.freeze_bn != 0: 272 raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, " 273 "but get freeze_bn={}".format(self.freeze_bn)) 274 if self.weight_qdelay != 0 or self.act_qdelay != 0: 275 raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)") 276 self.quant_config = create_quant_config(quant_delay=quant_delay, 277 quant_dtype=quant_dtype, 278 per_channel=per_channel, 279 symmetric=symmetric, 280 narrow_range=narrow_range, 281 mode=self.mode) 282 self.eps = 1e-5 283 284 @staticmethod 285 def _convert_op_name(name): 286 pattern = re.compile(r'([A-Z]{1})') 287 name_new = re.sub(pattern, r'_\1', name).lower() 288 if name_new[0] == '_': 289 name_new = name_new[1:] 290 return name_new 291 292 def quantize(self, network): 293 """ 294 Quant API to convert input network to a quantization aware training network. 295 296 Note: 297 Please refer to the Examples of class: `mindspore.compression.quant.QuantizationAwareTraining`. 298 299 Args: 300 network (Cell): network to be quantized. 301 302 Returns: 303 Cell, a quantization aware training network. 304 305 Raises: 306 KeyError: If the `device_target` set in context is not in `support_device`. 307 """ 308 support_device = ["Ascend", "GPU"] 309 if context.get_context('device_target') not in support_device: 310 raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) 311 312 if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option: 313 network.update_cell_prefix() 314 network = self._convert_subcells2quant(network) 315 network.update_cell_type("quant") 316 return network 317 318 def _convert_subcells2quant(self, network): 319 """ 320 convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell 321 """ 322 cells = network.name_cells() 323 change = False 324 for name in cells: 325 subcell = cells[name] 326 if subcell == network: 327 continue 328 elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)): 329 prefix = subcell.param_prefix 330 new_subcell = self._convert_method_map[type(subcell)](subcell) 331 new_subcell.update_parameters_name(prefix + '.') 332 network.insert_child_to_cell(name, new_subcell) 333 change = True 334 else: 335 self._convert_subcells2quant(subcell) 336 if isinstance(network, nn.SequentialCell) and change: 337 network.cell_list = list(network.cells()) 338 339 # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell. 340 if isinstance(network, (nn.FakeQuantWithMinMaxObserver, 341 nn.Conv2dBnFoldQuantOneConv, 342 nn.Conv2dBnFoldQuant, 343 nn.Conv2dBnWithoutFoldQuant, 344 nn.Conv2dQuant, 345 nn.DenseQuant, 346 nn.ActQuant, 347 nn.TensorAddQuant, 348 nn.MulQuant)): 349 return network 350 351 add_list = [] 352 for name in network.__dict__: 353 if name[0] == '_': 354 continue 355 attr = network.__dict__[name] 356 if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: 357 add_list.append((name, attr)) 358 for name, prim_op in add_list: 359 prefix = name 360 add_quant = _AddFakeQuantAfterSubCell(prim_op, 361 quant_dtype=self.act_dtype, 362 quant_delay=self.act_qdelay, 363 per_channel=self.act_channel, 364 symmetric=self.act_symmetric, 365 narrow_range=self.act_range, 366 optimize_option=self.optimize_option) 367 if network.param_prefix: 368 prefix = '.'.join([network.param_prefix, prefix]) 369 add_quant.update_parameters_name(prefix + '.') 370 del network.__dict__[name] 371 network.insert_child_to_cell(name, add_quant) 372 return network 373 374 def _convert_conv(self, subcell): 375 """ 376 convert Conv2d cell to quant cell 377 """ 378 min_init = -6 379 max_init = 6 380 if OptimizeOption.LEARNED_SCALE in self.optimize_option: 381 subcell_weight_para = subcell.conv.weight.data.asnumpy() 382 if subcell.has_bn: 383 scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / 384 np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) 385 subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) 386 min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype) 387 self.quant_config = self.quant_config._replace( 388 weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) 389 390 conv_inner = subcell.conv 391 if subcell.has_bn: 392 bn_inner = subcell.batchnorm 393 if self.bn_fold: 394 if self.one_conv_fold: 395 conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels, 396 conv_inner.out_channels, 397 kernel_size=conv_inner.kernel_size, 398 stride=conv_inner.stride, 399 pad_mode=conv_inner.pad_mode, 400 padding=conv_inner.padding, 401 dilation=conv_inner.dilation, 402 group=conv_inner.group, 403 eps=bn_inner.eps, 404 momentum=1 - bn_inner.momentum, 405 has_bias=conv_inner.has_bias, 406 bias_init=conv_inner.bias_init, 407 quant_config=self.quant_config, 408 quant_dtype=self.weight_dtype, 409 fake=True) 410 else: 411 conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, 412 conv_inner.out_channels, 413 kernel_size=conv_inner.kernel_size, 414 stride=conv_inner.stride, 415 pad_mode=conv_inner.pad_mode, 416 padding=conv_inner.padding, 417 dilation=conv_inner.dilation, 418 group=conv_inner.group, 419 eps=bn_inner.eps, 420 momentum=1 - bn_inner.momentum, 421 has_bias=conv_inner.has_bias, 422 bias_init=conv_inner.bias_init, 423 freeze_bn=self.freeze_bn, 424 quant_config=self.quant_config, 425 quant_dtype=self.weight_dtype, 426 fake=True) 427 # change original network Batch Normalization OP parameters to quant network 428 conv_inner.gamma = subcell.batchnorm.gamma 429 conv_inner.beta = subcell.batchnorm.beta 430 conv_inner.moving_mean = subcell.batchnorm.moving_mean 431 conv_inner.moving_variance = subcell.batchnorm.moving_variance 432 else: 433 conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, 434 conv_inner.out_channels, 435 kernel_size=conv_inner.kernel_size, 436 stride=conv_inner.stride, 437 pad_mode=conv_inner.pad_mode, 438 padding=conv_inner.padding, 439 dilation=conv_inner.dilation, 440 group=conv_inner.group, 441 eps=bn_inner.eps, 442 momentum=1 - bn_inner.momentum, 443 has_bias=conv_inner.has_bias, 444 bias_init=conv_inner.bias_init, 445 quant_config=self.quant_config, 446 quant_dtype=self.weight_dtype) 447 # change original network Batch Normalization OP parameters to quant network 448 conv_inner.batchnorm.gamma = subcell.batchnorm.gamma 449 conv_inner.batchnorm.beta = subcell.batchnorm.beta 450 conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean 451 conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance 452 del subcell.batchnorm 453 subcell.batchnorm = None 454 subcell.has_bn = False 455 else: 456 conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, 457 kernel_size=conv_inner.kernel_size, stride=conv_inner.stride, 458 pad_mode=conv_inner.pad_mode, padding=conv_inner.padding, 459 dilation=conv_inner.dilation, group=conv_inner.group, 460 has_bias=conv_inner.has_bias, quant_config=self.quant_config, 461 quant_dtype=self.weight_dtype) 462 # change original network Conv2D OP parameters to quant network 463 conv_inner.weight = subcell.conv.weight 464 if subcell.conv.has_bias: 465 conv_inner.bias = subcell.conv.bias 466 subcell.conv = conv_inner 467 if subcell.has_act and subcell.activation is not None: 468 subcell.activation = self._convert_activation(subcell.activation) 469 elif subcell.after_fake: 470 subcell.has_act = True 471 subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype, 472 quant_delay=self.act_qdelay, per_channel=self.act_channel, 473 symmetric=self.act_symmetric, narrow_range=self.act_range, 474 optimize_option=self.optimize_option) 475 return subcell 476 477 def _convert_dense(self, subcell): 478 """ 479 convert dense cell to quant cell 480 """ 481 min_init = -6 482 max_init = 6 483 if OptimizeOption.LEARNED_SCALE in self.optimize_option: 484 subcell_weight_para = subcell.dense.weight.data.asnumpy() 485 if subcell.has_bn: 486 scale_factor = (subcell.batchnorm.gamma.data.asnumpy() / 487 np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps)) 488 subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) 489 min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype) 490 self.quant_config = self.quant_config._replace( 491 weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init)) 492 493 dense_inner = subcell.dense 494 dense_inner = quant.DenseQuant(dense_inner.in_channels, 495 dense_inner.out_channels, 496 has_bias=dense_inner.has_bias, 497 quant_config=self.quant_config, 498 quant_dtype=self.weight_dtype) 499 # change original network Dense OP parameters to quant network 500 dense_inner.weight = subcell.dense.weight 501 if subcell.dense.has_bias: 502 dense_inner.bias = subcell.dense.bias 503 subcell.dense = dense_inner 504 if subcell.has_act and subcell.activation is not None: 505 subcell.activation = self._convert_activation(subcell.activation) 506 elif subcell.after_fake: 507 subcell.has_act = True 508 subcell.activation = _AddFakeQuantAfterSubCell(F.identity, 509 quant_dtype=self.act_dtype, 510 quant_delay=self.act_qdelay, 511 per_channel=self.act_channel, 512 symmetric=self.act_symmetric, 513 narrow_range=self.act_range, 514 optimize_option=self.optimize_option) 515 return subcell 516 517 def _convert_activation(self, activation): 518 """ 519 convert activation cell to quant cell 520 """ 521 act_class = activation.__class__ 522 act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] 523 act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] 524 525 if act_class in act_list: 526 return quant.ActQuant(activation=activation, 527 quant_config=self.quant_config, 528 quant_dtype=self.act_dtype) 529 if act_class in act_list_with_fake_before: 530 return quant.ActQuant(activation=activation, 531 ema=True, 532 fake_before=True, 533 quant_config=self.quant_config, 534 quant_dtype=self.act_dtype) 535 raise ValueError("Unsupported activation in auto quant: ", act_class) 536 537 def _kl_init(self, subcell_weight_para, weight_dtype): 538 """ 539 Calculate the value of max_init and min_init with compute_kl_threshold. 540 """ 541 if self.weight_channel: 542 max_init = [compute_kl_threshold(weight_para_each, weight_dtype) 543 for weight_para_each in subcell_weight_para] 544 min_init = [-x for x in max_init] 545 else: 546 max_init = [compute_kl_threshold(subcell_weight_para, weight_dtype)] 547 min_init = [-x for x in max_init] 548 return min_init, max_init 549 550 def _set_mixed_bits(self, network, strategy): 551 r""" 552 Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE` 553 optimize_option. 554 555 Args: 556 network (Cell): Input network. 557 strategy (list): The quantization strategy for layers that need to be quantified (eg. [[8], [8], 558 ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the 559 convolution layer is supported. 560 561 Returns: 562 Cell, a network with mixed bit strategy configured. 563 564 Raises: 565 ValueError: If `OptimizeOption.LEARNED_SCALE` is not in `self.optimize_option`. 566 """ 567 if OptimizeOption.LEARNED_SCALE not in self.optimize_option: 568 raise ValueError("The `_set_mixed_bits` function is currently only valid for `LEARNED_SCALE` " 569 "optimize_option.") 570 571 quantizable_idx = [] 572 pass_cell = None 573 for i, cell_and_name in enumerate(network.cells_and_names()): 574 cell = cell_and_name[1] 575 if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell: 576 quantizable_idx.append(i) 577 578 if len(quantizable_idx) != len(strategy): 579 raise ValueError("The dimension of quantifiable layers is not consistent with that of strategy.") 580 581 quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(quantizable_idx, strategy)} 582 type_map = { 583 QuantDtype.INT2.num_bits: QuantDtype.INT2, 584 QuantDtype.INT3.num_bits: QuantDtype.INT3, 585 QuantDtype.INT4.num_bits: QuantDtype.INT4, 586 QuantDtype.INT5.num_bits: QuantDtype.INT5, 587 QuantDtype.INT6.num_bits: QuantDtype.INT6, 588 QuantDtype.INT7.num_bits: QuantDtype.INT7, 589 QuantDtype.INT8.num_bits: QuantDtype.INT8 590 } 591 for i, cell_and_name in enumerate(network.cells_and_names()): 592 cell = cell_and_name[1] 593 if i not in quantizable_idx: 594 continue 595 else: 596 if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)): 597 cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]] 598 if isinstance(cell, nn.Conv2dBnAct): 599 subcell_weight_para = cell.conv.weight.data.asnumpy() 600 if hasattr(cell.conv, 'gamma'): 601 scale_factor = (cell.conv.gamma.data.asnumpy() / 602 np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps)) 603 subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) 604 min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype) 605 cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, 606 min_init=min_init, 607 max_init=max_init) 608 elif isinstance(cell, nn.DenseBnAct): 609 subcell_weight_para = cell.dense.weight.data.asnumpy() 610 if hasattr(cell.dense, 'gamma'): 611 scale_factor = (cell.dense.gamma.data.asnumpy() / 612 np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps)) 613 subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1) 614 min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype) 615 cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype, 616 min_init=min_init, 617 max_init=max_init) 618 return network 619