1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0(the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http: // www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""Operators for quantization.""" 17from functools import partial 18 19import mindspore.context as context 20from ..._checkparam import Validator as validator 21from ..._checkparam import Rel 22from ..primitive import PrimitiveWithInfer, prim_attr_register 23from ...common import dtype as mstype 24 25if context.get_context('device_target') == "Ascend": 26 import mindspore.ops._op_impl._custom_op 27 28__all__ = ["MinMaxUpdatePerLayer", 29 "MinMaxUpdatePerChannel", 30 "FakeLearnedScaleQuantPerLayer", 31 "FakeLearnedScaleQuantPerLayerGrad", 32 "FakeLearnedScaleQuantPerLayerGradD", 33 "FakeLearnedScaleQuantPerLayerGradDReduce", 34 "FakeLearnedScaleQuantPerChannel", 35 "FakeLearnedScaleQuantPerChannelGrad", 36 "FakeLearnedScaleQuantPerChannelGradD", 37 "FakeLearnedScaleQuantPerChannelGradDReduce", 38 "FakeQuantWithMinMaxVars", 39 "FakeQuantWithMinMaxVarsGradient", 40 "FakeQuantWithMinMaxVarsPerChannel", 41 "FakeQuantWithMinMaxVarsPerChannelGradient", 42 "FakeQuantPerLayer", 43 "FakeQuantPerLayerGrad", 44 "FakeQuantPerChannel", 45 "FakeQuantPerChannelGrad", 46 "BatchNormFold", 47 "BatchNormFoldGrad", 48 "CorrectionMul", 49 "CorrectionMulGrad", 50 "CorrectionMulGradReduce", 51 "BatchNormFold2", 52 "BatchNormFold2Grad", 53 "BatchNormFoldD", 54 "BatchNormFoldGradD", 55 "BatchNormFold2D", 56 "BatchNormFold2GradD", 57 "BatchNormFold2GradReduce", 58 "IFMR", 59 "ActsULQ", 60 "ActsULQInputGrad", 61 "ActULQClampMinGrad", 62 "ActULQClampMaxGrad", 63 "WtsARQ" 64 ] 65 66 67class MinMaxUpdatePerLayer(PrimitiveWithInfer): 68 r""" 69 Updates min and max per layer. 70 71 Args: 72 ema (bool): Uses EMA algorithm update value min and max. Default: False. 73 ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. 74 75 Inputs: 76 - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. 77 - **min** (Tensor) : Value of the min range of the input data x. 78 - **max** (Tensor) : Value of the max range of the input data x. 79 80 Outputs: 81 - Tensor: Simulates quantize tensor of x. 82 83 Examples: 84 >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 85 >>> min_tensor = Tensor(np.array([-6]), mstype.float32) 86 >>> max_tensor = Tensor(np.array([6]), mstype.float32) 87 >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) 88 """ 89 support_quant_bit = [4, 7, 8] 90 91 @prim_attr_register 92 def __init__(self, ema=False, ema_decay=0.999): 93 """Initialize FakeQuantMinMaxPerLayerUpdate OP""" 94 if context.get_context('device_target') == "Ascend": 95 from mindspore.ops._op_impl._custom_op import minmax_update_perlayer 96 if ema and not ema_decay: 97 raise ValueError( 98 f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") 99 100 self.ema = validator.check_value_type('ema', ema, (bool,), self.name) 101 self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) 102 self.init_prim_io_names(inputs=['x', 'min', 'max'], 103 outputs=['min_up', 'max_up']) 104 105 def infer_shape(self, x_shape, min_shape, max_shape): 106 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 107 validator.check("min shape", min_shape, "max shape", 108 max_shape, Rel.EQ, self.name) 109 validator.check_equal_int(len(min_shape), 1, "min shape", self.name) 110 return min_shape, max_shape 111 112 def infer_dtype(self, x_type, min_type, max_type): 113 tuple(map(partial(validator.check_tensor_dtype_valid, 114 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 115 ("x", "min", "max"), 116 (x_type, min_type, max_type))) 117 return min_type, max_type 118 119 120class MinMaxUpdatePerChannel(PrimitiveWithInfer): 121 r""" 122 Updates min and max per channel. 123 124 Args: 125 ema (bool): Uses EMA algorithm update value min and max. Default: False. 126 ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. 127 channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1. 128 129 Inputs: 130 - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. 131 - **min** (Tensor) : Value of the min range of the input data x. 132 - **max** (Tensor) : Value of the max range of the input data x. 133 134 Outputs: 135 - Tensor: Simulates quantize tensor of x. 136 137 Examples: 138 >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 139 >>> min_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) 140 >>> max_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32) 141 >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min_value, max_value) 142 """ 143 support_quant_bit = [4, 7, 8] 144 ascend_support_x_rank = [2, 4] 145 146 @prim_attr_register 147 def __init__(self, ema=False, ema_decay=0.999, channel_axis=1): 148 """Initialize FakeQuantPerChannelUpdate OP for Ascend""" 149 self.is_ascend = context.get_context('device_target') == "Ascend" 150 if self.is_ascend: 151 from mindspore.ops._op_impl._custom_op import minmax_update_perchannel 152 if ema and not ema_decay: 153 raise ValueError( 154 f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") 155 156 self.ema = validator.check_value_type('ema', ema, (bool,), self.name) 157 self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) 158 if self.is_ascend: 159 self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) 160 else: 161 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) 162 self.init_prim_io_names( 163 inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up']) 164 165 def infer_shape(self, x_shape, min_shape, max_shape): 166 if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: 167 raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") 168 if not self.is_ascend: 169 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 170 validator.check("min shape", min_shape, "max shape", 171 max_shape, Rel.EQ, self.name) 172 validator.check_equal_int(len(min_shape), 1, "min shape", self.name) 173 return min_shape, max_shape 174 175 def infer_dtype(self, x_type, min_type, max_type): 176 tuple(map(partial(validator.check_tensor_dtype_valid, 177 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 178 ("x", "min", "max"), 179 (x_type, min_type, max_type))) 180 return min_type, max_type 181 182 183class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer): 184 r""" 185 Simulates the quantize and dequantize operations of the fake learned scale quant per-layer case in training time. 186 187 Args: 188 quant_delay (int): Quantilization delay parameter. Before delay step in training time not update 189 simulate quantization aware function. After delay step in training time begin simulate the aware 190 quantize function. Default: 0. 191 neg_trunc (bool): Whether the quantization algorithm uses nagetive truncation or not. Default: False. 192 training (bool): Training the network or not. Default: True. 193 194 Inputs: 195 - **input_x** (Tensor) : Input tensor that needs to be quantified. 196 - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`. 197 - **quant_max** (Tensor) : Value of the quantization range. 198 199 Outputs: 200 - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`. 201 202 Examples: 203 >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 204 >>> alpha_tensor = Tensor(np.array([6]), mstype.float32) 205 >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32) 206 >>> output_tensor = FakeLearnedScaleQuantPerLayer()(input_tensor, alpha_tensor, quant_max_tensor) 207 """ 208 @prim_attr_register 209 def __init__(self, 210 quant_delay=0, 211 neg_trunc=False, 212 training=True): 213 """init FakeLearnedScaleQuantPerLayer OP""" 214 if context.get_context('device_target') == "Ascend": 215 from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer 216 217 self.quant_delay = validator.check_non_negative_int( 218 quant_delay, 'quant_delay', self.name) 219 self.neg_trunc = validator.check_value_type( 220 'neg_trunc', neg_trunc, (bool,), self.name) 221 self.training = validator.check_value_type( 222 'training', training, (bool,), self.name) 223 self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'], 224 outputs=['out']) 225 226 def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape): 227 validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name) 228 validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) 229 validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) 230 return input_x_shape 231 232 def infer_dtype(self, input_x_type, alpha_type, quant_max_type): 233 if context.get_context('device_target') == "GPU": 234 valid_dtypes = (mstype.float32,) 235 else: 236 valid_dtypes = (mstype.float16, mstype.float32) 237 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), 238 ("input_x", "alpha", "quant_max"), 239 (input_x_type, alpha_type, quant_max_type))) 240 return input_x_type 241 242 243class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer): 244 r""" 245 Performs grad of FakeLearnedScaleQuantPerLayer operation. 246 247 Examples: 248 >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerLayerGrad() 249 >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) 250 >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) 251 >>> _alpha = Tensor(np.array([6]), mindspore.float32) 252 >>> _quant_max = Tensor(np.array([127]), mindspore.float32) 253 >>> result = fake_learned_scale_grad(dout, input_x, _min, _max) 254 """ 255 256 @prim_attr_register 257 def __init__(self, 258 quant_delay=0, 259 neg_trunc=False): 260 self.quant_delay = validator.check_non_negative_int( 261 quant_delay, 'quant_delay', self.name) 262 self.neg_trunc = validator.check_value_type( 263 'neg_trunc', neg_trunc, (bool,), self.name) 264 self.init_prim_io_names( 265 inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) 266 267 def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): 268 validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) 269 validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) 270 validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) 271 return dout_shape, alpha_shape 272 273 def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): 274 if context.get_context('device_target') == "GPU": 275 valid_dtypes = (mstype.float32,) 276 else: 277 valid_dtypes = (mstype.float16, mstype.float32) 278 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), 279 ("dout", "x", "alpha", "quant_max"), 280 (dout_type, x_type, alpha_type, quant_max_type))) 281 return dout_type, alpha_type 282 283 284class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer): 285 r""" 286 Performs input grad of FakeLearnedScaleQuantPerLayer operation. 287 """ 288 289 @prim_attr_register 290 def __init__(self, 291 neg_trunc=False): 292 from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad 293 self.neg_trunc = validator.check_value_type( 294 'neg_trunc', neg_trunc, (bool,), self.name) 295 self.init_prim_io_names( 296 inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) 297 298 def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): 299 validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) 300 validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) 301 validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) 302 return dout_shape, dout_shape 303 304 def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): 305 valid_dtypes = (mstype.float16, mstype.float32) 306 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), 307 ("dout", "x", "alpha", "quant_max"), 308 (dout_type, x_type, alpha_type, quant_max_type))) 309 return dout_type, dout_type 310 311 312class FakeLearnedScaleQuantPerLayerGradDReduce(PrimitiveWithInfer): 313 r""" 314 Performs alpha grad reduce of FakeLearnedScaleQuantPerLayer operation. 315 """ 316 317 @prim_attr_register 318 def __init__(self): 319 from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad_reduce 320 self.init_prim_io_names( 321 inputs=['dout_alpha'], outputs=['dalpha']) 322 323 def infer_shape(self, dout_alpha_shape): 324 return (1,) 325 326 def infer_dtype(self, dout_alpha_type): 327 valid_dtypes = (mstype.float16, mstype.float32) 328 validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name) 329 return dout_alpha_type 330 331 332class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer): 333 r""" 334 Simulates the quantize and dequantize operations of the fake learned scale quant per-channel case in training time. 335 336 Args: 337 quant_delay (int): Quantilization delay parameter. Before delay step in training time not update 338 simulate quantization aware function. After delay step in training time begin simulate the aware 339 quantize function. Default: 0. 340 neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: False. 341 training (bool): Training the network or not. Default: True. 342 channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1. 343 344 Inputs: 345 - **input_x** (Tensor) : Input tensor that needs to be quantified. 346 - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`. 347 - **quant_max** (Tensor) : Value of the quantization range. 348 349 Outputs: 350 - Tensor: Simulates quantize tensor of `input_x`,with the same type and shape as the `input_x`. 351 352 Examples: 353 >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 354 >>> alpha_tensor = Tensor(np.array([6]*3), mstype.float32) 355 >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32) 356 >>> output_tensor = FakeLearnedScaleQuantPerChannel()(input_tensor, alpha_tensor, quant_max_tensor) 357 """ 358 ascend_support_x_rank = [2, 4] 359 360 @prim_attr_register 361 def __init__(self, 362 quant_delay=0, 363 neg_trunc=False, 364 training=True, 365 channel_axis=1): 366 """init FakeLearnedScaleQuantPerChannel OP""" 367 if context.get_context('device_target') == "Ascend": 368 from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel 369 self.is_ascend = context.get_context('device_target') == "Ascend" 370 self.quant_delay = validator.check_non_negative_int( 371 quant_delay, 'quant_delay', self.name) 372 self.neg_trunc = validator.check_value_type( 373 'neg_trunc', neg_trunc, (bool,), self.name) 374 self.training = validator.check_value_type( 375 'training', training, (bool,), self.name) 376 if self.is_ascend: 377 self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) 378 else: 379 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) 380 self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'], 381 outputs=['out']) 382 383 def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape): 384 if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank: 385 raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") 386 if not self.is_ascend: 387 validator.check_int(len(input_x_shape), 1, Rel.GE, "input_x rank", self.name) 388 if len(input_x_shape) == 1: 389 self.channel_axis = 0 390 391 validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name) 392 validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) 393 return input_x_shape 394 395 def infer_dtype(self, input_x_type, alpha_type, quant_max_type): 396 if context.get_context('device_target') == "GPU": 397 valid_dtypes = (mstype.float32,) 398 else: 399 valid_dtypes = (mstype.float16, mstype.float32) 400 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), 401 ("input_x", "alpha", "quant_max"), 402 (input_x_type, alpha_type, quant_max_type))) 403 return input_x_type 404 405 406class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer): 407 r""" 408 Performs grad of FakeLearnedScaleQuantPerChannel operation. 409 410 Examples: 411 >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerChannelGrad() 412 >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) 413 >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) 414 >>> _alpha = Tensor(np.array([6]*2), mindspore.float32) 415 >>> _quant_max = Tensor(np.array([127]), mindspore.float32) 416 >>> result = fake_learned_scale_grad(dout, input_x, _min, _max) 417 """ 418 419 @prim_attr_register 420 def __init__(self, 421 quant_delay=0, 422 neg_trunc=False, 423 channel_axis=1): 424 self.quant_delay = validator.check_non_negative_int( 425 quant_delay, 'quant_delay', self.name) 426 self.neg_trunc = validator.check_value_type( 427 'neg_trunc', neg_trunc, (bool,), self.name) 428 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) 429 self.init_prim_io_names( 430 inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) 431 432 def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): 433 validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) 434 return dout_shape, alpha_shape 435 436 def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): 437 if context.get_context('device_target') == "GPU": 438 valid_dtypes = (mstype.float32,) 439 else: 440 valid_dtypes = (mstype.float16, mstype.float32) 441 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), 442 ("dout", "x", "alpha", "quant_max"), 443 (dout_type, x_type, alpha_type, quant_max_type))) 444 return dout_type, alpha_type 445 446 447class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer): 448 r""" 449 Performs input grad of FakeLearnedScaleQuantPerChannel operation. 450 """ 451 452 @prim_attr_register 453 def __init__(self, 454 neg_trunc=False, 455 channel_axis=1): 456 from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad 457 self.neg_trunc = validator.check_value_type( 458 'neg_trunc', neg_trunc, (bool,), self.name) 459 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) 460 self.init_prim_io_names( 461 inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha']) 462 463 def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape): 464 validator.check("dout shape", dout_shape, "x_shape", x_shape, Rel.EQ, self.name) 465 validator.check_int(len(alpha_shape), 1, Rel.GE, "alpha rank", self.name) 466 validator.check_int(len(quant_max_shape), 1, Rel.GE, "quant max rank", self.name) 467 return dout_shape, dout_shape 468 469 def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type): 470 valid_dtypes = (mstype.float16, mstype.float32) 471 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name), 472 ("dout", "x", "alpha", "quant_max"), 473 (dout_type, x_type, alpha_type, quant_max_type))) 474 return dout_type, dout_type 475 476 477class FakeLearnedScaleQuantPerChannelGradDReduce(PrimitiveWithInfer): 478 r""" 479 Performs alpha grad reduce of FakeLearnedScaleQuantPerChannel operation. 480 """ 481 482 @prim_attr_register 483 def __init__(self, channel_axis=1): 484 from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad_reduce 485 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) 486 self.init_prim_io_names( 487 inputs=['dout_alpha'], outputs=['dalpha']) 488 489 def infer_shape(self, dout_alpha_shape): 490 return (dout_alpha_shape[self.channel_axis],) 491 492 def infer_dtype(self, dout_alpha_type): 493 valid_dtypes = (mstype.float16, mstype.float32) 494 validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name) 495 return dout_alpha_type 496 497 498class FakeQuantWithMinMaxVars(PrimitiveWithInfer): 499 r""" 500 Fake-quantize the input by min and max. 501 502 Args: 503 num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8. 504 narrow_range (bool): Whether the quantization algorithm uses narrow range or not. 505 if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization 506 range is [1, 2^num_bits-1]. Default: False. 507 508 Inputs: 509 - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. 510 - **min** (Tensor) - Value of the min range of the input data x. 511 - **max** (Tensor) - Value of the max range of the input data x. 512 513 Outputs: 514 - Tensor, the data type and shape of output tensor is the same as input x. 515 516 Examples: 517 >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 518 >>> min_tensor = Tensor(np.array([-6]), mstype.float32) 519 >>> max_tensor = Tensor(np.array([6]), mstype.float32) 520 >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)( 521 ... input_tensor, min_tensor, max_tensor) 522 >>> output_tensor # shape: (3, 16, 5, 5) data type: mstype.float32 523 """ 524 525 @prim_attr_register 526 def __init__(self, 527 num_bits=8, 528 narrow_range=False): 529 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 530 self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) 531 self.narrow_range = validator.check_value_type( 532 'narrow_range', narrow_range, (bool,), self.name) 533 534 def check_broadcast(self, min_shape, input_shape): 535 shape_val = 1 536 for shape in input_shape: 537 shape_val = shape_val * shape 538 if min_shape[0] > 1 and min_shape[0] != shape_val: 539 raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") 540 541 def infer_shape(self, x_shape, min_shape, max_shape): 542 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 543 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) 544 validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) 545 self.check_broadcast(min_shape, x_shape) 546 return x_shape 547 548 def infer_dtype(self, x_type, min_type, max_type): 549 tuple(map(partial(validator.check_tensor_dtype_valid, 550 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 551 ("x", "min", "max"), 552 (x_type, min_type, max_type))) 553 return x_type 554 555 556class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer): 557 r""" 558 Performs grad of FakeQuantWithMinMaxVars operation. 559 560 Args: 561 num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8. 562 narrow_range (bool): Whether the quantization algorithm uses narrow range or not. 563 if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization 564 range is [1, 2^num_bits-1]. Default: False. 565 566 Inputs: 567 - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars. 568 - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. 569 - **min** (Tensor) - Value of the min range of the input data x. 570 - **max** (Tensor) - Value of the max range of the input data x. 571 572 Outputs: 573 - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x. 574 - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min. 575 - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max. 576 577 Examples: 578 >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 579 >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 580 >>> min_tensor = Tensor(np.array([-6]), mstype.float32) 581 >>> max_tensor = Tensor(np.array([6]), mstype.float32) 582 >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False) 583 ... (gradients, input_tensor, min_tensor, max_tensor) 584 >>> x_gradient # shape: (3, 16, 5, 5) data type: mstype.float32 585 >>> min_gradient # shape: (1,) data type: mstype.float32 586 >>> max_gradient # shape: (1,) data type: mstype.float32 587 """ 588 589 @prim_attr_register 590 def __init__(self, 591 num_bits=8, 592 narrow_range=False): 593 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 594 self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) 595 self.narrow_range = validator.check_value_type( 596 'narrow_range', narrow_range, (bool,), self.name) 597 598 def check_broadcast(self, min_shape, input_shape): 599 shape_val = 1 600 for shape in input_shape: 601 shape_val = shape_val * shape 602 if min_shape[0] > 1 and min_shape[0] != shape_val: 603 raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.") 604 605 def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): 606 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 607 validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) 608 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) 609 validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) 610 self.check_broadcast(min_shape, x_shape) 611 return x_shape, min_shape, max_shape 612 613 def infer_dtype(self, dout_type, x_type, min_type, max_type): 614 tuple(map(partial(validator.check_tensor_dtype_valid, 615 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 616 ('dout', "x", "min", "max"), 617 (dout_type, x_type, min_type, max_type))) 618 return x_type, min_type, max_type 619 620 621class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer): 622 r""" 623 Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max 624 625 Args: 626 num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8. 627 narrow_range (bool): Whether the quantization algorithm uses narrow range or not. 628 if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization 629 range is [1, 2^num_bits-1]. Default: False. 630 631 Inputs: 632 - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. 633 - **min** (Tensor) - Value of the min range of the input data x. 634 - **max** (Tensor) - Value of the max range of the input data x. 635 636 Outputs: 637 - Tensor, the data type and shape of output tensor is the same as input x. 638 639 Examples: 640 >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) 641 >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32) 642 >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32) 643 >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)( 644 ... input_tensor, min_tensor, max_tensor) 645 >>> output_tensor # shape: (3, 16, 3, 4) data type: mstype.float32 646 """ 647 648 @prim_attr_register 649 def __init__(self, 650 num_bits=8, 651 narrow_range=False): 652 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 653 self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) 654 self.narrow_range = validator.check_value_type( 655 'narrow_range', narrow_range, (bool,), self.name) 656 657 def infer_shape(self, x_shape, min_shape, max_shape): 658 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 659 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) 660 validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) 661 validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) 662 return x_shape 663 664 def infer_dtype(self, x_type, min_type, max_type): 665 tuple(map(partial(validator.check_tensor_dtype_valid, 666 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 667 ("x", "min", "max"), 668 (x_type, min_type, max_type))) 669 return x_type 670 671 672class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer): 673 r""" 674 Performs grad of FakeQuantWithMinMaxVars operation. 675 676 Args: 677 num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8. 678 narrow_range (bool): Whether the quantization algorithm uses narrow range or not. 679 if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization 680 range is [1, 2^num_bits-1]. Default: False. 681 682 Inputs: 683 - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars. 684 - **x** (Tensor) - Float32 tensor representing the shape of the output tensor. 685 - **min** (Tensor) - Value of the min range of the input data x. 686 - **max** (Tensor) - Value of the max range of the input data x. 687 688 Outputs: 689 - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x. 690 - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min. 691 - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max. 692 693 Examples: 694 >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) 695 >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32) 696 >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32) 697 >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32) 698 >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient( 699 ... num_bits=8, narrow_range=False)( 700 ... gradients, input_tensor, min_tensor, max_tensor) 701 >>> x_gradient # shape: (3, 16, 3, 4) data type: mstype.float32 702 >>> min_gradient # shape: (4,) data type: mstype.float32 703 >>> max_gradient # shape: (4,) data type: mstype.float32 704 """ 705 706 @prim_attr_register 707 def __init__(self, 708 num_bits=8, 709 narrow_range=False): 710 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 711 self.num_bits = validator.check_int_range(self.num_bits, 2, 16, Rel.INC_BOTH, 'num_bits', self.name) 712 self.narrow_range = validator.check_value_type( 713 'narrow_range', narrow_range, (bool,), self.name) 714 715 def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): 716 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 717 validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) 718 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) 719 validator.check_int(len(min_shape), 1, Rel.EQ, "min shape", self.name) 720 validator.check("min shape", min_shape[0], "x shape", x_shape[-1], Rel.EQ, self.name) 721 return x_shape, min_shape, max_shape 722 723 def infer_dtype(self, dout_type, x_type, min_type, max_type): 724 tuple(map(partial(validator.check_tensor_dtype_valid, 725 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 726 ("dout", "x", "min", "max"), 727 (dout_type, x_type, min_type, max_type))) 728 return x_type, min_type, max_type 729 730 731def _fake_quant_per_infer_dtype(prim_name, x_type, min_type, max_type): 732 if context.get_context('device_target') == "GPU": 733 valid_dtypes = (mstype.float32,) 734 else: 735 valid_dtypes = (mstype.float16, mstype.float32) 736 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name), 737 ("x", "min", "max"), 738 (x_type, min_type, max_type))) 739 return x_type 740 741 742def _fake_quant_per_grad_infer_dtype(prim_name, dout_type, x_type, min_type, max_type): 743 if context.get_context('device_target') == "GPU": 744 valid_dtypes = (mstype.float32,) 745 else: 746 valid_dtypes = (mstype.float16, mstype.float32) 747 tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name), 748 ("dout", "x", "min", "max"), 749 (dout_type, x_type, min_type, max_type))) 750 return dout_type 751 752 753class FakeQuantPerLayer(PrimitiveWithInfer): 754 r""" 755 Simulates the quantize and dequantize operations in training time. 756 757 Args: 758 num_bits (int) : Number bits for quantization aware. Default: 8. 759 ema (bool): Uses EMA algorithm update value min and max. Default: False. 760 ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. 761 quant_delay (int): Quantilization delay parameter. Before delay step in training time not update 762 simulate quantization aware function. After delay step in training time begin simulate the aware 763 quantize function. Default: 0. 764 symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. 765 narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. 766 training (bool): Training the network or not. Default: True. 767 768 Inputs: 769 - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. 770 - **min** (Tensor) : Value of the min range of the input data x. 771 - **max** (Tensor) : Value of the max range of the input data x. 772 773 Outputs: 774 - Tensor: Simulates quantize tensor of x. 775 776 Examples: 777 >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) 778 >>> min_tensor = Tensor(np.array([-6]), mstype.float32) 779 >>> max_tensor = Tensor(np.array([6]), mstype.float32) 780 >>> output_tensor = FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor) 781 """ 782 support_quant_bit = [4, 7, 8] 783 784 @prim_attr_register 785 def __init__(self, 786 num_bits=8, 787 ema=False, 788 ema_decay=0.999, 789 quant_delay=0, 790 symmetric=False, 791 narrow_range=False, 792 training=True): 793 """Initialize FakeQuantPerLayer OP""" 794 if context.get_context('device_target') == "Ascend": 795 from mindspore.ops._op_impl._custom_op import fake_quant_perlayer 796 if num_bits not in self.support_quant_bit: 797 raise ValueError( 798 f"For '{self.name}' attr \'num_bits\' is not support.") 799 if ema and not ema_decay: 800 raise ValueError( 801 f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") 802 803 self.ema = validator.check_value_type('ema', ema, (bool,), self.name) 804 self.symmetric = validator.check_value_type( 805 'symmetric', symmetric, (bool,), self.name) 806 self.narrow_range = validator.check_value_type( 807 'narrow_range', narrow_range, (bool,), self.name) 808 self.training = validator.check_value_type('training', training, (bool,), self.name) 809 self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) 810 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 811 self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) 812 self.init_prim_io_names(inputs=['x', 'min', 'max'], 813 outputs=['out']) 814 815 def infer_shape(self, x_shape, min_shape, max_shape): 816 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 817 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) 818 validator.check_equal_int(len(min_shape), 1, "min shape", self.name) 819 return x_shape 820 821 def infer_dtype(self, x_type, min_type, max_type): 822 return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type) 823 824 825class FakeQuantPerLayerGrad(PrimitiveWithInfer): 826 r""" 827 Performs grad of FakeQuantPerLayer operation. 828 829 Examples: 830 >>> fake_min_max_grad = FakeQuantPerLayerGrad() 831 >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32) 832 >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32) 833 >>> _min = Tensor(np.array([-4]), mindspore.float32) 834 >>> _max = Tensor(np.array([2]), mindspore.float32) 835 >>> result = fake_min_max_grad(dout, input_x, _min, _max) 836 """ 837 support_quant_bit = [4, 7, 8] 838 839 @prim_attr_register 840 def __init__(self, 841 num_bits=8, 842 quant_delay=0, 843 symmetric=False, 844 narrow_range=False): 845 if context.get_context('device_target') == "Ascend": 846 from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad 847 if num_bits not in self.support_quant_bit: 848 raise ValueError( 849 f"For '{self.name}' attr \'num_bits\' is not support.") 850 851 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 852 self.quant_delay = validator.check_value_type( 853 'quant_delay', quant_delay, (int,), self.name) 854 self.symmetric = validator.check_value_type( 855 'symmetric', symmetric, (bool,), self.name) 856 self.narrow_range = validator.check_value_type( 857 'narrow_range', narrow_range, (bool,), self.name) 858 self.init_prim_io_names( 859 inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) 860 861 def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): 862 validator.check("dout shape", dout_shape, "x shape", 863 x_shape, Rel.EQ, self.name) 864 validator.check("min shape", min_shape, "max shape", 865 max_shape, Rel.EQ, self.name) 866 validator.check_equal_int(len(min_shape), 1, "min shape", self.name) 867 return dout_shape 868 869 def infer_dtype(self, dout_type, x_type, min_type, max_type): 870 return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type) 871 872 873class FakeQuantPerChannel(PrimitiveWithInfer): 874 r""" 875 Simulates the quantize and dequantize operations in training time base on per channel. 876 877 Args: 878 num_bits (int) : Number bits to quantilization. Default: 8. 879 ema (bool): Uses EMA algorithm update tensor min and tensor max. Default: False. 880 ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. 881 quant_delay (int): Quantilization delay parameter. Before delay step in training time not 882 update the weight data to simulate quantize operation. After delay step in training time 883 begin simulate the quantize operation. Default: 0. 884 symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. 885 narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. 886 training (bool): Training the network or not. Default: True. 887 channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1. 888 889 Inputs: 890 - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. 891 - **min** (int, float) : Value of the min range of the input data. 892 - **max** (int, float) : Value of the max range of the input data. 893 894 Outputs: 895 - Tensor, has the same type as input. 896 897 Examples: 898 >>> fake_quant = FakeQuantPerChannel() 899 >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32) 900 >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32) 901 >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32) 902 >>> result = fake_quant(input_x, _min, _max) 903 """ 904 support_quant_bit = [4, 7, 8] 905 ascend_support_x_rank = [2, 4] 906 907 @prim_attr_register 908 def __init__(self, 909 num_bits=8, 910 ema=False, 911 ema_decay=0.999, 912 quant_delay=0, 913 symmetric=False, 914 narrow_range=False, 915 training=True, 916 channel_axis=1): 917 """Initialize FakeQuantPerChannel OP""" 918 self.is_ascend = context.get_context('device_target') == "Ascend" 919 if self.is_ascend: 920 from mindspore.ops._op_impl._custom_op import fake_quant_perchannel 921 if num_bits not in self.support_quant_bit: 922 raise ValueError( 923 f"For '{self.name}' Attr \'num_bits\' is not support.") 924 if ema and not ema_decay: 925 raise ValueError( 926 f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") 927 928 self.ema = validator.check_value_type('ema', ema, (bool,), self.name) 929 self.symmetric = validator.check_value_type( 930 'symmetric', symmetric, (bool,), self.name) 931 self.narrow_range = validator.check_value_type( 932 'narrow_range', narrow_range, (bool,), self.name) 933 self.training = validator.check_value_type( 934 'training', training, (bool,), self.name) 935 self.ema_decay = validator.check_float_range(ema_decay, 0, 1, Rel.INC_BOTH, 'ema_decay', self.name) 936 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 937 self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name) 938 if self.is_ascend: 939 self.channel_axis = validator.check_int_range(channel_axis, 0, 1, Rel.INC_BOTH, 'channel_axis', self.name) 940 else: 941 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name) 942 self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) 943 944 def infer_shape(self, x_shape, min_shape, max_shape): 945 if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank: 946 raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'") 947 if not self.is_ascend: 948 validator.check_int(len(x_shape), 1, Rel.GE, "x rank", self.name) 949 if len(x_shape) == 1: 950 self.channel_axis = 0 951 validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) 952 validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name) 953 validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name) 954 return x_shape 955 956 def infer_dtype(self, x_type, min_type, max_type): 957 return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type) 958 959 960class FakeQuantPerChannelGrad(PrimitiveWithInfer): 961 r""" 962 Performs grad of FakeQuantPerChannel operation. 963 964 Examples: 965 >>> fqmmpc_grad = FakeQuantPerChannelGrad() 966 >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32) 967 >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32) 968 >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32) 969 >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32) 970 >>> result = fqmmpc_grad(dout, input_x, _min, _max) 971 """ 972 support_quant_bit = [4, 7, 8] 973 974 @prim_attr_register 975 def __init__(self, 976 num_bits=8, 977 quant_delay=0, 978 symmetric=False, 979 narrow_range=False, 980 channel_axis=1): 981 """Initialize FakeQuantPerChannelGrad Fill""" 982 if context.get_context('device_target') == "Ascend": 983 from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad 984 if num_bits not in self.support_quant_bit: 985 raise ValueError( 986 f"For '{self.name}' attr \'num_bits\' is not support.") 987 988 self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name) 989 self.quant_delay = validator.check_value_type( 990 'quant_delay', quant_delay, (int,), self.name) 991 self.symmetric = validator.check_value_type( 992 'symmetric', symmetric, (bool,), self.name) 993 self.narrow_range = validator.check_value_type( 994 'narrow_range', narrow_range, (bool,), self.name) 995 self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name) 996 self.init_prim_io_names( 997 inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) 998 999 def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): 1000 validator.check("dout shape", dout_shape, "x shape", x_shape) 1001 validator.check("min shape", min_shape, "max shape", max_shape) 1002 return dout_shape 1003 1004 def infer_dtype(self, dout_type, x_type, min_type, max_type): 1005 return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type) 1006 1007 1008class BatchNormFold(PrimitiveWithInfer): 1009 """ 1010 Batch Normalization folded. 1011 1012 Args: 1013 momentum (float): Momentum value must be [0, 1]. Default: 0.9. 1014 epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in 1015 float32 else 1e-3. Default: 1e-5. 1016 is_training (bool): In training mode set True, else set False. Default: True. 1017 freeze_bn (int): Delay in steps at which computation switches from regular batch 1018 norm to frozen mean and std. Default: 0. 1019 1020 Inputs: 1021 - **x** (Tensor) - Tensor of shape :math:`(N, C)`. 1022 - **mean** (Tensor) - Tensor of shape :math:`(C,)`. 1023 - **variance** (Tensor) - Tensor of shape :math:`(C,)`. 1024 - **global_step** (Tensor) - Tensor to record current global step. 1025 1026 Outputs: 1027 Tuple of 4 Tensor, the normalized input and the updated parameters. 1028 1029 - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. 1030 - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. 1031 - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. 1032 - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. 1033 1034 Examples: 1035 >>> batch_norm_fold = P.BatchNormFold() 1036 >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32) 1037 >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32) 1038 >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32) 1039 >>> global_step = Tensor(np.arange(6), mindspore.int32) 1040 >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step) 1041 """ 1042 channel_axis = 1 1043 1044 @prim_attr_register 1045 def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): 1046 """Initialize batch norm fold layer""" 1047 self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) 1048 self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) 1049 self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) 1050 self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) 1051 1052 self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'], 1053 outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std']) 1054 1055 def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): 1056 validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) 1057 validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) 1058 validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) 1059 return mean_shape, mean_shape, mean_shape, mean_shape 1060 1061 def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): 1062 validator.check("input type", x_type, "mean type", mean_type) 1063 validator.check("input type", x_type, "variance type", variance_type) 1064 args = {"x": x_type, "mean": mean_type, "variance": variance_type} 1065 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1066 validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) 1067 return x_type, x_type, x_type, x_type 1068 1069 1070class BatchNormFoldGrad(PrimitiveWithInfer): 1071 r""" 1072 Performs grad of BatchNormFold operation. 1073 1074 Examples: 1075 >>> batch_norm_fold_grad = ops.BatchNormFoldGrad() 1076 >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32) 1077 >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32) 1078 >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32) 1079 >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32) 1080 >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32) 1081 >>> global_step = Tensor([2], mindspore.int32) 1082 >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step) 1083 """ 1084 channel_axis = 1 1085 1086 @prim_attr_register 1087 def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): 1088 """Initialize BatchNormGrad layer""" 1089 self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) 1090 self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) 1091 self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) 1092 self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], 1093 outputs=['dx']) 1094 1095 def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape, 1096 global_step_shape): 1097 validator.check("d_batch_mean shape", d_batch_mean_shape, 1098 "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name) 1099 validator.check("d_batch_mean shape", d_batch_mean_shape, 1100 "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) 1101 validator.check("d_batch_mean shape", d_batch_mean_shape, 1102 "batch_std shape", batch_std_shape, Rel.EQ, self.name) 1103 validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], 1104 "input channel", x_shape[self.channel_axis], Rel.EQ, self.name) 1105 validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) 1106 return x_shape 1107 1108 def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, 1109 global_step_type): 1110 args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, 1111 "batch_mean": batch_mean_type, "batch_std": batch_std_type} 1112 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1113 validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) 1114 return x_type 1115 1116 1117class CorrectionMul(PrimitiveWithInfer): 1118 """ 1119 Scales the weights with a correction factor to the long term statistics 1120 prior to quantization. This ensures that there is no jitter in the quantized weights 1121 due to batch to batch variation. 1122 1123 Inputs: 1124 - **x** (Tensor) - Tensor of shape :math:`(N, C)`. 1125 - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. 1126 - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. 1127 1128 Outputs: 1129 - **out** (Tensor) - Tensor has the same shape as x. 1130 1131 Examples: 1132 >>> correction_mul = ops.CorrectionMul() 1133 >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32) 1134 >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32) 1135 >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32) 1136 >>> out = correction_mul(input_x, batch_std, running_std) 1137 """ 1138 1139 @prim_attr_register 1140 def __init__(self, channel_axis=0): 1141 """Initialize correction mul layer""" 1142 if context.get_context('device_target') == "Ascend": 1143 from mindspore.ops._op_impl._custom_op import correction_mul 1144 self.channel_axis = channel_axis 1145 self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], 1146 outputs=['out']) 1147 1148 def infer_shape(self, x_shape, batch_std_shape, running_std_shape): 1149 validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) 1150 validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], 1151 Rel.EQ, self.name) 1152 return x_shape 1153 1154 def infer_dtype(self, x_type, batch_std_type, running_std_type): 1155 args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} 1156 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1157 return x_type 1158 1159 1160class CorrectionMulGrad(PrimitiveWithInfer): 1161 r""" 1162 Performs grad of CorrectionMul operation. 1163 1164 Examples: 1165 >>> correction_mul_grad = ops.CorrectionMulGrad() 1166 >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32) 1167 >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32) 1168 >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32) 1169 >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) 1170 >>> result = correction_mul_grad(dout, input_x, gamma, running_std) 1171 """ 1172 1173 @prim_attr_register 1174 def __init__(self, channel_axis=0): 1175 """Initialize correction mul layer""" 1176 if context.get_context('device_target') == "Ascend": 1177 from mindspore.ops._op_impl._custom_op import correction_mul_grad 1178 self.channel_axis = channel_axis 1179 self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], 1180 outputs=['dx', 'mul_dx']) 1181 1182 def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): 1183 validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) 1184 validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis], 1185 Rel.EQ, self.name) 1186 validator.check("running_std_shape[0]", running_std_shape[0], 1187 "dout channel size", dout_shape[self.channel_axis], Rel.EQ, self.name) 1188 if context.get_context('device_target') == "Ascend": 1189 return x_shape, x_shape 1190 return x_shape, gamma_shape 1191 1192 def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): 1193 args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} 1194 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1195 if context.get_context('device_target') == "Ascend": 1196 return x_type, x_type 1197 return x_type, gamma_type 1198 1199 1200class CorrectionMulGradReduce(PrimitiveWithInfer): 1201 r""" 1202 Performs grad reduce of CorrectionMul operation. 1203 1204 Examples: 1205 >>> correction_mul_grad_rd = ops.CorrectionMulGradReduce() 1206 >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32) 1207 >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32) 1208 >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32) 1209 >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32) 1210 >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std) 1211 """ 1212 1213 @prim_attr_register 1214 def __init__(self, channel_axis=0): 1215 """Initialize correction mul reduce layer""" 1216 if context.get_context('device_target') == "Ascend": 1217 from mindspore.ops._op_impl._custom_op import correction_mul_grad 1218 self.channel_axis = channel_axis 1219 self.init_prim_io_names(inputs=['mul_dx'], 1220 outputs=['d_gamma']) 1221 1222 def infer_shape(self, mul_dx_shape): 1223 return [mul_dx_shape[self.channel_axis]] 1224 1225 def infer_dtype(self, mul_dx_type): 1226 return mul_dx_type 1227 1228 1229class BatchNormFold2(PrimitiveWithInfer): 1230 """ 1231 Scales the bias with a correction factor to the long term statistics 1232 prior to quantization. This ensures that there is no jitter in the quantized bias 1233 due to batch to batch variation. 1234 1235 Inputs: 1236 - **x** (Tensor) - Tensor of shape :math:`(N, C)`. 1237 - **beta** (Tensor) - Tensor of shape :math:`(C,)`. 1238 - **gamma** (Tensor) - Tensor of shape :math:`(C,)`. 1239 - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. 1240 - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. 1241 - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. 1242 - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. 1243 - **global_step** (Tensor) - Tensor to record current global step. 1244 1245 Outputs: 1246 - **y** (Tensor) - Tensor has the same shape as x. 1247 1248 Examples: 1249 >>> batch_norm_fold2 = ops.BatchNormFold2() 1250 >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32) 1251 >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32) 1252 >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32) 1253 >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32) 1254 >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32) 1255 >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32) 1256 >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32) 1257 >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32) 1258 >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean, 1259 >>> running_std, running_mean, global_step) 1260 """ 1261 channel_axis = 1 1262 1263 @prim_attr_register 1264 def __init__(self, freeze_bn=0): 1265 """Initialize conv2d fold layer""" 1266 self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) 1267 self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 1268 'running_std', 'running_mean', 'global_step'], 1269 outputs=['y']) 1270 1271 def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape, 1272 running_mean_shape, global_step_shape): 1273 validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) 1274 validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) 1275 validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) 1276 validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, 1277 Rel.EQ, self.name) 1278 validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) 1279 validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], 1280 Rel.EQ, self.name) 1281 validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) 1282 return x_shape 1283 1284 def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, 1285 running_mean_type, global_step_type): 1286 args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, 1287 "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} 1288 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1289 validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) 1290 return x_type 1291 1292 1293class BatchNormFold2Grad(PrimitiveWithInfer): 1294 r""" 1295 Performs grad of BatchNormFold2 operation. 1296 1297 Examples: 1298 >>> bnf2_grad = ops.BatchNormFold2Grad() 1299 >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32) 1300 >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32) 1301 >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32) 1302 >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32) 1303 >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32) 1304 >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32) 1305 >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32) 1306 >>> global_step = Tensor(np.array([-2]), mindspore.int32) 1307 >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step) 1308 """ 1309 channel_axis = 1 1310 1311 @prim_attr_register 1312 def __init__(self, freeze_bn=0): 1313 """Initialize MulFold layer""" 1314 self.freeze_bn = freeze_bn 1315 self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 1316 'batch_std', 'batch_mean', 1317 'running_std', 'running_mean', 'global_step'], 1318 outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx']) 1319 1320 def infer_shape(self, dout_shape, x_shape, gamma_shape, 1321 batch_std_shape, batch_mean_shape, 1322 running_std_shape, running_mean_shape, global_step_shape): 1323 validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) 1324 validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) 1325 validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, 1326 Rel.EQ, self.name) 1327 validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) 1328 validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], 1329 Rel.EQ, self.name) 1330 validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name) 1331 return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape 1332 1333 def infer_dtype(self, dout_type, x_type, gamma_type, 1334 batch_std_type, batch_mean_type, 1335 running_std_type, running_mean_type, global_step_type): 1336 validator.check("batch_std type", batch_std_type, 1337 "batch_mean type", batch_mean_type) 1338 validator.check("batch_std type", batch_std_type, 1339 "gamma type", gamma_type) 1340 validator.check("batch_std type", batch_std_type, 1341 "running_std type", running_std_type) 1342 validator.check("batch_std type", batch_std_type, 1343 "running_mean type", running_mean_type) 1344 validator.check("batch_std_type", batch_std_type, 1345 "dout type", dout_type) 1346 args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, 1347 "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} 1348 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1349 validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name) 1350 return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type 1351 1352 1353class BatchNormFoldD(PrimitiveWithInfer): 1354 """Performs grad of _BatchNormFold operation.""" 1355 1356 @prim_attr_register 1357 def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0): 1358 """Initialize _BatchNormFold layer""" 1359 from mindspore.ops._op_impl._custom_op import batchnorm_fold 1360 self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) 1361 self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) 1362 self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) 1363 self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) 1364 self.data_format = "NCHW" 1365 self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'], 1366 outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std', 1367 'mean_updated', 'variance_updated']) 1368 1369 def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape): 1370 validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) 1371 validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], Rel.EQ, self.name) 1372 return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape 1373 1374 def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type): 1375 validator.check("input type", x_type, "mean type", mean_type) 1376 validator.check("input type", x_type, "variance type", variance_type) 1377 args = {"x": x_type, "mean": mean_type, "variance": variance_type} 1378 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1379 return x_type, x_type, x_type, x_type, x_type, x_type, x_type 1380 1381 1382class BatchNormFoldGradD(PrimitiveWithInfer): 1383 """Performs grad of BatchNormFold operation.""" 1384 1385 @prim_attr_register 1386 def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): 1387 """Initialize _BatchNormFoldGrad layer""" 1388 from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad 1389 self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name) 1390 self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) 1391 self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) 1392 self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'], 1393 outputs=['dx']) 1394 1395 def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape): 1396 validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape) 1397 validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape) 1398 validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape) 1399 validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1]) 1400 return x_shape 1401 1402 def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type): 1403 validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type) 1404 validator.check("input type", x_type, "d_batch_std type", d_batch_std_type) 1405 validator.check("input type", x_type, "batch_mean type", batch_mean_type) 1406 validator.check("input type", x_type, "batch_std type", batch_std_type) 1407 validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name) 1408 return x_type 1409 1410 1411class BatchNormFold2D(PrimitiveWithInfer): 1412 """ 1413 Scales the bias with a correction factor to the long term statistics 1414 prior to quantization. This ensures that there is no jitter in the quantized bias 1415 due to batch to batch variation. 1416 1417 Inputs: 1418 - **x** (Tensor) - Tensor of shape :math:`(N, C)`. 1419 - **beta** (Tensor) - Tensor of shape :math:`(C,)`. 1420 - **gamma** (Tensor) - Tensor of shape :math:`(C,)`. 1421 - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. 1422 - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. 1423 - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. 1424 - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. 1425 - **global_step** (Tensor) - Tensor to record current global step. 1426 1427 Outputs: 1428 - **y** (Tensor) - Tensor has the same shape as x. 1429 1430 """ 1431 channel_axis = 1 1432 1433 @prim_attr_register 1434 def __init__(self, freeze_bn=0): 1435 """Initialize conv2d fold layer""" 1436 from mindspore.ops._op_impl._custom_op import batchnorm_fold2 1437 self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'], 1438 outputs=['y']) 1439 1440 def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape): 1441 validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) 1442 validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) 1443 validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) 1444 validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) 1445 validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis], 1446 Rel.EQ, self.name) 1447 return x_shape 1448 1449 def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type): 1450 args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, 1451 "beta": beta_type, "gamma": gamma_type, "x": x_type} 1452 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1453 return x_type 1454 1455 1456class BatchNormFold2GradD(PrimitiveWithInfer): 1457 """Performs grad of BatchNormFold2 operation.""" 1458 channel_axis = 1 1459 1460 @prim_attr_register 1461 def __init__(self, freeze_bn=False): 1462 """Initialize MulFold layer""" 1463 from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad 1464 self.freeze_bn = freeze_bn 1465 self.init_prim_io_names( 1466 inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'], 1467 outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx']) 1468 1469 def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape, 1470 batch_mean_shape, running_std_shape): 1471 validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) 1472 validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) 1473 validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) 1474 validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis], 1475 Rel.EQ, self.name) 1476 return gamma_shape, gamma_shape, gamma_shape, dout_shape 1477 1478 def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type, 1479 batch_mean_type, running_std_type): 1480 validator.check("batch_std type", batch_std_type, 1481 "batch_mean type", batch_mean_type) 1482 validator.check("batch_std type", batch_std_type, 1483 "gamma type", gamma_type) 1484 validator.check("batch_std type", batch_std_type, 1485 "running_std type", running_std_type) 1486 validator.check("batch_std_type", batch_std_type, 1487 "dout type", dout_type) 1488 args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, 1489 "running_std": running_std_type, "dout": dout_type} 1490 validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) 1491 return gamma_type, gamma_type, gamma_type, gamma_type 1492 1493 1494class BatchNormFold2GradReduce(PrimitiveWithInfer): 1495 """Performs grad of CorrectionAddGrad operation.""" 1496 channel_axis = 1 1497 1498 @prim_attr_register 1499 def __init__(self, freeze_bn=False): 1500 """Initialize MulFold layer""" 1501 from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce 1502 self.freeze_bn = freeze_bn 1503 self.init_prim_io_names(inputs=['dout', 'x'], 1504 outputs=['dout_reduce', 'dout_x_reduce']) 1505 1506 def infer_shape(self, dout_shape, x_shape): 1507 validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) 1508 return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],) 1509 1510 def infer_dtype(self, dout_type, x_type): 1511 validator.check("dout type", dout_type, "x type", x_type) 1512 return dout_type, dout_type 1513 1514 1515class ActsULQ(PrimitiveWithInfer): 1516 """ 1517 The ActsULQ(Activation universal learnable quantization). 1518 1519 Args: 1520 fixed_min (bool): whether fix clamp min to zero. 1521 num_bits (int): The bits num used for quantize. 1522 1523 Inputs: 1524 - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type. 1525 - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x. 1526 - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x. 1527 1528 Outputs: 1529 - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`. 1530 - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min. 1531 - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max. 1532 - **x_clamped_loss** (Tensor) - A tensor of clamped loss. 1533 1534 Examples: 1535 >>> data_type = np.float32 1536 >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type) 1537 >>> clamp_max = 0.7 * np.max(x) 1538 >>> clamp_min = 0.7 * np.min(x) 1539 >>> clamp_max = np.array([clamp_max], dtype=data_type) 1540 >>> clamp_min = np.array([clamp_min], dtype=data_type) 1541 >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8) 1542 >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min), 1543 Tensor(clamp_max)) 1544 """ 1545 @prim_attr_register 1546 def __init__(self, fixed_min=False, num_bits=8): 1547 validator.check_value_type("fixed_min", fixed_min, [bool], self.name) 1548 validator.check_value_type("num_bits", num_bits, [int], self.name) 1549 validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name) 1550 1551 def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape): 1552 """infer shape of primitive""" 1553 validator.check_int(len(clamp_min_shape), len(x_shape), Rel.EQ, "dims of clamp_min", self.name) 1554 validator.check_int(len(clamp_max_shape), len(x_shape), Rel.EQ, "dims of clamp_max", self.name) 1555 1556 x_shape_len = len(x_shape) 1557 for i in range(x_shape_len): 1558 validator.check_int(clamp_min_shape[i], 1, Rel.EQ, "dims of clamp_min", self.name) 1559 validator.check_int(clamp_max_shape[i], 1, Rel.EQ, "dims of clamp_max", self.name) 1560 1561 return x_shape, x_shape, x_shape, x_shape 1562 1563 def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype): 1564 """infer dtype of primitive""" 1565 valid_types = [mstype.float32, mstype.float16] 1566 validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name) 1567 validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name) 1568 validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name) 1569 1570 return x_dtype, mstype.bool_, mstype.bool_, x_dtype 1571 1572 1573class ActsULQInputGrad(PrimitiveWithInfer): 1574 """ 1575 The ActsULQInputGrad(grad of ActsULQ). 1576 1577 Inputs: 1578 - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type. 1579 1580 Outputs: 1581 - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`. 1582 """ 1583 @prim_attr_register 1584 def __init__(self): 1585 pass 1586 1587 def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape): 1588 return y_grad_shape 1589 1590 def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type): 1591 valid_types = [mstype.float32, mstype.float16] 1592 validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name) 1593 return y_grad_type 1594 1595 1596class ActULQClampMinGrad(PrimitiveWithInfer): 1597 """ 1598 The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient) 1599 1600 Inputs: 1601 - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type. 1602 - **clamp_min_mask** - A tensor of mask, only support int8 type. 1603 - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad". 1604 1605 Outputs: 1606 - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad". 1607 The length of tensor is 1. 1608 1609 Examples: 1610 >>> data_type = np.float32 1611 >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type) 1612 >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0) 1613 >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type) 1614 >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad() 1615 >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_), 1616 Tensor(x_clamped_loss)) 1617 """ 1618 @prim_attr_register 1619 def __init__(self): 1620 pass 1621 1622 def infer_shape(self, input_x, input_y, input_z): 1623 input_x_len = len(input_x) 1624 output_shape = [] 1625 for _ in range(input_x_len): 1626 output_shape.append(1) 1627 return tuple(output_shape) 1628 1629 def infer_dtype(self, input_x, input_y, input_z): 1630 return mstype.float32 1631 1632 1633class ActULQClampMaxGrad(PrimitiveWithInfer): 1634 """ 1635 The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient) 1636 1637 Inputs: 1638 - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type. 1639 - **clamp_max_mask** - A tensor of mask, only support int8 type. 1640 - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad". 1641 1642 Outputs: 1643 - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad". 1644 The length of tensor is 1. 1645 1646 Examples: 1647 >>> data_type = np.float32 1648 >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type) 1649 >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0) 1650 >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type) 1651 >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad() 1652 >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_), 1653 Tensor(x_clamped_loss)) 1654 """ 1655 @prim_attr_register 1656 def __init__(self): 1657 pass 1658 1659 def infer_shape(self, input_x, input_y, input_z): 1660 input_x_len = len(input_x) 1661 output_shape = [] 1662 for _ in range(input_x_len): 1663 output_shape.append(1) 1664 return tuple(output_shape) 1665 1666 def infer_dtype(self, input_x, input_y, input_z): 1667 return mstype.float32 1668 1669 1670class WtsARQ(PrimitiveWithInfer): 1671 """ 1672 The WtsARQ(Weights Adaptive Range Quantization). 1673 1674 Args: 1675 num_bits (int): The bits num used for quantize. 1676 offset_flag (bool): Whether use offset for quantize. 1677 1678 Inputs: 1679 - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type. 1680 1681 Outputs: 1682 - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`. 1683 - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`. 1684 - If axis is [], 1685 the shape of scale and offset is :math:`(1, )`. 1686 - If axis is [0], 1687 the shape of scale and offset is :math:`(w_1, )`. 1688 - If axis is [1], 1689 the shape of scale and offset is :math:`(w_2, )`. 1690 - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`. 1691 1692 Examples: 1693 >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) 1694 >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False) 1695 >>> scale, offset, y = wts_arq(data) 1696 """ 1697 @prim_attr_register 1698 def __init__(self, num_bits, offset_flag): 1699 validator.check_value_type("num_bits", num_bits, [int], self.name) 1700 validator.check_int(num_bits, 8, Rel.EQ, "value of num_bits", self.name) 1701 validator.check_value_type("offset_flag", offset_flag, [bool], self.name) 1702 1703 def infer_shape(self, w_shape, w_min_shape, w_max_shape): 1704 validator.check_int(len(w_min_shape), len(w_shape), Rel.EQ, "dims of w_min", self.name) 1705 validator.check_int(len(w_max_shape), len(w_shape), Rel.EQ, "dims of w_max", self.name) 1706 return w_shape 1707 1708 def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype): 1709 valid_types = [mstype.float32, mstype.float16] 1710 validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name) 1711 validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name) 1712 validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name) 1713 return w_dtype 1714 1715 1716class IFMR(PrimitiveWithInfer): 1717 """ 1718 The TFMR(Input Feature Map Reconstruction). 1719 1720 Args: 1721 min_percentile (float): Min init percentile. Default: 0.999999. 1722 max_percentile (float): Max init percentile. Default: 0.999999. 1723 search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3]. 1724 search_step (float): Step size of searching. Default: 0.01. 1725 with_offset (bool): Whether using offset. Default: True. 1726 1727 Inputs: 1728 - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type. 1729 - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`. 1730 With float16 or float32 data type. 1731 - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`. 1732 With float16 or float32 data type. 1733 - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type. 1734 1735 Outputs: 1736 - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32. 1737 - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32. 1738 1739 Examples: 1740 >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32)) 1741 >>> data_min = Tensor([0.1], mindspore.float32) 1742 >>> data_max = Tensor([0.5], mindspore.float32) 1743 >>> cumsum = Tensor(np.random.rand(4).astype(np.int32)) 1744 >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0), 1745 ... search_step=1.0, with_offset=False) 1746 >>> output = ifmr(data, data_min, data_max, cumsum) 1747 >>> print(output) 1748 (Tensor(shape=[1], dtype=Float32, value= [7.87401572e-03]), 1749 Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00])) 1750 """ 1751 1752 @prim_attr_register 1753 def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01, 1754 with_offset=True): 1755 validator.check_value_type("min_percentile", min_percentile, [float], self.name) 1756 validator.check_value_type("max_percentile", max_percentile, [float], self.name) 1757 validator.check_value_type("search_range", search_range, [list, tuple], self.name) 1758 for item in search_range: 1759 validator.check_positive_float(item, "item of search_range", self.name) 1760 validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], Rel.GE, self.name) 1761 validator.check_value_type("search_step", search_step, [float], self.name) 1762 validator.check_value_type("offset_flag", with_offset, [bool], self.name) 1763 1764 def infer_shape(self, data_shape, data_min_shape, data_max_shape, cumsum_shape): 1765 validator.check_equal_int(len(data_min_shape), 1, "dims of data_min", self.name) 1766 validator.check_equal_int(data_min_shape[0], 1, "data_min[0]", self.name) 1767 validator.check_equal_int(len(data_max_shape), 1, "dims of data_max", self.name) 1768 validator.check_equal_int(data_max_shape[0], 1, "data_max[0]", self.name) 1769 validator.check_equal_int(len(cumsum_shape), 1, "dims of cumsum", self.name) 1770 return (1,), (1,) 1771 1772 def infer_dtype(self, data_dtype, data_min_dtype, data_max_dtype, cumsum_dtype): 1773 tuple(map(partial(validator.check_tensor_dtype_valid, 1774 valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name), 1775 ("input_value", "input_min", "input_max"), 1776 (data_dtype, data_min_dtype, data_max_dtype))) 1777 validator.check_tensor_dtype_valid("input_bins", cumsum_dtype, [mstype.int32], self.name) 1778 return mstype.tensor_type(mstype.float32), mstype.tensor_type(mstype.float32) 1779