• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras utilities to split v1 and v2 classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import numpy as np
24import six
25
26from tensorflow.python import keras
27from tensorflow.python.framework import ops
28from tensorflow.python.keras import keras_parameterized
29from tensorflow.python.keras.engine import base_layer
30from tensorflow.python.keras.engine import base_layer_v1
31from tensorflow.python.keras.engine import training
32from tensorflow.python.keras.engine import training_v1
33from tensorflow.python.platform import test
34
35
36@keras_parameterized.run_all_keras_modes
37class SplitUtilsTest(keras_parameterized.TestCase):
38
39  def _check_model_class(self, model_class):
40    if ops.executing_eagerly_outside_functions():
41      self.assertEqual(model_class, training.Model)
42    else:
43      self.assertEqual(model_class, training_v1.Model)
44
45  def _check_layer_class(self, layer):
46    if ops.executing_eagerly_outside_functions():
47      self.assertIsInstance(layer, base_layer.Layer)
48      self.assertNotIsInstance(layer, base_layer_v1.Layer)
49    else:
50      self.assertIsInstance(layer, base_layer_v1.Layer)
51
52  def test_functional_model(self):
53    inputs = keras.Input(10)
54    outputs = keras.layers.Dense(1)(inputs)
55    model = keras.Model(inputs, outputs)
56    self._check_model_class(model.__class__.__bases__[0])
57    self._check_layer_class(model)
58
59  def test_subclass_model_with_functional_init(self):
60    inputs = keras.Input(10)
61    outputs = keras.layers.Dense(1)(inputs)
62
63    class MyModel(keras.Model):
64      pass
65
66    model = MyModel(inputs, outputs)
67    model_class = model.__class__.__bases__[0].__bases__[0]
68    self._check_model_class(model_class)
69    self._check_layer_class(model)
70
71  def test_subclass_model_with_functional_init_interleaved_v1_functional(self):
72    with ops.Graph().as_default():
73      inputs = keras.Input(10)
74      outputs = keras.layers.Dense(1)(inputs)
75      _ = keras.Model(inputs, outputs)
76
77    inputs = keras.Input(10)
78    outputs = keras.layers.Dense(1)(inputs)
79
80    class MyModel(keras.Model):
81      pass
82
83    model = MyModel(inputs, outputs)
84    model_class = model.__class__.__bases__[0].__bases__[0]
85    self._check_model_class(model_class)
86    self._check_layer_class(model)
87
88  def test_sequential_model(self):
89    model = keras.Sequential([keras.layers.Dense(1)])
90    model_class = model.__class__.__bases__[0].__bases__[0]
91    self._check_model_class(model_class)
92    self._check_layer_class(model)
93
94  def test_subclass_model(self):
95
96    class MyModel(keras.Model):
97
98      def call(self, x):
99        return 2 * x
100
101    model = MyModel()
102    model_class = model.__class__.__bases__[0]
103    self._check_model_class(model_class)
104    self._check_layer_class(model)
105
106  def test_layer(self):
107    class IdentityLayer(base_layer.Layer):
108      """A layer that returns it's input.
109
110      Useful for testing a layer without a variable.
111      """
112
113      def call(self, inputs):
114        return inputs
115
116    layer = IdentityLayer()
117    self._check_layer_class(layer)
118
119  def test_multiple_subclass_model(self):
120
121    class Model1(keras.Model):
122      pass
123
124    class Model2(Model1):
125
126      def call(self, x):
127        return 2 * x
128
129    model = Model2()
130    model_class = model.__class__.__bases__[0].__bases__[0]
131    self._check_model_class(model_class)
132    self._check_layer_class(model)
133
134  def test_user_provided_metaclass(self):
135
136    @six.add_metaclass(abc.ABCMeta)
137    class AbstractModel(keras.Model):
138
139      @abc.abstractmethod
140      def call(self, inputs):
141        """Calls the model."""
142
143    class MyModel(AbstractModel):
144
145      def call(self, inputs):
146        return 2 * inputs
147
148    with self.assertRaisesRegex(TypeError, 'instantiate abstract class'):
149      AbstractModel()
150
151    model = MyModel()
152    model_class = model.__class__.__bases__[0].__bases__[0]
153    self._check_model_class(model_class)
154    self._check_layer_class(model)
155
156  def test_multiple_inheritance(self):
157
158    class Return2(object):
159
160      def return_2(self):
161        return 2
162
163    class MyModel(keras.Model, Return2):
164
165      def call(self, x):
166        return self.return_2() * x
167
168    model = MyModel()
169    bases = model.__class__.__bases__
170    self._check_model_class(bases[0])
171    self.assertEqual(bases[1], Return2)
172    self.assertEqual(model.return_2(), 2)
173    self._check_layer_class(model)
174
175  def test_fit_error(self):
176    if not ops.executing_eagerly_outside_functions():
177      # Error only appears on the v2 class.
178      return
179
180    model = keras.Sequential([keras.layers.Dense(1)])
181    model.compile('sgd', 'mse')
182    x, y = np.ones((10, 10)), np.ones((10, 1))
183    with ops.get_default_graph().as_default():
184      with self.assertRaisesRegex(
185          ValueError, 'instance was constructed with eager mode enabled'):
186        model.fit(x, y, batch_size=2)
187
188
189if __name__ == '__main__':
190  test.main()
191