• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Layer Normalization layer."""
16# pylint: disable=g-classes-have-attributes
17
18from tensorflow.python.keras import constraints
19from tensorflow.python.keras import initializers
20from tensorflow.python.keras import regularizers
21from tensorflow.python.keras.engine.base_layer import Layer
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import nn
25
26from tensorflow.python.util.tf_export import keras_export
27
28
29@keras_export('keras.layers.LayerNormalization')
30class LayerNormalization(Layer):
31  """Layer normalization layer (Ba et al., 2016).
32
33  Normalize the activations of the previous layer for each given example in a
34  batch independently, rather than across a batch like Batch Normalization.
35  i.e. applies a transformation that maintains the mean activation within each
36  example close to 0 and the activation standard deviation close to 1.
37
38  Given a tensor `inputs`, moments are calculated and normalization
39  is performed across the axes specified in `axis`.
40
41  Example:
42
43  >>> data = tf.constant(np.arange(10).reshape(5, 2) * 10, dtype=tf.float32)
44  >>> print(data)
45  tf.Tensor(
46  [[ 0. 10.]
47   [20. 30.]
48   [40. 50.]
49   [60. 70.]
50   [80. 90.]], shape=(5, 2), dtype=float32)
51
52  >>> layer = tf.keras.layers.LayerNormalization(axis=1)
53  >>> output = layer(data)
54  >>> print(output)
55  tf.Tensor(
56  [[-1. 1.]
57   [-1. 1.]
58   [-1. 1.]
59   [-1. 1.]
60   [-1. 1.]], shape=(5, 2), dtype=float32)
61
62  Notice that with Layer Normalization the normalization happens across the
63  axes *within* each example, rather than across different examples in the
64  batch.
65
66  If `scale` or `center` are enabled, the layer will scale the normalized
67  outputs by broadcasting them with a trainable variable `gamma`, and center
68  the outputs by broadcasting with a trainable variable `beta`. `gamma` will
69  default to a ones tensor and `beta` will default to a zeros tensor, so that
70  centering and scaling are no-ops before training has begun.
71
72  So, with scaling and centering enabled the normalization equations
73  are as follows:
74
75  Let the intermediate activations for a mini-batch to be the `inputs`.
76
77  For each sample `x_i` in `inputs` with `k` features, we compute the mean and
78  variance of the sample:
79
80  ```python
81  mean_i = sum(x_i[j] for j in range(k)) / k
82  var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k
83  ```
84
85  and then compute a normalized `x_i_normalized`, including a small factor
86  `epsilon` for numerical stability.
87
88  ```python
89  x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)
90  ```
91
92  And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,
93  which are learned parameters:
94
95  ```python
96  output_i = x_i_normalized * gamma + beta
97  ```
98
99  `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and
100  this part of the inputs' shape must be fully defined.
101
102  For example:
103
104  >>> layer = tf.keras.layers.LayerNormalization(axis=[1, 2, 3])
105  >>> layer.build([5, 20, 30, 40])
106  >>> print(layer.beta.shape)
107  (20, 30, 40)
108  >>> print(layer.gamma.shape)
109  (20, 30, 40)
110
111  Note that other implementations of layer normalization may choose to define
112  `gamma` and `beta` over a separate set of axes from the axes being
113  normalized across. For example, Group Normalization
114  ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1
115  corresponds to a Layer Normalization that normalizes across height, width,
116  and channel and has `gamma` and `beta` span only the channel dimension.
117  So, this Layer Normalization implementation will not match a Group
118  Normalization layer with group size set to 1.
119
120  Args:
121    axis: Integer or List/Tuple. The axis or axes to normalize across. Typically
122      this is the features axis/axes. The left-out axes are typically the batch
123      axis/axes. This argument defaults to `-1`, the last dimension in the
124      input.
125    epsilon: Small float added to variance to avoid dividing by zero. Defaults
126      to 1e-3
127    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
128      is ignored. Defaults to True.
129    scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults
130      to True. When the next layer is linear (also e.g. `nn.relu`), this can be
131      disabled since the scaling will be done by the next layer.
132    beta_initializer: Initializer for the beta weight. Defaults to zeros.
133    gamma_initializer: Initializer for the gamma weight. Defaults to ones.
134    beta_regularizer: Optional regularizer for the beta weight. None by default.
135    gamma_regularizer: Optional regularizer for the gamma weight. None by
136      default.
137    beta_constraint: Optional constraint for the beta weight. None by default.
138    gamma_constraint: Optional constraint for the gamma weight. None by default.
139
140  Input shape:
141    Arbitrary. Use the keyword argument `input_shape` (tuple of
142    integers, does not include the samples axis) when using this layer as the
143    first layer in a model.
144
145  Output shape:
146    Same shape as input.
147
148  Reference:
149    - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
150  """
151
152  def __init__(self,
153               axis=-1,
154               epsilon=1e-3,
155               center=True,
156               scale=True,
157               beta_initializer='zeros',
158               gamma_initializer='ones',
159               beta_regularizer=None,
160               gamma_regularizer=None,
161               beta_constraint=None,
162               gamma_constraint=None,
163               **kwargs):
164    super(LayerNormalization, self).__init__(**kwargs)
165    if isinstance(axis, (list, tuple)):
166      self.axis = axis[:]
167    elif isinstance(axis, int):
168      self.axis = axis
169    else:
170      raise TypeError('Expected an int or a list/tuple of ints for the '
171                      'argument \'axis\', but received: %r' % axis)
172
173    self.epsilon = epsilon
174    self.center = center
175    self.scale = scale
176    self.beta_initializer = initializers.get(beta_initializer)
177    self.gamma_initializer = initializers.get(gamma_initializer)
178    self.beta_regularizer = regularizers.get(beta_regularizer)
179    self.gamma_regularizer = regularizers.get(gamma_regularizer)
180    self.beta_constraint = constraints.get(beta_constraint)
181    self.gamma_constraint = constraints.get(gamma_constraint)
182
183    self.supports_masking = True
184
185    # Indicates whether a faster fused implementation can be used. This will be
186    # set to True or False in build()"
187    self._fused = None
188
189  def _fused_can_be_used(self, ndims):
190    """Returns false if fused implementation cannot be used.
191
192    Check if the axis is contiguous and can be collapsed into the last axis.
193    The self.axis is assumed to have no duplicates.
194    """
195    axis = sorted(self.axis)
196    can_use_fused = False
197
198    if axis[-1] == ndims - 1 and axis[-1] - axis[0] == len(axis) - 1:
199      can_use_fused = True
200
201    # fused_batch_norm will silently raise epsilon to be at least 1.001e-5, so
202    # we cannot used the fused version if epsilon is below that value. Also, the
203    # variable dtype must be float32, as fused_batch_norm only supports float32
204    # variables.
205    if self.epsilon < 1.001e-5 or self.dtype != 'float32':
206      can_use_fused = False
207
208    return can_use_fused
209
210  def build(self, input_shape):
211    ndims = len(input_shape)
212    if ndims is None:
213      raise ValueError('Input shape %s has undefined rank.' % input_shape)
214
215    # Convert axis to list and resolve negatives
216    if isinstance(self.axis, int):
217      self.axis = [self.axis]
218    elif isinstance(self.axis, tuple):
219      self.axis = list(self.axis)
220    for idx, x in enumerate(self.axis):
221      if x < 0:
222        self.axis[idx] = ndims + x
223
224    # Validate axes
225    for x in self.axis:
226      if x < 0 or x >= ndims:
227        raise ValueError('Invalid axis: %d' % x)
228    if len(self.axis) != len(set(self.axis)):
229      raise ValueError('Duplicate axis: {}'.format(tuple(self.axis)))
230
231    param_shape = [input_shape[dim] for dim in self.axis]
232    if self.scale:
233      self.gamma = self.add_weight(
234          name='gamma',
235          shape=param_shape,
236          initializer=self.gamma_initializer,
237          regularizer=self.gamma_regularizer,
238          constraint=self.gamma_constraint,
239          trainable=True,
240          experimental_autocast=False)
241    else:
242      self.gamma = None
243
244    if self.center:
245      self.beta = self.add_weight(
246          name='beta',
247          shape=param_shape,
248          initializer=self.beta_initializer,
249          regularizer=self.beta_regularizer,
250          constraint=self.beta_constraint,
251          trainable=True,
252          experimental_autocast=False)
253    else:
254      self.beta = None
255
256    self._fused = self._fused_can_be_used(ndims)
257
258    self.built = True
259
260  def call(self, inputs):
261    # Compute the axes along which to reduce the mean / variance
262    input_shape = inputs.shape
263    ndims = len(input_shape)
264
265    # Broadcasting only necessary for norm when the axis is not just
266    # the last dimension
267    broadcast_shape = [1] * ndims
268    for dim in self.axis:
269      broadcast_shape[dim] = input_shape.dims[dim].value
270
271    def _broadcast(v):
272      if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
273        return array_ops.reshape(v, broadcast_shape)
274      return v
275
276    if not self._fused:
277      input_dtype = inputs.dtype
278      if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32':
279        # If mixed precision is used, cast inputs to float32 so that this is at
280        # least as numerically stable as the fused version.
281        inputs = math_ops.cast(inputs, 'float32')
282
283      # Calculate the moments on the last axis (layer activations).
284      mean, variance = nn.moments(inputs, self.axis, keep_dims=True)
285
286      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
287
288      # Compute layer normalization using the batch_normalization function.
289      outputs = nn.batch_normalization(
290          inputs,
291          mean,
292          variance,
293          offset=offset,
294          scale=scale,
295          variance_epsilon=self.epsilon)
296      outputs = math_ops.cast(outputs, input_dtype)
297    else:
298      # Collapse dims before self.axis, and dims in self.axis
299      pre_dim, in_dim = (1, 1)
300      axis = sorted(self.axis)
301      tensor_shape = array_ops.shape(inputs)
302      for dim in range(0, ndims):
303        dim_tensor = tensor_shape[dim]
304        if dim < axis[0]:
305          pre_dim = pre_dim * dim_tensor
306        else:
307          assert dim in axis
308          in_dim = in_dim * dim_tensor
309
310      squeezed_shape = [1, pre_dim, in_dim, 1]
311      # This fused operation requires reshaped inputs to be NCHW.
312      data_format = 'NCHW'
313
314      inputs = array_ops.reshape(inputs, squeezed_shape)
315
316      # self.gamma and self.beta have the wrong shape for fused_batch_norm, so
317      # we cannot pass them as the scale and offset parameters. Therefore, we
318      # create two constant tensors in correct shapes for fused_batch_norm and
319      # later construct a separate calculation on the scale and offset.
320      scale = array_ops.ones([pre_dim], dtype=self.dtype)
321      offset = array_ops.zeros([pre_dim], dtype=self.dtype)
322
323      # Compute layer normalization using the fused_batch_norm function.
324      outputs, _, _ = nn.fused_batch_norm(
325          inputs,
326          scale=scale,
327          offset=offset,
328          epsilon=self.epsilon,
329          data_format=data_format)
330
331      outputs = array_ops.reshape(outputs, tensor_shape)
332
333      scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
334
335      if scale is not None:
336        outputs = outputs * math_ops.cast(scale, outputs.dtype)
337      if offset is not None:
338        outputs = outputs + math_ops.cast(offset, outputs.dtype)
339
340    # If some components of the shape got lost due to adjustments, fix that.
341    outputs.set_shape(input_shape)
342
343    return outputs
344
345  def compute_output_shape(self, input_shape):
346    return input_shape
347
348  def get_config(self):
349    config = {
350        'axis': self.axis,
351        'epsilon': self.epsilon,
352        'center': self.center,
353        'scale': self.scale,
354        'beta_initializer': initializers.serialize(self.beta_initializer),
355        'gamma_initializer': initializers.serialize(self.gamma_initializer),
356        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
357        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
358        'beta_constraint': constraints.serialize(self.beta_constraint),
359        'gamma_constraint': constraints.serialize(self.gamma_constraint)
360    }
361    base_config = super(LayerNormalization, self).get_config()
362    return dict(list(base_config.items()) + list(config.items()))
363