• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Scikit-learn API wrapper."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.keras import testing_utils
25from tensorflow.python.keras.wrappers import scikit_learn
26from tensorflow.python.platform import test
27
28INPUT_DIM = 5
29HIDDEN_DIM = 5
30TRAIN_SAMPLES = 10
31TEST_SAMPLES = 5
32NUM_CLASSES = 2
33BATCH_SIZE = 5
34EPOCHS = 1
35
36
37def build_fn_clf(hidden_dim):
38  model = keras.models.Sequential()
39  model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,)))
40  model.add(keras.layers.Activation('relu'))
41  model.add(keras.layers.Dense(hidden_dim))
42  model.add(keras.layers.Activation('relu'))
43  model.add(keras.layers.Dense(NUM_CLASSES))
44  model.add(keras.layers.Activation('softmax'))
45  model.compile(
46      optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
47  return model
48
49
50def assert_classification_works(clf):
51  np.random.seed(42)
52  (x_train, y_train), (x_test, _) = testing_utils.get_test_data(
53      train_samples=TRAIN_SAMPLES,
54      test_samples=TEST_SAMPLES,
55      input_shape=(INPUT_DIM,),
56      num_classes=NUM_CLASSES)
57
58  clf.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS)
59
60  score = clf.score(x_train, y_train, batch_size=BATCH_SIZE)
61  assert np.isscalar(score) and np.isfinite(score)
62
63  preds = clf.predict(x_test, batch_size=BATCH_SIZE)
64  assert preds.shape == (TEST_SAMPLES,)
65  for prediction in np.unique(preds):
66    assert prediction in range(NUM_CLASSES)
67
68  proba = clf.predict_proba(x_test, batch_size=BATCH_SIZE)
69  assert proba.shape == (TEST_SAMPLES, NUM_CLASSES)
70  assert np.allclose(np.sum(proba, axis=1), np.ones(TEST_SAMPLES))
71
72
73def build_fn_reg(hidden_dim):
74  model = keras.models.Sequential()
75  model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,)))
76  model.add(keras.layers.Activation('relu'))
77  model.add(keras.layers.Dense(hidden_dim))
78  model.add(keras.layers.Activation('relu'))
79  model.add(keras.layers.Dense(1))
80  model.add(keras.layers.Activation('linear'))
81  model.compile(
82      optimizer='sgd', loss='mean_absolute_error', metrics=['accuracy'])
83  return model
84
85
86def assert_regression_works(reg):
87  np.random.seed(42)
88  (x_train, y_train), (x_test, _) = testing_utils.get_test_data(
89      train_samples=TRAIN_SAMPLES,
90      test_samples=TEST_SAMPLES,
91      input_shape=(INPUT_DIM,),
92      num_classes=NUM_CLASSES)
93
94  reg.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS)
95
96  score = reg.score(x_train, y_train, batch_size=BATCH_SIZE)
97  assert np.isscalar(score) and np.isfinite(score)
98
99  preds = reg.predict(x_test, batch_size=BATCH_SIZE)
100  assert preds.shape == (TEST_SAMPLES,)
101
102
103class ScikitLearnAPIWrapperTest(test.TestCase):
104
105  def test_classify_build_fn(self):
106    with self.cached_session():
107      clf = scikit_learn.KerasClassifier(
108          build_fn=build_fn_clf,
109          hidden_dim=HIDDEN_DIM,
110          batch_size=BATCH_SIZE,
111          epochs=EPOCHS)
112
113      assert_classification_works(clf)
114
115  def test_classify_class_build_fn(self):
116
117    class ClassBuildFnClf(object):
118
119      def __call__(self, hidden_dim):
120        return build_fn_clf(hidden_dim)
121
122    with self.cached_session():
123      clf = scikit_learn.KerasClassifier(
124          build_fn=ClassBuildFnClf(),
125          hidden_dim=HIDDEN_DIM,
126          batch_size=BATCH_SIZE,
127          epochs=EPOCHS)
128
129      assert_classification_works(clf)
130
131  def test_classify_inherit_class_build_fn(self):
132
133    class InheritClassBuildFnClf(scikit_learn.KerasClassifier):
134
135      def __call__(self, hidden_dim):
136        return build_fn_clf(hidden_dim)
137
138    with self.cached_session():
139      clf = InheritClassBuildFnClf(
140          build_fn=None,
141          hidden_dim=HIDDEN_DIM,
142          batch_size=BATCH_SIZE,
143          epochs=EPOCHS)
144
145      assert_classification_works(clf)
146
147  def test_regression_build_fn(self):
148    with self.cached_session():
149      reg = scikit_learn.KerasRegressor(
150          build_fn=build_fn_reg,
151          hidden_dim=HIDDEN_DIM,
152          batch_size=BATCH_SIZE,
153          epochs=EPOCHS)
154
155      assert_regression_works(reg)
156
157  def test_regression_class_build_fn(self):
158
159    class ClassBuildFnReg(object):
160
161      def __call__(self, hidden_dim):
162        return build_fn_reg(hidden_dim)
163
164    with self.cached_session():
165      reg = scikit_learn.KerasRegressor(
166          build_fn=ClassBuildFnReg(),
167          hidden_dim=HIDDEN_DIM,
168          batch_size=BATCH_SIZE,
169          epochs=EPOCHS)
170
171      assert_regression_works(reg)
172
173  def test_regression_inherit_class_build_fn(self):
174
175    class InheritClassBuildFnReg(scikit_learn.KerasRegressor):
176
177      def __call__(self, hidden_dim):
178        return build_fn_reg(hidden_dim)
179
180    with self.cached_session():
181      reg = InheritClassBuildFnReg(
182          build_fn=None,
183          hidden_dim=HIDDEN_DIM,
184          batch_size=BATCH_SIZE,
185          epochs=EPOCHS)
186
187      assert_regression_works(reg)
188
189
190if __name__ == '__main__':
191  test.main()
192