• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for estimators.SVM."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.layers.python.layers import feature_column
22from tensorflow.contrib.learn.python.learn.estimators import svm
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.platform import test
26
27
28class SVMTest(test.TestCase):
29
30  def testRealValuedFeaturesPerfectlySeparable(self):
31    """Tests SVM classifier with real valued features."""
32
33    def input_fn():
34      return {
35          'example_id': constant_op.constant(['1', '2', '3']),
36          'feature1': constant_op.constant([[0.0], [1.0], [3.0]]),
37          'feature2': constant_op.constant([[1.0], [-1.2], [1.0]]),
38      }, constant_op.constant([[1], [0], [1]])
39
40    feature1 = feature_column.real_valued_column('feature1')
41    feature2 = feature_column.real_valued_column('feature2')
42    svm_classifier = svm.SVM(feature_columns=[feature1, feature2],
43                             example_id_column='example_id',
44                             l1_regularization=0.0,
45                             l2_regularization=0.0)
46    svm_classifier.fit(input_fn=input_fn, steps=30)
47    metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1)
48    loss = metrics['loss']
49    accuracy = metrics['accuracy']
50    # The points are not only separable but there exist weights (for instance
51    # w1=0.0, w2=1.0) that satisfy the margin inequalities (y_i* w^T*x_i >=1).
52    # The unregularized loss should therefore be 0.0.
53    self.assertAlmostEqual(loss, 0.0, places=3)
54    self.assertAlmostEqual(accuracy, 1.0, places=3)
55
56  def testRealValuedFeaturesWithL2Regularization(self):
57    """Tests SVM classifier with real valued features and L2 regularization."""
58
59    def input_fn():
60      return {
61          'example_id': constant_op.constant(['1', '2', '3']),
62          'feature1': constant_op.constant([0.5, 1.0, 1.0]),
63          'feature2': constant_op.constant([1.0, -1.0, 0.5]),
64      }, constant_op.constant([1, 0, 1])
65
66    feature1 = feature_column.real_valued_column('feature1')
67    feature2 = feature_column.real_valued_column('feature2')
68    svm_classifier = svm.SVM(feature_columns=[feature1, feature2],
69                             example_id_column='example_id',
70                             l1_regularization=0.0,
71                             l2_regularization=1.0)
72    svm_classifier.fit(input_fn=input_fn, steps=30)
73    metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1)
74    loss = metrics['loss']
75    accuracy = metrics['accuracy']
76    # The points are in general separable. Also, if there was no regularization,
77    # the margin inequalities would be satisfied too (for instance by w1=1.0,
78    # w2=5.0). Due to regularization, smaller weights are chosen. This results
79    # to a small but non-zero uneregularized loss. Still, all the predictions
80    # will be correct resulting to perfect accuracy.
81    self.assertLess(loss, 0.1)
82    self.assertAlmostEqual(accuracy, 1.0, places=3)
83
84  def testMultiDimensionalRealValuedFeaturesWithL2Regularization(self):
85    """Tests SVM with multi-dimensional real features and L2 regularization."""
86
87    # This is identical to the one in testRealValuedFeaturesWithL2Regularization
88    # where 2 tensors (dense features) of shape [3, 1] have been replaced by a
89    # single tensor (dense feature) of shape [3, 2].
90    def input_fn():
91      return {
92          'example_id':
93              constant_op.constant(['1', '2', '3']),
94          'multi_dim_feature':
95              constant_op.constant([[0.5, 1.0], [1.0, -1.0], [1.0, 0.5]]),
96      }, constant_op.constant([[1], [0], [1]])
97
98    multi_dim_feature = feature_column.real_valued_column(
99        'multi_dim_feature', dimension=2)
100    svm_classifier = svm.SVM(feature_columns=[multi_dim_feature],
101                             example_id_column='example_id',
102                             l1_regularization=0.0,
103                             l2_regularization=1.0)
104    svm_classifier.fit(input_fn=input_fn, steps=30)
105    metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1)
106    loss = metrics['loss']
107    accuracy = metrics['accuracy']
108    self.assertLess(loss, 0.1)
109    self.assertAlmostEqual(accuracy, 1.0, places=3)
110
111  def testRealValuedFeaturesWithMildL1Regularization(self):
112    """Tests SVM classifier with real valued features and L2 regularization."""
113
114    def input_fn():
115      return {
116          'example_id': constant_op.constant(['1', '2', '3']),
117          'feature1': constant_op.constant([[0.5], [1.0], [1.0]]),
118          'feature2': constant_op.constant([[1.0], [-1.0], [0.5]]),
119      }, constant_op.constant([[1], [0], [1]])
120
121    feature1 = feature_column.real_valued_column('feature1')
122    feature2 = feature_column.real_valued_column('feature2')
123    svm_classifier = svm.SVM(feature_columns=[feature1, feature2],
124                             example_id_column='example_id',
125                             l1_regularization=0.5,
126                             l2_regularization=1.0)
127    svm_classifier.fit(input_fn=input_fn, steps=30)
128    metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1)
129    loss = metrics['loss']
130    accuracy = metrics['accuracy']
131
132    # Adding small L1 regularization favors even smaller weights. This results
133    # to somewhat moderate unregularized loss (bigger than the one when there is
134    # no L1 regularization. Still, since L1 is small, all the predictions will
135    # be correct resulting to perfect accuracy.
136    self.assertGreater(loss, 0.1)
137    self.assertAlmostEqual(accuracy, 1.0, places=3)
138
139  def testRealValuedFeaturesWithBigL1Regularization(self):
140    """Tests SVM classifier with real valued features and L2 regularization."""
141
142    def input_fn():
143      return {
144          'example_id': constant_op.constant(['1', '2', '3']),
145          'feature1': constant_op.constant([0.5, 1.0, 1.0]),
146          'feature2': constant_op.constant([[1.0], [-1.0], [0.5]]),
147      }, constant_op.constant([[1], [0], [1]])
148
149    feature1 = feature_column.real_valued_column('feature1')
150    feature2 = feature_column.real_valued_column('feature2')
151    svm_classifier = svm.SVM(feature_columns=[feature1, feature2],
152                             example_id_column='example_id',
153                             l1_regularization=3.0,
154                             l2_regularization=1.0)
155    svm_classifier.fit(input_fn=input_fn, steps=30)
156    metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1)
157    loss = metrics['loss']
158    accuracy = metrics['accuracy']
159
160    # When L1 regularization parameter is large, the loss due to regularization
161    # outweights the unregularized loss. In this case, the classifier will favor
162    # very small weights (in current case 0) resulting both big unregularized
163    # loss and bad accuracy.
164    self.assertAlmostEqual(loss, 1.0, places=3)
165    self.assertAlmostEqual(accuracy, 1 / 3, places=3)
166
167  def testSparseFeatures(self):
168    """Tests SVM classifier with (hashed) sparse features."""
169
170    def input_fn():
171      return {
172          'example_id':
173              constant_op.constant(['1', '2', '3']),
174          'price':
175              constant_op.constant([[0.8], [0.6], [0.3]]),
176          'country':
177              sparse_tensor.SparseTensor(
178                  values=['IT', 'US', 'GB'],
179                  indices=[[0, 0], [1, 0], [2, 0]],
180                  dense_shape=[3, 1]),
181      }, constant_op.constant([[0], [1], [1]])
182
183    price = feature_column.real_valued_column('price')
184    country = feature_column.sparse_column_with_hash_bucket(
185        'country', hash_bucket_size=5)
186    svm_classifier = svm.SVM(feature_columns=[price, country],
187                             example_id_column='example_id',
188                             l1_regularization=0.0,
189                             l2_regularization=1.0)
190    svm_classifier.fit(input_fn=input_fn, steps=30)
191    accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy']
192    self.assertAlmostEqual(accuracy, 1.0, places=3)
193
194  def testBucketizedFeatures(self):
195    """Tests SVM classifier with bucketized features."""
196
197    def input_fn():
198      return {
199          'example_id': constant_op.constant(['1', '2', '3']),
200          'price': constant_op.constant([[600.0], [800.0], [400.0]]),
201          'sq_footage': constant_op.constant([[1000.0], [800.0], [500.0]]),
202          'weights': constant_op.constant([[1.0], [1.0], [1.0]])
203      }, constant_op.constant([[1], [0], [1]])
204
205    price_bucket = feature_column.bucketized_column(
206        feature_column.real_valued_column('price'), boundaries=[500.0, 700.0])
207    sq_footage_bucket = feature_column.bucketized_column(
208        feature_column.real_valued_column('sq_footage'), boundaries=[650.0])
209
210    svm_classifier = svm.SVM(feature_columns=[price_bucket, sq_footage_bucket],
211                             example_id_column='example_id',
212                             l1_regularization=0.1,
213                             l2_regularization=1.0)
214    svm_classifier.fit(input_fn=input_fn, steps=30)
215    accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy']
216    self.assertAlmostEqual(accuracy, 1.0, places=3)
217
218  def testMixedFeatures(self):
219    """Tests SVM classifier with a mix of features."""
220
221    def input_fn():
222      return {
223          'example_id':
224              constant_op.constant(['1', '2', '3']),
225          'price':
226              constant_op.constant([0.6, 0.8, 0.3]),
227          'sq_footage':
228              constant_op.constant([[900.0], [700.0], [600.0]]),
229          'country':
230              sparse_tensor.SparseTensor(
231                  values=['IT', 'US', 'GB'],
232                  indices=[[0, 0], [1, 3], [2, 1]],
233                  dense_shape=[3, 5]),
234          'weights':
235              constant_op.constant([[3.0], [1.0], [1.0]])
236      }, constant_op.constant([[1], [0], [1]])
237
238    price = feature_column.real_valued_column('price')
239    sq_footage_bucket = feature_column.bucketized_column(
240        feature_column.real_valued_column('sq_footage'),
241        boundaries=[650.0, 800.0])
242    country = feature_column.sparse_column_with_hash_bucket(
243        'country', hash_bucket_size=5)
244    sq_footage_country = feature_column.crossed_column(
245        [sq_footage_bucket, country], hash_bucket_size=10)
246    svm_classifier = svm.SVM(
247        feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
248        example_id_column='example_id',
249        weight_column_name='weights',
250        l1_regularization=0.1,
251        l2_regularization=1.0)
252
253    svm_classifier.fit(input_fn=input_fn, steps=30)
254    accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy']
255    self.assertAlmostEqual(accuracy, 1.0, places=3)
256
257
258if __name__ == '__main__':
259  test.main()
260