• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras utilities to split v1 and v2 classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22
23import numpy as np
24import six
25
26from tensorflow.python import keras
27from tensorflow.python.eager import context
28from tensorflow.python.framework import ops
29from tensorflow.python.keras import keras_parameterized
30from tensorflow.python.keras.engine import training
31from tensorflow.python.keras.engine import training_v1
32from tensorflow.python.platform import test
33
34
35@keras_parameterized.run_all_keras_modes
36class SplitUtilsTest(keras_parameterized.TestCase):
37
38  def _check_model_class(self, model_class):
39    if ops.executing_eagerly_outside_functions():
40      self.assertEqual(model_class, training.Model)
41    else:
42      self.assertEqual(model_class, training_v1.Model)
43
44  def test_functional_model(self):
45    inputs = keras.Input(10)
46    outputs = keras.layers.Dense(1)(inputs)
47    model = keras.Model(inputs, outputs)
48    self._check_model_class(model.__class__)
49
50  def test_sequential_model(self):
51    model = keras.Sequential([keras.layers.Dense(1)])
52    model_class = model.__class__.__bases__[0]
53    self._check_model_class(model_class)
54
55  def test_subclass_model(self):
56
57    class MyModel(keras.Model):
58
59      def call(self, x):
60        return 2 * x
61
62    model = MyModel()
63    model_class = model.__class__.__bases__[0]
64    self._check_model_class(model_class)
65
66  def test_multiple_subclass_model(self):
67
68    class Model1(keras.Model):
69      pass
70
71    class Model2(Model1):
72
73      def call(self, x):
74        return 2 * x
75
76    model = Model2()
77    model_class = model.__class__.__bases__[0].__bases__[0]
78    self._check_model_class(model_class)
79
80  def test_user_provided_metaclass(self):
81
82    @six.add_metaclass(abc.ABCMeta)
83    class AbstractModel(keras.Model):
84
85      @abc.abstractmethod
86      def call(self, inputs):
87        """Calls the model."""
88
89    class MyModel(AbstractModel):
90
91      def call(self, inputs):
92        return 2 * inputs
93
94    with self.assertRaisesRegexp(TypeError, 'instantiate abstract class'):
95      AbstractModel()
96
97    model = MyModel()
98    model_class = model.__class__.__bases__[0].__bases__[0]
99    self._check_model_class(model_class)
100
101  def test_multiple_inheritance(self):
102
103    class Return2(object):
104
105      def return_2(self):
106        return 2
107
108    class MyModel(keras.Model, Return2):
109
110      def call(self, x):
111        return self.return_2() * x
112
113    model = MyModel()
114    bases = model.__class__.__bases__
115    self._check_model_class(bases[0])
116    self.assertEqual(bases[1], Return2)
117    self.assertEqual(model.return_2(), 2)
118
119  def test_fit_error(self):
120    if not ops.executing_eagerly_outside_functions():
121      # Error only appears on the v2 class.
122      return
123
124    model = keras.Sequential([keras.layers.Dense(1)])
125    model.compile('sgd', 'mse')
126    x, y = np.ones((10, 10)), np.ones((10, 1))
127    with context.graph_mode():
128      with self.assertRaisesRegexp(
129          ValueError, 'instance was constructed with eager mode enabled'):
130        model.fit(x, y, batch_size=2)
131
132
133if __name__ == '__main__':
134  test.main()
135