• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras' base preprocessing layer."""
16
17import json
18import os
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python import keras
24
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.eager import context
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.keras import keras_parameterized
31from tensorflow.python.keras import testing_utils
32from tensorflow.python.keras.engine import base_preprocessing_layer
33from tensorflow.python.ops import init_ops
34from tensorflow.python.ops import sparse_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.ops.ragged import ragged_factory_ops
37from tensorflow.python.platform import test
38from tensorflow.python.util import compat
39
40
41# Define a test-only implementation of CombinerPreprocessingLayer to validate
42# its correctness directly.
43class AddingPreprocessingLayer(
44    base_preprocessing_layer.CombinerPreprocessingLayer):
45  _SUM_NAME = "sum"
46
47  def __init__(self, **kwargs):
48    super(AddingPreprocessingLayer, self).__init__(
49        combiner=self.AddingCombiner(), **kwargs)
50
51  def build(self, input_shape):
52    super(AddingPreprocessingLayer, self).build(input_shape)
53    self._sum = self._add_state_variable(
54        name=self._SUM_NAME,
55        shape=(1,),
56        dtype=dtypes.float32,
57        initializer=init_ops.zeros_initializer)
58
59  def reset_state(self):  # pylint: disable=method-hidden
60    self._sum.assign([0.])
61
62  def set_total(self, sum_value):
63    """This is an example of how a subclass would implement a direct setter.
64
65    These methods should generally just create a dict mapping the correct names
66    to the relevant passed values, and call self._set_state_variables() with the
67    dict of data.
68
69    Args:
70      sum_value: The total to set.
71    """
72    self._set_state_variables({self._SUM_NAME: [sum_value]})
73
74  def call(self, inputs):
75    return inputs + self._sum
76
77  # Define a Combiner for this layer class.
78  class AddingCombiner(base_preprocessing_layer.Combiner):
79
80    def compute(self, batch_values, accumulator=None):
81      """Compute a step in this computation, returning a new accumulator."""
82      new_accumulator = 0 if batch_values is None else np.sum(batch_values)
83      if accumulator is None:
84        return new_accumulator
85      else:
86        return self.merge([accumulator, new_accumulator])
87
88    def merge(self, accumulators):
89      """Merge several accumulators to a single accumulator."""
90      # Combine accumulators and return the result.
91      result = accumulators[0]
92      for accumulator in accumulators[1:]:
93        result = np.sum([np.sum(result), np.sum(accumulator)])
94      return result
95
96    def extract(self, accumulator):
97      """Convert an accumulator into a dict of output values."""
98      # We have to add an additional dimension here because the weight shape
99      # is (1,) not None.
100      return {AddingPreprocessingLayer._SUM_NAME: [accumulator]}
101
102    def restore(self, output):
103      """Create an accumulator based on 'output'."""
104      # There is no special internal state here, so we just return the relevant
105      # internal value. We take the [0] value here because the weight itself
106      # is of the shape (1,) and we want the scalar contained inside it.
107      return output[AddingPreprocessingLayer._SUM_NAME][0]
108
109    def serialize(self, accumulator):
110      """Serialize an accumulator for a remote call."""
111      return compat.as_bytes(json.dumps(accumulator))
112
113    def deserialize(self, encoded_accumulator):
114      """Deserialize an accumulator received from 'serialize()'."""
115      return json.loads(compat.as_text(encoded_accumulator))
116
117
118@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
119class PreprocessingLayerTest(keras_parameterized.TestCase):
120
121  def test_adapt_bad_input_fails(self):
122    """Test that non-Dataset/Numpy inputs cause a reasonable error."""
123    input_dataset = {"foo": 0}
124
125    layer = AddingPreprocessingLayer()
126    if context.executing_eagerly():
127      with self.assertRaisesRegex(ValueError, "Failed to find data adapter"):
128        layer.adapt(input_dataset)
129    else:
130      with self.assertRaisesRegex(ValueError, "requires a"):
131        layer.adapt(input_dataset)
132
133  def test_adapt_infinite_dataset_fails(self):
134    """Test that preproc layers fail if an infinite dataset is passed."""
135    input_dataset = dataset_ops.Dataset.from_tensor_slices(
136        np.array([[1], [2], [3], [4], [5], [0]])).repeat()
137
138    layer = AddingPreprocessingLayer()
139    if context.executing_eagerly():
140      with self.assertRaisesRegex(ValueError, "infinite dataset"):
141        layer.adapt(input_dataset)
142    else:
143      with self.assertRaisesRegex(ValueError,
144                                  ".*infinite number of elements.*"):
145        layer.adapt(input_dataset)
146
147  def test_pre_build_injected_update_with_no_build_fails(self):
148    """Test external update injection before build() is called fails."""
149    input_dataset = np.array([1, 2, 3, 4, 5])
150
151    layer = AddingPreprocessingLayer()
152    combiner = layer._combiner
153    updates = combiner.extract(combiner.compute(input_dataset))
154
155    with self.assertRaisesRegex(RuntimeError, ".*called after build.*"):
156      layer._set_state_variables(updates)
157
158  def test_setter_update(self):
159    """Test the prototyped setter method."""
160    input_data = keras.Input(shape=(1,))
161    layer = AddingPreprocessingLayer()
162    output = layer(input_data)
163    model = keras.Model(input_data, output)
164    model._run_eagerly = testing_utils.should_run_eagerly()
165
166    layer.set_total(15)
167
168    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
169
170  def test_pre_build_adapt_update_numpy(self):
171    """Test that preproc layers can adapt() before build() is called."""
172    input_dataset = np.array([1, 2, 3, 4, 5])
173
174    layer = AddingPreprocessingLayer()
175    layer.adapt(input_dataset)
176
177    input_data = keras.Input(shape=(1,))
178    output = layer(input_data)
179    model = keras.Model(input_data, output)
180    model._run_eagerly = testing_utils.should_run_eagerly()
181
182    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
183
184  def test_post_build_adapt_update_numpy(self):
185    """Test that preproc layers can adapt() after build() is called."""
186    input_dataset = np.array([1, 2, 3, 4, 5])
187
188    input_data = keras.Input(shape=(1,))
189    layer = AddingPreprocessingLayer()
190    output = layer(input_data)
191    model = keras.Model(input_data, output)
192    model._run_eagerly = testing_utils.should_run_eagerly()
193
194    layer.adapt(input_dataset)
195
196    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
197
198  def test_pre_build_injected_update(self):
199    """Test external update injection before build() is called."""
200    input_dataset = np.array([1, 2, 3, 4, 5])
201
202    layer = AddingPreprocessingLayer()
203    combiner = layer._combiner
204    updates = combiner.extract(combiner.compute(input_dataset))
205
206    layer.build((1,))
207    layer._set_state_variables(updates)
208
209    input_data = keras.Input(shape=(1,))
210    output = layer(input_data)
211    model = keras.Model(input_data, output)
212    model._run_eagerly = testing_utils.should_run_eagerly()
213
214    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
215
216  def test_post_build_injected_update(self):
217    """Test external update injection after build() is called."""
218    input_dataset = np.array([1, 2, 3, 4, 5])
219    input_data = keras.Input(shape=(1,))
220    layer = AddingPreprocessingLayer()
221    output = layer(input_data)
222    model = keras.Model(input_data, output)
223    model._run_eagerly = testing_utils.should_run_eagerly()
224
225    combiner = layer._combiner
226    updates = combiner.extract(combiner.compute(input_dataset))
227    layer._set_state_variables(updates)
228
229    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
230
231  def test_pre_build_adapt_update_dataset(self):
232    """Test that preproc layers can adapt() before build() is called."""
233    input_dataset = dataset_ops.Dataset.from_tensor_slices(
234        np.array([[1], [2], [3], [4], [5], [0]]))
235
236    layer = AddingPreprocessingLayer()
237    layer.adapt(input_dataset)
238
239    input_data = keras.Input(shape=(1,))
240    output = layer(input_data)
241    model = keras.Model(input_data, output)
242    model._run_eagerly = testing_utils.should_run_eagerly()
243
244    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
245
246  def test_post_build_adapt_update_dataset(self):
247    """Test that preproc layers can adapt() after build() is called."""
248    input_dataset = dataset_ops.Dataset.from_tensor_slices(
249        np.array([[1], [2], [3], [4], [5], [0]]))
250
251    input_data = keras.Input(shape=(1,))
252    layer = AddingPreprocessingLayer()
253    output = layer(input_data)
254    model = keras.Model(input_data, output)
255    model._run_eagerly = testing_utils.should_run_eagerly()
256
257    layer.adapt(input_dataset)
258
259    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
260
261  def test_further_tuning(self):
262    """Test that models can be tuned with multiple calls to 'adapt'."""
263
264    input_dataset = np.array([1, 2, 3, 4, 5])
265
266    layer = AddingPreprocessingLayer()
267    layer.adapt(input_dataset)
268
269    input_data = keras.Input(shape=(1,))
270    output = layer(input_data)
271    model = keras.Model(input_data, output)
272    model._run_eagerly = testing_utils.should_run_eagerly()
273
274    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
275
276    layer.adapt(np.array([1, 2]), reset_state=False)
277    self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.]))
278
279  def test_further_tuning_post_injection(self):
280    """Test that models can be tuned with multiple calls to 'adapt'."""
281
282    input_dataset = np.array([1, 2, 3, 4, 5])
283
284    layer = AddingPreprocessingLayer()
285
286    input_data = keras.Input(shape=(1,))
287    output = layer(input_data)
288    model = keras.Model(input_data, output)
289    model._run_eagerly = testing_utils.should_run_eagerly()
290
291    combiner = layer._combiner
292    updates = combiner.extract(combiner.compute(input_dataset))
293    layer._set_state_variables(updates)
294    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
295
296    layer.adapt(np.array([1, 2]), reset_state=False)
297    self.assertAllEqual([[19], [20], [21]], model.predict([1., 2., 3.]))
298
299  def test_weight_based_state_transfer(self):
300    """Test that preproc layers can transfer state via get/set weights.."""
301
302    def get_model():
303      input_data = keras.Input(shape=(1,))
304      layer = AddingPreprocessingLayer()
305      output = layer(input_data)
306      model = keras.Model(input_data, output)
307      model._run_eagerly = testing_utils.should_run_eagerly()
308      return (model, layer)
309
310    input_dataset = np.array([1, 2, 3, 4, 5])
311    model, layer = get_model()
312    layer.adapt(input_dataset)
313    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
314
315    # Create a new model and verify it has no state carryover.
316    weights = model.get_weights()
317    model_2, _ = get_model()
318    self.assertAllEqual([[1], [2], [3]], model_2.predict([1., 2., 3.]))
319
320    # Transfer state from model to model_2 via get/set weights.
321    model_2.set_weights(weights)
322    self.assertAllEqual([[16], [17], [18]], model_2.predict([1., 2., 3.]))
323
324  def test_weight_based_state_transfer_with_further_tuning(self):
325    """Test that transferred state can be used to further tune a model.."""
326
327    def get_model():
328      input_data = keras.Input(shape=(1,))
329      layer = AddingPreprocessingLayer()
330      output = layer(input_data)
331      model = keras.Model(input_data, output)
332      model._run_eagerly = testing_utils.should_run_eagerly()
333      return (model, layer)
334
335    input_dataset = np.array([1, 2, 3, 4, 5])
336    model, layer = get_model()
337    layer.adapt(input_dataset)
338    self.assertAllEqual([[16], [17], [18]], model.predict([1., 2., 3.]))
339
340    # Transfer state from model to model_2 via get/set weights.
341    weights = model.get_weights()
342    model_2, layer_2 = get_model()
343    model_2.set_weights(weights)
344
345    # Further adapt this layer based on the transferred weights.
346    layer_2.adapt(np.array([1, 2]), reset_state=False)
347    self.assertAllEqual([[19], [20], [21]], model_2.predict([1., 2., 3.]))
348
349  def test_loading_without_providing_class_fails(self):
350    input_data = keras.Input(shape=(1,))
351    layer = AddingPreprocessingLayer()
352    output = layer(input_data)
353    model = keras.Model(input_data, output)
354
355    if not context.executing_eagerly():
356      self.evaluate(variables.variables_initializer(model.variables))
357
358    output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
359    model.save(output_path, save_format="tf")
360
361    with self.assertRaisesRegex(RuntimeError, "Unable to restore a layer of"):
362      _ = keras.models.load_model(output_path)
363
364  def test_adapt_sets_input_shape_rank(self):
365    """Check that `.adapt()` sets the `input_shape`'s rank."""
366    # Shape: (3,1,2)
367    adapt_dataset = np.array([[[1., 2.]],
368                              [[3., 4.]],
369                              [[5., 6.]]], dtype=np.float32)
370
371    layer = AddingPreprocessingLayer()
372    layer.adapt(adapt_dataset)
373
374    input_dataset = np.array([[[1., 2.], [3., 4.]],
375                              [[3., 4.], [5., 6.]]], dtype=np.float32)
376    layer(input_dataset)
377
378    model = keras.Sequential([layer])
379    self.assertTrue(model.built)
380    self.assertEqual(model.input_shape, (None, None, None))
381
382  def test_adapt_doesnt_overwrite_input_shape(self):
383    """Check that `.adapt()` doesn't change the `input_shape`."""
384    # Shape: (3, 1, 2)
385    adapt_dataset = np.array([[[1., 2.]],
386                              [[3., 4.]],
387                              [[5., 6.]]], dtype=np.float32)
388
389    layer = AddingPreprocessingLayer(input_shape=[1, 2])
390    layer.adapt(adapt_dataset)
391
392    model = keras.Sequential([layer])
393    self.assertTrue(model.built)
394    self.assertEqual(model.input_shape, (None, 1, 2))
395
396
397class PreprocessingLayerV1Test(keras_parameterized.TestCase):
398
399  def test_adapt_fails(self):
400    """Test that calling adapt leads to a runtime error."""
401    input_dataset = {"foo": 0}
402
403    with ops.Graph().as_default():
404      layer = AddingPreprocessingLayer()
405      with self.assertRaisesRegex(RuntimeError,
406                                  "`adapt` is only supported in tensorflow v2"):
407        layer.adapt(input_dataset)
408
409
410@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
411class ConvertToListTest(keras_parameterized.TestCase):
412
413  # Note: We need the inputs to be lambdas below to avoid some strangeness with
414  # TF1.x graph mode - specifically, if the inputs are created outside the test
415  # function body, the graph inside the test body will not contain the tensors
416  # that were created in the parameters.
417  @parameterized.named_parameters(
418      {
419          "testcase_name": "ndarray",
420          "inputs": lambda: np.array([[1, 2, 3], [4, 5, 6]]),
421          "expected": [[1, 2, 3], [4, 5, 6]]
422      }, {
423          "testcase_name": "list",
424          "inputs": lambda: [[1, 2, 3], [4, 5, 6]],
425          "expected": [[1, 2, 3], [4, 5, 6]]
426      }, {
427          "testcase_name": "tensor",
428          "inputs": lambda: constant_op.constant([[1, 2, 3], [4, 5, 6]]),
429          "expected": [[1, 2, 3], [4, 5, 6]]
430      }, {
431          "testcase_name":
432              "ragged_tensor",
433          "inputs":
434              lambda: ragged_factory_ops.constant([[1, 2, 3, 4], [4, 5, 6]]),
435          "expected": [[1, 2, 3, 4], [4, 5, 6]]
436      }, {
437          "testcase_name": "sparse_tensor",
438          "inputs": lambda: sparse_ops.from_dense([[1, 2, 0, 4], [4, 5, 6, 0]]),
439          "expected": [[1, 2, -1, 4], [4, 5, 6, -1]]
440      })
441  def test_conversion(self, inputs, expected):
442    values = base_preprocessing_layer.convert_to_list(inputs())
443    self.assertAllEqual(expected, values)
444
445
446if __name__ == "__main__":
447  test.main()
448