1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras utilities to split v1 and v2 classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import numpy as np 24import six 25 26from tensorflow.python import keras 27from tensorflow.python.framework import ops 28from tensorflow.python.keras import keras_parameterized 29from tensorflow.python.keras.engine import base_layer 30from tensorflow.python.keras.engine import base_layer_v1 31from tensorflow.python.keras.engine import training 32from tensorflow.python.keras.engine import training_v1 33from tensorflow.python.platform import test 34 35 36@keras_parameterized.run_all_keras_modes 37class SplitUtilsTest(keras_parameterized.TestCase): 38 39 def _check_model_class(self, model_class): 40 if ops.executing_eagerly_outside_functions(): 41 self.assertEqual(model_class, training.Model) 42 else: 43 self.assertEqual(model_class, training_v1.Model) 44 45 def _check_layer_class(self, layer): 46 if ops.executing_eagerly_outside_functions(): 47 self.assertIsInstance(layer, base_layer.Layer) 48 self.assertNotIsInstance(layer, base_layer_v1.Layer) 49 else: 50 self.assertIsInstance(layer, base_layer_v1.Layer) 51 52 def test_functional_model(self): 53 inputs = keras.Input(10) 54 outputs = keras.layers.Dense(1)(inputs) 55 model = keras.Model(inputs, outputs) 56 self._check_model_class(model.__class__.__bases__[0]) 57 self._check_layer_class(model) 58 59 def test_subclass_model_with_functional_init(self): 60 inputs = keras.Input(10) 61 outputs = keras.layers.Dense(1)(inputs) 62 63 class MyModel(keras.Model): 64 pass 65 66 model = MyModel(inputs, outputs) 67 model_class = model.__class__.__bases__[0].__bases__[0] 68 self._check_model_class(model_class) 69 self._check_layer_class(model) 70 71 def test_subclass_model_with_functional_init_interleaved_v1_functional(self): 72 with ops.Graph().as_default(): 73 inputs = keras.Input(10) 74 outputs = keras.layers.Dense(1)(inputs) 75 _ = keras.Model(inputs, outputs) 76 77 inputs = keras.Input(10) 78 outputs = keras.layers.Dense(1)(inputs) 79 80 class MyModel(keras.Model): 81 pass 82 83 model = MyModel(inputs, outputs) 84 model_class = model.__class__.__bases__[0].__bases__[0] 85 self._check_model_class(model_class) 86 self._check_layer_class(model) 87 88 def test_sequential_model(self): 89 model = keras.Sequential([keras.layers.Dense(1)]) 90 model_class = model.__class__.__bases__[0].__bases__[0] 91 self._check_model_class(model_class) 92 self._check_layer_class(model) 93 94 def test_subclass_model(self): 95 96 class MyModel(keras.Model): 97 98 def call(self, x): 99 return 2 * x 100 101 model = MyModel() 102 model_class = model.__class__.__bases__[0] 103 self._check_model_class(model_class) 104 self._check_layer_class(model) 105 106 def test_layer(self): 107 class IdentityLayer(base_layer.Layer): 108 """A layer that returns it's input. 109 110 Useful for testing a layer without a variable. 111 """ 112 113 def call(self, inputs): 114 return inputs 115 116 layer = IdentityLayer() 117 self._check_layer_class(layer) 118 119 def test_multiple_subclass_model(self): 120 121 class Model1(keras.Model): 122 pass 123 124 class Model2(Model1): 125 126 def call(self, x): 127 return 2 * x 128 129 model = Model2() 130 model_class = model.__class__.__bases__[0].__bases__[0] 131 self._check_model_class(model_class) 132 self._check_layer_class(model) 133 134 def test_user_provided_metaclass(self): 135 136 @six.add_metaclass(abc.ABCMeta) 137 class AbstractModel(keras.Model): 138 139 @abc.abstractmethod 140 def call(self, inputs): 141 """Calls the model.""" 142 143 class MyModel(AbstractModel): 144 145 def call(self, inputs): 146 return 2 * inputs 147 148 with self.assertRaisesRegex(TypeError, 'instantiate abstract class'): 149 AbstractModel() 150 151 model = MyModel() 152 model_class = model.__class__.__bases__[0].__bases__[0] 153 self._check_model_class(model_class) 154 self._check_layer_class(model) 155 156 def test_multiple_inheritance(self): 157 158 class Return2(object): 159 160 def return_2(self): 161 return 2 162 163 class MyModel(keras.Model, Return2): 164 165 def call(self, x): 166 return self.return_2() * x 167 168 model = MyModel() 169 bases = model.__class__.__bases__ 170 self._check_model_class(bases[0]) 171 self.assertEqual(bases[1], Return2) 172 self.assertEqual(model.return_2(), 2) 173 self._check_layer_class(model) 174 175 def test_fit_error(self): 176 if not ops.executing_eagerly_outside_functions(): 177 # Error only appears on the v2 class. 178 return 179 180 model = keras.Sequential([keras.layers.Dense(1)]) 181 model.compile('sgd', 'mse') 182 x, y = np.ones((10, 10)), np.ones((10, 1)) 183 with ops.get_default_graph().as_default(): 184 with self.assertRaisesRegex( 185 ValueError, 'instance was constructed with eager mode enabled'): 186 model.fit(x, y, batch_size=2) 187 188 189if __name__ == '__main__': 190 test.main() 191