• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Estimator input."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import tempfile
23
24import numpy as np
25
26from tensorflow.python.training import training_util
27from tensorflow.contrib.layers.python.layers import optimizers
28from tensorflow.contrib.learn.python.learn import metric_spec
29from tensorflow.contrib.learn.python.learn import models
30from tensorflow.contrib.learn.python.learn.datasets import base
31from tensorflow.contrib.learn.python.learn.estimators import _sklearn
32from tensorflow.contrib.learn.python.learn.estimators import estimator
33from tensorflow.contrib.learn.python.learn.estimators import model_fn
34from tensorflow.contrib.metrics.python.ops import metric_ops
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import data_flow_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.platform import test
41from tensorflow.python.training import input as input_lib
42from tensorflow.python.training import queue_runner_impl
43
44_BOSTON_INPUT_DIM = 13
45_IRIS_INPUT_DIM = 4
46
47
48def boston_input_fn(num_epochs=None):
49  boston = base.load_boston()
50  features = input_lib.limit_epochs(
51      array_ops.reshape(
52          constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM]),
53      num_epochs=num_epochs)
54  labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1])
55  return features, labels
56
57
58def boston_input_fn_with_queue(num_epochs=None):
59  features, labels = boston_input_fn(num_epochs=num_epochs)
60
61  # Create a minimal queue runner.
62  fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32)
63  queue_runner = queue_runner_impl.QueueRunner(fake_queue,
64                                               [constant_op.constant(0)])
65  queue_runner_impl.add_queue_runner(queue_runner)
66
67  return features, labels
68
69
70def iris_input_fn():
71  iris = base.load_iris()
72  features = array_ops.reshape(
73      constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM])
74  labels = array_ops.reshape(constant_op.constant(iris.target), [-1])
75  return features, labels
76
77
78def iris_input_fn_labels_dict():
79  iris = base.load_iris()
80  features = array_ops.reshape(
81      constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM])
82  labels = {
83      'labels': array_ops.reshape(constant_op.constant(iris.target), [-1])
84  }
85  return features, labels
86
87
88def boston_eval_fn():
89  boston = base.load_boston()
90  n_examples = len(boston.target)
91  features = array_ops.reshape(
92      constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
93  labels = array_ops.reshape(
94      constant_op.constant(boston.target), [n_examples, 1])
95  return array_ops.concat([features, features],
96                          0), array_ops.concat([labels, labels], 0)
97
98
99def extract(data, key):
100  if isinstance(data, dict):
101    assert key in data
102    return data[key]
103  else:
104    return data
105
106
107def linear_model_params_fn(features, labels, mode, params):
108  features = extract(features, 'input')
109  labels = extract(labels, 'labels')
110
111  assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
112                  model_fn.ModeKeys.INFER)
113  prediction, loss = (models.linear_regression_zero_init(features, labels))
114  train_op = optimizers.optimize_loss(
115      loss,
116      training_util.get_global_step(),
117      optimizer='Adagrad',
118      learning_rate=params['learning_rate'])
119  return prediction, loss, train_op
120
121
122def linear_model_fn(features, labels, mode):
123  features = extract(features, 'input')
124  labels = extract(labels, 'labels')
125  assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
126                  model_fn.ModeKeys.INFER)
127  if isinstance(features, dict):
128    (_, features), = features.items()
129  prediction, loss = (models.linear_regression_zero_init(features, labels))
130  train_op = optimizers.optimize_loss(
131      loss,
132      training_util.get_global_step(),
133      optimizer='Adagrad',
134      learning_rate=0.1)
135  return prediction, loss, train_op
136
137
138def linear_model_fn_with_model_fn_ops(features, labels, mode):
139  """Same as linear_model_fn, but returns `ModelFnOps`."""
140  assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
141                  model_fn.ModeKeys.INFER)
142  prediction, loss = (models.linear_regression_zero_init(features, labels))
143  train_op = optimizers.optimize_loss(
144      loss,
145      training_util.get_global_step(),
146      optimizer='Adagrad',
147      learning_rate=0.1)
148  return model_fn.ModelFnOps(
149      mode=mode, predictions=prediction, loss=loss, train_op=train_op)
150
151
152def logistic_model_no_mode_fn(features, labels):
153  features = extract(features, 'input')
154  labels = extract(labels, 'labels')
155  labels = array_ops.one_hot(labels, 3, 1, 0)
156  prediction, loss = (models.logistic_regression_zero_init(features, labels))
157  train_op = optimizers.optimize_loss(
158      loss,
159      training_util.get_global_step(),
160      optimizer='Adagrad',
161      learning_rate=0.1)
162  return {
163      'class': math_ops.argmax(prediction, 1),
164      'prob': prediction
165  }, loss, train_op
166
167
168VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
169EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
170
171
172class EstimatorInputTest(test.TestCase):
173
174  def testContinueTrainingDictionaryInput(self):
175    boston = base.load_boston()
176    output_dir = tempfile.mkdtemp()
177    est = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)
178    boston_input = {'input': boston.data}
179    float64_target = {'labels': boston.target.astype(np.float64)}
180    est.fit(x=boston_input, y=float64_target, steps=50)
181    scores = est.evaluate(
182        x=boston_input,
183        y=float64_target,
184        metrics={
185            'MSE': metric_ops.streaming_mean_squared_error
186        })
187    del est
188    # Create another estimator object with the same output dir.
189    est2 = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)
190
191    # Check we can evaluate and predict.
192    scores2 = est2.evaluate(
193        x=boston_input,
194        y=float64_target,
195        metrics={
196            'MSE': metric_ops.streaming_mean_squared_error
197        })
198    self.assertAllClose(scores2['MSE'], scores['MSE'])
199    predictions = np.array(list(est2.predict(x=boston_input)))
200    other_score = _sklearn.mean_squared_error(predictions,
201                                              float64_target['labels'])
202    self.assertAllClose(other_score, scores['MSE'])
203
204  def testBostonAll(self):
205    boston = base.load_boston()
206    est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn))
207    float64_labels = boston.target.astype(np.float64)
208    est.fit(x=boston.data, y=float64_labels, steps=100)
209    scores = est.score(
210        x=boston.data,
211        y=float64_labels,
212        metrics={
213            'MSE': metric_ops.streaming_mean_squared_error
214        })
215    predictions = np.array(list(est.predict(x=boston.data)))
216    other_score = _sklearn.mean_squared_error(predictions, boston.target)
217    self.assertAllClose(scores['MSE'], other_score)
218    self.assertTrue('global_step' in scores)
219    self.assertEqual(100, scores['global_step'])
220
221  def testBostonAllDictionaryInput(self):
222    boston = base.load_boston()
223    est = estimator.Estimator(model_fn=linear_model_fn)
224    boston_input = {'input': boston.data}
225    float64_target = {'labels': boston.target.astype(np.float64)}
226    est.fit(x=boston_input, y=float64_target, steps=100)
227    scores = est.evaluate(
228        x=boston_input,
229        y=float64_target,
230        metrics={
231            'MSE': metric_ops.streaming_mean_squared_error
232        })
233    predictions = np.array(list(est.predict(x=boston_input)))
234    other_score = _sklearn.mean_squared_error(predictions, boston.target)
235    self.assertAllClose(other_score, scores['MSE'])
236    self.assertTrue('global_step' in scores)
237    self.assertEqual(scores['global_step'], 100)
238
239  def testIrisAll(self):
240    iris = base.load_iris()
241    est = estimator.SKCompat(
242        estimator.Estimator(model_fn=logistic_model_no_mode_fn))
243    est.fit(iris.data, iris.target, steps=100)
244    scores = est.score(
245        x=iris.data,
246        y=iris.target,
247        metrics={
248            ('accuracy', 'class'): metric_ops.streaming_accuracy
249        })
250    predictions = est.predict(x=iris.data)
251    predictions_class = est.predict(x=iris.data, outputs=['class'])['class']
252    self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0])
253    self.assertAllClose(predictions['class'], predictions_class)
254    self.assertAllClose(predictions['class'],
255                        np.argmax(predictions['prob'], axis=1))
256    other_score = _sklearn.accuracy_score(iris.target, predictions['class'])
257    self.assertAllClose(scores['accuracy'], other_score)
258    self.assertTrue('global_step' in scores)
259    self.assertEqual(100, scores['global_step'])
260
261  def testIrisAllDictionaryInput(self):
262    iris = base.load_iris()
263    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
264    iris_data = {'input': iris.data}
265    iris_target = {'labels': iris.target}
266    est.fit(iris_data, iris_target, steps=100)
267    scores = est.evaluate(
268        x=iris_data,
269        y=iris_target,
270        metrics={
271            ('accuracy', 'class'): metric_ops.streaming_accuracy
272        })
273    predictions = list(est.predict(x=iris_data))
274    predictions_class = list(est.predict(x=iris_data, outputs=['class']))
275    self.assertEqual(len(predictions), iris.target.shape[0])
276    classes_batch = np.array([p['class'] for p in predictions])
277    self.assertAllClose(classes_batch,
278                        np.array([p['class'] for p in predictions_class]))
279    self.assertAllClose(classes_batch,
280                        np.argmax(
281                            np.array([p['prob'] for p in predictions]), axis=1))
282    other_score = _sklearn.accuracy_score(iris.target, classes_batch)
283    self.assertAllClose(other_score, scores['accuracy'])
284    self.assertTrue('global_step' in scores)
285    self.assertEqual(scores['global_step'], 100)
286
287  def testIrisInputFn(self):
288    iris = base.load_iris()
289    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
290    est.fit(input_fn=iris_input_fn, steps=100)
291    _ = est.evaluate(input_fn=iris_input_fn, steps=1)
292    predictions = list(est.predict(x=iris.data))
293    self.assertEqual(len(predictions), iris.target.shape[0])
294
295  def testIrisInputFnLabelsDict(self):
296    iris = base.load_iris()
297    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
298    est.fit(input_fn=iris_input_fn_labels_dict, steps=100)
299    _ = est.evaluate(
300        input_fn=iris_input_fn_labels_dict,
301        steps=1,
302        metrics={
303            'accuracy':
304                metric_spec.MetricSpec(
305                    metric_fn=metric_ops.streaming_accuracy,
306                    prediction_key='class',
307                    label_key='labels')
308        })
309    predictions = list(est.predict(x=iris.data))
310    self.assertEqual(len(predictions), iris.target.shape[0])
311
312  def testTrainInputFn(self):
313    est = estimator.Estimator(model_fn=linear_model_fn)
314    est.fit(input_fn=boston_input_fn, steps=1)
315    _ = est.evaluate(input_fn=boston_eval_fn, steps=1)
316
317  def testPredictInputFn(self):
318    est = estimator.Estimator(model_fn=linear_model_fn)
319    boston = base.load_boston()
320    est.fit(input_fn=boston_input_fn, steps=1)
321    input_fn = functools.partial(boston_input_fn, num_epochs=1)
322    output = list(est.predict(input_fn=input_fn))
323    self.assertEqual(len(output), boston.target.shape[0])
324
325  def testPredictInputFnWithQueue(self):
326    est = estimator.Estimator(model_fn=linear_model_fn)
327    boston = base.load_boston()
328    est.fit(input_fn=boston_input_fn, steps=1)
329    input_fn = functools.partial(boston_input_fn_with_queue, num_epochs=2)
330    output = list(est.predict(input_fn=input_fn))
331    self.assertEqual(len(output), boston.target.shape[0] * 2)
332
333  def testPredictConstInputFn(self):
334    est = estimator.Estimator(model_fn=linear_model_fn)
335    boston = base.load_boston()
336    est.fit(input_fn=boston_input_fn, steps=1)
337
338    def input_fn():
339      features = array_ops.reshape(
340          constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM])
341      labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1])
342      return features, labels
343
344    output = list(est.predict(input_fn=input_fn))
345    self.assertEqual(len(output), boston.target.shape[0])
346
347
348if __name__ == '__main__':
349  test.main()
350