• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experiment class collecting information for a single training run (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import contextlib
27import functools
28import math
29import os
30import time
31
32from tensorflow.contrib.framework import deprecated
33from tensorflow.contrib.framework.python.framework import experimental
34from tensorflow.contrib.learn.python.learn import evaluable
35from tensorflow.contrib.learn.python.learn import export_strategy
36from tensorflow.contrib.learn.python.learn import monitors
37from tensorflow.contrib.learn.python.learn import trainable
38from tensorflow.contrib.learn.python.learn.estimators import run_config
39from tensorflow.contrib.tpu.python.tpu import tpu_estimator
40from tensorflow.python.estimator import estimator as core_estimator
41from tensorflow.python.framework import ops
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.training import basic_session_run_hooks
44from tensorflow.python.training import checkpoint_management
45from tensorflow.python.training import server_lib
46from tensorflow.python.util import compat
47from tensorflow.python.util import function_utils
48
49__all__ = ["Experiment"]
50
51
52def _get_standardized_predicate_fn(predicate_fn):
53  pred_fn_args = function_utils.fn_args(predicate_fn)
54  if "checkpoint_path" not in pred_fn_args:
55    # pylint: disable=unused-argument
56    def _pred_fn_wrapper(eval_results, checkpoint_path):
57      return predicate_fn(eval_results)
58
59    return _pred_fn_wrapper
60  else:
61    return predicate_fn
62
63
64class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
65  """Listener that evaluates and exports a model after creating a checkpoint.
66
67  The `EvalAndExportListener` waits for the associated `CheckpointSaverHook`
68  to save a checkpoint. It then uses the provided `eval_fn` and `export_fn` to
69  first evaluate the model using the newly-created checkpoint, and then export
70  the model according to the `export_strategies` provided in the `Experiment`.
71
72  This listener is experimental and may be changed or removed in the future.
73  """
74
75  def __init__(self, eval_fn, export_fn, model_dir):
76    """Initializes an `EvalAndExportListener`.
77
78    Args:
79      eval_fn: function which evaluates the model with the following signature:
80        `(name, checkpoint_path) -> eval_result`
81      export_fn: function which exports the model according to a set of export
82        strategies. Has the following signature:
83        `(eval_result, checkpoint_path) -> export_results`
84      model_dir: directory which contains estimator parameters and checkpoints.
85    """
86    self._eval_fn = eval_fn
87    self._export_fn = export_fn
88    self._model_dir = model_dir
89    self._latest_path = None
90    self._eval_result = None
91    self._export_results = None
92
93  def after_save(self, session, global_step_value):
94    """Evaluates and exports the model after a checkpoint is created."""
95    # Load and cache the path of the most recent checkpoint to avoid duplicate
96    # searches on GCS.
97    logging.info("Checking for checkpoint in %s", self._model_dir)
98    latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
99
100    if not latest_path:
101      logging.warning("Skipping evaluation and export since model has not been "
102                      "saved yet.")
103    elif latest_path == self._latest_path:
104      logging.warning("Skipping evaluation due to same latest checkpoint %s.",
105                      latest_path)
106    else:
107      self._latest_path = latest_path
108      self._eval_result = self._eval_fn(
109          name="intermediate_export", checkpoint_path=latest_path)
110      self._export_results = self._export_fn(
111          self._eval_result, checkpoint_path=latest_path)
112
113  @property
114  def eval_result(self):
115    return self._eval_result
116
117  @property
118  def export_results(self):
119    return self._export_results
120
121
122class Experiment(object):
123  """Experiment is a class containing all information needed to train a model.
124
125  THIS CLASS IS DEPRECATED. See
126  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
127  for general migration instructions.
128
129  After an experiment is created (by passing an Estimator and inputs for
130  training and evaluation), an Experiment instance knows how to invoke training
131  and eval loops in a sensible fashion for distributed training.
132  """
133
134  # TODO(ispir): remove delay_workers_by_global_step and make global step based
135  # waiting as only behavior.
136  @deprecated(None, "Please switch to tf.estimator.train_and_evaluate. You will"
137              " also have to convert to a tf.estimator.Estimator.")
138  def __init__(self,
139               estimator,
140               train_input_fn,
141               eval_input_fn,
142               eval_metrics=None,
143               train_steps=None,
144               eval_steps=100,
145               train_monitors=None,
146               eval_hooks=None,
147               local_eval_frequency=None,
148               eval_delay_secs=120,
149               continuous_eval_throttle_secs=60,
150               min_eval_frequency=None,
151               delay_workers_by_global_step=False,
152               export_strategies=None,
153               train_steps_per_iteration=None,
154               checkpoint_and_export=False,
155               saving_listeners=None,
156               check_interval_secs=5):
157    """Constructor for `Experiment`.
158
159    Creates an Experiment instance. None of the functions passed to this
160    constructor are executed at construction time. They are stored and used
161    when a method is executed which requires it.
162
163    Args:
164      estimator: Object implementing Estimator interface, which could be a
165        combination of `tf.contrib.learn.Trainable` and
166        `tf.contrib.learn.Evaluable` (deprecated), or
167        `tf.estimator.Estimator`.
168      train_input_fn: function, returns features and labels for training.
169      eval_input_fn: function, returns features and labels for evaluation. If
170        `eval_steps` is `None`, this should be configured only to produce for a
171        finite number of batches (generally, 1 epoch over the evaluation data).
172      eval_metrics: `dict` of string, metric function. If `None`, default set
173        is used. This should be `None` if the `estimator` is
174        `tf.estimator.Estimator`. If metrics are provided they will be
175        *appended* to the default set.
176      train_steps: Perform this many steps of training. `None`, the default,
177        means train forever.
178      eval_steps: `evaluate` runs until input is exhausted (or another exception
179        is raised), or for `eval_steps` steps, if specified.
180      train_monitors: A list of monitors to pass to the `Estimator`'s `fit`
181        function.
182      eval_hooks: A list of `SessionRunHook` hooks to pass to the
183        `Estimator`'s `evaluate` function.
184      local_eval_frequency: (applies only to local_run) Frequency of running
185        eval in steps. If `None`, runs evaluation only at the end of training.
186      eval_delay_secs: Start evaluating after waiting for this many seconds.
187      continuous_eval_throttle_secs: Do not re-evaluate unless the last
188        evaluation was started at least this many seconds ago for
189        continuous_eval().
190      min_eval_frequency: (applies only to train_and_evaluate). the minimum
191        number of steps between evaluations. Of course, evaluation does not
192        occur if no new snapshot is available, hence, this is the minimum.
193        If 0, the evaluation will only happen after training.
194        If None, defaults to 1. To avoid checking for new checkpoints too
195        frequent, the interval is further limited to be at least
196        check_interval_secs between checks.
197      delay_workers_by_global_step: if `True` delays training workers
198        based on global step instead of time.
199      export_strategies: Iterable of `ExportStrategy`s, or a single one, or
200        `None`.
201      train_steps_per_iteration: (applies only to continuous_train_and_eval).
202        Perform this many (integer) number of train steps for each
203        training-evaluation iteration. With a small value, the model will be
204        evaluated more frequently with more checkpoints saved. If `None`, will
205        use a default value (which is smaller than `train_steps` if provided).
206      checkpoint_and_export: (applies only to train_and_evaluate). If `True`,
207        performs intermediate model checkpoints and exports during the training
208        process, rather than only once model training is complete. This
209        parameter is experimental and may be changed or removed in the future.
210        Setting this parameter leads to the following: the value of
211        `min_eval_frequency` will be ignored, and the number of steps between
212        evaluations and exports will instead be determined by the Estimator
213        configuration parameters `save_checkpoints_secs` and
214        `save_checkpoints_steps`. Also, this parameter leads to the creation of
215        a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the
216        provided `train_monitors` will need to be adjusted accordingly.
217      saving_listeners: list of `CheckpointSaverListener` objects. Used by
218        tf.estimator.Estimator for callbacks that run immediately before or
219        after checkpoint savings.
220      check_interval_secs:
221        Minimum time between subsequent checks for a new checkpoint. This
222        mostly applies if both min_eval_frequency and the time spent per
223        training step is low.
224    Raises:
225      ValueError: if `estimator` does not implement Estimator interface,
226        or if export_strategies has the wrong type.
227    """
228    if isinstance(estimator, core_estimator.Estimator):
229      self._core_estimator_used = True
230      if eval_metrics is not None:
231        raise ValueError(
232            "`eval_metrics` must be `None` with `tf.estimator.Estimator`. "
233            "Use `eval_metric_ops` in `tf.estimator.EstimatorSpec` instead.")
234    else:
235      self._core_estimator_used = False
236      if not isinstance(estimator, evaluable.Evaluable):
237        raise ValueError(
238            "`estimator` must implement `tf.contrib.learn.Evaluable` "
239            "or `tf.estimator.Estimator`.")
240      if not isinstance(estimator, trainable.Trainable):
241        raise ValueError(
242            "`estimator` must implement `tf.contrib.learn.Trainable`"
243            "or `tf.estimator.`Estimator`.")
244      if saving_listeners is not None:
245        raise ValueError("`saving_listeners` must be `None` with "
246                         "`tf.contrib.learn.Estimator`.")
247
248    if isinstance(estimator, tpu_estimator.TPUEstimator):
249      logging.warn(
250          "`Experiment` class cannot work with `tf.contrib.tpu.TPUEstimator`. "
251          "Please call `TPUEstimator` train/evaluate directly. \n"
252          "Details: `Experiment` class is designed for between-graph "
253          "distributed training, while `TPUEstimator` is working in in-graph "
254          "distributed mode. Use with care.")
255
256    super(Experiment, self).__init__()
257    # Immutable fields.
258    self._estimator = estimator
259    self._train_input_fn = train_input_fn
260    self._eval_input_fn = eval_input_fn
261    self._eval_metrics = eval_metrics
262    self._train_steps = train_steps
263    self._eval_steps = eval_steps
264    self._local_eval_frequency = local_eval_frequency
265    self._eval_delay_secs = eval_delay_secs
266    self._continuous_eval_throttle_secs = continuous_eval_throttle_secs
267    self._checkpoint_and_export = checkpoint_and_export
268    self._saving_listeners = saving_listeners
269    self._min_eval_frequency = min_eval_frequency if (
270        min_eval_frequency is not None) else 1
271    self._check_interval_secs = check_interval_secs
272    self._delay_workers_by_global_step = delay_workers_by_global_step
273    self._train_monitors = train_monitors[:] if train_monitors else []
274    self._eval_hooks = eval_hooks[:] if eval_hooks else []
275    self._set_export_strategies(export_strategies)
276
277    self._train_steps_per_iteration = train_steps_per_iteration
278    if (self._train_steps_per_iteration is not None and
279        not isinstance(self._train_steps_per_iteration, int)):
280      raise ValueError("`train_steps_per_iteration` must be an integer.")
281
282  @property
283  def estimator(self):
284    return self._estimator
285
286  @property
287  def eval_metrics(self):
288    return self._eval_metrics
289
290  @property
291  def train_steps(self):
292    return self._train_steps
293
294  @property
295  def eval_steps(self):
296    return self._eval_steps
297
298  def _set_export_strategies(self, values):  # pylint: disable=missing-docstring
299    export_strategies = []
300    if values:
301      if isinstance(values, export_strategy.ExportStrategy):
302        export_strategies.append(values)
303      else:
304        for value in values:
305          if not isinstance(value, export_strategy.ExportStrategy):
306            raise ValueError("`export_strategies` must be an ExportStrategy,"
307                             " an iterable of ExportStrategy, or `None`,"
308                             " found %s." % value)
309          export_strategies.append(value)
310    self._export_strategies = tuple(export_strategies)
311
312  def extend_train_hooks(self, additional_hooks):
313    """Extends the hooks for training."""
314    self._train_monitors.extend(additional_hooks)
315
316  def reset_export_strategies(self, new_export_strategies=None):
317    """Resets the export strategies with the `new_export_strategies`.
318
319    Args:
320      new_export_strategies: A new list of `ExportStrategy`s, or a single one,
321        or None.
322
323    Returns:
324      The old export strategies.
325    """
326    old_export_strategies = self._export_strategies
327    self._set_export_strategies(new_export_strategies)
328    return old_export_strategies
329
330  def train(self, delay_secs=None):
331    """Fit the estimator using the training data.
332
333    Train the estimator for `self._train_steps` steps, after waiting for
334    `delay_secs` seconds. If `self._train_steps` is `None`, train forever.
335
336    Args:
337      delay_secs: Start training after this many seconds.
338
339    Returns:
340      The trained estimator.
341    """
342    start = time.time()
343
344    # Start the server, if needed. It's important to start the server before
345    # we (optionally) sleep for the case where no device_filters are set.
346    # Otherwise, the servers will wait to connect to each other before starting
347    # to train. We might as well start as soon as we can.
348    config = self._estimator.config
349    if isinstance(config, run_config.RunConfig):
350      if (config.cluster_spec and config.master and
351          config.environment == run_config.Environment.LOCAL):
352        logging.warn("ClusterSpec and master are provided, but environment is "
353                     "set to 'local'. Set environment to 'cloud' if you intend "
354                     "to use the distributed runtime.")
355      if (config.environment != run_config.Environment.LOCAL and
356          config.environment != run_config.Environment.GOOGLE and
357          config.cluster_spec and config.master):
358        self._start_server()
359    elif config.cluster_spec and config.master:
360      raise ValueError(
361          "For distributed runtime, Experiment class only works with "
362          "tf.contrib.learn.RunConfig for now, but provided {}".format(
363              type(config)))
364
365    extra_hooks = []
366    if delay_secs is None:
367      task_id = self._estimator.config.task_id or 0
368      if self._delay_workers_by_global_step:
369        # Wait 5500 global steps for the second worker. Each worker waits more
370        # then previous one but with a diminishing number of steps.
371        extra_hooks.append(
372            basic_session_run_hooks.GlobalStepWaiterHook(
373                int(8000.0 * math.log(task_id + 1))))
374        delay_secs = 0
375      else:
376        # Wait 5 secs more for each new worker up to 60 secs.
377        delay_secs = min(60, task_id * 5)
378
379    if delay_secs > 0:
380      elapsed_secs = time.time() - start
381      remaining = delay_secs - elapsed_secs
382      logging.info("Waiting %d secs before starting training.", remaining)
383      time.sleep(delay_secs)
384
385    return self._call_train(
386        input_fn=self._train_input_fn,
387        max_steps=self._train_steps,
388        hooks=self._train_monitors + extra_hooks,
389        saving_listeners=self._saving_listeners)
390
391  def evaluate(self, delay_secs=None, name=None):
392    """Evaluate on the evaluation data.
393
394    Runs evaluation on the evaluation data and returns the result. Runs for
395    `self._eval_steps` steps, or if it's `None`, then run until input is
396    exhausted or another exception is raised. Start the evaluation after
397    `delay_secs` seconds, or if it's `None`, defaults to using
398    `self._eval_delay_secs` seconds.
399
400    Args:
401      delay_secs: Start evaluating after this many seconds. If `None`, defaults
402        to using `self._eval_delays_secs`.
403      name: Gives the name to the evauation for the case multiple evaluation is
404        run for the same experiment.
405
406    Returns:
407      The result of the `evaluate` call to the `Estimator`.
408    """
409    if delay_secs is None:
410      delay_secs = self._eval_delay_secs
411
412    if delay_secs:
413      logging.info("Waiting %d secs before starting eval.", delay_secs)
414      time.sleep(delay_secs)
415
416    return self._call_evaluate(
417        input_fn=self._eval_input_fn,
418        steps=self._eval_steps,
419        metrics=self._eval_metrics,
420        name=(name or "one_pass"),
421        hooks=self._eval_hooks)
422
423  @deprecated(
424      "2016-10-23",
425      "local_run will be renamed to train_and_evaluate and the new default "
426      "behavior will be to run evaluation every time there is a new "
427      "checkpoint.")
428  def local_run(self):
429    with _new_attr_context(self, "_min_eval_frequency"):
430      self._min_eval_frequency = self._local_eval_frequency
431      return self.train_and_evaluate()
432
433  # TODO(xiejw): Allow continuous_eval_predicate_fn to be passed via constructor
434  # once stopping all jobs is implemented.
435  def _continuous_eval(self,
436                       input_fn,
437                       name,
438                       delay_secs,
439                       throttle_delay_secs,
440                       evaluate_checkpoint_only_once=True,
441                       continuous_eval_predicate_fn=None,
442                       export=True):
443    """Run continuous eval.
444
445    Runs infinite eval on the evaluation data set. This function starts
446    evaluating after `delay_secs` seconds and then runs no more than one
447    evaluation (with `self._eval_steps` steps each time) per
448    `throttle_delay_secs`. If `train_steps` is not None, will return after
449    global_step reaches `train_steps`.
450
451    Args:
452      input_fn: The input to use for this eval.
453      name: A string appended to the folder name of evaluation results.
454      delay_secs: Start evaluating after this many seconds. If None, defaults to
455        self._eval_delay_secs.
456      throttle_delay_secs: Do not re-evaluate unless the last evaluation was
457        started at least this many seconds ago. If None, defaults to
458        self._continuous_eval_throttle_secs.
459      evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints
460        that have already been evaluated. Default is `True`.
461      continuous_eval_predicate_fn: A predicate function determining whether to
462        continue eval after each iteration. A `predicate_fn` has one of the
463        following signatures:
464          * (eval_results) -> boolean
465          * (eval_results, checkpoint_path) -> boolean
466        Where `eval_results` is the dictionary of metric evaluations and
467        checkpoint_path is the path to the checkpoint containing the parameters
468        on which that evaluation was based.
469        At the beginning of evaluation, the passed `eval_results` will be None
470        so it's expected that the predicate function handles that gracefully.
471        Continuous eval behavior under different conditions:
472          * When `predicate_fn` is specified:
473            + if `train_steps` is None, run until `predicate_fn` returns False.
474            + if `train_steps` is specified, run until either global step
475              reaches `train_steps` or `predicate_fn` returns False.
476          * When `predicate_fn` is not specified:
477            + if `train_steps` is None, run in an infinite loop.
478            + if `train_steps` is specified, run until global step reaches
479              `train_steps`.
480      export: Whether to export from this step. Default is 'True'.
481
482    Raises:
483      ValueError: if `continuous_eval_predicate_fn` is neither None nor
484        callable.
485    """
486    if continuous_eval_predicate_fn is not None:
487      if not callable(continuous_eval_predicate_fn):
488        raise ValueError(
489            "`continuous_eval_predicate_fn` must be a callable, or None.")
490      predicate_fn = _get_standardized_predicate_fn(
491          continuous_eval_predicate_fn)
492    else:
493      predicate_fn = None
494
495    if delay_secs is None:
496      delay_secs = self._eval_delay_secs
497    if throttle_delay_secs is None:
498      throttle_delay_secs = self._continuous_eval_throttle_secs
499
500    if delay_secs:
501      logging.info("Waiting %f secs before starting eval.", delay_secs)
502      time.sleep(delay_secs)
503
504    previous_path = None
505    eval_result = None
506    last_warning_time = 0
507    while (not predicate_fn or predicate_fn(
508        eval_result, checkpoint_path=previous_path)):
509      # Exit if we have already reached number of steps to train.
510      if self._has_training_stopped(eval_result):
511        logging.info("Exiting continuous eval, global_step=%s >= "
512                     "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP],
513                     self._train_steps)
514        return
515
516      start = time.time()
517
518      error_msg = None
519      latest_path = checkpoint_management.latest_checkpoint(
520          self._estimator.model_dir)
521      if not latest_path:
522        error_msg = ("Estimator is not fitted yet. "
523                     "Will start an evaluation when a checkpoint is ready.")
524      elif evaluate_checkpoint_only_once and latest_path == previous_path:
525        error_msg = "No new checkpoint ready for evaluation."
526
527      if error_msg:
528        # Print warning message every 10 mins.
529        eval_result = {}
530        if time.time() - last_warning_time > 600:
531          logging.warning(error_msg)
532          last_warning_time = time.time()
533      else:
534        eval_result = self._call_evaluate(
535            input_fn=input_fn,
536            steps=self._eval_steps,
537            metrics=self._eval_metrics,
538            name=name,
539            checkpoint_path=latest_path,
540            hooks=self._eval_hooks)
541        # Ensure eval result is not None for next round of evaluation.
542        if not eval_result:
543          eval_result = {}
544
545        if export:
546          self._maybe_export(eval_result, checkpoint_path=latest_path)
547
548        # Clear warning timer and update last evaluated checkpoint
549        last_warning_time = 0
550        previous_path = latest_path
551
552      duration = time.time() - start
553      if duration < throttle_delay_secs:
554        difference = throttle_delay_secs - duration
555        logging.info("Waiting %f secs before starting next eval run.",
556                     difference)
557        time.sleep(difference)
558
559  def _has_training_stopped(self, eval_result):
560    """Determines whether the training has stopped."""
561    if not eval_result:
562      return False
563
564    global_step = eval_result.get(ops.GraphKeys.GLOBAL_STEP)
565    return global_step and self._train_steps and (global_step >=
566                                                  self._train_steps)
567
568  def continuous_eval(self,
569                      delay_secs=None,
570                      throttle_delay_secs=None,
571                      evaluate_checkpoint_only_once=True,
572                      continuous_eval_predicate_fn=None,
573                      name="continuous"):
574    self._continuous_eval(
575        self._eval_input_fn,
576        name=name,
577        delay_secs=delay_secs,
578        throttle_delay_secs=throttle_delay_secs,
579        evaluate_checkpoint_only_once=evaluate_checkpoint_only_once,
580        continuous_eval_predicate_fn=continuous_eval_predicate_fn)
581
582  def continuous_eval_on_train_data(self,
583                                    delay_secs=None,
584                                    throttle_delay_secs=None,
585                                    continuous_eval_predicate_fn=None,
586                                    name="continuous_on_train_data"):
587    self._continuous_eval(
588        self._train_input_fn,
589        name=name,
590        delay_secs=delay_secs,
591        throttle_delay_secs=throttle_delay_secs,
592        continuous_eval_predicate_fn=continuous_eval_predicate_fn,
593        export=False)
594
595  def train_and_evaluate(self):
596    """Interleaves training and evaluation.
597
598    The frequency of evaluation is controlled by the constructor arg
599    `min_eval_frequency`. When this parameter is 0, evaluation happens
600    only after training has completed. Note that evaluation cannot happen
601    more frequently than checkpoints are taken. If no new snapshots are
602    available when evaluation is supposed to occur, then evaluation doesn't
603    happen for another `min_eval_frequency` steps (assuming a checkpoint is
604    available at that point). Thus, settings `min_eval_frequency` to 1 means
605    that the model will be evaluated everytime there is a new checkpoint.
606
607    This is particular useful for a "Master" task in the cloud, whose
608    responsibility it is to take checkpoints, evaluate those checkpoints,
609    and write out summaries. Participating in training as the supervisor
610    allows such a task to accomplish the first and last items, while
611    performing evaluation allows for the second.
612
613    Returns:
614      The result of the `evaluate` call to the `Estimator` as well as the
615      export results using the specified `ExportStrategy`.
616    """
617    # The directory to which evaluation summaries are written are determined
618    # by adding a suffix to 'eval'; that suffix is the 'name' parameter to
619    # the various evaluate(...) methods. By setting it to None, we force
620    # the directory name to simply be 'eval'.
621    eval_dir_suffix = None
622
623    # We set every_n_steps to 1, but evaluation only occurs when a new
624    # snapshot is available. If, by the time we finish evaluation
625    # there is a new snapshot, then we just evaluate again. Otherwise,
626    # we keep training until one becomes available.
627    with _new_attr_context(self, "_train_monitors"):
628      self._train_monitors = self._train_monitors or []
629      config = self._estimator.config
630      intermediate_export = self._checkpoint_and_export and (
631          config.save_checkpoints_secs or config.save_checkpoints_steps)
632      if intermediate_export:
633        # Create a partially specified evaluate function with the desired
634        # arguments. This will be executed by the _EvalAndExportListener,
635        # which will specify the latest checkpoint path.
636        eval_fn = functools.partial(
637            self._call_evaluate,
638            input_fn=self._eval_input_fn,
639            steps=self._eval_steps,
640            metrics=self._eval_metrics,
641            hooks=self._eval_hooks)
642
643        export_listener = _EvalAndExportListener(
644            eval_fn=eval_fn,
645            export_fn=self._maybe_export,
646            model_dir=self._estimator.model_dir)
647
648        saver_hook = basic_session_run_hooks.CheckpointSaverHook(
649            checkpoint_dir=self._estimator.model_dir,
650            save_secs=config.save_checkpoints_secs,
651            save_steps=config.save_checkpoints_steps,
652            listeners=[export_listener])
653        self._train_monitors += [saver_hook]
654      else:
655        if self._min_eval_frequency:
656          # Using low min_eval_frequency (default is 1) on a non-cached file
657          # system requires a lot of overhead to read the checkpoint state file.
658          # This is particular bad on GCS and CNS. See also b/36498507 for
659          # context. `check_interval_secs = 5` avoids polling a remote
660          # fileystem too often.
661
662          self._train_monitors += [
663              monitors.ValidationMonitor(
664                  input_fn=self._eval_input_fn,
665                  eval_steps=self._eval_steps,
666                  metrics=self._eval_metrics,
667                  every_n_steps=self._min_eval_frequency,
668                  check_interval_secs=self._check_interval_secs,
669                  name=eval_dir_suffix,
670                  hooks=self._eval_hooks)
671          ]
672      self.train(delay_secs=0)
673
674    # If the checkpoint_and_export flag and appropriate estimator configuration
675    # parameters are set, then model evaluations and exports are done during the
676    # training process. In particular, this will always occur at the end of
677    # training, so we return the most recent results to avoid performing a
678    # duplicate evaluation and model export.
679    if intermediate_export:
680      return export_listener.eval_result, export_listener.export_results
681    else:
682      eval_result = self._call_evaluate(
683          input_fn=self._eval_input_fn,
684          steps=self._eval_steps,
685          metrics=self._eval_metrics,
686          name=eval_dir_suffix,
687          hooks=self._eval_hooks)
688      export_results = self._maybe_export(eval_result)
689      return eval_result, export_results
690
691  @experimental
692  def continuous_train_and_eval(self, continuous_eval_predicate_fn=None):
693    """Interleaves training and evaluation.
694
695    The frequency of evaluation is controlled by the `train_steps_per_iteration`
696    (via constructor). The model will be first trained for
697    `train_steps_per_iteration`, and then be evaluated in turns.
698
699    This method is intended for single machine usage.
700
701    This differs from `train_and_evaluate` as follows:
702
703      1. The procedure will have train and evaluation in turns. The model
704      will be trained for a number of steps (usually smaller than `train_steps`
705      if provided) and then be evaluated.  `train_and_evaluate` will train the
706      model for `train_steps` (no small training iterations).
707
708      2. Due to the different approach this schedule takes, it leads to two
709      differences in resource control. First, the resources (e.g., memory) used
710      by training will be released before evaluation (`train_and_evaluate` takes
711      double resources). Second, more checkpoints will be saved as a checkpoint
712      is generated at the end of each training iteration.
713
714      3. As the estimator.train starts from scratch (new graph, new states for
715      input, etc) at each iteration, it is recommended to have the
716      `train_steps_per_iteration` larger. It is also recommended to shuffle your
717      input.
718
719    Args:
720      continuous_eval_predicate_fn: A predicate function determining whether to
721        continue eval after each iteration. A `predicate_fn` has one of the
722        following signatures:
723          * (eval_results) -> boolean
724          * (eval_results, checkpoint_path) -> boolean
725        Where `eval_results` is the dictionary of metric evaluations and
726        checkpoint_path is the path to the checkpoint containing the parameters
727        on which that evaluation was based.
728        At the beginning of evaluation, the passed `eval_results` and
729        `checkpoint_path` will be None so it's expected that the predicate
730        function handles that gracefully.
731        When `predicate_fn` is not specified, continuous eval will run in an
732        infinite loop (if `train_steps` is None). or exit once global step
733        reaches `train_steps`.
734
735    Returns:
736      A tuple of the result of the `evaluate` call to the `Estimator` and the
737      export results using the specified `ExportStrategy`.
738
739    Raises:
740      ValueError: if `continuous_eval_predicate_fn` is neither None nor
741        callable.
742    """
743
744    if continuous_eval_predicate_fn is not None:
745      if not callable(continuous_eval_predicate_fn):
746        raise ValueError(
747            "`continuous_eval_predicate_fn` must be a callable, or None.")
748      predicate_fn = _get_standardized_predicate_fn(
749          continuous_eval_predicate_fn)
750    else:
751      predicate_fn = None
752
753    export_results = None
754    latest_checkpoint = None
755    eval_result = None
756
757    # Set the default value for train_steps_per_iteration, which will be
758    # overridden by other settings.
759    train_steps_per_iteration = 1000
760    if self._train_steps_per_iteration is not None:
761      train_steps_per_iteration = self._train_steps_per_iteration
762    elif self._train_steps is not None:
763      train_steps_per_iteration = int(self._train_steps / 10)
764
765    while (not predicate_fn or predicate_fn(
766        eval_result, checkpoint_path=latest_checkpoint
767        if eval_result else None)):
768
769      if self._has_training_stopped(eval_result):
770        # Exits once max steps of training is satisfied.
771        logging.info("Stop training model as max steps reached")
772        break
773
774      logging.info("Training model for %s steps", train_steps_per_iteration)
775      self._call_train(
776          input_fn=self._train_input_fn,
777          steps=train_steps_per_iteration,
778          hooks=self._train_monitors,
779          saving_listeners=self._saving_listeners)
780
781      logging.info("Evaluating model now.")
782      latest_checkpoint = checkpoint_management.latest_checkpoint(
783          self._estimator.model_dir)
784      eval_result = self._call_evaluate(
785          input_fn=self._eval_input_fn,
786          steps=self._eval_steps,
787          metrics=self._eval_metrics,
788          name="one_pass",
789          checkpoint_path=latest_checkpoint,
790          hooks=self._eval_hooks)
791      export_results = self._maybe_export(eval_result)
792
793    return eval_result, export_results
794
795  def _maybe_export(self, eval_result, checkpoint_path=None):
796    """Export the Estimator using export_fn, if defined."""
797    export_dir_base = os.path.join(
798        compat.as_bytes(self._estimator.model_dir), compat.as_bytes("export"))
799
800    export_results = []
801    for strategy in self._export_strategies:
802      export_results.append(
803          strategy.export(
804              self._estimator,
805              os.path.join(
806                  compat.as_bytes(export_dir_base),
807                  compat.as_bytes(strategy.name)),
808              checkpoint_path=checkpoint_path,
809              eval_result=eval_result))
810
811    return export_results
812
813  def run_std_server(self):
814    """Starts a TensorFlow server and joins the serving thread.
815
816    Typically used for parameter servers.
817
818    Raises:
819      ValueError: if not enough information is available in the estimator's
820        config to create a server.
821    """
822    self._start_server().join()
823
824  def test(self):
825    """Tests training, evaluating and exporting the estimator for a single step.
826
827    Returns:
828      The result of the `evaluate` call to the `Estimator`.
829    """
830    self._call_train(
831        input_fn=self._train_input_fn,
832        steps=1,
833        hooks=self._train_monitors,
834        saving_listeners=self._saving_listeners)
835
836    eval_result = self._call_evaluate(
837        input_fn=self._eval_input_fn,
838        steps=1,
839        metrics=self._eval_metrics,
840        name="one_pass")
841    _ = self._maybe_export(eval_result)
842
843    return eval_result
844
845  def _start_server(self):
846    """Creates, starts, and returns a server_lib.Server."""
847    config = self._estimator.config
848    if (not config.cluster_spec or not config.task_type or not config.master or
849        config.task_id is None):
850      raise ValueError("Could not start server; be sure to specify "
851                       "cluster_spec, task_type, master, and task in "
852                       "RunConfig or set the TF_CONFIG environment variable.")
853    server = server_lib.Server(
854        config.cluster_spec,
855        job_name=config.task_type,
856        task_index=config.task_id,
857        config=config.tf_config,
858        start=False)
859    server.start()
860    return server
861
862  def _call_train(
863      self,
864      _sentinel=None,  # pylint: disable=invalid-name,
865      input_fn=None,
866      steps=None,
867      hooks=None,
868      max_steps=None,
869      saving_listeners=None):
870    if _sentinel is not None:
871      raise ValueError("_call_train should be called with keyword args only")
872
873    # Estimator in core cannot work with monitors. We need to convert them
874    # to hooks. For Estimator in contrib, it is converted internally. So, it is
875    # safe to convert for both cases.
876    hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
877    if self._core_estimator_used:
878      return self._estimator.train(
879          input_fn=input_fn,
880          steps=steps,
881          max_steps=max_steps,
882          hooks=hooks,
883          saving_listeners=saving_listeners)
884    else:
885      return self._estimator.fit(
886          input_fn=input_fn, steps=steps, max_steps=max_steps, monitors=hooks)
887
888  def _call_evaluate(
889      self,
890      _sentinel=None,  # pylint: disable=invalid-name,
891      input_fn=None,
892      steps=None,
893      metrics=None,
894      name=None,
895      checkpoint_path=None,
896      hooks=None):
897    if _sentinel is not None:
898      raise ValueError("_call_evaluate should be called with keyword args only")
899
900    if self._core_estimator_used:
901      if metrics is not None:
902        raise ValueError(
903            "`eval_metrics` must be `None` with `tf.estimator.Estimator`")
904      return self._estimator.evaluate(
905          input_fn=input_fn,
906          steps=steps,
907          name=name,
908          checkpoint_path=checkpoint_path,
909          hooks=hooks)
910    else:
911      return self._estimator.evaluate(
912          input_fn=input_fn,
913          steps=steps,
914          metrics=metrics,
915          name=name,
916          checkpoint_path=checkpoint_path,
917          hooks=hooks)
918
919
920@contextlib.contextmanager
921def _new_attr_context(obj, attr):
922  """Creates a new context in which an object's attribute can be changed.
923
924  This creates a context in which an object's attribute can be changed.
925  Once the context is exited, the attribute reverts to its original value.
926
927  Args:
928    obj: An object whose attribute to restore at the end of the context.
929    attr: An attribute to remember and restore at the end of the context.
930
931  Yields:
932    Context.
933
934  Example:
935    my_obj.x = 1
936    with _new_attr_context(my_obj, "x"):
937      my_obj.x = 2
938      print(my_obj.x)
939    print(my_obj.x)
940  """
941  saved = getattr(obj, attr)
942  try:
943    yield
944  finally:
945    setattr(obj, attr, saved)
946