• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Quantization aware training."""
16
17from functools import partial
18from collections import namedtuple
19import numpy as np
20import mindspore.common.dtype as mstype
21from mindspore.ops.primitive import Primitive
22from mindspore.ops import operations as P
23from mindspore.common.parameter import Parameter
24from mindspore.common.initializer import initializer
25from mindspore.common.tensor import Tensor
26from mindspore._checkparam import Validator, twice
27from mindspore.compression.common import QuantDtype
28import mindspore.context as context
29from .normalization import BatchNorm2d
30from .activation import get_activation
31from ..cell import Cell
32from ... import nn
33from ...ops.operations import _quant_ops as Q
34
35__all__ = [
36    'FakeQuantWithMinMaxObserver',
37    'Conv2dBnFoldQuantOneConv',
38    'Conv2dBnFoldQuant',
39    'Conv2dBnWithoutFoldQuant',
40    'Conv2dQuant',
41    'DenseQuant',
42    'ActQuant',
43    'TensorAddQuant',
44    'MulQuant',
45]
46
47
48class BatchNormFoldCell(Cell):
49    """
50    Batch Normalization folded.
51
52    Args:
53        momentum (float): Momentum value must be [0, 1]. Default: 0.9.
54        epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
55            float32 else 1e-3. Default: 1e-5.
56        freeze_bn (int): Delay in steps at which computation switches from regular batch
57            norm to frozen mean and std. Default: 0.
58
59    Inputs:
60        - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`.
61        - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
62        - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
63        - **global_step** (Tensor) - Tensor to record current global step.
64
65    Outputs:
66        Tuple of 4 Tensor, the normalized input and the updated parameters.
67
68        - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
69        - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
70        - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
71        - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
72    """
73
74    def __init__(self, momentum=0.9, epsilon=1e-5, freeze_bn=0):
75        """Initialize batch norm fold layer"""
76        super(BatchNormFoldCell, self).__init__()
77        self.epsilon = epsilon
78        self.is_gpu = context.get_context('device_target') == "GPU"
79        if self.is_gpu:
80            self.bn_train = Q.BatchNormFold(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
81            self.bn_infer = Q.BatchNormFold(momentum, epsilon, is_training=False, freeze_bn=freeze_bn)
82        else:
83            self.bn_reduce = P.BNTrainingReduce()
84            self.bn_update = Q.BatchNormFoldD(momentum, epsilon, is_training=True, freeze_bn=freeze_bn)
85
86    def construct(self, x, mean, variance, global_step):
87        if self.is_gpu:
88            if self.training:
89                batch_mean, batch_std, running_mean, running_std = self.bn_train(x, mean, variance, global_step)
90            else:
91                batch_mean, batch_std, running_mean, running_std = self.bn_infer(x, mean, variance, global_step)
92        else:
93            if self.training:
94                x_sum, x_square_sum = self.bn_reduce(x)
95                _, batch_mean, batch_std, running_mean, running_std, mean_updated, variance_updated = \
96                    self.bn_update(x, x_sum, x_square_sum, mean, variance)
97                P.Assign()(mean, mean_updated)
98                P.Assign()(variance, variance_updated)
99            else:
100                batch_mean = P.ZerosLike()(variance)
101                batch_std = P.OnesLike()(variance)
102                running_mean = P.Add()(mean, 0.)
103                running_std = P.Sqrt()(P.Add()(variance, self.epsilon))
104        return batch_mean, batch_std, running_mean, running_std
105
106
107def _partial_init(cls_or_self, **kwargs):
108    """
109    Wrapper that allows creation of class factories.
110
111    This can be useful when there is a need to create classes with the same
112    constructor arguments, but different instances.
113
114    Examples:
115        >>> class Foo:
116        ...     def __init__(self, a, b, answer):
117        ...         pass
118        >>> Foo.partial_init = classmethod(_partial_init)
119        >>> foo_builder = Foo.partial_init(a=3, b=4).partial_init(answer=42)
120        >>> foo_instance1 = foo_builder()
121        >>> foo_instance2 = foo_builder()
122        >>> result = (id(foo_instance1) == id(foo_instance2))
123        >>> print(result)
124        False
125    """
126
127    class _PartialWrapper:
128        r"""
129        class of wrapper that allows creation of class factories.
130        """
131
132        def __init__(self, p):
133            self.p = p
134
135        def __call__(self, *args, **keywords):
136            return self.p(*args, **keywords)
137
138        def __repr__(self):
139            return self.p.__repr__()
140
141        partial_init = _partial_init
142
143    r = _PartialWrapper(partial(cls_or_self, **kwargs))
144    return r
145
146
147class _Observer(Cell):
148    """
149    Base class of Observer. Observer is used to calculate the statistics of specific layer.
150
151    Notes:
152        This class is an abstract class.
153
154    Args:
155        quant_dtype (QuantDtype): The type of FakeQuant data.
156    """
157
158    def __init__(self, quant_dtype):
159        """Initialize _Observer."""
160        super(_Observer, self).__init__()
161        self.quant_dtype = quant_dtype
162
163    def extend_repr(self):
164        s = f"quant_dtype={self.quant_dtype}"
165        return s
166
167    def construct(self):
168        pass
169
170    partial_init = classmethod(_partial_init)
171
172
173class UniformQuantObserver(_Observer):
174    """
175    The base class of Uniform Quantization Observer.
176
177    Args:
178        quant_dtype (QuantDtype): The type of FakeQuant data. Default: QuantDtype.INT8.
179        per_channel (bool):  Quantization granularity based on layer or on channel. Default: False.
180        symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
181        narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
182        num_channels (int): declarate the min and max channel size, Default: 1.
183
184    Returns:
185        Tensor.
186    """
187
188    min_max_map = {
189        QuantDtype.INT2: (-2, 1),
190        QuantDtype.INT3: (-4, 3),
191        QuantDtype.INT4: (-8, 7),
192        QuantDtype.INT5: (-16, 15),
193        QuantDtype.INT6: (-32, 31),
194        QuantDtype.INT7: (-64, 63),
195        QuantDtype.INT8: (-128, 127),
196
197        QuantDtype.UINT2: (0, 3),
198        QuantDtype.UINT3: (0, 7),
199        QuantDtype.UINT4: (0, 15),
200        QuantDtype.UINT5: (0, 31),
201        QuantDtype.UINT6: (0, 63),
202        QuantDtype.UINT7: (0, 127),
203        QuantDtype.UINT8: (0, 255)
204    }
205
206    def __init__(self, quant_dtype=QuantDtype.INT8, per_channel=False, symmetric=False, narrow_range=False,
207                 num_channels=1):
208        """Initialize UniformQuantObserver."""
209        super(UniformQuantObserver, self).__init__(quant_dtype)
210        self.per_channel = per_channel
211        self.symmetric = symmetric
212        self.narrow_range = narrow_range
213        self.num_channels = num_channels
214
215
216class FakeQuantWithMinMaxObserver(UniformQuantObserver):
217    r"""
218    Quantization aware operation which provides the fake quantization observer function on data with min and max.
219
220    The detail of the quantization mode `DEFAULT` is described as below:
221
222    The running min/max :math:`x_{min}` and :math:`x_{max}` are computed as:
223
224    .. math::
225
226        \begin{array}{ll} \\
227            x_{min} =
228            \begin{cases}
229                \min(\min(X), 0)
230                  & \text{ if } ema = \text{False} \\
231                \min((1 - c) \min(X) + \text{c } x_{min}, 0)
232                  & \text{ if } \text{otherwise}
233            \end{cases}\\
234            x_{max} =
235            \begin{cases}
236                \max(\max(X), 0)
237                  & \text{ if } ema = \text{False} \\
238                \max((1 - c) \max(X) + \text{c } x_{max}, 0)
239                  & \text{ if } \text{otherwise}
240            \end{cases}
241        \end{array}
242
243    where X is the input tensor, and :math:`c` is the `ema_decay`.
244
245    The scale and zero point zp is computed as:
246
247    .. math::
248
249        \begin{array}{ll} \\
250            scale =
251            \begin{cases}
252                \frac{x_{max} - x_{min}}{Q_{max} - Q_{min}}
253                  & \text{ if } symmetric = \text{False} \\
254                \frac{2\max(x_{max}, \left | x_{min} \right |) }{Q_{max} - Q_{min}}
255                  & \text{ if } \text{otherwise}
256            \end{cases}\\
257            zp\_min = Q_{min} - \frac{x_{min}}{scale} \\
258            zp = \left \lfloor \min(Q_{max}, \max(Q_{min}, zp\_min)) + 0.5 \right \rfloor
259        \end{array}
260
261    where :math:`Q_{max}` and :math:`Q_{min}` is decided by quant_dtype, for example, if quant_dtype=INT8,
262    then :math:`Q_{max} = 127` and :math:`Q_{min} = -128`.
263
264    The fake quant output is computed as:
265
266    .. math::
267
268        \begin{array}{ll} \\
269            u_{min} = (Q_{min} - zp) * scale \\
270            u_{max} = (Q_{max} - zp) * scale \\
271            u_X = \left \lfloor \frac{\min(u_{max}, \max(u_{min}, X)) - u_{min}}{scale}
272            + 0.5 \right \rfloor \\
273            output = u_X * scale + u_{min}
274        \end{array}
275
276    The detail of the quantization mode `LEARNED_SCALE` is described as below:
277
278    The fake quant output is computed as:
279
280    .. math::
281
282        \bar{X}=\left\{\begin{matrix}
283        clip\left ( \frac{X}{maxq},0,1\right ) \qquad \quad if\quad neg\_trunc\\
284        clip\left ( \frac{X}{maxq},-1,1\right )\qquad \ if\quad otherwise
285        \end{matrix}\right. \\
286
287        output=\frac{floor\left ( \bar{X}\ast  Q_{max}+0.5  \right ) \ast scale }{Q_{max}}
288
289    where X is the input tensor.
290    where :math:`Q_{max}` (quant_max) is decided by quant_dtype and neg_trunc, for example, if quant_dtype=INT8
291    and neg_trunc works, :math:`Q_{max} = 256` , otherwise math:`Q_{max} = 127`.
292
293    The maxq is updated by training, and its gradient is calculated as follows:
294
295    .. math::
296
297        \frac{\partial \ output}{\partial \ maxq} = \left\{\begin{matrix}
298        -\frac{X}{maxq}+\left \lfloor \frac{X}{maxq} \right \rceil \qquad if\quad bound_{lower}< \frac{X}{maxq}< 1\\
299        -1 \qquad \quad \qquad \quad if\quad \frac{X}{maxq}\le bound_{lower}\\
300         1  \qquad \quad \qquad \quad if\quad \frac{X}{maxq}\ge  1 \qquad \quad
301        \end{matrix}\right. \\
302
303        bound_{lower}=
304        \left\{\begin{matrix}
305         0\qquad \quad if\quad neg\_trunc\\
306        -1\qquad if\quad otherwise
307        \end{matrix}\right.
308
309    Then minq is computed as:
310
311    .. math::
312
313        minq=\left\{\begin{matrix}
314        0  \qquad \qquad \quad if\quad neg\_trunc\\
315        -maxq\qquad if\quad otherwise
316        \end{matrix}\right.
317
318    When exporting, the scale and zero point zp is computed as:
319
320    .. math::
321
322        scale=\frac{maxq}{quant\_max} ,\quad zp=0 \\
323
324    zp is equal to 0 consistently, due to the LEARNED_SCALE`s symmetric nature.
325
326    Args:
327        min_init (int, float, list): The initialized min value. Default: -6.
328        max_init (int, float, list): The initialized max value. Default: 6.
329        ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
330        ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
331        per_channel (bool):  Quantization granularity based on layer or on channel. Default: False.
332        channel_axis (int): Quantization by channel axis. Default: 1.
333        num_channels (int): declarate the min and max channel size, Default: 1.
334        quant_dtype (QuantDtype): The datatype of quantization, supporting 4 and 8bits. Default: QuantDtype.INT8.
335        symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
336        narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
337        quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
338        neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: False.
339        mode (str): Optional quantization mode, currently only `DEFAULT`(QAT) and `LEARNED_SCALE` are supported.
340            Default: ("DEFAULT")
341    Inputs:
342        - **x** (Tensor) - The input of FakeQuantWithMinMaxObserver. The input dimension is preferably 2D or 4D.
343
344    Outputs:
345        Tensor, with the same type and shape as the `x`.
346
347    Raises:
348        TypeError: If `min_init` or `max_init` is not int, float or list.
349        TypeError: If `quant_delay` is not an int.
350        ValueError: If `quant_delay` is less than 0.
351        ValueError: If `min_init` is not less than `max_init`.
352        ValueError: If `mode` is neither `DEFAULT` nor `LEARNED_SCALE`.
353        ValueError: If `mode` is `LEARNED_SCALE` and `symmetric` is not `True`.
354        ValueError: If `mode` is `LEARNED_SCALE`, and `narrow_range` is not `True` unless when `neg_trunc` is `True`.
355
356    Supported Platforms:
357        ``Ascend`` ``GPU``
358
359    Examples:
360        >>> import mindspore
361        >>> from mindspore import Tensor
362        >>> fake_quant = nn.FakeQuantWithMinMaxObserver()
363        >>> x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
364        >>> result = fake_quant(x)
365        >>> print(result)
366        [[ 0.9882355  1.9764705  0.9882355]
367         [-1.9764705  0.        -0.9882355]]
368    """
369
370    def __init__(self,
371                 min_init=-6,
372                 max_init=6,
373                 ema=False,
374                 ema_decay=0.999,
375                 per_channel=False,
376                 channel_axis=1,
377                 num_channels=1,
378                 quant_dtype=QuantDtype.INT8,
379                 symmetric=False,
380                 narrow_range=False,
381                 quant_delay=0,
382                 neg_trunc=False,
383                 mode="DEFAULT"):
384        """Initialize FakeQuantWithMinMaxObserver"""
385        super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
386                                                          symmetric=symmetric, narrow_range=narrow_range,
387                                                          num_channels=num_channels)
388        Validator.check_value_type("min_init", min_init, [int, float, list], type(self).__name__)
389        Validator.check_value_type("max_init", max_init, [int, float, list], type(self).__name__)
390        Validator.check_non_negative_int(quant_delay, 'quant_delay', self.cls_name)
391        self.min_init = min_init
392        self.max_init = max_init
393        self.quant_dtype = quant_dtype
394        self.num_bits = quant_dtype.num_bits
395        self.ema = ema
396        self.ema_decay = ema_decay
397        self.per_channel = per_channel
398        self.num_channels = num_channels
399        self.channel_axis = channel_axis
400        self.quant_delay = quant_delay
401        self.symmetric = symmetric
402        self.narrow_range = narrow_range
403        self.neg_trunc = neg_trunc
404        self.mode = mode
405        self.is_ascend = context.get_context('device_target') == "Ascend"
406        self.Neg = P.Neg()
407
408        min_array = self._get_init_array(self.min_init)
409        max_array = self._get_init_array(self.max_init)
410        if not np.greater(max_array, min_array).all():
411            raise ValueError(f"For '{self.cls_name}', the 'max_init' should be greater than 'min_init', "
412                             f"but got 'max_init': {max_init}, 'min_init': {min_init}.")
413        if self.mode == "DEFAULT":
414            self._default_init(min_array, max_array)
415        elif self.mode == "LEARNED_SCALE":
416            self._learned_scale_init(min_array, max_array)
417        else:
418            raise ValueError(f"For '{self.cls_name}', only `DEFAULT` and `LEARNED_SCALE` mode are valid, but got "
419                             f"'mode': {self.mode}.")
420
421    def reset(self, quant_dtype=QuantDtype.INT8, min_init=-6, max_init=6):
422        r"""
423        Reset the quant max parameter (eg. 256) and the initial value of the minq parameter and maxq parameter,
424        this function is currently only valid for `LEARNED_SCALE` mode.
425        """
426        if self.mode == "LEARNED_SCALE":
427            self.quant_dtype = quant_dtype
428            self.num_bits = quant_dtype.num_bits
429            self._calculate_quant_max()
430            if self.neg_trunc:
431                min_init = 0
432
433            self.min_init = min_init
434            self.max_init = max_init
435            min_array = self._get_init_array(self.min_init)
436            max_array = self._get_init_array(self.max_init)
437            if not np.greater(max_array, min_array).all():
438                raise ValueError(f"For '{self.cls_name}', the 'max_init' should be greater than 'min_init', "
439                                 f"but got 'max_init': {max_init}, 'min_init': {min_init}.")
440
441            self.minq.set_data(Tensor(min_array))
442            self.maxq.set_data(Tensor(max_array))
443            self.quant_max.set_data(Tensor(np.array([self._quant_max]).astype(np.float32)))
444        else:
445            raise ValueError(f"For '{self.cls_name}', only `LEARNED_SCALE` mode is valid, but got 'mode': {self.mode}.")
446
447    def _default_init(self, min_array, max_array):
448        """
449        Initialization of `DEFAULT`(QAT) mode.
450        """
451        # init tensor min and max for fake quantized operation
452        self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
453        self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
454
455        # init fake quant relative op
456        if self.per_channel:
457            quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
458            ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
459        else:
460            quant_fun = Q.FakeQuantPerLayer
461            ema_fun = Q.MinMaxUpdatePerLayer
462
463        self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
464        if self.is_ascend:
465            self.fake_quant_train = quant_fun(num_bits=self.quant_dtype.num_bits,
466                                              symmetric=self.symmetric,
467                                              narrow_range=self.narrow_range,
468                                              quant_delay=self.quant_delay)
469            self.fake_quant_infer = self.fake_quant_train
470        else:
471            quant_fun = partial(quant_fun,
472                                ema=self.ema,
473                                ema_decay=self.ema_decay,
474                                num_bits=self.quant_dtype.num_bits,
475                                symmetric=self.symmetric,
476                                narrow_range=self.narrow_range,
477                                quant_delay=self.quant_delay)
478            self.fake_quant_train = quant_fun(training=True)
479            self.fake_quant_infer = quant_fun(training=False)
480
481    def _learned_scale_init(self, min_array, max_array):
482        """
483        Initialization of `LEARNED_SCALE` mode.
484        """
485        if not self.symmetric:
486            raise ValueError(f"For '{self.cls_name}', the 'LEARNED_SCALE' mode only support 'symmetric' quant, "
487                             f"but got 'symmetric': {self.symmetric}. Please set 'symmetric' to True.")
488        if self.neg_trunc:
489            min_array = self._get_init_array(0)
490            if self.narrow_range:
491                raise ValueError(f"For '{self.cls_name}', the 'LEARNED_SCALE' mode only support the combination of "
492                                 f"'neg_trunc=True and narrow_range=False' config scenario, but got 'narrow_range': "
493                                 f"{self.narrow_range}.")
494        elif not self.narrow_range:
495            raise ValueError(f"For '{self.cls_name}', the 'LEARNED_SCALE' mode only support 'narrow_range=True' "
496                             f"config, except for 'neg_trunc=True' scenario. But got 'narrow_range': "
497                             f"{self.narrow_range}.")
498
499        self._calculate_quant_max()
500
501        self.minq = Parameter(Tensor(min_array), name='minq')
502        self.maxq = Parameter(Tensor(max_array), name='maxq')
503        self.quant_max = Parameter(Tensor(np.array([self._quant_max]).astype(np.float32)),
504                                   name="quant_max", requires_grad=False)
505
506        # init fake quant relative op
507        if self.per_channel:
508            quant_fun = partial(Q.FakeLearnedScaleQuantPerChannel, channel_axis=self.channel_axis)
509        else:
510            quant_fun = Q.FakeLearnedScaleQuantPerLayer
511
512        quant_fun = partial(quant_fun,
513                            quant_delay=self.quant_delay,
514                            neg_trunc=self.neg_trunc)
515        self.fake_quant_train = quant_fun(training=True)
516        self.fake_quant_infer = quant_fun(training=False)
517
518    def _get_init_array(self, init_date):
519        """
520        Convert the initial value to array.
521        """
522        if isinstance(init_date, list) and self.per_channel and len(init_date) != self.num_channels:
523            raise ValueError(f"For '{self.cls_name}', the length of 'min_init/max_init' list should be equal to "
524                             f"'num_channels' for perchannel quant scenario, but got 'min_init/max_init': {init_date} "
525                             f"and num_channels: {self.num_channels}.")
526        if isinstance(init_date, list) and not self.per_channel and len(init_date) != 1:
527            raise ValueError(f"For '{self.cls_name}', the length of the 'min_init/max_init' list should be 1 for "
528                             f"perlayer quant scenario, but got {len(init_date)}.")
529
530        if isinstance(init_date, list):
531            min_max_array = np.array(init_date).astype(np.float32)
532        elif self.per_channel and not isinstance(init_date, list):
533            min_max_array = np.array([init_date] * self.num_channels).astype(np.float32)
534        else:
535            min_max_array = np.array([init_date]).astype(np.float32)
536        return min_max_array
537
538    def _calculate_quant_max(self):
539        """
540        The quantization range is calculated according to num_bits.
541        """
542        if not self.neg_trunc:
543            self._quant_max = (1 << (self.num_bits - 1)) - 1
544        else:
545            self._quant_max = (1 << self.num_bits) - 1
546
547    def extend_repr(self):
548        """Display instance object as string."""
549        s = 'quant_dtype={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
550            'quant_delay={}, min_init={}, max_init={}'.format(self.quant_dtype, self.symmetric, self.narrow_range,
551                                                              self.ema, self.ema_decay, self.per_channel,
552                                                              self.channel_axis, self.num_channels, self.quant_delay,
553                                                              self.min_init, self.max_init)
554        return s
555
556    def construct(self, x):
557        if self.mode == "LEARNED_SCALE":
558            if self.training:
559                out = self.fake_quant_train(x, self.maxq, self.quant_max)
560                if not self.neg_trunc:
561                    self.minq = self.Neg(self.maxq)
562            else:
563                out = self.fake_quant_infer(x, self.maxq, self.quant_max)
564        else:
565            if self.training:
566                min_up, max_up = self.ema_update(x, self.minq, self.maxq)
567                self.minq = min_up
568                self.maxq = max_up
569                out = self.fake_quant_train(x, self.minq, self.maxq)
570            else:
571                out = self.fake_quant_infer(x, self.minq, self.maxq)
572        return out
573
574
575QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
576
577quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver.partial_init(),
578                                   activation=FakeQuantWithMinMaxObserver.partial_init())
579
580
581class Conv2dBnFoldQuantOneConv(Cell):
582    r"""
583    2D convolution which use the convolution layer statistics once to calculate Batch Normalization
584    operation folded construct.
585
586    This part is a more detailed overview of Conv2d operation. For more details about Quantization,
587    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
588    :class:`FakeQuantWithMinMaxObserver`.
589
590    .. math::
591        w_{q}=quant(\frac{w}{\sqrt{var_{G}+\epsilon}}*\gamma )
592
593        b=\frac{-\mu _{G} }{\sqrt{var_{G}+\epsilon }}*\gamma +\beta
594
595        y=w_{q}\times x+b
596
597    where :math:`quant` is the continuous execution of quant and dequant, you can refer to the implementation of
598    subclass of `FakeQuantWithMinMaxObserver`, :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
599    `mu _{G}` and `var_{G}` represent the global mean and variance respectively.
600
601    Args:
602        in_channels (int): The number of input channel :math:`C_{in}`.
603        out_channels (int): The number of output channel :math:`C_{out}`.
604        kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution window.
605        stride (Union[int, tuple[int]]): Specifies stride for all spatial dimensions with the same value. Default: 1.
606        pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
607        padding (Union[int, tuple[int]]): Implicit paddings on both sides of the `x`. Default: 0.
608        dilation (Union[int, tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
609        group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
610            divisible by the number of groups. Default: 1.
611        eps (float): Parameters for Batch Normalization. Default: 1e-5.
612        momentum (float): Parameters for Batch Normalization op. Default: 0.997.
613        has_bias (bool): Specifies whether the layer uses a bias vector, which is temporarily invalid. Default: False.
614        weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
615            convolution kernel. Default: 'normal'.
616        bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
617            bias vector. Default: 'zeros'.
618        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
619            beta vector. Default: 'zeros'.
620        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
621            gamma vector. Default: 'ones'.
622        mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
623            mean vector. Default: 'zeros'.
624        var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
625            variance vector. Default: 'ones'.
626        fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
627        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
628            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
629            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
630            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
631        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
632
633    Inputs:
634        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
635
636    Outputs:
637        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
638
639    Raises:
640        TypeError: If `in_channels`, `out_channels` or `group` is not an int.
641        TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
642        TypeError: If `has_bias` or `fake` is not a bool.
643        TypeError: If `data_format` is not a string.
644        ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
645        ValueError: If `padding` is less than 0.
646        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
647
648    Supported Platforms:
649        ``Ascend`` ``GPU``
650
651    Examples:
652        >>> import mindspore
653        >>> from mindspore.compression import quant
654        >>> from mindspore import Tensor
655        >>> qconfig = quant.create_quant_config()
656        >>> conv2d_bnfold = nn.Conv2dBnFoldQuantOneConv(1, 1, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
657        ...                                             weight_init="ones", quant_config=qconfig)
658        >>> x = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
659        >>> result = conv2d_bnfold(x)
660        >>> print(result)
661        [[[[5.9296875 13.8359375]
662           [11.859375 17.78125]]]]
663    """
664
665    def __init__(self,
666                 in_channels,
667                 out_channels,
668                 kernel_size,
669                 stride=1,
670                 pad_mode='same',
671                 padding=0,
672                 dilation=1,
673                 group=1,
674                 eps=1e-5,
675                 momentum=0.997,
676                 has_bias=False,
677                 weight_init='normal',
678                 bias_init='zeros',
679                 beta_init='zeros',
680                 gamma_init='ones',
681                 mean_init='zeros',
682                 var_init='ones',
683                 fake=True,
684                 quant_config=quant_config_default,
685                 quant_dtype=QuantDtype.INT8):
686        """Initialize Conv2dBnFoldQuant layer"""
687        super(Conv2dBnFoldQuantOneConv, self).__init__()
688        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
689        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
690        self.kernel_size = twice(kernel_size)
691        self.stride = twice(stride)
692        self.dilation = twice(dilation)
693        for kernel_size_elem in self.kernel_size:
694            Validator.check_positive_int(kernel_size_elem, 'kernel_size item', self.cls_name)
695        for stride_elem in self.stride:
696            Validator.check_positive_int(stride_elem, 'stride item', self.cls_name)
697        for dilation_elem in self.dilation:
698            Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
699        if pad_mode not in ('valid', 'same', 'pad'):
700            raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values "
701                             f"in ('valid', 'same', 'pad'), but got {pad_mode}.")
702        self.pad_mode = pad_mode
703        if isinstance(padding, int):
704            Validator.check_non_negative_int(padding, 'padding', self.cls_name)
705            self.padding = padding
706        elif isinstance(padding, tuple):
707            for pad in padding:
708                Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
709            self.padding = padding
710        else:
711            raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), but got "
712                            f"{type(padding).__name__}!")
713        self.group = Validator.check_positive_int(group, "group", self.cls_name)
714        self.eps = eps
715        self.momentum = 1 - momentum
716        self.has_bias = has_bias
717        self.fake = Validator.check_bool(fake, "fake", self.cls_name)
718        self.quant_config = quant_config
719        self.quant_dtype = quant_dtype
720        data_format = 'NCHW'
721        self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
722        self._target = context.get_context("device_target")
723        self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
724        self.is_ge_backend = False
725        if context.get_context("enable_ge"):
726            self.is_ge_backend = True
727        self.enable_default_train = self.is_graph_mode and \
728                                    (self.is_ge_backend or self._target == "Ascend")
729
730        # initialize convolution op and Parameter
731        self.conv = P.Conv2D(out_channel=out_channels,
732                             kernel_size=self.kernel_size,
733                             pad_mode=pad_mode,
734                             pad=padding,
735                             stride=self.stride,
736                             dilation=self.dilation,
737                             group=group)
738        weight_shape = [out_channels, in_channels // group, *self.kernel_size]
739        channel_axis = 0
740        self.channel_axis = channel_axis
741        self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
742        self.bias_add = P.BiasAdd()
743        self.bias = None
744        if Validator.check_bool(has_bias, "has_bias", self.cls_name):
745            self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
746
747        # initialize BatchNorm Parameter
748        self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
749        self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
750        self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
751        self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
752                                         requires_grad=False)
753
754        # initialize fake ops
755        self.fake_quant_weight = quant_config.weight(ema=False,
756                                                     channel_axis=channel_axis,
757                                                     num_channels=out_channels,
758                                                     quant_dtype=quant_dtype)
759        self.freeze_bn = False
760        if self.fake_quant_weight.mode == "LEARNED_SCALE":
761            self.freeze_bn = True
762        self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps,
763                                    momentum=self.momentum, data_format=self.format)
764
765        self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
766        self.sub_mean = P.Sub()
767        self.sub_var = P.Sub()
768        self.mul_mean = P.Mul()
769        self.mul_var = P.Mul()
770        self.assign_sub_mean = P.AssignSub()
771        self.assign_sub_var = P.AssignSub()
772        self.reshape = P.Reshape()
773
774    def extend_repr(self):
775        """Display instance object as string."""
776        s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
777            'pad_mode={}, padding={}, dilation={}, group={}, ' \
778            'fake={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
779                                                          self.kernel_size, self.stride,
780                                                          self.pad_mode, self.padding, self.dilation,
781                                                          self.group,
782                                                          self.fake, self.momentum,
783                                                          self.fake_quant_weight.quant_delay)
784        return s
785
786    def construct(self, x):
787        running_std = P.Sqrt()(P.Add()(self.moving_variance, self.eps))
788        scale_factor = self.gamma / running_std
789        if self.channel_axis:
790            scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
791        else:
792            scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1))
793        weight = self.weight * scale_factor
794        if self.fake:
795            weight = self.fake_quant_weight(weight)
796        conv = self.conv(x, weight)
797
798        if self.freeze_bn:
799            return conv + self.reshape((self.beta - self.gamma * self.moving_mean / running_std), (1, -1, 1, 1))
800        scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
801        if self.enable_default_train:
802            scale_factor = P.Reciprocal()(scale_factor)
803            conv_orig = conv * scale_factor
804        else:
805            conv_orig = conv / scale_factor
806        if self.training:
807            return self.bn_train(conv_orig,
808                                 self.gamma,
809                                 self.beta,
810                                 self.moving_mean,
811                                 self.moving_variance)[0]
812
813        return self.bn_infer(conv_orig,
814                             self.gamma,
815                             self.beta,
816                             self.moving_mean,
817                             self.moving_variance)[0]
818
819
820class Conv2dBnFoldQuant(Cell):
821    r"""
822    2D convolution with Batch Normalization operation folded construct.
823
824    This part is a more detailed overview of Conv2d operation. For more details about Quantization,
825    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
826    :class:`FakeQuantWithMinMaxObserver`.
827
828    .. math::
829        y = x\times w+  b
830
831        w_{q}=quant(\frac{w}{\sqrt{Var[y]+\epsilon}}*\gamma )
832
833        y_{out}= w_{q}\times x+\frac{b-E[y]}{\sqrt{Var[y]+\epsilon}}*\gamma +\beta
834
835    where :math:`quant` is the continuous execution of quant and dequant. Two convolution
836    and Batch Normalization operation are used here, the purpose of the first convolution and Batch Normalization
837    is to count the mean `E[y]` and variance `Var[y]` of current batch output for quantization.
838
839    Args:
840        in_channels (int): The number of input channel :math:`C_{in}`.
841        out_channels (int): The number of output channel :math:`C_{out}`.
842        kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution window.
843        stride (Union[int, tuple[int]]): Specifies stride for all spatial dimensions with the same value. Default: 1.
844        pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
845        padding (Union[int, tuple[int]]): Implicit paddings on both sides of the `x`. Default: 0.
846        dilation (Union[int, tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
847        group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
848            divisible by the number of groups. Default: 1.
849        eps (float): Parameters for Batch Normalization. Default: 1e-5.
850        momentum (float): Parameters for Batch Normalization op. Default: 0.997.
851        has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
852        weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
853            convolution kernel. Default: 'normal'.
854        bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
855            bias vector. Default: 'zeros'.
856        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
857            beta vector. Default: 'zeros'.
858        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
859            gamma vector. Default: 'ones'.
860        mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
861            mean vector. Default: 'zeros'.
862        var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
863            variance vector. Default: 'ones'.
864        fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
865        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
866            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
867            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
868            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
869        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
870        freeze_bn (int): The quantization freeze Batch Normalization op is according to the global step.
871            Default: 100000.
872
873    Inputs:
874        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
875
876    Outputs:
877        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
878
879    Raises:
880        TypeError: If `in_channels`, `out_channels` or `group` is not an int.
881        TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
882        TypeError: If `has_bias` or `fake` is not a bool.
883        ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
884        ValueError: If `padding` is less than 0.
885        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
886        ValueError: If `device_target` in context is neither `Ascend` nor `GPU`.
887
888    Supported Platforms:
889        ``Ascend`` ``GPU``
890
891    Examples:
892        >>> import mindspore
893        >>> from mindspore.compression import quant
894        >>> from mindspore import Tensor
895        >>> qconfig = quant.create_quant_config()
896        >>> conv2d_bnfold = nn.Conv2dBnFoldQuant(1, 1, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
897        ...                                      weight_init="ones", quant_config=qconfig)
898        >>> x = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
899        >>> result = conv2d_bnfold(x)
900        >>> print(result)
901        [[[[5.9296875 13.8359375]
902           [11.859375 17.78125]]]]
903    """
904
905    def __init__(self,
906                 in_channels,
907                 out_channels,
908                 kernel_size,
909                 stride=1,
910                 pad_mode='same',
911                 padding=0,
912                 dilation=1,
913                 group=1,
914                 eps=1e-5,
915                 momentum=0.997,
916                 has_bias=False,
917                 weight_init='normal',
918                 bias_init='zeros',
919                 beta_init='zeros',
920                 gamma_init='ones',
921                 mean_init='zeros',
922                 var_init='ones',
923                 fake=True,
924                 quant_config=quant_config_default,
925                 quant_dtype=QuantDtype.INT8,
926                 freeze_bn=100000):
927        """Initialize Conv2dBnFoldQuant layer"""
928        super(Conv2dBnFoldQuant, self).__init__()
929        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
930        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
931        self.kernel_size = twice(kernel_size)
932        self.stride = twice(stride)
933        self.dilation = twice(dilation)
934        for kernel_size_elem in self.kernel_size:
935            Validator.check_positive_int(kernel_size_elem, 'kernel_size item', self.cls_name)
936        for stride_elem in self.stride:
937            Validator.check_positive_int(stride_elem, 'stride item', self.cls_name)
938        for dilation_elem in self.dilation:
939            Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
940        if pad_mode not in ('valid', 'same', 'pad'):
941            raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values in "
942                             f"('valid', 'same', 'pad'), but got {pad_mode}.")
943        self.pad_mode = pad_mode
944        if isinstance(padding, int):
945            Validator.check_non_negative_int(padding, 'padding', self.cls_name)
946            self.padding = padding
947        elif isinstance(padding, tuple):
948            for pad in padding:
949                Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
950            self.padding = padding
951        else:
952            raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
953                            f"but got {type(padding).__name__}!")
954        self.group = Validator.check_positive_int(group, "group", self.cls_name)
955        self.eps = eps
956        self.momentum = momentum
957        self.has_bias = has_bias
958        self.freeze_bn = freeze_bn
959        self.fake = Validator.check_bool(fake, "fake", self.cls_name)
960        self.quant_config = quant_config
961        self.quant_dtype = quant_dtype
962        self.is_gpu = context.get_context('device_target') == "GPU"
963
964        # initialize convolution op and Parameter
965        self.conv = P.Conv2D(out_channel=out_channels,
966                             kernel_size=self.kernel_size,
967                             pad_mode=pad_mode,
968                             pad=padding,
969                             stride=self.stride,
970                             dilation=self.dilation,
971                             group=group)
972        weight_shape = [out_channels, in_channels // group, *self.kernel_size]
973        channel_axis = 0
974        self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
975        self.bias_add = P.BiasAdd()
976        self.bias = None
977        if Validator.check_bool(has_bias, "has_bias", self.cls_name):
978            self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
979
980        # initialize BatchNorm Parameter
981        self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
982        self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
983        self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
984        self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
985                                         requires_grad=False)
986
987        # initialize fake ops
988        self.fake_quant_weight = quant_config.weight(ema=False,
989                                                     channel_axis=channel_axis,
990                                                     num_channels=out_channels,
991                                                     quant_dtype=quant_dtype)
992        self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
993        self.correct_mul = Q.CorrectionMul(channel_axis)
994        if context.get_context('device_target') == "Ascend":
995            self.batchnorm_fold2_train = Q.BatchNormFold2D(freeze_bn=freeze_bn)
996            self.batchnorm_fold2_infer = Q.BatchNormFold2D(freeze_bn=0)
997        elif context.get_context('device_target') == "GPU":
998            self.batchnorm_fold2_train = Q.BatchNormFold2(freeze_bn=freeze_bn)
999            self.batchnorm_fold2_infer = Q.BatchNormFold2(freeze_bn=0)
1000        else:
1001            raise ValueError(f"For '{self.cls_name}', only the 'Ascend' and 'GPU' platforms"
1002                             f" are supported, but got {context.get_context('device_target')}.")
1003        self.step = Parameter(initializer('normal', [1], dtype=mstype.int32), name='step', requires_grad=False)
1004        self.one = Tensor(1, mstype.int32)
1005        self.assignadd = P.AssignAdd()
1006
1007    def extend_repr(self):
1008        """Display instance object as string."""
1009        s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
1010            'pad_mode={}, padding={}, dilation={}, group={}, ' \
1011            'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
1012                                                                        self.kernel_size, self.stride,
1013                                                                        self.pad_mode, self.padding, self.dilation,
1014                                                                        self.group,
1015                                                                        self.fake, self.freeze_bn, self.momentum,
1016                                                                        self.fake_quant_weight.quant_delay)
1017        return s
1018
1019    def construct(self, x):
1020        out_conv = self.conv(x, self.weight)
1021        if self.has_bias:
1022            out_conv = self.bias_add(out_conv, self.bias)
1023        # BN fold1
1024        batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold(out_conv,
1025                                                                               self.moving_mean,
1026                                                                               self.moving_variance,
1027                                                                               self.step)
1028        # fake weight
1029        weight = self.correct_mul(self.weight, self.gamma, running_std)
1030        if self.fake:
1031            weight = self.fake_quant_weight(weight)
1032        out = self.conv(x, weight)
1033        if self.has_bias:
1034            out = self.bias_add(out, self.bias)
1035        # BN fold2
1036        if self.is_gpu:
1037            if self.training:
1038                out = self.batchnorm_fold2_train(out, self.beta, self.gamma,
1039                                                 batch_std, batch_mean, running_std, running_mean, self.step)
1040                self.assignadd(self.step, self.one)
1041            else:
1042                out = self.batchnorm_fold2_infer(out, self.beta, self.gamma,
1043                                                 batch_std, batch_mean, running_std, running_mean, self.step)
1044        else:
1045            if self.training:
1046                out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std)
1047                self.assignadd(self.step, self.one)
1048            else:
1049                out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std)
1050        return out
1051
1052
1053class Conv2dBnWithoutFoldQuant(Cell):
1054    r"""
1055    2D convolution and batchnorm without fold with fake quantized construct.
1056
1057    This part is a more detailed overview of Conv2d operation. For more details about Quantization,
1058    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
1059    :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1060
1061    .. math::
1062        y =x\times quant(w)+  b
1063
1064        y_{bn} =\frac{y-E[y] }{\sqrt{Var[y]+  \epsilon  } } *\gamma +  \beta
1065
1066    where :math:`quant` is the continuous execution of quant and dequant, you can refer to the implementation of
1067    class of `FakeQuantWithMinMaxObserver`, :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1068
1069    Args:
1070        in_channels (int): The number of input channel :math:`C_{in}`.
1071        out_channels (int): The number of output channel :math:`C_{out}`.
1072        kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution window.
1073        stride (Union[int, tuple[int]]): Specifies stride for all spatial dimensions with the same value. Default: 1.
1074        pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
1075        padding (Union[int, tuple[int]]): Implicit paddings on both sides of the `x`. Default: 0.
1076        dilation (Union[int, tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
1077        group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
1078            divisible by the number of groups. Default: 1.
1079        has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
1080        eps (float): Parameters for Batch Normalization. Default: 1e-5.
1081        momentum (float): Parameters for Batch Normalization op. Default: 0.997.
1082        weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
1083            Default: 'normal'.
1084        bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
1085        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
1086            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
1087            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
1088            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
1089        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
1090
1091    Inputs:
1092        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1093
1094    Outputs:
1095        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
1096
1097    Supported Platforms:
1098        ``Ascend`` ``GPU``
1099
1100    Raises:
1101        TypeError: If `in_channels`, `out_channels` or `group` is not an int.
1102        TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
1103        TypeError: If `has_bias` is not a bool.
1104        ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
1105        ValueError: If `padding` is less than 0.
1106        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
1107
1108    Examples:
1109        >>> import mindspore
1110        >>> from mindspore.compression import quant
1111        >>> from mindspore import Tensor
1112        >>> qconfig = quant.create_quant_config()
1113        >>> conv2d_no_bnfold = nn.Conv2dBnWithoutFoldQuant(1, 1, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
1114        ...                                                weight_init='ones', quant_config=qconfig)
1115        >>> x = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
1116        >>> result = conv2d_no_bnfold(x)
1117        >>> print(result)
1118        [[[[5.929658  13.835868]
1119           [11.859316  17.78116]]]]
1120    """
1121
1122    def __init__(self,
1123                 in_channels,
1124                 out_channels,
1125                 kernel_size,
1126                 stride=1,
1127                 pad_mode='same',
1128                 padding=0,
1129                 dilation=1,
1130                 group=1,
1131                 has_bias=False,
1132                 eps=1e-5,
1133                 momentum=0.997,
1134                 weight_init='normal',
1135                 bias_init='zeros',
1136                 quant_config=quant_config_default,
1137                 quant_dtype=QuantDtype.INT8):
1138        """Initialize Conv2dBnWithoutFoldQuant."""
1139        super(Conv2dBnWithoutFoldQuant, self).__init__()
1140        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
1141        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
1142        self.has_bias = has_bias
1143        self.kernel_size = twice(kernel_size)
1144        self.stride = twice(stride)
1145        self.dilation = twice(dilation)
1146        for kernel_size_elem in self.kernel_size:
1147            Validator.check_positive_int(kernel_size_elem, 'kernel_size item', self.cls_name)
1148        for stride_elem in self.stride:
1149            Validator.check_positive_int(stride_elem, 'stride item', self.cls_name)
1150        for dilation_elem in self.dilation:
1151            Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
1152        if pad_mode not in ('valid', 'same', 'pad'):
1153            raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values in "
1154                             f"('valid', 'same', 'pad'), but got {pad_mode}.")
1155        self.pad_mode = pad_mode
1156        if isinstance(padding, int):
1157            Validator.check_non_negative_int(padding, 'padding', self.cls_name)
1158            self.padding = padding
1159        elif isinstance(padding, tuple):
1160            for pad in padding:
1161                Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
1162            self.padding = padding
1163        else:
1164            raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
1165                            f"but got {type(padding).__name__}!")
1166        self.group = Validator.check_positive_int(group, "group", self.cls_name)
1167        self.bias_add = P.BiasAdd()
1168        if Validator.check_bool(has_bias, "has_bias", self.cls_name):
1169            self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
1170        else:
1171            self.bias = None
1172        # initialize convolution op and Parameter
1173        self.conv = P.Conv2D(out_channel=self.out_channels,
1174                             kernel_size=self.kernel_size,
1175                             mode=1,
1176                             pad_mode=self.pad_mode,
1177                             pad=self.padding,
1178                             stride=self.stride,
1179                             dilation=self.dilation,
1180                             group=self.group)
1181        weight_shape = [out_channels, in_channels // group, *self.kernel_size]
1182        channel_axis = 0
1183        self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
1184        self.fake_quant_weight = quant_config.weight(ema=False,
1185                                                     channel_axis=channel_axis,
1186                                                     num_channels=out_channels,
1187                                                     quant_dtype=quant_dtype)
1188        self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
1189
1190    def construct(self, x):
1191        weight = self.fake_quant_weight(self.weight)
1192        out = self.conv(x, weight)
1193        if self.has_bias:
1194            out = self.bias_add(out, self.bias)
1195        out = self.batchnorm(out)
1196        return out
1197
1198    def extend_repr(self):
1199        """Display instance object as string."""
1200        s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
1201            'pad_mode={}, padding={}, dilation={}, group={}, ' \
1202            'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
1203                                                 self.pad_mode, self.padding, self.dilation, self.group,
1204                                                 self.has_bias, self.fake_quant_weight.quant_delay)
1205        return s
1206
1207
1208class Conv2dQuant(Cell):
1209    r"""
1210    2D convolution with fake quantized operation layer.
1211
1212    This part is a more detailed overview of Conv2d operation. For more details about Quantization,
1213    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
1214    :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1215
1216    Args:
1217        in_channels (int): The number of input channel :math:`C_{in}`.
1218        out_channels (int): The number of output channel :math:`C_{out}`.
1219        kernel_size (Union[int, tuple[int]]): Specifies the height and width of the 2D convolution window.
1220        stride (Union[int, tuple[int]]): Specifies stride for all spatial dimensions with the same value. Default: 1.
1221        pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
1222        padding (Union[int, tuple[int]]): Implicit paddings on both sides of the `x`. Default: 0.
1223        dilation (Union[int, tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1.
1224        group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
1225            divisible by the number of groups. Default: 1.
1226        has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
1227        weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
1228            Default: 'normal'.
1229        bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
1230        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
1231            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
1232            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
1233            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
1234        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
1235
1236    Inputs:
1237        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1238          The input dimension is preferably 2D or 4D.
1239
1240    Outputs:
1241        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
1242
1243    Raises:
1244        TypeError: If `in_channels`, `out_channels` or `group` is not an int.
1245        TypeError: If `kernel_size`, `stride`, `padding` or `dilation` is neither an int nor a tuple.
1246        TypeError: If `has_bias` is not a bool.
1247        ValueError: If `in_channels`, `out_channels`, `kernel_size`, `stride` or `dilation` is less than 1.
1248        ValueError: If `padding` is less than 0.
1249        ValueError: If `pad_mode` is not one of 'same', 'valid', 'pad'.
1250
1251    Supported Platforms:
1252        ``Ascend`` ``GPU``
1253
1254    Examples:
1255        >>> import mindspore
1256        >>> from mindspore.compression import quant
1257        >>> from mindspore import Tensor
1258        >>> qconfig = quant.create_quant_config()
1259        >>> conv2d_quant = nn.Conv2dQuant(1, 1, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
1260        ...                               weight_init='ones', quant_config=qconfig)
1261        >>> x = Tensor(np.array([[[[1, 0, 3], [1, 4, 7], [2, 5, 2]]]]), mindspore.float32)
1262        >>> result = conv2d_quant(x)
1263        >>> print(result)
1264        [[[[5.9296875  13.8359375]
1265           [11.859375  17.78125]]]]
1266    """
1267
1268    def __init__(self,
1269                 in_channels,
1270                 out_channels,
1271                 kernel_size,
1272                 stride=1,
1273                 pad_mode='same',
1274                 padding=0,
1275                 dilation=1,
1276                 group=1,
1277                 has_bias=False,
1278                 weight_init='normal',
1279                 bias_init='zeros',
1280                 quant_config=quant_config_default,
1281                 quant_dtype=QuantDtype.INT8):
1282        """Initialize Conv2dQuant."""
1283        super(Conv2dQuant, self).__init__()
1284        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
1285        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
1286        self.has_bias = has_bias
1287        self.kernel_size = twice(kernel_size)
1288        self.stride = twice(stride)
1289        self.dilation = twice(dilation)
1290        for kernel_size_elem in self.kernel_size:
1291            Validator.check_positive_int(kernel_size_elem, 'kernel_size item', self.cls_name)
1292        for stride_elem in self.stride:
1293            Validator.check_positive_int(stride_elem, 'stride item', self.cls_name)
1294        for dilation_elem in self.dilation:
1295            Validator.check_positive_int(dilation_elem, 'dilation item', self.cls_name)
1296        if pad_mode not in ('valid', 'same', 'pad'):
1297            raise ValueError(f"For '{self.cls_name}', the 'pad_mode' should be one of values "
1298                             f"in ('valid', 'same', 'pad'), but got {pad_mode}.")
1299        self.pad_mode = pad_mode
1300        if isinstance(padding, int):
1301            Validator.check_non_negative_int(padding, 'padding', self.cls_name)
1302            self.padding = padding
1303        elif isinstance(padding, tuple):
1304            for pad in padding:
1305                Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
1306            self.padding = padding
1307        else:
1308            raise TypeError(f"For '{self.cls_name}', the type of 'padding' must be int/tuple(int), "
1309                            f"but got {type(padding).__name__}!")
1310        self.group = Validator.check_positive_int(group, "group", self.cls_name)
1311
1312        weight_shape = [out_channels, in_channels // group, *self.kernel_size]
1313        self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
1314
1315        self.bias_add = P.BiasAdd()
1316        if Validator.check_bool(has_bias, "has_bias", self.cls_name):
1317            self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
1318        else:
1319            self.bias = None
1320
1321        self.conv = P.Conv2D(out_channel=self.out_channels,
1322                             kernel_size=self.kernel_size,
1323                             mode=1,
1324                             pad_mode=self.pad_mode,
1325                             pad=self.padding,
1326                             stride=self.stride,
1327                             dilation=self.dilation,
1328                             group=self.group)
1329        channel_axis = 0
1330        self.fake_quant_weight = quant_config.weight(ema=False,
1331                                                     channel_axis=channel_axis,
1332                                                     num_channels=out_channels,
1333                                                     quant_dtype=quant_dtype)
1334
1335    def construct(self, x):
1336        weight = self.fake_quant_weight(self.weight)
1337        out = self.conv(x, weight)
1338        if self.has_bias:
1339            return self.bias_add(out, self.bias)
1340        return out
1341
1342    def extend_repr(self):
1343        """Display instance object as string."""
1344        s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
1345            'pad_mode={}, padding={}, dilation={}, group={}, ' \
1346            'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
1347                                                 self.pad_mode, self.padding, self.dilation, self.group,
1348                                                 self.has_bias, self.fake_quant_weight.quant_delay)
1349        return s
1350
1351
1352class DenseQuant(Cell):
1353    r"""
1354    The fully connected layer with fake quantized operation.
1355
1356    This part is a more detailed overview of Dense operation. For more details about Quantization,
1357    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
1358    :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1359
1360    Args:
1361        in_channels (int): The dimension of the input space.
1362        out_channels (int): The dimension of the output space.
1363        weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
1364            is same as `x`. The values of str refer to the function `initializer`. Default: 'normal'.
1365        bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
1366            same as `x`. The values of str refer to the function `initializer`. Default: 'zeros'.
1367        has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
1368        activation (Union[str, Cell, Primitive]): The regularization function applied to the output of the layer,
1369            eg. 'relu'. Default: None.
1370        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
1371            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
1372            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
1373            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
1374        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
1375
1376    Inputs:
1377        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
1378          The input dimension is preferably 2D or 4D.
1379
1380    Outputs:
1381        Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
1382
1383    Raises:
1384        TypeError: If `in_channels`, `out_channels` is not an int.
1385        TypeError: If `has_bias` is not a bool.
1386        TypeError: If `activation` is not str, Cell and Primitive.
1387        ValueError: If `in_channels` or `out_channels` is less than 1.
1388        ValueError: If the dims of `weight_init` is not equal to 2 or the first element of `weight_init` is not equal
1389            to `out_channels` or the second element of `weight_init` is not equal to `in_channels`.
1390        ValueError: If the dims of `bias_init` is not equal to 1 or the element of `bias_init` is not equal
1391            to `out_channels`.
1392
1393    Supported Platforms:
1394        ``Ascend`` ``GPU``
1395
1396    Examples:
1397        >>> import mindspore
1398        >>> from mindspore.compression import quant
1399        >>> from mindspore import Tensor
1400        >>> qconfig = quant.create_quant_config()
1401        >>> dense_quant = nn.DenseQuant(2, 1, weight_init='ones', quant_config=qconfig)
1402        >>> x = Tensor(np.array([[1, 5], [3, 4]]), mindspore.float32)
1403        >>> result = dense_quant(x)
1404        >>> print(result)
1405        [[5.929413]
1406         [6.9176483]]
1407    """
1408
1409    def __init__(self,
1410                 in_channels,
1411                 out_channels,
1412                 weight_init='normal',
1413                 bias_init='zeros',
1414                 has_bias=True,
1415                 activation=None,
1416                 quant_config=quant_config_default,
1417                 quant_dtype=QuantDtype.INT8):
1418        """Initialize DenseQuant."""
1419        super(DenseQuant, self).__init__()
1420        self.in_channels = Validator.check_positive_int(in_channels, "in_channels", self.cls_name)
1421        self.out_channels = Validator.check_positive_int(out_channels, "out_channels", self.cls_name)
1422        self.has_bias = Validator.check_bool(has_bias, "has_bias", self.cls_name)
1423
1424        if isinstance(weight_init, Tensor):
1425            if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
1426                    weight_init.shape[1] != in_channels:
1427                raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' should "
1428                                 f"be equal to 2, and the first dim should be equal to 'out_channels', and the "
1429                                 f"second dim should be equal to 'in_channels'. But got 'weight_init': {weight_init}, "
1430                                 f"'out_channels': {out_channels}, 'in_channels': {in_channels}.")
1431
1432        self.weight = Parameter(initializer(
1433            weight_init, [out_channels, in_channels]), name="weight")
1434
1435        if self.has_bias:
1436            if isinstance(bias_init, Tensor):
1437                if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
1438                    raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' should "
1439                                     f"be equal to 1, and the first dim should be equal to 'out_channels'. But got "
1440                                     f"'bias_init': {bias_init}, 'out_channels': {out_channels}.")
1441
1442            self.bias = Parameter(initializer(
1443                bias_init, [out_channels]), name="bias")
1444
1445        self.matmul = P.MatMul(transpose_b=True)
1446        self.bias_add = P.BiasAdd()
1447
1448        self.activation = get_activation(activation) if isinstance(activation, str) else activation
1449        if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
1450            raise TypeError(f"For '{self.cls_name}', the 'activation' must be str or Cell or Primitive, "
1451                            f"but got {activation}.")
1452
1453        self.activation_flag = self.activation is not None
1454        self.fake_quant_weight = quant_config.weight(ema=False,
1455                                                     channel_axis=0,
1456                                                     num_channels=out_channels,
1457                                                     quant_dtype=quant_dtype)
1458
1459    def construct(self, x):
1460        """Use operators to construct the Dense layer."""
1461        output = self.fake_quant_weight(self.weight)
1462        output = self.matmul(x, output)
1463        if self.has_bias:
1464            output = self.bias_add(output, self.bias)
1465        if self.activation_flag:
1466            return self.activation(output)
1467        return output
1468
1469    def extend_repr(self):
1470        """A pretty print for Dense layer."""
1471        s = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format(
1472            self.in_channels, self.out_channels, self.weight, self.has_bias)
1473        if self.has_bias:
1474            s += ', bias={}'.format(self.bias)
1475        if self.activation_flag:
1476            s += ', activation={}'.format(self.activation)
1477        return s
1478
1479
1480class _QuantActivation(Cell):
1481    r"""
1482    Base class for quantization aware training activation function. Adds fake quantized operation
1483    after activation operation.
1484    """
1485
1486    def get_origin(self):
1487        raise NotImplementedError
1488
1489
1490class ActQuant(_QuantActivation):
1491    r"""
1492    Quantization aware training activation function.
1493
1494    Add the fake quantized operation to the end of activation operation, by which the output of activation
1495    operation will be truncated. For more details about Quantization, please refer to the implementation
1496    of subclass of `FakeQuantWithMinMaxObserver`, :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1497
1498    Args:
1499        activation (Cell): Activation cell.
1500        ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
1501        ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
1502        fake_before (bool): Whether add fake quantized operation before activation. Default: False.
1503        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
1504            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
1505            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
1506            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
1507        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
1508
1509    Inputs:
1510        - **x** (Tensor) - The input of ActQuant. The input dimension is preferably 2D or 4D.
1511
1512    Outputs:
1513        Tensor, with the same type and shape as the `x`.
1514
1515    Raises:
1516        TypeError: If `activation` is not an instance of Cell.
1517        TypeError: If `fake_before` is not a bool.
1518
1519    Supported Platforms:
1520        ``Ascend`` ``GPU``
1521
1522    Examples:
1523        >>> import mindspore
1524        >>> from mindspore.compression import quant
1525        >>> from mindspore import Tensor
1526        >>> qconfig = quant.create_quant_config()
1527        >>> act_quant = nn.ActQuant(nn.ReLU(), quant_config=qconfig)
1528        >>> x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
1529        >>> result = act_quant(x)
1530        >>> print(result)
1531        [[0.9882355 1.9764705 0.       ]
1532         [0.        0.        0.       ]]
1533    """
1534
1535    def __init__(self,
1536                 activation,
1537                 ema=False,
1538                 ema_decay=0.999,
1539                 fake_before=False,
1540                 quant_config=quant_config_default,
1541                 quant_dtype=QuantDtype.INT8):
1542        """Initialize ActQuant."""
1543        super(ActQuant, self).__init__()
1544        act_class = activation.__class__
1545        act_list = [nn.ReLU, nn.ReLU6]
1546        self.act = Validator.check_isinstance("activation", activation, Cell)
1547        self.fake_before = Validator.check_bool(fake_before, "fake_before", self.cls_name)
1548        if self.fake_before:
1549            self.fake_quant_act_before = quant_config.activation(min_init=-6,
1550                                                                 max_init=6,
1551                                                                 ema=ema,
1552                                                                 ema_decay=ema_decay,
1553                                                                 quant_dtype=quant_dtype)
1554        self.neg_trunc = False
1555        self.narrow_range = False
1556        preset_dict = quant_config.activation.p.keywords
1557        if 'mode' in preset_dict and preset_dict['mode'] == "LEARNED_SCALE" and act_class in act_list:
1558            self.neg_trunc = True
1559        elif 'narrow_range' in preset_dict:
1560            self.narrow_range = preset_dict['narrow_range']
1561
1562        self.fake_quant_act = quant_config.activation(min_init=-6,
1563                                                      max_init=6,
1564                                                      ema=ema,
1565                                                      ema_decay=ema_decay,
1566                                                      quant_dtype=quant_dtype,
1567                                                      neg_trunc=self.neg_trunc,
1568                                                      narrow_range=self.narrow_range)
1569
1570    def construct(self, x):
1571        if self.fake_before:
1572            x = self.fake_quant_act_before(x)
1573        x = self.act(x)
1574        x = self.fake_quant_act(x)
1575        return x
1576
1577    def get_origin(self):
1578        return self.act
1579
1580
1581class TensorAddQuant(Cell):
1582    r"""
1583    Adds fake quantized operation after TensorAdd operation.
1584
1585    This part is a more detailed overview of TensorAdd operation. For more details about Quantization,
1586    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
1587    :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1588
1589    Args:
1590        ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
1591        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
1592            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
1593            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
1594            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
1595        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
1596
1597    Inputs:
1598        - **x1** (Tensor) - The first tensor of TensorAddQuant. The input dimension is preferably 2D or 4D.
1599        - **x2** (Tensor) - The second tensor of TensorAddQuant. Has the same shape with `x1`.
1600
1601    Outputs:
1602        Tensor, with the same type and shape as the `x1`.
1603
1604    Raises:
1605        TypeError: If `ema_decay` is not a float.
1606        ValueError: If the shape of `x2` is different with `x1`.
1607
1608    Supported Platforms:
1609        ``Ascend`` ``GPU``
1610
1611    Examples:
1612        >>> import mindspore
1613        >>> from mindspore.compression import quant
1614        >>> from mindspore import Tensor
1615        >>> qconfig = quant.create_quant_config()
1616        >>> add_quant = nn.TensorAddQuant(quant_config=qconfig)
1617        >>> x1 = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
1618        >>> x2 = Tensor(np.ones((2, 3)), mindspore.float32)
1619        >>> output = add_quant(x1, x2)
1620        >>> print(output)
1621        [[ 1.9764705  3.011765   1.9764705]
1622         [-0.9882355  0.9882355  0.       ]]
1623    """
1624
1625    def __init__(self,
1626                 ema_decay=0.999,
1627                 quant_config=quant_config_default,
1628                 quant_dtype=QuantDtype.INT8):
1629        """Initialize TensorAddQuant."""
1630        super(TensorAddQuant, self).__init__()
1631        self.fake_quant_act = quant_config.activation(min_init=-6,
1632                                                      max_init=6,
1633                                                      ema=True,
1634                                                      ema_decay=ema_decay,
1635                                                      quant_dtype=quant_dtype)
1636        self.add = P.Add()
1637
1638    def construct(self, x1, x2):
1639        x = self.add(x1, x2)
1640        x = self.fake_quant_act(x)
1641        return x
1642
1643
1644class MulQuant(Cell):
1645    r"""
1646    Adds fake quantized operation after `Mul` operation.
1647
1648    This part is a more detailed overview of `Mul` operation. For more details about Quantization,
1649    please refer to the implementation of class of `FakeQuantWithMinMaxObserver`,
1650    :class:`mindspore.nn.FakeQuantWithMinMaxObserver`.
1651
1652    Args:
1653        ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
1654        quant_config (QuantConfig): Configures the types of quant observer and quant settings of weight and
1655            activation. Note that, QuantConfig is a special namedtuple, which is designed for quantization
1656            and can be generated by :func:`mindspore.compression.quant.create_quant_config` method.
1657            Default: QuantConfig with both items set to default :class:`FakeQuantWithMinMaxObserver`.
1658        quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
1659
1660    Inputs:
1661        - **x1** (Tensor) - The first tensor of MulQuant. The input dimension is preferably 2D or 4D.
1662        - **x2** (Tensor) - The second tensor of MulQuant. Has the same shape with `x1`.
1663
1664    Outputs:
1665        Tensor, with the same type and shape as the `x1`.
1666
1667    Raises:
1668        TypeError: If `ema_decay` is not a float.
1669        ValueError: If the shape of `x2` is different with `x1`.
1670
1671    Supported Platforms:
1672        ``Ascend`` ``GPU``
1673
1674    Examples:
1675        >>> import mindspore
1676        >>> from mindspore.compression import quant
1677        >>> from mindspore import Tensor
1678        >>> qconfig = quant.create_quant_config()
1679        >>> mul_quant = nn.MulQuant(quant_config=qconfig)
1680        >>> x1 = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
1681        >>> x2 = Tensor(np.ones((2, 3)) * 2, mindspore.float32)
1682        >>> output = mul_quant(x1, x2)
1683        >>> print(output)
1684        [[ 1.9764705  4.0000005  1.9764705]
1685         [-4.         0.        -1.9764705]]
1686    """
1687
1688    def __init__(self,
1689                 ema_decay=0.999,
1690                 quant_config=quant_config_default,
1691                 quant_dtype=QuantDtype.INT8):
1692        """Initialize MulQuant."""
1693        super(MulQuant, self).__init__()
1694        self.fake_quant_act = quant_config.activation(min_init=-6,
1695                                                      max_init=6,
1696                                                      ema=True,
1697                                                      ema_decay=ema_decay,
1698                                                      quant_dtype=quant_dtype)
1699        self.mul = P.Mul()
1700
1701    def construct(self, x1, x2):
1702        x = self.mul(x1, x2)
1703        x = self.fake_quant_act(x)
1704        return x
1705