• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""The V2 implementation of Normalization layers."""
16# pylint: disable=g-classes-have-attributes
17
18from tensorflow.python.distribute import distribution_strategy_context
19from tensorflow.python.distribute import reduce_util
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.keras import backend
25from tensorflow.python.keras import constraints
26from tensorflow.python.keras import initializers
27from tensorflow.python.keras import regularizers
28from tensorflow.python.keras.engine.base_layer import Layer
29from tensorflow.python.keras.engine.input_spec import InputSpec
30from tensorflow.python.keras.utils import control_flow_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn
35from tensorflow.python.ops import state_ops
36from tensorflow.python.ops import variables as tf_variables
37from tensorflow.python.ops.control_flow_ops import get_enclosing_xla_context
38from tensorflow.python.platform import tf_logging as logging
39from tensorflow.python.util.tf_export import keras_export
40
41
42class BatchNormalizationBase(Layer):
43  r"""Layer that normalizes its inputs.
44
45  Batch normalization applies a transformation that maintains the mean output
46  close to 0 and the output standard deviation close to 1.
47
48  Importantly, batch normalization works differently during training and
49  during inference.
50
51  **During training** (i.e. when using `fit()` or when calling the layer/model
52  with the argument `training=True`), the layer normalizes its output using
53  the mean and standard deviation of the current batch of inputs. That is to
54  say, for each channel being normalized, the layer returns
55  `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:
56
57  - `epsilon` is small constant (configurable as part of the constructor
58  arguments)
59  - `gamma` is a learned scaling factor (initialized as 1), which
60  can be disabled by passing `scale=False` to the constructor.
61  - `beta` is a learned offset factor (initialized as 0), which
62  can be disabled by passing `center=False` to the constructor.
63
64  **During inference** (i.e. when using `evaluate()` or `predict()` or when
65  calling the layer/model with the argument `training=False` (which is the
66  default), the layer normalizes its output using a moving average of the
67  mean and standard deviation of the batches it has seen during training. That
68  is to say, it returns
69  `gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`.
70
71  `self.moving_mean` and `self.moving_var` are non-trainable variables that
72  are updated each time the layer in called in training mode, as such:
73
74  - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
75  - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
76
77  As such, the layer will only normalize its inputs during inference
78  *after having been trained on data that has similar statistics as the
79  inference data*.
80
81  Args:
82    axis: Integer or a list of integers, the axis that should be normalized
83      (typically the features axis). For instance, after a `Conv2D` layer with
84      `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
85    momentum: Momentum for the moving average.
86    epsilon: Small float added to variance to avoid dividing by zero.
87    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
88      is ignored.
89    scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
90      next layer is linear (also e.g. `nn.relu`), this can be disabled since the
91      scaling will be done by the next layer.
92    beta_initializer: Initializer for the beta weight.
93    gamma_initializer: Initializer for the gamma weight.
94    moving_mean_initializer: Initializer for the moving mean.
95    moving_variance_initializer: Initializer for the moving variance.
96    beta_regularizer: Optional regularizer for the beta weight.
97    gamma_regularizer: Optional regularizer for the gamma weight.
98    beta_constraint: Optional constraint for the beta weight.
99    gamma_constraint: Optional constraint for the gamma weight.
100    renorm: Whether to use [Batch Renormalization](
101      https://arxiv.org/abs/1702.03275). This adds extra variables during
102        training. The inference is the same for either value of this parameter.
103    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
104      scalar `Tensors` used to clip the renorm correction. The correction `(r,
105      d)` is used as `corrected_value = normalized_value * r + d`, with `r`
106      clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
107      dmax are set to inf, 0, inf, respectively.
108    renorm_momentum: Momentum used to update the moving means and standard
109      deviations with renorm. Unlike `momentum`, this affects training and
110      should be neither too small (which would add noise) nor too large (which
111      would give stale estimates). Note that `momentum` is still applied to get
112      the means and variances for inference.
113    fused: if `True`, use a faster, fused implementation, or raise a ValueError
114      if the fused implementation cannot be used. If `None`, use the faster
115      implementation if possible. If False, do not used the fused
116      implementation.
117      Note that in TensorFlow 1.x, the meaning of `fused=True` is different: if
118        `False`, the layer uses the system-recommended implementation.
119    trainable: Boolean, if `True` the variables will be marked as trainable.
120    virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`,
121      which means batch normalization is performed across the whole batch. When
122      `virtual_batch_size` is not `None`, instead perform "Ghost Batch
123      Normalization", which creates virtual sub-batches which are each
124      normalized separately (with shared gamma, beta, and moving statistics).
125      Must divide the actual batch size during execution.
126    adjustment: A function taking the `Tensor` containing the (dynamic) shape of
127      the input tensor and returning a pair (scale, bias) to apply to the
128      normalized values (before gamma and beta), only during training. For
129      example, if `axis=-1`,
130        `adjustment = lambda shape: (
131          tf.random.uniform(shape[-1:], 0.93, 1.07),
132          tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
133            value by up to 7% up or down, then shift the result by up to 0.1
134            (with independent scaling and bias for each feature but shared
135            across all examples), and finally apply gamma and/or beta. If
136            `None`, no adjustment is applied. Cannot be specified if
137            virtual_batch_size is specified.
138
139  Call arguments:
140    inputs: Input tensor (of any rank).
141    training: Python boolean indicating whether the layer should behave in
142      training mode or in inference mode.
143      - `training=True`: The layer will normalize its inputs using the mean and
144        variance of the current batch of inputs.
145      - `training=False`: The layer will normalize its inputs using the mean and
146        variance of its moving statistics, learned during training.
147
148  Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
149    integers, does not include the samples axis) when using this layer as the
150    first layer in a model.
151
152  Output shape: Same shape as input.
153
154  Reference:
155    - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
156  """
157
158  # By default, the base class uses V2 behavior. The BatchNormalization V1
159  # subclass sets this to False to use the V1 behavior.
160  _USE_V2_BEHAVIOR = True
161
162  def __init__(self,
163               axis=-1,
164               momentum=0.99,
165               epsilon=1e-3,
166               center=True,
167               scale=True,
168               beta_initializer='zeros',
169               gamma_initializer='ones',
170               moving_mean_initializer='zeros',
171               moving_variance_initializer='ones',
172               beta_regularizer=None,
173               gamma_regularizer=None,
174               beta_constraint=None,
175               gamma_constraint=None,
176               renorm=False,
177               renorm_clipping=None,
178               renorm_momentum=0.99,
179               fused=None,
180               trainable=True,
181               virtual_batch_size=None,
182               adjustment=None,
183               name=None,
184               **kwargs):
185    super(BatchNormalizationBase, self).__init__(name=name, **kwargs)
186    if isinstance(axis, (list, tuple)):
187      self.axis = axis[:]
188    elif isinstance(axis, int):
189      self.axis = axis
190    else:
191      raise TypeError('Expected an int or a list/tuple of ints for the '
192                      'argument \'axis\', but received: %r' % axis)
193    self.momentum = momentum
194    self.epsilon = epsilon
195    self.center = center
196    self.scale = scale
197    self.beta_initializer = initializers.get(beta_initializer)
198    self.gamma_initializer = initializers.get(gamma_initializer)
199    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
200    self.moving_variance_initializer = initializers.get(
201        moving_variance_initializer)
202    self.beta_regularizer = regularizers.get(beta_regularizer)
203    self.gamma_regularizer = regularizers.get(gamma_regularizer)
204    self.beta_constraint = constraints.get(beta_constraint)
205    self.gamma_constraint = constraints.get(gamma_constraint)
206    self.renorm = renorm
207    self.virtual_batch_size = virtual_batch_size
208    self.adjustment = adjustment
209    if self._USE_V2_BEHAVIOR:
210      if fused:
211        self._raise_if_fused_cannot_be_used()
212      # We leave fused as None if self._fused_can_be_used()==True, since we
213      # still may set it to False in self.build() if the input rank is not 4.
214      elif fused is None and not self._fused_can_be_used():
215        fused = False
216    elif fused is None:
217      fused = True
218    self.supports_masking = True
219
220    self.fused = fused
221    self._bessels_correction_test_only = True
222    self.trainable = trainable
223
224    if renorm:
225      renorm_clipping = renorm_clipping or {}
226      keys = ['rmax', 'rmin', 'dmax']
227      if set(renorm_clipping) - set(keys):
228        raise ValueError('renorm_clipping %s contains keys not in %s' %
229                         (renorm_clipping, keys))
230      self.renorm_clipping = renorm_clipping
231      self.renorm_momentum = renorm_momentum
232
233  def _raise_if_fused_cannot_be_used(self):
234    """Raises a ValueError if fused implementation cannot be used.
235
236    In addition to the checks done in this function, the input tensors rank must
237    be 4 or 5. The input rank check can only be done once the input shape is
238    known.
239    """
240    # Note the ValueErrors in this function are caught and not reraised in
241    # _fused_can_be_used(). No other exception besides ValueError should be
242    # raised here.
243
244    # Currently fused batch norm doesn't support renorm. It also only supports a
245    # channel dimension on axis 1 or 3 (rank=4) / 1 or 4 (rank5), when no
246    # virtual batch size or adjustment is used.
247    if self.renorm:
248      raise ValueError('Passing both `fused=True` and `renorm=True` is '
249                       'not supported')
250    axis = [self.axis] if isinstance(self.axis, int) else self.axis
251    # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, when the
252    # input rank is 4. Similarly, the valid axis is -4, -1, 1, 4 when the rank
253    # is 5. The combination of ranks and axes will be checked later.
254    if len(axis) > 1 or axis[0] not in (-4, -3, -1, 1, 3, 4):
255      raise ValueError('Passing `fused=True` is only supported when axis is 1 '
256                       'or 3 for input rank = 4 or 1 or 4 for input rank = 5. '
257                       'Got axis %s' % (axis,))
258    if self.virtual_batch_size is not None:
259      raise ValueError('Passing `fused=True` is not supported when '
260                       '`virtual_batch_size` is specified.')
261    if self.adjustment is not None:
262      raise ValueError('Passing `fused=True` is not supported when '
263                       '`adjustment` is specified.')
264    # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check.
265    if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None):
266      raise ValueError(
267          'Passing `fused=True` is only supported when the compute '
268          'dtype is float16, bfloat16, or float32. Got dtype: %s' %
269          (self._compute_dtype,))
270
271  def _fused_can_be_used(self):
272    try:
273      self._raise_if_fused_cannot_be_used()
274      return True
275    except ValueError:
276      return False
277
278  @property
279  def trainable(self):
280    return self._trainable
281
282  @trainable.setter
283  def trainable(self, value):
284    self._trainable = value
285
286  @property
287  def _param_dtype(self):
288    # Raise parameters of fp16 batch norm to fp32
289    if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16:
290      return dtypes.float32
291    else:
292      return self.dtype or dtypes.float32
293
294  def _support_zero_size_input(self):
295    return distribution_strategy_context.has_strategy() and getattr(
296        distribution_strategy_context.get_strategy().extended,
297        'experimental_enable_get_next_as_optional', False)
298
299  def build(self, input_shape):
300    input_shape = tensor_shape.TensorShape(input_shape)
301    if not input_shape.ndims:
302      raise ValueError('Input has undefined rank.')
303    ndims = len(input_shape)
304
305    # Convert axis to list and resolve negatives
306    if isinstance(self.axis, int):
307      self.axis = [self.axis]
308
309    for idx, x in enumerate(self.axis):
310      if x < 0:
311        self.axis[idx] = ndims + x
312
313    # Validate axes
314    for x in self.axis:
315      if x < 0 or x >= ndims:
316        raise ValueError('Invalid axis: %s' % (self.axis,))
317    if len(self.axis) != len(set(self.axis)):
318      raise ValueError('Duplicate axis: %s' % (self.axis,))
319
320    if self.virtual_batch_size is not None:
321      if self.virtual_batch_size <= 0:
322        raise ValueError('virtual_batch_size must be a positive integer that '
323                         'divides the true batch size of the input tensor')
324      # If using virtual batches, the first dimension must be the batch
325      # dimension and cannot be the batch norm axis
326      if 0 in self.axis:
327        raise ValueError('When using virtual_batch_size, the batch dimension '
328                         'must be 0 and thus axis cannot include 0. '
329                         'Received axis=%s' % (self.axis,))
330      if self.adjustment is not None:
331        raise ValueError('When using virtual_batch_size, adjustment cannot '
332                         'be specified')
333
334    if self.fused in (None, True):
335      # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
336      # output back to its original shape accordingly.
337      if self._USE_V2_BEHAVIOR:
338        if self.fused is None:
339          self.fused = ndims in (4, 5)
340        elif self.fused and ndims not in (4, 5):
341          raise ValueError('Batch normalization layers with `fused=True` only '
342                           'support 4D or 5D input tensors. '
343                           'Received tensor with shape: %s' %
344                           (tuple(input_shape),))
345      else:
346        assert self.fused is not None
347        self.fused = (ndims in (4, 5) and self._fused_can_be_used())
348      # TODO(chrisying): fused batch norm is currently not supported for
349      # multi-axis batch norm and by extension virtual batches. In some cases,
350      # it might be possible to use fused batch norm but would require reshaping
351      # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is
352      # particularly tricky. A compromise might be to just support the most
353      # common use case (turning 5D w/ virtual batch to NCHW)
354
355    if self.fused:
356      if self.axis == [1] and ndims == 4:
357        self._data_format = 'NCHW'
358      elif self.axis == [1] and ndims == 5:
359        self._data_format = 'NCDHW'
360      elif self.axis == [3] and ndims == 4:
361        self._data_format = 'NHWC'
362      elif self.axis == [4] and ndims == 5:
363        self._data_format = 'NDHWC'
364      elif ndims == 5:
365        # 5D tensors that can be passed in but should not use fused batch norm
366        # due to unsupported axis.
367        self.fused = False
368      else:
369        if ndims == 4:
370          raise ValueError(
371              'Unsupported axis. The use of `fused=True` is only possible with '
372              '`axis=1` or `axis=3` for 4D input tensors. Received '
373              'axis=%s' % (self.axis,))
374        else:
375          raise ValueError(
376              'Unsupported axis. The use of `fused=True` is only possible with '
377              '`axis=1` or `axis=4` for 5D input tensors. Received '
378              'axis=%s' % (self.axis,))
379
380    axis_to_dim = {x: input_shape.dims[x].value for x in self.axis}
381    for x in axis_to_dim:
382      if axis_to_dim[x] is None:
383        raise ValueError('Input has undefined `axis` dimension. Received input '
384                         'with shape %s. Axis value: %s' %
385                         (tuple(input_shape), self.axis))
386    self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim)
387
388    if len(axis_to_dim) == 1 and self.virtual_batch_size is None:
389      # Single axis batch norm (most common/default use-case)
390      param_shape = (list(axis_to_dim.values())[0],)
391    else:
392      # Parameter shape is the original shape but with 1 in all non-axis dims
393      param_shape = [
394          axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
395      ]
396      if self.virtual_batch_size is not None:
397        # When using virtual batches, add an extra dim at index 1
398        param_shape.insert(1, 1)
399        for idx, x in enumerate(self.axis):
400          self.axis[idx] = x + 1  # Account for added dimension
401
402    if self.scale:
403      self.gamma = self.add_weight(
404          name='gamma',
405          shape=param_shape,
406          dtype=self._param_dtype,
407          initializer=self.gamma_initializer,
408          regularizer=self.gamma_regularizer,
409          constraint=self.gamma_constraint,
410          trainable=True,
411          experimental_autocast=False)
412    else:
413      self.gamma = None
414      if self.fused:
415        self._gamma_const = backend.constant(
416            1.0, dtype=self._param_dtype, shape=param_shape)
417
418    if self.center:
419      self.beta = self.add_weight(
420          name='beta',
421          shape=param_shape,
422          dtype=self._param_dtype,
423          initializer=self.beta_initializer,
424          regularizer=self.beta_regularizer,
425          constraint=self.beta_constraint,
426          trainable=True,
427          experimental_autocast=False)
428    else:
429      self.beta = None
430      if self.fused:
431        self._beta_const = backend.constant(
432            0.0, dtype=self._param_dtype, shape=param_shape)
433
434    try:
435      # Disable variable partitioning when creating the moving mean and variance
436      if hasattr(self, '_scope') and self._scope:
437        partitioner = self._scope.partitioner
438        self._scope.set_partitioner(None)
439      else:
440        partitioner = None
441      self.moving_mean = self.add_weight(
442          name='moving_mean',
443          shape=param_shape,
444          dtype=self._param_dtype,
445          initializer=self.moving_mean_initializer,
446          synchronization=tf_variables.VariableSynchronization.ON_READ,
447          trainable=False,
448          aggregation=tf_variables.VariableAggregation.MEAN,
449          experimental_autocast=False)
450
451      self.moving_variance = self.add_weight(
452          name='moving_variance',
453          shape=param_shape,
454          dtype=self._param_dtype,
455          initializer=self.moving_variance_initializer,
456          synchronization=tf_variables.VariableSynchronization.ON_READ,
457          trainable=False,
458          aggregation=tf_variables.VariableAggregation.MEAN,
459          experimental_autocast=False)
460
461      if self.renorm:
462        # In batch renormalization we track the inference moving stddev instead
463        # of the moving variance to more closely align with the paper.
464        def moving_stddev_initializer(*args, **kwargs):
465          return math_ops.sqrt(
466              self.moving_variance_initializer(*args, **kwargs))
467
468        with distribution_strategy_context.get_strategy(
469        ).extended.colocate_vars_with(self.moving_variance):
470          self.moving_stddev = self.add_weight(
471              name='moving_stddev',
472              shape=param_shape,
473              dtype=self._param_dtype,
474              initializer=moving_stddev_initializer,
475              synchronization=tf_variables.VariableSynchronization.ON_READ,
476              trainable=False,
477              aggregation=tf_variables.VariableAggregation.MEAN,
478              experimental_autocast=False)
479
480        # Create variables to maintain the moving mean and standard deviation.
481        # These are used in training and thus are different from the moving
482        # averages above. The renorm variables are colocated with moving_mean
483        # and moving_stddev.
484        # NOTE: below, the outer `with device` block causes the current device
485        # stack to be cleared. The nested ones use a `lambda` to set the desired
486        # device and ignore any devices that may be set by the custom getter.
487        def _renorm_variable(name,
488                             shape,
489                             initializer=init_ops.zeros_initializer()):
490          """Create a renorm variable."""
491          var = self.add_weight(
492              name=name,
493              shape=shape,
494              dtype=self._param_dtype,
495              initializer=initializer,
496              synchronization=tf_variables.VariableSynchronization.ON_READ,
497              trainable=False,
498              aggregation=tf_variables.VariableAggregation.MEAN,
499              experimental_autocast=False)
500          return var
501
502        with distribution_strategy_context.get_strategy(
503        ).extended.colocate_vars_with(self.moving_mean):
504          self.renorm_mean = _renorm_variable('renorm_mean', param_shape,
505                                              self.moving_mean_initializer)
506        with distribution_strategy_context.get_strategy(
507        ).extended.colocate_vars_with(self.moving_stddev):
508          self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape,
509                                                moving_stddev_initializer)
510    finally:
511      if partitioner:
512        self._scope.set_partitioner(partitioner)
513    self.built = True
514
515  def _assign_moving_average(self, variable, value, momentum, inputs_size):
516
517    def calculate_update_delta():
518      decay = ops.convert_to_tensor_v2_with_dispatch(
519          1.0 - momentum, name='decay')
520      if decay.dtype != variable.dtype.base_dtype:
521        decay = math_ops.cast(decay, variable.dtype.base_dtype)
522      update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
523      if inputs_size is not None:
524        update_delta = array_ops.where(inputs_size > 0, update_delta,
525                                       backend.zeros_like(update_delta))
526      return update_delta
527
528    with backend.name_scope('AssignMovingAvg') as scope:
529      if ops.executing_eagerly_outside_functions():
530        return variable.assign_sub(calculate_update_delta(), name=scope)
531      else:
532        with ops._colocate_with(variable):  # pylint: disable=protected-access
533          return state_ops.assign_sub(
534              variable, calculate_update_delta(), name=scope)
535
536  def _assign_new_value(self, variable, value):
537    with backend.name_scope('AssignNewValue') as scope:
538      if ops.executing_eagerly_outside_functions():
539        return variable.assign(value, name=scope)
540      else:
541        with ops._colocate_with(variable):  # pylint: disable=protected-access
542          return state_ops.assign(variable, value, name=scope)
543
544  def _fused_batch_norm(self, inputs, training):
545    """Returns the output of fused batch norm."""
546    beta = self.beta if self.center else self._beta_const
547    gamma = self.gamma if self.scale else self._gamma_const
548
549    # TODO(b/129279393): Support zero batch input in non DistributionStrategy
550    # code as well.
551    if self._support_zero_size_input():
552      # Keras assumes that batch dimension is the first dimension for Batch
553      # Normalization.
554      input_batch_size = array_ops.shape(inputs)[0]
555    else:
556      input_batch_size = None
557
558    # TODO(rmlarsen): Support using fused avg updates for non-eager execution
559    # after fixing graph pattern matching and enabling fused_batch_norm to
560    # take exponential_avg_factor as a tensor input.
561    use_fused_avg_updates = (
562        ops.executing_eagerly_outside_functions() and
563        isinstance(self.momentum,
564                   (float, int)) and get_enclosing_xla_context() is None)
565    if use_fused_avg_updates:
566      exponential_avg_factor = 1.0 - self.momentum
567    else:
568      exponential_avg_factor = None
569
570    def _maybe_add_or_remove_bessels_correction(variance, remove=True):
571      r"""Add or remove Bessel's correction."""
572      # Removes Bessel's correction if remove == True, adds it otherwise.
573      # This is to be consistent with non-fused batch norm. Note that the
574      # variance computed by fused batch norm is with Bessel's correction.
575      # This is only used in legacy V1 batch norm tests.
576      if self._bessels_correction_test_only:
577        return variance
578      sample_size = math_ops.cast(
579          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
580      if remove:
581        factor = (sample_size -
582                  math_ops.cast(1.0, variance.dtype)) / sample_size
583      else:
584        factor = sample_size / (
585            sample_size - math_ops.cast(1.0, variance.dtype))
586      return variance * factor
587
588    def _fused_batch_norm_training():
589      return nn.fused_batch_norm(
590          inputs,
591          gamma,
592          beta,
593          mean=self.moving_mean,
594          variance=_maybe_add_or_remove_bessels_correction(
595              self.moving_variance, remove=False),
596          epsilon=self.epsilon,
597          is_training=True,
598          data_format=self._data_format,
599          exponential_avg_factor=exponential_avg_factor)
600
601    def _fused_batch_norm_training_empty():
602      return inputs, self.moving_mean, self.moving_variance
603
604    def _fused_batch_norm_inference():
605      return nn.fused_batch_norm(
606          inputs,
607          gamma,
608          beta,
609          mean=self.moving_mean,
610          variance=self.moving_variance,
611          epsilon=self.epsilon,
612          is_training=False,
613          data_format=self._data_format)
614
615    train_op = _fused_batch_norm_training
616    if use_fused_avg_updates and input_batch_size is not None:
617      # pylint: disable=g-long-lambda
618      train_op = lambda: control_flow_util.smart_cond(
619          input_batch_size > 0, _fused_batch_norm_training,
620          _fused_batch_norm_training_empty)
621      # pylint: enable=g-long-lambda
622
623    output, mean, variance = control_flow_util.smart_cond(
624        training, train_op, _fused_batch_norm_inference)
625    variance = _maybe_add_or_remove_bessels_correction(variance, remove=True)
626
627    training_value = control_flow_util.constant_value(training)
628    if training_value or training_value is None:
629      if not use_fused_avg_updates:
630        if training_value is None:
631          momentum = control_flow_util.smart_cond(training,
632                                                  lambda: self.momentum,
633                                                  lambda: 1.0)
634        else:
635          momentum = ops.convert_to_tensor_v2_with_dispatch(self.momentum)
636
637      def mean_update():
638        """Update self.moving_mean with the most recent data point."""
639        if use_fused_avg_updates:
640          return self._assign_new_value(self.moving_mean, mean)
641        else:
642          return self._assign_moving_average(self.moving_mean, mean, momentum,
643                                             input_batch_size)
644
645      def variance_update():
646        """Update self.moving_variance with the most recent data point."""
647        if use_fused_avg_updates:
648          return self._assign_new_value(self.moving_variance, variance)
649        else:
650          return self._assign_moving_average(self.moving_variance, variance,
651                                             momentum, input_batch_size)
652
653      self.add_update(mean_update)
654      self.add_update(variance_update)
655
656    return output
657
658  def _renorm_correction_and_moments(self, mean, variance, training,
659                                     inputs_size):
660    """Returns the correction and update values for renorm."""
661    stddev = math_ops.sqrt(variance + self.epsilon)
662    # Compute the average mean and standard deviation, as if they were
663    # initialized with this batch's moments.
664    renorm_mean = self.renorm_mean
665    # Avoid divide by zero early on in training.
666    renorm_stddev = math_ops.maximum(self.renorm_stddev,
667                                     math_ops.sqrt(self.epsilon))
668    # Compute the corrections for batch renorm.
669    r = stddev / renorm_stddev
670    d = (mean - renorm_mean) / renorm_stddev
671    # Ensure the corrections use pre-update moving averages.
672    with ops.control_dependencies([r, d]):
673      mean = array_ops.identity(mean)
674      stddev = array_ops.identity(stddev)
675    rmin, rmax, dmax = [
676        self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax']
677    ]
678    if rmin is not None:
679      r = math_ops.maximum(r, rmin)
680    if rmax is not None:
681      r = math_ops.minimum(r, rmax)
682    if dmax is not None:
683      d = math_ops.maximum(d, -dmax)
684      d = math_ops.minimum(d, dmax)
685    # When not training, use r=1, d=0.
686    r = control_flow_util.smart_cond(training, lambda: r,
687                                     lambda: array_ops.ones_like(r))
688    d = control_flow_util.smart_cond(training, lambda: d,
689                                     lambda: array_ops.zeros_like(d))
690
691    def _update_renorm_variable(var, value, inputs_size):
692      """Updates a moving average and weight, returns the unbiased value."""
693      value = array_ops.identity(value)
694
695      def _do_update():
696        """Updates the var, returns the updated value."""
697        new_var = self._assign_moving_average(var, value, self.renorm_momentum,
698                                              inputs_size)
699        return new_var
700
701      def _fake_update():
702        return array_ops.identity(var)
703
704      return control_flow_util.smart_cond(training, _do_update, _fake_update)
705
706    # TODO(yuefengz): colocate the operations
707    update_new_mean = _update_renorm_variable(self.renorm_mean, mean,
708                                              inputs_size)
709    update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev,
710                                                inputs_size)
711
712    # Update the inference mode moving averages with the batch value.
713    with ops.control_dependencies([update_new_mean, update_new_stddev]):
714      out_mean = array_ops.identity(mean)
715      out_variance = array_ops.identity(variance)
716
717    return (r, d, out_mean, out_variance)
718
719  def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims):
720    return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
721
722  def _moments(self, inputs, reduction_axes, keep_dims):
723    mean, variance = self._calculate_mean_and_var(inputs, reduction_axes,
724                                                  keep_dims)
725    # TODO(b/129279393): Support zero batch input in non DistributionStrategy
726    # code as well.
727    if self._support_zero_size_input():
728      input_batch_size = array_ops.shape(inputs)[0]
729      mean = array_ops.where(input_batch_size > 0, mean,
730                             backend.zeros_like(mean))
731      variance = array_ops.where(input_batch_size > 0, variance,
732                                 backend.zeros_like(variance))
733    return mean, variance
734
735  def _get_training_value(self, training=None):
736    if training is None:
737      training = backend.learning_phase()
738    if self._USE_V2_BEHAVIOR:
739      if isinstance(training, int):
740        training = bool(training)
741      if not self.trainable:
742        # When the layer is not trainable, it overrides the value passed from
743        # model.
744        training = False
745    return training
746
747  def call(self, inputs, training=None):
748    training = self._get_training_value(training)
749
750    if self.virtual_batch_size is not None:
751      # Virtual batches (aka ghost batches) can be simulated by reshaping the
752      # Tensor and reusing the existing batch norm implementation
753      original_shape = array_ops.shape(inputs)
754      original_shape = array_ops.concat(
755          [constant_op.constant([-1]), original_shape[1:]], axis=0)
756      expanded_shape = array_ops.concat([
757          constant_op.constant([self.virtual_batch_size, -1]),
758          original_shape[1:]
759      ],
760                                        axis=0)
761
762      # Will cause errors if virtual_batch_size does not divide the batch size
763      inputs = array_ops.reshape(inputs, expanded_shape)
764
765      def undo_virtual_batching(outputs):
766        outputs = array_ops.reshape(outputs, original_shape)
767        return outputs
768
769    if self.fused:
770      outputs = self._fused_batch_norm(inputs, training=training)
771      if self.virtual_batch_size is not None:
772        # Currently never reaches here since fused_batch_norm does not support
773        # virtual batching
774        outputs = undo_virtual_batching(outputs)
775      return outputs
776
777    inputs_dtype = inputs.dtype.base_dtype
778    if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
779      # Do all math in float32 if given 16-bit inputs for numeric stability.
780      # In particular, it's very easy for variance to overflow in float16 and
781      # for safety we also choose to cast bfloat16 to float32.
782      inputs = math_ops.cast(inputs, dtypes.float32)
783
784    # Compute the axes along which to reduce the mean / variance
785    input_shape = inputs.shape
786    ndims = len(input_shape)
787    reduction_axes = [i for i in range(ndims) if i not in self.axis]
788    if self.virtual_batch_size is not None:
789      del reduction_axes[1]  # Do not reduce along virtual batch dim
790
791    # Broadcasting only necessary for single-axis batch norm where the axis is
792    # not the last dimension
793    broadcast_shape = [1] * ndims
794    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
795
796    def _broadcast(v):
797      if (v is not None and len(v.shape) != ndims and
798          reduction_axes != list(range(ndims - 1))):
799        return array_ops.reshape(v, broadcast_shape)
800      return v
801
802    scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
803
804    def _compose_transforms(scale, offset, then_scale, then_offset):
805      if then_scale is not None:
806        scale *= then_scale
807        offset *= then_scale
808      if then_offset is not None:
809        offset += then_offset
810      return (scale, offset)
811
812    # Determine a boolean value for `training`: could be True, False, or None.
813    training_value = control_flow_util.constant_value(training)
814    if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
815      mean, variance = self.moving_mean, self.moving_variance
816    else:
817      if self.adjustment:
818        adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
819        # Adjust only during training.
820        adj_scale = control_flow_util.smart_cond(
821            training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale))
822        adj_bias = control_flow_util.smart_cond(
823            training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias))
824        scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
825
826      # Some of the computations here are not necessary when training==False
827      # but not a constant. However, this makes the code simpler.
828      keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
829      mean, variance = self._moments(
830          math_ops.cast(inputs, self._param_dtype),
831          reduction_axes,
832          keep_dims=keep_dims)
833
834      moving_mean = self.moving_mean
835      moving_variance = self.moving_variance
836
837      mean = control_flow_util.smart_cond(
838          training, lambda: mean,
839          lambda: ops.convert_to_tensor_v2_with_dispatch(moving_mean))
840      variance = control_flow_util.smart_cond(
841          training, lambda: variance,
842          lambda: ops.convert_to_tensor_v2_with_dispatch(moving_variance))
843
844      if self.virtual_batch_size is not None:
845        # This isn't strictly correct since in ghost batch norm, you are
846        # supposed to sequentially update the moving_mean and moving_variance
847        # with each sub-batch. However, since the moving statistics are only
848        # used during evaluation, it is more efficient to just update in one
849        # step and should not make a significant difference in the result.
850        new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
851        new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True)
852      else:
853        new_mean, new_variance = mean, variance
854
855      if self._support_zero_size_input():
856        # Keras assumes that batch dimension is the first dimension for Batch
857        # Normalization.
858        input_batch_size = array_ops.shape(inputs)[0]
859      else:
860        input_batch_size = None
861
862      if self.renorm:
863        r, d, new_mean, new_variance = self._renorm_correction_and_moments(
864            new_mean, new_variance, training, input_batch_size)
865        # When training, the normalized values (say, x) will be transformed as
866        # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
867        # = x * (r * gamma) + (d * gamma + beta) with renorm.
868        r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
869        d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
870        scale, offset = _compose_transforms(r, d, scale, offset)
871
872      def _do_update(var, value):
873        """Compute the updates for mean and variance."""
874        return self._assign_moving_average(var, value, self.momentum,
875                                           input_batch_size)
876
877      def mean_update():
878        true_branch = lambda: _do_update(self.moving_mean, new_mean)
879        false_branch = lambda: self.moving_mean
880        return control_flow_util.smart_cond(training, true_branch, false_branch)
881
882      def variance_update():
883        """Update the moving variance."""
884
885        def true_branch_renorm():
886          # We apply epsilon as part of the moving_stddev to mirror the training
887          # code path.
888          moving_stddev = _do_update(self.moving_stddev,
889                                     math_ops.sqrt(new_variance + self.epsilon))
890          return self._assign_new_value(
891              self.moving_variance,
892              # Apply relu in case floating point rounding causes it to go
893              # negative.
894              backend.relu(moving_stddev * moving_stddev - self.epsilon))
895
896        if self.renorm:
897          true_branch = true_branch_renorm
898        else:
899          true_branch = lambda: _do_update(self.moving_variance, new_variance)
900
901        false_branch = lambda: self.moving_variance
902        return control_flow_util.smart_cond(training, true_branch, false_branch)
903
904      self.add_update(mean_update)
905      self.add_update(variance_update)
906
907    mean = math_ops.cast(mean, inputs.dtype)
908    variance = math_ops.cast(variance, inputs.dtype)
909    if offset is not None:
910      offset = math_ops.cast(offset, inputs.dtype)
911    if scale is not None:
912      scale = math_ops.cast(scale, inputs.dtype)
913    outputs = nn.batch_normalization(inputs, _broadcast(mean),
914                                     _broadcast(variance), offset, scale,
915                                     self.epsilon)
916    if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
917      outputs = math_ops.cast(outputs, inputs_dtype)
918
919    # If some components of the shape got lost due to adjustments, fix that.
920    outputs.set_shape(input_shape)
921
922    if self.virtual_batch_size is not None:
923      outputs = undo_virtual_batching(outputs)
924    return outputs
925
926  def compute_output_shape(self, input_shape):
927    return input_shape
928
929  def get_config(self):
930    config = {
931        'axis':
932            self.axis,
933        'momentum':
934            self.momentum,
935        'epsilon':
936            self.epsilon,
937        'center':
938            self.center,
939        'scale':
940            self.scale,
941        'beta_initializer':
942            initializers.serialize(self.beta_initializer),
943        'gamma_initializer':
944            initializers.serialize(self.gamma_initializer),
945        'moving_mean_initializer':
946            initializers.serialize(self.moving_mean_initializer),
947        'moving_variance_initializer':
948            initializers.serialize(self.moving_variance_initializer),
949        'beta_regularizer':
950            regularizers.serialize(self.beta_regularizer),
951        'gamma_regularizer':
952            regularizers.serialize(self.gamma_regularizer),
953        'beta_constraint':
954            constraints.serialize(self.beta_constraint),
955        'gamma_constraint':
956            constraints.serialize(self.gamma_constraint)
957    }
958    # Only add TensorFlow-specific parameters if they are set, so as to preserve
959    # model compatibility with external Keras.
960    if self.renorm:
961      config['renorm'] = True
962      config['renorm_clipping'] = self.renorm_clipping
963      config['renorm_momentum'] = self.renorm_momentum
964    if self.virtual_batch_size is not None:
965      config['virtual_batch_size'] = self.virtual_batch_size
966    # Note: adjustment is not serializable.
967    if self.adjustment is not None:
968      logging.warning('The `adjustment` function of this `BatchNormalization` '
969                      'layer cannot be serialized and has been omitted from '
970                      'the layer config. It will not be included when '
971                      're-creating the layer from the saved config.')
972    base_config = super(BatchNormalizationBase, self).get_config()
973    return dict(list(base_config.items()) + list(config.items()))
974
975
976# pylint: disable=g-classes-have-attributes
977@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[])
978class SyncBatchNormalization(BatchNormalizationBase):
979  r"""Normalize and scale inputs or activations synchronously across replicas.
980
981  Applies batch normalization to activations of the previous layer at each batch
982  by synchronizing the global batch statistics across all devices that are
983  training the model. For specific details about batch normalization please
984  refer to the `tf.keras.layers.BatchNormalization` layer docs.
985
986  If this layer is used when using tf.distribute strategy to train models
987  across devices/workers, there will be an allreduce call to aggregate batch
988  statistics across all replicas at every training step. Without tf.distribute
989  strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization`
990  layer.
991
992  Example usage:
993
994  ```python
995  strategy = tf.distribute.MirroredStrategy()
996
997  with strategy.scope():
998    model = tf.keras.Sequential()
999    model.add(tf.keras.layers.Dense(16))
1000    model.add(tf.keras.layers.experimental.SyncBatchNormalization())
1001  ```
1002
1003  Args:
1004    axis: Integer, the axis that should be normalized
1005      (typically the features axis).
1006      For instance, after a `Conv2D` layer with
1007      `data_format="channels_first"`,
1008      set `axis=1` in `BatchNormalization`.
1009    momentum: Momentum for the moving average.
1010    epsilon: Small float added to variance to avoid dividing by zero.
1011    center: If True, add offset of `beta` to normalized tensor.
1012      If False, `beta` is ignored.
1013    scale: If True, multiply by `gamma`.
1014      If False, `gamma` is not used.
1015      When the next layer is linear (also e.g. `nn.relu`),
1016      this can be disabled since the scaling
1017      will be done by the next layer.
1018    beta_initializer: Initializer for the beta weight.
1019    gamma_initializer: Initializer for the gamma weight.
1020    moving_mean_initializer: Initializer for the moving mean.
1021    moving_variance_initializer: Initializer for the moving variance.
1022    beta_regularizer: Optional regularizer for the beta weight.
1023    gamma_regularizer: Optional regularizer for the gamma weight.
1024    beta_constraint: Optional constraint for the beta weight.
1025    gamma_constraint: Optional constraint for the gamma weight.
1026
1027  Call arguments:
1028    inputs: Input tensor (of any rank).
1029    training: Python boolean indicating whether the layer should behave in
1030      training mode or in inference mode.
1031      - `training=True`: The layer will normalize its inputs using the
1032        mean and variance of the current batch of inputs.
1033      - `training=False`: The layer will normalize its inputs using the
1034        mean and variance of its moving statistics, learned during training.
1035
1036  Input shape:
1037    Arbitrary. Use the keyword argument `input_shape`
1038    (tuple of integers, does not include the samples axis)
1039    when using this layer as the first layer in a model.
1040
1041  Output shape:
1042    Same shape as input.
1043
1044  """
1045
1046  def __init__(self,
1047               axis=-1,
1048               momentum=0.99,
1049               epsilon=1e-3,
1050               center=True,
1051               scale=True,
1052               beta_initializer='zeros',
1053               gamma_initializer='ones',
1054               moving_mean_initializer='zeros',
1055               moving_variance_initializer='ones',
1056               beta_regularizer=None,
1057               gamma_regularizer=None,
1058               beta_constraint=None,
1059               gamma_constraint=None,
1060               **kwargs):
1061    if kwargs.pop('fused', None):
1062      raise ValueError(
1063          '`fused` argument cannot be True for SyncBatchNormalization.')
1064
1065    # Currently we only support aggregating over the global batch size.
1066    super(SyncBatchNormalization, self).__init__(
1067        axis=axis,
1068        momentum=momentum,
1069        epsilon=epsilon,
1070        center=center,
1071        scale=scale,
1072        beta_initializer=beta_initializer,
1073        gamma_initializer=gamma_initializer,
1074        moving_mean_initializer=moving_mean_initializer,
1075        moving_variance_initializer=moving_variance_initializer,
1076        beta_regularizer=beta_regularizer,
1077        gamma_regularizer=gamma_regularizer,
1078        beta_constraint=beta_constraint,
1079        gamma_constraint=gamma_constraint,
1080        fused=False,
1081        **kwargs)
1082
1083  def _calculate_mean_and_var(self, x, axes, keep_dims):
1084
1085    with backend.name_scope('moments'):
1086      # The dynamic range of fp16 is too limited to support the collection of
1087      # sufficient statistics. As a workaround we simply perform the operations
1088      # on 32-bit floats before converting the mean and variance back to fp16
1089      y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
1090      replica_ctx = distribution_strategy_context.get_replica_context()
1091      if replica_ctx:
1092        local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True)
1093        local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes,
1094                                                keepdims=True)
1095        batch_size = math_ops.cast(array_ops.shape_v2(y)[axes[0]],
1096                                   dtypes.float32)
1097        # TODO(b/163099951): batch the all-reduces once we sort out the ordering
1098        # issue for NCCL. We don't have a mechanism to launch NCCL in the same
1099        # order in each replica nowadays, so we limit NCCL to batch all-reduces.
1100        y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum)
1101        y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
1102                                               local_squared_sum)
1103        global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM,
1104                                                   batch_size)
1105
1106        axes_vals = [(array_ops.shape_v2(y))[axes[i]]
1107                     for i in range(1, len(axes))]
1108        multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals),
1109                                   dtypes.float32)
1110        multiplier = multiplier * global_batch_size
1111
1112        mean = y_sum / multiplier
1113        y_squared_mean = y_squared_sum / multiplier
1114        # var = E(x^2) - E(x)^2
1115        variance = y_squared_mean - math_ops.square(mean)
1116      else:
1117        # Compute true mean while keeping the dims for proper broadcasting.
1118        mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean')
1119        # sample variance, not unbiased variance
1120        # Note: stop_gradient does not change the gradient that gets
1121        #       backpropagated to the mean from the variance calculation,
1122        #       because that gradient is zero
1123        variance = math_ops.reduce_mean(
1124            math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
1125            axes,
1126            keepdims=True,
1127            name='variance')
1128      if not keep_dims:
1129        mean = array_ops.squeeze(mean, axes)
1130        variance = array_ops.squeeze(variance, axes)
1131      if x.dtype == dtypes.float16:
1132        return (math_ops.cast(mean, dtypes.float16),
1133                math_ops.cast(variance, dtypes.float16))
1134      else:
1135        return (mean, variance)
1136
1137
1138@keras_export('keras.layers.BatchNormalization', v1=[])
1139class BatchNormalization(BatchNormalizationBase):
1140  """Layer that normalizes its inputs.
1141
1142  Batch normalization applies a transformation that maintains the mean output
1143  close to 0 and the output standard deviation close to 1.
1144
1145  Importantly, batch normalization works differently during training and
1146  during inference.
1147
1148  **During training** (i.e. when using `fit()` or when calling the layer/model
1149  with the argument `training=True`), the layer normalizes its output using
1150  the mean and standard deviation of the current batch of inputs. That is to
1151  say, for each channel being normalized, the layer returns
1152  `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:
1153
1154  - `epsilon` is small constant (configurable as part of the constructor
1155  arguments)
1156  - `gamma` is a learned scaling factor (initialized as 1), which
1157  can be disabled by passing `scale=False` to the constructor.
1158  - `beta` is a learned offset factor (initialized as 0), which
1159  can be disabled by passing `center=False` to the constructor.
1160
1161  **During inference** (i.e. when using `evaluate()` or `predict()` or when
1162  calling the layer/model with the argument `training=False` (which is the
1163  default), the layer normalizes its output using a moving average of the
1164  mean and standard deviation of the batches it has seen during training. That
1165  is to say, it returns
1166  `gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`.
1167
1168  `self.moving_mean` and `self.moving_var` are non-trainable variables that
1169  are updated each time the layer in called in training mode, as such:
1170
1171  - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`
1172  - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`
1173
1174  As such, the layer will only normalize its inputs during inference
1175  *after having been trained on data that has similar statistics as the
1176  inference data*.
1177
1178  Args:
1179    axis: Integer, the axis that should be normalized (typically the features
1180      axis). For instance, after a `Conv2D` layer with
1181      `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
1182    momentum: Momentum for the moving average.
1183    epsilon: Small float added to variance to avoid dividing by zero.
1184    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
1185      is ignored.
1186    scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
1187      next layer is linear (also e.g. `nn.relu`), this can be disabled since the
1188      scaling will be done by the next layer.
1189    beta_initializer: Initializer for the beta weight.
1190    gamma_initializer: Initializer for the gamma weight.
1191    moving_mean_initializer: Initializer for the moving mean.
1192    moving_variance_initializer: Initializer for the moving variance.
1193    beta_regularizer: Optional regularizer for the beta weight.
1194    gamma_regularizer: Optional regularizer for the gamma weight.
1195    beta_constraint: Optional constraint for the beta weight.
1196    gamma_constraint: Optional constraint for the gamma weight.
1197
1198  Call arguments:
1199    inputs: Input tensor (of any rank).
1200    training: Python boolean indicating whether the layer should behave in
1201      training mode or in inference mode.
1202      - `training=True`: The layer will normalize its inputs using the mean and
1203        variance of the current batch of inputs.
1204      - `training=False`: The layer will normalize its inputs using the mean and
1205        variance of its moving statistics, learned during training.
1206
1207  Input shape:
1208    Arbitrary. Use the keyword argument `input_shape` (tuple of
1209    integers, does not include the samples axis) when using this layer as the
1210    first layer in a model.
1211
1212  Output shape:
1213    Same shape as input.
1214
1215  Reference:
1216    - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
1217
1218  **About setting `layer.trainable = False` on a `BatchNormalization` layer:**
1219
1220  The meaning of setting `layer.trainable = False` is to freeze the layer,
1221  i.e. its internal state will not change during training:
1222  its trainable weights will not be updated
1223  during `fit()` or `train_on_batch()`, and its state updates will not be run.
1224
1225  Usually, this does not necessarily mean that the layer is run in inference
1226  mode (which is normally controlled by the `training` argument that can
1227  be passed when calling a layer). "Frozen state" and "inference mode"
1228  are two separate concepts.
1229
1230  However, in the case of the `BatchNormalization` layer, **setting
1231  `trainable = False` on the layer means that the layer will be
1232  subsequently run in inference mode** (meaning that it will use
1233  the moving mean and the moving variance to normalize the current batch,
1234  rather than using the mean and variance of the current batch).
1235
1236  This behavior has been introduced in TensorFlow 2.0, in order
1237  to enable `layer.trainable = False` to produce the most commonly
1238  expected behavior in the convnet fine-tuning use case.
1239
1240  Note that:
1241    - Setting `trainable` on an model containing other layers will
1242      recursively set the `trainable` value of all inner layers.
1243    - If the value of the `trainable`
1244      attribute is changed after calling `compile()` on a model,
1245      the new value doesn't take effect for this model
1246      until `compile()` is called again.
1247  """
1248  _USE_V2_BEHAVIOR = True
1249
1250  def __init__(self,
1251               axis=-1,
1252               momentum=0.99,
1253               epsilon=1e-3,
1254               center=True,
1255               scale=True,
1256               beta_initializer='zeros',
1257               gamma_initializer='ones',
1258               moving_mean_initializer='zeros',
1259               moving_variance_initializer='ones',
1260               beta_regularizer=None,
1261               gamma_regularizer=None,
1262               beta_constraint=None,
1263               gamma_constraint=None,
1264               **kwargs):
1265    super(BatchNormalization, self).__init__(
1266        axis=axis,
1267        momentum=momentum,
1268        epsilon=epsilon,
1269        center=center,
1270        scale=scale,
1271        beta_initializer=beta_initializer,
1272        gamma_initializer=gamma_initializer,
1273        moving_mean_initializer=moving_mean_initializer,
1274        moving_variance_initializer=moving_variance_initializer,
1275        beta_regularizer=beta_regularizer,
1276        gamma_regularizer=gamma_regularizer,
1277        beta_constraint=beta_constraint,
1278        gamma_constraint=gamma_constraint,
1279        **kwargs)
1280