1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for losses_utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import test_util 23from tensorflow.python.framework import ops 24from tensorflow.python.keras import combinations 25from tensorflow.python.keras.utils import losses_utils 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops.ragged import ragged_factory_ops 28from tensorflow.python.platform import test 29 30 31@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 32class RemoveSqueezableTest(test_util.TensorFlowTestCase): 33 """Test remove_squeezable_dimensions""" 34 35 def test_ragged_3d_same_shape(self): 36 """ shape (2, (sequence={1, 2}), 3)""" 37 x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]]) 38 rank = x.shape.ndims 39 x_p, _ = losses_utils.remove_squeezable_dimensions(x, x) 40 self.assertEqual(x_p.shape.ndims, rank) 41 42 def test_ragged_3d_4d_squeezable(self): 43 """ shapes: 44 45 x: (2, (sequence={1, 2}), 3) 46 y: (2, (sequence={1, 2}), 3, 1) 47 """ 48 x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]]) 49 y = array_ops.expand_dims(x, axis=-1) 50 self.assertEqual(x.shape.ndims, 3) 51 self.assertEqual(y.shape.ndims, 4) 52 _, y_p = losses_utils.remove_squeezable_dimensions(x, y) 53 y_p.shape.assert_is_compatible_with(x.shape) 54 self.assertEqual(y_p.shape.ndims, 3) 55 56 x_p, _ = losses_utils.remove_squeezable_dimensions(y, x) 57 x_p.shape.assert_is_compatible_with(x.shape) 58 self.assertEqual(x_p.shape.ndims, 3) 59 60 def test_dense_2d_3d_squeezable(self): 61 x = constant_op.constant([[1, 2], [3, 4]]) 62 y = constant_op.constant([[[1], [2]], [[3], [4]]]) 63 _, y_p = losses_utils.remove_squeezable_dimensions(x, y) 64 y_p.shape.assert_is_compatible_with(x.shape) 65 self.assertEqual(y_p.shape.ndims, x.shape.ndims) 66 x_p, _ = losses_utils.remove_squeezable_dimensions(y, x) 67 x_p.shape.assert_is_compatible_with(x.shape) 68 69 70class RemoveSqueezableTestGraphOnly(test_util.TensorFlowTestCase): 71 """Test remove_squeezable_dimensions (graph-mode only).""" 72 73 def test_placeholder(self): 74 """Test dynamic rank tensors.""" 75 with ops.Graph().as_default(): 76 x = array_ops.placeholder_with_default([1., 2., 3.], shape=None) 77 y = array_ops.placeholder_with_default([[1.], [2.], [3.]], shape=None) 78 _, y_p = losses_utils.remove_squeezable_dimensions(x, y) 79 y_p.shape.assert_is_compatible_with(x.shape) 80 self.assertAllEqual(array_ops.shape(x), array_ops.shape(y_p)) 81 x_p, _ = losses_utils.remove_squeezable_dimensions(y, x) 82 x_p.shape.assert_is_compatible_with(x.shape) 83 84 85if __name__ == '__main__': 86 test.main() 87