1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""InputSpec tests.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import sparse_tensor 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.framework import tensor_spec 22from tensorflow.python.keras import layers 23from tensorflow.python.keras.engine import keras_tensor 24from tensorflow.python.ops import array_ops 25from tensorflow.python.platform import test 26 27 28class KerasTensorTest(test.TestCase): 29 30 def test_repr_and_string(self): 31 kt = keras_tensor.KerasTensor( 32 type_spec=tensor_spec.TensorSpec(shape=(1, 2, 3), dtype=dtypes.float32)) 33 expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(1, 2, 3), " 34 "dtype=tf.float32, name=None))") 35 expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>" 36 self.assertEqual(expected_str, str(kt)) 37 self.assertEqual(expected_repr, repr(kt)) 38 39 kt = keras_tensor.KerasTensor( 40 type_spec=tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32), 41 inferred_value=[2, 3]) 42 expected_str = ("KerasTensor(type_spec=TensorSpec(shape=(2,), " 43 "dtype=tf.int32, name=None), inferred_value=[2, 3])") 44 expected_repr = ( 45 "<KerasTensor: shape=(2,) dtype=int32 inferred_value=[2, 3]>") 46 self.assertEqual(expected_str, str(kt)) 47 self.assertEqual(expected_repr, repr(kt)) 48 49 kt = keras_tensor.KerasTensor( 50 type_spec=sparse_tensor.SparseTensorSpec( 51 shape=(1, 2, 3), dtype=dtypes.float32)) 52 expected_str = ("KerasTensor(type_spec=SparseTensorSpec(" 53 "TensorShape([1, 2, 3]), tf.float32))") 54 expected_repr = ( 55 "<KerasTensor: type_spec=SparseTensorSpec(" 56 "TensorShape([1, 2, 3]), tf.float32)>") 57 self.assertEqual(expected_str, str(kt)) 58 self.assertEqual(expected_repr, repr(kt)) 59 60 inp = layers.Input(shape=(3, 5)) 61 kt = layers.Dense(10)(inp) 62 expected_str = ( 63 "KerasTensor(type_spec=TensorSpec(shape=(None, 3, 10), " 64 "dtype=tf.float32, name=None), name='dense/BiasAdd:0', " 65 "description=\"created by layer 'dense'\")") 66 expected_repr = ( 67 "<KerasTensor: shape=(None, 3, 10) dtype=float32 (created " 68 "by layer 'dense')>") 69 self.assertEqual(expected_str, str(kt)) 70 self.assertEqual(expected_repr, repr(kt)) 71 72 kt = array_ops.reshape(kt, shape=(3, 5, 2)) 73 expected_str = ( 74 "KerasTensor(type_spec=TensorSpec(shape=(3, 5, 2), dtype=tf.float32, " 75 "name=None), name='tf.reshape/Reshape:0', description=\"created " 76 "by layer 'tf.reshape'\")") 77 expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (created " 78 "by layer 'tf.reshape')>") 79 self.assertEqual(expected_str, str(kt)) 80 self.assertEqual(expected_repr, repr(kt)) 81 82 kts = array_ops.unstack(kt) 83 for i in range(3): 84 expected_str = ( 85 "KerasTensor(type_spec=TensorSpec(shape=(5, 2), dtype=tf.float32, " 86 "name=None), name='tf.unstack/unstack:%s', description=\"created " 87 "by layer 'tf.unstack'\")" % (i,)) 88 expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 " 89 "(created by layer 'tf.unstack')>") 90 self.assertEqual(expected_str, str(kts[i])) 91 self.assertEqual(expected_repr, repr(kts[i])) 92 93if __name__ == "__main__": 94 ops.enable_eager_execution() 95 tensor_shape.enable_v2_tensorshape() 96 test.main() 97