1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support Vector Machine (SVM) Estimator (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.contrib import layers 27from tensorflow.contrib.framework import deprecated 28from tensorflow.contrib.framework import deprecated_arg_values 29from tensorflow.contrib.learn.python.learn.estimators import estimator 30from tensorflow.contrib.learn.python.learn.estimators import head as head_lib 31from tensorflow.contrib.learn.python.learn.estimators import linear 32from tensorflow.contrib.learn.python.learn.estimators import prediction_key 33from tensorflow.contrib.linear_optimizer.python import sdca_optimizer 34 35 36def _as_iterable(preds, output): 37 for pred in preds: 38 yield pred[output] 39 40 41class SVM(estimator.Estimator): 42 """Support Vector Machine (SVM) model for binary classification. 43 44 THIS CLASS IS DEPRECATED. See 45 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 46 for general migration instructions. 47 48 Currently, only linear SVMs are supported. For the underlying optimization 49 problem, the `SDCAOptimizer` is used. For performance and convergence tuning, 50 the num_loss_partitions parameter passed to `SDCAOptimizer` (see `__init__()` 51 method), should be set to (#concurrent train ops per worker) x (#workers). If 52 num_loss_partitions is larger or equal to this value, convergence is 53 guaranteed but becomes slower as num_loss_partitions increases. If it is set 54 to a smaller value, the optimizer is more aggressive in reducing the global 55 loss but convergence is not guaranteed. The recommended value in an 56 `Estimator` (where there is one process per worker) is the number of workers 57 running the train steps. It defaults to 1 (single machine). 58 59 Example: 60 61 ```python 62 real_feature_column = real_valued_column(...) 63 sparse_feature_column = sparse_column_with_hash_bucket(...) 64 65 estimator = SVM( 66 example_id_column='example_id', 67 feature_columns=[real_feature_column, sparse_feature_column], 68 l2_regularization=10.0) 69 70 # Input builders 71 def input_fn_train: # returns x, y 72 ... 73 def input_fn_eval: # returns x, y 74 ... 75 76 estimator.fit(input_fn=input_fn_train) 77 estimator.evaluate(input_fn=input_fn_eval) 78 estimator.predict(x=x) 79 ``` 80 81 Input of `fit` and `evaluate` should have following features, otherwise there 82 will be a `KeyError`: 83 a feature with `key=example_id_column` whose value is a `Tensor` of dtype 84 string. 85 if `weight_column_name` is not `None`, a feature with 86 `key=weight_column_name` whose value is a `Tensor`. 87 for each `column` in `feature_columns`: 88 - if `column` is a `SparseColumn`, a feature with `key=column.name` 89 whose `value` is a `SparseTensor`. 90 - if `column` is a `RealValuedColumn, a feature with `key=column.name` 91 whose `value` is a `Tensor`. 92 """ 93 94 def __init__(self, 95 example_id_column, 96 feature_columns, 97 weight_column_name=None, 98 model_dir=None, 99 l1_regularization=0.0, 100 l2_regularization=0.0, 101 num_loss_partitions=1, 102 kernels=None, 103 config=None, 104 feature_engineering_fn=None): 105 """Constructs an `SVM` estimator object. 106 107 Args: 108 example_id_column: A string defining the feature column name representing 109 example ids. Used to initialize the underlying optimizer. 110 feature_columns: An iterable containing all the feature columns used by 111 the model. All items in the set should be instances of classes derived 112 from `FeatureColumn`. 113 weight_column_name: A string defining feature column name representing 114 weights. It is used to down weight or boost examples during training. It 115 will be multiplied by the loss of the example. 116 model_dir: Directory to save model parameters, graph and etc. This can 117 also be used to load checkpoints from the directory into a estimator to 118 continue training a previously saved model. 119 l1_regularization: L1-regularization parameter. Refers to global L1 120 regularization (across all examples). 121 l2_regularization: L2-regularization parameter. Refers to global L2 122 regularization (across all examples). 123 num_loss_partitions: number of partitions of the (global) loss function 124 optimized by the underlying optimizer (SDCAOptimizer). 125 kernels: A list of kernels for the SVM. Currently, no kernels are 126 supported. Reserved for future use for non-linear SVMs. 127 config: RunConfig object to configure the runtime settings. 128 feature_engineering_fn: Feature engineering function. Takes features and 129 labels which are the output of `input_fn` and 130 returns features and labels which will be fed 131 into the model. 132 133 Raises: 134 ValueError: if kernels passed is not None. 135 """ 136 if kernels is not None: 137 raise ValueError("Kernel SVMs are not currently supported.") 138 optimizer = sdca_optimizer.SDCAOptimizer( 139 example_id_column=example_id_column, 140 num_loss_partitions=num_loss_partitions, 141 symmetric_l1_regularization=l1_regularization, 142 symmetric_l2_regularization=l2_regularization) 143 144 self._feature_columns = feature_columns 145 chief_hook = linear._SdcaUpdateWeightsHook() # pylint: disable=protected-access 146 super(SVM, self).__init__( 147 model_fn=linear.sdca_model_fn, 148 model_dir=model_dir, 149 config=config, 150 params={ 151 "head": head_lib.binary_svm_head( 152 weight_column_name=weight_column_name, 153 enable_centered_bias=False), 154 "feature_columns": feature_columns, 155 "optimizer": optimizer, 156 "weight_column_name": weight_column_name, 157 "update_weights_hook": chief_hook, 158 }, 159 feature_engineering_fn=feature_engineering_fn) 160 161 @deprecated_arg_values( 162 estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS, 163 as_iterable=False) 164 def predict_classes(self, x=None, input_fn=None, batch_size=None, 165 as_iterable=True): 166 """Runs inference to determine the predicted class.""" 167 key = prediction_key.PredictionKey.CLASSES 168 preds = super(SVM, self).predict( 169 x=x, 170 input_fn=input_fn, 171 batch_size=batch_size, 172 outputs=[key], 173 as_iterable=as_iterable) 174 if as_iterable: 175 return _as_iterable(preds, output=key) 176 return preds[key] 177 178 @deprecated_arg_values( 179 estimator.AS_ITERABLE_DATE, estimator.AS_ITERABLE_INSTRUCTIONS, 180 as_iterable=False) 181 def predict_proba(self, x=None, input_fn=None, batch_size=None, outputs=None, 182 as_iterable=True): 183 """Runs inference to determine the class probability predictions.""" 184 key = prediction_key.PredictionKey.PROBABILITIES 185 preds = super(SVM, self).predict( 186 x=x, 187 input_fn=input_fn, 188 batch_size=batch_size, 189 outputs=[key], 190 as_iterable=as_iterable) 191 if as_iterable: 192 return _as_iterable(preds, output=key) 193 return preds[key] 194 # pylint: enable=protected-access 195 196 @deprecated("2017-03-25", "Please use Estimator.export_savedmodel() instead.") 197 def export(self, export_dir, signature_fn=None, 198 input_fn=None, default_batch_size=1, 199 exports_to_keep=None): 200 """See BaseEstimator.export.""" 201 return self.export_with_defaults( 202 export_dir=export_dir, 203 signature_fn=signature_fn, 204 input_fn=input_fn, 205 default_batch_size=default_batch_size, 206 exports_to_keep=exports_to_keep) 207 208 @deprecated("2017-03-25", "Please use Estimator.export_savedmodel() instead.") 209 def export_with_defaults( 210 self, 211 export_dir, 212 signature_fn=None, 213 input_fn=None, 214 default_batch_size=1, 215 exports_to_keep=None): 216 """Same as BaseEstimator.export, but uses some defaults.""" 217 def default_input_fn(unused_estimator, examples): 218 return layers.parse_feature_columns_from_examples( 219 examples, self._feature_columns) 220 return super(SVM, self).export(export_dir=export_dir, 221 signature_fn=signature_fn, 222 input_fn=input_fn or default_input_fn, 223 default_batch_size=default_batch_size, 224 exports_to_keep=exports_to_keep) 225