1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for estimators.SVM.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.layers.python.layers import feature_column 22from tensorflow.contrib.learn.python.learn.estimators import svm 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.platform import test 26 27 28class SVMTest(test.TestCase): 29 30 def testRealValuedFeaturesPerfectlySeparable(self): 31 """Tests SVM classifier with real valued features.""" 32 33 def input_fn(): 34 return { 35 'example_id': constant_op.constant(['1', '2', '3']), 36 'feature1': constant_op.constant([[0.0], [1.0], [3.0]]), 37 'feature2': constant_op.constant([[1.0], [-1.2], [1.0]]), 38 }, constant_op.constant([[1], [0], [1]]) 39 40 feature1 = feature_column.real_valued_column('feature1') 41 feature2 = feature_column.real_valued_column('feature2') 42 svm_classifier = svm.SVM(feature_columns=[feature1, feature2], 43 example_id_column='example_id', 44 l1_regularization=0.0, 45 l2_regularization=0.0) 46 svm_classifier.fit(input_fn=input_fn, steps=30) 47 metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1) 48 loss = metrics['loss'] 49 accuracy = metrics['accuracy'] 50 # The points are not only separable but there exist weights (for instance 51 # w1=0.0, w2=1.0) that satisfy the margin inequalities (y_i* w^T*x_i >=1). 52 # The unregularized loss should therefore be 0.0. 53 self.assertAlmostEqual(loss, 0.0, places=3) 54 self.assertAlmostEqual(accuracy, 1.0, places=3) 55 56 def testRealValuedFeaturesWithL2Regularization(self): 57 """Tests SVM classifier with real valued features and L2 regularization.""" 58 59 def input_fn(): 60 return { 61 'example_id': constant_op.constant(['1', '2', '3']), 62 'feature1': constant_op.constant([0.5, 1.0, 1.0]), 63 'feature2': constant_op.constant([1.0, -1.0, 0.5]), 64 }, constant_op.constant([1, 0, 1]) 65 66 feature1 = feature_column.real_valued_column('feature1') 67 feature2 = feature_column.real_valued_column('feature2') 68 svm_classifier = svm.SVM(feature_columns=[feature1, feature2], 69 example_id_column='example_id', 70 l1_regularization=0.0, 71 l2_regularization=1.0) 72 svm_classifier.fit(input_fn=input_fn, steps=30) 73 metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1) 74 loss = metrics['loss'] 75 accuracy = metrics['accuracy'] 76 # The points are in general separable. Also, if there was no regularization, 77 # the margin inequalities would be satisfied too (for instance by w1=1.0, 78 # w2=5.0). Due to regularization, smaller weights are chosen. This results 79 # to a small but non-zero uneregularized loss. Still, all the predictions 80 # will be correct resulting to perfect accuracy. 81 self.assertLess(loss, 0.1) 82 self.assertAlmostEqual(accuracy, 1.0, places=3) 83 84 def testMultiDimensionalRealValuedFeaturesWithL2Regularization(self): 85 """Tests SVM with multi-dimensional real features and L2 regularization.""" 86 87 # This is identical to the one in testRealValuedFeaturesWithL2Regularization 88 # where 2 tensors (dense features) of shape [3, 1] have been replaced by a 89 # single tensor (dense feature) of shape [3, 2]. 90 def input_fn(): 91 return { 92 'example_id': 93 constant_op.constant(['1', '2', '3']), 94 'multi_dim_feature': 95 constant_op.constant([[0.5, 1.0], [1.0, -1.0], [1.0, 0.5]]), 96 }, constant_op.constant([[1], [0], [1]]) 97 98 multi_dim_feature = feature_column.real_valued_column( 99 'multi_dim_feature', dimension=2) 100 svm_classifier = svm.SVM(feature_columns=[multi_dim_feature], 101 example_id_column='example_id', 102 l1_regularization=0.0, 103 l2_regularization=1.0) 104 svm_classifier.fit(input_fn=input_fn, steps=30) 105 metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1) 106 loss = metrics['loss'] 107 accuracy = metrics['accuracy'] 108 self.assertLess(loss, 0.1) 109 self.assertAlmostEqual(accuracy, 1.0, places=3) 110 111 def testRealValuedFeaturesWithMildL1Regularization(self): 112 """Tests SVM classifier with real valued features and L2 regularization.""" 113 114 def input_fn(): 115 return { 116 'example_id': constant_op.constant(['1', '2', '3']), 117 'feature1': constant_op.constant([[0.5], [1.0], [1.0]]), 118 'feature2': constant_op.constant([[1.0], [-1.0], [0.5]]), 119 }, constant_op.constant([[1], [0], [1]]) 120 121 feature1 = feature_column.real_valued_column('feature1') 122 feature2 = feature_column.real_valued_column('feature2') 123 svm_classifier = svm.SVM(feature_columns=[feature1, feature2], 124 example_id_column='example_id', 125 l1_regularization=0.5, 126 l2_regularization=1.0) 127 svm_classifier.fit(input_fn=input_fn, steps=30) 128 metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1) 129 loss = metrics['loss'] 130 accuracy = metrics['accuracy'] 131 132 # Adding small L1 regularization favors even smaller weights. This results 133 # to somewhat moderate unregularized loss (bigger than the one when there is 134 # no L1 regularization. Still, since L1 is small, all the predictions will 135 # be correct resulting to perfect accuracy. 136 self.assertGreater(loss, 0.1) 137 self.assertAlmostEqual(accuracy, 1.0, places=3) 138 139 def testRealValuedFeaturesWithBigL1Regularization(self): 140 """Tests SVM classifier with real valued features and L2 regularization.""" 141 142 def input_fn(): 143 return { 144 'example_id': constant_op.constant(['1', '2', '3']), 145 'feature1': constant_op.constant([0.5, 1.0, 1.0]), 146 'feature2': constant_op.constant([[1.0], [-1.0], [0.5]]), 147 }, constant_op.constant([[1], [0], [1]]) 148 149 feature1 = feature_column.real_valued_column('feature1') 150 feature2 = feature_column.real_valued_column('feature2') 151 svm_classifier = svm.SVM(feature_columns=[feature1, feature2], 152 example_id_column='example_id', 153 l1_regularization=3.0, 154 l2_regularization=1.0) 155 svm_classifier.fit(input_fn=input_fn, steps=30) 156 metrics = svm_classifier.evaluate(input_fn=input_fn, steps=1) 157 loss = metrics['loss'] 158 accuracy = metrics['accuracy'] 159 160 # When L1 regularization parameter is large, the loss due to regularization 161 # outweights the unregularized loss. In this case, the classifier will favor 162 # very small weights (in current case 0) resulting both big unregularized 163 # loss and bad accuracy. 164 self.assertAlmostEqual(loss, 1.0, places=3) 165 self.assertAlmostEqual(accuracy, 1 / 3, places=3) 166 167 def testSparseFeatures(self): 168 """Tests SVM classifier with (hashed) sparse features.""" 169 170 def input_fn(): 171 return { 172 'example_id': 173 constant_op.constant(['1', '2', '3']), 174 'price': 175 constant_op.constant([[0.8], [0.6], [0.3]]), 176 'country': 177 sparse_tensor.SparseTensor( 178 values=['IT', 'US', 'GB'], 179 indices=[[0, 0], [1, 0], [2, 0]], 180 dense_shape=[3, 1]), 181 }, constant_op.constant([[0], [1], [1]]) 182 183 price = feature_column.real_valued_column('price') 184 country = feature_column.sparse_column_with_hash_bucket( 185 'country', hash_bucket_size=5) 186 svm_classifier = svm.SVM(feature_columns=[price, country], 187 example_id_column='example_id', 188 l1_regularization=0.0, 189 l2_regularization=1.0) 190 svm_classifier.fit(input_fn=input_fn, steps=30) 191 accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy'] 192 self.assertAlmostEqual(accuracy, 1.0, places=3) 193 194 def testBucketizedFeatures(self): 195 """Tests SVM classifier with bucketized features.""" 196 197 def input_fn(): 198 return { 199 'example_id': constant_op.constant(['1', '2', '3']), 200 'price': constant_op.constant([[600.0], [800.0], [400.0]]), 201 'sq_footage': constant_op.constant([[1000.0], [800.0], [500.0]]), 202 'weights': constant_op.constant([[1.0], [1.0], [1.0]]) 203 }, constant_op.constant([[1], [0], [1]]) 204 205 price_bucket = feature_column.bucketized_column( 206 feature_column.real_valued_column('price'), boundaries=[500.0, 700.0]) 207 sq_footage_bucket = feature_column.bucketized_column( 208 feature_column.real_valued_column('sq_footage'), boundaries=[650.0]) 209 210 svm_classifier = svm.SVM(feature_columns=[price_bucket, sq_footage_bucket], 211 example_id_column='example_id', 212 l1_regularization=0.1, 213 l2_regularization=1.0) 214 svm_classifier.fit(input_fn=input_fn, steps=30) 215 accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy'] 216 self.assertAlmostEqual(accuracy, 1.0, places=3) 217 218 def testMixedFeatures(self): 219 """Tests SVM classifier with a mix of features.""" 220 221 def input_fn(): 222 return { 223 'example_id': 224 constant_op.constant(['1', '2', '3']), 225 'price': 226 constant_op.constant([0.6, 0.8, 0.3]), 227 'sq_footage': 228 constant_op.constant([[900.0], [700.0], [600.0]]), 229 'country': 230 sparse_tensor.SparseTensor( 231 values=['IT', 'US', 'GB'], 232 indices=[[0, 0], [1, 3], [2, 1]], 233 dense_shape=[3, 5]), 234 'weights': 235 constant_op.constant([[3.0], [1.0], [1.0]]) 236 }, constant_op.constant([[1], [0], [1]]) 237 238 price = feature_column.real_valued_column('price') 239 sq_footage_bucket = feature_column.bucketized_column( 240 feature_column.real_valued_column('sq_footage'), 241 boundaries=[650.0, 800.0]) 242 country = feature_column.sparse_column_with_hash_bucket( 243 'country', hash_bucket_size=5) 244 sq_footage_country = feature_column.crossed_column( 245 [sq_footage_bucket, country], hash_bucket_size=10) 246 svm_classifier = svm.SVM( 247 feature_columns=[price, sq_footage_bucket, country, sq_footage_country], 248 example_id_column='example_id', 249 weight_column_name='weights', 250 l1_regularization=0.1, 251 l2_regularization=1.0) 252 253 svm_classifier.fit(input_fn=input_fn, steps=30) 254 accuracy = svm_classifier.evaluate(input_fn=input_fn, steps=1)['accuracy'] 255 self.assertAlmostEqual(accuracy, 1.0, places=3) 256 257 258if __name__ == '__main__': 259 test.main() 260