1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ComposableModel classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.training import training_util 22from tensorflow.contrib.layers.python.layers import feature_column 23from tensorflow.contrib.learn.python.learn.datasets import base 24from tensorflow.contrib.learn.python.learn.estimators import composable_model 25from tensorflow.contrib.learn.python.learn.estimators import estimator 26from tensorflow.contrib.learn.python.learn.estimators import head as head_lib 27from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import state_ops 33from tensorflow.python.platform import test 34 35 36def _iris_input_fn(): 37 iris = base.load_iris() 38 return { 39 'feature': constant_op.constant( 40 iris.data, dtype=dtypes.float32) 41 }, constant_op.constant( 42 iris.target, shape=[150, 1], dtype=dtypes.int32) 43 44 45def _base_model_fn(features, labels, mode, params): 46 model = params['model'] 47 feature_columns = params['feature_columns'] 48 head = params['head'] 49 50 if mode == model_fn_lib.ModeKeys.TRAIN: 51 logits = model.build_model(features, feature_columns, is_training=True) 52 elif mode == model_fn_lib.ModeKeys.EVAL: 53 logits = model.build_model(features, feature_columns, is_training=False) 54 else: 55 raise NotImplementedError 56 57 def _train_op_fn(loss): 58 global_step = training_util.get_global_step() 59 assert global_step 60 train_step = model.get_train_step(loss) 61 62 with ops.control_dependencies(train_step): 63 with ops.get_default_graph().colocate_with(global_step): 64 return state_ops.assign_add(global_step, 1).op 65 66 return head.create_model_fn_ops( 67 features=features, 68 mode=mode, 69 labels=labels, 70 train_op_fn=_train_op_fn, 71 logits=logits) 72 73 74def _linear_estimator(head, feature_columns): 75 return estimator.Estimator( 76 model_fn=_base_model_fn, 77 params={ 78 'model': 79 composable_model.LinearComposableModel( 80 num_label_columns=head.logits_dimension), 81 'feature_columns': 82 feature_columns, 83 'head': 84 head 85 }) 86 87 88def _joint_linear_estimator(head, feature_columns): 89 return estimator.Estimator( 90 model_fn=_base_model_fn, 91 params={ 92 'model': 93 composable_model.LinearComposableModel( 94 num_label_columns=head.logits_dimension, _joint_weights=True), 95 'feature_columns': 96 feature_columns, 97 'head': 98 head 99 }) 100 101 102def _dnn_estimator(head, feature_columns, hidden_units): 103 return estimator.Estimator( 104 model_fn=_base_model_fn, 105 params={ 106 'model': 107 composable_model.DNNComposableModel( 108 num_label_columns=head.logits_dimension, 109 hidden_units=hidden_units), 110 'feature_columns': 111 feature_columns, 112 'head': 113 head 114 }) 115 116 117class ComposableModelTest(test.TestCase): 118 119 def testLinearModel(self): 120 """Tests that loss goes down with training.""" 121 122 def input_fn(): 123 return { 124 'age': 125 constant_op.constant([1]), 126 'language': 127 sparse_tensor.SparseTensor( 128 values=['english'], indices=[[0, 0]], dense_shape=[1, 1]) 129 }, constant_op.constant([[1]]) 130 131 language = feature_column.sparse_column_with_hash_bucket('language', 100) 132 age = feature_column.real_valued_column('age') 133 134 head = head_lib.multi_class_head(n_classes=2) 135 classifier = _linear_estimator(head, feature_columns=[age, language]) 136 137 classifier.fit(input_fn=input_fn, steps=1000) 138 loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss'] 139 classifier.fit(input_fn=input_fn, steps=2000) 140 loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss'] 141 self.assertLess(loss2, loss1) 142 self.assertLess(loss2, 0.01) 143 144 def testJointLinearModel(self): 145 """Tests that loss goes down with training.""" 146 147 def input_fn(): 148 return { 149 'age': 150 sparse_tensor.SparseTensor( 151 values=['1'], indices=[[0, 0]], dense_shape=[1, 1]), 152 'language': 153 sparse_tensor.SparseTensor( 154 values=['english'], indices=[[0, 0]], dense_shape=[1, 1]) 155 }, constant_op.constant([[1]]) 156 157 language = feature_column.sparse_column_with_hash_bucket('language', 100) 158 age = feature_column.sparse_column_with_hash_bucket('age', 2) 159 160 head = head_lib.multi_class_head(n_classes=2) 161 classifier = _joint_linear_estimator(head, feature_columns=[age, language]) 162 163 classifier.fit(input_fn=input_fn, steps=1000) 164 loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss'] 165 classifier.fit(input_fn=input_fn, steps=2000) 166 loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss'] 167 self.assertLess(loss2, loss1) 168 self.assertLess(loss2, 0.01) 169 170 def testDNNModel(self): 171 """Tests multi-class classification using matrix data as input.""" 172 cont_features = [feature_column.real_valued_column('feature', dimension=4)] 173 174 head = head_lib.multi_class_head(n_classes=3) 175 classifier = _dnn_estimator( 176 head, feature_columns=cont_features, hidden_units=[3, 3]) 177 178 classifier.fit(input_fn=_iris_input_fn, steps=1000) 179 classifier.evaluate(input_fn=_iris_input_fn, steps=100) 180 181 182if __name__ == '__main__': 183 test.main() 184