1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""InputSpec tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.keras.engine import input_spec 22from tensorflow.python.platform import test 23 24 25class InputSpecTest(test.TestCase): 26 27 def test_axes_initialization(self): 28 input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2}) 29 with self.assertRaisesRegex(ValueError, 'Axis 4 is greater than'): 30 input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5}) 31 with self.assertRaisesRegex(TypeError, 'keys in axes must be integers'): 32 input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5}) 33 34 35class InputSpecToTensorShapeTest(test.TestCase): 36 37 def test_defined_shape(self): 38 spec = input_spec.InputSpec(shape=[1, None, 2, 3]) 39 self.assertAllEqual( 40 [1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list()) 41 42 def test_defined_ndims(self): 43 spec = input_spec.InputSpec(ndim=5) 44 self.assertAllEqual( 45 [None] * 5, input_spec.to_tensor_shape(spec).as_list()) 46 47 spec = input_spec.InputSpec(ndim=0) 48 self.assertAllEqual( 49 [], input_spec.to_tensor_shape(spec).as_list()) 50 51 spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2}) 52 self.assertAllEqual( 53 [None, 3, 2], input_spec.to_tensor_shape(spec).as_list()) 54 55 def test_undefined_shapes(self): 56 spec = input_spec.InputSpec(max_ndim=5) 57 with self.assertRaisesRegex(ValueError, 'unknown TensorShape'): 58 input_spec.to_tensor_shape(spec).as_list() 59 60 spec = input_spec.InputSpec(min_ndim=5, max_ndim=5) 61 with self.assertRaisesRegex(ValueError, 'unknown TensorShape'): 62 input_spec.to_tensor_shape(spec).as_list() 63 64 65if __name__ == '__main__': 66 test.main() 67