• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ComposableModel classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.training import training_util
22from tensorflow.contrib.layers.python.layers import feature_column
23from tensorflow.contrib.learn.python.learn.datasets import base
24from tensorflow.contrib.learn.python.learn.estimators import composable_model
25from tensorflow.contrib.learn.python.learn.estimators import estimator
26from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
27from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import state_ops
33from tensorflow.python.platform import test
34
35
36def _iris_input_fn():
37  iris = base.load_iris()
38  return {
39      'feature': constant_op.constant(
40          iris.data, dtype=dtypes.float32)
41  }, constant_op.constant(
42      iris.target, shape=[150, 1], dtype=dtypes.int32)
43
44
45def _base_model_fn(features, labels, mode, params):
46  model = params['model']
47  feature_columns = params['feature_columns']
48  head = params['head']
49
50  if mode == model_fn_lib.ModeKeys.TRAIN:
51    logits = model.build_model(features, feature_columns, is_training=True)
52  elif mode == model_fn_lib.ModeKeys.EVAL:
53    logits = model.build_model(features, feature_columns, is_training=False)
54  else:
55    raise NotImplementedError
56
57  def _train_op_fn(loss):
58    global_step = training_util.get_global_step()
59    assert global_step
60    train_step = model.get_train_step(loss)
61
62    with ops.control_dependencies(train_step):
63      with ops.get_default_graph().colocate_with(global_step):
64        return state_ops.assign_add(global_step, 1).op
65
66  return head.create_model_fn_ops(
67      features=features,
68      mode=mode,
69      labels=labels,
70      train_op_fn=_train_op_fn,
71      logits=logits)
72
73
74def _linear_estimator(head, feature_columns):
75  return estimator.Estimator(
76      model_fn=_base_model_fn,
77      params={
78          'model':
79              composable_model.LinearComposableModel(
80                  num_label_columns=head.logits_dimension),
81          'feature_columns':
82              feature_columns,
83          'head':
84              head
85      })
86
87
88def _joint_linear_estimator(head, feature_columns):
89  return estimator.Estimator(
90      model_fn=_base_model_fn,
91      params={
92          'model':
93              composable_model.LinearComposableModel(
94                  num_label_columns=head.logits_dimension, _joint_weights=True),
95          'feature_columns':
96              feature_columns,
97          'head':
98              head
99      })
100
101
102def _dnn_estimator(head, feature_columns, hidden_units):
103  return estimator.Estimator(
104      model_fn=_base_model_fn,
105      params={
106          'model':
107              composable_model.DNNComposableModel(
108                  num_label_columns=head.logits_dimension,
109                  hidden_units=hidden_units),
110          'feature_columns':
111              feature_columns,
112          'head':
113              head
114      })
115
116
117class ComposableModelTest(test.TestCase):
118
119  def testLinearModel(self):
120    """Tests that loss goes down with training."""
121
122    def input_fn():
123      return {
124          'age':
125              constant_op.constant([1]),
126          'language':
127              sparse_tensor.SparseTensor(
128                  values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
129      }, constant_op.constant([[1]])
130
131    language = feature_column.sparse_column_with_hash_bucket('language', 100)
132    age = feature_column.real_valued_column('age')
133
134    head = head_lib.multi_class_head(n_classes=2)
135    classifier = _linear_estimator(head, feature_columns=[age, language])
136
137    classifier.fit(input_fn=input_fn, steps=1000)
138    loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
139    classifier.fit(input_fn=input_fn, steps=2000)
140    loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
141    self.assertLess(loss2, loss1)
142    self.assertLess(loss2, 0.01)
143
144  def testJointLinearModel(self):
145    """Tests that loss goes down with training."""
146
147    def input_fn():
148      return {
149          'age':
150              sparse_tensor.SparseTensor(
151                  values=['1'], indices=[[0, 0]], dense_shape=[1, 1]),
152          'language':
153              sparse_tensor.SparseTensor(
154                  values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
155      }, constant_op.constant([[1]])
156
157    language = feature_column.sparse_column_with_hash_bucket('language', 100)
158    age = feature_column.sparse_column_with_hash_bucket('age', 2)
159
160    head = head_lib.multi_class_head(n_classes=2)
161    classifier = _joint_linear_estimator(head, feature_columns=[age, language])
162
163    classifier.fit(input_fn=input_fn, steps=1000)
164    loss1 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
165    classifier.fit(input_fn=input_fn, steps=2000)
166    loss2 = classifier.evaluate(input_fn=input_fn, steps=1)['loss']
167    self.assertLess(loss2, loss1)
168    self.assertLess(loss2, 0.01)
169
170  def testDNNModel(self):
171    """Tests multi-class classification using matrix data as input."""
172    cont_features = [feature_column.real_valued_column('feature', dimension=4)]
173
174    head = head_lib.multi_class_head(n_classes=3)
175    classifier = _dnn_estimator(
176        head, feature_columns=cont_features, hidden_units=[3, 3])
177
178    classifier.fit(input_fn=_iris_input_fn, steps=1000)
179    classifier.evaluate(input_fn=_iris_input_fn, steps=100)
180
181
182if __name__ == '__main__':
183  test.main()
184