1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""The V2 implementation of Normalization layers.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.distribute import distribution_strategy_context 19from tensorflow.python.distribute import reduce_util 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.keras import backend 25from tensorflow.python.keras import constraints 26from tensorflow.python.keras import initializers 27from tensorflow.python.keras import regularizers 28from tensorflow.python.keras.engine.base_layer import Layer 29from tensorflow.python.keras.engine.input_spec import InputSpec 30from tensorflow.python.keras.utils import control_flow_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import nn 35from tensorflow.python.ops import state_ops 36from tensorflow.python.ops import variables as tf_variables 37from tensorflow.python.ops.control_flow_ops import get_enclosing_xla_context 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.util.tf_export import keras_export 40 41 42class BatchNormalizationBase(Layer): 43 r"""Layer that normalizes its inputs. 44 45 Batch normalization applies a transformation that maintains the mean output 46 close to 0 and the output standard deviation close to 1. 47 48 Importantly, batch normalization works differently during training and 49 during inference. 50 51 **During training** (i.e. when using `fit()` or when calling the layer/model 52 with the argument `training=True`), the layer normalizes its output using 53 the mean and standard deviation of the current batch of inputs. That is to 54 say, for each channel being normalized, the layer returns 55 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where: 56 57 - `epsilon` is small constant (configurable as part of the constructor 58 arguments) 59 - `gamma` is a learned scaling factor (initialized as 1), which 60 can be disabled by passing `scale=False` to the constructor. 61 - `beta` is a learned offset factor (initialized as 0), which 62 can be disabled by passing `center=False` to the constructor. 63 64 **During inference** (i.e. when using `evaluate()` or `predict()` or when 65 calling the layer/model with the argument `training=False` (which is the 66 default), the layer normalizes its output using a moving average of the 67 mean and standard deviation of the batches it has seen during training. That 68 is to say, it returns 69 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`. 70 71 `self.moving_mean` and `self.moving_var` are non-trainable variables that 72 are updated each time the layer in called in training mode, as such: 73 74 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` 75 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` 76 77 As such, the layer will only normalize its inputs during inference 78 *after having been trained on data that has similar statistics as the 79 inference data*. 80 81 Args: 82 axis: Integer or a list of integers, the axis that should be normalized 83 (typically the features axis). For instance, after a `Conv2D` layer with 84 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. 85 momentum: Momentum for the moving average. 86 epsilon: Small float added to variance to avoid dividing by zero. 87 center: If True, add offset of `beta` to normalized tensor. If False, `beta` 88 is ignored. 89 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the 90 next layer is linear (also e.g. `nn.relu`), this can be disabled since the 91 scaling will be done by the next layer. 92 beta_initializer: Initializer for the beta weight. 93 gamma_initializer: Initializer for the gamma weight. 94 moving_mean_initializer: Initializer for the moving mean. 95 moving_variance_initializer: Initializer for the moving variance. 96 beta_regularizer: Optional regularizer for the beta weight. 97 gamma_regularizer: Optional regularizer for the gamma weight. 98 beta_constraint: Optional constraint for the beta weight. 99 gamma_constraint: Optional constraint for the gamma weight. 100 renorm: Whether to use [Batch Renormalization]( 101 https://arxiv.org/abs/1702.03275). This adds extra variables during 102 training. The inference is the same for either value of this parameter. 103 renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to 104 scalar `Tensors` used to clip the renorm correction. The correction `(r, 105 d)` is used as `corrected_value = normalized_value * r + d`, with `r` 106 clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, 107 dmax are set to inf, 0, inf, respectively. 108 renorm_momentum: Momentum used to update the moving means and standard 109 deviations with renorm. Unlike `momentum`, this affects training and 110 should be neither too small (which would add noise) nor too large (which 111 would give stale estimates). Note that `momentum` is still applied to get 112 the means and variances for inference. 113 fused: if `True`, use a faster, fused implementation, or raise a ValueError 114 if the fused implementation cannot be used. If `None`, use the faster 115 implementation if possible. If False, do not used the fused 116 implementation. 117 Note that in TensorFlow 1.x, the meaning of `fused=True` is different: if 118 `False`, the layer uses the system-recommended implementation. 119 trainable: Boolean, if `True` the variables will be marked as trainable. 120 virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, 121 which means batch normalization is performed across the whole batch. When 122 `virtual_batch_size` is not `None`, instead perform "Ghost Batch 123 Normalization", which creates virtual sub-batches which are each 124 normalized separately (with shared gamma, beta, and moving statistics). 125 Must divide the actual batch size during execution. 126 adjustment: A function taking the `Tensor` containing the (dynamic) shape of 127 the input tensor and returning a pair (scale, bias) to apply to the 128 normalized values (before gamma and beta), only during training. For 129 example, if `axis=-1`, 130 `adjustment = lambda shape: ( 131 tf.random.uniform(shape[-1:], 0.93, 1.07), 132 tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized 133 value by up to 7% up or down, then shift the result by up to 0.1 134 (with independent scaling and bias for each feature but shared 135 across all examples), and finally apply gamma and/or beta. If 136 `None`, no adjustment is applied. Cannot be specified if 137 virtual_batch_size is specified. 138 139 Call arguments: 140 inputs: Input tensor (of any rank). 141 training: Python boolean indicating whether the layer should behave in 142 training mode or in inference mode. 143 - `training=True`: The layer will normalize its inputs using the mean and 144 variance of the current batch of inputs. 145 - `training=False`: The layer will normalize its inputs using the mean and 146 variance of its moving statistics, learned during training. 147 148 Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of 149 integers, does not include the samples axis) when using this layer as the 150 first layer in a model. 151 152 Output shape: Same shape as input. 153 154 Reference: 155 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). 156 """ 157 158 # By default, the base class uses V2 behavior. The BatchNormalization V1 159 # subclass sets this to False to use the V1 behavior. 160 _USE_V2_BEHAVIOR = True 161 162 def __init__(self, 163 axis=-1, 164 momentum=0.99, 165 epsilon=1e-3, 166 center=True, 167 scale=True, 168 beta_initializer='zeros', 169 gamma_initializer='ones', 170 moving_mean_initializer='zeros', 171 moving_variance_initializer='ones', 172 beta_regularizer=None, 173 gamma_regularizer=None, 174 beta_constraint=None, 175 gamma_constraint=None, 176 renorm=False, 177 renorm_clipping=None, 178 renorm_momentum=0.99, 179 fused=None, 180 trainable=True, 181 virtual_batch_size=None, 182 adjustment=None, 183 name=None, 184 **kwargs): 185 super(BatchNormalizationBase, self).__init__(name=name, **kwargs) 186 if isinstance(axis, (list, tuple)): 187 self.axis = axis[:] 188 elif isinstance(axis, int): 189 self.axis = axis 190 else: 191 raise TypeError('Expected an int or a list/tuple of ints for the ' 192 'argument \'axis\', but received: %r' % axis) 193 self.momentum = momentum 194 self.epsilon = epsilon 195 self.center = center 196 self.scale = scale 197 self.beta_initializer = initializers.get(beta_initializer) 198 self.gamma_initializer = initializers.get(gamma_initializer) 199 self.moving_mean_initializer = initializers.get(moving_mean_initializer) 200 self.moving_variance_initializer = initializers.get( 201 moving_variance_initializer) 202 self.beta_regularizer = regularizers.get(beta_regularizer) 203 self.gamma_regularizer = regularizers.get(gamma_regularizer) 204 self.beta_constraint = constraints.get(beta_constraint) 205 self.gamma_constraint = constraints.get(gamma_constraint) 206 self.renorm = renorm 207 self.virtual_batch_size = virtual_batch_size 208 self.adjustment = adjustment 209 if self._USE_V2_BEHAVIOR: 210 if fused: 211 self._raise_if_fused_cannot_be_used() 212 # We leave fused as None if self._fused_can_be_used()==True, since we 213 # still may set it to False in self.build() if the input rank is not 4. 214 elif fused is None and not self._fused_can_be_used(): 215 fused = False 216 elif fused is None: 217 fused = True 218 self.supports_masking = True 219 220 self.fused = fused 221 self._bessels_correction_test_only = True 222 self.trainable = trainable 223 224 if renorm: 225 renorm_clipping = renorm_clipping or {} 226 keys = ['rmax', 'rmin', 'dmax'] 227 if set(renorm_clipping) - set(keys): 228 raise ValueError('renorm_clipping %s contains keys not in %s' % 229 (renorm_clipping, keys)) 230 self.renorm_clipping = renorm_clipping 231 self.renorm_momentum = renorm_momentum 232 233 def _raise_if_fused_cannot_be_used(self): 234 """Raises a ValueError if fused implementation cannot be used. 235 236 In addition to the checks done in this function, the input tensors rank must 237 be 4 or 5. The input rank check can only be done once the input shape is 238 known. 239 """ 240 # Note the ValueErrors in this function are caught and not reraised in 241 # _fused_can_be_used(). No other exception besides ValueError should be 242 # raised here. 243 244 # Currently fused batch norm doesn't support renorm. It also only supports a 245 # channel dimension on axis 1 or 3 (rank=4) / 1 or 4 (rank5), when no 246 # virtual batch size or adjustment is used. 247 if self.renorm: 248 raise ValueError('Passing both `fused=True` and `renorm=True` is ' 249 'not supported') 250 axis = [self.axis] if isinstance(self.axis, int) else self.axis 251 # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, when the 252 # input rank is 4. Similarly, the valid axis is -4, -1, 1, 4 when the rank 253 # is 5. The combination of ranks and axes will be checked later. 254 if len(axis) > 1 or axis[0] not in (-4, -3, -1, 1, 3, 4): 255 raise ValueError('Passing `fused=True` is only supported when axis is 1 ' 256 'or 3 for input rank = 4 or 1 or 4 for input rank = 5. ' 257 'Got axis %s' % (axis,)) 258 if self.virtual_batch_size is not None: 259 raise ValueError('Passing `fused=True` is not supported when ' 260 '`virtual_batch_size` is specified.') 261 if self.adjustment is not None: 262 raise ValueError('Passing `fused=True` is not supported when ' 263 '`adjustment` is specified.') 264 # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check. 265 if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None): 266 raise ValueError( 267 'Passing `fused=True` is only supported when the compute ' 268 'dtype is float16, bfloat16, or float32. Got dtype: %s' % 269 (self._compute_dtype,)) 270 271 def _fused_can_be_used(self): 272 try: 273 self._raise_if_fused_cannot_be_used() 274 return True 275 except ValueError: 276 return False 277 278 @property 279 def trainable(self): 280 return self._trainable 281 282 @trainable.setter 283 def trainable(self, value): 284 self._trainable = value 285 286 @property 287 def _param_dtype(self): 288 # Raise parameters of fp16 batch norm to fp32 289 if self.dtype == dtypes.float16 or self.dtype == dtypes.bfloat16: 290 return dtypes.float32 291 else: 292 return self.dtype or dtypes.float32 293 294 def _support_zero_size_input(self): 295 return distribution_strategy_context.has_strategy() and getattr( 296 distribution_strategy_context.get_strategy().extended, 297 'experimental_enable_get_next_as_optional', False) 298 299 def build(self, input_shape): 300 input_shape = tensor_shape.TensorShape(input_shape) 301 if not input_shape.ndims: 302 raise ValueError('Input has undefined rank.') 303 ndims = len(input_shape) 304 305 # Convert axis to list and resolve negatives 306 if isinstance(self.axis, int): 307 self.axis = [self.axis] 308 309 for idx, x in enumerate(self.axis): 310 if x < 0: 311 self.axis[idx] = ndims + x 312 313 # Validate axes 314 for x in self.axis: 315 if x < 0 or x >= ndims: 316 raise ValueError('Invalid axis: %s' % (self.axis,)) 317 if len(self.axis) != len(set(self.axis)): 318 raise ValueError('Duplicate axis: %s' % (self.axis,)) 319 320 if self.virtual_batch_size is not None: 321 if self.virtual_batch_size <= 0: 322 raise ValueError('virtual_batch_size must be a positive integer that ' 323 'divides the true batch size of the input tensor') 324 # If using virtual batches, the first dimension must be the batch 325 # dimension and cannot be the batch norm axis 326 if 0 in self.axis: 327 raise ValueError('When using virtual_batch_size, the batch dimension ' 328 'must be 0 and thus axis cannot include 0. ' 329 'Received axis=%s' % (self.axis,)) 330 if self.adjustment is not None: 331 raise ValueError('When using virtual_batch_size, adjustment cannot ' 332 'be specified') 333 334 if self.fused in (None, True): 335 # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the 336 # output back to its original shape accordingly. 337 if self._USE_V2_BEHAVIOR: 338 if self.fused is None: 339 self.fused = ndims in (4, 5) 340 elif self.fused and ndims not in (4, 5): 341 raise ValueError('Batch normalization layers with `fused=True` only ' 342 'support 4D or 5D input tensors. ' 343 'Received tensor with shape: %s' % 344 (tuple(input_shape),)) 345 else: 346 assert self.fused is not None 347 self.fused = (ndims in (4, 5) and self._fused_can_be_used()) 348 # TODO(chrisying): fused batch norm is currently not supported for 349 # multi-axis batch norm and by extension virtual batches. In some cases, 350 # it might be possible to use fused batch norm but would require reshaping 351 # the Tensor to 4D with the axis in 1 or 3 (preferred 1) which is 352 # particularly tricky. A compromise might be to just support the most 353 # common use case (turning 5D w/ virtual batch to NCHW) 354 355 if self.fused: 356 if self.axis == [1] and ndims == 4: 357 self._data_format = 'NCHW' 358 elif self.axis == [1] and ndims == 5: 359 self._data_format = 'NCDHW' 360 elif self.axis == [3] and ndims == 4: 361 self._data_format = 'NHWC' 362 elif self.axis == [4] and ndims == 5: 363 self._data_format = 'NDHWC' 364 elif ndims == 5: 365 # 5D tensors that can be passed in but should not use fused batch norm 366 # due to unsupported axis. 367 self.fused = False 368 else: 369 if ndims == 4: 370 raise ValueError( 371 'Unsupported axis. The use of `fused=True` is only possible with ' 372 '`axis=1` or `axis=3` for 4D input tensors. Received ' 373 'axis=%s' % (self.axis,)) 374 else: 375 raise ValueError( 376 'Unsupported axis. The use of `fused=True` is only possible with ' 377 '`axis=1` or `axis=4` for 5D input tensors. Received ' 378 'axis=%s' % (self.axis,)) 379 380 axis_to_dim = {x: input_shape.dims[x].value for x in self.axis} 381 for x in axis_to_dim: 382 if axis_to_dim[x] is None: 383 raise ValueError('Input has undefined `axis` dimension. Received input ' 384 'with shape %s. Axis value: %s' % 385 (tuple(input_shape), self.axis)) 386 self.input_spec = InputSpec(ndim=ndims, axes=axis_to_dim) 387 388 if len(axis_to_dim) == 1 and self.virtual_batch_size is None: 389 # Single axis batch norm (most common/default use-case) 390 param_shape = (list(axis_to_dim.values())[0],) 391 else: 392 # Parameter shape is the original shape but with 1 in all non-axis dims 393 param_shape = [ 394 axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims) 395 ] 396 if self.virtual_batch_size is not None: 397 # When using virtual batches, add an extra dim at index 1 398 param_shape.insert(1, 1) 399 for idx, x in enumerate(self.axis): 400 self.axis[idx] = x + 1 # Account for added dimension 401 402 if self.scale: 403 self.gamma = self.add_weight( 404 name='gamma', 405 shape=param_shape, 406 dtype=self._param_dtype, 407 initializer=self.gamma_initializer, 408 regularizer=self.gamma_regularizer, 409 constraint=self.gamma_constraint, 410 trainable=True, 411 experimental_autocast=False) 412 else: 413 self.gamma = None 414 if self.fused: 415 self._gamma_const = backend.constant( 416 1.0, dtype=self._param_dtype, shape=param_shape) 417 418 if self.center: 419 self.beta = self.add_weight( 420 name='beta', 421 shape=param_shape, 422 dtype=self._param_dtype, 423 initializer=self.beta_initializer, 424 regularizer=self.beta_regularizer, 425 constraint=self.beta_constraint, 426 trainable=True, 427 experimental_autocast=False) 428 else: 429 self.beta = None 430 if self.fused: 431 self._beta_const = backend.constant( 432 0.0, dtype=self._param_dtype, shape=param_shape) 433 434 try: 435 # Disable variable partitioning when creating the moving mean and variance 436 if hasattr(self, '_scope') and self._scope: 437 partitioner = self._scope.partitioner 438 self._scope.set_partitioner(None) 439 else: 440 partitioner = None 441 self.moving_mean = self.add_weight( 442 name='moving_mean', 443 shape=param_shape, 444 dtype=self._param_dtype, 445 initializer=self.moving_mean_initializer, 446 synchronization=tf_variables.VariableSynchronization.ON_READ, 447 trainable=False, 448 aggregation=tf_variables.VariableAggregation.MEAN, 449 experimental_autocast=False) 450 451 self.moving_variance = self.add_weight( 452 name='moving_variance', 453 shape=param_shape, 454 dtype=self._param_dtype, 455 initializer=self.moving_variance_initializer, 456 synchronization=tf_variables.VariableSynchronization.ON_READ, 457 trainable=False, 458 aggregation=tf_variables.VariableAggregation.MEAN, 459 experimental_autocast=False) 460 461 if self.renorm: 462 # In batch renormalization we track the inference moving stddev instead 463 # of the moving variance to more closely align with the paper. 464 def moving_stddev_initializer(*args, **kwargs): 465 return math_ops.sqrt( 466 self.moving_variance_initializer(*args, **kwargs)) 467 468 with distribution_strategy_context.get_strategy( 469 ).extended.colocate_vars_with(self.moving_variance): 470 self.moving_stddev = self.add_weight( 471 name='moving_stddev', 472 shape=param_shape, 473 dtype=self._param_dtype, 474 initializer=moving_stddev_initializer, 475 synchronization=tf_variables.VariableSynchronization.ON_READ, 476 trainable=False, 477 aggregation=tf_variables.VariableAggregation.MEAN, 478 experimental_autocast=False) 479 480 # Create variables to maintain the moving mean and standard deviation. 481 # These are used in training and thus are different from the moving 482 # averages above. The renorm variables are colocated with moving_mean 483 # and moving_stddev. 484 # NOTE: below, the outer `with device` block causes the current device 485 # stack to be cleared. The nested ones use a `lambda` to set the desired 486 # device and ignore any devices that may be set by the custom getter. 487 def _renorm_variable(name, 488 shape, 489 initializer=init_ops.zeros_initializer()): 490 """Create a renorm variable.""" 491 var = self.add_weight( 492 name=name, 493 shape=shape, 494 dtype=self._param_dtype, 495 initializer=initializer, 496 synchronization=tf_variables.VariableSynchronization.ON_READ, 497 trainable=False, 498 aggregation=tf_variables.VariableAggregation.MEAN, 499 experimental_autocast=False) 500 return var 501 502 with distribution_strategy_context.get_strategy( 503 ).extended.colocate_vars_with(self.moving_mean): 504 self.renorm_mean = _renorm_variable('renorm_mean', param_shape, 505 self.moving_mean_initializer) 506 with distribution_strategy_context.get_strategy( 507 ).extended.colocate_vars_with(self.moving_stddev): 508 self.renorm_stddev = _renorm_variable('renorm_stddev', param_shape, 509 moving_stddev_initializer) 510 finally: 511 if partitioner: 512 self._scope.set_partitioner(partitioner) 513 self.built = True 514 515 def _assign_moving_average(self, variable, value, momentum, inputs_size): 516 517 def calculate_update_delta(): 518 decay = ops.convert_to_tensor_v2_with_dispatch( 519 1.0 - momentum, name='decay') 520 if decay.dtype != variable.dtype.base_dtype: 521 decay = math_ops.cast(decay, variable.dtype.base_dtype) 522 update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay 523 if inputs_size is not None: 524 update_delta = array_ops.where(inputs_size > 0, update_delta, 525 backend.zeros_like(update_delta)) 526 return update_delta 527 528 with backend.name_scope('AssignMovingAvg') as scope: 529 if ops.executing_eagerly_outside_functions(): 530 return variable.assign_sub(calculate_update_delta(), name=scope) 531 else: 532 with ops._colocate_with(variable): # pylint: disable=protected-access 533 return state_ops.assign_sub( 534 variable, calculate_update_delta(), name=scope) 535 536 def _assign_new_value(self, variable, value): 537 with backend.name_scope('AssignNewValue') as scope: 538 if ops.executing_eagerly_outside_functions(): 539 return variable.assign(value, name=scope) 540 else: 541 with ops._colocate_with(variable): # pylint: disable=protected-access 542 return state_ops.assign(variable, value, name=scope) 543 544 def _fused_batch_norm(self, inputs, training): 545 """Returns the output of fused batch norm.""" 546 beta = self.beta if self.center else self._beta_const 547 gamma = self.gamma if self.scale else self._gamma_const 548 549 # TODO(b/129279393): Support zero batch input in non DistributionStrategy 550 # code as well. 551 if self._support_zero_size_input(): 552 # Keras assumes that batch dimension is the first dimension for Batch 553 # Normalization. 554 input_batch_size = array_ops.shape(inputs)[0] 555 else: 556 input_batch_size = None 557 558 # TODO(rmlarsen): Support using fused avg updates for non-eager execution 559 # after fixing graph pattern matching and enabling fused_batch_norm to 560 # take exponential_avg_factor as a tensor input. 561 use_fused_avg_updates = ( 562 ops.executing_eagerly_outside_functions() and 563 isinstance(self.momentum, 564 (float, int)) and get_enclosing_xla_context() is None) 565 if use_fused_avg_updates: 566 exponential_avg_factor = 1.0 - self.momentum 567 else: 568 exponential_avg_factor = None 569 570 def _maybe_add_or_remove_bessels_correction(variance, remove=True): 571 r"""Add or remove Bessel's correction.""" 572 # Removes Bessel's correction if remove == True, adds it otherwise. 573 # This is to be consistent with non-fused batch norm. Note that the 574 # variance computed by fused batch norm is with Bessel's correction. 575 # This is only used in legacy V1 batch norm tests. 576 if self._bessels_correction_test_only: 577 return variance 578 sample_size = math_ops.cast( 579 array_ops.size(inputs) / array_ops.size(variance), variance.dtype) 580 if remove: 581 factor = (sample_size - 582 math_ops.cast(1.0, variance.dtype)) / sample_size 583 else: 584 factor = sample_size / ( 585 sample_size - math_ops.cast(1.0, variance.dtype)) 586 return variance * factor 587 588 def _fused_batch_norm_training(): 589 return nn.fused_batch_norm( 590 inputs, 591 gamma, 592 beta, 593 mean=self.moving_mean, 594 variance=_maybe_add_or_remove_bessels_correction( 595 self.moving_variance, remove=False), 596 epsilon=self.epsilon, 597 is_training=True, 598 data_format=self._data_format, 599 exponential_avg_factor=exponential_avg_factor) 600 601 def _fused_batch_norm_training_empty(): 602 return inputs, self.moving_mean, self.moving_variance 603 604 def _fused_batch_norm_inference(): 605 return nn.fused_batch_norm( 606 inputs, 607 gamma, 608 beta, 609 mean=self.moving_mean, 610 variance=self.moving_variance, 611 epsilon=self.epsilon, 612 is_training=False, 613 data_format=self._data_format) 614 615 train_op = _fused_batch_norm_training 616 if use_fused_avg_updates and input_batch_size is not None: 617 # pylint: disable=g-long-lambda 618 train_op = lambda: control_flow_util.smart_cond( 619 input_batch_size > 0, _fused_batch_norm_training, 620 _fused_batch_norm_training_empty) 621 # pylint: enable=g-long-lambda 622 623 output, mean, variance = control_flow_util.smart_cond( 624 training, train_op, _fused_batch_norm_inference) 625 variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) 626 627 training_value = control_flow_util.constant_value(training) 628 if training_value or training_value is None: 629 if not use_fused_avg_updates: 630 if training_value is None: 631 momentum = control_flow_util.smart_cond(training, 632 lambda: self.momentum, 633 lambda: 1.0) 634 else: 635 momentum = ops.convert_to_tensor_v2_with_dispatch(self.momentum) 636 637 def mean_update(): 638 """Update self.moving_mean with the most recent data point.""" 639 if use_fused_avg_updates: 640 return self._assign_new_value(self.moving_mean, mean) 641 else: 642 return self._assign_moving_average(self.moving_mean, mean, momentum, 643 input_batch_size) 644 645 def variance_update(): 646 """Update self.moving_variance with the most recent data point.""" 647 if use_fused_avg_updates: 648 return self._assign_new_value(self.moving_variance, variance) 649 else: 650 return self._assign_moving_average(self.moving_variance, variance, 651 momentum, input_batch_size) 652 653 self.add_update(mean_update) 654 self.add_update(variance_update) 655 656 return output 657 658 def _renorm_correction_and_moments(self, mean, variance, training, 659 inputs_size): 660 """Returns the correction and update values for renorm.""" 661 stddev = math_ops.sqrt(variance + self.epsilon) 662 # Compute the average mean and standard deviation, as if they were 663 # initialized with this batch's moments. 664 renorm_mean = self.renorm_mean 665 # Avoid divide by zero early on in training. 666 renorm_stddev = math_ops.maximum(self.renorm_stddev, 667 math_ops.sqrt(self.epsilon)) 668 # Compute the corrections for batch renorm. 669 r = stddev / renorm_stddev 670 d = (mean - renorm_mean) / renorm_stddev 671 # Ensure the corrections use pre-update moving averages. 672 with ops.control_dependencies([r, d]): 673 mean = array_ops.identity(mean) 674 stddev = array_ops.identity(stddev) 675 rmin, rmax, dmax = [ 676 self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax'] 677 ] 678 if rmin is not None: 679 r = math_ops.maximum(r, rmin) 680 if rmax is not None: 681 r = math_ops.minimum(r, rmax) 682 if dmax is not None: 683 d = math_ops.maximum(d, -dmax) 684 d = math_ops.minimum(d, dmax) 685 # When not training, use r=1, d=0. 686 r = control_flow_util.smart_cond(training, lambda: r, 687 lambda: array_ops.ones_like(r)) 688 d = control_flow_util.smart_cond(training, lambda: d, 689 lambda: array_ops.zeros_like(d)) 690 691 def _update_renorm_variable(var, value, inputs_size): 692 """Updates a moving average and weight, returns the unbiased value.""" 693 value = array_ops.identity(value) 694 695 def _do_update(): 696 """Updates the var, returns the updated value.""" 697 new_var = self._assign_moving_average(var, value, self.renorm_momentum, 698 inputs_size) 699 return new_var 700 701 def _fake_update(): 702 return array_ops.identity(var) 703 704 return control_flow_util.smart_cond(training, _do_update, _fake_update) 705 706 # TODO(yuefengz): colocate the operations 707 update_new_mean = _update_renorm_variable(self.renorm_mean, mean, 708 inputs_size) 709 update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev, 710 inputs_size) 711 712 # Update the inference mode moving averages with the batch value. 713 with ops.control_dependencies([update_new_mean, update_new_stddev]): 714 out_mean = array_ops.identity(mean) 715 out_variance = array_ops.identity(variance) 716 717 return (r, d, out_mean, out_variance) 718 719 def _calculate_mean_and_var(self, inputs, reduction_axes, keep_dims): 720 return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) 721 722 def _moments(self, inputs, reduction_axes, keep_dims): 723 mean, variance = self._calculate_mean_and_var(inputs, reduction_axes, 724 keep_dims) 725 # TODO(b/129279393): Support zero batch input in non DistributionStrategy 726 # code as well. 727 if self._support_zero_size_input(): 728 input_batch_size = array_ops.shape(inputs)[0] 729 mean = array_ops.where(input_batch_size > 0, mean, 730 backend.zeros_like(mean)) 731 variance = array_ops.where(input_batch_size > 0, variance, 732 backend.zeros_like(variance)) 733 return mean, variance 734 735 def _get_training_value(self, training=None): 736 if training is None: 737 training = backend.learning_phase() 738 if self._USE_V2_BEHAVIOR: 739 if isinstance(training, int): 740 training = bool(training) 741 if not self.trainable: 742 # When the layer is not trainable, it overrides the value passed from 743 # model. 744 training = False 745 return training 746 747 def call(self, inputs, training=None): 748 training = self._get_training_value(training) 749 750 if self.virtual_batch_size is not None: 751 # Virtual batches (aka ghost batches) can be simulated by reshaping the 752 # Tensor and reusing the existing batch norm implementation 753 original_shape = array_ops.shape(inputs) 754 original_shape = array_ops.concat( 755 [constant_op.constant([-1]), original_shape[1:]], axis=0) 756 expanded_shape = array_ops.concat([ 757 constant_op.constant([self.virtual_batch_size, -1]), 758 original_shape[1:] 759 ], 760 axis=0) 761 762 # Will cause errors if virtual_batch_size does not divide the batch size 763 inputs = array_ops.reshape(inputs, expanded_shape) 764 765 def undo_virtual_batching(outputs): 766 outputs = array_ops.reshape(outputs, original_shape) 767 return outputs 768 769 if self.fused: 770 outputs = self._fused_batch_norm(inputs, training=training) 771 if self.virtual_batch_size is not None: 772 # Currently never reaches here since fused_batch_norm does not support 773 # virtual batching 774 outputs = undo_virtual_batching(outputs) 775 return outputs 776 777 inputs_dtype = inputs.dtype.base_dtype 778 if inputs_dtype in (dtypes.float16, dtypes.bfloat16): 779 # Do all math in float32 if given 16-bit inputs for numeric stability. 780 # In particular, it's very easy for variance to overflow in float16 and 781 # for safety we also choose to cast bfloat16 to float32. 782 inputs = math_ops.cast(inputs, dtypes.float32) 783 784 # Compute the axes along which to reduce the mean / variance 785 input_shape = inputs.shape 786 ndims = len(input_shape) 787 reduction_axes = [i for i in range(ndims) if i not in self.axis] 788 if self.virtual_batch_size is not None: 789 del reduction_axes[1] # Do not reduce along virtual batch dim 790 791 # Broadcasting only necessary for single-axis batch norm where the axis is 792 # not the last dimension 793 broadcast_shape = [1] * ndims 794 broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value 795 796 def _broadcast(v): 797 if (v is not None and len(v.shape) != ndims and 798 reduction_axes != list(range(ndims - 1))): 799 return array_ops.reshape(v, broadcast_shape) 800 return v 801 802 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 803 804 def _compose_transforms(scale, offset, then_scale, then_offset): 805 if then_scale is not None: 806 scale *= then_scale 807 offset *= then_scale 808 if then_offset is not None: 809 offset += then_offset 810 return (scale, offset) 811 812 # Determine a boolean value for `training`: could be True, False, or None. 813 training_value = control_flow_util.constant_value(training) 814 if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison 815 mean, variance = self.moving_mean, self.moving_variance 816 else: 817 if self.adjustment: 818 adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) 819 # Adjust only during training. 820 adj_scale = control_flow_util.smart_cond( 821 training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) 822 adj_bias = control_flow_util.smart_cond( 823 training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) 824 scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) 825 826 # Some of the computations here are not necessary when training==False 827 # but not a constant. However, this makes the code simpler. 828 keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 829 mean, variance = self._moments( 830 math_ops.cast(inputs, self._param_dtype), 831 reduction_axes, 832 keep_dims=keep_dims) 833 834 moving_mean = self.moving_mean 835 moving_variance = self.moving_variance 836 837 mean = control_flow_util.smart_cond( 838 training, lambda: mean, 839 lambda: ops.convert_to_tensor_v2_with_dispatch(moving_mean)) 840 variance = control_flow_util.smart_cond( 841 training, lambda: variance, 842 lambda: ops.convert_to_tensor_v2_with_dispatch(moving_variance)) 843 844 if self.virtual_batch_size is not None: 845 # This isn't strictly correct since in ghost batch norm, you are 846 # supposed to sequentially update the moving_mean and moving_variance 847 # with each sub-batch. However, since the moving statistics are only 848 # used during evaluation, it is more efficient to just update in one 849 # step and should not make a significant difference in the result. 850 new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) 851 new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) 852 else: 853 new_mean, new_variance = mean, variance 854 855 if self._support_zero_size_input(): 856 # Keras assumes that batch dimension is the first dimension for Batch 857 # Normalization. 858 input_batch_size = array_ops.shape(inputs)[0] 859 else: 860 input_batch_size = None 861 862 if self.renorm: 863 r, d, new_mean, new_variance = self._renorm_correction_and_moments( 864 new_mean, new_variance, training, input_batch_size) 865 # When training, the normalized values (say, x) will be transformed as 866 # x * gamma + beta without renorm, and (x * r + d) * gamma + beta 867 # = x * (r * gamma) + (d * gamma + beta) with renorm. 868 r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) 869 d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) 870 scale, offset = _compose_transforms(r, d, scale, offset) 871 872 def _do_update(var, value): 873 """Compute the updates for mean and variance.""" 874 return self._assign_moving_average(var, value, self.momentum, 875 input_batch_size) 876 877 def mean_update(): 878 true_branch = lambda: _do_update(self.moving_mean, new_mean) 879 false_branch = lambda: self.moving_mean 880 return control_flow_util.smart_cond(training, true_branch, false_branch) 881 882 def variance_update(): 883 """Update the moving variance.""" 884 885 def true_branch_renorm(): 886 # We apply epsilon as part of the moving_stddev to mirror the training 887 # code path. 888 moving_stddev = _do_update(self.moving_stddev, 889 math_ops.sqrt(new_variance + self.epsilon)) 890 return self._assign_new_value( 891 self.moving_variance, 892 # Apply relu in case floating point rounding causes it to go 893 # negative. 894 backend.relu(moving_stddev * moving_stddev - self.epsilon)) 895 896 if self.renorm: 897 true_branch = true_branch_renorm 898 else: 899 true_branch = lambda: _do_update(self.moving_variance, new_variance) 900 901 false_branch = lambda: self.moving_variance 902 return control_flow_util.smart_cond(training, true_branch, false_branch) 903 904 self.add_update(mean_update) 905 self.add_update(variance_update) 906 907 mean = math_ops.cast(mean, inputs.dtype) 908 variance = math_ops.cast(variance, inputs.dtype) 909 if offset is not None: 910 offset = math_ops.cast(offset, inputs.dtype) 911 if scale is not None: 912 scale = math_ops.cast(scale, inputs.dtype) 913 outputs = nn.batch_normalization(inputs, _broadcast(mean), 914 _broadcast(variance), offset, scale, 915 self.epsilon) 916 if inputs_dtype in (dtypes.float16, dtypes.bfloat16): 917 outputs = math_ops.cast(outputs, inputs_dtype) 918 919 # If some components of the shape got lost due to adjustments, fix that. 920 outputs.set_shape(input_shape) 921 922 if self.virtual_batch_size is not None: 923 outputs = undo_virtual_batching(outputs) 924 return outputs 925 926 def compute_output_shape(self, input_shape): 927 return input_shape 928 929 def get_config(self): 930 config = { 931 'axis': 932 self.axis, 933 'momentum': 934 self.momentum, 935 'epsilon': 936 self.epsilon, 937 'center': 938 self.center, 939 'scale': 940 self.scale, 941 'beta_initializer': 942 initializers.serialize(self.beta_initializer), 943 'gamma_initializer': 944 initializers.serialize(self.gamma_initializer), 945 'moving_mean_initializer': 946 initializers.serialize(self.moving_mean_initializer), 947 'moving_variance_initializer': 948 initializers.serialize(self.moving_variance_initializer), 949 'beta_regularizer': 950 regularizers.serialize(self.beta_regularizer), 951 'gamma_regularizer': 952 regularizers.serialize(self.gamma_regularizer), 953 'beta_constraint': 954 constraints.serialize(self.beta_constraint), 955 'gamma_constraint': 956 constraints.serialize(self.gamma_constraint) 957 } 958 # Only add TensorFlow-specific parameters if they are set, so as to preserve 959 # model compatibility with external Keras. 960 if self.renorm: 961 config['renorm'] = True 962 config['renorm_clipping'] = self.renorm_clipping 963 config['renorm_momentum'] = self.renorm_momentum 964 if self.virtual_batch_size is not None: 965 config['virtual_batch_size'] = self.virtual_batch_size 966 # Note: adjustment is not serializable. 967 if self.adjustment is not None: 968 logging.warning('The `adjustment` function of this `BatchNormalization` ' 969 'layer cannot be serialized and has been omitted from ' 970 'the layer config. It will not be included when ' 971 're-creating the layer from the saved config.') 972 base_config = super(BatchNormalizationBase, self).get_config() 973 return dict(list(base_config.items()) + list(config.items())) 974 975 976# pylint: disable=g-classes-have-attributes 977@keras_export('keras.layers.experimental.SyncBatchNormalization', v1=[]) 978class SyncBatchNormalization(BatchNormalizationBase): 979 r"""Normalize and scale inputs or activations synchronously across replicas. 980 981 Applies batch normalization to activations of the previous layer at each batch 982 by synchronizing the global batch statistics across all devices that are 983 training the model. For specific details about batch normalization please 984 refer to the `tf.keras.layers.BatchNormalization` layer docs. 985 986 If this layer is used when using tf.distribute strategy to train models 987 across devices/workers, there will be an allreduce call to aggregate batch 988 statistics across all replicas at every training step. Without tf.distribute 989 strategy, this layer behaves as a regular `tf.keras.layers.BatchNormalization` 990 layer. 991 992 Example usage: 993 994 ```python 995 strategy = tf.distribute.MirroredStrategy() 996 997 with strategy.scope(): 998 model = tf.keras.Sequential() 999 model.add(tf.keras.layers.Dense(16)) 1000 model.add(tf.keras.layers.experimental.SyncBatchNormalization()) 1001 ``` 1002 1003 Args: 1004 axis: Integer, the axis that should be normalized 1005 (typically the features axis). 1006 For instance, after a `Conv2D` layer with 1007 `data_format="channels_first"`, 1008 set `axis=1` in `BatchNormalization`. 1009 momentum: Momentum for the moving average. 1010 epsilon: Small float added to variance to avoid dividing by zero. 1011 center: If True, add offset of `beta` to normalized tensor. 1012 If False, `beta` is ignored. 1013 scale: If True, multiply by `gamma`. 1014 If False, `gamma` is not used. 1015 When the next layer is linear (also e.g. `nn.relu`), 1016 this can be disabled since the scaling 1017 will be done by the next layer. 1018 beta_initializer: Initializer for the beta weight. 1019 gamma_initializer: Initializer for the gamma weight. 1020 moving_mean_initializer: Initializer for the moving mean. 1021 moving_variance_initializer: Initializer for the moving variance. 1022 beta_regularizer: Optional regularizer for the beta weight. 1023 gamma_regularizer: Optional regularizer for the gamma weight. 1024 beta_constraint: Optional constraint for the beta weight. 1025 gamma_constraint: Optional constraint for the gamma weight. 1026 1027 Call arguments: 1028 inputs: Input tensor (of any rank). 1029 training: Python boolean indicating whether the layer should behave in 1030 training mode or in inference mode. 1031 - `training=True`: The layer will normalize its inputs using the 1032 mean and variance of the current batch of inputs. 1033 - `training=False`: The layer will normalize its inputs using the 1034 mean and variance of its moving statistics, learned during training. 1035 1036 Input shape: 1037 Arbitrary. Use the keyword argument `input_shape` 1038 (tuple of integers, does not include the samples axis) 1039 when using this layer as the first layer in a model. 1040 1041 Output shape: 1042 Same shape as input. 1043 1044 """ 1045 1046 def __init__(self, 1047 axis=-1, 1048 momentum=0.99, 1049 epsilon=1e-3, 1050 center=True, 1051 scale=True, 1052 beta_initializer='zeros', 1053 gamma_initializer='ones', 1054 moving_mean_initializer='zeros', 1055 moving_variance_initializer='ones', 1056 beta_regularizer=None, 1057 gamma_regularizer=None, 1058 beta_constraint=None, 1059 gamma_constraint=None, 1060 **kwargs): 1061 if kwargs.pop('fused', None): 1062 raise ValueError( 1063 '`fused` argument cannot be True for SyncBatchNormalization.') 1064 1065 # Currently we only support aggregating over the global batch size. 1066 super(SyncBatchNormalization, self).__init__( 1067 axis=axis, 1068 momentum=momentum, 1069 epsilon=epsilon, 1070 center=center, 1071 scale=scale, 1072 beta_initializer=beta_initializer, 1073 gamma_initializer=gamma_initializer, 1074 moving_mean_initializer=moving_mean_initializer, 1075 moving_variance_initializer=moving_variance_initializer, 1076 beta_regularizer=beta_regularizer, 1077 gamma_regularizer=gamma_regularizer, 1078 beta_constraint=beta_constraint, 1079 gamma_constraint=gamma_constraint, 1080 fused=False, 1081 **kwargs) 1082 1083 def _calculate_mean_and_var(self, x, axes, keep_dims): 1084 1085 with backend.name_scope('moments'): 1086 # The dynamic range of fp16 is too limited to support the collection of 1087 # sufficient statistics. As a workaround we simply perform the operations 1088 # on 32-bit floats before converting the mean and variance back to fp16 1089 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x 1090 replica_ctx = distribution_strategy_context.get_replica_context() 1091 if replica_ctx: 1092 local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) 1093 local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, 1094 keepdims=True) 1095 batch_size = math_ops.cast(array_ops.shape_v2(y)[axes[0]], 1096 dtypes.float32) 1097 # TODO(b/163099951): batch the all-reduces once we sort out the ordering 1098 # issue for NCCL. We don't have a mechanism to launch NCCL in the same 1099 # order in each replica nowadays, so we limit NCCL to batch all-reduces. 1100 y_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, local_sum) 1101 y_squared_sum = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, 1102 local_squared_sum) 1103 global_batch_size = replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, 1104 batch_size) 1105 1106 axes_vals = [(array_ops.shape_v2(y))[axes[i]] 1107 for i in range(1, len(axes))] 1108 multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), 1109 dtypes.float32) 1110 multiplier = multiplier * global_batch_size 1111 1112 mean = y_sum / multiplier 1113 y_squared_mean = y_squared_sum / multiplier 1114 # var = E(x^2) - E(x)^2 1115 variance = y_squared_mean - math_ops.square(mean) 1116 else: 1117 # Compute true mean while keeping the dims for proper broadcasting. 1118 mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') 1119 # sample variance, not unbiased variance 1120 # Note: stop_gradient does not change the gradient that gets 1121 # backpropagated to the mean from the variance calculation, 1122 # because that gradient is zero 1123 variance = math_ops.reduce_mean( 1124 math_ops.squared_difference(y, array_ops.stop_gradient(mean)), 1125 axes, 1126 keepdims=True, 1127 name='variance') 1128 if not keep_dims: 1129 mean = array_ops.squeeze(mean, axes) 1130 variance = array_ops.squeeze(variance, axes) 1131 if x.dtype == dtypes.float16: 1132 return (math_ops.cast(mean, dtypes.float16), 1133 math_ops.cast(variance, dtypes.float16)) 1134 else: 1135 return (mean, variance) 1136 1137 1138@keras_export('keras.layers.BatchNormalization', v1=[]) 1139class BatchNormalization(BatchNormalizationBase): 1140 """Layer that normalizes its inputs. 1141 1142 Batch normalization applies a transformation that maintains the mean output 1143 close to 0 and the output standard deviation close to 1. 1144 1145 Importantly, batch normalization works differently during training and 1146 during inference. 1147 1148 **During training** (i.e. when using `fit()` or when calling the layer/model 1149 with the argument `training=True`), the layer normalizes its output using 1150 the mean and standard deviation of the current batch of inputs. That is to 1151 say, for each channel being normalized, the layer returns 1152 `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where: 1153 1154 - `epsilon` is small constant (configurable as part of the constructor 1155 arguments) 1156 - `gamma` is a learned scaling factor (initialized as 1), which 1157 can be disabled by passing `scale=False` to the constructor. 1158 - `beta` is a learned offset factor (initialized as 0), which 1159 can be disabled by passing `center=False` to the constructor. 1160 1161 **During inference** (i.e. when using `evaluate()` or `predict()` or when 1162 calling the layer/model with the argument `training=False` (which is the 1163 default), the layer normalizes its output using a moving average of the 1164 mean and standard deviation of the batches it has seen during training. That 1165 is to say, it returns 1166 `gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`. 1167 1168 `self.moving_mean` and `self.moving_var` are non-trainable variables that 1169 are updated each time the layer in called in training mode, as such: 1170 1171 - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)` 1172 - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)` 1173 1174 As such, the layer will only normalize its inputs during inference 1175 *after having been trained on data that has similar statistics as the 1176 inference data*. 1177 1178 Args: 1179 axis: Integer, the axis that should be normalized (typically the features 1180 axis). For instance, after a `Conv2D` layer with 1181 `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. 1182 momentum: Momentum for the moving average. 1183 epsilon: Small float added to variance to avoid dividing by zero. 1184 center: If True, add offset of `beta` to normalized tensor. If False, `beta` 1185 is ignored. 1186 scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the 1187 next layer is linear (also e.g. `nn.relu`), this can be disabled since the 1188 scaling will be done by the next layer. 1189 beta_initializer: Initializer for the beta weight. 1190 gamma_initializer: Initializer for the gamma weight. 1191 moving_mean_initializer: Initializer for the moving mean. 1192 moving_variance_initializer: Initializer for the moving variance. 1193 beta_regularizer: Optional regularizer for the beta weight. 1194 gamma_regularizer: Optional regularizer for the gamma weight. 1195 beta_constraint: Optional constraint for the beta weight. 1196 gamma_constraint: Optional constraint for the gamma weight. 1197 1198 Call arguments: 1199 inputs: Input tensor (of any rank). 1200 training: Python boolean indicating whether the layer should behave in 1201 training mode or in inference mode. 1202 - `training=True`: The layer will normalize its inputs using the mean and 1203 variance of the current batch of inputs. 1204 - `training=False`: The layer will normalize its inputs using the mean and 1205 variance of its moving statistics, learned during training. 1206 1207 Input shape: 1208 Arbitrary. Use the keyword argument `input_shape` (tuple of 1209 integers, does not include the samples axis) when using this layer as the 1210 first layer in a model. 1211 1212 Output shape: 1213 Same shape as input. 1214 1215 Reference: 1216 - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). 1217 1218 **About setting `layer.trainable = False` on a `BatchNormalization` layer:** 1219 1220 The meaning of setting `layer.trainable = False` is to freeze the layer, 1221 i.e. its internal state will not change during training: 1222 its trainable weights will not be updated 1223 during `fit()` or `train_on_batch()`, and its state updates will not be run. 1224 1225 Usually, this does not necessarily mean that the layer is run in inference 1226 mode (which is normally controlled by the `training` argument that can 1227 be passed when calling a layer). "Frozen state" and "inference mode" 1228 are two separate concepts. 1229 1230 However, in the case of the `BatchNormalization` layer, **setting 1231 `trainable = False` on the layer means that the layer will be 1232 subsequently run in inference mode** (meaning that it will use 1233 the moving mean and the moving variance to normalize the current batch, 1234 rather than using the mean and variance of the current batch). 1235 1236 This behavior has been introduced in TensorFlow 2.0, in order 1237 to enable `layer.trainable = False` to produce the most commonly 1238 expected behavior in the convnet fine-tuning use case. 1239 1240 Note that: 1241 - Setting `trainable` on an model containing other layers will 1242 recursively set the `trainable` value of all inner layers. 1243 - If the value of the `trainable` 1244 attribute is changed after calling `compile()` on a model, 1245 the new value doesn't take effect for this model 1246 until `compile()` is called again. 1247 """ 1248 _USE_V2_BEHAVIOR = True 1249 1250 def __init__(self, 1251 axis=-1, 1252 momentum=0.99, 1253 epsilon=1e-3, 1254 center=True, 1255 scale=True, 1256 beta_initializer='zeros', 1257 gamma_initializer='ones', 1258 moving_mean_initializer='zeros', 1259 moving_variance_initializer='ones', 1260 beta_regularizer=None, 1261 gamma_regularizer=None, 1262 beta_constraint=None, 1263 gamma_constraint=None, 1264 **kwargs): 1265 super(BatchNormalization, self).__init__( 1266 axis=axis, 1267 momentum=momentum, 1268 epsilon=epsilon, 1269 center=center, 1270 scale=scale, 1271 beta_initializer=beta_initializer, 1272 gamma_initializer=gamma_initializer, 1273 moving_mean_initializer=moving_mean_initializer, 1274 moving_variance_initializer=moving_variance_initializer, 1275 beta_regularizer=beta_regularizer, 1276 gamma_regularizer=gamma_regularizer, 1277 beta_constraint=beta_constraint, 1278 gamma_constraint=gamma_constraint, 1279 **kwargs) 1280