• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras TF utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl.testing import parameterized
22
23from tensorflow.python import keras
24from tensorflow.python.eager import context
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.keras import combinations
29from tensorflow.python.keras.utils import tf_utils
30from tensorflow.python.ops import sparse_ops
31from tensorflow.python.ops import variables
32from tensorflow.python.ops.ragged import ragged_factory_ops
33from tensorflow.python.ops.ragged import ragged_tensor
34from tensorflow.python.platform import test
35
36try:
37  import attr  # pylint:disable=g-import-not-at-top
38except ImportError:
39  attr = None
40
41
42@combinations.generate(combinations.combine(mode=['graph', 'eager']))
43class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase):
44
45  def test_default_behavior(self):
46    if context.executing_eagerly():
47      self.assertFalse(tf_utils.is_symbolic_tensor(
48          variables.Variable(name='blah', initial_value=0.)))
49      self.assertFalse(
50          tf_utils.is_symbolic_tensor(
51              ops.convert_to_tensor_v2_with_dispatch(0.)))
52      self.assertFalse(tf_utils.is_symbolic_tensor(
53          sparse_tensor.SparseTensor(
54              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
55    else:
56      self.assertTrue(tf_utils.is_symbolic_tensor(
57          variables.Variable(name='blah', initial_value=0.)))
58      self.assertTrue(
59          tf_utils.is_symbolic_tensor(
60              ops.convert_to_tensor_v2_with_dispatch(0.)))
61      self.assertTrue(tf_utils.is_symbolic_tensor(
62          sparse_tensor.SparseTensor(
63              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
64
65  def test_works_with_registered(self):
66
67    class CustomClass(object):
68
69      def value(self):
70        return ops.convert_to_tensor_v2_with_dispatch(42.)
71
72    ops.register_tensor_conversion_function(
73        CustomClass, lambda value, **_: value.value())
74
75    tf_utils.register_symbolic_tensor_type(CustomClass)
76
77    if context.executing_eagerly():
78      self.assertFalse(tf_utils.is_symbolic_tensor(
79          variables.Variable(name='blah', initial_value=0.)))
80      self.assertFalse(
81          tf_utils.is_symbolic_tensor(
82              ops.convert_to_tensor_v2_with_dispatch(0.)))
83      self.assertFalse(tf_utils.is_symbolic_tensor(
84          sparse_tensor.SparseTensor(
85              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
86      self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass()))
87    else:
88      self.assertTrue(tf_utils.is_symbolic_tensor(
89          variables.Variable(name='blah', initial_value=0.)))
90      self.assertTrue(
91          tf_utils.is_symbolic_tensor(
92              ops.convert_to_tensor_v2_with_dispatch(0.)))
93      self.assertTrue(tf_utils.is_symbolic_tensor(
94          sparse_tensor.SparseTensor(
95              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
96      self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass()))
97
98  def test_enables_nontensor_plumbing(self):
99    if context.executing_eagerly():
100      self.skipTest('`compile` functionality changed.')
101    # Setup.
102
103    class Foo(object):
104
105      def __init__(self, input_):
106        self._input = input_
107        self.value = ops.convert_to_tensor_v2_with_dispatch([[42.]])
108
109      @property
110      def dtype(self):
111        return self.value.dtype
112
113    ops.register_tensor_conversion_function(
114        Foo, lambda x, *args, **kwargs: x.value)
115    tf_utils.register_symbolic_tensor_type(Foo)
116
117    class PlumbingLayer(keras.layers.Lambda):
118
119      def __init__(self, fn, **kwargs):
120        def _fn(*fargs, **fkwargs):
121          d = fn(*fargs, **fkwargs)
122          x = ops.convert_to_tensor_v2_with_dispatch(d)
123          d.shape = x.shape
124          d.get_shape = x.get_shape
125          return d, x
126        super(PlumbingLayer, self).__init__(_fn, **kwargs)
127        self._enter_dunder_call = False
128
129      def __call__(self, inputs, *args, **kwargs):
130        self._enter_dunder_call = True
131        d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs)
132        self._enter_dunder_call = False
133        return d
134
135      def call(self, inputs, *args, **kwargs):
136        d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
137        if self._enter_dunder_call:
138          return d, v
139        return d
140
141    # User-land.
142    model = keras.Sequential([
143        keras.layers.InputLayer((1,)),
144        PlumbingLayer(Foo),  # Makes a `Foo` object.
145    ])
146    # Let's ensure Keras graph history is preserved by composing the models.
147    model = keras.Model(model.inputs, model(model.outputs))
148    # Now we instantiate the model and verify we have a `Foo` object, not a
149    # `Tensor`.
150    y = model(ops.convert_to_tensor_v2_with_dispatch([[7.]]))
151    self.assertIsInstance(y, Foo)
152    # Confirm that (custom) loss sees `Foo` instance, not Tensor.
153    obtained_prediction_box = [None]
154    def custom_loss(y_obs, y_pred):
155      del y_obs
156      obtained_prediction_box[0] = y_pred
157      return y_pred
158    # Apparently `compile` calls the loss function enough to trigger the
159    # side-effect.
160    model.compile('SGD', loss=custom_loss)
161    self.assertIsInstance(obtained_prediction_box[0], Foo)
162
163
164class ConvertInnerNodeDataTest(test.TestCase):
165
166  def test_convert_inner_node_data(self):
167    data = tf_utils.convert_inner_node_data((tf_utils.ListWrapper(['l', 2, 3]),
168                                             tf_utils.ListWrapper(['l', 5, 6])))
169    self.assertEqual(data, (['l', 2, 3], ['l', 5, 6]))
170
171    data = tf_utils.convert_inner_node_data(((['l', 2, 3], ['l', 5, 6])),
172                                            wrap=True)
173    self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data))
174
175
176class AttrsTest(test.TestCase):
177
178  def test_map_structure_with_atomic_accept_attr(self):
179    if attr is None:
180      self.skipTest('attr module is unavailable.')
181
182    @attr.s(frozen=True)
183    class Foo(object):
184
185      bar = attr.ib()
186
187    self.assertEqual(
188        Foo(2),
189        tf_utils.map_structure_with_atomic(
190            is_atomic_fn=lambda x: isinstance(x, int),
191            map_fn=lambda x: x + 1,
192            nested=Foo(1)))
193
194
195class TestIsRagged(test.TestCase):
196
197  def test_is_ragged_return_true_for_ragged_tensor(self):
198    tensor = ragged_tensor.RaggedTensor.from_row_splits(
199        values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
200    self.assertTrue(tf_utils.is_ragged(tensor))
201
202  def test_is_ragged_return_false_for_list(self):
203    tensor = [1., 2., 3.]
204    self.assertFalse(tf_utils.is_ragged(tensor))
205
206
207class TestIsExtensionType(test.TestCase):
208
209  def test_is_extension_type_return_true_for_ragged_tensor(self):
210    self.assertTrue(tf_utils.is_extension_type(
211        ragged_factory_ops.constant([[1, 2], [3]])))
212
213  def test_is_extension_type_return_true_for_sparse_tensor(self):
214    self.assertTrue(tf_utils.is_extension_type(
215        sparse_ops.from_dense([[1, 2], [3, 4]])))
216
217  def test_is_extension_type_return_false_for_dense_tensor(self):
218    self.assertFalse(tf_utils.is_extension_type(
219        constant_op.constant([[1, 2], [3, 4]])))
220
221  def test_is_extension_type_return_false_for_list(self):
222    tensor = [1., 2., 3.]
223    self.assertFalse(tf_utils.is_extension_type(tensor))
224
225if __name__ == '__main__':
226  test.main()
227