• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#,============================================================================
15"""Tests for InputLayer construction."""
16
17from tensorflow.python.eager import def_function
18from tensorflow.python.framework import composite_tensor
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.framework import tensor_spec
23from tensorflow.python.framework import type_spec
24from tensorflow.python.keras import backend
25from tensorflow.python.keras import combinations
26from tensorflow.python.keras import keras_parameterized
27from tensorflow.python.keras.engine import functional
28from tensorflow.python.keras.engine import input_layer as input_layer_lib
29from tensorflow.python.keras.layers import core
30from tensorflow.python.keras.saving import model_config
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.platform import test
35
36
37class TwoTensors(composite_tensor.CompositeTensor):
38  """A simple value type to test TypeSpec.
39
40  Contains two tensors (x, y) and a string (color).  The color value is a
41  stand-in for any extra type metadata we might need to store.
42
43  This value type contains no single dtype.
44  """
45
46  def __init__(self, x, y, color='red', assign_variant_dtype=False):
47    assert isinstance(color, str)
48    self.x = ops.convert_to_tensor_v2_with_dispatch(x)
49    self.y = ops.convert_to_tensor_v2_with_dispatch(y)
50    self.color = color
51    self.shape = tensor_shape.TensorShape(None)
52    self._shape = tensor_shape.TensorShape(None)
53    if assign_variant_dtype:
54      self.dtype = dtypes.variant
55    self._assign_variant_dtype = assign_variant_dtype
56
57  def _type_spec(self):
58    return TwoTensorsSpecNoOneDtype(
59        self.x.shape, self.x.dtype, self.y.shape,
60        self.y.dtype, color=self.color,
61        assign_variant_dtype=self._assign_variant_dtype)
62
63
64def as_shape(shape):
65  """Converts the given object to a TensorShape."""
66  if isinstance(shape, tensor_shape.TensorShape):
67    return shape
68  else:
69    return tensor_shape.TensorShape(shape)
70
71
72@type_spec.register('tf.TwoTensorsSpec')
73class TwoTensorsSpecNoOneDtype(type_spec.TypeSpec):
74  """A TypeSpec for the TwoTensors value type."""
75
76  def __init__(
77      self, x_shape, x_dtype, y_shape, y_dtype, color='red',
78      assign_variant_dtype=False):
79    self.x_shape = as_shape(x_shape)
80    self.x_dtype = dtypes.as_dtype(x_dtype)
81    self.y_shape = as_shape(y_shape)
82    self.y_dtype = dtypes.as_dtype(y_dtype)
83    self.color = color
84    self.shape = tensor_shape.TensorShape(None)
85    self._shape = tensor_shape.TensorShape(None)
86    if assign_variant_dtype:
87      self.dtype = dtypes.variant
88    self._assign_variant_dtype = assign_variant_dtype
89
90  value_type = property(lambda self: TwoTensors)
91
92  @property
93  def _component_specs(self):
94    return (tensor_spec.TensorSpec(self.x_shape, self.x_dtype),
95            tensor_spec.TensorSpec(self.y_shape, self.y_dtype))
96
97  def _to_components(self, value):
98    return (value.x, value.y)
99
100  def _from_components(self, components):
101    x, y = components
102    return TwoTensors(x, y, self.color)
103
104  def _serialize(self):
105    return (self.x_shape, self.x_dtype, self.y_shape, self.y_dtype, self.color)
106
107  @classmethod
108  def from_value(cls, value):
109    return cls(value.x.shape, value.x.dtype, value.y.shape, value.y.dtype,
110               value.color)
111
112
113type_spec.register_type_spec_from_value_converter(
114    TwoTensors, TwoTensorsSpecNoOneDtype.from_value)
115
116
117class InputLayerTest(keras_parameterized.TestCase):
118
119  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
120  def testBasicOutputShapeNoBatchSize(self):
121    # Create a Keras Input
122    x = input_layer_lib.Input(shape=(32,), name='input_a')
123    self.assertAllEqual(x.shape.as_list(), [None, 32])
124
125    # Verify you can construct and use a model w/ this input
126    model = functional.Functional(x, x * 2.0)
127    self.assertAllEqual(model(array_ops.ones((3, 32))),
128                        array_ops.ones((3, 32)) * 2.0)
129
130  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
131  def testBasicOutputShapeWithBatchSize(self):
132    # Create a Keras Input
133    x = input_layer_lib.Input(batch_size=6, shape=(32,), name='input_b')
134    self.assertAllEqual(x.shape.as_list(), [6, 32])
135
136    # Verify you can construct and use a model w/ this input
137    model = functional.Functional(x, x * 2.0)
138    self.assertAllEqual(model(array_ops.ones(x.shape)),
139                        array_ops.ones(x.shape) * 2.0)
140
141  @combinations.generate(combinations.combine(mode=['eager']))
142  def testBasicOutputShapeNoBatchSizeInTFFunction(self):
143    model = None
144    @def_function.function
145    def run_model(inp):
146      nonlocal model
147      if not model:
148        # Create a Keras Input
149        x = input_layer_lib.Input(shape=(8,), name='input_a')
150        self.assertAllEqual(x.shape.as_list(), [None, 8])
151
152        # Verify you can construct and use a model w/ this input
153        model = functional.Functional(x, x * 2.0)
154      return model(inp)
155
156    self.assertAllEqual(run_model(array_ops.ones((10, 8))),
157                        array_ops.ones((10, 8)) * 2.0)
158
159  @combinations.generate(combinations.combine(mode=['graph', 'eager']))
160  def testInputTensorArg(self):
161    # Create a Keras Input
162    x = input_layer_lib.Input(tensor=array_ops.zeros((7, 32)))
163    self.assertAllEqual(x.shape.as_list(), [7, 32])
164
165    # Verify you can construct and use a model w/ this input
166    model = functional.Functional(x, x * 2.0)
167    self.assertAllEqual(model(array_ops.ones(x.shape)),
168                        array_ops.ones(x.shape) * 2.0)
169
170  @combinations.generate(combinations.combine(mode=['eager']))
171  def testInputTensorArgInTFFunction(self):
172    # We use a mutable model container instead of a model python variable,
173    # because python 2.7 does not have `nonlocal`
174    model_container = {}
175
176    @def_function.function
177    def run_model(inp):
178      if not model_container:
179        # Create a Keras Input
180        x = input_layer_lib.Input(tensor=array_ops.zeros((10, 16)))
181        self.assertAllEqual(x.shape.as_list(), [10, 16])
182
183        # Verify you can construct and use a model w/ this input
184        model_container['model'] = functional.Functional(x, x * 3.0)
185      return model_container['model'](inp)
186
187    self.assertAllEqual(run_model(array_ops.ones((10, 16))),
188                        array_ops.ones((10, 16)) * 3.0)
189
190  @combinations.generate(combinations.combine(mode=['eager']))
191  def testCompositeInputTensorArg(self):
192    # Create a Keras Input
193    rt = ragged_tensor.RaggedTensor.from_row_splits(
194        values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
195    x = input_layer_lib.Input(tensor=rt)
196
197    # Verify you can construct and use a model w/ this input
198    model = functional.Functional(x, x * 2)
199
200    # And that the model works
201    rt = ragged_tensor.RaggedTensor.from_row_splits(
202        values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
203    self.assertAllEqual(model(rt), rt * 2)
204
205  @combinations.generate(combinations.combine(mode=['eager']))
206  def testCompositeInputTensorArgInTFFunction(self):
207    # We use a mutable model container instead of a model python variable,
208    # because python 2.7 does not have `nonlocal`
209    model_container = {}
210
211    @def_function.function
212    def run_model(inp):
213      if not model_container:
214        # Create a Keras Input
215        rt = ragged_tensor.RaggedTensor.from_row_splits(
216            values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
217        x = input_layer_lib.Input(tensor=rt)
218
219        # Verify you can construct and use a model w/ this input
220        model_container['model'] = functional.Functional(x, x * 3)
221      return model_container['model'](inp)
222
223    # And verify the model works
224    rt = ragged_tensor.RaggedTensor.from_row_splits(
225        values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
226    self.assertAllEqual(run_model(rt), rt * 3)
227
228  @combinations.generate(combinations.combine(mode=['eager']))
229  def testNoMixingArgsWithTypeSpecArg(self):
230    with self.assertRaisesRegexp(
231        ValueError, 'all other args except `name` must be None'):
232      input_layer_lib.Input(
233          shape=(4, 7),
234          type_spec=tensor_spec.TensorSpec((2, 7, 32), dtypes.float32))
235    with self.assertRaisesRegexp(
236        ValueError, 'all other args except `name` must be None'):
237      input_layer_lib.Input(
238          batch_size=4,
239          type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
240    with self.assertRaisesRegexp(
241        ValueError, 'all other args except `name` must be None'):
242      input_layer_lib.Input(
243          dtype=dtypes.int64,
244          type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
245    with self.assertRaisesRegexp(
246        ValueError, 'all other args except `name` must be None'):
247      input_layer_lib.Input(
248          sparse=True,
249          type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
250    with self.assertRaisesRegexp(
251        ValueError, 'all other args except `name` must be None'):
252      input_layer_lib.Input(
253          ragged=True,
254          type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
255
256  @combinations.generate(combinations.combine(mode=['eager']))
257  def testTypeSpecArg(self):
258    # Create a Keras Input
259    x = input_layer_lib.Input(
260        type_spec=tensor_spec.TensorSpec((7, 32), dtypes.float32))
261    self.assertAllEqual(x.shape.as_list(), [7, 32])
262
263    # Verify you can construct and use a model w/ this input
264    model = functional.Functional(x, x * 2.0)
265    self.assertAllEqual(model(array_ops.ones(x.shape)),
266                        array_ops.ones(x.shape) * 2.0)
267
268    # Test serialization / deserialization
269    model = functional.Functional.from_config(model.get_config())
270    self.assertAllEqual(model(array_ops.ones(x.shape)),
271                        array_ops.ones(x.shape) * 2.0)
272
273    model = model_config.model_from_json(model.to_json())
274    self.assertAllEqual(model(array_ops.ones(x.shape)),
275                        array_ops.ones(x.shape) * 2.0)
276
277  @combinations.generate(combinations.combine(mode=['eager']))
278  def testTypeSpecArgInTFFunction(self):
279    # We use a mutable model container instead of a model python variable,
280    # because python 2.7 does not have `nonlocal`
281    model_container = {}
282
283    @def_function.function
284    def run_model(inp):
285      if not model_container:
286        # Create a Keras Input
287        x = input_layer_lib.Input(
288            type_spec=tensor_spec.TensorSpec((10, 16), dtypes.float32))
289        self.assertAllEqual(x.shape.as_list(), [10, 16])
290
291        # Verify you can construct and use a model w/ this input
292        model_container['model'] = functional.Functional(x, x * 3.0)
293      return model_container['model'](inp)
294
295    self.assertAllEqual(run_model(array_ops.ones((10, 16))),
296                        array_ops.ones((10, 16)) * 3.0)
297
298  @combinations.generate(combinations.combine(mode=['eager']))
299  def testCompositeTypeSpecArg(self):
300    # Create a Keras Input
301    rt = ragged_tensor.RaggedTensor.from_row_splits(
302        values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
303    x = input_layer_lib.Input(type_spec=rt._type_spec)
304
305    # Verify you can construct and use a model w/ this input
306    model = functional.Functional(x, x * 2)
307
308    # And that the model works
309    rt = ragged_tensor.RaggedTensor.from_row_splits(
310        values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
311    self.assertAllEqual(model(rt), rt * 2)
312
313    # Test serialization / deserialization
314    model = functional.Functional.from_config(model.get_config())
315    self.assertAllEqual(model(rt), rt * 2)
316    model = model_config.model_from_json(model.to_json())
317    self.assertAllEqual(model(rt), rt * 2)
318
319  @combinations.generate(combinations.combine(mode=['eager']))
320  def testCompositeTypeSpecArgInTFFunction(self):
321    # We use a mutable model container instead of a model pysthon variable,
322    # because python 2.7 does not have `nonlocal`
323    model_container = {}
324
325    @def_function.function
326    def run_model(inp):
327      if not model_container:
328        # Create a Keras Input
329        rt = ragged_tensor.RaggedTensor.from_row_splits(
330            values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
331        x = input_layer_lib.Input(type_spec=rt._type_spec)
332
333        # Verify you can construct and use a model w/ this input
334        model_container['model'] = functional.Functional(x, x * 3)
335      return model_container['model'](inp)
336
337    # And verify the model works
338    rt = ragged_tensor.RaggedTensor.from_row_splits(
339        values=[3, 21, 4, 1, 53, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
340    self.assertAllEqual(run_model(rt), rt * 3)
341
342  @combinations.generate(combinations.combine(mode=['eager']))
343  def testCompositeTypeSpecArgWithoutDtype(self):
344    for assign_variant_dtype in [False, True]:
345      # Create a Keras Input
346      spec = TwoTensorsSpecNoOneDtype(
347          (1, 2, 3), dtypes.float32, (1, 2, 3), dtypes.int64,
348          assign_variant_dtype=assign_variant_dtype)
349      x = input_layer_lib.Input(type_spec=spec)
350
351      def lambda_fn(tensors):
352        return (math_ops.cast(tensors.x, dtypes.float64)
353                + math_ops.cast(tensors.y, dtypes.float64))
354      # Verify you can construct and use a model w/ this input
355      model = functional.Functional(x, core.Lambda(lambda_fn)(x))
356
357      # And that the model works
358      two_tensors = TwoTensors(array_ops.ones((1, 2, 3)) * 2.0,
359                               array_ops.ones(1, 2, 3))
360      self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
361
362      # Test serialization / deserialization
363      model = functional.Functional.from_config(model.get_config())
364      self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
365      model = model_config.model_from_json(model.to_json())
366      self.assertAllEqual(model(two_tensors), lambda_fn(two_tensors))
367
368  def test_serialize_with_unknown_rank(self):
369    inp = backend.placeholder(shape=None, dtype=dtypes.string)
370    x = input_layer_lib.InputLayer(input_tensor=inp, dtype=dtypes.string)
371    loaded = input_layer_lib.InputLayer.from_config(x.get_config())
372    self.assertIsNone(loaded._batch_input_shape)
373
374
375if __name__ == '__main__':
376  test.main()
377