• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Normalization preprocessing layer."""
16# pylint: disable=g-classes-have-attributes
17
18import numpy as np
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.keras import backend
25from tensorflow.python.keras.engine import base_preprocessing_layer
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_impl
30from tensorflow.python.ops import variables
31from tensorflow.python.util.tf_export import keras_export
32
33
34@keras_export('keras.layers.experimental.preprocessing.Normalization')
35class Normalization(base_preprocessing_layer.PreprocessingLayer):
36  """Feature-wise normalization of the data.
37
38  This layer will coerce its inputs into a distribution centered around
39  0 with standard deviation 1. It accomplishes this by precomputing the mean and
40  variance of the data, and calling (input-mean)/sqrt(var) at runtime.
41
42  What happens in `adapt`: Compute mean and variance of the data and store them
43    as the layer's weights. `adapt` should be called before `fit`, `evaluate`,
44    or `predict`.
45
46  Args:
47      axis: Integer or tuple of integers, the axis or axes that should be
48        "kept". These axes are not be summed over when calculating the
49        normalization statistics. By default the last axis, the `features` axis
50        is kept and any `space` or `time` axes are summed. Each element in the
51        the axes that are kept is normalized independently. If `axis` is set to
52        'None', the layer will perform scalar normalization (dividing the input
53        by a single scalar value). The `batch` axis, 0, is always summed over
54        (`axis=0` is not allowed).
55      mean: The mean value(s) to use during normalization. The passed value(s)
56        will be broadcast to the shape of the kept axes above; if the value(s)
57        cannot be broadcast, an error will be raised when this layer's build()
58        method is called.
59      variance: The variance value(s) to use during normalization. The passed
60        value(s) will be broadcast to the shape of the kept axes above; if the
61        value(s) cannot be broadcast, an error will be raised when this layer's
62        build() method is called.
63
64  Examples:
65
66  Calculate the mean and variance by analyzing the dataset in `adapt`.
67
68  >>> adapt_data = np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32)
69  >>> input_data = np.array([[1.], [2.], [3.]], np.float32)
70  >>> layer = Normalization()
71  >>> layer.adapt(adapt_data)
72  >>> layer(input_data)
73  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
74  array([[-1.4142135 ],
75         [-0.70710677],
76         [ 0.        ]], dtype=float32)>
77
78  Pass the mean and variance directly.
79
80  >>> input_data = np.array([[1.], [2.], [3.]], np.float32)
81  >>> layer = Normalization(mean=3., variance=2.)
82  >>> layer(input_data)
83  <tf.Tensor: shape=(3, 1), dtype=float32, numpy=
84  array([[-1.4142135 ],
85         [-0.70710677],
86         [ 0.        ]], dtype=float32)>
87  """
88
89  def __init__(self, axis=-1, mean=None, variance=None, **kwargs):
90    super().__init__(streaming=True, **kwargs)
91
92    # Standardize `axis` to a tuple.
93    if axis is None:
94      axis = ()
95    elif isinstance(axis, int):
96      axis = (axis,)
97    else:
98      axis = tuple(axis)
99    if 0 in axis:
100      raise ValueError('The argument \'axis\' may not be 0.')
101    self.axis = axis
102
103    # Set `mean` and `variance` if passed.
104    if isinstance(mean, variables.Variable):
105      raise ValueError('Normalization does not support passing a Variable '
106                       'for the `mean` init arg.')
107    if isinstance(variance, variables.Variable):
108      raise ValueError('Normalization does not support passing a Variable '
109                       'for the `variance` init arg.')
110    if (mean is not None) != (variance is not None):
111      raise ValueError(
112          'When setting values directly, both `mean` and `variance` '
113          'must be set. Got mean: {} and variance: {}'.format(mean, variance))
114    self.input_mean = mean
115    self.input_variance = variance
116
117  def build(self, input_shape):
118    super().build(input_shape)
119
120    input_shape = tensor_shape.TensorShape(input_shape).as_list()
121    if len(input_shape) == 1:
122      input_shape = input_shape + [1]
123    ndim = len(input_shape)
124
125    if any(a < 1 - ndim or a >= ndim for a in self.axis):
126      raise ValueError('All `axis` values must be in the range '
127                       '[1 - ndim, ndim - 1]. Found '
128                       'ndim: `{}`, axis: {}'.format(ndim, self.axis))
129
130    # Axes to be kept, replacing negative values with positive equivalents.
131    # Sorted to avoid transposing axes.
132    self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis])
133    # Axes to be reduced.
134    self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis]
135    # 1 if an axis should be reduced, 0 otherwise.
136    self._reduce_axis_mask = [
137        0 if d in self._keep_axis else 1 for d in range(ndim)
138    ]
139    # Broadcast any reduced axes.
140    self._broadcast_shape = [
141        input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)
142    ]
143    mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)
144
145    if self.input_mean is None:
146      self.adapt_mean = self.add_weight(
147          name='mean',
148          shape=mean_and_var_shape,
149          dtype=self.dtype,
150          initializer=init_ops.zeros_initializer,
151          trainable=False)
152      self.adapt_variance = self.add_weight(
153          name='variance',
154          shape=mean_and_var_shape,
155          dtype=self.dtype,
156          initializer=init_ops.ones_initializer,
157          trainable=False)
158      self.count = self.add_weight(
159          name='count',
160          shape=(),
161          dtype=dtypes.int64,
162          initializer=init_ops.zeros_initializer,
163          trainable=False)
164      self.finalize_state()
165    else:
166      # In the no adapt case, make constant tensors for mean and variance with
167      # proper broadcast shape for use during call.
168      mean = self.input_mean * np.ones(mean_and_var_shape)
169      variance = self.input_variance * np.ones(mean_and_var_shape)
170      mean = array_ops.reshape(mean, self._broadcast_shape)
171      variance = array_ops.reshape(variance, self._broadcast_shape)
172      self.mean = math_ops.cast(mean, self.compute_dtype)
173      self.variance = math_ops.cast(variance, self.compute_dtype)
174
175  def update_state(self, data):
176    if self.input_mean is not None:
177      raise ValueError(
178          'Cannot `adapt` a Normalization layer that is initialized with '
179          'static `mean` and `variance`, you passed mean {} and variance {}.'
180          .format(self.input_mean, self.input_variance))
181
182    if not self.built:
183      raise RuntimeError('`build` must be called before `update_state`.')
184
185    data = self._standardize_inputs(data)
186    data = math_ops.cast(data, self.adapt_mean.dtype)
187    batch_mean, batch_variance = nn_impl.moments_v2(
188        data, axes=self._reduce_axis)
189    batch_shape = array_ops.shape(data, out_type=self.count.dtype)
190    batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis)
191    batch_count = math_ops.reduce_prod(batch_reduce_shape)
192
193    total_count = batch_count + self.count
194    batch_weight = (
195        math_ops.cast(batch_count, dtype=self.dtype) /
196        math_ops.cast(total_count, dtype=self.dtype))
197    existing_weight = 1. - batch_weight
198
199    total_mean = self.adapt_mean * existing_weight + batch_mean * batch_weight
200    # The variance is computed using the lack-of-fit sum of squares
201    # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
202    total_variance = ((self.adapt_variance +
203                       (self.adapt_mean - total_mean)**2) * existing_weight +
204                      (batch_variance +
205                       (batch_mean - total_mean)**2) * batch_weight)
206    self.adapt_mean.assign(total_mean)
207    self.adapt_variance.assign(total_variance)
208    self.count.assign(total_count)
209
210  def merge_state(self, layers):
211    layers = layers + [self]
212    for l in layers:
213      if l.input_mean is not None:
214        raise ValueError(
215            'Cannot merge Normalization layer {} that has initialized with '
216            '`mean` and `variance`, you passed `mean={}` and `variance={}`.'
217            .format(l.name, l.input_mean, l.input_variance))
218      if not l.built:
219        raise ValueError(
220            'Cannot merge Normalization layer {}, it has no state. You need to '
221            'call `adapt` on this layer before merging.'.format(l.name))
222
223    layer_counts = [l.count for l in layers]
224    layer_means = [l.adapt_mean for l in layers]
225    layer_variances = [l.adapt_variance for l in layers]
226
227    total_count = math_ops.reduce_sum(layer_counts)
228    layer_weightings = (
229        math_ops.cast(layer_counts, self.dtype) /
230        math_ops.cast(total_count, self.dtype))
231    layer_weightings = array_ops.reshape(
232        layer_weightings,
233        shape=[len(layers)] + [1] * self.adapt_mean.shape.rank)
234
235    total_mean = math_ops.reduce_sum(layer_means * layer_weightings, axis=0)
236    inter_layer_variances = (layer_means - total_mean)**2
237    total_variance = math_ops.reduce_sum(
238        ((layer_variances + inter_layer_variances) * layer_weightings), axis=0)
239
240    self.adapt_mean.assign(total_mean)
241    self.adapt_variance.assign(total_variance)
242    self.count.assign(total_count)
243    self.finalize_state()
244
245  def reset_state(self):  # pylint: disable=method-hidden
246    if self.input_mean is not None or not self.built:
247      return
248
249    self.adapt_mean.assign(array_ops.zeros_like(self.adapt_mean))
250    self.adapt_variance.assign(array_ops.ones_like(self.adapt_variance))
251    self.count.assign(array_ops.zeros_like(self.count))
252
253  def finalize_state(self):
254    if self.input_mean is not None or not self.built:
255      return
256
257    # In the adapt case, we make constant tensors for mean and variance with
258    # proper broadcast shape and dtype each time `finalize_state` is called.
259    self.mean = array_ops.reshape(self.adapt_mean, self._broadcast_shape)
260    self.mean = math_ops.cast(self.mean, self.compute_dtype)
261    self.variance = array_ops.reshape(self.adapt_variance,
262                                      self._broadcast_shape)
263    self.variance = math_ops.cast(self.variance, self.compute_dtype)
264
265  def call(self, inputs):
266    inputs = self._standardize_inputs(inputs)
267    # The base layer automatically casts floating-point inputs, but we
268    # explicitly cast here to also allow integer inputs to be passed
269    inputs = math_ops.cast(inputs, self.compute_dtype)
270    return ((inputs - self.mean) /
271            math_ops.maximum(math_ops.sqrt(self.variance), backend.epsilon()))
272
273  def compute_output_shape(self, input_shape):
274    return input_shape
275
276  def compute_output_signature(self, input_spec):
277    return input_spec
278
279  def get_config(self):
280    config = super().get_config()
281    config.update({
282        'axis': self.axis,
283        'mean': self._convert_to_list(self.input_mean),
284        'variance': self._convert_to_list(self.input_variance),
285    })
286    return config
287
288  def _standardize_inputs(self, inputs):
289    inputs = ops.convert_to_tensor_v2_with_dispatch(inputs)
290    if inputs.shape.rank == 0:
291      inputs = array_ops.reshape(inputs, [1, 1])
292    elif inputs.shape.rank == 1:
293      inputs = array_ops.expand_dims(inputs, 1)
294    return inputs
295
296  def _convert_to_list(self, inputs):
297    if tensor_util.is_tensor(inputs):
298      inputs = inputs.numpy()
299    if isinstance(inputs, (np.ndarray)):
300      inputs = inputs.tolist()
301      inputs = list(inputs)
302    return inputs
303