1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras' base preprocessing layer.""" 16 17import json 18import os 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python import keras 24 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.eager import context 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.keras import keras_parameterized 31from tensorflow.python.keras import testing_utils 32from tensorflow.python.keras.engine import base_preprocessing_layer 33from tensorflow.python.ops import init_ops 34from tensorflow.python.ops import sparse_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.ops.ragged import ragged_factory_ops 37from tensorflow.python.platform import test 38from tensorflow.python.util import compat 39 40 41# Define a test-only implementation of CombinerPreprocessingLayer to validate 42# its correctness directly. 43class AddingPreprocessingLayer( 44 base_preprocessing_layer.CombinerPreprocessingLayer): 45 _SUM_NAME = "sum" 46 47 def __init__(self, **kwargs): 48 super(AddingPreprocessingLayer, self).__init__( 49 combiner=self.AddingCombiner(), **kwargs) 50 51 def build(self, input_shape): 52 super(AddingPreprocessingLayer, self).build(input_shape) 53 self._sum = self._add_state_variable( 54 name=self._SUM_NAME, 55 shape=(1,), 56 dtype=dtypes.float32, 57 initializer=init_ops.zeros_initializer) 58 59 def reset_state(self): # pylint: disable=method-hidden 60 self._sum.assign([0.]) 61 62 def set_total(self, sum_value): 63 """This is an example of how a subclass would implement a direct setter. 64 65 These methods should generally just create a dict mapping the correct names 66 to the relevant passed values, and call self._set_state_variables() with the 67 dict of data. 68 69 Args: 70 sum_value: The total to set. 71 """ 72 self._set_state_variables({self._SUM_NAME: [sum_value]}) 73 74 def call(self, inputs): 75 return inputs + self._sum 76 77 # Define a Combiner for this layer class. 78 class AddingCombiner(base_preprocessing_layer.Combiner): 79 80 def compute(self, batch_values, accumulator=None): 81 """Compute a step in this computation, returning a new accumulator.""" 82 new_accumulator = 0 if batch_values is None else np.sum(batch_values) 83 if accumulator is None: 84 return new_accumulator 85 else: 86 return self.merge([accumulator, new_accumulator]) 87 88 def merge(self, accumulators): 89 """Merge several accumulators to a single accumulator.""" 90 # Combine accumulators and return the result. 91 result = accumulators[0] 92 for accumulator in accumulators[1:]: 93 result = np.sum([np.sum(result), np.sum(accumulator)]) 94 return result 95 96 def extract(self, accumulator): 97 """Convert an accumulator into a dict of output values.""" 98 # We have to add an additional dimension here because the weight shape 99 # is (1,) not None. 100 return {AddingPreprocessingLayer._SUM_NAME: [accumulator]} 101 102 def restore(self, output): 103 """Create an accumulator based on 'output'.""" 104 # There is no special internal state here, so we just return the relevant 105 # internal value. We take the [0] value here because the weight itself 106 # is of the shape (1,) and we want the scalar contained inside it. 107 return output[AddingPreprocessingLayer._SUM_NAME][0] 108 109 def serialize(self, accumulator): 110 """Serialize an accumulator for a remote call.""" 111 return compat.as_bytes(json.dumps(accumulator)) 112 113 def deserialize(self, encoded_accumulator): 114 """Deserialize an accumulator received from 'serialize()'.""" 115 return json.loads(compat.as_text(encoded_accumulator)) 116 117 118@keras_parameterized.run_all_keras_modes(always_skip_v1=True) 119class PreprocessingLayerTest(keras_parameterized.TestCase): 120 121 def test_adapt_bad_input_fails(self): 122 """Test that non-Dataset/Numpy inputs cause a reasonable error.""" 123 input_dataset = {"foo": 0} 124 125 layer = AddingPreprocessingLayer() 126 if context.executing_eagerly(): 127 with self.assertRaisesRegex(ValueError, "Failed to find data adapter"): 128 layer.adapt(input_dataset) 129 else: 130 with self.assertRaisesRegex(ValueError, "requires a"): 131 layer.adapt(input_dataset) 132 133 def test_adapt_infinite_dataset_fails(self): 134 """Test that preproc layers fail if an infinite dataset is passed.""" 135 input_dataset = dataset_ops.Dataset.from_tensor_slices( 136 np.array([[1], [2], [3], [4], [5], [0]])).repeat() 137 138 layer = AddingPreprocessingLayer() 139 if context.executing_eagerly(): 140 with self.assertRaisesRegex(ValueError, "infinite dataset"): 141 layer.adapt(input_dataset) 142 else: 143 with self.assertRaisesRegex(ValueError, 144 ".*infinite number of elements.*"): 145 layer.adapt(input_dataset) 146 147 def test_pre_build_injected_update_with_no_build_fails(self): 148 """Test external update injection before build() is called fails.""" 149 input_dataset = np.array([1, 2, 3, 4, 5]) 150 151 layer = AddingPreprocessingLayer() 152 combiner = layer._combiner 153 updates = combiner.extract(combiner.compute(input_dataset)) 154 155 with self.assertRaisesRegex(RuntimeError, ".*called after build.*"): 156 layer._set_state_variables(updates) 157 158 def test_setter_update(self): 159 """Test the prototyped setter method.""" 160 input_data = keras.Input(shape=(1,)) 161 layer = AddingPreprocessingLayer() 162 output = layer(input_data) 163 model = keras.Model(input_data, output) 164 model._run_eagerly = testing_utils.should_run_eagerly() 165 166 layer.set_total(15) 167 168 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 169 170 def test_pre_build_adapt_update_numpy(self): 171 """Test that preproc layers can adapt() before build() is called.""" 172 input_dataset = np.array([1, 2, 3, 4, 5]) 173 174 layer = AddingPreprocessingLayer() 175 layer.adapt(input_dataset) 176 177 input_data = keras.Input(shape=(1,)) 178 output = layer(input_data) 179 model = keras.Model(input_data, output) 180 model._run_eagerly = testing_utils.should_run_eagerly() 181 182 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 183 184 def test_post_build_adapt_update_numpy(self): 185 """Test that preproc layers can adapt() after build() is called.""" 186 input_dataset = np.array([1, 2, 3, 4, 5]) 187 188 input_data = keras.Input(shape=(1,)) 189 layer = AddingPreprocessingLayer() 190 output = layer(input_data) 191 model = keras.Model(input_data, output) 192 model._run_eagerly = testing_utils.should_run_eagerly() 193 194 layer.adapt(input_dataset) 195 196 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 197 198 def test_pre_build_injected_update(self): 199 """Test external update injection before build() is called.""" 200 input_dataset = np.array([1, 2, 3, 4, 5]) 201 202 layer = AddingPreprocessingLayer() 203 combiner = layer._combiner 204 updates = combiner.extract(combiner.compute(input_dataset)) 205 206 layer.build((1,)) 207 layer._set_state_variables(updates) 208 209 input_data = keras.Input(shape=(1,)) 210 output = layer(input_data) 211 model = keras.Model(input_data, output) 212 model._run_eagerly = testing_utils.should_run_eagerly() 213 214 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 215 216 def test_post_build_injected_update(self): 217 """Test external update injection after build() is called.""" 218 input_dataset = np.array([1, 2, 3, 4, 5]) 219 input_data = keras.Input(shape=(1,)) 220 layer = AddingPreprocessingLayer() 221 output = layer(input_data) 222 model = keras.Model(input_data, output) 223 model._run_eagerly = testing_utils.should_run_eagerly() 224 225 combiner = layer._combiner 226 updates = combiner.extract(combiner.compute(input_dataset)) 227 layer._set_state_variables(updates) 228 229 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 230 231 def test_pre_build_adapt_update_dataset(self): 232 """Test that preproc layers can adapt() before build() is called.""" 233 input_dataset = dataset_ops.Dataset.from_tensor_slices( 234 np.array([[1], [2], [3], [4], [5], [0]])) 235 236 layer = AddingPreprocessingLayer() 237 layer.adapt(input_dataset) 238 239 input_data = keras.Input(shape=(1,)) 240 output = layer(input_data) 241 model = keras.Model(input_data, output) 242 model._run_eagerly = testing_utils.should_run_eagerly() 243 244 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 245 246 def test_post_build_adapt_update_dataset(self): 247 """Test that preproc layers can adapt() after build() is called.""" 248 input_dataset = dataset_ops.Dataset.from_tensor_slices( 249 np.array([[1], [2], [3], [4], [5], [0]])) 250 251 input_data = keras.Input(shape=(1,)) 252 layer = AddingPreprocessingLayer() 253 output = layer(input_data) 254 model = keras.Model(input_data, output) 255 model._run_eagerly = testing_utils.should_run_eagerly() 256 257 layer.adapt(input_dataset) 258 259 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 260 261 def test_further_tuning(self): 262 """Test that models can be tuned with multiple calls to 'adapt'.""" 263 264 input_dataset = np.array([1, 2, 3, 4, 5]) 265 266 layer = AddingPreprocessingLayer() 267 layer.adapt(input_dataset) 268 269 input_data = keras.Input(shape=(1,)) 270 output = layer(input_data) 271 model = keras.Model(input_data, output) 272 model._run_eagerly = testing_utils.should_run_eagerly() 273 274 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 275 276 layer.adapt(np.array([1, 2]), reset_state=False) 277 self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.])) 278 279 def test_further_tuning_post_injection(self): 280 """Test that models can be tuned with multiple calls to 'adapt'.""" 281 282 input_dataset = np.array([1, 2, 3, 4, 5]) 283 284 layer = AddingPreprocessingLayer() 285 286 input_data = keras.Input(shape=(1,)) 287 output = layer(input_data) 288 model = keras.Model(input_data, output) 289 model._run_eagerly = testing_utils.should_run_eagerly() 290 291 combiner = layer._combiner 292 updates = combiner.extract(combiner.compute(input_dataset)) 293 layer._set_state_variables(updates) 294 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 295 296 layer.adapt(np.array([1, 2]), reset_state=False) 297 self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.])) 298 299 def test_weight_based_state_transfer(self): 300 """Test that preproc layers can transfer state via get/set weights..""" 301 302 def get_model(): 303 input_data = keras.Input(shape=(1,)) 304 layer = AddingPreprocessingLayer() 305 output = layer(input_data) 306 model = keras.Model(input_data, output) 307 model._run_eagerly = testing_utils.should_run_eagerly() 308 return (model, layer) 309 310 input_dataset = np.array([1, 2, 3, 4, 5]) 311 model, layer = get_model() 312 layer.adapt(input_dataset) 313 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 314 315 # Create a new model and verify it has no state carryover. 316 weights = model.get_weights() 317 model_2, _ = get_model() 318 self.assertAllEqual([[1], [2], [3]], model_2.predict([1., 2., 3.])) 319 320 # Transfer state from model to model_2 via get/set weights. 321 model_2.set_weights(weights) 322 self.assertAllEqual([[16], [17], [18]], model_2.predict([1., 2., 3.])) 323 324 def test_weight_based_state_transfer_with_further_tuning(self): 325 """Test that transferred state can be used to further tune a model..""" 326 327 def get_model(): 328 input_data = keras.Input(shape=(1,)) 329 layer = AddingPreprocessingLayer() 330 output = layer(input_data) 331 model = keras.Model(input_data, output) 332 model._run_eagerly = testing_utils.should_run_eagerly() 333 return (model, layer) 334 335 input_dataset = np.array([1, 2, 3, 4, 5]) 336 model, layer = get_model() 337 layer.adapt(input_dataset) 338 self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.])) 339 340 # Transfer state from model to model_2 via get/set weights. 341 weights = model.get_weights() 342 model_2, layer_2 = get_model() 343 model_2.set_weights(weights) 344 345 # Further adapt this layer based on the transferred weights. 346 layer_2.adapt(np.array([1, 2]), reset_state=False) 347 self.assertAllEqual([[19], [20], [21]], model_2.predict([1., 2., 3.])) 348 349 def test_loading_without_providing_class_fails(self): 350 input_data = keras.Input(shape=(1,)) 351 layer = AddingPreprocessingLayer() 352 output = layer(input_data) 353 model = keras.Model(input_data, output) 354 355 if not context.executing_eagerly(): 356 self.evaluate(variables.variables_initializer(model.variables)) 357 358 output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model") 359 model.save(output_path, save_format="tf") 360 361 with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"): 362 _ = keras.models.load_model(output_path) 363 364 def test_adapt_sets_input_shape_rank(self): 365 """Check that `.adapt()` sets the `input_shape`'s rank.""" 366 # Shape: (3,1,2) 367 adapt_dataset = np.array([[[1., 2.]], 368 [[3., 4.]], 369 [[5., 6.]]], dtype=np.float32) 370 371 layer = AddingPreprocessingLayer() 372 layer.adapt(adapt_dataset) 373 374 input_dataset = np.array([[[1., 2.], [3., 4.]], 375 [[3., 4.], [5., 6.]]], dtype=np.float32) 376 layer(input_dataset) 377 378 model = keras.Sequential([layer]) 379 self.assertTrue(model.built) 380 self.assertEqual(model.input_shape, (None, None, None)) 381 382 def test_adapt_doesnt_overwrite_input_shape(self): 383 """Check that `.adapt()` doesn't change the `input_shape`.""" 384 # Shape: (3, 1, 2) 385 adapt_dataset = np.array([[[1., 2.]], 386 [[3., 4.]], 387 [[5., 6.]]], dtype=np.float32) 388 389 layer = AddingPreprocessingLayer(input_shape=[1, 2]) 390 layer.adapt(adapt_dataset) 391 392 model = keras.Sequential([layer]) 393 self.assertTrue(model.built) 394 self.assertEqual(model.input_shape, (None, 1, 2)) 395 396 397class PreprocessingLayerV1Test(keras_parameterized.TestCase): 398 399 def test_adapt_fails(self): 400 """Test that calling adapt leads to a runtime error.""" 401 input_dataset = {"foo": 0} 402 403 with ops.Graph().as_default(): 404 layer = AddingPreprocessingLayer() 405 with self.assertRaisesRegex(RuntimeError, 406 "`adapt` is only supported in tensorflow v2"): 407 layer.adapt(input_dataset) 408 409 410@keras_parameterized.run_all_keras_modes(always_skip_v1=True) 411class ConvertToListTest(keras_parameterized.TestCase): 412 413 # Note: We need the inputs to be lambdas below to avoid some strangeness with 414 # TF1.x graph mode - specifically, if the inputs are created outside the test 415 # function body, the graph inside the test body will not contain the tensors 416 # that were created in the parameters. 417 @parameterized.named_parameters( 418 { 419 "testcase_name": "ndarray", 420 "inputs": lambda: np.array([[1, 2, 3], [4, 5, 6]]), 421 "expected": [[1, 2, 3], [4, 5, 6]] 422 }, { 423 "testcase_name": "list", 424 "inputs": lambda: [[1, 2, 3], [4, 5, 6]], 425 "expected": [[1, 2, 3], [4, 5, 6]] 426 }, { 427 "testcase_name": "tensor", 428 "inputs": lambda: constant_op.constant([[1, 2, 3], [4, 5, 6]]), 429 "expected": [[1, 2, 3], [4, 5, 6]] 430 }, { 431 "testcase_name": 432 "ragged_tensor", 433 "inputs": 434 lambda: ragged_factory_ops.constant([[1, 2, 3, 4], [4, 5, 6]]), 435 "expected": [[1, 2, 3, 4], [4, 5, 6]] 436 }, { 437 "testcase_name": "sparse_tensor", 438 "inputs": lambda: sparse_ops.from_dense([[1, 2, 0, 4], [4, 5, 6, 0]]), 439 "expected": [[1, 2, -1, 4], [4, 5, 6, -1]] 440 }) 441 def test_conversion(self, inputs, expected): 442 values = base_preprocessing_layer.convert_to_list(inputs()) 443 self.assertAllEqual(expected, values) 444 445 446if __name__ == "__main__": 447 test.main() 448