1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Scikit-learn API wrapper.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python import keras 24from tensorflow.python.keras import testing_utils 25from tensorflow.python.keras.wrappers import scikit_learn 26from tensorflow.python.platform import test 27 28INPUT_DIM = 5 29HIDDEN_DIM = 5 30TRAIN_SAMPLES = 10 31TEST_SAMPLES = 5 32NUM_CLASSES = 2 33BATCH_SIZE = 5 34EPOCHS = 1 35 36 37def build_fn_clf(hidden_dim): 38 model = keras.models.Sequential() 39 model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,))) 40 model.add(keras.layers.Activation('relu')) 41 model.add(keras.layers.Dense(hidden_dim)) 42 model.add(keras.layers.Activation('relu')) 43 model.add(keras.layers.Dense(NUM_CLASSES)) 44 model.add(keras.layers.Activation('softmax')) 45 model.compile( 46 optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) 47 return model 48 49 50def assert_classification_works(clf): 51 np.random.seed(42) 52 (x_train, y_train), (x_test, _) = testing_utils.get_test_data( 53 train_samples=TRAIN_SAMPLES, 54 test_samples=TEST_SAMPLES, 55 input_shape=(INPUT_DIM,), 56 num_classes=NUM_CLASSES) 57 58 clf.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS) 59 60 score = clf.score(x_train, y_train, batch_size=BATCH_SIZE) 61 assert np.isscalar(score) and np.isfinite(score) 62 63 preds = clf.predict(x_test, batch_size=BATCH_SIZE) 64 assert preds.shape == (TEST_SAMPLES,) 65 for prediction in np.unique(preds): 66 assert prediction in range(NUM_CLASSES) 67 68 proba = clf.predict_proba(x_test, batch_size=BATCH_SIZE) 69 assert proba.shape == (TEST_SAMPLES, NUM_CLASSES) 70 assert np.allclose(np.sum(proba, axis=1), np.ones(TEST_SAMPLES)) 71 72 73def build_fn_reg(hidden_dim): 74 model = keras.models.Sequential() 75 model.add(keras.layers.Dense(INPUT_DIM, input_shape=(INPUT_DIM,))) 76 model.add(keras.layers.Activation('relu')) 77 model.add(keras.layers.Dense(hidden_dim)) 78 model.add(keras.layers.Activation('relu')) 79 model.add(keras.layers.Dense(1)) 80 model.add(keras.layers.Activation('linear')) 81 model.compile( 82 optimizer='sgd', loss='mean_absolute_error', metrics=['accuracy']) 83 return model 84 85 86def assert_regression_works(reg): 87 np.random.seed(42) 88 (x_train, y_train), (x_test, _) = testing_utils.get_test_data( 89 train_samples=TRAIN_SAMPLES, 90 test_samples=TEST_SAMPLES, 91 input_shape=(INPUT_DIM,), 92 num_classes=NUM_CLASSES) 93 94 reg.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCHS) 95 96 score = reg.score(x_train, y_train, batch_size=BATCH_SIZE) 97 assert np.isscalar(score) and np.isfinite(score) 98 99 preds = reg.predict(x_test, batch_size=BATCH_SIZE) 100 assert preds.shape == (TEST_SAMPLES,) 101 102 103class ScikitLearnAPIWrapperTest(test.TestCase): 104 105 def test_classify_build_fn(self): 106 with self.cached_session(): 107 clf = scikit_learn.KerasClassifier( 108 build_fn=build_fn_clf, 109 hidden_dim=HIDDEN_DIM, 110 batch_size=BATCH_SIZE, 111 epochs=EPOCHS) 112 113 assert_classification_works(clf) 114 115 def test_classify_class_build_fn(self): 116 117 class ClassBuildFnClf(object): 118 119 def __call__(self, hidden_dim): 120 return build_fn_clf(hidden_dim) 121 122 with self.cached_session(): 123 clf = scikit_learn.KerasClassifier( 124 build_fn=ClassBuildFnClf(), 125 hidden_dim=HIDDEN_DIM, 126 batch_size=BATCH_SIZE, 127 epochs=EPOCHS) 128 129 assert_classification_works(clf) 130 131 def test_classify_inherit_class_build_fn(self): 132 133 class InheritClassBuildFnClf(scikit_learn.KerasClassifier): 134 135 def __call__(self, hidden_dim): 136 return build_fn_clf(hidden_dim) 137 138 with self.cached_session(): 139 clf = InheritClassBuildFnClf( 140 build_fn=None, 141 hidden_dim=HIDDEN_DIM, 142 batch_size=BATCH_SIZE, 143 epochs=EPOCHS) 144 145 assert_classification_works(clf) 146 147 def test_regression_build_fn(self): 148 with self.cached_session(): 149 reg = scikit_learn.KerasRegressor( 150 build_fn=build_fn_reg, 151 hidden_dim=HIDDEN_DIM, 152 batch_size=BATCH_SIZE, 153 epochs=EPOCHS) 154 155 assert_regression_works(reg) 156 157 def test_regression_class_build_fn(self): 158 159 class ClassBuildFnReg(object): 160 161 def __call__(self, hidden_dim): 162 return build_fn_reg(hidden_dim) 163 164 with self.cached_session(): 165 reg = scikit_learn.KerasRegressor( 166 build_fn=ClassBuildFnReg(), 167 hidden_dim=HIDDEN_DIM, 168 batch_size=BATCH_SIZE, 169 epochs=EPOCHS) 170 171 assert_regression_works(reg) 172 173 def test_regression_inherit_class_build_fn(self): 174 175 class InheritClassBuildFnReg(scikit_learn.KerasRegressor): 176 177 def __call__(self, hidden_dim): 178 return build_fn_reg(hidden_dim) 179 180 with self.cached_session(): 181 reg = InheritClassBuildFnReg( 182 build_fn=None, 183 hidden_dim=HIDDEN_DIM, 184 batch_size=BATCH_SIZE, 185 epochs=EPOCHS) 186 187 assert_regression_works(reg) 188 189 190if __name__ == '__main__': 191 test.main() 192