• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Estimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import itertools
23import json
24import os
25import tempfile
26
27import numpy as np
28import six
29from six.moves import xrange  # pylint: disable=redefined-builtin
30
31from google.protobuf import text_format
32
33from tensorflow.contrib import learn
34from tensorflow.contrib import lookup
35from tensorflow.python.training import training_util
36from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
37from tensorflow.contrib.layers.python.layers import optimizers
38from tensorflow.contrib.learn.python.learn import experiment
39from tensorflow.contrib.learn.python.learn import models
40from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
41from tensorflow.contrib.learn.python.learn.datasets import base
42from tensorflow.contrib.learn.python.learn.estimators import _sklearn
43from tensorflow.contrib.learn.python.learn.estimators import constants
44from tensorflow.contrib.learn.python.learn.estimators import estimator
45from tensorflow.contrib.learn.python.learn.estimators import linear
46from tensorflow.contrib.learn.python.learn.estimators import model_fn
47from tensorflow.contrib.learn.python.learn.estimators import run_config
48from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
49from tensorflow.contrib.metrics.python.ops import metric_ops
50from tensorflow.contrib.testing.python.framework import util_test
51from tensorflow.python.client import session as session_lib
52from tensorflow.python.framework import constant_op
53from tensorflow.python.framework import dtypes
54from tensorflow.python.framework import ops
55from tensorflow.python.lib.io import file_io
56from tensorflow.python.ops import array_ops
57from tensorflow.python.ops import check_ops
58from tensorflow.python.ops import control_flow_ops
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import parsing_ops
61from tensorflow.python.ops import variables as variables_lib
62from tensorflow.python.platform import gfile
63from tensorflow.python.platform import test
64from tensorflow.python.saved_model import loader
65from tensorflow.python.saved_model import tag_constants
66from tensorflow.python.summary import summary
67from tensorflow.python.training import basic_session_run_hooks
68from tensorflow.python.training import checkpoint_state_pb2
69from tensorflow.python.training import input as input_lib
70from tensorflow.python.training import monitored_session
71from tensorflow.python.training import saver as saver_lib
72from tensorflow.python.training import session_run_hook
73from tensorflow.python.util import compat
74
75_BOSTON_INPUT_DIM = 13
76_IRIS_INPUT_DIM = 4
77
78
79def boston_input_fn(num_epochs=None):
80  boston = base.load_boston()
81  features = input_lib.limit_epochs(
82      array_ops.reshape(
83          constant_op.constant(boston.data), [-1, _BOSTON_INPUT_DIM]),
84      num_epochs=num_epochs)
85  labels = array_ops.reshape(constant_op.constant(boston.target), [-1, 1])
86  return features, labels
87
88
89def iris_input_fn():
90  iris = base.load_iris()
91  features = array_ops.reshape(
92      constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM])
93  labels = array_ops.reshape(constant_op.constant(iris.target), [-1])
94  return features, labels
95
96
97def iris_input_fn_labels_dict():
98  iris = base.load_iris()
99  features = array_ops.reshape(
100      constant_op.constant(iris.data), [-1, _IRIS_INPUT_DIM])
101  labels = {
102      'labels': array_ops.reshape(constant_op.constant(iris.target), [-1])
103  }
104  return features, labels
105
106
107def boston_eval_fn():
108  boston = base.load_boston()
109  n_examples = len(boston.target)
110  features = array_ops.reshape(
111      constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM])
112  labels = array_ops.reshape(
113      constant_op.constant(boston.target), [n_examples, 1])
114  return array_ops.concat([features, features],
115                          0), array_ops.concat([labels, labels], 0)
116
117
118def extract(data, key):
119  if isinstance(data, dict):
120    assert key in data
121    return data[key]
122  else:
123    return data
124
125
126def linear_model_params_fn(features, labels, mode, params):
127  features = extract(features, 'input')
128  labels = extract(labels, 'labels')
129
130  assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
131                  model_fn.ModeKeys.INFER)
132  prediction, loss = (models.linear_regression_zero_init(features, labels))
133  train_op = optimizers.optimize_loss(
134      loss,
135      training_util.get_global_step(),
136      optimizer='Adagrad',
137      learning_rate=params['learning_rate'])
138  return prediction, loss, train_op
139
140
141def linear_model_fn(features, labels, mode):
142  features = extract(features, 'input')
143  labels = extract(labels, 'labels')
144  assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
145                  model_fn.ModeKeys.INFER)
146  if isinstance(features, dict):
147    (_, features), = features.items()
148  prediction, loss = (models.linear_regression_zero_init(features, labels))
149  train_op = optimizers.optimize_loss(
150      loss,
151      training_util.get_global_step(),
152      optimizer='Adagrad',
153      learning_rate=0.1)
154  return prediction, loss, train_op
155
156
157def linear_model_fn_with_model_fn_ops(features, labels, mode):
158  """Same as linear_model_fn, but returns `ModelFnOps`."""
159  assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
160                  model_fn.ModeKeys.INFER)
161  prediction, loss = (models.linear_regression_zero_init(features, labels))
162  train_op = optimizers.optimize_loss(
163      loss,
164      training_util.get_global_step(),
165      optimizer='Adagrad',
166      learning_rate=0.1)
167  return model_fn.ModelFnOps(
168      mode=mode, predictions=prediction, loss=loss, train_op=train_op)
169
170
171def logistic_model_no_mode_fn(features, labels):
172  features = extract(features, 'input')
173  labels = extract(labels, 'labels')
174  labels = array_ops.one_hot(labels, 3, 1, 0)
175  prediction, loss = (models.logistic_regression_zero_init(features, labels))
176  train_op = optimizers.optimize_loss(
177      loss,
178      training_util.get_global_step(),
179      optimizer='Adagrad',
180      learning_rate=0.1)
181  return {
182      'class': math_ops.argmax(prediction, 1),
183      'prob': prediction
184  }, loss, train_op
185
186
187VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
188EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
189
190
191def _build_estimator_for_export_tests(tmpdir):
192
193  def _input_fn():
194    iris = base.load_iris()
195    return {
196        'feature': constant_op.constant(iris.data, dtype=dtypes.float32)
197    }, constant_op.constant(
198        iris.target, shape=[150], dtype=dtypes.int32)
199
200  feature_columns = [
201      feature_column_lib.real_valued_column('feature', dimension=4)
202  ]
203
204  est = linear.LinearRegressor(feature_columns)
205  est.fit(input_fn=_input_fn, steps=20)
206
207  feature_spec = feature_column_lib.create_feature_spec_for_parsing(
208      feature_columns)
209  serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
210
211  # hack in an op that uses an asset, in order to test asset export.
212  # this is not actually valid, of course.
213  def serving_input_fn_with_asset():
214    features, labels, inputs = serving_input_fn()
215
216    vocab_file_name = os.path.join(tmpdir, 'my_vocab_file')
217    vocab_file = gfile.GFile(vocab_file_name, mode='w')
218    vocab_file.write(VOCAB_FILE_CONTENT)
219    vocab_file.close()
220    hashtable = lookup.HashTable(
221        lookup.TextFileStringTableInitializer(vocab_file_name), 'x')
222    features['bogus_lookup'] = hashtable.lookup(
223        math_ops.cast(features['feature'], dtypes.int64))
224
225    return input_fn_utils.InputFnOps(features, labels, inputs)
226
227  return est, serving_input_fn_with_asset
228
229
230def _build_estimator_for_resource_export_test():
231
232  def _input_fn():
233    iris = base.load_iris()
234    return {
235        'feature': constant_op.constant(iris.data, dtype=dtypes.float32)
236    }, constant_op.constant(
237        iris.target, shape=[150], dtype=dtypes.int32)
238
239  feature_columns = [
240      feature_column_lib.real_valued_column('feature', dimension=4)
241  ]
242
243  def resource_constant_model_fn(unused_features, unused_labels, mode):
244    """A model_fn that loads a constant from a resource and serves it."""
245    assert mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
246                    model_fn.ModeKeys.INFER)
247
248    const = constant_op.constant(-1, dtype=dtypes.int64)
249    table = lookup.MutableHashTable(
250        dtypes.string, dtypes.int64, const, name='LookupTableModel')
251    update_global_step = training_util.get_global_step().assign_add(1)
252    if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL):
253      key = constant_op.constant(['key'])
254      value = constant_op.constant([42], dtype=dtypes.int64)
255      train_op_1 = table.insert(key, value)
256      training_state = lookup.MutableHashTable(
257          dtypes.string, dtypes.int64, const, name='LookupTableTrainingState')
258      training_op_2 = training_state.insert(key, value)
259      return (const, const,
260              control_flow_ops.group(train_op_1, training_op_2,
261                                     update_global_step))
262    if mode == model_fn.ModeKeys.INFER:
263      key = constant_op.constant(['key'])
264      prediction = table.lookup(key)
265      return prediction, const, update_global_step
266
267  est = estimator.Estimator(model_fn=resource_constant_model_fn)
268  est.fit(input_fn=_input_fn, steps=1)
269
270  feature_spec = feature_column_lib.create_feature_spec_for_parsing(
271      feature_columns)
272  serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
273  return est, serving_input_fn
274
275
276class CheckCallsMonitor(monitors_lib.BaseMonitor):
277
278  def __init__(self, expect_calls):
279    super(CheckCallsMonitor, self).__init__()
280    self.begin_calls = None
281    self.end_calls = None
282    self.expect_calls = expect_calls
283
284  def begin(self, max_steps):
285    self.begin_calls = 0
286    self.end_calls = 0
287
288  def step_begin(self, step):
289    self.begin_calls += 1
290    return {}
291
292  def step_end(self, step, outputs):
293    self.end_calls += 1
294    return False
295
296  def end(self):
297    assert (self.end_calls == self.expect_calls and
298            self.begin_calls == self.expect_calls)
299
300
301def _model_fn_ops(expected_features, expected_labels, actual_features,
302                  actual_labels, mode):
303  assert_ops = tuple([
304      check_ops.assert_equal(
305          expected_features[k], actual_features[k], name='assert_%s' % k)
306      for k in expected_features
307  ] + [
308      check_ops.assert_equal(
309          expected_labels, actual_labels, name='assert_labels')
310  ])
311  with ops.control_dependencies(assert_ops):
312    return model_fn.ModelFnOps(
313        mode=mode,
314        predictions=constant_op.constant(0.),
315        loss=constant_op.constant(0.),
316        train_op=training_util.get_global_step().assign_add(1))
317
318
319def _make_input_fn(features, labels):
320
321  def _input_fn():
322    return {k: constant_op.constant(v)
323            for k, v in six.iteritems(features)}, constant_op.constant(labels)
324
325  return _input_fn
326
327
328class EstimatorModelFnTest(test.TestCase):
329
330  def testModelFnArgs(self):
331    features = {'x': 42., 'y': 43.}
332    labels = 44.
333    expected_params = {'some_param': 'some_value'}
334    expected_config = run_config.RunConfig()
335    expected_config.i_am_test = True
336
337    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments
338    # doesn't work with mock fns.
339    model_fn_call_count = [0]
340
341    # `features` and `labels` are passed by position, `arg0` and `arg1` here.
342    def _model_fn(arg0, arg1, mode, params, config):
343      model_fn_call_count[0] += 1
344      self.assertItemsEqual(features.keys(), arg0.keys())
345      self.assertEqual(model_fn.ModeKeys.TRAIN, mode)
346      self.assertEqual(expected_params, params)
347      self.assertTrue(config.i_am_test)
348      return _model_fn_ops(features, labels, arg0, arg1, mode)
349
350    est = estimator.Estimator(
351        model_fn=_model_fn, params=expected_params, config=expected_config)
352    self.assertEqual(0, model_fn_call_count[0])
353    est.fit(input_fn=_make_input_fn(features, labels), steps=1)
354    self.assertEqual(1, model_fn_call_count[0])
355
356  def testPartialModelFnArgs(self):
357    features = {'x': 42., 'y': 43.}
358    labels = 44.
359    expected_params = {'some_param': 'some_value'}
360    expected_config = run_config.RunConfig()
361    expected_config.i_am_test = True
362    expected_foo = 45.
363    expected_bar = 46.
364
365    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments
366    # doesn't work with mock fns.
367    model_fn_call_count = [0]
368
369    # `features` and `labels` are passed by position, `arg0` and `arg1` here.
370    def _model_fn(arg0, arg1, foo, mode, params, config, bar):
371      model_fn_call_count[0] += 1
372      self.assertEqual(expected_foo, foo)
373      self.assertEqual(expected_bar, bar)
374      self.assertItemsEqual(features.keys(), arg0.keys())
375      self.assertEqual(model_fn.ModeKeys.TRAIN, mode)
376      self.assertEqual(expected_params, params)
377      self.assertTrue(config.i_am_test)
378      return _model_fn_ops(features, labels, arg0, arg1, mode)
379
380    partial_model_fn = functools.partial(
381        _model_fn, foo=expected_foo, bar=expected_bar)
382
383    est = estimator.Estimator(
384        model_fn=partial_model_fn,
385        params=expected_params,
386        config=expected_config)
387    self.assertEqual(0, model_fn_call_count[0])
388    est.fit(input_fn=_make_input_fn(features, labels), steps=1)
389    self.assertEqual(1, model_fn_call_count[0])
390
391  def testModelFnWithModelDir(self):
392    expected_param = {'some_param': 'some_value'}
393    expected_model_dir = tempfile.mkdtemp()
394
395    def _argument_checker(features,
396                          labels,
397                          mode,
398                          params,
399                          config=None,
400                          model_dir=None):
401      _, _, _ = features, labels, config
402      self.assertEqual(model_fn.ModeKeys.TRAIN, mode)
403      self.assertEqual(expected_param, params)
404      self.assertEqual(model_dir, expected_model_dir)
405      return (constant_op.constant(0.), constant_op.constant(0.),
406              training_util.get_global_step().assign_add(1))
407
408    est = estimator.Estimator(
409        model_fn=_argument_checker,
410        params=expected_param,
411        model_dir=expected_model_dir)
412    est.fit(input_fn=boston_input_fn, steps=1)
413
414  def testInvalidModelFn_no_train_op(self):
415
416    def _invalid_model_fn(features, labels):
417      # pylint: disable=unused-argument
418      w = variables_lib.Variable(42.0, 'weight')
419      update_global_step = training_util.get_global_step().assign_add(1)
420      with ops.control_dependencies([update_global_step]):
421        loss = 100.0 - w
422      return None, loss, None
423
424    est = estimator.Estimator(model_fn=_invalid_model_fn)
425    with self.assertRaisesRegexp(ValueError, 'Missing train_op'):
426      est.fit(input_fn=boston_input_fn, steps=1)
427
428  def testInvalidModelFn_no_loss(self):
429
430    def _invalid_model_fn(features, labels, mode):
431      # pylint: disable=unused-argument
432      w = variables_lib.Variable(42.0, 'weight')
433      loss = 100.0 - w
434      update_global_step = training_util.get_global_step().assign_add(1)
435      with ops.control_dependencies([update_global_step]):
436        train_op = w.assign_add(loss / 100.0)
437      predictions = loss
438      if mode == model_fn.ModeKeys.EVAL:
439        loss = None
440      return predictions, loss, train_op
441
442    est = estimator.Estimator(model_fn=_invalid_model_fn)
443    est.fit(input_fn=boston_input_fn, steps=1)
444    with self.assertRaisesRegexp(ValueError, 'Missing loss'):
445      est.evaluate(input_fn=boston_eval_fn, steps=1)
446
447  def testInvalidModelFn_no_prediction(self):
448
449    def _invalid_model_fn(features, labels):
450      # pylint: disable=unused-argument
451      w = variables_lib.Variable(42.0, 'weight')
452      loss = 100.0 - w
453      update_global_step = training_util.get_global_step().assign_add(1)
454      with ops.control_dependencies([update_global_step]):
455        train_op = w.assign_add(loss / 100.0)
456      return None, loss, train_op
457
458    est = estimator.Estimator(model_fn=_invalid_model_fn)
459    est.fit(input_fn=boston_input_fn, steps=1)
460    with self.assertRaisesRegexp(ValueError, 'Missing prediction'):
461      est.evaluate(input_fn=boston_eval_fn, steps=1)
462    with self.assertRaisesRegexp(ValueError, 'Missing prediction'):
463      est.predict(input_fn=boston_input_fn)
464    with self.assertRaisesRegexp(ValueError, 'Missing prediction'):
465      est.predict(
466          input_fn=functools.partial(boston_input_fn, num_epochs=1),
467          as_iterable=True)
468
469  def testModelFnScaffoldInTraining(self):
470    self.is_init_fn_called = False
471
472    def _init_fn(scaffold, session):
473      _, _ = scaffold, session
474      self.is_init_fn_called = True
475
476    def _model_fn_scaffold(features, labels, mode):
477      _, _ = features, labels
478      return model_fn.ModelFnOps(
479          mode=mode,
480          predictions=constant_op.constant(0.),
481          loss=constant_op.constant(0.),
482          train_op=training_util.get_global_step().assign_add(1),
483          scaffold=monitored_session.Scaffold(init_fn=_init_fn))
484
485    est = estimator.Estimator(model_fn=_model_fn_scaffold)
486    est.fit(input_fn=boston_input_fn, steps=1)
487    self.assertTrue(self.is_init_fn_called)
488
489  def testModelFnScaffoldSaverUsage(self):
490
491    def _model_fn_scaffold(features, labels, mode):
492      _, _ = features, labels
493      variables_lib.Variable(1., 'weight')
494      real_saver = saver_lib.Saver()
495      self.mock_saver = test.mock.Mock(
496          wraps=real_saver, saver_def=real_saver.saver_def)
497      return model_fn.ModelFnOps(
498          mode=mode,
499          predictions=constant_op.constant([[1.]]),
500          loss=constant_op.constant(0.),
501          train_op=training_util.get_global_step().assign_add(1),
502          scaffold=monitored_session.Scaffold(saver=self.mock_saver))
503
504    def input_fn():
505      return {
506          'x': constant_op.constant([[1.]]),
507      }, constant_op.constant([[1.]])
508
509    est = estimator.Estimator(model_fn=_model_fn_scaffold)
510    est.fit(input_fn=input_fn, steps=1)
511    self.assertTrue(self.mock_saver.save.called)
512    est.evaluate(input_fn=input_fn, steps=1)
513    self.assertTrue(self.mock_saver.restore.called)
514    est.predict(input_fn=input_fn)
515    self.assertTrue(self.mock_saver.restore.called)
516
517    def serving_input_fn():
518      serialized_tf_example = array_ops.placeholder(
519          dtype=dtypes.string, shape=[None], name='input_example_tensor')
520      features, labels = input_fn()
521      return input_fn_utils.InputFnOps(features, labels, {
522          'examples': serialized_tf_example
523      })
524
525    est.export_savedmodel(
526        os.path.join(est.model_dir, 'export'), serving_input_fn)
527    self.assertTrue(self.mock_saver.restore.called)
528
529
530class EstimatorTest(test.TestCase):
531
532  def testExperimentIntegration(self):
533    exp = experiment.Experiment(
534        estimator=estimator.Estimator(model_fn=linear_model_fn),
535        train_input_fn=boston_input_fn,
536        eval_input_fn=boston_input_fn)
537    exp.test()
538
539  def testCheckpointSaverHookSuppressesTheDefaultOne(self):
540    saver_hook = test.mock.Mock(
541        spec=basic_session_run_hooks.CheckpointSaverHook)
542    saver_hook.before_run.return_value = None
543    est = estimator.Estimator(model_fn=linear_model_fn)
544    est.fit(input_fn=boston_input_fn, steps=1, monitors=[saver_hook])
545    # test nothing is saved, due to suppressing default saver
546    with self.assertRaises(learn.NotFittedError):
547      est.evaluate(input_fn=boston_input_fn, steps=1)
548
549  def testCustomConfig(self):
550    test_random_seed = 5783452
551
552    class TestInput(object):
553
554      def __init__(self):
555        self.random_seed = 0
556
557      def config_test_input_fn(self):
558        self.random_seed = ops.get_default_graph().seed
559        return constant_op.constant([[1.]]), constant_op.constant([1.])
560
561    config = run_config.RunConfig(tf_random_seed=test_random_seed)
562    test_input = TestInput()
563    est = estimator.Estimator(model_fn=linear_model_fn, config=config)
564    est.fit(input_fn=test_input.config_test_input_fn, steps=1)
565    # If input_fn ran, it will have given us the random seed set on the graph.
566    self.assertEquals(test_random_seed, test_input.random_seed)
567
568  def testRunConfigModelDir(self):
569    config = run_config.RunConfig(model_dir='test_dir')
570    est = estimator.Estimator(model_fn=linear_model_fn, config=config)
571    self.assertEqual('test_dir', est.config.model_dir)
572    self.assertEqual('test_dir', est.model_dir)
573
574  def testModelDirAndRunConfigModelDir(self):
575    config = run_config.RunConfig(model_dir='test_dir')
576    est = estimator.Estimator(
577        model_fn=linear_model_fn, config=config, model_dir='test_dir')
578    self.assertEqual('test_dir', est.config.model_dir)
579
580    with self.assertRaisesRegexp(
581        ValueError, 'model_dir are set both in constructor and RunConfig, '
582        'but with different'):
583      estimator.Estimator(
584          model_fn=linear_model_fn, config=config, model_dir='different_dir')
585
586  def testModelDirIsCopiedToRunConfig(self):
587    config = run_config.RunConfig()
588    self.assertIsNone(config.model_dir)
589
590    est = estimator.Estimator(
591        model_fn=linear_model_fn, model_dir='test_dir', config=config)
592    self.assertEqual('test_dir', est.config.model_dir)
593    self.assertEqual('test_dir', est.model_dir)
594
595  def testModelDirAsTempDir(self):
596    with test.mock.patch.object(tempfile, 'mkdtemp', return_value='temp_dir'):
597      est = estimator.Estimator(model_fn=linear_model_fn)
598      self.assertEqual('temp_dir', est.config.model_dir)
599      self.assertEqual('temp_dir', est.model_dir)
600
601  def testCheckInputs(self):
602    est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn))
603    # Lambdas so we have to different objects to compare
604    right_features = lambda: np.ones(shape=[7, 8], dtype=np.float32)
605    right_labels = lambda: np.ones(shape=[7, 10], dtype=np.int32)
606    est.fit(right_features(), right_labels(), steps=1)
607    # TODO(wicke): This does not fail for np.int32 because of data_feeder magic.
608    wrong_type_features = np.ones(shape=[7, 8], dtype=np.int64)
609    wrong_size_features = np.ones(shape=[7, 10])
610    wrong_type_labels = np.ones(shape=[7, 10], dtype=np.float32)
611    wrong_size_labels = np.ones(shape=[7, 11])
612    est.fit(x=right_features(), y=right_labels(), steps=1)
613    with self.assertRaises(ValueError):
614      est.fit(x=wrong_type_features, y=right_labels(), steps=1)
615    with self.assertRaises(ValueError):
616      est.fit(x=wrong_size_features, y=right_labels(), steps=1)
617    with self.assertRaises(ValueError):
618      est.fit(x=right_features(), y=wrong_type_labels, steps=1)
619    with self.assertRaises(ValueError):
620      est.fit(x=right_features(), y=wrong_size_labels, steps=1)
621
622  def testBadInput(self):
623    est = estimator.Estimator(model_fn=linear_model_fn)
624    self.assertRaisesRegexp(
625        ValueError,
626        'Either x or input_fn must be provided.',
627        est.fit,
628        x=None,
629        input_fn=None,
630        steps=1)
631    self.assertRaisesRegexp(
632        ValueError,
633        'Can not provide both input_fn and x or y',
634        est.fit,
635        x='X',
636        input_fn=iris_input_fn,
637        steps=1)
638    self.assertRaisesRegexp(
639        ValueError,
640        'Can not provide both input_fn and x or y',
641        est.fit,
642        y='Y',
643        input_fn=iris_input_fn,
644        steps=1)
645    self.assertRaisesRegexp(
646        ValueError,
647        'Can not provide both input_fn and batch_size',
648        est.fit,
649        input_fn=iris_input_fn,
650        batch_size=100,
651        steps=1)
652    self.assertRaisesRegexp(
653        ValueError,
654        'Inputs cannot be tensors. Please provide input_fn.',
655        est.fit,
656        x=constant_op.constant(1.),
657        steps=1)
658
659  def testUntrained(self):
660    boston = base.load_boston()
661    est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn))
662    with self.assertRaises(learn.NotFittedError):
663      _ = est.score(x=boston.data, y=boston.target.astype(np.float64))
664    with self.assertRaises(learn.NotFittedError):
665      est.predict(x=boston.data)
666
667  def testContinueTraining(self):
668    boston = base.load_boston()
669    output_dir = tempfile.mkdtemp()
670    est = estimator.SKCompat(
671        estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir))
672    float64_labels = boston.target.astype(np.float64)
673    est.fit(x=boston.data, y=float64_labels, steps=50)
674    scores = est.score(
675        x=boston.data,
676        y=float64_labels,
677        metrics={
678            'MSE': metric_ops.streaming_mean_squared_error
679        })
680    del est
681    # Create another estimator object with the same output dir.
682    est2 = estimator.SKCompat(
683        estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir))
684
685    # Check we can evaluate and predict.
686    scores2 = est2.score(
687        x=boston.data,
688        y=float64_labels,
689        metrics={
690            'MSE': metric_ops.streaming_mean_squared_error
691        })
692    self.assertAllClose(scores['MSE'], scores2['MSE'])
693    predictions = np.array(list(est2.predict(x=boston.data)))
694    other_score = _sklearn.mean_squared_error(predictions, float64_labels)
695    self.assertAllClose(scores['MSE'], other_score)
696
697    # Check we can keep training.
698    est2.fit(x=boston.data, y=float64_labels, steps=100)
699    scores3 = est2.score(
700        x=boston.data,
701        y=float64_labels,
702        metrics={
703            'MSE': metric_ops.streaming_mean_squared_error
704        })
705    self.assertLess(scores3['MSE'], scores['MSE'])
706
707  def test_checkpoint_contains_relative_paths(self):
708    tmpdir = tempfile.mkdtemp()
709    est = estimator.Estimator(
710        model_dir=tmpdir, model_fn=linear_model_fn_with_model_fn_ops)
711    est.fit(input_fn=boston_input_fn, steps=5)
712
713    checkpoint_file_content = file_io.read_file_to_string(
714        os.path.join(tmpdir, 'checkpoint'))
715    ckpt = checkpoint_state_pb2.CheckpointState()
716    text_format.Merge(checkpoint_file_content, ckpt)
717    self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
718    # TODO(b/78461127): Please modify tests to not directly rely on names of
719    # checkpoints.
720    self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],
721                        ckpt.all_model_checkpoint_paths)
722
723  def test_train_save_copy_reload(self):
724    tmpdir = tempfile.mkdtemp()
725    model_dir1 = os.path.join(tmpdir, 'model_dir1')
726    est1 = estimator.Estimator(
727        model_dir=model_dir1, model_fn=linear_model_fn_with_model_fn_ops)
728    est1.fit(input_fn=boston_input_fn, steps=5)
729
730    model_dir2 = os.path.join(tmpdir, 'model_dir2')
731    os.renames(model_dir1, model_dir2)
732    est2 = estimator.Estimator(
733        model_dir=model_dir2, model_fn=linear_model_fn_with_model_fn_ops)
734    self.assertEqual(5, est2.get_variable_value('global_step'))
735    est2.fit(input_fn=boston_input_fn, steps=5)
736    self.assertEqual(10, est2.get_variable_value('global_step'))
737
738  def testEstimatorParams(self):
739    boston = base.load_boston()
740    est = estimator.SKCompat(
741        estimator.Estimator(
742            model_fn=linear_model_params_fn, params={
743                'learning_rate': 0.01
744            }))
745    est.fit(x=boston.data, y=boston.target, steps=100)
746
747  def testHooksNotChanged(self):
748    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
749    # We pass empty array and expect it to remain empty after calling
750    # fit and evaluate. Requires inside to copy this array if any hooks were
751    # added.
752    my_array = []
753    est.fit(input_fn=iris_input_fn, steps=100, monitors=my_array)
754    _ = est.evaluate(input_fn=iris_input_fn, steps=1, hooks=my_array)
755    self.assertEqual(my_array, [])
756
757  def testIrisIterator(self):
758    iris = base.load_iris()
759    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
760    x_iter = itertools.islice(iris.data, 100)
761    y_iter = itertools.islice(iris.target, 100)
762    estimator.SKCompat(est).fit(x_iter, y_iter, steps=20)
763    eval_result = est.evaluate(input_fn=iris_input_fn, steps=1)
764    x_iter_eval = itertools.islice(iris.data, 100)
765    y_iter_eval = itertools.islice(iris.target, 100)
766    score_result = estimator.SKCompat(est).score(x_iter_eval, y_iter_eval)
767    print(score_result)
768    self.assertItemsEqual(eval_result.keys(), score_result.keys())
769    self.assertItemsEqual(['global_step', 'loss'], score_result.keys())
770    predictions = estimator.SKCompat(est).predict(x=iris.data)['class']
771    self.assertEqual(len(predictions), iris.target.shape[0])
772
773  def testIrisIteratorArray(self):
774    iris = base.load_iris()
775    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
776    x_iter = itertools.islice(iris.data, 100)
777    y_iter = (np.array(x) for x in iris.target)
778    est.fit(x_iter, y_iter, steps=100)
779    _ = est.evaluate(input_fn=iris_input_fn, steps=1)
780    _ = six.next(est.predict(x=iris.data))['class']
781
782  def testIrisIteratorPlainInt(self):
783    iris = base.load_iris()
784    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
785    x_iter = itertools.islice(iris.data, 100)
786    y_iter = (v for v in iris.target)
787    est.fit(x_iter, y_iter, steps=100)
788    _ = est.evaluate(input_fn=iris_input_fn, steps=1)
789    _ = six.next(est.predict(x=iris.data))['class']
790
791  def testIrisTruncatedIterator(self):
792    iris = base.load_iris()
793    est = estimator.Estimator(model_fn=logistic_model_no_mode_fn)
794    x_iter = itertools.islice(iris.data, 50)
795    y_iter = ([np.int32(v)] for v in iris.target)
796    est.fit(x_iter, y_iter, steps=100)
797
798  def testTrainStepsIsIncremental(self):
799    est = estimator.Estimator(model_fn=linear_model_fn)
800    est.fit(input_fn=boston_input_fn, steps=10)
801    self.assertEqual(10, est.get_variable_value('global_step'))
802    est.fit(input_fn=boston_input_fn, steps=15)
803    self.assertEqual(25, est.get_variable_value('global_step'))
804
805  def testTrainMaxStepsIsNotIncremental(self):
806    est = estimator.Estimator(model_fn=linear_model_fn)
807    est.fit(input_fn=boston_input_fn, max_steps=10)
808    self.assertEqual(10, est.get_variable_value('global_step'))
809    est.fit(input_fn=boston_input_fn, max_steps=15)
810    self.assertEqual(15, est.get_variable_value('global_step'))
811
812  def testPredict(self):
813    est = estimator.Estimator(model_fn=linear_model_fn)
814    boston = base.load_boston()
815    est.fit(input_fn=boston_input_fn, steps=1)
816    output = list(est.predict(x=boston.data, batch_size=10))
817    self.assertEqual(len(output), boston.target.shape[0])
818
819  def testWithModelFnOps(self):
820    """Test for model_fn that returns `ModelFnOps`."""
821    est = estimator.Estimator(model_fn=linear_model_fn_with_model_fn_ops)
822    boston = base.load_boston()
823    est.fit(input_fn=boston_input_fn, steps=1)
824    input_fn = functools.partial(boston_input_fn, num_epochs=1)
825    scores = est.evaluate(input_fn=input_fn, steps=1)
826    self.assertIn('loss', scores.keys())
827    output = list(est.predict(input_fn=input_fn))
828    self.assertEqual(len(output), boston.target.shape[0])
829
830  def testWrongInput(self):
831
832    def other_input_fn():
833      return {
834          'other': constant_op.constant([0, 0, 0])
835      }, constant_op.constant([0, 0, 0])
836
837    est = estimator.Estimator(model_fn=linear_model_fn)
838    est.fit(input_fn=boston_input_fn, steps=1)
839    with self.assertRaises(ValueError):
840      est.fit(input_fn=other_input_fn, steps=1)
841
842  def testMonitorsForFit(self):
843    est = estimator.Estimator(model_fn=linear_model_fn)
844    est.fit(
845        input_fn=boston_input_fn,
846        steps=21,
847        monitors=[CheckCallsMonitor(expect_calls=21)])
848
849  def testHooksForEvaluate(self):
850
851    class CheckCallHook(session_run_hook.SessionRunHook):
852
853      def __init__(self):
854        self.run_count = 0
855
856      def after_run(self, run_context, run_values):
857        self.run_count += 1
858
859    est = learn.Estimator(model_fn=linear_model_fn)
860    est.fit(input_fn=boston_input_fn, steps=1)
861    hook = CheckCallHook()
862    est.evaluate(input_fn=boston_eval_fn, steps=3, hooks=[hook])
863
864    self.assertEqual(3, hook.run_count)
865
866  def testSummaryWriting(self):
867    est = estimator.Estimator(model_fn=linear_model_fn)
868    est.fit(input_fn=boston_input_fn, steps=200)
869    est.evaluate(input_fn=boston_input_fn, steps=200)
870    loss_summary = util_test.simple_values_from_events(
871        util_test.latest_events(est.model_dir), ['OptimizeLoss/loss'])
872    self.assertEqual(1, len(loss_summary))
873
874  def testSummaryWritingWithSummaryProto(self):
875
876    def _streaming_mean_squared_error_histogram(predictions,
877                                                labels,
878                                                weights=None,
879                                                metrics_collections=None,
880                                                updates_collections=None,
881                                                name=None):
882      metrics, update_ops = metric_ops.streaming_mean_squared_error(
883          predictions,
884          labels,
885          weights=weights,
886          metrics_collections=metrics_collections,
887          updates_collections=updates_collections,
888          name=name)
889      return summary.histogram('histogram', metrics), update_ops
890
891    est = estimator.Estimator(model_fn=linear_model_fn)
892    est.fit(input_fn=boston_input_fn, steps=200)
893    est.evaluate(
894        input_fn=boston_input_fn,
895        steps=200,
896        metrics={
897            'MSE': _streaming_mean_squared_error_histogram
898        })
899    events = util_test.latest_events(est.model_dir + '/eval')
900    output_values = {}
901    for e in events:
902      if e.HasField('summary'):
903        for v in e.summary.value:
904          output_values[v.tag] = v
905    self.assertTrue('MSE' in output_values)
906    self.assertTrue(output_values['MSE'].HasField('histo'))
907
908  def testSummaryWritingWithTensor(self):
909
910    def _streaming_precition_mean_tensor(predictions,
911                                         weights=None,
912                                         metrics_collections=None,
913                                         updates_collections=None,
914                                         name=None):
915      return metric_ops.streaming_mean_tensor(
916          predictions,
917          weights=weights,
918          metrics_collections=metrics_collections,
919          updates_collections=updates_collections,
920          name=name)
921
922    est = estimator.Estimator(model_fn=linear_model_fn)
923    est.fit(input_fn=boston_input_fn, steps=200)
924    est.evaluate(
925        input_fn=boston_input_fn,
926        steps=200,
927        metrics={
928            'PMT': _streaming_precition_mean_tensor
929        })
930    events = util_test.latest_events(est.model_dir + '/eval')
931    output_values = {}
932    for e in events:
933      if e.HasField('summary'):
934        for v in e.summary.value:
935          output_values[v.tag] = v
936    self.assertTrue('PMT' in output_values)
937    self.assertTrue(output_values['PMT'].HasField('tensor'))
938
939  def testLossInGraphCollection(self):
940
941    class _LossCheckerHook(session_run_hook.SessionRunHook):
942
943      def begin(self):
944        self.loss_collection = ops.get_collection(ops.GraphKeys.LOSSES)
945
946    hook = _LossCheckerHook()
947    est = estimator.Estimator(model_fn=linear_model_fn)
948    est.fit(input_fn=boston_input_fn, steps=200, monitors=[hook])
949    self.assertTrue(hook.loss_collection)
950
951  def test_export_returns_exported_dirname(self):
952    expected = '/path/to/some_dir'
953    with test.mock.patch.object(estimator, 'export') as mock_export_module:
954      mock_export_module._export_estimator.return_value = expected
955
956      est = estimator.Estimator(model_fn=linear_model_fn)
957      actual = est.export('/path/to')
958
959    self.assertEquals(expected, actual)
960
961  def test_export_savedmodel(self):
962    tmpdir = tempfile.mkdtemp()
963    est, serving_input_fn = _build_estimator_for_export_tests(tmpdir)
964
965    extra_file_name = os.path.join(
966        compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file'))
967    extra_file = gfile.GFile(extra_file_name, mode='w')
968    extra_file.write(EXTRA_FILE_CONTENT)
969    extra_file.close()
970    assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
971
972    export_dir_base = os.path.join(
973        compat.as_bytes(tmpdir), compat.as_bytes('export'))
974    export_dir = est.export_savedmodel(
975        export_dir_base, serving_input_fn, assets_extra=assets_extra)
976
977    self.assertTrue(gfile.Exists(export_dir_base))
978    self.assertTrue(gfile.Exists(export_dir))
979    self.assertTrue(
980        gfile.Exists(
981            os.path.join(
982                compat.as_bytes(export_dir),
983                compat.as_bytes('saved_model.pb'))))
984    self.assertTrue(
985        gfile.Exists(
986            os.path.join(
987                compat.as_bytes(export_dir), compat.as_bytes('variables'))))
988    self.assertTrue(
989        gfile.Exists(
990            os.path.join(
991                compat.as_bytes(export_dir),
992                compat.as_bytes('variables/variables.index'))))
993    self.assertTrue(
994        gfile.Exists(
995            os.path.join(
996                compat.as_bytes(export_dir),
997                compat.as_bytes('variables/variables.data-00000-of-00001'))))
998
999    self.assertTrue(
1000        gfile.Exists(
1001            os.path.join(
1002                compat.as_bytes(export_dir), compat.as_bytes('assets'))))
1003    self.assertTrue(
1004        gfile.Exists(
1005            os.path.join(
1006                compat.as_bytes(export_dir),
1007                compat.as_bytes('assets/my_vocab_file'))))
1008    self.assertEqual(
1009        compat.as_bytes(VOCAB_FILE_CONTENT),
1010        compat.as_bytes(
1011            gfile.GFile(
1012                os.path.join(
1013                    compat.as_bytes(export_dir),
1014                    compat.as_bytes('assets/my_vocab_file'))).read()))
1015
1016    expected_extra_path = os.path.join(
1017        compat.as_bytes(export_dir),
1018        compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
1019    self.assertTrue(
1020        gfile.Exists(
1021            os.path.join(
1022                compat.as_bytes(export_dir), compat.as_bytes('assets.extra'))))
1023    self.assertTrue(gfile.Exists(expected_extra_path))
1024    self.assertEqual(
1025        compat.as_bytes(EXTRA_FILE_CONTENT),
1026        compat.as_bytes(gfile.GFile(expected_extra_path).read()))
1027
1028    expected_vocab_file = os.path.join(
1029        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
1030    # Restore, to validate that the export was well-formed.
1031    with ops.Graph().as_default() as graph:
1032      with session_lib.Session(graph=graph) as sess:
1033        loader.load(sess, [tag_constants.SERVING], export_dir)
1034        assets = [
1035            x.eval()
1036            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
1037        ]
1038        self.assertItemsEqual([expected_vocab_file], assets)
1039        graph_ops = [x.name for x in graph.get_operations()]
1040        self.assertTrue('input_example_tensor' in graph_ops)
1041        self.assertTrue('ParseExample/ParseExample' in graph_ops)
1042        self.assertTrue('linear/linear/feature/matmul' in graph_ops)
1043        self.assertItemsEqual(['bogus_lookup', 'feature'], [
1044            compat.as_str_any(x)
1045            for x in graph.get_collection(
1046                constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)
1047        ])
1048
1049    # cleanup
1050    gfile.DeleteRecursively(tmpdir)
1051
1052  def test_export_savedmodel_with_resource(self):
1053    tmpdir = tempfile.mkdtemp()
1054    est, serving_input_fn = _build_estimator_for_resource_export_test()
1055
1056    export_dir_base = os.path.join(
1057        compat.as_bytes(tmpdir), compat.as_bytes('export'))
1058    export_dir = est.export_savedmodel(export_dir_base, serving_input_fn)
1059
1060    self.assertTrue(gfile.Exists(export_dir_base))
1061    self.assertTrue(gfile.Exists(export_dir))
1062    self.assertTrue(
1063        gfile.Exists(
1064            os.path.join(
1065                compat.as_bytes(export_dir),
1066                compat.as_bytes('saved_model.pb'))))
1067    self.assertTrue(
1068        gfile.Exists(
1069            os.path.join(
1070                compat.as_bytes(export_dir), compat.as_bytes('variables'))))
1071    self.assertTrue(
1072        gfile.Exists(
1073            os.path.join(
1074                compat.as_bytes(export_dir),
1075                compat.as_bytes('variables/variables.index'))))
1076    self.assertTrue(
1077        gfile.Exists(
1078            os.path.join(
1079                compat.as_bytes(export_dir),
1080                compat.as_bytes('variables/variables.data-00000-of-00001'))))
1081
1082    # Restore, to validate that the export was well-formed.
1083    with ops.Graph().as_default() as graph:
1084      with session_lib.Session(graph=graph) as sess:
1085        loader.load(sess, [tag_constants.SERVING], export_dir)
1086        graph_ops = [x.name for x in graph.get_operations()]
1087        self.assertTrue('input_example_tensor' in graph_ops)
1088        self.assertTrue('ParseExample/ParseExample' in graph_ops)
1089        self.assertTrue('LookupTableModel' in graph_ops)
1090        self.assertFalse('LookupTableTrainingState' in graph_ops)
1091
1092    # cleanup
1093    gfile.DeleteRecursively(tmpdir)
1094
1095  def test_export_savedmodel_with_graph_transforms(self):
1096    tmpdir = tempfile.mkdtemp()
1097    est, serving_input_fn = _build_estimator_for_export_tests(tmpdir)
1098
1099    extra_file_name = os.path.join(
1100        compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file'))
1101    extra_file = gfile.GFile(extra_file_name, mode='w')
1102    extra_file.write(EXTRA_FILE_CONTENT)
1103    extra_file.close()
1104    assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
1105
1106    export_dir_base = os.path.join(
1107        compat.as_bytes(tmpdir), compat.as_bytes('export'))
1108    export_dir = est.export_savedmodel(
1109        export_dir_base,
1110        serving_input_fn,
1111        assets_extra=assets_extra,
1112        graph_rewrite_specs=[
1113            estimator.GraphRewriteSpec(['tag_1'], []),
1114            estimator.GraphRewriteSpec(['tag_2', 'tag_3'],
1115                                       ['strip_unused_nodes'])
1116        ])
1117
1118    self.assertTrue(gfile.Exists(export_dir_base))
1119    self.assertTrue(gfile.Exists(export_dir))
1120    self.assertTrue(
1121        gfile.Exists(
1122            os.path.join(
1123                compat.as_bytes(export_dir),
1124                compat.as_bytes('saved_model.pb'))))
1125    self.assertTrue(
1126        gfile.Exists(
1127            os.path.join(
1128                compat.as_bytes(export_dir), compat.as_bytes('variables'))))
1129    self.assertTrue(
1130        gfile.Exists(
1131            os.path.join(
1132                compat.as_bytes(export_dir),
1133                compat.as_bytes('variables/variables.index'))))
1134    self.assertTrue(
1135        gfile.Exists(
1136            os.path.join(
1137                compat.as_bytes(export_dir),
1138                compat.as_bytes('variables/variables.data-00000-of-00001'))))
1139
1140    self.assertTrue(
1141        gfile.Exists(
1142            os.path.join(
1143                compat.as_bytes(export_dir), compat.as_bytes('assets'))))
1144    self.assertTrue(
1145        gfile.Exists(
1146            os.path.join(
1147                compat.as_bytes(export_dir),
1148                compat.as_bytes('assets/my_vocab_file'))))
1149    self.assertEqual(
1150        compat.as_bytes(VOCAB_FILE_CONTENT),
1151        compat.as_bytes(
1152            gfile.GFile(
1153                os.path.join(
1154                    compat.as_bytes(export_dir),
1155                    compat.as_bytes('assets/my_vocab_file'))).read()))
1156
1157    expected_extra_path = os.path.join(
1158        compat.as_bytes(export_dir),
1159        compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
1160    self.assertTrue(
1161        gfile.Exists(
1162            os.path.join(
1163                compat.as_bytes(export_dir), compat.as_bytes('assets.extra'))))
1164    self.assertTrue(gfile.Exists(expected_extra_path))
1165    self.assertEqual(
1166        compat.as_bytes(EXTRA_FILE_CONTENT),
1167        compat.as_bytes(gfile.GFile(expected_extra_path).read()))
1168
1169    expected_vocab_file = os.path.join(
1170        compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
1171
1172    # Restore, to validate that the export was well-formed.
1173    # tag_1 is untransformed.
1174    tags = ['tag_1']
1175    with ops.Graph().as_default() as graph:
1176      with session_lib.Session(graph=graph) as sess:
1177        loader.load(sess, tags, export_dir)
1178        assets = [
1179            x.eval()
1180            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
1181        ]
1182        self.assertItemsEqual([expected_vocab_file], assets)
1183        graph_ops = [x.name for x in graph.get_operations()]
1184        self.assertIn('input_example_tensor', graph_ops)
1185        self.assertIn('ParseExample/ParseExample', graph_ops)
1186        self.assertIn('linear/linear/feature/matmul', graph_ops)
1187        # Since there were no transforms, both save ops are still present.
1188        self.assertIn('save/SaveV2/tensor_names', graph_ops)
1189        self.assertIn('save_1/SaveV2/tensor_names', graph_ops)
1190        # Since there were no transforms, the hash table lookup is still there.
1191        self.assertIn('hash_table_Lookup/LookupTableFindV2', graph_ops)
1192
1193    # Restore, to validate that the export was well-formed.
1194    # tag_2, tag_3 was subjected to strip_unused_nodes.
1195    tags = ['tag_2', 'tag_3']
1196    with ops.Graph().as_default() as graph:
1197      with session_lib.Session(graph=graph) as sess:
1198        loader.load(sess, tags, export_dir)
1199        assets = [
1200            x.eval()
1201            for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
1202        ]
1203        self.assertItemsEqual([expected_vocab_file], assets)
1204        graph_ops = [x.name for x in graph.get_operations()]
1205        self.assertTrue('input_example_tensor' in graph_ops)
1206        self.assertTrue('ParseExample/ParseExample' in graph_ops)
1207        self.assertTrue('linear/linear/feature/matmul' in graph_ops)
1208        # The Saver used to restore the checkpoint into the export Session
1209        # was not added to the SAVERS collection, so strip_unused_nodes removes
1210        # it.  The one explicitly created in export_savedmodel is tracked in
1211        # the MetaGraphDef saver_def field, so that one is retained.
1212        # TODO(soergel): Make Savers sane again.  I understand this is all a bit
1213        # nuts but for now the test demonstrates what actually happens.
1214        self.assertFalse('save/SaveV2/tensor_names' in graph_ops)
1215        self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops)
1216        # The fake hash table lookup wasn't connected to anything; stripped.
1217        self.assertFalse('hash_table_Lookup' in graph_ops)
1218
1219    # cleanup
1220    gfile.DeleteRecursively(tmpdir)
1221
1222
1223class InferRealValuedColumnsTest(test.TestCase):
1224
1225  def testInvalidArgs(self):
1226    with self.assertRaisesRegexp(ValueError, 'x or input_fn must be provided'):
1227      estimator.infer_real_valued_columns_from_input(None)
1228
1229    with self.assertRaisesRegexp(ValueError, 'cannot be tensors'):
1230      estimator.infer_real_valued_columns_from_input(constant_op.constant(1.0))
1231
1232  def _assert_single_feature_column(self, expected_shape, expected_dtype,
1233                                    feature_columns):
1234    self.assertEqual(1, len(feature_columns))
1235    feature_column = feature_columns[0]
1236    self.assertEqual('', feature_column.name)
1237    self.assertEqual({
1238        '':
1239            parsing_ops.FixedLenFeature(
1240                shape=expected_shape, dtype=expected_dtype)
1241    }, feature_column.config)
1242
1243  def testInt32Input(self):
1244    feature_columns = estimator.infer_real_valued_columns_from_input(
1245        np.ones(shape=[7, 8], dtype=np.int32))
1246    self._assert_single_feature_column([8], dtypes.int32, feature_columns)
1247
1248  def testInt32InputFn(self):
1249    feature_columns = estimator.infer_real_valued_columns_from_input_fn(
1250        lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.int32), None))
1251    self._assert_single_feature_column([8], dtypes.int32, feature_columns)
1252
1253  def testInt64Input(self):
1254    feature_columns = estimator.infer_real_valued_columns_from_input(
1255        np.ones(shape=[7, 8], dtype=np.int64))
1256    self._assert_single_feature_column([8], dtypes.int64, feature_columns)
1257
1258  def testInt64InputFn(self):
1259    feature_columns = estimator.infer_real_valued_columns_from_input_fn(
1260        lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.int64), None))
1261    self._assert_single_feature_column([8], dtypes.int64, feature_columns)
1262
1263  def testFloat32Input(self):
1264    feature_columns = estimator.infer_real_valued_columns_from_input(
1265        np.ones(shape=[7, 8], dtype=np.float32))
1266    self._assert_single_feature_column([8], dtypes.float32, feature_columns)
1267
1268  def testFloat32InputFn(self):
1269    feature_columns = estimator.infer_real_valued_columns_from_input_fn(
1270        lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.float32), None))
1271    self._assert_single_feature_column([8], dtypes.float32, feature_columns)
1272
1273  def testFloat64Input(self):
1274    feature_columns = estimator.infer_real_valued_columns_from_input(
1275        np.ones(shape=[7, 8], dtype=np.float64))
1276    self._assert_single_feature_column([8], dtypes.float64, feature_columns)
1277
1278  def testFloat64InputFn(self):
1279    feature_columns = estimator.infer_real_valued_columns_from_input_fn(
1280        lambda: (array_ops.ones(shape=[7, 8], dtype=dtypes.float64), None))
1281    self._assert_single_feature_column([8], dtypes.float64, feature_columns)
1282
1283  def testBoolInput(self):
1284    with self.assertRaisesRegexp(
1285        ValueError, 'on integer or non floating types are not supported'):
1286      estimator.infer_real_valued_columns_from_input(
1287          np.array([[False for _ in xrange(8)] for _ in xrange(7)]))
1288
1289  def testBoolInputFn(self):
1290    with self.assertRaisesRegexp(
1291        ValueError, 'on integer or non floating types are not supported'):
1292      # pylint: disable=g-long-lambda
1293      estimator.infer_real_valued_columns_from_input_fn(
1294          lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), None)
1295      )
1296
1297  def testStringInput(self):
1298    with self.assertRaisesRegexp(
1299        ValueError, 'on integer or non floating types are not supported'):
1300      # pylint: disable=g-long-lambda
1301      estimator.infer_real_valued_columns_from_input(
1302          np.array([['%d.0' % i for i in xrange(8)] for _ in xrange(7)]))
1303
1304  def testStringInputFn(self):
1305    with self.assertRaisesRegexp(
1306        ValueError, 'on integer or non floating types are not supported'):
1307      # pylint: disable=g-long-lambda
1308      estimator.infer_real_valued_columns_from_input_fn(
1309          lambda: (
1310              constant_op.constant([['%d.0' % i
1311                                     for i in xrange(8)]
1312                                    for _ in xrange(7)]),
1313              None))
1314
1315  def testBostonInputFn(self):
1316    feature_columns = estimator.infer_real_valued_columns_from_input_fn(
1317        boston_input_fn)
1318    self._assert_single_feature_column([_BOSTON_INPUT_DIM], dtypes.float64,
1319                                       feature_columns)
1320
1321  def testIrisInputFn(self):
1322    feature_columns = estimator.infer_real_valued_columns_from_input_fn(
1323        iris_input_fn)
1324    self._assert_single_feature_column([_IRIS_INPUT_DIM], dtypes.float64,
1325                                       feature_columns)
1326
1327
1328class ReplicaDeviceSetterTest(test.TestCase):
1329
1330  def testVariablesAreOnPs(self):
1331    tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
1332    with test.mock.patch.dict('os.environ', {
1333        'TF_CONFIG': json.dumps(tf_config)
1334    }):
1335      config = run_config.RunConfig()
1336
1337    with ops.device(estimator._get_replica_device_setter(config)):
1338      v = variables_lib.Variable([1, 2])
1339      w = variables_lib.Variable([2, 1])
1340      a = v + w
1341    self.assertDeviceEqual('/job:ps/task:0', v.device)
1342    self.assertDeviceEqual('/job:ps/task:0', v.initializer.device)
1343    self.assertDeviceEqual('/job:ps/task:0', w.device)
1344    self.assertDeviceEqual('/job:ps/task:0', w.initializer.device)
1345    self.assertDeviceEqual('/job:worker', a.device)
1346
1347  def testVariablesAreLocal(self):
1348    with ops.device(
1349        estimator._get_replica_device_setter(run_config.RunConfig())):
1350      v = variables_lib.Variable([1, 2])
1351      w = variables_lib.Variable([2, 1])
1352      a = v + w
1353    self.assertDeviceEqual('', v.device)
1354    self.assertDeviceEqual('', v.initializer.device)
1355    self.assertDeviceEqual('', w.device)
1356    self.assertDeviceEqual('', w.initializer.device)
1357    self.assertDeviceEqual('', a.device)
1358
1359  def testMutableHashTableIsOnPs(self):
1360    tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}}
1361    with test.mock.patch.dict('os.environ', {
1362        'TF_CONFIG': json.dumps(tf_config)
1363    }):
1364      config = run_config.RunConfig()
1365
1366    with ops.device(estimator._get_replica_device_setter(config)):
1367      default_val = constant_op.constant([-1, -1], dtypes.int64)
1368      table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
1369      input_string = constant_op.constant(['brain', 'salad', 'tank'])
1370      output = table.lookup(input_string)
1371    self.assertDeviceEqual('/job:ps/task:0', table.resource_handle.device)
1372    self.assertDeviceEqual('/job:ps/task:0', output.device)
1373
1374  def testMutableHashTableIsLocal(self):
1375    with ops.device(
1376        estimator._get_replica_device_setter(run_config.RunConfig())):
1377      default_val = constant_op.constant([-1, -1], dtypes.int64)
1378      table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val)
1379      input_string = constant_op.constant(['brain', 'salad', 'tank'])
1380      output = table.lookup(input_string)
1381    self.assertDeviceEqual('', table.resource_handle.device)
1382    self.assertDeviceEqual('', output.device)
1383
1384  def testTaskIsSetOnWorkerWhenJobNameIsSet(self):
1385    tf_config = {
1386        'cluster': {
1387            run_config.TaskType.PS: ['fake_ps_0']
1388        },
1389        'task': {
1390            'type': run_config.TaskType.WORKER,
1391            'index': 3
1392        }
1393    }
1394    with test.mock.patch.dict('os.environ', {
1395        'TF_CONFIG': json.dumps(tf_config)
1396    }):
1397      config = run_config.RunConfig()
1398
1399    with ops.device(estimator._get_replica_device_setter(config)):
1400      v = variables_lib.Variable([1, 2])
1401      w = variables_lib.Variable([2, 1])
1402      a = v + w
1403    self.assertDeviceEqual('/job:ps/task:0', v.device)
1404    self.assertDeviceEqual('/job:ps/task:0', v.initializer.device)
1405    self.assertDeviceEqual('/job:ps/task:0', w.device)
1406    self.assertDeviceEqual('/job:ps/task:0', w.initializer.device)
1407    self.assertDeviceEqual('/job:worker/task:3', a.device)
1408
1409
1410if __name__ == '__main__':
1411  test.main()
1412