• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""sklearn cross-support (deprecated)."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import os
24
25import numpy as np
26import six
27
28
29def _pprint(d):
30  return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
31
32
33class _BaseEstimator(object):
34  """This is a cross-import when sklearn is not available.
35
36  Adopted from sklearn.BaseEstimator implementation.
37  https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/base.py
38  """
39
40  def get_params(self, deep=True):
41    """Get parameters for this estimator.
42
43    Args:
44      deep: boolean, optional
45
46        If `True`, will return the parameters for this estimator and
47        contained subobjects that are estimators.
48
49    Returns:
50      params : mapping of string to any
51      Parameter names mapped to their values.
52    """
53    out = dict()
54    param_names = [name for name in self.__dict__ if not name.startswith('_')]
55    for key in param_names:
56      value = getattr(self, key, None)
57
58      if isinstance(value, collections.Callable):
59        continue
60
61      # XXX: should we rather test if instance of estimator?
62      if deep and hasattr(value, 'get_params'):
63        deep_items = value.get_params().items()
64        out.update((key + '__' + k, val) for k, val in deep_items)
65      out[key] = value
66    return out
67
68  def set_params(self, **params):
69    """Set the parameters of this estimator.
70
71    The method works on simple estimators as well as on nested objects
72    (such as pipelines). The former have parameters of the form
73    ``<component>__<parameter>`` so that it's possible to update each
74    component of a nested object.
75
76    Args:
77      **params: Parameters.
78
79    Returns:
80      self
81
82    Raises:
83      ValueError: If params contain invalid names.
84    """
85    if not params:
86      # Simple optimisation to gain speed (inspect is slow)
87      return self
88    valid_params = self.get_params(deep=True)
89    for key, value in six.iteritems(params):
90      split = key.split('__', 1)
91      if len(split) > 1:
92        # nested objects case
93        name, sub_name = split
94        if name not in valid_params:
95          raise ValueError('Invalid parameter %s for estimator %s. '
96                           'Check the list of available parameters '
97                           'with `estimator.get_params().keys()`.' %
98                           (name, self))
99        sub_object = valid_params[name]
100        sub_object.set_params(**{sub_name: value})
101      else:
102        # simple objects case
103        if key not in valid_params:
104          raise ValueError('Invalid parameter %s for estimator %s. '
105                           'Check the list of available parameters '
106                           'with `estimator.get_params().keys()`.' %
107                           (key, self.__class__.__name__))
108        setattr(self, key, value)
109    return self
110
111  def __repr__(self):
112    class_name = self.__class__.__name__
113    return '%s(%s)' % (class_name,
114                       _pprint(self.get_params(deep=False)),)
115
116
117# pylint: disable=old-style-class
118class _ClassifierMixin():
119  """Mixin class for all classifiers."""
120  pass
121
122
123class _RegressorMixin():
124  """Mixin class for all regression estimators."""
125  pass
126
127
128class _TransformerMixin():
129  """Mixin class for all transformer estimators."""
130
131
132class NotFittedError(ValueError, AttributeError):
133  """Exception class to raise if estimator is used before fitting.
134
135  USE OF THIS EXCEPTION IS DEPRECATED.
136
137  This class inherits from both ValueError and AttributeError to help with
138  exception handling and backward compatibility.
139
140  Examples:
141  >>> from sklearn.svm import LinearSVC
142  >>> from sklearn.exceptions import NotFittedError
143  >>> try:
144  ...     LinearSVC().predict([[1, 2], [2, 3], [3, 4]])
145  ... except NotFittedError as e:
146  ...     print(repr(e))
147  ...                        # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
148  NotFittedError('This LinearSVC instance is not fitted yet',)
149
150  Copied from
151  https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py
152  """
153
154# pylint: enable=old-style-class
155
156
157def _accuracy_score(y_true, y_pred):
158  score = y_true == y_pred
159  return np.average(score)
160
161
162def _mean_squared_error(y_true, y_pred):
163  if len(y_true.shape) > 1:
164    y_true = np.squeeze(y_true)
165  if len(y_pred.shape) > 1:
166    y_pred = np.squeeze(y_pred)
167  return np.average((y_true - y_pred)**2)
168
169
170def _train_test_split(*args, **options):
171  # pylint: disable=missing-docstring
172  test_size = options.pop('test_size', None)
173  train_size = options.pop('train_size', None)
174  random_state = options.pop('random_state', None)
175
176  if test_size is None and train_size is None:
177    train_size = 0.75
178  elif train_size is None:
179    train_size = 1 - test_size
180  train_size = int(train_size * args[0].shape[0])
181
182  np.random.seed(random_state)
183  indices = np.random.permutation(args[0].shape[0])
184  train_idx, test_idx = indices[:train_size], indices[train_size:]
185  result = []
186  for x in args:
187    result += [x.take(train_idx, axis=0), x.take(test_idx, axis=0)]
188  return tuple(result)
189
190
191# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
192TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
193if TRY_IMPORT_SKLEARN:
194  # pylint: disable=g-import-not-at-top,g-multiple-import,unused-import
195  from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
196  from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
197  from sklearn.model_selection import train_test_split
198  try:
199    from sklearn.exceptions import NotFittedError
200  except ImportError:
201    try:
202      from sklearn.utils.validation import NotFittedError
203    except ImportError:
204      pass
205else:
206  # Naive implementations of sklearn classes and functions.
207  BaseEstimator = _BaseEstimator
208  ClassifierMixin = _ClassifierMixin
209  RegressorMixin = _RegressorMixin
210  TransformerMixin = _TransformerMixin
211  accuracy_score = _accuracy_score
212  log_loss = None
213  mean_squared_error = _mean_squared_error
214  train_test_split = _train_test_split
215