• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support Vector Machine (SVM) Estimator (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26from tensorflow.contrib import layers
27from tensorflow.contrib.framework import deprecated
28from tensorflow.contrib.framework import deprecated_arg_values
29from tensorflow.contrib.learn.python.learn.estimators import estimator
30from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
31from tensorflow.contrib.learn.python.learn.estimators import linear
32from tensorflow.contrib.learn.python.learn.estimators import prediction_key
33from tensorflow.contrib.linear_optimizer.python import sdca_optimizer
34
35
36def _as_iterable(preds, output):
37  for pred in preds:
38    yield pred[output]
39
40
41class SVM(estimator.Estimator):
42  """Support Vector Machine (SVM) model for binary classification.
43
44  THIS CLASS IS DEPRECATED. See
45  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
46  for general migration instructions.
47
48  Currently, only linear SVMs are supported. For the underlying optimization
49  problem, the `SDCAOptimizer` is used. For performance and convergence tuning,
50  the num_loss_partitions parameter passed to `SDCAOptimizer` (see `__init__()`
51  method), should be set to (#concurrent train ops per worker) x (#workers). If
52  num_loss_partitions is larger or equal to this value, convergence is
53  guaranteed but becomes slower as num_loss_partitions increases. If it is set
54  to a smaller value, the optimizer is more aggressive in reducing the global
55  loss but convergence is not guaranteed. The recommended value in an
56  `Estimator` (where there is one process per worker) is the number of workers
57  running the train steps. It defaults to 1 (single machine).
58
59  Example:
60
61  ```python
62  real_feature_column = real_valued_column(...)
63  sparse_feature_column = sparse_column_with_hash_bucket(...)
64
65  estimator = SVM(
66      example_id_column='example_id',
67      feature_columns=[real_feature_column, sparse_feature_column],
68      l2_regularization=10.0)
69
70  # Input builders
71  def input_fn_train: # returns x, y
72    ...
73  def input_fn_eval: # returns x, y
74    ...
75
76  estimator.fit(input_fn=input_fn_train)
77  estimator.evaluate(input_fn=input_fn_eval)
78  estimator.predict(x=x)
79  ```
80
81  Input of `fit` and `evaluate` should have following features, otherwise there
82  will be a `KeyError`:
83    a feature with `key=example_id_column` whose value is a `Tensor` of dtype
84    string.
85    if `weight_column_name` is not `None`, a feature with
86    `key=weight_column_name` whose value is a `Tensor`.
87    for each `column` in `feature_columns`:
88      - if `column` is a `SparseColumn`, a feature with `key=column.name`
89        whose `value` is a `SparseTensor`.
90      - if `column` is a `RealValuedColumn, a feature with `key=column.name`
91        whose `value` is a `Tensor`.
92  """
93
94  def __init__(self,
95               example_id_column,
96               feature_columns,
97               weight_column_name=None,
98               model_dir=None,
99               l1_regularization=0.0,
100               l2_regularization=0.0,
101               num_loss_partitions=1,
102               kernels=None,
103               config=None,
104               feature_engineering_fn=None):
105    """Constructs an `SVM` estimator object.
106
107    Args:
108      example_id_column: A string defining the feature column name representing
109        example ids. Used to initialize the underlying optimizer.
110      feature_columns: An iterable containing all the feature columns used by
111        the model. All items in the set should be instances of classes derived
112        from `FeatureColumn`.
113      weight_column_name: A string defining feature column name representing
114        weights. It is used to down weight or boost examples during training. It
115        will be multiplied by the loss of the example.
116      model_dir: Directory to save model parameters, graph and etc. This can
117        also be used to load checkpoints from the directory into a estimator to
118        continue training a previously saved model.
119      l1_regularization: L1-regularization parameter. Refers to global L1
120        regularization (across all examples).
121      l2_regularization: L2-regularization parameter. Refers to global L2
122        regularization (across all examples).
123      num_loss_partitions: number of partitions of the (global) loss function
124        optimized by the underlying optimizer (SDCAOptimizer).
125      kernels: A list of kernels for the SVM. Currently, no kernels are
126        supported. Reserved for future use for non-linear SVMs.
127      config: RunConfig object to configure the runtime settings.
128      feature_engineering_fn: Feature engineering function. Takes features and
129                        labels which are the output of `input_fn` and
130                        returns features and labels which will be fed
131                        into the model.
132
133    Raises:
134      ValueError: if kernels passed is not None.
135    """
136    if kernels is not None:
137      raise ValueError("Kernel SVMs are not currently supported.")
138    optimizer = sdca_optimizer.SDCAOptimizer(
139        example_id_column=example_id_column,
140        num_loss_partitions=num_loss_partitions,
141        symmetric_l1_regularization=l1_regularization,
142        symmetric_l2_regularization=l2_regularization)
143
144    self._feature_columns = feature_columns
145    chief_hook = linear._SdcaUpdateWeightsHook()  # pylint: disable=protected-access
146    super(SVM, self).__init__(
147        model_fn=linear.sdca_model_fn,
148        model_dir=model_dir,
149        config=config,
150        params={
151            "head": head_lib.binary_svm_head(
152                weight_column_name=weight_column_name,
153                enable_centered_bias=False),
154            "feature_columns": feature_columns,
155            "optimizer": optimizer,
156            "weight_column_name": weight_column_name,
157            "update_weights_hook": chief_hook,
158        },
159        feature_engineering_fn=feature_engineering_fn)
160
161  @deprecated_arg_values(
162      estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
163      as_iterable=False)
164  def predict_classes(self, x=None, input_fn=None, batch_size=None,
165                      as_iterable=True):
166    """Runs inference to determine the predicted class."""
167    key = prediction_key.PredictionKey.CLASSES
168    preds = super(SVM, self).predict(
169        x=x,
170        input_fn=input_fn,
171        batch_size=batch_size,
172        outputs=[key],
173        as_iterable=as_iterable)
174    if as_iterable:
175      return _as_iterable(preds, output=key)
176    return preds[key]
177
178  @deprecated_arg_values(
179      estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS,
180      as_iterable=False)
181  def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None,
182                    as_iterable=True):
183    """Runs inference to determine the class probability predictions."""
184    key = prediction_key.PredictionKey.PROBABILITIES
185    preds = super(SVM, self).predict(
186        x=x,
187        input_fn=input_fn,
188        batch_size=batch_size,
189        outputs=[key],
190        as_iterable=as_iterable)
191    if as_iterable:
192      return _as_iterable(preds, output=key)
193    return preds[key]
194  # pylint: enable=protected-access
195
196  @deprecated("2017-03-25", "Please use Estimator.export_savedmodel() instead.")
197  def export(self, export_dir, signature_fn=None,
198             input_fn=None, default_batch_size=1,
199             exports_to_keep=None):
200    """See BaseEstimator.export."""
201    return self.export_with_defaults(
202        export_dir=export_dir,
203        signature_fn=signature_fn,
204        input_fn=input_fn,
205        default_batch_size=default_batch_size,
206        exports_to_keep=exports_to_keep)
207
208  @deprecated("2017-03-25", "Please use Estimator.export_savedmodel() instead.")
209  def export_with_defaults(
210      self,
211      export_dir,
212      signature_fn=None,
213      input_fn=None,
214      default_batch_size=1,
215      exports_to_keep=None):
216    """Same as BaseEstimator.export, but uses some defaults."""
217    def default_input_fn(unused_estimator, examples):
218      return layers.parse_feature_columns_from_examples(
219          examples, self._feature_columns)
220    return super(SVM, self).export(export_dir=export_dir,
221                                   signature_fn=signature_fn,
222                                   input_fn=input_fn or default_input_fn,
223                                   default_batch_size=default_batch_size,
224                                   exports_to_keep=exports_to_keep)
225