1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras utilities to split v1 and v2 classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import numpy as np 24import six 25 26from tensorflow.python import keras 27from tensorflow.python.eager import context 28from tensorflow.python.framework import ops 29from tensorflow.python.keras import keras_parameterized 30from tensorflow.python.keras.engine import training 31from tensorflow.python.keras.engine import training_v1 32from tensorflow.python.platform import test 33 34 35@keras_parameterized.run_all_keras_modes 36class SplitUtilsTest(keras_parameterized.TestCase): 37 38 def _check_model_class(self, model_class): 39 if ops.executing_eagerly_outside_functions(): 40 self.assertEqual(model_class, training.Model) 41 else: 42 self.assertEqual(model_class, training_v1.Model) 43 44 def test_functional_model(self): 45 inputs = keras.Input(10) 46 outputs = keras.layers.Dense(1)(inputs) 47 model = keras.Model(inputs, outputs) 48 self._check_model_class(model.__class__) 49 50 def test_sequential_model(self): 51 model = keras.Sequential([keras.layers.Dense(1)]) 52 model_class = model.__class__.__bases__[0] 53 self._check_model_class(model_class) 54 55 def test_subclass_model(self): 56 57 class MyModel(keras.Model): 58 59 def call(self, x): 60 return 2 * x 61 62 model = MyModel() 63 model_class = model.__class__.__bases__[0] 64 self._check_model_class(model_class) 65 66 def test_multiple_subclass_model(self): 67 68 class Model1(keras.Model): 69 pass 70 71 class Model2(Model1): 72 73 def call(self, x): 74 return 2 * x 75 76 model = Model2() 77 model_class = model.__class__.__bases__[0].__bases__[0] 78 self._check_model_class(model_class) 79 80 def test_user_provided_metaclass(self): 81 82 @six.add_metaclass(abc.ABCMeta) 83 class AbstractModel(keras.Model): 84 85 @abc.abstractmethod 86 def call(self, inputs): 87 """Calls the model.""" 88 89 class MyModel(AbstractModel): 90 91 def call(self, inputs): 92 return 2 * inputs 93 94 with self.assertRaisesRegexp(TypeError, 'instantiate abstract class'): 95 AbstractModel() 96 97 model = MyModel() 98 model_class = model.__class__.__bases__[0].__bases__[0] 99 self._check_model_class(model_class) 100 101 def test_multiple_inheritance(self): 102 103 class Return2(object): 104 105 def return_2(self): 106 return 2 107 108 class MyModel(keras.Model, Return2): 109 110 def call(self, x): 111 return self.return_2() * x 112 113 model = MyModel() 114 bases = model.__class__.__bases__ 115 self._check_model_class(bases[0]) 116 self.assertEqual(bases[1], Return2) 117 self.assertEqual(model.return_2(), 2) 118 119 def test_fit_error(self): 120 if not ops.executing_eagerly_outside_functions(): 121 # Error only appears on the v2 class. 122 return 123 124 model = keras.Sequential([keras.layers.Dense(1)]) 125 model.compile('sgd', 'mse') 126 x, y = np.ones((10, 10)), np.ones((10, 1)) 127 with context.graph_mode(): 128 with self.assertRaisesRegexp( 129 ValueError, 'instance was constructed with eager mode enabled'): 130 model.fit(x, y, batch_size=2) 131 132 133if __name__ == '__main__': 134 test.main() 135