• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0(the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http:  // www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""Operators for quantization."""
17from functools import partial
18
19import mindspore.context as context
20from ..._checkparam import Validator as validator
21from ..._checkparam import Rel
22from ..primitive import PrimitiveWithInfer, prim_attr_register
23from ...common import dtype as mstype
24
25if context.get_context('device_target') == "Ascend":
26    import mindspore.ops._op_impl._custom_op
27
28__all__ = ["MinMaxUpdatePerLayer",
29           "MinMaxUpdatePerChannel",
30           "FakeLearnedScaleQuantPerLayer",
31           "FakeLearnedScaleQuantPerLayerGrad",
32           "FakeLearnedScaleQuantPerLayerGradD",
33           "FakeLearnedScaleQuantPerLayerGradDReduce",
34           "FakeLearnedScaleQuantPerChannel",
35           "FakeLearnedScaleQuantPerChannelGrad",
36           "FakeLearnedScaleQuantPerChannelGradD",
37           "FakeLearnedScaleQuantPerChannelGradDReduce",
38           "FakeQuantWithMinMaxVars",
39           "FakeQuantWithMinMaxVarsGradient",
40           "FakeQuantWithMinMaxVarsPerChannel",
41           "FakeQuantWithMinMaxVarsPerChannelGradient",
42           "FakeQuantPerLayer",
43           "FakeQuantPerLayerGrad",
44           "FakeQuantPerChannel",
45           "FakeQuantPerChannelGrad",
46           "BatchNormFold",
47           "BatchNormFoldGrad",
48           "CorrectionMul",
49           "CorrectionMulGrad",
50           "CorrectionMulGradReduce",
51           "BatchNormFold2",
52           "BatchNormFold2Grad",
53           "BatchNormFoldD",
54           "BatchNormFoldGradD",
55           "BatchNormFold2D",
56           "BatchNormFold2GradD",
57           "BatchNormFold2GradReduce",
58           "IFMR",
59           "ActsULQ",
60           "ActsULQInputGrad",
61           "ActULQClampMinGrad",
62           "ActULQClampMaxGrad",
63           "WtsARQ"
64           ]
65
66
67class MinMaxUpdatePerLayer(PrimitiveWithInfer):
68    r"""
69    Updates min and max per layer.
70
71    Args:
72        ema (bool): Uses EMA algorithm update value min and max. Default: False.
73        ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
74
75    Inputs:
76        - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
77        - **min** (Tensor) : Value of the min range of the input data x.
78        - **max** (Tensor) : Value of the max range of the input data x.
79
80    Outputs:
81        - Tensor: Simulates quantize tensor of x.
82
83    Examples:
84        >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
85        >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
86        >>> max_tensor = Tensor(np.array([6]), mstype.float32)
87        >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
88    """
89    support_quant_bit = [4, 7, 8]
90
91    @prim_attr_register
92    def __init__(self, ema=False, ema_decay=0.999):
93        """Initialize FakeQuantMinMaxPerLayerUpdate OP"""
94        if context.get_context('device_target') == "Ascend":
95            from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
96        if ema and not ema_decay:
97            raise ValueError(
98                f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
99
100        self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
101        self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
102        self.init_prim_io_names(inputs=['x', 'min', 'max'],
103                                outputs=['min_up', 'max_up'])
104
105    def infer_shape(self, x_shape, min_shape, max_shape):
106        validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
107        validator.check("min shape", min_shape, "max shape",
108                        max_shape, Rel.EQ, self.name)
109        validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
110        return min_shape, max_shape
111
112    def infer_dtype(self, x_type, min_type, max_type):
113        tuple(map(partial(validator.check_tensor_dtype_valid,
114                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
115                  ("x", "min", "max"),
116                  (x_type, min_type, max_type)))
117        return min_type, max_type
118
119
120class MinMaxUpdatePerChannel(PrimitiveWithInfer):
121    r"""
122     Updates min and max per channel.
123
124    Args:
125        ema (bool): Uses EMA algorithm update value min and max. Default: False.
126        ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
127        channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
128
129    Inputs:
130        - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
131        - **min** (Tensor) : Value of the min range of the input data x.
132        - **max** (Tensor) : Value of the max range of the input data x.
133
134    Outputs:
135        - Tensor: Simulates quantize tensor of x.
136
137    Examples:
138        >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
139        >>> min_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
140        >>> max_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
141        >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min_value, max_value)
142    """
143    support_quant_bit = [4, 7, 8]
144    ascend_support_x_rank = [2, 4]
145
146    @prim_attr_register
147    def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
148        """Initialize FakeQuantPerChannelUpdate OP for Ascend"""
149        self.is_ascend = context.get_context('device_target') == "Ascend"
150        if self.is_ascend:
151            from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
152        if ema and not ema_decay:
153            raise ValueError(
154                f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
155
156        self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
157        self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
158        if self.is_ascend:
159            self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
160        else:
161            self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
162        self.init_prim_io_names(
163            inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
164
165    def infer_shape(self, x_shape, min_shape, max_shape):
166        if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
167            raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
168        if not self.is_ascend:
169            validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
170        validator.check("min shape", min_shape, "max shape",
171                        max_shape, Rel.EQ, self.name)
172        validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
173        return min_shape, max_shape
174
175    def infer_dtype(self, x_type, min_type, max_type):
176        tuple(map(partial(validator.check_tensor_dtype_valid,
177                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
178                  ("x", "min", "max"),
179                  (x_type, min_type, max_type)))
180        return min_type, max_type
181
182
183class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
184    r"""
185    Simulates the quantize and dequantize operations of the fake learned scale quant per-layer case in training time.
186
187    Args:
188        quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
189            simulate quantization aware function. After delay step in training time begin simulate the aware
190            quantize function. Default: 0.
191        neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False.
192        training (bool): Training the network or not. Default: True.
193
194    Inputs:
195        - **input_x** (Tensor) : Input tensor that needs to be quantified.
196        - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`.
197        - **quant_max** (Tensor) : Value of the quantization range.
198
199    Outputs:
200        - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`.
201
202    Examples:
203        >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
204        >>> alpha_tensor = Tensor(np.array([6]), mstype.float32)
205        >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32)
206        >>> output_tensor = FakeLearnedScaleQuantPerLayer()(input_tensor, alpha_tensor, quant_max_tensor)
207    """
208    @prim_attr_register
209    def __init__(self,
210                 quant_delay=0,
211                 neg_trunc=False,
212                 training=True):
213        """init FakeLearnedScaleQuantPerLayer OP"""
214        if context.get_context('device_target') == "Ascend":
215            from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer
216
217        self.quant_delay = validator.check_non_negative_int(
218            quant_delay, 'quant_delay', self.name)
219        self.neg_trunc = validator.check_value_type(
220            'neg_trunc', neg_trunc, (bool,), self.name)
221        self.training = validator.check_value_type(
222            'training', training, (bool,), self.name)
223        self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
224                                outputs=['out'])
225
226    def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
227        validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
228        validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
229        validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
230        return input_x_shape
231
232    def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
233        if context.get_context('device_target') == "GPU":
234            valid_dtypes = (mstype.float32,)
235        else:
236            valid_dtypes = (mstype.float16, mstype.float32)
237        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
238                  ("input_x", "alpha", "quant_max"),
239                  (input_x_type, alpha_type, quant_max_type)))
240        return input_x_type
241
242
243class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer):
244    r"""
245    Performs grad of FakeLearnedScaleQuantPerLayer operation.
246
247    Examples:
248        >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerLayerGrad()
249        >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
250        >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
251        >>> _alpha = Tensor(np.array([6]), mindspore.float32)
252        >>> _quant_max = Tensor(np.array([127]), mindspore.float32)
253        >>> result = fake_learned_scale_grad(dout, input_x, _min, _max)
254    """
255
256    @prim_attr_register
257    def __init__(self,
258                 quant_delay=0,
259                 neg_trunc=False):
260        self.quant_delay = validator.check_non_negative_int(
261            quant_delay, 'quant_delay', self.name)
262        self.neg_trunc = validator.check_value_type(
263            'neg_trunc', neg_trunc, (bool,), self.name)
264        self.init_prim_io_names(
265            inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
266
267    def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
268        validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
269        validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
270        validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
271        return dout_shape, alpha_shape
272
273    def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
274        if context.get_context('device_target') == "GPU":
275            valid_dtypes = (mstype.float32,)
276        else:
277            valid_dtypes = (mstype.float16, mstype.float32)
278        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
279                  ("dout", "x", "alpha", "quant_max"),
280                  (dout_type, x_type, alpha_type, quant_max_type)))
281        return dout_type, alpha_type
282
283
284class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer):
285    r"""
286    Performs input grad of FakeLearnedScaleQuantPerLayer operation.
287    """
288
289    @prim_attr_register
290    def __init__(self,
291                 neg_trunc=False):
292        from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad
293        self.neg_trunc = validator.check_value_type(
294            'neg_trunc', neg_trunc, (bool,), self.name)
295        self.init_prim_io_names(
296            inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
297
298    def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
299        validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
300        validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
301        validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
302        return dout_shape, dout_shape
303
304    def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
305        valid_dtypes = (mstype.float16, mstype.float32)
306        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
307                  ("dout", "x", "alpha", "quant_max"),
308                  (dout_type, x_type, alpha_type, quant_max_type)))
309        return dout_type, dout_type
310
311
312class FakeLearnedScaleQuantPerLayerGradDReduce(PrimitiveWithInfer):
313    r"""
314    Performs alpha grad reduce of FakeLearnedScaleQuantPerLayer operation.
315    """
316
317    @prim_attr_register
318    def __init__(self):
319        from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad_reduce
320        self.init_prim_io_names(
321            inputs=['dout_alpha'], outputs=['dalpha'])
322
323    def infer_shape(self, dout_alpha_shape):
324        return (1,)
325
326    def infer_dtype(self, dout_alpha_type):
327        valid_dtypes = (mstype.float16, mstype.float32)
328        validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name)
329        return dout_alpha_type
330
331
332class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
333    r"""
334    Simulates the quantize and dequantize operations of the fake learned scale quant per-channel case in training time.
335
336    Args:
337        quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
338            simulate quantization aware function. After delay step in training time begin simulate the aware
339            quantize function. Default: 0.
340        neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: False.
341        training (bool): Training the network or not. Default: True.
342        channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
343
344    Inputs:
345        - **input_x** (Tensor) : Input tensor that needs to be quantified.
346        - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`.
347        - **quant_max** (Tensor) : Value of the quantization range.
348
349    Outputs:
350        - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`.
351
352    Examples:
353        >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
354        >>> alpha_tensor = Tensor(np.array([6]*3), mstype.float32)
355        >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32)
356        >>> output_tensor = FakeLearnedScaleQuantPerChannel()(input_tensor, alpha_tensor, quant_max_tensor)
357    """
358    ascend_support_x_rank = [2, 4]
359
360    @prim_attr_register
361    def __init__(self,
362                 quant_delay=0,
363                 neg_trunc=False,
364                 training=True,
365                 channel_axis=1):
366        """init FakeLearnedScaleQuantPerChannel OP"""
367        if context.get_context('device_target') == "Ascend":
368            from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel
369        self.is_ascend = context.get_context('device_target') == "Ascend"
370        self.quant_delay = validator.check_non_negative_int(
371            quant_delay, 'quant_delay', self.name)
372        self.neg_trunc = validator.check_value_type(
373            'neg_trunc', neg_trunc, (bool,), self.name)
374        self.training = validator.check_value_type(
375            'training', training, (bool,), self.name)
376        if self.is_ascend:
377            self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
378        else:
379            self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
380        self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
381                                outputs=['out'])
382
383    def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
384        if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank:
385            raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
386        if not self.is_ascend:
387            validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name)
388        if len(input_x_shape) == 1:
389            self.channel_axis = 0
390
391        validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name)
392        validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
393        return input_x_shape
394
395    def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
396        if context.get_context('device_target') == "GPU":
397            valid_dtypes = (mstype.float32,)
398        else:
399            valid_dtypes = (mstype.float16, mstype.float32)
400        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
401                  ("input_x", "alpha", "quant_max"),
402                  (input_x_type, alpha_type, quant_max_type)))
403        return input_x_type
404
405
406class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer):
407    r"""
408    Performs grad of FakeLearnedScaleQuantPerChannel operation.
409
410    Examples:
411        >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerChannelGrad()
412        >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
413        >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
414        >>> _alpha = Tensor(np.array([6]*2), mindspore.float32)
415        >>> _quant_max = Tensor(np.array([127]), mindspore.float32)
416        >>> result = fake_learned_scale_grad(dout, input_x, _min, _max)
417    """
418
419    @prim_attr_register
420    def __init__(self,
421                 quant_delay=0,
422                 neg_trunc=False,
423                 channel_axis=1):
424        self.quant_delay = validator.check_non_negative_int(
425            quant_delay, 'quant_delay', self.name)
426        self.neg_trunc = validator.check_value_type(
427            'neg_trunc', neg_trunc, (bool,), self.name)
428        self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
429        self.init_prim_io_names(
430            inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
431
432    def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
433        validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
434        return dout_shape, alpha_shape
435
436    def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
437        if context.get_context('device_target') == "GPU":
438            valid_dtypes = (mstype.float32,)
439        else:
440            valid_dtypes = (mstype.float16, mstype.float32)
441        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
442                  ("dout", "x", "alpha", "quant_max"),
443                  (dout_type, x_type, alpha_type, quant_max_type)))
444        return dout_type, alpha_type
445
446
447class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer):
448    r"""
449    Performs input grad of FakeLearnedScaleQuantPerChannel operation.
450    """
451
452    @prim_attr_register
453    def __init__(self,
454                 neg_trunc=False,
455                 channel_axis=1):
456        from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad
457        self.neg_trunc = validator.check_value_type(
458            'neg_trunc', neg_trunc, (bool,), self.name)
459        self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
460        self.init_prim_io_names(
461            inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
462
463    def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
464        validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name)
465        validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name)
466        validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name)
467        return dout_shape, dout_shape
468
469    def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
470        valid_dtypes = (mstype.float16, mstype.float32)
471        tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
472                  ("dout", "x", "alpha", "quant_max"),
473                  (dout_type, x_type, alpha_type, quant_max_type)))
474        return dout_type, dout_type
475
476
477class FakeLearnedScaleQuantPerChannelGradDReduce(PrimitiveWithInfer):
478    r"""
479    Performs alpha grad reduce of FakeLearnedScaleQuantPerChannel operation.
480    """
481
482    @prim_attr_register
483    def __init__(self, channel_axis=1):
484        from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad_reduce
485        self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
486        self.init_prim_io_names(
487            inputs=['dout_alpha'], outputs=['dalpha'])
488
489    def infer_shape(self, dout_alpha_shape):
490        return (dout_alpha_shape[self.channel_axis],)
491
492    def infer_dtype(self, dout_alpha_type):
493        valid_dtypes = (mstype.float16, mstype.float32)
494        validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name)
495        return dout_alpha_type
496
497
498class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
499    r"""
500    Fake-quantize the input by min and max.
501
502    Args:
503        num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8.
504        narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
505            if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
506            range is [1, 2^num_bits-1]. Default: False.
507
508    Inputs:
509        - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
510        - **min** (Tensor) - Value of the min range of the input data x.
511        - **max** (Tensor) - Value of the max range of the input data x.
512
513    Outputs:
514        - Tensor, the data type and shape of output tensor is the same as input x.
515
516    Examples:
517        >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
518        >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
519        >>> max_tensor = Tensor(np.array([6]), mstype.float32)
520        >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
521        ...                 input_tensor, min_tensor, max_tensor)
522        >>> output_tensor # shape: (3, 16, 5, 5)  data type: mstype.float32
523    """
524
525    @prim_attr_register
526    def __init__(self,
527                 num_bits=8,
528                 narrow_range=False):
529        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
530        self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
531        self.narrow_range = validator.check_value_type(
532            'narrow_range', narrow_range, (bool,), self.name)
533
534    def check_broadcast(self, min_shape, input_shape):
535        shape_val = 1
536        for shape in input_shape:
537            shape_val = shape_val * shape
538        if min_shape[0] > 1 and min_shape[0] != shape_val:
539            raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
540
541    def infer_shape(self, x_shape, min_shape, max_shape):
542        validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
543        validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
544        validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
545        self.check_broadcast(min_shape, x_shape)
546        return x_shape
547
548    def infer_dtype(self, x_type, min_type, max_type):
549        tuple(map(partial(validator.check_tensor_dtype_valid,
550                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
551                  ("x", "min", "max"),
552                  (x_type, min_type, max_type)))
553        return x_type
554
555
556class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
557    r"""
558    Performs grad of FakeQuantWithMinMaxVars operation.
559
560    Args:
561        num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
562        narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
563            if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
564            range is [1, 2^num_bits-1]. Default: False.
565
566    Inputs:
567        - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
568        - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
569        - **min** (Tensor) - Value of the min range of the input data x.
570        - **max** (Tensor) - Value of the max range of the input data x.
571
572    Outputs:
573        - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
574        - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
575        - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
576
577    Examples:
578        >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
579        >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
580        >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
581        >>> max_tensor = Tensor(np.array([6]), mstype.float32)
582        >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
583        ...                                          (gradients, input_tensor, min_tensor, max_tensor)
584        >>> x_gradient   # shape: (3, 16, 5, 5)  data type: mstype.float32
585        >>> min_gradient # shape: (1,)           data type: mstype.float32
586        >>> max_gradient # shape: (1,)           data type: mstype.float32
587    """
588
589    @prim_attr_register
590    def __init__(self,
591                 num_bits=8,
592                 narrow_range=False):
593        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
594        self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
595        self.narrow_range = validator.check_value_type(
596            'narrow_range', narrow_range, (bool,), self.name)
597
598    def check_broadcast(self, min_shape, input_shape):
599        shape_val = 1
600        for shape in input_shape:
601            shape_val = shape_val * shape
602        if min_shape[0] > 1 and min_shape[0] != shape_val:
603            raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
604
605    def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
606        validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
607        validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
608        validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
609        validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
610        self.check_broadcast(min_shape, x_shape)
611        return x_shape, min_shape, max_shape
612
613    def infer_dtype(self, dout_type, x_type, min_type, max_type):
614        tuple(map(partial(validator.check_tensor_dtype_valid,
615                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
616                  ('dout', "x", "min", "max"),
617                  (dout_type, x_type, min_type, max_type)))
618        return x_type, min_type, max_type
619
620
621class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
622    r"""
623    Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max
624
625    Args:
626        num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
627        narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
628            if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
629            range is [1, 2^num_bits-1]. Default: False.
630
631    Inputs:
632        - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
633        - **min** (Tensor) - Value of the min range of the input data x.
634        - **max** (Tensor) - Value of the max range of the input data x.
635
636    Outputs:
637        - Tensor, the data type and shape of output tensor is the same as input x.
638
639    Examples:
640        >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
641        >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
642        >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
643        >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
644        ...                 input_tensor, min_tensor, max_tensor)
645        >>> output_tensor # shape: (3, 16, 3, 4)  data type: mstype.float32
646    """
647
648    @prim_attr_register
649    def __init__(self,
650                 num_bits=8,
651                 narrow_range=False):
652        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
653        self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
654        self.narrow_range = validator.check_value_type(
655            'narrow_range', narrow_range, (bool,), self.name)
656
657    def infer_shape(self, x_shape, min_shape, max_shape):
658        validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
659        validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
660        validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
661        validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
662        return x_shape
663
664    def infer_dtype(self, x_type, min_type, max_type):
665        tuple(map(partial(validator.check_tensor_dtype_valid,
666                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
667                  ("x", "min", "max"),
668                  (x_type, min_type, max_type)))
669        return x_type
670
671
672class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
673    r"""
674    Performs grad of FakeQuantWithMinMaxVars operation.
675
676    Args:
677        num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
678        narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
679            if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
680            range is [1, 2^num_bits-1]. Default: False.
681
682    Inputs:
683        - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
684        - **x** (Tensor) - Float32 tensor representing the shape of the output tensor.
685        - **min** (Tensor) - Value of the min range of the input data x.
686        - **max** (Tensor) - Value of the max range of the input data x.
687
688    Outputs:
689        - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
690        - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
691        - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
692
693    Examples:
694        >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
695        >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
696        >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
697        >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
698        >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
699        ...                                          num_bits=8, narrow_range=False)(
700        ...                                          gradients, input_tensor, min_tensor, max_tensor)
701        >>> x_gradient   # shape: (3, 16, 3, 4)  data type: mstype.float32
702        >>> min_gradient # shape: (4,)           data type: mstype.float32
703        >>> max_gradient # shape: (4,)           data type: mstype.float32
704    """
705
706    @prim_attr_register
707    def __init__(self,
708                 num_bits=8,
709                 narrow_range=False):
710        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
711        self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name)
712        self.narrow_range = validator.check_value_type(
713            'narrow_range', narrow_range, (bool,), self.name)
714
715    def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
716        validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
717        validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
718        validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
719        validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name)
720        validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name)
721        return x_shape, min_shape, max_shape
722
723    def infer_dtype(self, dout_type, x_type, min_type, max_type):
724        tuple(map(partial(validator.check_tensor_dtype_valid,
725                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
726                  ("dout", "x", "min", "max"),
727                  (dout_type, x_type, min_type, max_type)))
728        return x_type, min_type, max_type
729
730
731def _fake_quant_per_infer_dtype(prim_name, x_type, min_type, max_type):
732    if context.get_context('device_target') == "GPU":
733        valid_dtypes = (mstype.float32,)
734    else:
735        valid_dtypes = (mstype.float16, mstype.float32)
736    tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
737              ("x", "min", "max"),
738              (x_type, min_type, max_type)))
739    return x_type
740
741
742def _fake_quant_per_grad_infer_dtype(prim_name, dout_type, x_type, min_type, max_type):
743    if context.get_context('device_target') == "GPU":
744        valid_dtypes = (mstype.float32,)
745    else:
746        valid_dtypes = (mstype.float16, mstype.float32)
747    tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
748              ("dout", "x", "min", "max"),
749              (dout_type, x_type, min_type, max_type)))
750    return dout_type
751
752
753class FakeQuantPerLayer(PrimitiveWithInfer):
754    r"""
755    Simulates the quantize and dequantize operations in training time.
756
757    Args:
758        num_bits (int) : Number bits for quantization aware. Default: 8.
759        ema (bool): Uses EMA algorithm update value min and max. Default: False.
760        ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
761        quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
762            simulate quantization aware function. After delay step in training time begin simulate the aware
763            quantize function. Default: 0.
764        symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
765        narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
766        training (bool): Training the network or not. Default: True.
767
768    Inputs:
769        - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
770        - **min** (Tensor) : Value of the min range of the input data x.
771        - **max** (Tensor) : Value of the max range of the input data x.
772
773    Outputs:
774        - Tensor: Simulates quantize tensor of x.
775
776    Examples:
777        >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
778        >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
779        >>> max_tensor = Tensor(np.array([6]), mstype.float32)
780        >>> output_tensor = FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
781    """
782    support_quant_bit = [4, 7, 8]
783
784    @prim_attr_register
785    def __init__(self,
786                 num_bits=8,
787                 ema=False,
788                 ema_decay=0.999,
789                 quant_delay=0,
790                 symmetric=False,
791                 narrow_range=False,
792                 training=True):
793        """Initialize FakeQuantPerLayer OP"""
794        if context.get_context('device_target') == "Ascend":
795            from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
796        if num_bits not in self.support_quant_bit:
797            raise ValueError(
798                f"For '{self.name}' attr \'num_bits\' is not support.")
799        if ema and not ema_decay:
800            raise ValueError(
801                f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
802
803        self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
804        self.symmetric = validator.check_value_type(
805            'symmetric', symmetric, (bool,), self.name)
806        self.narrow_range = validator.check_value_type(
807            'narrow_range', narrow_range, (bool,), self.name)
808        self.training = validator.check_value_type('training', training, (bool,), self.name)
809        self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
810        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
811        self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
812        self.init_prim_io_names(inputs=['x', 'min', 'max'],
813                                outputs=['out'])
814
815    def infer_shape(self, x_shape, min_shape, max_shape):
816        validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
817        validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
818        validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
819        return x_shape
820
821    def infer_dtype(self, x_type, min_type, max_type):
822        return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
823
824
825class FakeQuantPerLayerGrad(PrimitiveWithInfer):
826    r"""
827    Performs grad of FakeQuantPerLayer operation.
828
829    Examples:
830        >>> fake_min_max_grad = FakeQuantPerLayerGrad()
831        >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
832        >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
833        >>> _min = Tensor(np.array([-4]), mindspore.float32)
834        >>> _max = Tensor(np.array([2]), mindspore.float32)
835        >>> result = fake_min_max_grad(dout, input_x, _min, _max)
836    """
837    support_quant_bit = [4, 7, 8]
838
839    @prim_attr_register
840    def __init__(self,
841                 num_bits=8,
842                 quant_delay=0,
843                 symmetric=False,
844                 narrow_range=False):
845        if context.get_context('device_target') == "Ascend":
846            from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
847        if num_bits not in self.support_quant_bit:
848            raise ValueError(
849                f"For '{self.name}' attr \'num_bits\' is not support.")
850
851        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
852        self.quant_delay = validator.check_value_type(
853            'quant_delay', quant_delay, (int,), self.name)
854        self.symmetric = validator.check_value_type(
855            'symmetric', symmetric, (bool,), self.name)
856        self.narrow_range = validator.check_value_type(
857            'narrow_range', narrow_range, (bool,), self.name)
858        self.init_prim_io_names(
859            inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
860
861    def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
862        validator.check("dout shape", dout_shape, "x shape",
863                        x_shape, Rel.EQ, self.name)
864        validator.check("min shape", min_shape, "max shape",
865                        max_shape, Rel.EQ, self.name)
866        validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
867        return dout_shape
868
869    def infer_dtype(self, dout_type, x_type, min_type, max_type):
870        return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
871
872
873class FakeQuantPerChannel(PrimitiveWithInfer):
874    r"""
875    Simulates the quantize and dequantize operations in training time base on per channel.
876
877    Args:
878        num_bits (int) : Number bits to quantilization. Default: 8.
879        ema (bool): Uses EMA algorithm update tensor min and tensor max. Default: False.
880        ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
881        quant_delay (int): Quantilization delay  parameter. Before delay step in training time not
882            update the weight data to simulate quantize operation. After delay step in training time
883            begin simulate the quantize operation. Default: 0.
884        symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
885        narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
886        training (bool): Training the network or not. Default: True.
887        channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
888
889    Inputs:
890        - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
891        - **min** (int, float) : Value of the min range of the input data.
892        - **max** (int, float) : Value of the max range of the input data.
893
894    Outputs:
895        - Tensor, has the same type as input.
896
897    Examples:
898        >>> fake_quant = FakeQuantPerChannel()
899        >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
900        >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
901        >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
902        >>> result = fake_quant(input_x, _min, _max)
903    """
904    support_quant_bit = [4, 7, 8]
905    ascend_support_x_rank = [2, 4]
906
907    @prim_attr_register
908    def __init__(self,
909                 num_bits=8,
910                 ema=False,
911                 ema_decay=0.999,
912                 quant_delay=0,
913                 symmetric=False,
914                 narrow_range=False,
915                 training=True,
916                 channel_axis=1):
917        """Initialize FakeQuantPerChannel OP"""
918        self.is_ascend = context.get_context('device_target') == "Ascend"
919        if self.is_ascend:
920            from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
921        if num_bits not in self.support_quant_bit:
922            raise ValueError(
923                f"For '{self.name}' Attr \'num_bits\' is not support.")
924        if ema and not ema_decay:
925            raise ValueError(
926                f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
927
928        self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
929        self.symmetric = validator.check_value_type(
930            'symmetric', symmetric, (bool,), self.name)
931        self.narrow_range = validator.check_value_type(
932            'narrow_range', narrow_range, (bool,), self.name)
933        self.training = validator.check_value_type(
934            'training', training, (bool,), self.name)
935        self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name)
936        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
937        self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
938        if self.is_ascend:
939            self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name)
940        else:
941            self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
942        self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
943
944    def infer_shape(self, x_shape, min_shape, max_shape):
945        if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
946            raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
947        if not self.is_ascend:
948            validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name)
949        if len(x_shape) == 1:
950            self.channel_axis = 0
951        validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
952        validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
953        validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
954        return x_shape
955
956    def infer_dtype(self, x_type, min_type, max_type):
957        return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
958
959
960class FakeQuantPerChannelGrad(PrimitiveWithInfer):
961    r"""
962    Performs grad of FakeQuantPerChannel operation.
963
964    Examples:
965        >>> fqmmpc_grad = FakeQuantPerChannelGrad()
966        >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
967        >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
968        >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
969        >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
970        >>> result = fqmmpc_grad(dout, input_x, _min, _max)
971    """
972    support_quant_bit = [4, 7, 8]
973
974    @prim_attr_register
975    def __init__(self,
976                 num_bits=8,
977                 quant_delay=0,
978                 symmetric=False,
979                 narrow_range=False,
980                 channel_axis=1):
981        """Initialize FakeQuantPerChannelGrad Fill"""
982        if context.get_context('device_target') == "Ascend":
983            from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
984        if num_bits not in self.support_quant_bit:
985            raise ValueError(
986                f"For '{self.name}' attr \'num_bits\' is not support.")
987
988        self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
989        self.quant_delay = validator.check_value_type(
990            'quant_delay', quant_delay, (int,), self.name)
991        self.symmetric = validator.check_value_type(
992            'symmetric', symmetric, (bool,), self.name)
993        self.narrow_range = validator.check_value_type(
994            'narrow_range', narrow_range, (bool,), self.name)
995        self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
996        self.init_prim_io_names(
997            inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
998
999    def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
1000        validator.check("dout shape", dout_shape, "x shape", x_shape)
1001        validator.check("min shape", min_shape, "max shape", max_shape)
1002        return dout_shape
1003
1004    def infer_dtype(self, dout_type, x_type, min_type, max_type):
1005        return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
1006
1007
1008class BatchNormFold(PrimitiveWithInfer):
1009    """
1010    Batch Normalization folded.
1011
1012    Args:
1013        momentum (float): Momentum value must be [0, 1]. Default: 0.9.
1014        epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
1015            float32 else 1e-3. Default: 1e-5.
1016        is_training (bool): In training mode set True, else set False. Default: True.
1017        freeze_bn (int): Delay in steps at which computation switches from regular batch
1018            norm to frozen mean and std. Default: 0.
1019
1020    Inputs:
1021        - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
1022        - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
1023        - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
1024        - **global_step** (Tensor) - Tensor to record current global step.
1025
1026    Outputs:
1027        Tuple of 4 Tensor, the normalized input and the updated parameters.
1028
1029        - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1030        - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1031        - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1032        - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1033
1034    Examples:
1035        >>> batch_norm_fold = P.BatchNormFold()
1036        >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32)
1037        >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32)
1038        >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32)
1039        >>> global_step = Tensor(np.arange(6), mindspore.int32)
1040        >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
1041    """
1042    channel_axis = 1
1043
1044    @prim_attr_register
1045    def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1046        """Initialize batch norm fold layer"""
1047        self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1048        self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1049        self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1050        self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1051
1052        self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
1053                                outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
1054
1055    def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
1056        validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
1057        validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
1058        validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1059        return mean_shape, mean_shape, mean_shape, mean_shape
1060
1061    def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
1062        validator.check("input type", x_type, "mean type", mean_type)
1063        validator.check("input type", x_type, "variance type", variance_type)
1064        args = {"x": x_type, "mean": mean_type, "variance": variance_type}
1065        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1066        validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1067        return x_type, x_type, x_type, x_type
1068
1069
1070class BatchNormFoldGrad(PrimitiveWithInfer):
1071    r"""
1072    Performs grad of BatchNormFold operation.
1073
1074    Examples:
1075        >>> batch_norm_fold_grad = ops.BatchNormFoldGrad()
1076        >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32)
1077        >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32)
1078        >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32)
1079        >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32)
1080        >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32)
1081        >>> global_step = Tensor([2], mindspore.int32)
1082        >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
1083    """
1084    channel_axis = 1
1085
1086    @prim_attr_register
1087    def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
1088        """Initialize BatchNormGrad layer"""
1089        self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1090        self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1091        self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1092        self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
1093                                outputs=['dx'])
1094
1095    def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
1096                    global_step_shape):
1097        validator.check("d_batch_mean shape", d_batch_mean_shape,
1098                        "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name)
1099        validator.check("d_batch_mean shape", d_batch_mean_shape,
1100                        "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1101        validator.check("d_batch_mean shape", d_batch_mean_shape,
1102                        "batch_std shape", batch_std_shape, Rel.EQ, self.name)
1103        validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
1104                        "input channel", x_shape[self.channel_axis], Rel.EQ, self.name)
1105        validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1106        return x_shape
1107
1108    def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
1109                    global_step_type):
1110        args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
1111                "batch_mean": batch_mean_type, "batch_std": batch_std_type}
1112        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1113        validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1114        return x_type
1115
1116
1117class CorrectionMul(PrimitiveWithInfer):
1118    """
1119    Scales the weights with a correction factor to the long term statistics
1120    prior to quantization. This ensures that there is no jitter in the quantized weights
1121    due to batch to batch variation.
1122
1123    Inputs:
1124        - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
1125        - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1126        - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1127
1128    Outputs:
1129        - **out** (Tensor) - Tensor has the same shape as x.
1130
1131    Examples:
1132        >>> correction_mul = ops.CorrectionMul()
1133        >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32)
1134        >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32)
1135        >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
1136        >>> out = correction_mul(input_x, batch_std, running_std)
1137    """
1138
1139    @prim_attr_register
1140    def __init__(self, channel_axis=0):
1141        """Initialize correction mul layer"""
1142        if context.get_context('device_target') == "Ascend":
1143            from mindspore.ops._op_impl._custom_op import correction_mul
1144        self.channel_axis = channel_axis
1145        self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
1146                                outputs=['out'])
1147
1148    def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
1149        validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1150        validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1151                        Rel.EQ, self.name)
1152        return x_shape
1153
1154    def infer_dtype(self, x_type, batch_std_type, running_std_type):
1155        args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
1156        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1157        return x_type
1158
1159
1160class CorrectionMulGrad(PrimitiveWithInfer):
1161    r"""
1162    Performs grad of CorrectionMul operation.
1163
1164    Examples:
1165        >>> correction_mul_grad = ops.CorrectionMulGrad()
1166        >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
1167        >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
1168        >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
1169        >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
1170        >>> result = correction_mul_grad(dout, input_x, gamma, running_std)
1171    """
1172
1173    @prim_attr_register
1174    def __init__(self, channel_axis=0):
1175        """Initialize correction mul layer"""
1176        if context.get_context('device_target') == "Ascend":
1177            from mindspore.ops._op_impl._custom_op import correction_mul_grad
1178        self.channel_axis = channel_axis
1179        self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
1180                                outputs=['dx', 'mul_dx'])
1181
1182    def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
1183        validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name)
1184        validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
1185                        Rel.EQ, self.name)
1186        validator.check("running_std_shape[0]", running_std_shape[0],
1187                        "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name)
1188        if context.get_context('device_target') == "Ascend":
1189            return x_shape, x_shape
1190        return x_shape, gamma_shape
1191
1192    def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
1193        args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
1194        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1195        if context.get_context('device_target') == "Ascend":
1196            return x_type, x_type
1197        return x_type, gamma_type
1198
1199
1200class CorrectionMulGradReduce(PrimitiveWithInfer):
1201    r"""
1202    Performs grad reduce of CorrectionMul operation.
1203
1204    Examples:
1205        >>> correction_mul_grad_rd = ops.CorrectionMulGradReduce()
1206        >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
1207        >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
1208        >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
1209        >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
1210        >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
1211    """
1212
1213    @prim_attr_register
1214    def __init__(self, channel_axis=0):
1215        """Initialize correction mul reduce layer"""
1216        if context.get_context('device_target') == "Ascend":
1217            from mindspore.ops._op_impl._custom_op import correction_mul_grad
1218        self.channel_axis = channel_axis
1219        self.init_prim_io_names(inputs=['mul_dx'],
1220                                outputs=['d_gamma'])
1221
1222    def infer_shape(self, mul_dx_shape):
1223        return [mul_dx_shape[self.channel_axis]]
1224
1225    def infer_dtype(self, mul_dx_type):
1226        return mul_dx_type
1227
1228
1229class BatchNormFold2(PrimitiveWithInfer):
1230    """
1231    Scales the bias with a correction factor to the long term statistics
1232    prior to quantization. This ensures that there is no jitter in the quantized bias
1233    due to batch to batch variation.
1234
1235    Inputs:
1236        - **x** (Tensor)  - Tensor of shape :math:`(N, C)`.
1237        - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
1238        - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
1239        - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1240        - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1241        - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1242        - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1243        - **global_step** (Tensor) - Tensor to record current global step.
1244
1245    Outputs:
1246        - **y** (Tensor) - Tensor has the same shape as x.
1247
1248    Examples:
1249        >>> batch_norm_fold2 = ops.BatchNormFold2()
1250        >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32)
1251        >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32)
1252        >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32)
1253        >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32)
1254        >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32)
1255        >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32)
1256        >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32)
1257        >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32)
1258        >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
1259        >>>                           running_std, running_mean, global_step)
1260    """
1261    channel_axis = 1
1262
1263    @prim_attr_register
1264    def __init__(self, freeze_bn=0):
1265        """Initialize conv2d fold layer"""
1266        self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1267        self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
1268                                        'running_std', 'running_mean', 'global_step'],
1269                                outputs=['y'])
1270
1271    def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
1272                    running_mean_shape, global_step_shape):
1273        validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1274        validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1275        validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
1276        validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1277                        Rel.EQ, self.name)
1278        validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
1279        validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1280                        Rel.EQ, self.name)
1281        validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1282        return x_shape
1283
1284    def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
1285                    running_mean_type, global_step_type):
1286        args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
1287                "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
1288        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1289        validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1290        return x_type
1291
1292
1293class BatchNormFold2Grad(PrimitiveWithInfer):
1294    r"""
1295    Performs grad of BatchNormFold2 operation.
1296
1297    Examples:
1298        >>> bnf2_grad = ops.BatchNormFold2Grad()
1299        >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32)
1300        >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32)
1301        >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32)
1302        >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32)
1303        >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32)
1304        >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32)
1305        >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32)
1306        >>> global_step = Tensor(np.array([-2]), mindspore.int32)
1307        >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
1308    """
1309    channel_axis = 1
1310
1311    @prim_attr_register
1312    def __init__(self, freeze_bn=0):
1313        """Initialize MulFold layer"""
1314        self.freeze_bn = freeze_bn
1315        self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
1316                                        'batch_std', 'batch_mean',
1317                                        'running_std', 'running_mean', 'global_step'],
1318                                outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
1319
1320    def infer_shape(self, dout_shape, x_shape, gamma_shape,
1321                    batch_std_shape, batch_mean_shape,
1322                    running_std_shape, running_mean_shape, global_step_shape):
1323        validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1324        validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1325        validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1326                        Rel.EQ, self.name)
1327        validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
1328        validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1329                        Rel.EQ, self.name)
1330        validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1331        return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
1332
1333    def infer_dtype(self, dout_type, x_type, gamma_type,
1334                    batch_std_type, batch_mean_type,
1335                    running_std_type, running_mean_type, global_step_type):
1336        validator.check("batch_std type", batch_std_type,
1337                        "batch_mean type", batch_mean_type)
1338        validator.check("batch_std type", batch_std_type,
1339                        "gamma type", gamma_type)
1340        validator.check("batch_std type", batch_std_type,
1341                        "running_std type", running_std_type)
1342        validator.check("batch_std type", batch_std_type,
1343                        "running_mean type", running_mean_type)
1344        validator.check("batch_std_type", batch_std_type,
1345                        "dout type", dout_type)
1346        args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
1347                "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
1348        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1349        validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1350        return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
1351
1352
1353class BatchNormFoldD(PrimitiveWithInfer):
1354    """Performs grad of _BatchNormFold operation."""
1355
1356    @prim_attr_register
1357    def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1358        """Initialize _BatchNormFold layer"""
1359        from mindspore.ops._op_impl._custom_op import batchnorm_fold
1360        self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
1361        self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1362        self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1363        self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1364        self.data_format = "NCHW"
1365        self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
1366                                outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
1367                                         'mean_updated', 'variance_updated'])
1368
1369    def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
1370        validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name)
1371        validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name)
1372        return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
1373
1374    def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
1375        validator.check("input type", x_type, "mean type", mean_type)
1376        validator.check("input type", x_type, "variance type", variance_type)
1377        args = {"x": x_type, "mean": mean_type, "variance": variance_type}
1378        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1379        return x_type, x_type, x_type, x_type, x_type, x_type, x_type
1380
1381
1382class BatchNormFoldGradD(PrimitiveWithInfer):
1383    """Performs grad of BatchNormFold operation."""
1384
1385    @prim_attr_register
1386    def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
1387        """Initialize _BatchNormFoldGrad layer"""
1388        from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
1389        self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1390        self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1391        self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1392        self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
1393                                outputs=['dx'])
1394
1395    def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
1396        validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
1397        validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
1398        validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
1399        validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
1400        return x_shape
1401
1402    def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
1403        validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
1404        validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
1405        validator.check("input type", x_type, "batch_mean type", batch_mean_type)
1406        validator.check("input type", x_type, "batch_std type", batch_std_type)
1407        validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
1408        return x_type
1409
1410
1411class BatchNormFold2D(PrimitiveWithInfer):
1412    """
1413    Scales the bias with a correction factor to the long term statistics
1414    prior to quantization. This ensures that there is no jitter in the quantized bias
1415    due to batch to batch variation.
1416
1417    Inputs:
1418        - **x** (Tensor)  - Tensor of shape :math:`(N, C)`.
1419        - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
1420        - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
1421        - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1422        - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1423        - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1424        - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1425        - **global_step** (Tensor) - Tensor to record current global step.
1426
1427    Outputs:
1428        - **y** (Tensor) - Tensor has the same shape as x.
1429
1430    """
1431    channel_axis = 1
1432
1433    @prim_attr_register
1434    def __init__(self, freeze_bn=0):
1435        """Initialize conv2d fold layer"""
1436        from mindspore.ops._op_impl._custom_op import batchnorm_fold2
1437        self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
1438                                outputs=['y'])
1439
1440    def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
1441        validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1442        validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1443        validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name)
1444        validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name)
1445        validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1446                        Rel.EQ, self.name)
1447        return x_shape
1448
1449    def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
1450        args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
1451                "beta": beta_type, "gamma": gamma_type, "x": x_type}
1452        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1453        return x_type
1454
1455
1456class BatchNormFold2GradD(PrimitiveWithInfer):
1457    """Performs grad of BatchNormFold2 operation."""
1458    channel_axis = 1
1459
1460    @prim_attr_register
1461    def __init__(self, freeze_bn=False):
1462        """Initialize MulFold layer"""
1463        from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
1464        self.freeze_bn = freeze_bn
1465        self.init_prim_io_names(
1466            inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
1467            outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
1468
1469    def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
1470                    batch_mean_shape, running_std_shape):
1471        validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name)
1472        validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name)
1473        validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name)
1474        validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1475                        Rel.EQ, self.name)
1476        return gamma_shape, gamma_shape, gamma_shape, dout_shape
1477
1478    def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
1479                    batch_mean_type, running_std_type):
1480        validator.check("batch_std type", batch_std_type,
1481                        "batch_mean type", batch_mean_type)
1482        validator.check("batch_std type", batch_std_type,
1483                        "gamma type", gamma_type)
1484        validator.check("batch_std type", batch_std_type,
1485                        "running_std type", running_std_type)
1486        validator.check("batch_std_type", batch_std_type,
1487                        "dout type", dout_type)
1488        args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
1489                "running_std": running_std_type, "dout": dout_type}
1490        validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1491        return gamma_type, gamma_type, gamma_type, gamma_type
1492
1493
1494class BatchNormFold2GradReduce(PrimitiveWithInfer):
1495    """Performs grad of CorrectionAddGrad operation."""
1496    channel_axis = 1
1497
1498    @prim_attr_register
1499    def __init__(self, freeze_bn=False):
1500        """Initialize MulFold layer"""
1501        from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
1502        self.freeze_bn = freeze_bn
1503        self.init_prim_io_names(inputs=['dout', 'x'],
1504                                outputs=['dout_reduce', 'dout_x_reduce'])
1505
1506    def infer_shape(self, dout_shape, x_shape):
1507        validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name)
1508        return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
1509
1510    def infer_dtype(self, dout_type, x_type):
1511        validator.check("dout type", dout_type, "x type", x_type)
1512        return dout_type, dout_type
1513
1514
1515class ActsULQ(PrimitiveWithInfer):
1516    """
1517    The ActsULQ(Activation universal learnable quantization).
1518
1519    Args:
1520        fixed_min (bool): whether fix clamp min to zero.
1521        num_bits (int): The bits num used for quantize.
1522
1523    Inputs:
1524        - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
1525        - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x.
1526        - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x.
1527
1528    Outputs:
1529        - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`.
1530        - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min.
1531        - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max.
1532        - **x_clamped_loss** (Tensor) - A tensor of clamped loss.
1533
1534    Examples:
1535        >>> data_type = np.float32
1536        >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1537        >>> clamp_max = 0.7 * np.max(x)
1538        >>> clamp_min = 0.7 * np.min(x)
1539        >>> clamp_max = np.array([clamp_max], dtype=data_type)
1540        >>> clamp_min = np.array([clamp_min], dtype=data_type)
1541        >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8)
1542        >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min),
1543                                                                               Tensor(clamp_max))
1544    """
1545    @prim_attr_register
1546    def __init__(self, fixed_min=False, num_bits=8):
1547        validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
1548        validator.check_value_type("num_bits", num_bits, [int], self.name)
1549        validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
1550
1551    def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
1552        """infer shape of primitive"""
1553        validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name)
1554        validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name)
1555
1556        x_shape_len = len(x_shape)
1557        for i in range(x_shape_len):
1558            validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name)
1559            validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name)
1560
1561        return x_shape, x_shape, x_shape, x_shape
1562
1563    def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
1564        """infer dtype of primitive"""
1565        valid_types = [mstype.float32, mstype.float16]
1566        validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name)
1567        validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name)
1568        validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name)
1569
1570        return x_dtype, mstype.bool_, mstype.bool_, x_dtype
1571
1572
1573class ActsULQInputGrad(PrimitiveWithInfer):
1574    """
1575    The ActsULQInputGrad(grad of ActsULQ).
1576
1577    Inputs:
1578        - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type.
1579
1580    Outputs:
1581        - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`.
1582    """
1583    @prim_attr_register
1584    def __init__(self):
1585        pass
1586
1587    def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape):
1588        return y_grad_shape
1589
1590    def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
1591        valid_types = [mstype.float32, mstype.float16]
1592        validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name)
1593        return y_grad_type
1594
1595
1596class ActULQClampMinGrad(PrimitiveWithInfer):
1597    """
1598    The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient)
1599
1600    Inputs:
1601        - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
1602        - **clamp_min_mask** - A tensor of mask, only support int8 type.
1603        - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
1604
1605    Outputs:
1606        - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad".
1607          The length of tensor is 1.
1608
1609    Examples:
1610        >>> data_type = np.float32
1611        >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1612        >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
1613        >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1614        >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad()
1615        >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_),
1616                                                           Tensor(x_clamped_loss))
1617    """
1618    @prim_attr_register
1619    def __init__(self):
1620        pass
1621
1622    def infer_shape(self, input_x, input_y, input_z):
1623        input_x_len = len(input_x)
1624        output_shape = []
1625        for _ in range(input_x_len):
1626            output_shape.append(1)
1627        return tuple(output_shape)
1628
1629    def infer_dtype(self, input_x, input_y, input_z):
1630        return mstype.float32
1631
1632
1633class ActULQClampMaxGrad(PrimitiveWithInfer):
1634    """
1635    The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient)
1636
1637    Inputs:
1638        - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
1639        - **clamp_max_mask** - A tensor of mask, only support int8 type.
1640        - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
1641
1642    Outputs:
1643        - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad".
1644          The length of tensor is 1.
1645
1646    Examples:
1647        >>> data_type = np.float32
1648        >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1649        >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
1650        >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1651        >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad()
1652        >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_),
1653                                                    Tensor(x_clamped_loss))
1654    """
1655    @prim_attr_register
1656    def __init__(self):
1657        pass
1658
1659    def infer_shape(self, input_x, input_y, input_z):
1660        input_x_len = len(input_x)
1661        output_shape = []
1662        for _ in range(input_x_len):
1663            output_shape.append(1)
1664        return tuple(output_shape)
1665
1666    def infer_dtype(self, input_x, input_y, input_z):
1667        return mstype.float32
1668
1669
1670class WtsARQ(PrimitiveWithInfer):
1671    """
1672    The WtsARQ(Weights Adaptive Range Quantization).
1673
1674    Args:
1675        num_bits (int): The bits num used for quantize.
1676        offset_flag (bool): Whether use offset for quantize.
1677
1678    Inputs:
1679        - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type.
1680
1681    Outputs:
1682        - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`.
1683        - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`.
1684        - If axis is [],
1685          the shape of scale and offset is :math:`(1, )`.
1686        - If axis is [0],
1687          the shape of scale and offset is :math:`(w_1, )`.
1688        - If axis is [1],
1689          the shape of scale and offset is :math:`(w_2, )`.
1690        - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`.
1691
1692    Examples:
1693        >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
1694        >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False)
1695        >>> scale, offset, y = wts_arq(data)
1696    """
1697    @prim_attr_register
1698    def __init__(self, num_bits, offset_flag):
1699        validator.check_value_type("num_bits", num_bits, [int], self.name)
1700        validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name)
1701        validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
1702
1703    def infer_shape(self, w_shape, w_min_shape, w_max_shape):
1704        validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name)
1705        validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name)
1706        return w_shape
1707
1708    def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
1709        valid_types = [mstype.float32, mstype.float16]
1710        validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name)
1711        validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name)
1712        validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name)
1713        return w_dtype
1714
1715
1716class IFMR(PrimitiveWithInfer):
1717    """
1718    The TFMR(Input Feature Map Reconstruction).
1719
1720    Args:
1721        min_percentile (float): Min init percentile. Default: 0.999999.
1722        max_percentile (float): Max init percentile. Default: 0.999999.
1723        search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3].
1724        search_step (float): Step size of searching. Default: 0.01.
1725        with_offset (bool): Whether using offset. Default: True.
1726
1727    Inputs:
1728        - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
1729        - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`.
1730          With float16 or float32 data type.
1731        - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`.
1732          With float16 or float32 data type.
1733        - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type.
1734
1735    Outputs:
1736        - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32.
1737        - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32.
1738
1739    Examples:
1740        >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
1741        >>> data_min = Tensor([0.1], mindspore.float32)
1742        >>> data_max = Tensor([0.5], mindspore.float32)
1743        >>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
1744        >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
1745        ...               search_step=1.0, with_offset=False)
1746        >>> output = ifmr(data, data_min, data_max, cumsum)
1747        >>> print(output)
1748        (Tensor(shape=[1], dtype=Float32, value= [7.87401572e-03]),
1749         Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]))
1750    """
1751
1752    @prim_attr_register
1753    def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
1754                 with_offset=True):
1755        validator.check_value_type("min_percentile", min_percentile, [float], self.name)
1756        validator.check_value_type("max_percentile", max_percentile, [float], self.name)
1757        validator.check_value_type("search_range", search_range, [list, tuple], self.name)
1758        for item in search_range:
1759            validator.check_positive_float(item, "item of search_range", self.name)
1760        validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name)
1761        validator.check_value_type("search_step", search_step, [float], self.name)
1762        validator.check_value_type("offset_flag", with_offset, [bool], self.name)
1763
1764    def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape):
1765        validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name)
1766        validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name)
1767        validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name)
1768        validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name)
1769        validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name)
1770        return (1,), (1,)
1771
1772    def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype):
1773        tuple(map(partial(validator.check_tensor_dtype_valid,
1774                          valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
1775                  ("input_value", "input_min", "input_max"),
1776                  (data_dtype, data_min_dtype, data_max_dtype)))
1777        validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name)
1778        return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32)
1779