1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""normalization""" 16import itertools 17import numbers 18 19from mindspore.ops import operations as P 20from mindspore.ops import functional as F 21from mindspore.ops.operations import _inner_ops as inner 22from mindspore.common.parameter import Parameter 23from mindspore.common.initializer import initializer, Initializer 24from mindspore.common.tensor import Tensor 25from mindspore.common._decorator import deprecated 26from mindspore.ops.primitive import constexpr 27import mindspore.context as context 28from mindspore._checkparam import Rel 29from mindspore._checkparam import Validator as validator 30from mindspore._extends import cell_attr_register 31from mindspore.communication.management import get_group_size, get_rank 32from mindspore.communication import management 33from mindspore.common import dtype as mstype 34from mindspore.parallel._utils import _is_in_auto_parallel_mode 35from ..cell import Cell 36 37__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm', 38 'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d'] 39 40SYNC_BN_GROUP_NAME = "" 41 42 43class _BatchNorm(Cell): 44 """Batch Normalization base class.""" 45 46 @cell_attr_register 47 def __init__(self, 48 num_features, 49 eps=1e-5, 50 momentum=0.9, 51 affine=True, 52 gamma_init='ones', 53 beta_init='zeros', 54 moving_mean_init='zeros', 55 moving_var_init='ones', 56 use_batch_statistics=None, 57 device_num_each_group=1, 58 process_groups=0, 59 input_dims='2d', 60 data_format='NCHW'): 61 """Initialize _BatchNorm.""" 62 super(_BatchNorm, self).__init__() 63 validator.check_value_type('num_features', num_features, [int], self.cls_name) 64 if num_features < 1: 65 raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.") 66 67 if momentum < 0 or momentum > 1: 68 raise ValueError(f"For '{self.cls_name}', the 'momentum' should be a number in range [0, 1], " 69 f"but got {momentum}.") 70 self.input_dims = input_dims 71 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) 72 if context.get_context("device_target") != "GPU" and self.format == "NHWC": 73 raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device " 74 f"target {context.get_context('device_target')}.") 75 self.use_batch_statistics = use_batch_statistics 76 if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): 77 raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' should be a boolean value or None," 78 f" but got {use_batch_statistics}.") 79 self.num_features = num_features 80 self.eps = eps 81 self.moving_mean = Parameter(initializer( 82 moving_mean_init, num_features), name="mean", requires_grad=False) 83 self.moving_variance = Parameter(initializer( 84 moving_var_init, num_features), name="variance", requires_grad=False) 85 self.gamma = Parameter(initializer( 86 gamma_init, num_features), name="gamma", requires_grad=affine) 87 self.beta = Parameter(initializer( 88 beta_init, num_features), name="beta", requires_grad=affine) 89 self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group", 90 self.cls_name) 91 self.process_groups = process_groups 92 self.is_global = False 93 self.parallel_mode = context.get_auto_parallel_context("parallel_mode") 94 global SYNC_BN_GROUP_NAME 95 # for GlobalBatchNorm 96 if self.group_device_num != 1: 97 self.rank_id = get_rank() 98 self.rank_size = get_group_size() 99 self.device_list = [i for i in range(0, self.rank_size)] 100 self.rank_list = self.list_group(self.device_list, self.group_device_num) 101 self.rank_list_idx = len(self.rank_list) 102 self._create_global_groups() 103 # for SyncBatchNorm 104 if self.process_groups != 0: 105 self.rank_id = get_rank() 106 self.rank_size = get_group_size() 107 if self.process_groups is not None: 108 validator.check_isinstance("process_groups", self.process_groups, list) 109 self._check_rank_ids(self.process_groups, self.rank_size) 110 self._create_sync_groups() 111 elif self.rank_size > 1: 112 self.is_global = True 113 self.group_device_num = self.rank_size 114 self.device_list = [i for i in range(0, self.rank_size)] 115 if context.get_context("device_target") == "Ascend": 116 if SYNC_BN_GROUP_NAME == "": 117 SYNC_BN_GROUP_NAME = "sync_bn_group0" 118 management.create_group(SYNC_BN_GROUP_NAME, self.device_list) 119 elif context.get_context("device_target") == "GPU": 120 if SYNC_BN_GROUP_NAME == "": 121 SYNC_BN_GROUP_NAME = "nccl_world_group" 122 123 self.shape = P.Shape() 124 self.reduce_mean = P.ReduceMean(keep_dims=True) 125 self.square = P.Square() 126 self.sqrt = P.Sqrt() 127 self.cast = P.Cast() 128 self.dtype = P.DType() 129 self.reshape = P.Reshape() 130 self._target = context.get_context("device_target") 131 self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE 132 self.momentum = 1.0 - momentum 133 if context.get_context("enable_ge"): 134 self.is_ge_backend = True 135 else: 136 self.is_ge_backend = False 137 138 self.bn_train = P.BatchNorm(is_training=True, 139 epsilon=self.eps, 140 momentum=self.momentum, 141 data_format=self.format) 142 if self.is_global: 143 self.bn_train = inner.SyncBatchNorm(epsilon=self.eps, 144 momentum=self.momentum, 145 group=SYNC_BN_GROUP_NAME, 146 device_num=self.group_device_num) 147 148 self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format) 149 if _is_in_auto_parallel_mode(): 150 data_parallel_strategy = ((1,), (1,)) 151 data_parallel_strategy_one = ((1,), ()) 152 else: 153 data_parallel_strategy = None 154 data_parallel_strategy_one = None 155 self.sub_mean = P.Sub().shard(data_parallel_strategy) 156 self.sub_var = P.Sub().shard(data_parallel_strategy) 157 self.mul_mean = P.Mul().shard(data_parallel_strategy_one) 158 self.mul_var = P.Mul().shard(data_parallel_strategy_one) 159 self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) 160 self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) 161 162 def _check_data_dim(self, x): 163 raise NotImplementedError 164 165 def list_group(self, world_rank, group_size): 166 """ Check whether world_rank and group_size are valid. """ 167 if group_size > get_group_size(): 168 raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' cannot be greater than " 169 f"local rank size, but got 'device_num_each_group': {group_size}, " 170 f"local rank size: {get_group_size()}.") 171 if len(world_rank) % group_size != 0: 172 raise ValueError(f"For '{self.cls_name}', the dimension of device_list should be divisible by " 173 f"'device_num_each_group', but got the length of device_list: {len(world_rank)}, " 174 f"'device_num_each_group': {group_size}.") 175 world_rank_list = zip(*(iter(world_rank),) * group_size) 176 group_list = [list(i) for i in world_rank_list] 177 return group_list 178 179 def _check_rank_ids(self, process_groups, rank_size): 180 seen = set() 181 for rid in itertools.chain(*process_groups): 182 validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name) 183 if rid in seen: 184 raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' should not be duplicated, " 185 f"but got {process_groups}.") 186 seen.add(rid) 187 188 def _create_global_groups(self): 189 for i in range(self.rank_list_idx): 190 if self.rank_id in self.rank_list[i]: 191 self.is_global = True 192 global SYNC_BN_GROUP_NAME 193 if SYNC_BN_GROUP_NAME == "": 194 SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) 195 management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i]) 196 197 def _create_sync_groups(self): 198 for i in range(len(self.process_groups)): 199 validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list) 200 self.group_device_num = len(self.process_groups[i]) 201 if self.rank_id in self.process_groups[i] and self.group_device_num > 1: 202 self.is_global = True 203 global SYNC_BN_GROUP_NAME 204 if SYNC_BN_GROUP_NAME == "": 205 SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i) 206 management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i]) 207 208 def construct(self, x): 209 _shape_check_bn(self.shape(x), self.input_dims, self.cls_name) 210 if self.use_batch_statistics is None: 211 if self.training: 212 return self.bn_train(x, 213 self.gamma, 214 self.beta, 215 self.moving_mean, 216 self.moving_variance)[0] 217 if not self.training: 218 return self.bn_infer(x, 219 self.gamma, 220 self.beta, 221 self.moving_mean, 222 self.moving_variance)[0] 223 224 if self.use_batch_statistics is True: 225 return self.bn_train(x, 226 self.gamma, 227 self.beta, 228 self.moving_mean, 229 self.moving_variance)[0] 230 231 return self.bn_infer(x, 232 self.gamma, 233 self.beta, 234 self.moving_mean, 235 self.moving_variance)[0] 236 237 def extend_repr(self): 238 return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( 239 self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) 240 241 242@constexpr 243def _channel_check(channel, num_channel, prim_name=None): 244 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 245 if channel != num_channel: 246 raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') should be equal to num_channels, " 247 f"but got channel: {channel}, num_channels: {num_channel}.") 248 249 250@constexpr 251def _shape_check(in_shape, prim_name=None): 252 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 253 if len(in_shape) != 4: 254 raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.") 255 256 257@constexpr 258def _shape_check_bn(in_shape, in_dims, prim_name=None): 259 """check input dims of batch norm.""" 260 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 261 dim = len(in_shape) 262 if in_dims == '1d' and dim != 2: 263 raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.") 264 if in_dims == '2d' and dim != 4: 265 raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.") 266 if in_dims == '3d' and dim != 5: 267 raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.") 268 if in_dims == 'both' and dim != 2 and dim != 4: 269 raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.") 270 271 272@constexpr 273def _shape_infer(x_shape, num_feature): 274 """global Batch Normalization shape and axes infer""" 275 if len(x_shape) == 4: 276 axes = (0, 2, 3) 277 re_shape = (1, num_feature, 1, 1) 278 else: 279 axes = (0,) 280 re_shape = (1, num_feature) 281 return axes, re_shape 282 283 284class BatchNorm1d(_BatchNorm): 285 r""" 286 Batch Normalization layer over a 2D input. 287 288 Batch Normalization is widely used in convolutional networks. This layer 289 applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to 290 reduce internal covariate shift as described in the paper 291 `Batch Normalization: Accelerating Deep Network Training by 292 Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It 293 rescales and recenters the feature using a mini-batch of data and 294 the learned parameters which can be described in the following formula. 295 296 .. math:: 297 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 298 299 Note: 300 The implementation of BatchNorm is different in graph mode and pynative mode, therefore the mode is not 301 recommended to be changed after net was initialized. 302 303 Args: 304 num_features (int): `C` from an expected input of size (N, C). 305 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 306 momentum (float): A floating hyperparameter of the momentum for the 307 running_mean and running_var computation. Default: 0.9. 308 affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. 309 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 310 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 311 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 312 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 313 moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. 314 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 315 moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. 316 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 317 use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, 318 use the mean value and variance value of specified value. If None, the training process will use the mean 319 and variance of current batch data and track the running mean and variance, the evaluation process will use 320 the running mean and variance. Default: None. 321 322 Inputs: 323 - **x** (Tensor) - Tensor of shape :math:`(N, C_{in})`. 324 325 Outputs: 326 Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`. 327 328 Supported Platforms: 329 ``Ascend`` ``GPU`` ``CPU`` 330 331 Raises: 332 TypeError: If `num_features` is not an int. 333 TypeError: If `eps` is not a float. 334 ValueError: If `num_features` is less than 1. 335 ValueError: If `momentum` is not in range [0, 1]. 336 337 Examples: 338 >>> import numpy as np 339 >>> import mindspore.nn as nn 340 >>> from mindspore import Tensor 341 >>> net = nn.BatchNorm1d(num_features=4) 342 >>> x = Tensor(np.array([[0.7, 0.5, 0.5, 0.6], 343 ... [0.5, 0.4, 0.6, 0.9]]).astype(np.float32)) 344 >>> output = net(x) 345 >>> print(output) 346 [[ 0.6999965 0.4999975 0.4999975 0.59999704 ] 347 [ 0.4999975 0.399998 0.59999704 0.89999545 ]] 348 """ 349 350 def __init__(self, 351 num_features, 352 eps=1e-5, 353 momentum=0.9, 354 affine=True, 355 gamma_init='ones', 356 beta_init='zeros', 357 moving_mean_init='zeros', 358 moving_var_init='ones', 359 use_batch_statistics=None): 360 """Initialize BatchNorm1d.""" 361 super(BatchNorm1d, self).__init__(num_features, 362 eps, 363 momentum, 364 affine, 365 gamma_init, 366 beta_init, 367 moving_mean_init, 368 moving_var_init, 369 use_batch_statistics, 370 input_dims='1d') 371 372 def _check_data_dim(self, x): 373 if x.ndim != 2: 374 pass 375 376 377class BatchNorm2d(_BatchNorm): 378 r""" 379 Batch Normalization layer over a 4D input. 380 381 Batch Normalization is widely used in convolutional networks. This layer 382 applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with 383 additional channel dimension) to avoid internal covariate shift as described 384 in the paper `Batch Normalization: Accelerating Deep Network Training by 385 Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It 386 rescales and recenters the feature using a mini-batch of data and 387 the learned parameters which can be described in the following formula. 388 389 .. math:: 390 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 391 392 Note: 393 The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be 394 changed after net was initialized. 395 Note that the formula for updating the running_mean and running_var is 396 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`, 397 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. 398 399 Args: 400 num_features (int): `C` from an expected input of size (N, C, H, W). 401 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 402 momentum (float): A floating hyperparameter of the momentum for the 403 running_mean and running_var computation. Default: 0.9. 404 affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. 405 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 406 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 407 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 408 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 409 moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. 410 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 411 moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. 412 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 413 use_batch_statistics (bool): 414 415 - If true, use the mean value and variance value of current batch data and track running mean 416 and running varance. 417 - If false, use the mean value and variance value of specified value, and not track statistical value. 418 - If None, The use_batch_statistics is automatically assigned process according to 419 the training and eval mode. During training, batchnorm2d process will be the same 420 with use_batch_statistics=True. Contrarily, in eval, batchnorm2d process will be the same 421 with use_batch_statistics=False. Default: None. 422 423 data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. 424 Default: 'NCHW'. 425 426 Inputs: 427 - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. 428 429 Outputs: 430 Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. 431 432 Raises: 433 TypeError: If `num_features` is not an int. 434 TypeError: If `eps` is not a float. 435 ValueError: If `num_features` is less than 1. 436 ValueError: If `momentum` is not in range [0, 1]. 437 ValueError: If `data_format` is neither 'NHWC' not 'NCHW'. 438 439 Supported Platforms: 440 ``Ascend`` ``GPU`` ``CPU`` 441 442 Examples: 443 >>> import numpy as np 444 >>> import mindspore.nn as nn 445 >>> from mindspore import Tensor 446 >>> net = nn.BatchNorm2d(num_features=3) 447 >>> x = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32)) 448 >>> output = net(x) 449 >>> print(output) 450 [[[[ 0.999995 0.999995 ] 451 [ 0.999995 0.999995 ]] 452 [[ 0.999995 0.999995 ] 453 [ 0.999995 0.999995 ]] 454 [[ 0.999995 0.999995 ] 455 [ 0.999995 0.999995 ]]]] 456 """ 457 458 def __init__(self, 459 num_features, 460 eps=1e-5, 461 momentum=0.9, 462 affine=True, 463 gamma_init='ones', 464 beta_init='zeros', 465 moving_mean_init='zeros', 466 moving_var_init='ones', 467 use_batch_statistics=None, 468 data_format='NCHW'): 469 """Initialize BatchNorm2d.""" 470 super(BatchNorm2d, self).__init__(num_features, 471 eps, 472 momentum, 473 affine, 474 gamma_init, 475 beta_init, 476 moving_mean_init, 477 moving_var_init, 478 use_batch_statistics, 479 input_dims='2d', 480 data_format=data_format) 481 482 def _check_data_dim(self, x): 483 if x.ndim != 4: 484 pass 485 486 487@constexpr 488def _check_3d_shape(input_shape, prim_name=None): 489 msg_prefix = f"For '{prim_name}', the" if prim_name else "The" 490 if len(input_shape) != 5: 491 raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: " 492 f"{len(input_shape)}.") 493 494 495class BatchNorm3d(Cell): 496 r""" 497 Batch Normalization layer over a 5D input. 498 499 Batch Normalization is widely used in convolutional networks. This layer 500 applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with 501 additional channel dimension) to avoid internal covariate shift. 502 503 .. math:: 504 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 505 506 Note: 507 The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be 508 changed after net was initialized. 509 Note that the formula for updating the running_mean and running_var is 510 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`, 511 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. 512 513 Args: 514 num_features (int): `C` from an expected input of size (N, C, D, H, W). 515 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 516 momentum (float): A floating hyperparameter of the momentum for the 517 running_mean and running_var computation. Default: 0.9. 518 affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. 519 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 520 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 521 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 522 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 523 moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. 524 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 525 moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. 526 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 527 use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, 528 use the mean value and variance value of specified value. If None, the training process will use the mean 529 and variance of current batch data and track the running mean and variance, the evaluation process will use 530 the running mean and variance. Default: None. 531 data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'. 532 533 Inputs: 534 - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`. 535 536 Outputs: 537 Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`. 538 539 Raises: 540 TypeError: If `num_features` is not an int. 541 TypeError: If `eps` is not a float. 542 ValueError: If `num_features` is less than 1. 543 ValueError: If `momentum` is not in range [0, 1]. 544 ValueError: If `data_format` is not 'NCDHW'. 545 546 Supported Platforms: 547 ``Ascend`` ``GPU`` ``CPU`` 548 549 Examples: 550 >>> import numpy as np 551 >>> import mindspore.nn as nn 552 >>> from mindspore import Tensor 553 >>> net = nn.BatchNorm3d(num_features=3) 554 >>> x = Tensor(np.ones([16, 3, 10, 32, 32]).astype(np.float32)) 555 >>> output = net(x) 556 >>> print(output.shape) 557 (16, 3, 10, 32, 32) 558 """ 559 560 def __init__(self, 561 num_features, 562 eps=1e-5, 563 momentum=0.9, 564 affine=True, 565 gamma_init='ones', 566 beta_init='zeros', 567 moving_mean_init='zeros', 568 moving_var_init='ones', 569 use_batch_statistics=None, 570 data_format='NCDHW'): 571 """Initialize BatchNorm3d.""" 572 super(BatchNorm3d, self).__init__() 573 self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name) 574 self.reshape = P.Reshape() 575 self.bn2d = BatchNorm2d(num_features=num_features, 576 eps=eps, 577 momentum=momentum, 578 affine=affine, 579 gamma_init=gamma_init, 580 beta_init=beta_init, 581 moving_mean_init=moving_mean_init, 582 moving_var_init=moving_var_init, 583 use_batch_statistics=use_batch_statistics, 584 data_format="NCHW") 585 586 def construct(self, input_x): 587 x_shape = F.shape(input_x) 588 _check_3d_shape(x_shape, self.cls_name) 589 input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4])) 590 bn2d_out = self.bn2d(input_x) 591 bn3d_out = self.reshape(bn2d_out, x_shape) 592 return bn3d_out 593 594 595class GlobalBatchNorm(_BatchNorm): 596 r""" 597 Global Batch Normalization layer over a N-dimension input. 598 599 Global Batch Normalization is cross device synchronized Batch Normalization. The implementation of 600 Batch Normalization only normalizes the data within each device. Global Normalization will normalize 601 the input within the group.It has been described in the paper `Batch Normalization: Accelerating Deep Network 602 Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the 603 feature using a mini-batch of data and the learned parameters which can be described in the following formula. 604 605 .. math:: 606 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 607 608 Note: 609 Currently, GlobalBatchNorm only supports 2D and 4D inputs. 610 611 Args: 612 num_features (int): `C` from an expected input of size (N, C, H, W). 613 device_num_each_group (int): The number of devices in each group. Default: 2. 614 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 615 momentum (float): A floating hyperparameter of the momentum for the 616 running_mean and running_var computation. Default: 0.9. 617 affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. 618 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 619 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 620 'he_uniform', etc. Default: 'ones'. 621 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 622 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 623 'he_uniform', etc. Default: 'zeros'. 624 moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. 625 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 626 'he_uniform', etc. Default: 'zeros'. 627 moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. 628 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 629 'he_uniform', etc. Default: 'ones'. 630 use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, 631 use the mean value and variance value of specified value. If None, training process will use the mean and 632 variance of current batch data and track the running mean and variance, eval process will use the running 633 mean and variance. Default: None. 634 635 Inputs: 636 - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. 637 638 Outputs: 639 Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. 640 641 Raises: 642 TypeError: If `num_features` or `device_num_each_group` is not an int. 643 TypeError: If `eps` is not a float. 644 ValueError: If `num_features` is less than 1. 645 ValueError: If `momentum` is not in range [0, 1]. 646 ValueError: If `device_num_each_group` is less than 2. 647 648 Supported Platforms: 649 ``Ascend`` 650 651 Examples: 652 >>> # This example should be run with multiple processes. 653 >>> # Please refer to the tutorial > Distributed Training on mindspore.cn. 654 >>> import numpy as np 655 >>> from mindspore.communication import init 656 >>> from mindspore import context 657 >>> from mindspore.context import ParallelMode 658 >>> from mindspore import nn 659 >>> from mindspore import Tensor 660 >>> 661 >>> context.set_context(mode=context.GRAPH_MODE) 662 >>> init() 663 >>> context.reset_auto_parallel_context() 664 >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) 665 >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=2) 666 >>> x = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32)) 667 >>> output = global_bn_op(x) 668 >>> print(output) 669 [[[[ 0.999995 0.999995 ] 670 [ 0.999995 0.999995 ]] 671 [[ 0.999995 0.999995 ] 672 [ 0.999995 0.999995 ]] 673 [[ 0.999995 0.999995 ] 674 [ 0.999995 0.999995 ]]]] 675 """ 676 677 @deprecated("1.2", "SyncBatchNorm", True) 678 def __init__(self, 679 num_features, 680 eps=1e-5, 681 momentum=0.9, 682 affine=True, 683 gamma_init='ones', 684 beta_init='zeros', 685 moving_mean_init='zeros', 686 moving_var_init='ones', 687 use_batch_statistics=None, 688 device_num_each_group=2): 689 """Initialize GlobalBatchNorm.""" 690 super(GlobalBatchNorm, self).__init__(num_features, 691 eps, 692 momentum, 693 affine, 694 gamma_init, 695 beta_init, 696 moving_mean_init, 697 moving_var_init, 698 use_batch_statistics, 699 device_num_each_group, 700 input_dims='both') 701 self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group", 702 self.cls_name) 703 if self.group_device_num <= 1: 704 raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, " 705 f"but got {self.group_device_num}.") 706 707 def _check_data_dim(self, x): 708 if x.dim == 0: 709 pass 710 711 712class SyncBatchNorm(_BatchNorm): 713 r""" 714 Sync Batch Normalization layer over a N-dimension input. 715 716 Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch 717 Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input 718 within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by 719 Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the 720 feature using a mini-batch of data and the learned parameters which can be described in the following formula. 721 722 .. math:: 723 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 724 725 Note: 726 Currently, SyncBatchNorm only supports 2D and 4D inputs. 727 728 Args: 729 num_features (int): `C` from an expected input of size (N, C, H, W). 730 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 731 momentum (float): A floating hyperparameter of the momentum for the 732 running_mean and running_var computation. Default: 0.9. 733 affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. 734 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 735 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 736 'he_uniform', etc. Default: 'ones'. 737 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 738 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 739 'he_uniform', etc. Default: 'zeros'. 740 moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. 741 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 742 'he_uniform', etc. Default: 'zeros'. 743 moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. 744 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 745 'he_uniform', etc. Default: 'ones'. 746 use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, 747 use the mean value and variance value of specified value. If None, training process will use the mean and 748 variance of current batch data and track the running mean and variance, eval process will use the running 749 mean and variance. Default: None. 750 process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists. 751 Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same 752 group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating 753 synchronization across all devices. 754 755 Inputs: 756 - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. 757 758 Outputs: 759 Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. 760 761 Raises: 762 TypeError: If `num_features` is not an int. 763 TypeError: If `eps` is not a float. 764 TypeError: If `process_groups` is not a list. 765 ValueError: If `num_features` is less than 1. 766 ValueError: If `momentum` is not in range [0, 1]. 767 ValueError: If rank_id in `process_groups` is not in range [0, rank_size). 768 769 Supported Platforms: 770 ``Ascend`` 771 772 Examples: 773 >>> # This example should be run with multiple processes. 774 >>> # Please refer to the tutorial > Distributed Training on mindspore.cn. 775 >>> import numpy as np 776 >>> from mindspore.communication import init 777 >>> from mindspore import context 778 >>> from mindspore.context import ParallelMode 779 >>> from mindspore import Tensor 780 >>> from mindspore import nn 781 >>> from mindspore import dtype as mstype 782 >>> 783 >>> context.set_context(mode=context.GRAPH_MODE) 784 >>> init() 785 >>> context.reset_auto_parallel_context() 786 >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) 787 >>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]]) 788 >>> x = Tensor(np.ones([1, 3, 2, 2]), mstype.float32) 789 >>> output = sync_bn_op(x) 790 >>> print(output) 791 [[[[ 0.999995 0.999995 ] 792 [ 0.999995 0.999995 ]] 793 [[ 0.999995 0.999995 ] 794 [ 0.999995 0.999995 ]] 795 [[ 0.999995 0.999995 ] 796 [ 0.999995 0.999995 ]]]] 797 """ 798 799 def __init__(self, 800 num_features, 801 eps=1e-5, 802 momentum=0.9, 803 affine=True, 804 gamma_init='ones', 805 beta_init='zeros', 806 moving_mean_init='zeros', 807 moving_var_init='ones', 808 use_batch_statistics=None, 809 process_groups=None): 810 """Initialize SyncBatchNorm.""" 811 super(SyncBatchNorm, self).__init__(num_features, 812 eps, 813 momentum, 814 affine, 815 gamma_init, 816 beta_init, 817 moving_mean_init, 818 moving_var_init, 819 use_batch_statistics, 820 process_groups=process_groups, 821 input_dims='both') 822 823 def _check_data_dim(self, x): 824 if x.dim == 0: 825 pass 826 827 828class LayerNorm(Cell): 829 r""" 830 Applies Layer Normalization over a mini-batch of inputs. 831 832 Layer Normalization is widely used in recurrent neural networks. It applies 833 normalization on a mini-batch of inputs for each single training case as described 834 in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch 835 Normalization, Layer Normalization performs exactly the same computation at training and 836 testing time. It can be described using the following formula. It is applied across all channels 837 and pixel but only one batch size. 838 839 .. math:: 840 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 841 842 Args: 843 normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis 844 `begin_norm_axis ... R - 1`. 845 begin_norm_axis (int): The first normalization dimension: normalization will be performed along dimensions 846 `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1. 847 begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters 848 will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with 849 the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1. 850 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 851 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 852 'he_uniform', etc. Default: 'ones'. 853 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 854 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 855 'he_uniform', etc. Default: 'zeros'. 856 epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7. 857 858 Inputs: 859 - **x** (Tensor) - The shape of 'x' is :math:`(x_1, x_2, ..., x_R)`, 860 and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`. 861 862 Outputs: 863 Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`. 864 865 Raises: 866 TypeError: If `normalized_shape` is neither a list nor tuple. 867 TypeError: If `begin_norm_axis` or `begin_params_axis` is not an int. 868 TypeError: If `epsilon` is not a float. 869 870 Supported Platforms: 871 ``Ascend`` ``GPU`` ``CPU`` 872 873 Examples: 874 >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32) 875 >>> shape1 = x.shape[1:] 876 >>> m = nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1) 877 >>> output = m(x).shape 878 >>> print(output) 879 (20, 5, 10, 10) 880 """ 881 882 def __init__(self, 883 normalized_shape, 884 begin_norm_axis=-1, 885 begin_params_axis=-1, 886 gamma_init='ones', 887 beta_init='zeros', 888 epsilon=1e-7 889 ): 890 """Initialize LayerNorm.""" 891 super(LayerNorm, self).__init__() 892 if not isinstance(normalized_shape, (tuple, list)): 893 raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' should be tuple[int] or list[int], " 894 f"but got {normalized_shape} and the type is {type(normalized_shape)}.") 895 self.normalized_shape = normalized_shape 896 self.begin_norm_axis = begin_norm_axis 897 self.begin_params_axis = begin_params_axis 898 self.epsilon = epsilon 899 self.gamma = Parameter(initializer( 900 gamma_init, normalized_shape), name="gamma") 901 self.beta = Parameter(initializer( 902 beta_init, normalized_shape), name="beta") 903 self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis, 904 begin_params_axis=self.begin_params_axis, 905 epsilon=self.epsilon) 906 907 def construct(self, input_x): 908 y, _, _ = self.layer_norm(input_x, self.gamma, self.beta) 909 return y 910 911 def extend_repr(self): 912 """Display instance object as string.""" 913 return 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format( 914 self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta) 915 916 917class InstanceNorm2d(Cell): 918 r""" 919 Instance Normalization layer over a 4D input. 920 921 This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with 922 additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for 923 Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch 924 of data and the learned parameters which can be described in the following formula. 925 926 .. math:: 927 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 928 929 \gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation 930 is calculated via the biased estimator. 931 932 This layer uses instance statistics computed from input data in both training and evaluation modes. 933 934 InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each 935 channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data. 936 937 Note: 938 Note that the formula for updating the running_mean and running_var is 939 :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`, 940 where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value. 941 942 Args: 943 num_features (int): `C` from an expected input of size (N, C, H, W). 944 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 945 momentum (float): A floating hyperparameter of the momentum for the 946 running_mean and running_var computation. Default: 0.1. 947 affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True. 948 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 949 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. 950 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 951 The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. 952 953 Inputs: 954 - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32. 955 956 Outputs: 957 Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`. Same type and 958 shape as the `x`. 959 960 Supported Platforms: 961 ``GPU`` 962 963 Raises: 964 TypeError: If `num_features` is not an int. 965 TypeError: If `eps` is not a float. 966 TypeError: If `momentum` is not a float. 967 TypeError: If `affine` is not a bool. 968 TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not 969 float32. 970 ValueError: If `num_features` is less than 1. 971 ValueError: If `momentum` is not in range [0, 1]. 972 KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not 973 exists. 974 975 Examples: 976 >>> import mindspore 977 >>> import numpy as np 978 >>> import mindspore.nn as nn 979 >>> from mindspore import Tensor 980 >>> net = nn.InstanceNorm2d(3) 981 >>> x = Tensor(np.ones([2, 3, 2, 2]), mindspore.float32) 982 >>> output = net(x) 983 >>> print(output.shape) 984 (2, 3, 2, 2) 985 """ 986 987 @cell_attr_register 988 def __init__(self, 989 num_features, 990 eps=1e-5, 991 momentum=0.1, 992 affine=True, 993 gamma_init='ones', 994 beta_init='zeros'): 995 """Initialize InstanceNorm2d.""" 996 super(InstanceNorm2d, self).__init__() 997 validator.check_value_type('num_features', num_features, [int], self.cls_name) 998 validator.check_value_type('eps', eps, [float], self.cls_name) 999 validator.check_value_type('momentum', momentum, [float], self.cls_name) 1000 validator.check_value_type('affine', affine, [bool], self.cls_name) 1001 args_input = {"gamma_init": gamma_init, "beta_init": beta_init} 1002 self.check_types_valid(args_input, 'InstanceNorm2d') 1003 if num_features < 1: 1004 raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.") 1005 1006 if momentum < 0 or momentum > 1: 1007 raise ValueError(f"For '{self.cls_name}', the 'momentum' should be a number in range [0, 1], " 1008 f"but got {momentum}.") 1009 self.num_features = num_features 1010 self.eps = eps 1011 self.input_dims = '2d' 1012 self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False) 1013 self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False) 1014 self.gamma = Parameter(initializer( 1015 gamma_init, num_features), name="gamma", requires_grad=affine) 1016 self.beta = Parameter(initializer( 1017 beta_init, num_features), name="beta", requires_grad=affine) 1018 1019 self.shape = P.Shape() 1020 self.momentum = momentum 1021 self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum) 1022 1023 def _check_data_dim(self, x): 1024 raise NotImplementedError 1025 1026 def construct(self, x): 1027 _shape_check_bn(self.shape(x), self.input_dims, self.cls_name) 1028 return self.instance_bn(x, 1029 self.gamma, 1030 self.beta, 1031 self.moving_mean, 1032 self.moving_variance)[0] 1033 1034 def extend_repr(self): 1035 return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( 1036 self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) 1037 1038 def check_types_valid(self, args_dict, name): 1039 for key, _ in args_dict.items(): 1040 val = args_dict[key] 1041 if not isinstance(val, (Tensor, numbers.Number, str, Initializer)): 1042 raise TypeError(f"For '{self.cls_name}', the type of args_dict['{key}'] should be in " 1043 f"[Tensor, numbers.Number, str, Initializer], but got type {type(val).__name__}.") 1044 if isinstance(val, Tensor) and val.dtype != mstype.float32: 1045 raise TypeError(f"For '{self.cls_name}', the type of args_dict['{key}'] should be float32, " 1046 f"but got {val.dtype}.") 1047 1048 1049class GroupNorm(Cell): 1050 r""" 1051 Group Normalization over a mini-batch of inputs. 1052 1053 Group Normalization is widely used in recurrent neural networks. It applies 1054 normalization on a mini-batch of inputs for each single training case as described 1055 in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group Normalization 1056 divides the channels into groups and computes within each group the mean and variance for normalization, 1057 and it performs very stable over a wide range of batch size. It can be described using the following formula. 1058 1059 .. math:: 1060 y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta 1061 1062 Args: 1063 num_groups (int): The number of groups to be divided along the channel dimension. 1064 num_channels (int): The number of channels per group. 1065 eps (float): A value added to the denominator for numerical stability. Default: 1e-5. 1066 affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True. 1067 gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. 1068 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 1069 'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be [num_channels]. 1070 beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. 1071 The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 1072 'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be [num_channels]. 1073 1074 Inputs: 1075 - **x** (Tensor) - The input feature with shape [N, C, H, W]. 1076 1077 Outputs: 1078 Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`. 1079 1080 Raises: 1081 TypeError: If `num_groups` or `num_channels` is not an int. 1082 TypeError: If `eps` is not a float. 1083 TypeError: If `affine` is not a bool. 1084 ValueError: If `num_groups` or `num_channels` is less than 1. 1085 ValueError: If `num_channels` is not divided by `num_groups`. 1086 1087 Supported Platforms: 1088 ``Ascend`` ``GPU`` ``CPU`` 1089 1090 Examples: 1091 >>> group_norm_op = nn.GroupNorm(2, 2) 1092 >>> x = Tensor(np.ones([1, 2, 4, 4], np.float32)) 1093 >>> output = group_norm_op(x) 1094 >>> print(output) 1095 [[[[0. 0. 0. 0.] 1096 [0. 0. 0. 0.] 1097 [0. 0. 0. 0.] 1098 [0. 0. 0. 0.]] 1099 [[0. 0. 0. 0.] 1100 [0. 0. 0. 0.] 1101 [0. 0. 0. 0.] 1102 [0. 0. 0. 0.]]]] 1103 """ 1104 1105 def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'): 1106 """Initialize GroupNorm.""" 1107 super(GroupNorm, self).__init__() 1108 self.num_groups = validator.check_positive_int(num_groups, "num_groups", self.cls_name) 1109 self.num_channels = validator.check_positive_int(num_channels, "num_channels", self.cls_name) 1110 if num_channels % num_groups != 0: 1111 raise ValueError(f"For '{self.cls_name}', the 'num_channels' should be divided by 'num_groups', " 1112 f"but got 'num_channels': {num_channels}, 'num_groups': {num_groups}.") 1113 self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__) 1114 self.affine = validator.check_bool(affine, arg_name="affine", prim_name=self.cls_name) 1115 1116 gamma = initializer(gamma_init, num_channels) 1117 beta = initializer(beta_init, num_channels) 1118 if self.affine: 1119 self.gamma = Parameter(gamma, name='gamma') 1120 self.beta = Parameter(beta, name='beta') 1121 else: 1122 self.gamma = gamma 1123 self.beta = beta 1124 self.shape = F.shape 1125 self.reshape = F.reshape 1126 self.reduce_mean = P.ReduceMean(keep_dims=True) 1127 self.square = F.square 1128 self.reduce_sum = P.ReduceSum(keep_dims=True) 1129 self.sqrt = P.Sqrt() 1130 1131 def _cal_output(self, x): 1132 """calculate groupnorm output""" 1133 batch, channel, height, width = self.shape(x) 1134 _channel_check(channel, self.num_channels, self.cls_name) 1135 x = self.reshape(x, (batch, self.num_groups, -1)) 1136 mean = self.reduce_mean(x, 2) 1137 var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups) 1138 std = self.sqrt(var + self.eps) 1139 x = (x - mean) / std 1140 x = self.reshape(x, (batch, channel, height, width)) 1141 output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1)) 1142 return output 1143 1144 def construct(self, x): 1145 _shape_check(self.shape(x), self.cls_name) 1146 output = self._cal_output(x) 1147 return output 1148 1149 def extend_repr(self): 1150 """Display instance object as string.""" 1151 return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels) 1152