• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Quantization aware training
17
18User can use quantization aware to train a model. MindSpore supports quantization aware training,
19which models quantization errors in both the forward and backward passes using fake-quantization
20operations. Note that the entire computation is carried out in floating point. At the end of quantization
21aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
22"""
23
24import re
25import mindspore.context as context
26import numpy as np
27from ... import nn, ops
28from ..._checkparam import Validator, Rel
29from ...nn.layer import quant
30from ...ops import functional as F
31from ..common import QuantDtype
32from .quantizer import Quantizer, OptimizeOption
33from .quant_utils import compute_kl_threshold
34
35
36__all__ = ["QuantizationAwareTraining", "create_quant_config"]
37
38
39def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver),
40                        quant_delay=(0, 0),
41                        quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
42                        per_channel=(False, False),
43                        symmetric=(False, False),
44                        narrow_range=(False, False),
45                        mode="DEFAULT"):
46    r"""
47    Config the observer type of weights and data flow with quant parameters.
48
49    Args:
50        quant_observer (Union[Observer, list, tuple]): The types of observer for quantization. The first element
51            applies to weights and the second applies to data flow. Currently, only
52            :class:`FakeQuantWithMinMaxObserver` supported.
53            Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver).
54        quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized
55            during train and eval. The first element represents weights and the second element represents data flow.
56            Default: (0, 0).
57        quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first
58            element represents weights and the second element represents data flow.
59            Default: (QuantDtype.INT8, QuantDtype.INT8).
60        per_channel (Union[bool, list, tuple]):  Quantization granularity based on layer or on channel. If `True`
61            then base on per channel, otherwise base on per layer. The first element represents weights
62            and the second element represents data flow, and the second element must be `False` now.
63            Default: (False, False).
64        symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
65            base on symmetric, otherwise base on asymmetric. The first element represents weights and the second
66            element represents data flow. Default: (False, False).
67        narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not.
68            The first element represents weights and the second element represents data flow.
69            Default: (False, False).
70        mode (str): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported.
71            Default: ("DEFAULT").
72
73    Returns:
74        QuantConfig, contains the observer type of weight and activation.
75
76    Raises:
77        ValueError: If the second element of `per_channel` is not `False`.
78    """
79    if per_channel[-1]:
80        raise ValueError("Arg 'per_channel' second element must be 'False'.")
81    weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
82                                                     per_channel=per_channel[0], symmetric=symmetric[0],
83                                                     narrow_range=narrow_range[0], mode=mode)
84    act_observer = quant_observer[-1].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
85                                                   per_channel=per_channel[-1], symmetric=symmetric[-1],
86                                                   narrow_range=narrow_range[-1], mode=mode)
87    return quant.QuantConfig(weight=weight_observer, activation=act_observer)
88
89
90class _AddFakeQuantInput(nn.Cell):
91    """
92    Add FakeQuant OP at input of the network. Only support one input case.
93    """
94
95    def __init__(self, network, quant_delay=0):
96        super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
97        self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
98                                                                  quant_delay=quant_delay, ema=True)
99        self.fake_quant_input.update_parameters_name('fake_quant_input.')
100        self.network = network
101
102    def construct(self, data):
103        data = self.fake_quant_input(data)
104        output = self.network(data)
105        return output
106
107
108class _AddFakeQuantAfterSubCell(nn.Cell):
109    """
110    Add FakeQuant OP after of the sub Cell.
111    """
112
113    def __init__(self, subcell, **kwargs):
114        super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
115        self.subcell = subcell
116        self.mode = "DEFAULT"
117        self.max_init = 6
118        self.min_init = -6
119
120        if OptimizeOption.LEARNED_SCALE in kwargs["optimize_option"]:
121            self.mode = "LEARNED_SCALE"
122            self.max_init = 16
123            self.min_init = -16
124
125        self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=self.min_init,
126                                                                max_init=self.max_init,
127                                                                ema=True,
128                                                                quant_dtype=kwargs["quant_dtype"],
129                                                                quant_delay=kwargs["quant_delay"],
130                                                                per_channel=kwargs["per_channel"],
131                                                                symmetric=kwargs["symmetric"],
132                                                                narrow_range=kwargs["narrow_range"],
133                                                                mode=self.mode)
134
135    def construct(self, *data):
136        output = self.subcell(*data)
137        output = self.fake_quant_act(output)
138        return output
139
140
141class QuantizationAwareTraining(Quantizer):
142    r"""
143    Quantizer for quantization aware training.
144
145    Args:
146        bn_fold (bool): Whether to use bn fold ops for simulation inference operation. Default: True.
147        freeze_bn (int): Number of steps after which BatchNorm OP parameters fixed to global mean and variance.
148            Default: 1e7.
149        quant_delay (Union[int, list, tuple]): Number of steps after which weights and activations are quantized
150            during train and eval. The first element represents weights and the second element represents data flow.
151            Default: (0, 0).
152        quant_dtype (Union[QuantDtype, list, tuple]): Datatype used to quantize weights and activations. The first
153            element represents weights and the second element represents data flow. It is necessary to consider the
154            precision support of hardware devices in the practical quantization infer scenario.
155            Default: (QuantDtype.INT8, QuantDtype.INT8).
156        per_channel (Union[bool, list, tuple]):  Quantization granularity based on layer or on channel. If `True`
157            then base on per channel, otherwise base on per layer. The first element represents weights and the
158            second element represents data flow, and the second element must be `False` now. Default: (False, False).
159        symmetric (Union[bool, list, tuple]): Whether the quantization algorithm is symmetric or not. If `True` then
160            base on symmetric, otherwise base on asymmetric. The first element represents weights and the second
161            element represents data flow. Default: (False, False).
162        narrow_range (Union[bool, list, tuple]): Whether the quantization algorithm uses narrow range or not.
163            The first element represents weights and the second element represents data flow.
164            Default: (False, False).
165        optimize_option (Union[OptimizeOption, list, tuple]): Specifies the quant algorithm and options, currently
166            only support `QAT` and `LEARNED_SCALE` (Note that, if both `QAT` and `LEARNED_SCALE` are configured,
167            `LEARNED_SCALE` has a higher priority. `LEARNED_SCALE` currently only work under some constraints, which
168            includes: freeze_bn=0, quant_delay=0, symmetric=True, narrow_range=True, More specifically, for operators
169            such as Relu and Relu6, which only have positive values, we add a negative truncation to optimize this
170            scenario, and narrow_range will automatically match to False). Default: OptimizeOption.QAT.
171        one_conv_fold (bool): Whether to use one conv bn fold ops for simulation inference operation. Default: True.
172
173    Raises:
174        TypeError: If the element of `quant_delay` or `freeze_bn` is not int.
175        TypeError: If `bn_fold`, `one_conv_fold` or the element of `per_channel`, `symmetric`, `narrow_range`
176            is not bool.
177        TypeError: If the element of `quant_dtype` is not `QuantDtype`.
178        ValueError: If the length of `quant_delay`, `quant_dtype`, `per_channel`, `symmetric` or `narrow_range` is
179            not less than 2.
180        ValueError: If the `optimize_option` is `LEARNED_SCALE` and `freeze_bn` is not equal to 0.
181        ValueError: If the `optimize_option` is `LEARNED_SCALE` and `symmetric` is not (True, True).
182        ValueError: If the `optimize_option` is `LEARNED_SCALE` and `narrow_range` is not (True, True).
183        ValueError: If the `optimize_option` is `LEARNED_SCALE` and `quant_delay` is not (0, 0).
184
185    Examples:
186        >>> from mindspore.compression.quant import QuantizationAwareTraining
187        >>> class LeNet5(nn.Cell):
188        ...     def __init__(self, num_class=10, channel=1):
189        ...         super(LeNet5, self).__init__()
190        ...         self.type = "fusion"
191        ...         self.num_class = num_class
192        ...
193        ...         # change `nn.Conv2d` to `nn.Conv2dBnAct`
194        ...         self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
195        ...         self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
196        ...         # change `nn.Dense` to `nn.DenseBnAct`
197        ...         self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
198        ...         self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
199        ...         self.fc3 = nn.DenseBnAct(84, self.num_class)
200        ...
201        ...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
202        ...         self.flatten = nn.Flatten()
203        ...
204        ...     def construct(self, x):
205        ...         x = self.conv1(x)
206        ...         x = self.max_pool2d(x)
207        ...         x = self.conv2(x)
208        ...         x = self.max_pool2d(x)
209        ...         x = self.flatten(x)
210        ...         x = self.fc1(x)
211        ...         x = self.fc2(x)
212        ...         x = self.fc3(x)
213        ...         return x
214        ...
215        >>> net = LeNet5()
216        >>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
217        >>> net_qat = quantizer.quantize(net)
218    """
219    __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv", "ReduceMean"]
220
221    def __init__(self,
222                 bn_fold=True,
223                 freeze_bn=10000000,
224                 quant_delay=(0, 0),
225                 quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
226                 per_channel=(False, False),
227                 symmetric=(False, False),
228                 narrow_range=(False, False),
229                 optimize_option=OptimizeOption.QAT,
230                 one_conv_fold=True):
231        """Init for QuantizationAwareTraining quantizer"""
232        super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
233
234        def convert2list(name, value):
235            if not isinstance(value, list) and not isinstance(value, tuple):
236                value = [value]
237            elif len(value) > 2:
238                raise ValueError("input `{}` len should less then 2".format(name))
239            return value
240
241        quant_delay = convert2list("quant delay", quant_delay)
242        quant_dtype = convert2list("quant dtype", quant_dtype)
243        per_channel = convert2list("per channel", per_channel)
244        symmetric = convert2list("symmetric", symmetric)
245        narrow_range = convert2list("narrow range", narrow_range)
246
247        self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay")
248        self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay")
249        self.bn_fold = Validator.check_bool(bn_fold, "bn fold")
250        self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn")
251        self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype)
252        self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype)
253        self.weight_channel = Validator.check_bool(per_channel[0], "per channel")
254        self.act_channel = Validator.check_bool(per_channel[-1], "per channel")
255        self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric")
256        self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
257        self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
258        self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
259        self.one_conv_fold = Validator.check_bool(one_conv_fold, "one conv fold")
260        self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
261                                    nn.DenseBnAct: self._convert_dense}
262        self.mode = "DEFAULT"
263        if OptimizeOption.LEARNED_SCALE in self.optimize_option:
264            self.mode = "LEARNED_SCALE"
265            if not self.weight_symmetric or not self.act_symmetric:
266                raise ValueError("OptimizeOption.LEARNED_SCALE currently only support "
267                                 "symmetric=(True, True) for quant")
268            if not self.weight_range or not self.act_range:
269                raise ValueError("OptimizeOption.LEARNED_SCALE currently only support narrow_range=(True, True) "
270                                 "for quant")
271            if self.freeze_bn != 0:
272                raise ValueError("OptimizeOption.LEARNED_SCALE currently only support freeze_bn equal to 0, "
273                                 "but get freeze_bn={}".format(self.freeze_bn))
274            if self.weight_qdelay != 0 or self.act_qdelay != 0:
275                raise ValueError("OptimizeOption.LEARNED_SCALE currently only support quant_delay=(0, 0)")
276        self.quant_config = create_quant_config(quant_delay=quant_delay,
277                                                quant_dtype=quant_dtype,
278                                                per_channel=per_channel,
279                                                symmetric=symmetric,
280                                                narrow_range=narrow_range,
281                                                mode=self.mode)
282        self.eps = 1e-5
283
284    @staticmethod
285    def _convert_op_name(name):
286        pattern = re.compile(r'([A-Z]{1})')
287        name_new = re.sub(pattern, r'_\1', name).lower()
288        if name_new[0] == '_':
289            name_new = name_new[1:]
290        return name_new
291
292    def quantize(self, network):
293        """
294        Quant API to convert input network to a quantization aware training network.
295
296        Note:
297            Please refer to the Examples of class: `mindspore.compression.quant.QuantizationAwareTraining`.
298
299        Args:
300            network (Cell): network to be quantized.
301
302        Returns:
303            Cell, a quantization aware training network.
304
305        Raises:
306            KeyError: If the `device_target` set in context is not in `support_device`.
307        """
308        support_device = ["Ascend", "GPU"]
309        if context.get_context('device_target') not in support_device:
310            raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
311
312        if OptimizeOption.QAT in self.optimize_option or OptimizeOption.LEARNED_SCALE in self.optimize_option:
313            network.update_cell_prefix()
314            network = self._convert_subcells2quant(network)
315            network.update_cell_type("quant")
316        return network
317
318    def _convert_subcells2quant(self, network):
319        """
320        convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
321        """
322        cells = network.name_cells()
323        change = False
324        for name in cells:
325            subcell = cells[name]
326            if subcell == network:
327                continue
328            elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)):
329                prefix = subcell.param_prefix
330                new_subcell = self._convert_method_map[type(subcell)](subcell)
331                new_subcell.update_parameters_name(prefix + '.')
332                network.insert_child_to_cell(name, new_subcell)
333                change = True
334            else:
335                self._convert_subcells2quant(subcell)
336        if isinstance(network, nn.SequentialCell) and change:
337            network.cell_list = list(network.cells())
338
339        # add FakeQuant OP after OP in white list, but not including those wrapped in the below quantization cell.
340        if isinstance(network, (nn.FakeQuantWithMinMaxObserver,
341                                nn.Conv2dBnFoldQuantOneConv,
342                                nn.Conv2dBnFoldQuant,
343                                nn.Conv2dBnWithoutFoldQuant,
344                                nn.Conv2dQuant,
345                                nn.DenseQuant,
346                                nn.ActQuant,
347                                nn.TensorAddQuant,
348                                nn.MulQuant)):
349            return network
350
351        add_list = []
352        for name in network.__dict__:
353            if name[0] == '_':
354                continue
355            attr = network.__dict__[name]
356            if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
357                add_list.append((name, attr))
358        for name, prim_op in add_list:
359            prefix = name
360            add_quant = _AddFakeQuantAfterSubCell(prim_op,
361                                                  quant_dtype=self.act_dtype,
362                                                  quant_delay=self.act_qdelay,
363                                                  per_channel=self.act_channel,
364                                                  symmetric=self.act_symmetric,
365                                                  narrow_range=self.act_range,
366                                                  optimize_option=self.optimize_option)
367            if network.param_prefix:
368                prefix = '.'.join([network.param_prefix, prefix])
369            add_quant.update_parameters_name(prefix + '.')
370            del network.__dict__[name]
371            network.insert_child_to_cell(name, add_quant)
372        return network
373
374    def _convert_conv(self, subcell):
375        """
376        convert Conv2d cell to quant cell
377        """
378        min_init = -6
379        max_init = 6
380        if OptimizeOption.LEARNED_SCALE in self.optimize_option:
381            subcell_weight_para = subcell.conv.weight.data.asnumpy()
382            if subcell.has_bn:
383                scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
384                                np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
385                subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
386            min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype)
387        self.quant_config = self.quant_config._replace(
388            weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
389
390        conv_inner = subcell.conv
391        if subcell.has_bn:
392            bn_inner = subcell.batchnorm
393            if self.bn_fold:
394                if self.one_conv_fold:
395                    conv_inner = quant.Conv2dBnFoldQuantOneConv(conv_inner.in_channels,
396                                                                conv_inner.out_channels,
397                                                                kernel_size=conv_inner.kernel_size,
398                                                                stride=conv_inner.stride,
399                                                                pad_mode=conv_inner.pad_mode,
400                                                                padding=conv_inner.padding,
401                                                                dilation=conv_inner.dilation,
402                                                                group=conv_inner.group,
403                                                                eps=bn_inner.eps,
404                                                                momentum=1 - bn_inner.momentum,
405                                                                has_bias=conv_inner.has_bias,
406                                                                bias_init=conv_inner.bias_init,
407                                                                quant_config=self.quant_config,
408                                                                quant_dtype=self.weight_dtype,
409                                                                fake=True)
410                else:
411                    conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
412                                                         conv_inner.out_channels,
413                                                         kernel_size=conv_inner.kernel_size,
414                                                         stride=conv_inner.stride,
415                                                         pad_mode=conv_inner.pad_mode,
416                                                         padding=conv_inner.padding,
417                                                         dilation=conv_inner.dilation,
418                                                         group=conv_inner.group,
419                                                         eps=bn_inner.eps,
420                                                         momentum=1 - bn_inner.momentum,
421                                                         has_bias=conv_inner.has_bias,
422                                                         bias_init=conv_inner.bias_init,
423                                                         freeze_bn=self.freeze_bn,
424                                                         quant_config=self.quant_config,
425                                                         quant_dtype=self.weight_dtype,
426                                                         fake=True)
427                # change original network Batch Normalization OP parameters to quant network
428                conv_inner.gamma = subcell.batchnorm.gamma
429                conv_inner.beta = subcell.batchnorm.beta
430                conv_inner.moving_mean = subcell.batchnorm.moving_mean
431                conv_inner.moving_variance = subcell.batchnorm.moving_variance
432            else:
433                conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
434                                                            conv_inner.out_channels,
435                                                            kernel_size=conv_inner.kernel_size,
436                                                            stride=conv_inner.stride,
437                                                            pad_mode=conv_inner.pad_mode,
438                                                            padding=conv_inner.padding,
439                                                            dilation=conv_inner.dilation,
440                                                            group=conv_inner.group,
441                                                            eps=bn_inner.eps,
442                                                            momentum=1 - bn_inner.momentum,
443                                                            has_bias=conv_inner.has_bias,
444                                                            bias_init=conv_inner.bias_init,
445                                                            quant_config=self.quant_config,
446                                                            quant_dtype=self.weight_dtype)
447                # change original network Batch Normalization OP parameters to quant network
448                conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
449                conv_inner.batchnorm.beta = subcell.batchnorm.beta
450                conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
451                conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
452            del subcell.batchnorm
453            subcell.batchnorm = None
454            subcell.has_bn = False
455        else:
456            conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels,
457                                           kernel_size=conv_inner.kernel_size, stride=conv_inner.stride,
458                                           pad_mode=conv_inner.pad_mode, padding=conv_inner.padding,
459                                           dilation=conv_inner.dilation, group=conv_inner.group,
460                                           has_bias=conv_inner.has_bias, quant_config=self.quant_config,
461                                           quant_dtype=self.weight_dtype)
462        # change original network Conv2D OP parameters to quant network
463        conv_inner.weight = subcell.conv.weight
464        if subcell.conv.has_bias:
465            conv_inner.bias = subcell.conv.bias
466        subcell.conv = conv_inner
467        if subcell.has_act and subcell.activation is not None:
468            subcell.activation = self._convert_activation(subcell.activation)
469        elif subcell.after_fake:
470            subcell.has_act = True
471            subcell.activation = _AddFakeQuantAfterSubCell(F.identity, quant_dtype=self.act_dtype,
472                                                           quant_delay=self.act_qdelay, per_channel=self.act_channel,
473                                                           symmetric=self.act_symmetric, narrow_range=self.act_range,
474                                                           optimize_option=self.optimize_option)
475        return subcell
476
477    def _convert_dense(self, subcell):
478        """
479        convert dense cell to quant cell
480        """
481        min_init = -6
482        max_init = 6
483        if OptimizeOption.LEARNED_SCALE in self.optimize_option:
484            subcell_weight_para = subcell.dense.weight.data.asnumpy()
485            if subcell.has_bn:
486                scale_factor = (subcell.batchnorm.gamma.data.asnumpy() /
487                                np.sqrt(subcell.batchnorm.moving_variance.data.asnumpy() + self.eps))
488                subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
489            min_init, max_init = self._kl_init(subcell_weight_para, self.weight_dtype)
490        self.quant_config = self.quant_config._replace(
491            weight=self.quant_config.weight.partial_init(min_init=min_init, max_init=max_init))
492
493        dense_inner = subcell.dense
494        dense_inner = quant.DenseQuant(dense_inner.in_channels,
495                                       dense_inner.out_channels,
496                                       has_bias=dense_inner.has_bias,
497                                       quant_config=self.quant_config,
498                                       quant_dtype=self.weight_dtype)
499        # change original network Dense OP parameters to quant network
500        dense_inner.weight = subcell.dense.weight
501        if subcell.dense.has_bias:
502            dense_inner.bias = subcell.dense.bias
503        subcell.dense = dense_inner
504        if subcell.has_act and subcell.activation is not None:
505            subcell.activation = self._convert_activation(subcell.activation)
506        elif subcell.after_fake:
507            subcell.has_act = True
508            subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
509                                                           quant_dtype=self.act_dtype,
510                                                           quant_delay=self.act_qdelay,
511                                                           per_channel=self.act_channel,
512                                                           symmetric=self.act_symmetric,
513                                                           narrow_range=self.act_range,
514                                                           optimize_option=self.optimize_option)
515        return subcell
516
517    def _convert_activation(self, activation):
518        """
519        convert activation cell to quant cell
520        """
521        act_class = activation.__class__
522        act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
523        act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
524
525        if act_class in act_list:
526            return quant.ActQuant(activation=activation,
527                                  quant_config=self.quant_config,
528                                  quant_dtype=self.act_dtype)
529        if act_class in act_list_with_fake_before:
530            return quant.ActQuant(activation=activation,
531                                  ema=True,
532                                  fake_before=True,
533                                  quant_config=self.quant_config,
534                                  quant_dtype=self.act_dtype)
535        raise ValueError("Unsupported activation in auto quant: ", act_class)
536
537    def _kl_init(self, subcell_weight_para, weight_dtype):
538        """
539        Calculate the value of max_init and min_init with compute_kl_threshold.
540        """
541        if self.weight_channel:
542            max_init = [compute_kl_threshold(weight_para_each, weight_dtype)
543                        for weight_para_each in subcell_weight_para]
544            min_init = [-x for x in max_init]
545        else:
546            max_init = [compute_kl_threshold(subcell_weight_para, weight_dtype)]
547            min_init = [-x for x in max_init]
548        return min_init, max_init
549
550    def _set_mixed_bits(self, network, strategy):
551        r"""
552        Set network's quantization strategy, this function is currently only valid for `LEARNED_SCALE`
553        optimize_option.
554
555        Args:
556            network (Cell): Input network.
557            strategy (list): The quantization strategy for layers that need to be quantified (eg. [[8], [8],
558                ..., [6], [4], [8]]), currently only the quant_dtype for weights of the dense layer and the
559                convolution layer is supported.
560
561        Returns:
562            Cell, a network with mixed bit strategy configured.
563
564        Raises:
565            ValueError: If `OptimizeOption.LEARNED_SCALE` is not in `self.optimize_option`.
566        """
567        if OptimizeOption.LEARNED_SCALE not in self.optimize_option:
568            raise ValueError("The `_set_mixed_bits` function is currently only valid for `LEARNED_SCALE` "
569                             "optimize_option.")
570
571        quantizable_idx = []
572        pass_cell = None
573        for i, cell_and_name in enumerate(network.cells_and_names()):
574            cell = cell_and_name[1]
575            if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)) and cell is not pass_cell:
576                quantizable_idx.append(i)
577
578        if len(quantizable_idx) != len(strategy):
579            raise ValueError("The dimension of quantifiable layers is not consistent with that of strategy.")
580
581        quantizable_layer_bit_dict = {idx: bit for idx, bit in zip(quantizable_idx, strategy)}
582        type_map = {
583            QuantDtype.INT2.num_bits: QuantDtype.INT2,
584            QuantDtype.INT3.num_bits: QuantDtype.INT3,
585            QuantDtype.INT4.num_bits: QuantDtype.INT4,
586            QuantDtype.INT5.num_bits: QuantDtype.INT5,
587            QuantDtype.INT6.num_bits: QuantDtype.INT6,
588            QuantDtype.INT7.num_bits: QuantDtype.INT7,
589            QuantDtype.INT8.num_bits: QuantDtype.INT8
590        }
591        for i, cell_and_name in enumerate(network.cells_and_names()):
592            cell = cell_and_name[1]
593            if i not in quantizable_idx:
594                continue
595            else:
596                if isinstance(cell, (nn.Conv2dBnAct, nn.DenseBnAct)):
597                    cell.weight_dtype = type_map[quantizable_layer_bit_dict[i][0]]
598                    if isinstance(cell, nn.Conv2dBnAct):
599                        subcell_weight_para = cell.conv.weight.data.asnumpy()
600                        if hasattr(cell.conv, 'gamma'):
601                            scale_factor = (cell.conv.gamma.data.asnumpy() /
602                                            np.sqrt(cell.conv.moving_variance.data.asnumpy() + self.eps))
603                            subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
604                        min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype)
605                        cell.conv.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
606                                                          min_init=min_init,
607                                                          max_init=max_init)
608                    elif isinstance(cell, nn.DenseBnAct):
609                        subcell_weight_para = cell.dense.weight.data.asnumpy()
610                        if hasattr(cell.dense, 'gamma'):
611                            scale_factor = (cell.dense.gamma.data.asnumpy() /
612                                            np.sqrt(cell.dense.moving_variance.data.asnumpy() + self.eps))
613                            subcell_weight_para = subcell_weight_para * scale_factor.reshape(-1, 1, 1, 1)
614                        min_init, max_init = self._kl_init(subcell_weight_para, cell.weight_dtype)
615                        cell.dense.fake_quant_weight.reset(quant_dtype=cell.weight_dtype,
616                                                           min_init=min_init,
617                                                           max_init=max_init)
618        return network
619