• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""normalization"""
16import itertools
17import numbers
18
19from mindspore.ops import operations as P
20from mindspore.ops import functional as F
21from mindspore.ops.operations import _inner_ops as inner
22from mindspore.common.parameter import Parameter
23from mindspore.common.initializer import initializer, Initializer
24from mindspore.common.tensor import Tensor
25from mindspore.common._decorator import deprecated
26from mindspore.ops.primitive import constexpr
27import mindspore.context as context
28from mindspore._checkparam import Rel
29from mindspore._checkparam import Validator as validator
30from mindspore._extends import cell_attr_register
31from mindspore.communication.management import get_group_size, get_rank
32from mindspore.communication import management
33from mindspore.common import dtype as mstype
34from mindspore.parallel._utils import _is_in_auto_parallel_mode
35from ..cell import Cell
36
37__all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'GroupNorm',
38           'GlobalBatchNorm', 'SyncBatchNorm', 'InstanceNorm2d']
39
40SYNC_BN_GROUP_NAME = ""
41
42
43class _BatchNorm(Cell):
44    """Batch Normalization base class."""
45
46    @cell_attr_register
47    def __init__(self,
48                 num_features,
49                 eps=1e-5,
50                 momentum=0.9,
51                 affine=True,
52                 gamma_init='ones',
53                 beta_init='zeros',
54                 moving_mean_init='zeros',
55                 moving_var_init='ones',
56                 use_batch_statistics=None,
57                 device_num_each_group=1,
58                 process_groups=0,
59                 input_dims='2d',
60                 data_format='NCHW'):
61        """Initialize _BatchNorm."""
62        super(_BatchNorm, self).__init__()
63        validator.check_value_type('num_features', num_features, [int], self.cls_name)
64        if num_features < 1:
65            raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
66
67        if momentum < 0 or momentum > 1:
68            raise ValueError(f"For '{self.cls_name}', the 'momentum' should be a number in range [0, 1], "
69                             f"but got {momentum}.")
70        self.input_dims = input_dims
71        self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
72        if context.get_context("device_target") != "GPU" and self.format == "NHWC":
73            raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
74                             f"target {context.get_context('device_target')}.")
75        self.use_batch_statistics = use_batch_statistics
76        if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
77            raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' should be a boolean value or None,"
78                             f" but got {use_batch_statistics}.")
79        self.num_features = num_features
80        self.eps = eps
81        self.moving_mean = Parameter(initializer(
82            moving_mean_init, num_features), name="mean", requires_grad=False)
83        self.moving_variance = Parameter(initializer(
84            moving_var_init, num_features), name="variance", requires_grad=False)
85        self.gamma = Parameter(initializer(
86            gamma_init, num_features), name="gamma", requires_grad=affine)
87        self.beta = Parameter(initializer(
88            beta_init, num_features), name="beta", requires_grad=affine)
89        self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
90                                                             self.cls_name)
91        self.process_groups = process_groups
92        self.is_global = False
93        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
94        global SYNC_BN_GROUP_NAME
95        # for GlobalBatchNorm
96        if self.group_device_num != 1:
97            self.rank_id = get_rank()
98            self.rank_size = get_group_size()
99            self.device_list = [i for i in range(0, self.rank_size)]
100            self.rank_list = self.list_group(self.device_list, self.group_device_num)
101            self.rank_list_idx = len(self.rank_list)
102            self._create_global_groups()
103        # for SyncBatchNorm
104        if self.process_groups != 0:
105            self.rank_id = get_rank()
106            self.rank_size = get_group_size()
107            if self.process_groups is not None:
108                validator.check_isinstance("process_groups", self.process_groups, list)
109                self._check_rank_ids(self.process_groups, self.rank_size)
110                self._create_sync_groups()
111            elif self.rank_size > 1:
112                self.is_global = True
113                self.group_device_num = self.rank_size
114                self.device_list = [i for i in range(0, self.rank_size)]
115                if context.get_context("device_target") == "Ascend":
116                    if SYNC_BN_GROUP_NAME == "":
117                        SYNC_BN_GROUP_NAME = "sync_bn_group0"
118                        management.create_group(SYNC_BN_GROUP_NAME, self.device_list)
119                elif context.get_context("device_target") == "GPU":
120                    if SYNC_BN_GROUP_NAME == "":
121                        SYNC_BN_GROUP_NAME = "nccl_world_group"
122
123        self.shape = P.Shape()
124        self.reduce_mean = P.ReduceMean(keep_dims=True)
125        self.square = P.Square()
126        self.sqrt = P.Sqrt()
127        self.cast = P.Cast()
128        self.dtype = P.DType()
129        self.reshape = P.Reshape()
130        self._target = context.get_context("device_target")
131        self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
132        self.momentum = 1.0 - momentum
133        if context.get_context("enable_ge"):
134            self.is_ge_backend = True
135        else:
136            self.is_ge_backend = False
137
138        self.bn_train = P.BatchNorm(is_training=True,
139                                    epsilon=self.eps,
140                                    momentum=self.momentum,
141                                    data_format=self.format)
142        if self.is_global:
143            self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
144                                                momentum=self.momentum,
145                                                group=SYNC_BN_GROUP_NAME,
146                                                device_num=self.group_device_num)
147
148        self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
149        if _is_in_auto_parallel_mode():
150            data_parallel_strategy = ((1,), (1,))
151            data_parallel_strategy_one = ((1,), ())
152        else:
153            data_parallel_strategy = None
154            data_parallel_strategy_one = None
155        self.sub_mean = P.Sub().shard(data_parallel_strategy)
156        self.sub_var = P.Sub().shard(data_parallel_strategy)
157        self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
158        self.mul_var = P.Mul().shard(data_parallel_strategy_one)
159        self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
160        self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
161
162    def _check_data_dim(self, x):
163        raise NotImplementedError
164
165    def list_group(self, world_rank, group_size):
166        """ Check whether world_rank and group_size  are valid. """
167        if group_size > get_group_size():
168            raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' cannot be greater than "
169                             f"local rank size, but got 'device_num_each_group': {group_size}, "
170                             f"local rank size: {get_group_size()}.")
171        if len(world_rank) % group_size != 0:
172            raise ValueError(f"For '{self.cls_name}', the dimension of device_list should be divisible by "
173                             f"'device_num_each_group', but got the length of device_list: {len(world_rank)}, "
174                             f"'device_num_each_group': {group_size}.")
175        world_rank_list = zip(*(iter(world_rank),) * group_size)
176        group_list = [list(i) for i in world_rank_list]
177        return group_list
178
179    def _check_rank_ids(self, process_groups, rank_size):
180        seen = set()
181        for rid in itertools.chain(*process_groups):
182            validator.check_int_range(rid, 0, rank_size, Rel.INC_LEFT, "rank id in process_groups", self.cls_name)
183            if rid in seen:
184                raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' should not be duplicated, "
185                                 f"but got {process_groups}.")
186            seen.add(rid)
187
188    def _create_global_groups(self):
189        for i in range(self.rank_list_idx):
190            if self.rank_id in self.rank_list[i]:
191                self.is_global = True
192                global SYNC_BN_GROUP_NAME
193                if SYNC_BN_GROUP_NAME == "":
194                    SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
195                    management.create_group(SYNC_BN_GROUP_NAME, self.rank_list[i])
196
197    def _create_sync_groups(self):
198        for i in range(len(self.process_groups)):
199            validator.check_isinstance("process_groups[" + str(i) + "]", self.process_groups[i], list)
200            self.group_device_num = len(self.process_groups[i])
201            if self.rank_id in self.process_groups[i] and self.group_device_num > 1:
202                self.is_global = True
203                global SYNC_BN_GROUP_NAME
204                if SYNC_BN_GROUP_NAME == "":
205                    SYNC_BN_GROUP_NAME = "sync_bn_group" + str(i)
206                    management.create_group(SYNC_BN_GROUP_NAME, self.process_groups[i])
207
208    def construct(self, x):
209        _shape_check_bn(self.shape(x), self.input_dims, self.cls_name)
210        if self.use_batch_statistics is None:
211            if self.training:
212                return self.bn_train(x,
213                                     self.gamma,
214                                     self.beta,
215                                     self.moving_mean,
216                                     self.moving_variance)[0]
217            if not self.training:
218                return self.bn_infer(x,
219                                     self.gamma,
220                                     self.beta,
221                                     self.moving_mean,
222                                     self.moving_variance)[0]
223
224        if self.use_batch_statistics is True:
225            return self.bn_train(x,
226                                 self.gamma,
227                                 self.beta,
228                                 self.moving_mean,
229                                 self.moving_variance)[0]
230
231        return self.bn_infer(x,
232                             self.gamma,
233                             self.beta,
234                             self.moving_mean,
235                             self.moving_variance)[0]
236
237    def extend_repr(self):
238        return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
239            self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
240
241
242@constexpr
243def _channel_check(channel, num_channel, prim_name=None):
244    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
245    if channel != num_channel:
246        raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') should be equal to num_channels, "
247                         f"but got channel: {channel}, num_channels: {num_channel}.")
248
249
250@constexpr
251def _shape_check(in_shape, prim_name=None):
252    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
253    if len(in_shape) != 4:
254        raise ValueError(f"{msg_prefix} in_shape must has 4 dims, but got the length of in_shape: {len(in_shape)}.")
255
256
257@constexpr
258def _shape_check_bn(in_shape, in_dims, prim_name=None):
259    """check input dims of batch norm."""
260    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
261    dim = len(in_shape)
262    if in_dims == '1d' and dim != 2:
263        raise ValueError(f"{msg_prefix} in_shape must have 2 dims, but got {len(in_shape)}.")
264    if in_dims == '2d' and dim != 4:
265        raise ValueError(f"{msg_prefix} in_shape must have 4 dims, but got {len(in_shape)}.")
266    if in_dims == '3d' and dim != 5:
267        raise ValueError(f"{msg_prefix} in_shape must have 5 dims, but got {len(in_shape)}.")
268    if in_dims == 'both' and dim != 2 and dim != 4:
269        raise ValueError(f"{msg_prefix} in_shape must have 2 dims or 4 dims, but got {len(in_shape)}.")
270
271
272@constexpr
273def _shape_infer(x_shape, num_feature):
274    """global Batch Normalization shape and axes infer"""
275    if len(x_shape) == 4:
276        axes = (0, 2, 3)
277        re_shape = (1, num_feature, 1, 1)
278    else:
279        axes = (0,)
280        re_shape = (1, num_feature)
281    return axes, re_shape
282
283
284class BatchNorm1d(_BatchNorm):
285    r"""
286    Batch Normalization layer over a 2D input.
287
288    Batch Normalization is widely used in convolutional networks. This layer
289    applies Batch Normalization over a 2D input (a mini-batch of 1D inputs) to
290    reduce internal covariate shift as described in the paper
291    `Batch Normalization: Accelerating Deep Network Training by
292    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
293    rescales and recenters the feature using a mini-batch of data and
294    the learned parameters which can be described in the following formula.
295
296    .. math::
297        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
298
299    Note:
300        The implementation of BatchNorm is different in graph mode and pynative mode, therefore the mode is not
301        recommended to be changed after net was initialized.
302
303    Args:
304        num_features (int): `C` from an expected input of size (N, C).
305        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
306        momentum (float): A floating hyperparameter of the momentum for the
307            running_mean and running_var computation. Default: 0.9.
308        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
309        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
310            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
311        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
312            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
313        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
314            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
315        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
316            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
317        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
318            use the mean value and variance value of specified value. If None, the training process will use the mean
319            and variance of current batch data and track the running mean and variance, the evaluation process will use
320            the running mean and variance. Default: None.
321
322    Inputs:
323        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in})`.
324
325    Outputs:
326        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out})`.
327
328    Supported Platforms:
329        ``Ascend`` ``GPU`` ``CPU``
330
331    Raises:
332        TypeError: If `num_features` is not an int.
333        TypeError: If `eps` is not a float.
334        ValueError: If `num_features` is less than 1.
335        ValueError: If `momentum` is not in range [0, 1].
336
337    Examples:
338        >>> import numpy as np
339        >>> import mindspore.nn as nn
340        >>> from mindspore import Tensor
341        >>> net = nn.BatchNorm1d(num_features=4)
342        >>> x = Tensor(np.array([[0.7, 0.5, 0.5, 0.6],
343        ...                      [0.5, 0.4, 0.6, 0.9]]).astype(np.float32))
344        >>> output = net(x)
345        >>> print(output)
346        [[ 0.6999965   0.4999975  0.4999975  0.59999704 ]
347         [ 0.4999975   0.399998   0.59999704 0.89999545 ]]
348    """
349
350    def __init__(self,
351                 num_features,
352                 eps=1e-5,
353                 momentum=0.9,
354                 affine=True,
355                 gamma_init='ones',
356                 beta_init='zeros',
357                 moving_mean_init='zeros',
358                 moving_var_init='ones',
359                 use_batch_statistics=None):
360        """Initialize BatchNorm1d."""
361        super(BatchNorm1d, self).__init__(num_features,
362                                          eps,
363                                          momentum,
364                                          affine,
365                                          gamma_init,
366                                          beta_init,
367                                          moving_mean_init,
368                                          moving_var_init,
369                                          use_batch_statistics,
370                                          input_dims='1d')
371
372    def _check_data_dim(self, x):
373        if x.ndim != 2:
374            pass
375
376
377class BatchNorm2d(_BatchNorm):
378    r"""
379    Batch Normalization layer over a 4D input.
380
381    Batch Normalization is widely used in convolutional networks. This layer
382    applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
383    additional channel dimension) to avoid internal covariate shift as described
384    in the paper `Batch Normalization: Accelerating Deep Network Training by
385    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
386    rescales and recenters the feature using a mini-batch of data and
387    the learned parameters which can be described in the following formula.
388
389    .. math::
390        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
391
392    Note:
393        The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
394        changed after net was initialized.
395        Note that the formula for updating the running_mean and running_var is
396        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
397        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
398
399    Args:
400        num_features (int): `C` from an expected input of size (N, C, H, W).
401        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
402        momentum (float): A floating hyperparameter of the momentum for the
403            running_mean and running_var computation. Default: 0.9.
404        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
405        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
406            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
407        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
408            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
409        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
410            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
411        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
412            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
413        use_batch_statistics (bool):
414
415            - If true, use the mean value and variance value of current batch data and track running mean
416              and running varance.
417            - If false, use the mean value and variance value of specified value, and not track statistical value.
418            - If None, The use_batch_statistics is automatically assigned process according to
419              the training and eval mode. During training, batchnorm2d process will be the same
420              with use_batch_statistics=True. Contrarily, in eval, batchnorm2d process will be the same
421              with use_batch_statistics=False. Default: None.
422
423        data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
424            Default: 'NCHW'.
425
426    Inputs:
427        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
428
429    Outputs:
430        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
431
432    Raises:
433        TypeError: If `num_features` is not an int.
434        TypeError: If `eps` is not a float.
435        ValueError: If `num_features` is less than 1.
436        ValueError: If `momentum` is not in range [0, 1].
437        ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
438
439    Supported Platforms:
440        ``Ascend`` ``GPU`` ``CPU``
441
442    Examples:
443        >>> import numpy as np
444        >>> import mindspore.nn as nn
445        >>> from mindspore import Tensor
446        >>> net = nn.BatchNorm2d(num_features=3)
447        >>> x = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
448        >>> output = net(x)
449        >>> print(output)
450        [[[[ 0.999995 0.999995 ]
451           [ 0.999995 0.999995 ]]
452          [[ 0.999995 0.999995 ]
453           [ 0.999995 0.999995 ]]
454          [[ 0.999995 0.999995 ]
455           [ 0.999995 0.999995 ]]]]
456    """
457
458    def __init__(self,
459                 num_features,
460                 eps=1e-5,
461                 momentum=0.9,
462                 affine=True,
463                 gamma_init='ones',
464                 beta_init='zeros',
465                 moving_mean_init='zeros',
466                 moving_var_init='ones',
467                 use_batch_statistics=None,
468                 data_format='NCHW'):
469        """Initialize BatchNorm2d."""
470        super(BatchNorm2d, self).__init__(num_features,
471                                          eps,
472                                          momentum,
473                                          affine,
474                                          gamma_init,
475                                          beta_init,
476                                          moving_mean_init,
477                                          moving_var_init,
478                                          use_batch_statistics,
479                                          input_dims='2d',
480                                          data_format=data_format)
481
482    def _check_data_dim(self, x):
483        if x.ndim != 4:
484            pass
485
486
487@constexpr
488def _check_3d_shape(input_shape, prim_name=None):
489    msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
490    if len(input_shape) != 5:
491        raise ValueError(f"{msg_prefix} input_shape must be 5-dimensional, but got the length of input_shape: "
492                         f"{len(input_shape)}.")
493
494
495class BatchNorm3d(Cell):
496    r"""
497    Batch Normalization layer over a 5D input.
498
499    Batch Normalization is widely used in convolutional networks. This layer
500    applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
501    additional channel dimension) to avoid internal covariate shift.
502
503    .. math::
504        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
505
506    Note:
507        The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
508        changed after net was initialized.
509        Note that the formula for updating the running_mean and running_var is
510        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
511        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
512
513    Args:
514        num_features (int): `C` from an expected input of size (N, C, D, H, W).
515        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
516        momentum (float): A floating hyperparameter of the momentum for the
517            running_mean and running_var computation. Default: 0.9.
518        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
519        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
520            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
521        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
522            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
523        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
524            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
525        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
526            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
527        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
528            use the mean value and variance value of specified value. If None, the training process will use the mean
529            and variance of current batch data and track the running mean and variance, the evaluation process will use
530            the running mean and variance. Default: None.
531        data_format (str): The optional value for data format is 'NCDHW'. Default: 'NCDHW'.
532
533    Inputs:
534        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
535
536    Outputs:
537        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`.
538
539    Raises:
540        TypeError: If `num_features` is not an int.
541        TypeError: If `eps` is not a float.
542        ValueError: If `num_features` is less than 1.
543        ValueError: If `momentum` is not in range [0, 1].
544        ValueError: If `data_format` is not 'NCDHW'.
545
546    Supported Platforms:
547        ``Ascend`` ``GPU`` ``CPU``
548
549    Examples:
550        >>> import numpy as np
551        >>> import mindspore.nn as nn
552        >>> from mindspore import Tensor
553        >>> net = nn.BatchNorm3d(num_features=3)
554        >>> x = Tensor(np.ones([16, 3, 10, 32, 32]).astype(np.float32))
555        >>> output = net(x)
556        >>> print(output.shape)
557        (16, 3, 10, 32, 32)
558    """
559
560    def __init__(self,
561                 num_features,
562                 eps=1e-5,
563                 momentum=0.9,
564                 affine=True,
565                 gamma_init='ones',
566                 beta_init='zeros',
567                 moving_mean_init='zeros',
568                 moving_var_init='ones',
569                 use_batch_statistics=None,
570                 data_format='NCDHW'):
571        """Initialize BatchNorm3d."""
572        super(BatchNorm3d, self).__init__()
573        self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
574        self.reshape = P.Reshape()
575        self.bn2d = BatchNorm2d(num_features=num_features,
576                                eps=eps,
577                                momentum=momentum,
578                                affine=affine,
579                                gamma_init=gamma_init,
580                                beta_init=beta_init,
581                                moving_mean_init=moving_mean_init,
582                                moving_var_init=moving_var_init,
583                                use_batch_statistics=use_batch_statistics,
584                                data_format="NCHW")
585
586    def construct(self, input_x):
587        x_shape = F.shape(input_x)
588        _check_3d_shape(x_shape, self.cls_name)
589        input_x = self.reshape(input_x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
590        bn2d_out = self.bn2d(input_x)
591        bn3d_out = self.reshape(bn2d_out, x_shape)
592        return bn3d_out
593
594
595class GlobalBatchNorm(_BatchNorm):
596    r"""
597    Global Batch Normalization layer over a N-dimension input.
598
599    Global Batch Normalization is cross device synchronized Batch Normalization. The implementation of
600    Batch Normalization only normalizes the data within each device. Global Normalization will normalize
601    the input within the group.It has been described in the paper `Batch Normalization: Accelerating Deep Network
602    Training by Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
603    feature using a mini-batch of data and the learned parameters which can be described in the following formula.
604
605    .. math::
606        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
607
608    Note:
609        Currently, GlobalBatchNorm only supports 2D and 4D inputs.
610
611    Args:
612        num_features (int): `C` from an expected input of size (N, C, H, W).
613        device_num_each_group (int): The number of devices in each group. Default: 2.
614        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
615        momentum (float): A floating hyperparameter of the momentum for the
616            running_mean and running_var computation. Default: 0.9.
617        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
618        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
619            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
620            'he_uniform', etc. Default: 'ones'.
621        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
622            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
623            'he_uniform', etc. Default: 'zeros'.
624        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
625            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
626            'he_uniform', etc. Default: 'zeros'.
627        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
628            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
629            'he_uniform', etc. Default: 'ones'.
630        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
631            use the mean value and variance value of specified value. If None, training process will use the mean and
632            variance of current batch data and track the running mean and variance, eval process will use the running
633            mean and variance. Default: None.
634
635    Inputs:
636        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
637
638    Outputs:
639        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
640
641    Raises:
642        TypeError: If `num_features` or `device_num_each_group` is not an int.
643        TypeError: If `eps` is not a float.
644        ValueError: If `num_features` is less than 1.
645        ValueError: If `momentum` is not in range [0, 1].
646        ValueError: If `device_num_each_group` is less than 2.
647
648    Supported Platforms:
649        ``Ascend``
650
651    Examples:
652        >>> # This example should be run with multiple processes.
653        >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
654        >>> import numpy as np
655        >>> from mindspore.communication import init
656        >>> from mindspore import context
657        >>> from mindspore.context import ParallelMode
658        >>> from mindspore import nn
659        >>> from mindspore import Tensor
660        >>>
661        >>> context.set_context(mode=context.GRAPH_MODE)
662        >>> init()
663        >>> context.reset_auto_parallel_context()
664        >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
665        >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=2)
666        >>> x = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
667        >>> output = global_bn_op(x)
668        >>> print(output)
669        [[[[ 0.999995 0.999995 ]
670           [ 0.999995 0.999995 ]]
671          [[ 0.999995 0.999995 ]
672           [ 0.999995 0.999995 ]]
673          [[ 0.999995 0.999995 ]
674           [ 0.999995 0.999995 ]]]]
675    """
676
677    @deprecated("1.2", "SyncBatchNorm", True)
678    def __init__(self,
679                 num_features,
680                 eps=1e-5,
681                 momentum=0.9,
682                 affine=True,
683                 gamma_init='ones',
684                 beta_init='zeros',
685                 moving_mean_init='zeros',
686                 moving_var_init='ones',
687                 use_batch_statistics=None,
688                 device_num_each_group=2):
689        """Initialize GlobalBatchNorm."""
690        super(GlobalBatchNorm, self).__init__(num_features,
691                                              eps,
692                                              momentum,
693                                              affine,
694                                              gamma_init,
695                                              beta_init,
696                                              moving_mean_init,
697                                              moving_var_init,
698                                              use_batch_statistics,
699                                              device_num_each_group,
700                                              input_dims='both')
701        self.group_device_num = validator.check_positive_int(device_num_each_group, "device_num_each_group",
702                                                             self.cls_name)
703        if self.group_device_num <= 1:
704            raise ValueError(f"For '{self.cls_name}', the 'device_num_each_group' must be greater than 1, "
705                             f"but got {self.group_device_num}.")
706
707    def _check_data_dim(self, x):
708        if x.dim == 0:
709            pass
710
711
712class SyncBatchNorm(_BatchNorm):
713    r"""
714    Sync Batch Normalization layer over a N-dimension input.
715
716    Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch
717    Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input
718    within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
719    Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
720    feature using a mini-batch of data and the learned parameters which can be described in the following formula.
721
722    .. math::
723        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
724
725    Note:
726        Currently, SyncBatchNorm only supports 2D and 4D inputs.
727
728    Args:
729        num_features (int): `C` from an expected input of size (N, C, H, W).
730        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
731        momentum (float): A floating hyperparameter of the momentum for the
732            running_mean and running_var computation. Default: 0.9.
733        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
734        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
735            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
736            'he_uniform', etc. Default: 'ones'.
737        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
738            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
739            'he_uniform', etc. Default: 'zeros'.
740        moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
741            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
742            'he_uniform', etc. Default: 'zeros'.
743        moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
744            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
745            'he_uniform', etc. Default: 'ones'.
746        use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
747            use the mean value and variance value of specified value. If None, training process will use the mean and
748            variance of current batch data and track the running mean and variance, eval process will use the running
749            mean and variance. Default: None.
750        process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists.
751            Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same
752            group. All int values must be in [0, rank_size) and different from each other. Default: None, indicating
753            synchronization across all devices.
754
755    Inputs:
756        - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
757
758    Outputs:
759        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
760
761    Raises:
762        TypeError: If `num_features` is not an int.
763        TypeError: If `eps` is not a float.
764        TypeError: If `process_groups` is not a list.
765        ValueError: If `num_features` is less than 1.
766        ValueError: If `momentum` is not in range [0, 1].
767        ValueError: If rank_id in `process_groups` is not in range [0, rank_size).
768
769    Supported Platforms:
770        ``Ascend``
771
772    Examples:
773        >>> # This example should be run with multiple processes.
774        >>> # Please refer to the tutorial > Distributed Training on mindspore.cn.
775        >>> import numpy as np
776        >>> from mindspore.communication import init
777        >>> from mindspore import context
778        >>> from mindspore.context import ParallelMode
779        >>> from mindspore import Tensor
780        >>> from mindspore import nn
781        >>> from mindspore import dtype as mstype
782        >>>
783        >>> context.set_context(mode=context.GRAPH_MODE)
784        >>> init()
785        >>> context.reset_auto_parallel_context()
786        >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
787        >>> sync_bn_op = nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]])
788        >>> x = Tensor(np.ones([1, 3, 2, 2]), mstype.float32)
789        >>> output = sync_bn_op(x)
790        >>> print(output)
791        [[[[ 0.999995 0.999995 ]
792           [ 0.999995 0.999995 ]]
793          [[ 0.999995 0.999995 ]
794           [ 0.999995 0.999995 ]]
795          [[ 0.999995 0.999995 ]
796           [ 0.999995 0.999995 ]]]]
797    """
798
799    def __init__(self,
800                 num_features,
801                 eps=1e-5,
802                 momentum=0.9,
803                 affine=True,
804                 gamma_init='ones',
805                 beta_init='zeros',
806                 moving_mean_init='zeros',
807                 moving_var_init='ones',
808                 use_batch_statistics=None,
809                 process_groups=None):
810        """Initialize SyncBatchNorm."""
811        super(SyncBatchNorm, self).__init__(num_features,
812                                            eps,
813                                            momentum,
814                                            affine,
815                                            gamma_init,
816                                            beta_init,
817                                            moving_mean_init,
818                                            moving_var_init,
819                                            use_batch_statistics,
820                                            process_groups=process_groups,
821                                            input_dims='both')
822
823    def _check_data_dim(self, x):
824        if x.dim == 0:
825            pass
826
827
828class LayerNorm(Cell):
829    r"""
830    Applies Layer Normalization over a mini-batch of inputs.
831
832    Layer Normalization is widely used in recurrent neural networks. It applies
833    normalization on a mini-batch of inputs for each single training case as described
834    in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch
835    Normalization, Layer Normalization performs exactly the same computation at training and
836    testing time. It can be described using the following formula. It is applied across all channels
837    and pixel but only one batch size.
838
839    .. math::
840        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
841
842    Args:
843        normalized_shape (Union(tuple[int], list[int]): The normalization is performed over axis
844            `begin_norm_axis ... R - 1`.
845        begin_norm_axis (int): The first normalization dimension: normalization will be performed along dimensions
846            `begin_norm_axis: rank(inputs)`, the value should be in [-1, rank(input)). Default: -1.
847        begin_params_axis (int): The first parameter(beta, gamma)dimension: scale and centering parameters
848            will have dimensions `begin_params_axis: rank(inputs)` and will be broadcast with
849            the normalized inputs accordingly, the value should be in [-1, rank(input)). Default: -1.
850        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
851            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
852            'he_uniform', etc. Default: 'ones'.
853        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
854            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
855            'he_uniform', etc. Default: 'zeros'.
856        epsilon (float): A value added to the denominator for numerical stability. Default: 1e-7.
857
858    Inputs:
859        - **x** (Tensor) - The shape of 'x' is :math:`(x_1, x_2, ..., x_R)`,
860          and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
861
862    Outputs:
863        Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
864
865    Raises:
866        TypeError: If `normalized_shape` is neither a list nor tuple.
867        TypeError: If `begin_norm_axis` or `begin_params_axis` is not an int.
868        TypeError: If `epsilon` is not a float.
869
870    Supported Platforms:
871        ``Ascend`` ``GPU`` ``CPU``
872
873    Examples:
874        >>> x = Tensor(np.ones([20, 5, 10, 10]), mindspore.float32)
875        >>> shape1 = x.shape[1:]
876        >>> m = nn.LayerNorm(shape1,  begin_norm_axis=1, begin_params_axis=1)
877        >>> output = m(x).shape
878        >>> print(output)
879        (20, 5, 10, 10)
880    """
881
882    def __init__(self,
883                 normalized_shape,
884                 begin_norm_axis=-1,
885                 begin_params_axis=-1,
886                 gamma_init='ones',
887                 beta_init='zeros',
888                 epsilon=1e-7
889                 ):
890        """Initialize LayerNorm."""
891        super(LayerNorm, self).__init__()
892        if not isinstance(normalized_shape, (tuple, list)):
893            raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' should be tuple[int] or list[int], "
894                            f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
895        self.normalized_shape = normalized_shape
896        self.begin_norm_axis = begin_norm_axis
897        self.begin_params_axis = begin_params_axis
898        self.epsilon = epsilon
899        self.gamma = Parameter(initializer(
900            gamma_init, normalized_shape), name="gamma")
901        self.beta = Parameter(initializer(
902            beta_init, normalized_shape), name="beta")
903        self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
904                                      begin_params_axis=self.begin_params_axis,
905                                      epsilon=self.epsilon)
906
907    def construct(self, input_x):
908        y, _, _ = self.layer_norm(input_x, self.gamma, self.beta)
909        return y
910
911    def extend_repr(self):
912        """Display instance object as string."""
913        return 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
914            self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
915
916
917class InstanceNorm2d(Cell):
918    r"""
919    Instance Normalization layer over a 4D input.
920
921    This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
922    additional channel dimension) as described in the paper `Instance Normalization: The Missing Ingredient for
923    Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
924    of data and the learned parameters which can be described in the following formula.
925
926    .. math::
927        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
928
929    \gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation
930    is calculated via the biased estimator.
931
932    This layer uses instance statistics computed from input data in both training and evaluation modes.
933
934    InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each
935    channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data.
936
937    Note:
938        Note that the formula for updating the running_mean and running_var is
939        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
940        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
941
942    Args:
943        num_features (int): `C` from an expected input of size (N, C, H, W).
944        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
945        momentum (float): A floating hyperparameter of the momentum for the
946            running_mean and running_var computation. Default: 0.1.
947        affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: True.
948        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
949            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
950        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
951            The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
952
953    Inputs:
954        - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
955
956    Outputs:
957        Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`. Same type and
958        shape as the `x`.
959
960    Supported Platforms:
961        ``GPU``
962
963    Raises:
964        TypeError: If `num_features` is not an int.
965        TypeError: If `eps` is not a float.
966        TypeError: If `momentum` is not a float.
967        TypeError: If `affine` is not a bool.
968        TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
969            float32.
970        ValueError: If `num_features` is less than 1.
971        ValueError: If `momentum` is not in range [0, 1].
972        KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
973            exists.
974
975    Examples:
976        >>> import mindspore
977        >>> import numpy as np
978        >>> import mindspore.nn as nn
979        >>> from mindspore import Tensor
980        >>> net = nn.InstanceNorm2d(3)
981        >>> x = Tensor(np.ones([2, 3, 2, 2]), mindspore.float32)
982        >>> output = net(x)
983        >>> print(output.shape)
984        (2, 3, 2, 2)
985    """
986
987    @cell_attr_register
988    def __init__(self,
989                 num_features,
990                 eps=1e-5,
991                 momentum=0.1,
992                 affine=True,
993                 gamma_init='ones',
994                 beta_init='zeros'):
995        """Initialize InstanceNorm2d."""
996        super(InstanceNorm2d, self).__init__()
997        validator.check_value_type('num_features', num_features, [int], self.cls_name)
998        validator.check_value_type('eps', eps, [float], self.cls_name)
999        validator.check_value_type('momentum', momentum, [float], self.cls_name)
1000        validator.check_value_type('affine', affine, [bool], self.cls_name)
1001        args_input = {"gamma_init": gamma_init, "beta_init": beta_init}
1002        self.check_types_valid(args_input, 'InstanceNorm2d')
1003        if num_features < 1:
1004            raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
1005
1006        if momentum < 0 or momentum > 1:
1007            raise ValueError(f"For '{self.cls_name}', the 'momentum' should be a number in range [0, 1], "
1008                             f"but got {momentum}.")
1009        self.num_features = num_features
1010        self.eps = eps
1011        self.input_dims = '2d'
1012        self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
1013        self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
1014        self.gamma = Parameter(initializer(
1015            gamma_init, num_features), name="gamma", requires_grad=affine)
1016        self.beta = Parameter(initializer(
1017            beta_init, num_features), name="beta", requires_grad=affine)
1018
1019        self.shape = P.Shape()
1020        self.momentum = momentum
1021        self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
1022
1023    def _check_data_dim(self, x):
1024        raise NotImplementedError
1025
1026    def construct(self, x):
1027        _shape_check_bn(self.shape(x), self.input_dims, self.cls_name)
1028        return self.instance_bn(x,
1029                                self.gamma,
1030                                self.beta,
1031                                self.moving_mean,
1032                                self.moving_variance)[0]
1033
1034    def extend_repr(self):
1035        return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
1036            self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
1037
1038    def check_types_valid(self, args_dict, name):
1039        for key, _ in args_dict.items():
1040            val = args_dict[key]
1041            if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
1042                raise TypeError(f"For '{self.cls_name}', the type of args_dict['{key}'] should be in "
1043                                f"[Tensor, numbers.Number, str, Initializer], but got type {type(val).__name__}.")
1044            if isinstance(val, Tensor) and val.dtype != mstype.float32:
1045                raise TypeError(f"For '{self.cls_name}', the type of args_dict['{key}'] should be float32, "
1046                                f"but got {val.dtype}.")
1047
1048
1049class GroupNorm(Cell):
1050    r"""
1051    Group Normalization over a mini-batch of inputs.
1052
1053    Group Normalization is widely used in recurrent neural networks. It applies
1054    normalization on a mini-batch of inputs for each single training case as described
1055    in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group Normalization
1056    divides the channels into groups and computes within each group the mean and variance for normalization,
1057    and it performs very stable over a wide range of batch size. It can be described using the following formula.
1058
1059    .. math::
1060        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
1061
1062    Args:
1063        num_groups (int): The number of groups to be divided along the channel dimension.
1064        num_channels (int): The number of channels per group.
1065        eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
1066        affine (bool): A bool value, this layer will have learnable affine parameters when set to true. Default: True.
1067        gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
1068            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
1069            'he_uniform', etc. Default: 'ones'. If gamma_init is a Tensor, the shape must be [num_channels].
1070        beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
1071            The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
1072            'he_uniform', etc. Default: 'zeros'. If beta_init is a Tensor, the shape must be [num_channels].
1073
1074    Inputs:
1075        - **x** (Tensor) - The input feature with shape [N, C, H, W].
1076
1077    Outputs:
1078        Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
1079
1080    Raises:
1081        TypeError: If `num_groups` or `num_channels` is not an int.
1082        TypeError: If `eps` is not a float.
1083        TypeError: If `affine` is not a bool.
1084        ValueError: If `num_groups` or `num_channels` is less than 1.
1085        ValueError: If `num_channels` is not divided by `num_groups`.
1086
1087    Supported Platforms:
1088        ``Ascend`` ``GPU`` ``CPU``
1089
1090    Examples:
1091        >>> group_norm_op = nn.GroupNorm(2, 2)
1092        >>> x = Tensor(np.ones([1, 2, 4, 4], np.float32))
1093        >>> output = group_norm_op(x)
1094        >>> print(output)
1095        [[[[0. 0. 0. 0.]
1096           [0. 0. 0. 0.]
1097           [0. 0. 0. 0.]
1098           [0. 0. 0. 0.]]
1099          [[0. 0. 0. 0.]
1100           [0. 0. 0. 0.]
1101           [0. 0. 0. 0.]
1102           [0. 0. 0. 0.]]]]
1103    """
1104
1105    def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros'):
1106        """Initialize GroupNorm."""
1107        super(GroupNorm, self).__init__()
1108        self.num_groups = validator.check_positive_int(num_groups, "num_groups", self.cls_name)
1109        self.num_channels = validator.check_positive_int(num_channels, "num_channels", self.cls_name)
1110        if num_channels % num_groups != 0:
1111            raise ValueError(f"For '{self.cls_name}', the 'num_channels' should be divided by 'num_groups', "
1112                             f"but got 'num_channels': {num_channels}, 'num_groups': {num_groups}.")
1113        self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
1114        self.affine = validator.check_bool(affine, arg_name="affine", prim_name=self.cls_name)
1115
1116        gamma = initializer(gamma_init, num_channels)
1117        beta = initializer(beta_init, num_channels)
1118        if self.affine:
1119            self.gamma = Parameter(gamma, name='gamma')
1120            self.beta = Parameter(beta, name='beta')
1121        else:
1122            self.gamma = gamma
1123            self.beta = beta
1124        self.shape = F.shape
1125        self.reshape = F.reshape
1126        self.reduce_mean = P.ReduceMean(keep_dims=True)
1127        self.square = F.square
1128        self.reduce_sum = P.ReduceSum(keep_dims=True)
1129        self.sqrt = P.Sqrt()
1130
1131    def _cal_output(self, x):
1132        """calculate groupnorm output"""
1133        batch, channel, height, width = self.shape(x)
1134        _channel_check(channel, self.num_channels, self.cls_name)
1135        x = self.reshape(x, (batch, self.num_groups, -1))
1136        mean = self.reduce_mean(x, 2)
1137        var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups)
1138        std = self.sqrt(var + self.eps)
1139        x = (x - mean) / std
1140        x = self.reshape(x, (batch, channel, height, width))
1141        output = x * self.reshape(self.gamma, (-1, 1, 1)) + self.reshape(self.beta, (-1, 1, 1))
1142        return output
1143
1144    def construct(self, x):
1145        _shape_check(self.shape(x), self.cls_name)
1146        output = self._cal_output(x)
1147        return output
1148
1149    def extend_repr(self):
1150        """Display instance object as string."""
1151        return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
1152