• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#  Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3#  Licensed under the Apache License, Version 2.0 (the "License");
4#  you may not use this file except in compliance with the License.
5#  You may obtain a copy of the License at
6#
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9#  Unless required by applicable law or agreed to in writing, software
10#  distributed under the License is distributed on an "AS IS" BASIS,
11#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#  See the License for the specific language governing permissions and
13#  limitations under the License.
14"""Tests for TaskRunner and Experiment class."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import json
21import os
22import tempfile
23import time
24
25from tensorflow.contrib.layers.python.layers import feature_column
26from tensorflow.contrib.learn.python.learn import estimator as estimator_lib
27from tensorflow.contrib.learn.python.learn import evaluable
28from tensorflow.contrib.learn.python.learn import experiment
29from tensorflow.contrib.learn.python.learn import run_config
30from tensorflow.contrib.learn.python.learn import trainable
31from tensorflow.contrib.learn.python.learn.estimators import dnn
32from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
33from tensorflow.contrib.learn.python.learn.estimators import test_data
34from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
35from tensorflow.core.protobuf import config_pb2
36from tensorflow.python.client import session
37from tensorflow.python.estimator import estimator as core_estimator
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40from tensorflow.python.platform import tf_logging
41from tensorflow.python.training import saver
42from tensorflow.python.training import server_lib
43from tensorflow.python.training import session_run_hook
44from tensorflow.python.util import compat
45from tensorflow.python.util import tf_inspect
46
47
48class SheepCounter(object):
49  """To be patched in for the time module, replacing sleep() and time()."""
50
51  def __init__(self):
52    self._total_time = 0
53    self._sleeptimes = []
54    self._time_calls = 0
55
56  def sleep(self, t):
57    self._total_time += t
58    self._sleeptimes += [t]
59
60  def time(self):
61    self._time_calls += 1
62    return self._total_time
63
64  @property
65  def sleep_times(self):
66    return self._sleeptimes
67
68  @property
69  def time_calls(self):
70    return self._time_calls
71
72
73class TestBaseEstimator(object):
74
75  def __init__(self, config, max_evals, eval_dict):
76    self.eval_count = 0
77    self.fit_count = 0
78    self._max_evals = max_evals
79    self.export_count = 0
80    self.monitors = []
81    self.eval_hooks = []
82    self._config = config or run_config.RunConfig()
83    self._model_dir = tempfile.mkdtemp()
84    self._eval_dict = eval_dict
85
86  @property
87  def model_dir(self):
88    return self._model_dir
89
90  @property
91  def config(self):
92    return self._config
93
94  def evaluate(self, **kwargs):
95    tf_logging.info('evaluate called with args: %s' % kwargs)
96    if 'hooks' in kwargs:
97      self.eval_hooks = kwargs['hooks']
98    self.eval_count += 1
99    if self.eval_count > self._max_evals:
100      tf_logging.info('Ran %d evals. Done.' % self.eval_count)
101      raise StopIteration()
102    return self._eval_dict
103
104  def fake_checkpoint(self):
105    save_path = os.path.join(self.model_dir, 'model.ckpt')
106    with session.Session() as sess:
107      var = variables.Variable(1.0, name='var0')
108      save = saver.Saver({var.op.name: var})
109      var.initializer.run()
110      save.save(sess, save_path, global_step=0)
111
112  def train(self, **kwargs):
113    self.fake_checkpoint()
114    tf_logging.info('fit called with args: %s' % kwargs)
115    self.fit_count += 1
116
117    return [(key, kwargs[key]) for key in sorted(kwargs.keys())]
118
119  def export_savedmodel(self, export_dir_base, serving_input_fn, **kwargs):
120    tf_logging.info('export_savedmodel called with args: %s, %s, %s' %
121                    (export_dir_base, serving_input_fn, kwargs))
122    self.export_count += 1
123    return os.path.join(
124        compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
125
126
127def _check_method_supports_args(method, kwargs):
128  """Checks that the given method supports the given args."""
129  supported_args = tuple(tf_inspect.getargspec(method).args)
130  for kwarg in kwargs:
131    if kwarg not in supported_args:
132      raise ValueError(
133          'Argument `{}` is not supported in method {}.'.format(kwarg, method))
134
135
136class TestEstimator(
137    TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
138
139  def __init__(self, config=None, max_evals=5, eval_dict=None):
140    super(TestEstimator, self).__init__(config, max_evals, eval_dict)
141    tf_logging.info('Create Estimator')
142
143  def evaluate(self, **kwargs):
144    _check_method_supports_args(evaluable.Evaluable.evaluate, kwargs)
145    return super(TestEstimator, self).evaluate(**kwargs)
146
147  def fit(self, **kwargs):
148    _check_method_supports_args(trainable.Trainable.fit, kwargs)
149    if 'monitors' in kwargs:
150      self.monitors = kwargs['monitors']
151    return super(TestEstimator, self).train(**kwargs)
152
153  def train(self, **kwargs):
154    raise ValueError('`train` is not defined in Estimator.')
155
156  def export_savedmodel(
157      self, export_dir_base, serving_input_fn, **kwargs):
158    _check_method_supports_args(
159        estimator_lib.Estimator.export_savedmodel, kwargs)
160    return super(TestEstimator, self).export_savedmodel(
161        export_dir_base, serving_input_fn, **kwargs)
162
163
164class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
165
166  def __init__(self, config=None, max_evals=5, eval_dict=None):
167    super(TestCoreEstimator, self).__init__(config, max_evals, eval_dict)
168    tf_logging.info('Create Core Estimator')
169
170  def evaluate(self, **kwargs):
171    _check_method_supports_args(core_estimator.Estimator.evaluate, kwargs)
172    return super(TestCoreEstimator, self).evaluate(**kwargs)
173
174  def train(self, **kwargs):
175    _check_method_supports_args(core_estimator.Estimator.train, kwargs)
176    if 'hooks' in kwargs:
177      self.monitors = kwargs['hooks']
178    return super(TestCoreEstimator, self).train(**kwargs)
179
180  def export_savedmodel(
181      self, export_dir_base, serving_input_receiver_fn, **kwargs):
182    _check_method_supports_args(
183        core_estimator.Estimator.export_savedmodel, kwargs)
184    return super(TestCoreEstimator, self).export_savedmodel(
185        export_dir_base, serving_input_receiver_fn, **kwargs)
186
187
188class _NoopHook(session_run_hook.SessionRunHook):
189  pass
190
191
192class ExperimentTest(test.TestCase):
193
194  def _cluster_spec(self):
195    return {
196        run_config_lib.TaskType.PS: ['host1:2222', 'host2:2222'],
197        run_config_lib.TaskType.WORKER:
198            ['host3:2222', 'host4:2222', 'host5:2222']
199    }
200
201  def _estimators_for_tests(self, config=None, eval_dict=None):
202    return [TestEstimator(config=config, eval_dict=eval_dict),
203            TestCoreEstimator(config=config, eval_dict=eval_dict)]
204
205  def test_eval_metrcis_for_core_estimator(self):
206    est = TestCoreEstimator()
207    with self.assertRaisesRegexp(
208        ValueError, '`eval_metrics` must be `None`'):
209      experiment.Experiment(
210          est,
211          train_input_fn='train_input',
212          train_steps='train_steps',
213          eval_input_fn='eval_input',
214          eval_metrics='eval_metrics')
215
216  def test_default_output_alternative_key_core_estimator(self):
217    est = TestCoreEstimator()
218    export_strategy = saved_model_export_utils.make_export_strategy(
219        est,
220        default_output_alternative_key='export_key',
221        exports_to_keep=None)
222    ex = experiment.Experiment(
223        est,
224        train_input_fn='train_input',
225        eval_input_fn='eval_input',
226        train_steps=100,
227        eval_steps=100,
228        export_strategies=export_strategy)
229    with self.assertRaisesRegexp(
230        ValueError, 'default_output_alternative_key is not supported'):
231      ex.train_and_evaluate()
232
233  def test_train(self):
234    for est in self._estimators_for_tests():
235      if isinstance(est, core_estimator.Estimator):
236        eval_metrics = None
237        saving_listeners = 'saving_listeners'
238      else:
239        eval_metrics = 'eval_metrics'
240        saving_listeners = None
241      ex = experiment.Experiment(
242          est,
243          train_input_fn='train_input',
244          train_steps='train_steps',
245          eval_input_fn='eval_input',
246          eval_metrics=eval_metrics,
247          saving_listeners=saving_listeners)
248      fit_args = ex.train(delay_secs=0)
249      self.assertEqual(1, est.fit_count)
250      self.assertIn(('max_steps', 'train_steps'), fit_args)
251      self.assertEqual(0, est.eval_count)
252
253  def test_train_delay(self):
254    for est in self._estimators_for_tests():
255      ex = experiment.Experiment(
256          est, train_input_fn='train_input', eval_input_fn='eval_input')
257      for delay in [0, 1, 3]:
258        sheep = SheepCounter()
259        with test.mock.patch.object(time, 'time', sheep.time):
260          with test.mock.patch.object(time, 'sleep', sheep.sleep):
261            ex.train(delay_secs=delay)
262            self.assertAlmostEqual(delay, sheep.time(), delta=1e-4)
263
264  def test_train_default_delay(self):
265    for task_id in [0, 1, 3]:
266      tf_config = {'task': {'index': task_id}}
267      with test.mock.patch.dict('os.environ',
268                                {'TF_CONFIG': json.dumps(tf_config)}):
269        config = run_config.RunConfig()
270      for est in self._estimators_for_tests(config):
271        ex = experiment.Experiment(
272            est, train_input_fn='train_input', eval_input_fn='eval_input')
273
274        sheep = SheepCounter()
275        with test.mock.patch.object(time, 'time', sheep.time):
276          with test.mock.patch.object(time, 'sleep', sheep.sleep):
277            ex.train()
278            self.assertAlmostEqual(task_id * 5, sheep.time(), delta=1e-4)
279
280  @test.mock.patch.object(server_lib, 'Server')
281  def test_train_starts_server(self, mock_server):
282    # Arrange.
283    tf_config = {
284        'cluster': self._cluster_spec(),
285        'environment': run_config_lib.Environment.CLOUD,
286        'task': {
287            'type': run_config_lib.TaskType.WORKER,
288            'index': 1
289        }
290    }
291    with test.mock.patch.dict('os.environ',
292                              {'TF_CONFIG': json.dumps(tf_config)}):
293      config = run_config_lib.RunConfig(
294          master='host4:2222', num_cores=15, gpu_memory_fraction=0.314)
295
296    for est in self._estimators_for_tests(config):
297      ex = experiment.Experiment(
298          est, train_input_fn='train_input', eval_input_fn='eval_input')
299
300      # Act.
301      # We want to make sure we discount the time it takes to start the server
302      # in our accounting of the delay, so we set a small delay here.
303      sheep = SheepCounter()
304      with test.mock.patch.object(time, 'time', sheep.time):
305        with test.mock.patch.object(time, 'sleep', sheep.sleep):
306          ex.train(delay_secs=1)
307          # Ensure that the delay takes into account the time to start server.
308          self.assertAlmostEqual(1, sheep.time(), delta=1e-4)
309
310      # Assert.
311      expected_config_proto = config_pb2.ConfigProto()
312      expected_config_proto.inter_op_parallelism_threads = 15
313      expected_config_proto.intra_op_parallelism_threads = 15
314      expected_config_proto.gpu_options.per_process_gpu_memory_fraction = 0.314
315      mock_server.assert_called_with(
316          config.cluster_spec,
317          job_name=run_config_lib.TaskType.WORKER,
318          task_index=1,
319          config=expected_config_proto,
320          start=False)
321      mock_server.assert_has_calls([test.mock.call().start()])
322
323  @test.mock.patch.object(server_lib, 'Server')
324  def test_train_server_does_not_start_without_cluster_spec(self, mock_server):
325    config = run_config_lib.RunConfig(master='host4:2222')
326    for est in self._estimators_for_tests(config):
327      ex = experiment.Experiment(
328          est,
329          train_input_fn='train_input',
330          eval_input_fn='eval_input')
331      ex.train()
332
333      # The server should not have started because there was no ClusterSpec.
334      self.assertFalse(mock_server.called)
335
336  @test.mock.patch.object(server_lib, 'Server')
337  def test_train_server_does_not_start_with_empty_master(self, mock_server):
338    tf_config = {'cluster': self._cluster_spec()}
339    with test.mock.patch.dict('os.environ',
340                              {'TF_CONFIG': json.dumps(tf_config)}):
341      config = run_config_lib.RunConfig(master='')
342    for est in self._estimators_for_tests(config):
343      ex = experiment.Experiment(
344          est,
345          train_input_fn='train_input',
346          eval_input_fn='eval_input')
347      ex.train()
348      # The server should not have started because master was the empty string.
349      self.assertFalse(mock_server.called)
350
351  def test_train_raises_if_job_name_is_missing(self):
352    tf_config = {
353        'cluster': self._cluster_spec(),
354        'environment': run_config_lib.Environment.CLOUD,
355        'task': {
356            'index': 1
357        }
358    }
359    with test.mock.patch.dict(
360        'os.environ',
361        {'TF_CONFIG': json.dumps(tf_config)}), self.assertRaises(ValueError):
362      config = run_config_lib.RunConfig(
363          master='host3:2222'  # Normally selected by task type.
364      )
365      for est in self._estimators_for_tests(config):
366        ex = experiment.Experiment(
367            est,
368            train_input_fn='train_input',
369            eval_input_fn='eval_input')
370        ex.train()
371
372  def test_evaluate(self):
373    for est in self._estimators_for_tests():
374      eval_metrics = 'eval_metrics' if not isinstance(
375          est, core_estimator.Estimator) else None
376      est.fake_checkpoint()
377      noop_hook = _NoopHook()
378      ex = experiment.Experiment(
379          est,
380          train_input_fn='train_input',
381          eval_input_fn='eval_input',
382          eval_metrics=eval_metrics,
383          eval_hooks=[noop_hook],
384          eval_steps='steps',
385          eval_delay_secs=0)
386      ex.evaluate()
387      self.assertEqual(0, est.fit_count)
388      self.assertEqual(1, est.eval_count)
389      self.assertEqual([noop_hook], est.eval_hooks)
390
391  def test_evaluate_delay(self):
392    for est in self._estimators_for_tests():
393      est.fake_checkpoint()
394      noop_hook = _NoopHook()
395      ex = experiment.Experiment(
396          est, train_input_fn='train_input', eval_input_fn='eval_input',
397          eval_hooks=[noop_hook])
398
399      for delay in [0, 1, 3]:
400        sheep = SheepCounter()
401        with test.mock.patch.object(time, 'time', sheep.time):
402          with test.mock.patch.object(time, 'sleep', sheep.sleep):
403            ex.evaluate(delay_secs=delay)
404        self.assertAlmostEqual(delay, sheep.time(), delta=1e-4)
405        self.assertEqual([noop_hook], est.eval_hooks)
406
407  def test_continuous_eval(self):
408    for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
409      eval_metrics = 'eval_metrics' if not isinstance(
410          est, core_estimator.Estimator) else None
411      est.fake_checkpoint()
412      noop_hook = _NoopHook()
413      ex = experiment.Experiment(
414          est,
415          train_input_fn='train_input',
416          eval_input_fn='eval_input',
417          eval_metrics=eval_metrics,
418          eval_hooks=[noop_hook],
419          eval_delay_secs=0,
420          continuous_eval_throttle_secs=0)
421      self.assertRaises(StopIteration, ex.continuous_eval,
422                        evaluate_checkpoint_only_once=False)
423      self.assertEqual(0, est.fit_count)
424      self.assertEqual(6, est.eval_count)
425      self.assertEqual([noop_hook], est.eval_hooks)
426
427  def test_continuous_eval_ends_after_train_step(self):
428    for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
429      eval_metrics = 'eval_metrics' if not isinstance(
430          est, core_estimator.Estimator) else None
431      est.fake_checkpoint()
432      noop_hook = _NoopHook()
433      ex = experiment.Experiment(
434          est,
435          train_input_fn='train_input',
436          eval_input_fn='eval_input',
437          eval_metrics=eval_metrics,
438          eval_hooks=[noop_hook],
439          eval_delay_secs=0,
440          continuous_eval_throttle_secs=0,
441          train_steps=100)
442      ex.continuous_eval()
443      self.assertEqual(0, est.fit_count)
444      self.assertEqual(1, est.eval_count)
445      self.assertEqual([noop_hook], est.eval_hooks)
446
447  def test_continuous_eval_throttle_delay(self):
448    for delay in [0, 1, 2]:
449      for est in self._estimators_for_tests():
450        eval_metrics = 'eval_metrics' if not isinstance(
451            est, core_estimator.Estimator) else None
452        est.fake_checkpoint()
453        noop_hook = _NoopHook()
454        ex = experiment.Experiment(
455            est,
456            train_input_fn='train_input',
457            eval_input_fn='eval_input',
458            eval_metrics=eval_metrics,
459            eval_hooks=[noop_hook],
460            continuous_eval_throttle_secs=delay,
461            eval_delay_secs=0)
462        sheep = SheepCounter()
463        with test.mock.patch.object(time, 'time', sheep.time):
464          with test.mock.patch.object(time, 'sleep', sheep.sleep):
465            self.assertRaises(
466                StopIteration,
467                ex.continuous_eval,
468                evaluate_checkpoint_only_once=False)
469            self.assertAlmostEqual(5 * delay, sheep.time(), delta=1e-4)
470
471  def test_continuous_eval_predicate_fn(self):
472    for est in self._estimators_for_tests():
473      eval_metrics = 'eval_metrics' if not isinstance(
474          est, core_estimator.Estimator) else None
475      est.fake_checkpoint()
476      noop_hook = _NoopHook()
477
478      def _predicate_fn(unused_eval_result):
479        return est.eval_count < 3  # pylint: disable=cell-var-from-loop
480
481      ex = experiment.Experiment(
482          est,
483          train_input_fn='train_input',
484          eval_input_fn='eval_input',
485          eval_metrics=eval_metrics,
486          eval_hooks=[noop_hook],
487          eval_delay_secs=0,
488          continuous_eval_throttle_secs=0)
489      ex.continuous_eval(evaluate_checkpoint_only_once=False,
490                         continuous_eval_predicate_fn=_predicate_fn)
491      self.assertEqual(0, est.fit_count)
492      self.assertEqual(3, est.eval_count)
493      self.assertEqual([noop_hook], est.eval_hooks)
494
495  def test_continuous_eval_predicate_fn_with_checkpoint(self):
496    for est in self._estimators_for_tests():
497      eval_metrics = 'eval_metrics' if not isinstance(
498          est, core_estimator.Estimator) else None
499      est.fake_checkpoint()
500      noop_hook = _NoopHook()
501
502      def _predicate_fn(eval_result, checkpoint_path):
503        self.assertEqual(eval_result is None,
504                         checkpoint_path is None)
505        return est.eval_count < 3  # pylint: disable=cell-var-from-loop
506
507      ex = experiment.Experiment(
508          est,
509          train_input_fn='train_input',
510          eval_input_fn='eval_input',
511          eval_metrics=eval_metrics,
512          eval_hooks=[noop_hook],
513          eval_delay_secs=0,
514          continuous_eval_throttle_secs=0)
515      ex.continuous_eval(
516          evaluate_checkpoint_only_once=False,
517          continuous_eval_predicate_fn=_predicate_fn)
518      self.assertEqual(0, est.fit_count)
519      self.assertEqual(3, est.eval_count)
520      self.assertEqual([noop_hook], est.eval_hooks)
521
522  def test_run_local(self):
523    for est in self._estimators_for_tests():
524      eval_metrics = 'eval_metrics' if not isinstance(
525          est, core_estimator.Estimator) else None
526      noop_hook = _NoopHook()
527      ex = experiment.Experiment(
528          est,
529          train_input_fn='train_input',
530          eval_input_fn='eval_input',
531          eval_metrics=eval_metrics,
532          eval_hooks=[noop_hook],
533          train_steps=100,
534          eval_steps=100,
535          local_eval_frequency=10)
536      ex.local_run()
537      self.assertEqual(1, est.fit_count)
538      self.assertEqual(1, est.eval_count)
539      self.assertEqual(1, len(est.monitors))
540      self.assertEqual([noop_hook], est.eval_hooks)
541      self.assertTrue(isinstance(est.monitors[0],
542                                 session_run_hook.SessionRunHook))
543
544  def test_train_hooks_extend_does_not_mutate_input_hooks(self):
545    for est in self._estimators_for_tests():
546      eval_metrics = 'eval_metrics' if not isinstance(
547          est, core_estimator.Estimator) else None
548      noop_hook = _NoopHook()
549      input_hooks = [noop_hook]
550
551      ex = experiment.Experiment(
552          est,
553          train_input_fn='train_input',
554          eval_input_fn='eval_input',
555          eval_metrics=eval_metrics,
556          train_monitors=input_hooks)
557      self.assertAllEqual([noop_hook], ex._train_monitors)
558
559      another_noop_hook = _NoopHook()
560      # Assert that the extend API mutates the hooks, but not the input hooks
561      ex.extend_train_hooks([another_noop_hook])
562      self.assertAllEqual([noop_hook, another_noop_hook], ex._train_monitors)
563      self.assertAllEqual([noop_hook], input_hooks)
564
565  def test_invalid_export_strategies(self):
566    for est in self._estimators_for_tests():
567      with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
568        experiment.Experiment(
569            est,
570            train_input_fn='train_input',
571            eval_input_fn='eval_input',
572            train_steps=100,
573            eval_steps=100,
574            export_strategies='not_an_export_strategy')
575      with self.assertRaisesRegexp(ValueError, 'ExportStrategy'):
576        experiment.Experiment(
577            est,
578            train_input_fn='train_input',
579            eval_input_fn='eval_input',
580            train_steps=100,
581            eval_steps=100,
582            export_strategies=['not_an_export_srategy'])
583
584  def test_export_strategies_reset(self):
585    for est in self._estimators_for_tests():
586      eval_metrics = 'eval_metrics' if not isinstance(
587          est, core_estimator.Estimator) else None
588      export_strategy_1 = saved_model_export_utils.make_export_strategy(
589          est,
590          None if isinstance(est, core_estimator.Estimator) else 'export_1',
591          exports_to_keep=None)
592
593      ex = experiment.Experiment(
594          est,
595          train_input_fn='train_input',
596          eval_input_fn='eval_input',
597          eval_metrics=eval_metrics,
598          train_steps=100,
599          eval_steps=100,
600          export_strategies=(export_strategy_1,))
601      ex.train_and_evaluate()
602      self.assertEqual(1, est.export_count)
603
604      # After reset with empty list (None), the count does not change and the
605      # user provided export strategy list should remain intact.
606      old_es = ex.reset_export_strategies()
607      ex.train_and_evaluate()
608      self.assertAllEqual([export_strategy_1], old_es)
609      self.assertEqual(1, est.export_count)
610
611      # After reset with list, the count should increase with the number of
612      # items.
613      export_strategy_2 = saved_model_export_utils.make_export_strategy(
614          est,
615          None if isinstance(est, core_estimator.Estimator) else 'export_2',
616          exports_to_keep=None)
617      export_strategy_3 = saved_model_export_utils.make_export_strategy(
618          est,
619          None if isinstance(est, core_estimator.Estimator) else 'export_3',
620          exports_to_keep=None)
621
622      old_es = ex.reset_export_strategies(
623          [export_strategy_2, export_strategy_3])
624      ex.train_and_evaluate()
625      self.assertAllEqual([], old_es)
626      self.assertEqual(3, est.export_count)
627
628  def test_train_and_evaluate(self):
629    for est in self._estimators_for_tests():
630      eval_metrics = 'eval_metrics' if not isinstance(
631          est, core_estimator.Estimator) else None
632      noop_hook = _NoopHook()
633      export_strategy = saved_model_export_utils.make_export_strategy(
634          est,
635          None if isinstance(est, core_estimator.Estimator) else 'export_input',
636          exports_to_keep=None)
637      ex = experiment.Experiment(
638          est,
639          train_input_fn='train_input',
640          eval_input_fn='eval_input',
641          eval_metrics=eval_metrics,
642          eval_hooks=[noop_hook],
643          train_steps=100,
644          eval_steps=100,
645          export_strategies=export_strategy)
646      ex.train_and_evaluate()
647      self.assertEqual(1, est.fit_count)
648      self.assertEqual(1, est.eval_count)
649      self.assertEqual(1, est.export_count)
650      self.assertEqual(1, len(est.monitors))
651      self.assertEqual([noop_hook], est.eval_hooks)
652      self.assertTrue(isinstance(est.monitors[0],
653                                 session_run_hook.SessionRunHook))
654
655  def test_train_and_evaluate_with_no_eval_during_training(self):
656    for est in self._estimators_for_tests():
657      eval_metrics = 'eval_metrics' if not isinstance(
658          est, core_estimator.Estimator) else None
659      noop_hook = _NoopHook()
660      ex = experiment.Experiment(
661          est,
662          train_input_fn='train_input',
663          eval_input_fn='eval_input',
664          eval_metrics=eval_metrics,
665          eval_hooks=[noop_hook],
666          train_steps=100,
667          eval_steps=100,
668          min_eval_frequency=0)
669      ex.train_and_evaluate()
670      self.assertEqual(1, est.fit_count)
671      self.assertEqual(1, est.eval_count)
672      self.assertEqual(0, len(est.monitors))
673
674  def test_min_eval_frequency_defaults(self):
675    def dummy_model_fn(features, labels):  # pylint: disable=unused-argument
676      pass
677    estimator = core_estimator.Estimator(dummy_model_fn, '/tmp/dummy')
678    ex = experiment.Experiment(
679        estimator, train_input_fn=None, eval_input_fn=None)
680    self.assertEquals(ex._min_eval_frequency, 1)
681
682  def test_continuous_train_and_eval(self):
683    for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
684      if isinstance(est, core_estimator.Estimator):
685        eval_metrics = None
686        saving_listeners = 'saving_listeners'
687      else:
688        eval_metrics = 'eval_metrics'
689        saving_listeners = None
690      noop_hook = _NoopHook()
691      export_strategy = saved_model_export_utils.make_export_strategy(
692          est,
693          None if isinstance(est, core_estimator.Estimator) else 'export_input',
694          exports_to_keep=None)
695      ex = experiment.Experiment(
696          est,
697          train_input_fn='train_input',
698          eval_input_fn='eval_input',
699          eval_metrics=eval_metrics,
700          eval_hooks=[noop_hook],
701          train_steps=100,
702          eval_steps=100,
703          export_strategies=export_strategy,
704          saving_listeners=saving_listeners)
705      ex.continuous_train_and_eval()
706      self.assertEqual(1, est.fit_count)
707      self.assertEqual(1, est.eval_count)
708      self.assertEqual(1, est.export_count)
709      self.assertEqual([noop_hook], est.eval_hooks)
710
711  def test_continuous_train_and_eval_with_predicate_fn(self):
712    for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
713      eval_metrics = 'eval_metrics' if not isinstance(
714          est, core_estimator.Estimator) else None
715      export_strategy = saved_model_export_utils.make_export_strategy(
716          est,
717          None if isinstance(est, core_estimator.Estimator) else 'export_input',
718          exports_to_keep=None)
719      ex = experiment.Experiment(
720          est,
721          train_input_fn='train_input',
722          eval_input_fn='eval_input',
723          eval_metrics=eval_metrics,
724          train_steps=100000000000,  # a value will make `ex` never stops.
725          eval_steps=100,
726          export_strategies=export_strategy)
727
728      def predicate_fn(eval_result):
729        del eval_result  # unused. for fn signature.
730        return False
731
732      ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
733      self.assertEqual(0, est.fit_count)
734      self.assertEqual(0, est.eval_count)
735      self.assertEqual(0, est.export_count)
736
737  def test_continuous_train_and_eval_with_adapted_steps_per_iteration(self):
738    mock_estimator = test.mock.Mock(core_estimator.Estimator)
739    type(mock_estimator).model_dir = test.mock.PropertyMock(
740        return_value='test_dir')
741
742    total_steps = 100000000000000
743    ex = experiment.Experiment(
744        mock_estimator,
745        train_input_fn='train_input',
746        eval_input_fn='eval_input',
747        train_steps=total_steps)
748
749    def predicate_fn(eval_result):
750      # Allows the first invoke only.
751      return eval_result is None
752
753    ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
754    mock_estimator.train.assert_called_once_with(
755        input_fn='train_input',
756        steps=int(total_steps / 10),
757        max_steps=test.mock.ANY,
758        hooks=test.mock.ANY,
759        saving_listeners=test.mock.ANY)
760
761  def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self):
762    mock_estimator = test.mock.Mock(core_estimator.Estimator)
763    type(mock_estimator).model_dir = test.mock.PropertyMock(
764        return_value='test_dir')
765
766    total_steps = 100000000000000
767    ex = experiment.Experiment(
768        mock_estimator,
769        train_input_fn='train_input',
770        eval_input_fn='eval_input',
771        train_steps_per_iteration=1234,
772        train_steps=total_steps)
773
774    def predicate_fn(eval_result):
775      # Allows the first invoke only.
776      return eval_result is None
777
778    ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
779    mock_estimator.train.assert_called_once_with(
780        input_fn='train_input',
781        steps=1234,
782        max_steps=test.mock.ANY,
783        hooks=test.mock.ANY,
784        saving_listeners=test.mock.ANY)
785
786  def test_continuous_train_and_eval_with_default_steps_per_iteration(self):
787    mock_estimator = test.mock.Mock(core_estimator.Estimator)
788    type(mock_estimator).model_dir = test.mock.PropertyMock(
789        return_value='test_dir')
790
791    ex = experiment.Experiment(
792        mock_estimator,
793        train_input_fn='train_input',
794        eval_input_fn='eval_input',
795        train_steps_per_iteration=None,
796        train_steps=None)
797
798    def predicate_fn(eval_result):
799      # Allows the first invoke only.
800      return eval_result is None
801
802    ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn)
803    mock_estimator.train.assert_called_once_with(
804        input_fn='train_input',
805        steps=1000,
806        max_steps=test.mock.ANY,
807        hooks=test.mock.ANY,
808        saving_listeners=test.mock.ANY)
809
810  def test_continuous_train_and_eval_with_invalid_predicate_fn(self):
811    for est in self._estimators_for_tests():
812      ex = experiment.Experiment(
813          est,
814          train_input_fn='train_input',
815          eval_input_fn='eval_input')
816      with self.assertRaisesRegexp(
817          ValueError, '`continuous_eval_predicate_fn` must be a callable'):
818        ex.continuous_train_and_eval(continuous_eval_predicate_fn='fn')
819
820  def test_continuous_train_and_eval_with_invalid_train_steps_iterations(self):
821    for est in self._estimators_for_tests():
822      with self.assertRaisesRegexp(
823          ValueError, '`train_steps_per_iteration` must be an integer.'):
824        experiment.Experiment(
825            est,
826            train_input_fn='train_input',
827            eval_input_fn='eval_input',
828            train_steps_per_iteration='123')
829
830  @test.mock.patch.object(server_lib, 'Server')
831  def test_run_std_server(self, mock_server):
832    # Arrange.
833    tf_config = {
834        'cluster': self._cluster_spec(),
835        'task': {
836            'type': run_config_lib.TaskType.PS,
837            'index': 1
838        }
839    }
840    with test.mock.patch.dict('os.environ',
841                              {'TF_CONFIG': json.dumps(tf_config)}):
842      config = run_config_lib.RunConfig(
843          master='host2:2222',
844          num_cores=15,
845          gpu_memory_fraction=0.314,)
846    for est in self._estimators_for_tests(config):
847      ex = experiment.Experiment(
848          est, train_input_fn='train_input', eval_input_fn='eval_input')
849
850      # Act.
851      ex.run_std_server()
852
853      # Assert.
854      mock_server.assert_has_calls(
855          [test.mock.call().start(), test.mock.call().join()])
856
857  @test.mock.patch.object(server_lib, 'Server')
858  def test_run_std_server_raises_without_cluster_spec(self, mock_server):
859    config = run_config_lib.RunConfig(master='host4:2222')
860    for est in self._estimators_for_tests(config):
861      with self.assertRaises(ValueError):
862        ex = experiment.Experiment(
863            est,
864            train_input_fn='train_input',
865            eval_input_fn='eval_input')
866        ex.run_std_server()
867
868  def test_test(self):
869    for est in self._estimators_for_tests():
870      exp_strategy = saved_model_export_utils.make_export_strategy(
871          est,
872          None if isinstance(est, core_estimator.Estimator) else 'export_input',
873          exports_to_keep=None)
874      if isinstance(est, core_estimator.Estimator):
875        eval_metrics = None
876        saving_listeners = 'saving_listeners'
877      else:
878        eval_metrics = 'eval_metrics'
879        saving_listeners = None
880      ex = experiment.Experiment(
881          est,
882          train_input_fn='train_input',
883          eval_input_fn='eval_input',
884          export_strategies=(exp_strategy,),
885          eval_metrics=eval_metrics,
886          saving_listeners=saving_listeners)
887      ex.test()
888      self.assertEqual(1, est.fit_count)
889      self.assertEqual(1, est.eval_count)
890      self.assertEqual(1, est.export_count)
891
892  def test_continuous_eval_evaluates_checkpoint_once(self):
893    for est in self._estimators_for_tests(eval_dict={'global_step': 100}):
894      eval_metrics = 'eval_metrics' if not isinstance(
895          est, core_estimator.Estimator) else None
896      est.fake_checkpoint()
897
898      result = {
899          'called': 0,
900          'called_with_eval_result': 0,
901      }
902      # pylint: disable=cell-var-from-loop
903      def _predicate_fn(eval_result):
904        result['called'] += 1
905        if eval_result:
906          # If eval_result is not empty nor None, the checkpoint has been
907          # evaluated.
908          result['called_with_eval_result'] += 1
909        # With 300 times of evaluation, this should prove something.
910        return result['called'] < 300
911      # pylint: enable=cell-var-from-loop
912
913      ex = experiment.Experiment(
914          est,
915          train_input_fn='train_input',
916          eval_input_fn='eval_input',
917          eval_metrics=eval_metrics,
918          eval_delay_secs=0,
919          continuous_eval_throttle_secs=0)
920      ex.continuous_eval(evaluate_checkpoint_only_once=True,
921                         continuous_eval_predicate_fn=_predicate_fn)
922
923      self.assertEqual(0, est.fit_count)
924      self.assertEqual(1, est.eval_count)
925      self.assertEqual(300, result['called'])
926      self.assertEqual(1, result['called_with_eval_result'])
927
928  def test_checkpoint_and_export(self):
929    model_dir = tempfile.mkdtemp()
930    config = run_config_lib.RunConfig(save_checkpoints_steps=3)
931    est = dnn.DNNClassifier(
932        n_classes=3,
933        feature_columns=[
934            feature_column.real_valued_column('feature', dimension=4)
935        ],
936        hidden_units=[3, 3],
937        model_dir=model_dir,
938        config=config)
939
940    exp_strategy = saved_model_export_utils.make_export_strategy(
941        est, 'export_input', exports_to_keep=None)
942
943    ex = experiment.Experiment(
944        est,
945        train_input_fn=test_data.iris_input_multiclass_fn,
946        eval_input_fn=test_data.iris_input_multiclass_fn,
947        export_strategies=(exp_strategy,),
948        train_steps=8,
949        checkpoint_and_export=True,
950        eval_delay_secs=0)
951
952    with test.mock.patch.object(ex, '_maybe_export'):
953      with test.mock.patch.object(ex, '_call_evaluate'):
954        ex.train_and_evaluate()
955        # Eval and export are called after steps 1, 4, 7, and 8 (after training
956        # is completed).
957        self.assertEqual(ex._maybe_export.call_count, 4)
958        self.assertEqual(ex._call_evaluate.call_count, 4)
959
960
961if __name__ == '__main__':
962  test.main()
963