1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Layer Normalization layer.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.keras import constraints 19from tensorflow.python.keras import initializers 20from tensorflow.python.keras import regularizers 21from tensorflow.python.keras.engine.base_layer import Layer 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops import nn 25 26from tensorflow.python.util.tf_export import keras_export 27 28 29@keras_export('keras.layers.LayerNormalization') 30class LayerNormalization(Layer): 31 """Layer normalization layer (Ba et al., 2016). 32 33 Normalize the activations of the previous layer for each given example in a 34 batch independently, rather than across a batch like Batch Normalization. 35 i.e. applies a transformation that maintains the mean activation within each 36 example close to 0 and the activation standard deviation close to 1. 37 38 Given a tensor `inputs`, moments are calculated and normalization 39 is performed across the axes specified in `axis`. 40 41 Example: 42 43 >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32) 44 >>> print(data) 45 tf.Tensor( 46 [[ 0. 10.] 47 [20. 30.] 48 [40. 50.] 49 [60. 70.] 50 [80. 90.]], shape=(5, 2), dtype=float32) 51 52 >>> layer = tf.keras.layers.LayerNormalization(axis=1) 53 >>> output = layer(data) 54 >>> print(output) 55 tf.Tensor( 56 [[-1. 1.] 57 [-1. 1.] 58 [-1. 1.] 59 [-1. 1.] 60 [-1. 1.]], shape=(5, 2), dtype=float32) 61 62 Notice that with Layer Normalization the normalization happens across the 63 axes *within* each example, rather than across different examples in the 64 batch. 65 66 If `scale` or `center` are enabled, the layer will scale the normalized 67 outputs by broadcasting them with a trainable variable `gamma`, and center 68 the outputs by broadcasting with a trainable variable `beta`. `gamma` will 69 default to a ones tensor and `beta` will default to a zeros tensor, so that 70 centering and scaling are no-ops before training has begun. 71 72 So, with scaling and centering enabled the normalization equations 73 are as follows: 74 75 Let the intermediate activations for a mini-batch to be the `inputs`. 76 77 For each sample `x_i` in `inputs` with `k` features, we compute the mean and 78 variance of the sample: 79 80 ```python 81 mean_i = sum(x_i[j] for j in range(k)) / k 82 var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k 83 ``` 84 85 and then compute a normalized `x_i_normalized`, including a small factor 86 `epsilon` for numerical stability. 87 88 ```python 89 x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon) 90 ``` 91 92 And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`, 93 which are learned parameters: 94 95 ```python 96 output_i = x_i_normalized * gamma + beta 97 ``` 98 99 `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and 100 this part of the inputs' shape must be fully defined. 101 102 For example: 103 104 >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3]) 105 >>> layer.build([5, 20, 30, 40]) 106 >>> print(layer.beta.shape) 107 (20, 30, 40) 108 >>> print(layer.gamma.shape) 109 (20, 30, 40) 110 111 Note that other implementations of layer normalization may choose to define 112 `gamma` and `beta` over a separate set of axes from the axes being 113 normalized across. For example, Group Normalization 114 ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1 115 corresponds to a Layer Normalization that normalizes across height, width, 116 and channel and has `gamma` and `beta` span only the channel dimension. 117 So, this Layer Normalization implementation will not match a Group 118 Normalization layer with group size set to 1. 119 120 Args: 121 axis: Integer or List/Tuple. The axis or axes to normalize across. Typically 122 this is the features axis/axes. The left-out axes are typically the batch 123 axis/axes. This argument defaults to `-1`, the last dimension in the 124 input. 125 epsilon: Small float added to variance to avoid dividing by zero. Defaults 126 to 1e-3 127 center: If True, add offset of `beta` to normalized tensor. If False, `beta` 128 is ignored. Defaults to True. 129 scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults 130 to True. When the next layer is linear (also e.g. `nn.relu`), this can be 131 disabled since the scaling will be done by the next layer. 132 beta_initializer: Initializer for the beta weight. Defaults to zeros. 133 gamma_initializer: Initializer for the gamma weight. Defaults to ones. 134 beta_regularizer: Optional regularizer for the beta weight. None by default. 135 gamma_regularizer: Optional regularizer for the gamma weight. None by 136 default. 137 beta_constraint: Optional constraint for the beta weight. None by default. 138 gamma_constraint: Optional constraint for the gamma weight. None by default. 139 140 Input shape: 141 Arbitrary. Use the keyword argument `input_shape` (tuple of 142 integers, does not include the samples axis) when using this layer as the 143 first layer in a model. 144 145 Output shape: 146 Same shape as input. 147 148 Reference: 149 - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450). 150 """ 151 152 def __init__(self, 153 axis=-1, 154 epsilon=1e-3, 155 center=True, 156 scale=True, 157 beta_initializer='zeros', 158 gamma_initializer='ones', 159 beta_regularizer=None, 160 gamma_regularizer=None, 161 beta_constraint=None, 162 gamma_constraint=None, 163 **kwargs): 164 super(LayerNormalization, self).__init__(**kwargs) 165 if isinstance(axis, (list, tuple)): 166 self.axis = axis[:] 167 elif isinstance(axis, int): 168 self.axis = axis 169 else: 170 raise TypeError('Expected an int or a list/tuple of ints for the ' 171 'argument \'axis\', but received: %r' % axis) 172 173 self.epsilon = epsilon 174 self.center = center 175 self.scale = scale 176 self.beta_initializer = initializers.get(beta_initializer) 177 self.gamma_initializer = initializers.get(gamma_initializer) 178 self.beta_regularizer = regularizers.get(beta_regularizer) 179 self.gamma_regularizer = regularizers.get(gamma_regularizer) 180 self.beta_constraint = constraints.get(beta_constraint) 181 self.gamma_constraint = constraints.get(gamma_constraint) 182 183 self.supports_masking = True 184 185 # Indicates whether a faster fused implementation can be used. This will be 186 # set to True or False in build()" 187 self._fused = None 188 189 def _fused_can_be_used(self, ndims): 190 """Returns false if fused implementation cannot be used. 191 192 Check if the axis is contiguous and can be collapsed into the last axis. 193 The self.axis is assumed to have no duplicates. 194 """ 195 axis = sorted(self.axis) 196 can_use_fused = False 197 198 if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1: 199 can_use_fused = True 200 201 # fused_batch_norm will silently raise epsilon to be at least 1.001e-5, so 202 # we cannot used the fused version if epsilon is below that value. Also, the 203 # variable dtype must be float32, as fused_batch_norm only supports float32 204 # variables. 205 if self.epsilon < 1.001e-5 or self.dtype != 'float32': 206 can_use_fused = False 207 208 return can_use_fused 209 210 def build(self, input_shape): 211 ndims = len(input_shape) 212 if ndims is None: 213 raise ValueError('Input shape %s has undefined rank.' % input_shape) 214 215 # Convert axis to list and resolve negatives 216 if isinstance(self.axis, int): 217 self.axis = [self.axis] 218 elif isinstance(self.axis, tuple): 219 self.axis = list(self.axis) 220 for idx, x in enumerate(self.axis): 221 if x < 0: 222 self.axis[idx] = ndims + x 223 224 # Validate axes 225 for x in self.axis: 226 if x < 0 or x >= ndims: 227 raise ValueError('Invalid axis: %d' % x) 228 if len(self.axis) != len(set(self.axis)): 229 raise ValueError('Duplicate axis: {}'.format(tuple(self.axis))) 230 231 param_shape = [input_shape[dim] for dim in self.axis] 232 if self.scale: 233 self.gamma = self.add_weight( 234 name='gamma', 235 shape=param_shape, 236 initializer=self.gamma_initializer, 237 regularizer=self.gamma_regularizer, 238 constraint=self.gamma_constraint, 239 trainable=True, 240 experimental_autocast=False) 241 else: 242 self.gamma = None 243 244 if self.center: 245 self.beta = self.add_weight( 246 name='beta', 247 shape=param_shape, 248 initializer=self.beta_initializer, 249 regularizer=self.beta_regularizer, 250 constraint=self.beta_constraint, 251 trainable=True, 252 experimental_autocast=False) 253 else: 254 self.beta = None 255 256 self._fused = self._fused_can_be_used(ndims) 257 258 self.built = True 259 260 def call(self, inputs): 261 # Compute the axes along which to reduce the mean / variance 262 input_shape = inputs.shape 263 ndims = len(input_shape) 264 265 # Broadcasting only necessary for norm when the axis is not just 266 # the last dimension 267 broadcast_shape = [1] * ndims 268 for dim in self.axis: 269 broadcast_shape[dim] = input_shape.dims[dim].value 270 271 def _broadcast(v): 272 if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): 273 return array_ops.reshape(v, broadcast_shape) 274 return v 275 276 if not self._fused: 277 input_dtype = inputs.dtype 278 if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32': 279 # If mixed precision is used, cast inputs to float32 so that this is at 280 # least as numerically stable as the fused version. 281 inputs = math_ops.cast(inputs, 'float32') 282 283 # Calculate the moments on the last axis (layer activations). 284 mean, variance = nn.moments(inputs, self.axis, keep_dims=True) 285 286 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 287 288 # Compute layer normalization using the batch_normalization function. 289 outputs = nn.batch_normalization( 290 inputs, 291 mean, 292 variance, 293 offset=offset, 294 scale=scale, 295 variance_epsilon=self.epsilon) 296 outputs = math_ops.cast(outputs, input_dtype) 297 else: 298 # Collapse dims before self.axis, and dims in self.axis 299 pre_dim, in_dim = (1, 1) 300 axis = sorted(self.axis) 301 tensor_shape = array_ops.shape(inputs) 302 for dim in range(0, ndims): 303 dim_tensor = tensor_shape[dim] 304 if dim < axis[0]: 305 pre_dim = pre_dim * dim_tensor 306 else: 307 assert dim in axis 308 in_dim = in_dim * dim_tensor 309 310 squeezed_shape = [1, pre_dim, in_dim, 1] 311 # This fused operation requires reshaped inputs to be NCHW. 312 data_format = 'NCHW' 313 314 inputs = array_ops.reshape(inputs, squeezed_shape) 315 316 # self.gamma and self.beta have the wrong shape for fused_batch_norm, so 317 # we cannot pass them as the scale and offset parameters. Therefore, we 318 # create two constant tensors in correct shapes for fused_batch_norm and 319 # later construct a separate calculation on the scale and offset. 320 scale = array_ops.ones([pre_dim], dtype=self.dtype) 321 offset = array_ops.zeros([pre_dim], dtype=self.dtype) 322 323 # Compute layer normalization using the fused_batch_norm function. 324 outputs, _, _ = nn.fused_batch_norm( 325 inputs, 326 scale=scale, 327 offset=offset, 328 epsilon=self.epsilon, 329 data_format=data_format) 330 331 outputs = array_ops.reshape(outputs, tensor_shape) 332 333 scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 334 335 if scale is not None: 336 outputs = outputs * math_ops.cast(scale, outputs.dtype) 337 if offset is not None: 338 outputs = outputs + math_ops.cast(offset, outputs.dtype) 339 340 # If some components of the shape got lost due to adjustments, fix that. 341 outputs.set_shape(input_shape) 342 343 return outputs 344 345 def compute_output_shape(self, input_shape): 346 return input_shape 347 348 def get_config(self): 349 config = { 350 'axis': self.axis, 351 'epsilon': self.epsilon, 352 'center': self.center, 353 'scale': self.scale, 354 'beta_initializer': initializers.serialize(self.beta_initializer), 355 'gamma_initializer': initializers.serialize(self.gamma_initializer), 356 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 357 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 358 'beta_constraint': constraints.serialize(self.beta_constraint), 359 'gamma_constraint': constraints.serialize(self.gamma_constraint) 360 } 361 base_config = super(LayerNormalization, self).get_config() 362 return dict(list(base_config.items()) + list(config.items())) 363