1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Wrapper for using the Scikit-Learn API with Keras models.""" 16# pylint: disable=g-classes-have-attributes 17 18import copy 19import types 20 21import numpy as np 22 23from tensorflow.python.keras import losses 24from tensorflow.python.keras.models import Sequential 25from tensorflow.python.keras.utils.generic_utils import has_arg 26from tensorflow.python.keras.utils.np_utils import to_categorical 27from tensorflow.python.util.tf_export import keras_export 28 29 30class BaseWrapper(object): 31 """Base class for the Keras scikit-learn wrapper. 32 33 Warning: This class should not be used directly. 34 Use descendant classes instead. 35 36 Args: 37 build_fn: callable function or class instance 38 **sk_params: model parameters & fitting parameters 39 40 The `build_fn` should construct, compile and return a Keras model, which 41 will then be used to fit/predict. One of the following 42 three values could be passed to `build_fn`: 43 1. A function 44 2. An instance of a class that implements the `__call__` method 45 3. None. This means you implement a class that inherits from either 46 `KerasClassifier` or `KerasRegressor`. The `__call__` method of the 47 present class will then be treated as the default `build_fn`. 48 49 `sk_params` takes both model parameters and fitting parameters. Legal model 50 parameters are the arguments of `build_fn`. Note that like all other 51 estimators in scikit-learn, `build_fn` should provide default values for 52 its arguments, so that you could create the estimator without passing any 53 values to `sk_params`. 54 55 `sk_params` could also accept parameters for calling `fit`, `predict`, 56 `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`). 57 fitting (predicting) parameters are selected in the following order: 58 59 1. Values passed to the dictionary arguments of 60 `fit`, `predict`, `predict_proba`, and `score` methods 61 2. Values passed to `sk_params` 62 3. The default values of the `keras.models.Sequential` 63 `fit`, `predict`, `predict_proba` and `score` methods 64 65 When using scikit-learn's `grid_search` API, legal tunable parameters are 66 those you could pass to `sk_params`, including fitting parameters. 67 In other words, you could use `grid_search` to search for the best 68 `batch_size` or `epochs` as well as the model parameters. 69 """ 70 71 def __init__(self, build_fn=None, **sk_params): 72 self.build_fn = build_fn 73 self.sk_params = sk_params 74 self.check_params(sk_params) 75 76 def check_params(self, params): 77 """Checks for user typos in `params`. 78 79 Args: 80 params: dictionary; the parameters to be checked 81 82 Raises: 83 ValueError: if any member of `params` is not a valid argument. 84 """ 85 legal_params_fns = [ 86 Sequential.fit, Sequential.predict, Sequential.predict_classes, 87 Sequential.evaluate 88 ] 89 if self.build_fn is None: 90 legal_params_fns.append(self.__call__) 91 elif (not isinstance(self.build_fn, types.FunctionType) and 92 not isinstance(self.build_fn, types.MethodType)): 93 legal_params_fns.append(self.build_fn.__call__) 94 else: 95 legal_params_fns.append(self.build_fn) 96 97 for params_name in params: 98 for fn in legal_params_fns: 99 if has_arg(fn, params_name): 100 break 101 else: 102 if params_name != 'nb_epoch': 103 raise ValueError('{} is not a legal parameter'.format(params_name)) 104 105 def get_params(self, **params): # pylint: disable=unused-argument 106 """Gets parameters for this estimator. 107 108 Args: 109 **params: ignored (exists for API compatibility). 110 111 Returns: 112 Dictionary of parameter names mapped to their values. 113 """ 114 res = self.sk_params.copy() 115 res.update({'build_fn': self.build_fn}) 116 return res 117 118 def set_params(self, **params): 119 """Sets the parameters of this estimator. 120 121 Args: 122 **params: Dictionary of parameter names mapped to their values. 123 124 Returns: 125 self 126 """ 127 self.check_params(params) 128 self.sk_params.update(params) 129 return self 130 131 def fit(self, x, y, **kwargs): 132 """Constructs a new model with `build_fn` & fit the model to `(x, y)`. 133 134 Args: 135 x : array-like, shape `(n_samples, n_features)` 136 Training samples where `n_samples` is the number of samples 137 and `n_features` is the number of features. 138 y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 139 True labels for `x`. 140 **kwargs: dictionary arguments 141 Legal arguments are the arguments of `Sequential.fit` 142 143 Returns: 144 history : object 145 details about the training history at each epoch. 146 """ 147 if self.build_fn is None: 148 self.model = self.__call__(**self.filter_sk_params(self.__call__)) 149 elif (not isinstance(self.build_fn, types.FunctionType) and 150 not isinstance(self.build_fn, types.MethodType)): 151 self.model = self.build_fn( 152 **self.filter_sk_params(self.build_fn.__call__)) 153 else: 154 self.model = self.build_fn(**self.filter_sk_params(self.build_fn)) 155 156 if (losses.is_categorical_crossentropy(self.model.loss) and 157 len(y.shape) != 2): 158 y = to_categorical(y) 159 160 fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit)) 161 fit_args.update(kwargs) 162 163 history = self.model.fit(x, y, **fit_args) 164 165 return history 166 167 def filter_sk_params(self, fn, override=None): 168 """Filters `sk_params` and returns those in `fn`'s arguments. 169 170 Args: 171 fn : arbitrary function 172 override: dictionary, values to override `sk_params` 173 174 Returns: 175 res : dictionary containing variables 176 in both `sk_params` and `fn`'s arguments. 177 """ 178 override = override or {} 179 res = {} 180 for name, value in self.sk_params.items(): 181 if has_arg(fn, name): 182 res.update({name: value}) 183 res.update(override) 184 return res 185 186 187@keras_export('keras.wrappers.scikit_learn.KerasClassifier') 188class KerasClassifier(BaseWrapper): 189 """Implementation of the scikit-learn classifier API for Keras. 190 """ 191 192 def fit(self, x, y, **kwargs): 193 """Constructs a new model with `build_fn` & fit the model to `(x, y)`. 194 195 Args: 196 x : array-like, shape `(n_samples, n_features)` 197 Training samples where `n_samples` is the number of samples 198 and `n_features` is the number of features. 199 y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 200 True labels for `x`. 201 **kwargs: dictionary arguments 202 Legal arguments are the arguments of `Sequential.fit` 203 204 Returns: 205 history : object 206 details about the training history at each epoch. 207 208 Raises: 209 ValueError: In case of invalid shape for `y` argument. 210 """ 211 y = np.array(y) 212 if len(y.shape) == 2 and y.shape[1] > 1: 213 self.classes_ = np.arange(y.shape[1]) 214 elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1: 215 self.classes_ = np.unique(y) 216 y = np.searchsorted(self.classes_, y) 217 else: 218 raise ValueError('Invalid shape for y: ' + str(y.shape)) 219 self.n_classes_ = len(self.classes_) 220 return super(KerasClassifier, self).fit(x, y, **kwargs) 221 222 def predict(self, x, **kwargs): 223 """Returns the class predictions for the given test data. 224 225 Args: 226 x: array-like, shape `(n_samples, n_features)` 227 Test samples where `n_samples` is the number of samples 228 and `n_features` is the number of features. 229 **kwargs: dictionary arguments 230 Legal arguments are the arguments 231 of `Sequential.predict_classes`. 232 233 Returns: 234 preds: array-like, shape `(n_samples,)` 235 Class predictions. 236 """ 237 kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs) 238 classes = self.model.predict_classes(x, **kwargs) 239 return self.classes_[classes] 240 241 def predict_proba(self, x, **kwargs): 242 """Returns class probability estimates for the given test data. 243 244 Args: 245 x: array-like, shape `(n_samples, n_features)` 246 Test samples where `n_samples` is the number of samples 247 and `n_features` is the number of features. 248 **kwargs: dictionary arguments 249 Legal arguments are the arguments 250 of `Sequential.predict_classes`. 251 252 Returns: 253 proba: array-like, shape `(n_samples, n_outputs)` 254 Class probability estimates. 255 In the case of binary classification, 256 to match the scikit-learn API, 257 will return an array of shape `(n_samples, 2)` 258 (instead of `(n_sample, 1)` as in Keras). 259 """ 260 kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs) 261 probs = self.model.predict(x, **kwargs) 262 263 # check if binary classification 264 if probs.shape[1] == 1: 265 # first column is probability of class 0 and second is of class 1 266 probs = np.hstack([1 - probs, probs]) 267 return probs 268 269 def score(self, x, y, **kwargs): 270 """Returns the mean accuracy on the given test data and labels. 271 272 Args: 273 x: array-like, shape `(n_samples, n_features)` 274 Test samples where `n_samples` is the number of samples 275 and `n_features` is the number of features. 276 y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` 277 True labels for `x`. 278 **kwargs: dictionary arguments 279 Legal arguments are the arguments of `Sequential.evaluate`. 280 281 Returns: 282 score: float 283 Mean accuracy of predictions on `x` wrt. `y`. 284 285 Raises: 286 ValueError: If the underlying model isn't configured to 287 compute accuracy. You should pass `metrics=["accuracy"]` to 288 the `.compile()` method of the model. 289 """ 290 y = np.searchsorted(self.classes_, y) 291 kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) 292 293 loss_name = self.model.loss 294 if hasattr(loss_name, '__name__'): 295 loss_name = loss_name.__name__ 296 if loss_name == 'categorical_crossentropy' and len(y.shape) != 2: 297 y = to_categorical(y) 298 299 outputs = self.model.evaluate(x, y, **kwargs) 300 if not isinstance(outputs, list): 301 outputs = [outputs] 302 for name, output in zip(self.model.metrics_names, outputs): 303 if name in ['accuracy', 'acc']: 304 return output 305 raise ValueError('The model is not configured to compute accuracy. ' 306 'You should pass `metrics=["accuracy"]` to ' 307 'the `model.compile()` method.') 308 309 310@keras_export('keras.wrappers.scikit_learn.KerasRegressor') 311class KerasRegressor(BaseWrapper): 312 """Implementation of the scikit-learn regressor API for Keras. 313 """ 314 315 def predict(self, x, **kwargs): 316 """Returns predictions for the given test data. 317 318 Args: 319 x: array-like, shape `(n_samples, n_features)` 320 Test samples where `n_samples` is the number of samples 321 and `n_features` is the number of features. 322 **kwargs: dictionary arguments 323 Legal arguments are the arguments of `Sequential.predict`. 324 325 Returns: 326 preds: array-like, shape `(n_samples,)` 327 Predictions. 328 """ 329 kwargs = self.filter_sk_params(Sequential.predict, kwargs) 330 return np.squeeze(self.model.predict(x, **kwargs)) 331 332 def score(self, x, y, **kwargs): 333 """Returns the mean loss on the given test data and labels. 334 335 Args: 336 x: array-like, shape `(n_samples, n_features)` 337 Test samples where `n_samples` is the number of samples 338 and `n_features` is the number of features. 339 y: array-like, shape `(n_samples,)` 340 True labels for `x`. 341 **kwargs: dictionary arguments 342 Legal arguments are the arguments of `Sequential.evaluate`. 343 344 Returns: 345 score: float 346 Mean accuracy of predictions on `x` wrt. `y`. 347 """ 348 kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) 349 loss = self.model.evaluate(x, y, **kwargs) 350 if isinstance(loss, list): 351 return -loss[0] 352 return -loss 353