1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras TF utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl.testing import parameterized 22 23from tensorflow.python import keras 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.keras import combinations 29from tensorflow.python.keras.utils import tf_utils 30from tensorflow.python.ops import sparse_ops 31from tensorflow.python.ops import variables 32from tensorflow.python.ops.ragged import ragged_factory_ops 33from tensorflow.python.ops.ragged import ragged_tensor 34from tensorflow.python.platform import test 35 36try: 37 import attr # pylint:disable=g-import-not-at-top 38except ImportError: 39 attr = None 40 41 42@combinations.generate(combinations.combine(mode=['graph', 'eager'])) 43class TestIsSymbolicTensor(test.TestCase, parameterized.TestCase): 44 45 def test_default_behavior(self): 46 if context.executing_eagerly(): 47 self.assertFalse(tf_utils.is_symbolic_tensor( 48 variables.Variable(name='blah', initial_value=0.))) 49 self.assertFalse( 50 tf_utils.is_symbolic_tensor( 51 ops.convert_to_tensor_v2_with_dispatch(0.))) 52 self.assertFalse(tf_utils.is_symbolic_tensor( 53 sparse_tensor.SparseTensor( 54 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 55 else: 56 self.assertTrue(tf_utils.is_symbolic_tensor( 57 variables.Variable(name='blah', initial_value=0.))) 58 self.assertTrue( 59 tf_utils.is_symbolic_tensor( 60 ops.convert_to_tensor_v2_with_dispatch(0.))) 61 self.assertTrue(tf_utils.is_symbolic_tensor( 62 sparse_tensor.SparseTensor( 63 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 64 65 def test_works_with_registered(self): 66 67 class CustomClass(object): 68 69 def value(self): 70 return ops.convert_to_tensor_v2_with_dispatch(42.) 71 72 ops.register_tensor_conversion_function( 73 CustomClass, lambda value, **_: value.value()) 74 75 tf_utils.register_symbolic_tensor_type(CustomClass) 76 77 if context.executing_eagerly(): 78 self.assertFalse(tf_utils.is_symbolic_tensor( 79 variables.Variable(name='blah', initial_value=0.))) 80 self.assertFalse( 81 tf_utils.is_symbolic_tensor( 82 ops.convert_to_tensor_v2_with_dispatch(0.))) 83 self.assertFalse(tf_utils.is_symbolic_tensor( 84 sparse_tensor.SparseTensor( 85 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 86 self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass())) 87 else: 88 self.assertTrue(tf_utils.is_symbolic_tensor( 89 variables.Variable(name='blah', initial_value=0.))) 90 self.assertTrue( 91 tf_utils.is_symbolic_tensor( 92 ops.convert_to_tensor_v2_with_dispatch(0.))) 93 self.assertTrue(tf_utils.is_symbolic_tensor( 94 sparse_tensor.SparseTensor( 95 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))) 96 self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass())) 97 98 def test_enables_nontensor_plumbing(self): 99 if context.executing_eagerly(): 100 self.skipTest('`compile` functionality changed.') 101 # Setup. 102 103 class Foo(object): 104 105 def __init__(self, input_): 106 self._input = input_ 107 self.value = ops.convert_to_tensor_v2_with_dispatch([[42.]]) 108 109 @property 110 def dtype(self): 111 return self.value.dtype 112 113 ops.register_tensor_conversion_function( 114 Foo, lambda x, *args, **kwargs: x.value) 115 tf_utils.register_symbolic_tensor_type(Foo) 116 117 class PlumbingLayer(keras.layers.Lambda): 118 119 def __init__(self, fn, **kwargs): 120 def _fn(*fargs, **fkwargs): 121 d = fn(*fargs, **fkwargs) 122 x = ops.convert_to_tensor_v2_with_dispatch(d) 123 d.shape = x.shape 124 d.get_shape = x.get_shape 125 return d, x 126 super(PlumbingLayer, self).__init__(_fn, **kwargs) 127 self._enter_dunder_call = False 128 129 def __call__(self, inputs, *args, **kwargs): 130 self._enter_dunder_call = True 131 d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs) 132 self._enter_dunder_call = False 133 return d 134 135 def call(self, inputs, *args, **kwargs): 136 d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs) 137 if self._enter_dunder_call: 138 return d, v 139 return d 140 141 # User-land. 142 model = keras.Sequential([ 143 keras.layers.InputLayer((1,)), 144 PlumbingLayer(Foo), # Makes a `Foo` object. 145 ]) 146 # Let's ensure Keras graph history is preserved by composing the models. 147 model = keras.Model(model.inputs, model(model.outputs)) 148 # Now we instantiate the model and verify we have a `Foo` object, not a 149 # `Tensor`. 150 y = model(ops.convert_to_tensor_v2_with_dispatch([[7.]])) 151 self.assertIsInstance(y, Foo) 152 # Confirm that (custom) loss sees `Foo` instance, not Tensor. 153 obtained_prediction_box = [None] 154 def custom_loss(y_obs, y_pred): 155 del y_obs 156 obtained_prediction_box[0] = y_pred 157 return y_pred 158 # Apparently `compile` calls the loss function enough to trigger the 159 # side-effect. 160 model.compile('SGD', loss=custom_loss) 161 self.assertIsInstance(obtained_prediction_box[0], Foo) 162 163 164class ConvertInnerNodeDataTest(test.TestCase): 165 166 def test_convert_inner_node_data(self): 167 data = tf_utils.convert_inner_node_data((tf_utils.ListWrapper(['l', 2, 3]), 168 tf_utils.ListWrapper(['l', 5, 6]))) 169 self.assertEqual(data, (['l', 2, 3], ['l', 5, 6])) 170 171 data = tf_utils.convert_inner_node_data(((['l', 2, 3], ['l', 5, 6])), 172 wrap=True) 173 self.assertTrue(all(isinstance(ele, tf_utils.ListWrapper) for ele in data)) 174 175 176class AttrsTest(test.TestCase): 177 178 def test_map_structure_with_atomic_accept_attr(self): 179 if attr is None: 180 self.skipTest('attr module is unavailable.') 181 182 @attr.s(frozen=True) 183 class Foo(object): 184 185 bar = attr.ib() 186 187 self.assertEqual( 188 Foo(2), 189 tf_utils.map_structure_with_atomic( 190 is_atomic_fn=lambda x: isinstance(x, int), 191 map_fn=lambda x: x + 1, 192 nested=Foo(1))) 193 194 195class TestIsRagged(test.TestCase): 196 197 def test_is_ragged_return_true_for_ragged_tensor(self): 198 tensor = ragged_tensor.RaggedTensor.from_row_splits( 199 values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) 200 self.assertTrue(tf_utils.is_ragged(tensor)) 201 202 def test_is_ragged_return_false_for_list(self): 203 tensor = [1., 2., 3.] 204 self.assertFalse(tf_utils.is_ragged(tensor)) 205 206 207class TestIsExtensionType(test.TestCase): 208 209 def test_is_extension_type_return_true_for_ragged_tensor(self): 210 self.assertTrue(tf_utils.is_extension_type( 211 ragged_factory_ops.constant([[1, 2], [3]]))) 212 213 def test_is_extension_type_return_true_for_sparse_tensor(self): 214 self.assertTrue(tf_utils.is_extension_type( 215 sparse_ops.from_dense([[1, 2], [3, 4]]))) 216 217 def test_is_extension_type_return_false_for_dense_tensor(self): 218 self.assertFalse(tf_utils.is_extension_type( 219 constant_op.constant([[1, 2], [3, 4]]))) 220 221 def test_is_extension_type_return_false_for_list(self): 222 tensor = [1., 2., 3.] 223 self.assertFalse(tf_utils.is_extension_type(tensor)) 224 225if __name__ == '__main__': 226 test.main() 227