• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for losses_utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import test_util
23from tensorflow.python.framework import ops
24from tensorflow.python.keras import combinations
25from tensorflow.python.keras.utils import losses_utils
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops.ragged import ragged_factory_ops
28from tensorflow.python.platform import test
29
30
31@combinations.generate(combinations.combine(mode=['graph', 'eager']))
32class RemoveSqueezableTest(test_util.TensorFlowTestCase):
33  """Test remove_squeezable_dimensions"""
34
35  def test_ragged_3d_same_shape(self):
36    """ shape (2, (sequence={1, 2}), 3)"""
37    x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]])
38    rank = x.shape.ndims
39    x_p, _ = losses_utils.remove_squeezable_dimensions(x, x)
40    self.assertEqual(x_p.shape.ndims, rank)
41
42  def test_ragged_3d_4d_squeezable(self):
43    """ shapes:
44
45        x: (2, (sequence={1, 2}), 3)
46        y: (2, (sequence={1, 2}), 3, 1)
47    """
48    x = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5, 6], [7, 8, 9]]])
49    y = array_ops.expand_dims(x, axis=-1)
50    self.assertEqual(x.shape.ndims, 3)
51    self.assertEqual(y.shape.ndims, 4)
52    _, y_p = losses_utils.remove_squeezable_dimensions(x, y)
53    y_p.shape.assert_is_compatible_with(x.shape)
54    self.assertEqual(y_p.shape.ndims, 3)
55
56    x_p, _ = losses_utils.remove_squeezable_dimensions(y, x)
57    x_p.shape.assert_is_compatible_with(x.shape)
58    self.assertEqual(x_p.shape.ndims, 3)
59
60  def test_dense_2d_3d_squeezable(self):
61    x = constant_op.constant([[1, 2], [3, 4]])
62    y = constant_op.constant([[[1], [2]], [[3], [4]]])
63    _, y_p = losses_utils.remove_squeezable_dimensions(x, y)
64    y_p.shape.assert_is_compatible_with(x.shape)
65    self.assertEqual(y_p.shape.ndims, x.shape.ndims)
66    x_p, _ = losses_utils.remove_squeezable_dimensions(y, x)
67    x_p.shape.assert_is_compatible_with(x.shape)
68
69
70class RemoveSqueezableTestGraphOnly(test_util.TensorFlowTestCase):
71  """Test remove_squeezable_dimensions (graph-mode only)."""
72
73  def test_placeholder(self):
74    """Test dynamic rank tensors."""
75    with ops.Graph().as_default():
76      x = array_ops.placeholder_with_default([1., 2., 3.], shape=None)
77      y = array_ops.placeholder_with_default([[1.], [2.], [3.]], shape=None)
78      _, y_p = losses_utils.remove_squeezable_dimensions(x, y)
79      y_p.shape.assert_is_compatible_with(x.shape)
80      self.assertAllEqual(array_ops.shape(x), array_ops.shape(y_p))
81      x_p, _ = losses_utils.remove_squeezable_dimensions(y, x)
82      x_p.shape.assert_is_compatible_with(x.shape)
83
84
85if __name__ == '__main__':
86  test.main()
87