1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Normalization preprocessing layer.""" 16# pylint: disable=g-classes-have-attributes 17 18import numpy as np 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.keras import backend 25from tensorflow.python.keras.engine import base_preprocessing_layer 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import nn_impl 30from tensorflow.python.ops import variables 31from tensorflow.python.util.tf_export import keras_export 32 33 34@keras_export('keras.layers.experimental.preprocessing.Normalization') 35class Normalization(base_preprocessing_layer.PreprocessingLayer): 36 """Feature-wise normalization of the data. 37 38 This layer will coerce its inputs into a distribution centered around 39 0 with standard deviation 1. It accomplishes this by precomputing the mean and 40 variance of the data, and calling (input-mean)/sqrt(var) at runtime. 41 42 What happens in `adapt`: Compute mean and variance of the data and store them 43 as the layer's weights. `adapt` should be called before `fit`, `evaluate`, 44 or `predict`. 45 46 Args: 47 axis: Integer or tuple of integers, the axis or axes that should be 48 "kept". These axes are not be summed over when calculating the 49 normalization statistics. By default the last axis, the `features` axis 50 is kept and any `space` or `time` axes are summed. Each element in the 51 the axes that are kept is normalized independently. If `axis` is set to 52 'None', the layer will perform scalar normalization (dividing the input 53 by a single scalar value). The `batch` axis, 0, is always summed over 54 (`axis=0` is not allowed). 55 mean: The mean value(s) to use during normalization. The passed value(s) 56 will be broadcast to the shape of the kept axes above; if the value(s) 57 cannot be broadcast, an error will be raised when this layer's build() 58 method is called. 59 variance: The variance value(s) to use during normalization. The passed 60 value(s) will be broadcast to the shape of the kept axes above; if the 61 value(s) cannot be broadcast, an error will be raised when this layer's 62 build() method is called. 63 64 Examples: 65 66 Calculate the mean and variance by analyzing the dataset in `adapt`. 67 68 >>> adapt_data = np.array([[1.], [2.], [3.], [4.], [5.]], dtype=np.float32) 69 >>> input_data = np.array([[1.], [2.], [3.]], np.float32) 70 >>> layer = Normalization() 71 >>> layer.adapt(adapt_data) 72 >>> layer(input_data) 73 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 74 array([[-1.4142135 ], 75 [-0.70710677], 76 [ 0. ]], dtype=float32)> 77 78 Pass the mean and variance directly. 79 80 >>> input_data = np.array([[1.], [2.], [3.]], np.float32) 81 >>> layer = Normalization(mean=3., variance=2.) 82 >>> layer(input_data) 83 <tf.Tensor: shape=(3, 1), dtype=float32, numpy= 84 array([[-1.4142135 ], 85 [-0.70710677], 86 [ 0. ]], dtype=float32)> 87 """ 88 89 def __init__(self, axis=-1, mean=None, variance=None, **kwargs): 90 super().__init__(streaming=True, **kwargs) 91 92 # Standardize `axis` to a tuple. 93 if axis is None: 94 axis = () 95 elif isinstance(axis, int): 96 axis = (axis,) 97 else: 98 axis = tuple(axis) 99 if 0 in axis: 100 raise ValueError('The argument \'axis\' may not be 0.') 101 self.axis = axis 102 103 # Set `mean` and `variance` if passed. 104 if isinstance(mean, variables.Variable): 105 raise ValueError('Normalization does not support passing a Variable ' 106 'for the `mean` init arg.') 107 if isinstance(variance, variables.Variable): 108 raise ValueError('Normalization does not support passing a Variable ' 109 'for the `variance` init arg.') 110 if (mean is not None) != (variance is not None): 111 raise ValueError( 112 'When setting values directly, both `mean` and `variance` ' 113 'must be set. Got mean: {} and variance: {}'.format(mean, variance)) 114 self.input_mean = mean 115 self.input_variance = variance 116 117 def build(self, input_shape): 118 super().build(input_shape) 119 120 input_shape = tensor_shape.TensorShape(input_shape).as_list() 121 if len(input_shape) == 1: 122 input_shape = input_shape + [1] 123 ndim = len(input_shape) 124 125 if any(a < 1 - ndim or a >= ndim for a in self.axis): 126 raise ValueError('All `axis` values must be in the range ' 127 '[1 - ndim, ndim - 1]. Found ' 128 'ndim: `{}`, axis: {}'.format(ndim, self.axis)) 129 130 # Axes to be kept, replacing negative values with positive equivalents. 131 # Sorted to avoid transposing axes. 132 self._keep_axis = sorted([d if d >= 0 else d + ndim for d in self.axis]) 133 # Axes to be reduced. 134 self._reduce_axis = [d for d in range(ndim) if d not in self._keep_axis] 135 # 1 if an axis should be reduced, 0 otherwise. 136 self._reduce_axis_mask = [ 137 0 if d in self._keep_axis else 1 for d in range(ndim) 138 ] 139 # Broadcast any reduced axes. 140 self._broadcast_shape = [ 141 input_shape[d] if d in self._keep_axis else 1 for d in range(ndim) 142 ] 143 mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis) 144 145 if self.input_mean is None: 146 self.adapt_mean = self.add_weight( 147 name='mean', 148 shape=mean_and_var_shape, 149 dtype=self.dtype, 150 initializer=init_ops.zeros_initializer, 151 trainable=False) 152 self.adapt_variance = self.add_weight( 153 name='variance', 154 shape=mean_and_var_shape, 155 dtype=self.dtype, 156 initializer=init_ops.ones_initializer, 157 trainable=False) 158 self.count = self.add_weight( 159 name='count', 160 shape=(), 161 dtype=dtypes.int64, 162 initializer=init_ops.zeros_initializer, 163 trainable=False) 164 self.finalize_state() 165 else: 166 # In the no adapt case, make constant tensors for mean and variance with 167 # proper broadcast shape for use during call. 168 mean = self.input_mean * np.ones(mean_and_var_shape) 169 variance = self.input_variance * np.ones(mean_and_var_shape) 170 mean = array_ops.reshape(mean, self._broadcast_shape) 171 variance = array_ops.reshape(variance, self._broadcast_shape) 172 self.mean = math_ops.cast(mean, self.compute_dtype) 173 self.variance = math_ops.cast(variance, self.compute_dtype) 174 175 def update_state(self, data): 176 if self.input_mean is not None: 177 raise ValueError( 178 'Cannot `adapt` a Normalization layer that is initialized with ' 179 'static `mean` and `variance`, you passed mean {} and variance {}.' 180 .format(self.input_mean, self.input_variance)) 181 182 if not self.built: 183 raise RuntimeError('`build` must be called before `update_state`.') 184 185 data = self._standardize_inputs(data) 186 data = math_ops.cast(data, self.adapt_mean.dtype) 187 batch_mean, batch_variance = nn_impl.moments_v2( 188 data, axes=self._reduce_axis) 189 batch_shape = array_ops.shape(data, out_type=self.count.dtype) 190 batch_reduce_shape = array_ops.gather(batch_shape, self._reduce_axis) 191 batch_count = math_ops.reduce_prod(batch_reduce_shape) 192 193 total_count = batch_count + self.count 194 batch_weight = ( 195 math_ops.cast(batch_count, dtype=self.dtype) / 196 math_ops.cast(total_count, dtype=self.dtype)) 197 existing_weight = 1. - batch_weight 198 199 total_mean = self.adapt_mean * existing_weight + batch_mean * batch_weight 200 # The variance is computed using the lack-of-fit sum of squares 201 # formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). 202 total_variance = ((self.adapt_variance + 203 (self.adapt_mean - total_mean)**2) * existing_weight + 204 (batch_variance + 205 (batch_mean - total_mean)**2) * batch_weight) 206 self.adapt_mean.assign(total_mean) 207 self.adapt_variance.assign(total_variance) 208 self.count.assign(total_count) 209 210 def merge_state(self, layers): 211 layers = layers + [self] 212 for l in layers: 213 if l.input_mean is not None: 214 raise ValueError( 215 'Cannot merge Normalization layer {} that has initialized with ' 216 '`mean` and `variance`, you passed `mean={}` and `variance={}`.' 217 .format(l.name, l.input_mean, l.input_variance)) 218 if not l.built: 219 raise ValueError( 220 'Cannot merge Normalization layer {}, it has no state. You need to ' 221 'call `adapt` on this layer before merging.'.format(l.name)) 222 223 layer_counts = [l.count for l in layers] 224 layer_means = [l.adapt_mean for l in layers] 225 layer_variances = [l.adapt_variance for l in layers] 226 227 total_count = math_ops.reduce_sum(layer_counts) 228 layer_weightings = ( 229 math_ops.cast(layer_counts, self.dtype) / 230 math_ops.cast(total_count, self.dtype)) 231 layer_weightings = array_ops.reshape( 232 layer_weightings, 233 shape=[len(layers)] + [1] * self.adapt_mean.shape.rank) 234 235 total_mean = math_ops.reduce_sum(layer_means * layer_weightings, axis=0) 236 inter_layer_variances = (layer_means - total_mean)**2 237 total_variance = math_ops.reduce_sum( 238 ((layer_variances + inter_layer_variances) * layer_weightings), axis=0) 239 240 self.adapt_mean.assign(total_mean) 241 self.adapt_variance.assign(total_variance) 242 self.count.assign(total_count) 243 self.finalize_state() 244 245 def reset_state(self): # pylint: disable=method-hidden 246 if self.input_mean is not None or not self.built: 247 return 248 249 self.adapt_mean.assign(array_ops.zeros_like(self.adapt_mean)) 250 self.adapt_variance.assign(array_ops.ones_like(self.adapt_variance)) 251 self.count.assign(array_ops.zeros_like(self.count)) 252 253 def finalize_state(self): 254 if self.input_mean is not None or not self.built: 255 return 256 257 # In the adapt case, we make constant tensors for mean and variance with 258 # proper broadcast shape and dtype each time `finalize_state` is called. 259 self.mean = array_ops.reshape(self.adapt_mean, self._broadcast_shape) 260 self.mean = math_ops.cast(self.mean, self.compute_dtype) 261 self.variance = array_ops.reshape(self.adapt_variance, 262 self._broadcast_shape) 263 self.variance = math_ops.cast(self.variance, self.compute_dtype) 264 265 def call(self, inputs): 266 inputs = self._standardize_inputs(inputs) 267 # The base layer automatically casts floating-point inputs, but we 268 # explicitly cast here to also allow integer inputs to be passed 269 inputs = math_ops.cast(inputs, self.compute_dtype) 270 return ((inputs - self.mean) / 271 math_ops.maximum(math_ops.sqrt(self.variance), backend.epsilon())) 272 273 def compute_output_shape(self, input_shape): 274 return input_shape 275 276 def compute_output_signature(self, input_spec): 277 return input_spec 278 279 def get_config(self): 280 config = super().get_config() 281 config.update({ 282 'axis': self.axis, 283 'mean': self._convert_to_list(self.input_mean), 284 'variance': self._convert_to_list(self.input_variance), 285 }) 286 return config 287 288 def _standardize_inputs(self, inputs): 289 inputs = ops.convert_to_tensor_v2_with_dispatch(inputs) 290 if inputs.shape.rank == 0: 291 inputs = array_ops.reshape(inputs, [1, 1]) 292 elif inputs.shape.rank == 1: 293 inputs = array_ops.expand_dims(inputs, 1) 294 return inputs 295 296 def _convert_to_list(self, inputs): 297 if tensor_util.is_tensor(inputs): 298 inputs = inputs.numpy() 299 if isinstance(inputs, (np.ndarray)): 300 inputs = inputs.tolist() 301 inputs = list(inputs) 302 return inputs 303