• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Wrapper for using the Scikit-Learn API with Keras models."""
16# pylint: disable=g-classes-have-attributes
17
18import copy
19import types
20
21import numpy as np
22
23from tensorflow.python.keras import losses
24from tensorflow.python.keras.models import Sequential
25from tensorflow.python.keras.utils.generic_utils import has_arg
26from tensorflow.python.keras.utils.np_utils import to_categorical
27from tensorflow.python.util.tf_export import keras_export
28
29
30class BaseWrapper(object):
31  """Base class for the Keras scikit-learn wrapper.
32
33  Warning: This class should not be used directly.
34  Use descendant classes instead.
35
36  Args:
37      build_fn: callable function or class instance
38      **sk_params: model parameters & fitting parameters
39
40  The `build_fn` should construct, compile and return a Keras model, which
41  will then be used to fit/predict. One of the following
42  three values could be passed to `build_fn`:
43  1. A function
44  2. An instance of a class that implements the `__call__` method
45  3. None. This means you implement a class that inherits from either
46  `KerasClassifier` or `KerasRegressor`. The `__call__` method of the
47  present class will then be treated as the default `build_fn`.
48
49  `sk_params` takes both model parameters and fitting parameters. Legal model
50  parameters are the arguments of `build_fn`. Note that like all other
51  estimators in scikit-learn, `build_fn` should provide default values for
52  its arguments, so that you could create the estimator without passing any
53  values to `sk_params`.
54
55  `sk_params` could also accept parameters for calling `fit`, `predict`,
56  `predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`).
57  fitting (predicting) parameters are selected in the following order:
58
59  1. Values passed to the dictionary arguments of
60  `fit`, `predict`, `predict_proba`, and `score` methods
61  2. Values passed to `sk_params`
62  3. The default values of the `keras.models.Sequential`
63  `fit`, `predict`, `predict_proba` and `score` methods
64
65  When using scikit-learn's `grid_search` API, legal tunable parameters are
66  those you could pass to `sk_params`, including fitting parameters.
67  In other words, you could use `grid_search` to search for the best
68  `batch_size` or `epochs` as well as the model parameters.
69  """
70
71  def __init__(self, build_fn=None, **sk_params):
72    self.build_fn = build_fn
73    self.sk_params = sk_params
74    self.check_params(sk_params)
75
76  def check_params(self, params):
77    """Checks for user typos in `params`.
78
79    Args:
80        params: dictionary; the parameters to be checked
81
82    Raises:
83        ValueError: if any member of `params` is not a valid argument.
84    """
85    legal_params_fns = [
86        Sequential.fit, Sequential.predict, Sequential.predict_classes,
87        Sequential.evaluate
88    ]
89    if self.build_fn is None:
90      legal_params_fns.append(self.__call__)
91    elif (not isinstance(self.build_fn, types.FunctionType) and
92          not isinstance(self.build_fn, types.MethodType)):
93      legal_params_fns.append(self.build_fn.__call__)
94    else:
95      legal_params_fns.append(self.build_fn)
96
97    for params_name in params:
98      for fn in legal_params_fns:
99        if has_arg(fn, params_name):
100          break
101      else:
102        if params_name != 'nb_epoch':
103          raise ValueError('{} is not a legal parameter'.format(params_name))
104
105  def get_params(self, **params):  # pylint: disable=unused-argument
106    """Gets parameters for this estimator.
107
108    Args:
109        **params: ignored (exists for API compatibility).
110
111    Returns:
112        Dictionary of parameter names mapped to their values.
113    """
114    res = self.sk_params.copy()
115    res.update({'build_fn': self.build_fn})
116    return res
117
118  def set_params(self, **params):
119    """Sets the parameters of this estimator.
120
121    Args:
122        **params: Dictionary of parameter names mapped to their values.
123
124    Returns:
125        self
126    """
127    self.check_params(params)
128    self.sk_params.update(params)
129    return self
130
131  def fit(self, x, y, **kwargs):
132    """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
133
134    Args:
135        x : array-like, shape `(n_samples, n_features)`
136            Training samples where `n_samples` is the number of samples
137            and `n_features` is the number of features.
138        y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
139            True labels for `x`.
140        **kwargs: dictionary arguments
141            Legal arguments are the arguments of `Sequential.fit`
142
143    Returns:
144        history : object
145            details about the training history at each epoch.
146    """
147    if self.build_fn is None:
148      self.model = self.__call__(**self.filter_sk_params(self.__call__))
149    elif (not isinstance(self.build_fn, types.FunctionType) and
150          not isinstance(self.build_fn, types.MethodType)):
151      self.model = self.build_fn(
152          **self.filter_sk_params(self.build_fn.__call__))
153    else:
154      self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
155
156    if (losses.is_categorical_crossentropy(self.model.loss) and
157        len(y.shape) != 2):
158      y = to_categorical(y)
159
160    fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
161    fit_args.update(kwargs)
162
163    history = self.model.fit(x, y, **fit_args)
164
165    return history
166
167  def filter_sk_params(self, fn, override=None):
168    """Filters `sk_params` and returns those in `fn`'s arguments.
169
170    Args:
171        fn : arbitrary function
172        override: dictionary, values to override `sk_params`
173
174    Returns:
175        res : dictionary containing variables
176            in both `sk_params` and `fn`'s arguments.
177    """
178    override = override or {}
179    res = {}
180    for name, value in self.sk_params.items():
181      if has_arg(fn, name):
182        res.update({name: value})
183    res.update(override)
184    return res
185
186
187@keras_export('keras.wrappers.scikit_learn.KerasClassifier')
188class KerasClassifier(BaseWrapper):
189  """Implementation of the scikit-learn classifier API for Keras.
190  """
191
192  def fit(self, x, y, **kwargs):
193    """Constructs a new model with `build_fn` & fit the model to `(x, y)`.
194
195    Args:
196        x : array-like, shape `(n_samples, n_features)`
197            Training samples where `n_samples` is the number of samples
198            and `n_features` is the number of features.
199        y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
200            True labels for `x`.
201        **kwargs: dictionary arguments
202            Legal arguments are the arguments of `Sequential.fit`
203
204    Returns:
205        history : object
206            details about the training history at each epoch.
207
208    Raises:
209        ValueError: In case of invalid shape for `y` argument.
210    """
211    y = np.array(y)
212    if len(y.shape) == 2 and y.shape[1] > 1:
213      self.classes_ = np.arange(y.shape[1])
214    elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
215      self.classes_ = np.unique(y)
216      y = np.searchsorted(self.classes_, y)
217    else:
218      raise ValueError('Invalid shape for y: ' + str(y.shape))
219    self.n_classes_ = len(self.classes_)
220    return super(KerasClassifier, self).fit(x, y, **kwargs)
221
222  def predict(self, x, **kwargs):
223    """Returns the class predictions for the given test data.
224
225    Args:
226        x: array-like, shape `(n_samples, n_features)`
227            Test samples where `n_samples` is the number of samples
228            and `n_features` is the number of features.
229        **kwargs: dictionary arguments
230            Legal arguments are the arguments
231            of `Sequential.predict_classes`.
232
233    Returns:
234        preds: array-like, shape `(n_samples,)`
235            Class predictions.
236    """
237    kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
238    classes = self.model.predict_classes(x, **kwargs)
239    return self.classes_[classes]
240
241  def predict_proba(self, x, **kwargs):
242    """Returns class probability estimates for the given test data.
243
244    Args:
245        x: array-like, shape `(n_samples, n_features)`
246            Test samples where `n_samples` is the number of samples
247            and `n_features` is the number of features.
248        **kwargs: dictionary arguments
249            Legal arguments are the arguments
250            of `Sequential.predict_classes`.
251
252    Returns:
253        proba: array-like, shape `(n_samples, n_outputs)`
254            Class probability estimates.
255            In the case of binary classification,
256            to match the scikit-learn API,
257            will return an array of shape `(n_samples, 2)`
258            (instead of `(n_sample, 1)` as in Keras).
259    """
260    kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
261    probs = self.model.predict(x, **kwargs)
262
263    # check if binary classification
264    if probs.shape[1] == 1:
265      # first column is probability of class 0 and second is of class 1
266      probs = np.hstack([1 - probs, probs])
267    return probs
268
269  def score(self, x, y, **kwargs):
270    """Returns the mean accuracy on the given test data and labels.
271
272    Args:
273        x: array-like, shape `(n_samples, n_features)`
274            Test samples where `n_samples` is the number of samples
275            and `n_features` is the number of features.
276        y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
277            True labels for `x`.
278        **kwargs: dictionary arguments
279            Legal arguments are the arguments of `Sequential.evaluate`.
280
281    Returns:
282        score: float
283            Mean accuracy of predictions on `x` wrt. `y`.
284
285    Raises:
286        ValueError: If the underlying model isn't configured to
287            compute accuracy. You should pass `metrics=["accuracy"]` to
288            the `.compile()` method of the model.
289    """
290    y = np.searchsorted(self.classes_, y)
291    kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
292
293    loss_name = self.model.loss
294    if hasattr(loss_name, '__name__'):
295      loss_name = loss_name.__name__
296    if loss_name == 'categorical_crossentropy' and len(y.shape) != 2:
297      y = to_categorical(y)
298
299    outputs = self.model.evaluate(x, y, **kwargs)
300    if not isinstance(outputs, list):
301      outputs = [outputs]
302    for name, output in zip(self.model.metrics_names, outputs):
303      if name in ['accuracy', 'acc']:
304        return output
305    raise ValueError('The model is not configured to compute accuracy. '
306                     'You should pass `metrics=["accuracy"]` to '
307                     'the `model.compile()` method.')
308
309
310@keras_export('keras.wrappers.scikit_learn.KerasRegressor')
311class KerasRegressor(BaseWrapper):
312  """Implementation of the scikit-learn regressor API for Keras.
313  """
314
315  def predict(self, x, **kwargs):
316    """Returns predictions for the given test data.
317
318    Args:
319        x: array-like, shape `(n_samples, n_features)`
320            Test samples where `n_samples` is the number of samples
321            and `n_features` is the number of features.
322        **kwargs: dictionary arguments
323            Legal arguments are the arguments of `Sequential.predict`.
324
325    Returns:
326        preds: array-like, shape `(n_samples,)`
327            Predictions.
328    """
329    kwargs = self.filter_sk_params(Sequential.predict, kwargs)
330    return np.squeeze(self.model.predict(x, **kwargs))
331
332  def score(self, x, y, **kwargs):
333    """Returns the mean loss on the given test data and labels.
334
335    Args:
336        x: array-like, shape `(n_samples, n_features)`
337            Test samples where `n_samples` is the number of samples
338            and `n_features` is the number of features.
339        y: array-like, shape `(n_samples,)`
340            True labels for `x`.
341        **kwargs: dictionary arguments
342            Legal arguments are the arguments of `Sequential.evaluate`.
343
344    Returns:
345        score: float
346            Mean accuracy of predictions on `x` wrt. `y`.
347    """
348    kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
349    loss = self.model.evaluate(x, y, **kwargs)
350    if isinstance(loss, list):
351      return -loss[0]
352    return -loss
353