• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""InputSpec tests."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import sparse_tensor
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.framework import tensor_spec
22from tensorflow.python.keras import layers
23from tensorflow.python.keras.engine import keras_tensor
24from tensorflow.python.ops import array_ops
25from tensorflow.python.platform import test
26
27
28class KerasTensorTest(test.TestCase):
29
30  def test_repr_and_string(self):
31    kt = keras_tensor.KerasTensor(
32        type_spec=tensor_spec.TensorSpec(shape=(1, 2, 3), dtype=dtypes.float32))
33    expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(1, 2, 3), "
34                    "dtype=tf.float32, name=None))")
35    expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>"
36    self.assertEqual(expected_str, str(kt))
37    self.assertEqual(expected_repr, repr(kt))
38
39    kt = keras_tensor.KerasTensor(
40        type_spec=tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32),
41        inferred_value=[2, 3])
42    expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(2,), "
43                    "dtype=tf.int32, name=None), inferred_value=[2, 3])")
44    expected_repr = (
45        "<KerasTensor: shape=(2,) dtype=int32 inferred_value=[2, 3]>")
46    self.assertEqual(expected_str, str(kt))
47    self.assertEqual(expected_repr, repr(kt))
48
49    kt = keras_tensor.KerasTensor(
50        type_spec=sparse_tensor.SparseTensorSpec(
51            shape=(1, 2, 3), dtype=dtypes.float32))
52    expected_str = ("KerasTensor(type_spec=SparseTensorSpec("
53                    "TensorShape([1, 2, 3]), tf.float32))")
54    expected_repr = (
55        "<KerasTensor: type_spec=SparseTensorSpec("
56        "TensorShape([1, 2, 3]), tf.float32)>")
57    self.assertEqual(expected_str, str(kt))
58    self.assertEqual(expected_repr, repr(kt))
59
60    inp = layers.Input(shape=(3, 5))
61    kt = layers.Dense(10)(inp)
62    expected_str = (
63        "KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), "
64        "dtype=tf.float32, name=None), name='dense/BiasAdd:0', "
65        "description=\"created by layer 'dense'\")")
66    expected_repr = (
67        "<KerasTensor: shape=(None, 3, 10) dtype=float32 (created "
68        "by layer 'dense')>")
69    self.assertEqual(expected_str, str(kt))
70    self.assertEqual(expected_repr, repr(kt))
71
72    kt = array_ops.reshape(kt, shape=(3, 5, 2))
73    expected_str = (
74        "KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), dtype=tf.float32, "
75        "name=None), name='tf.reshape/Reshape:0', description=\"created "
76        "by layer 'tf.reshape'\")")
77    expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (created "
78                     "by layer 'tf.reshape')>")
79    self.assertEqual(expected_str, str(kt))
80    self.assertEqual(expected_repr, repr(kt))
81
82    kts = array_ops.unstack(kt)
83    for i in range(3):
84      expected_str = (
85          "KerasTensor(type_spec=TensorSpec(shape=(5, 2), dtype=tf.float32, "
86          "name=None), name='tf.unstack/unstack:%s', description=\"created "
87          "by layer 'tf.unstack'\")" % (i,))
88      expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
89                       "(created by layer 'tf.unstack')>")
90      self.assertEqual(expected_str, str(kts[i]))
91      self.assertEqual(expected_repr, repr(kts[i]))
92
93if __name__ == "__main__":
94  ops.enable_eager_execution()
95  tensor_shape.enable_v2_tensorshape()
96  test.main()
97