• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Monitors instrument the training process (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20
21@@get_default_monitors
22@@BaseMonitor
23@@CaptureVariable
24@@CheckpointSaver
25@@EveryN
26@@ExportMonitor
27@@GraphDump
28@@LoggingTrainable
29@@NanLoss
30@@PrintTensor
31@@StepCounter
32@@StopAtStep
33@@SummarySaver
34@@ValidationMonitor
35"""
36
37from __future__ import absolute_import
38from __future__ import division
39from __future__ import print_function
40
41import copy
42import os
43import time
44
45import numpy as np
46import six
47
48from tensorflow.core.framework.summary_pb2 import Summary
49from tensorflow.core.util.event_pb2 import SessionLog
50from tensorflow.python.estimator import estimator as core_estimator
51from tensorflow.python.framework import ops
52from tensorflow.python.platform import tf_logging as logging
53from tensorflow.python.summary import summary as core_summary
54from tensorflow.python.training import checkpoint_management
55from tensorflow.python.training import session_run_hook
56from tensorflow.python.training import training_util
57from tensorflow.python.util import deprecation
58from tensorflow.python.util import tf_inspect
59
60
61# TODO(ptucker): Split each monitor class into a separate file.
62# TODO(ptucker): Fail if epoch or step does not monotonically increase?
63class BaseMonitor(object):
64  """Base class for Monitors.
65
66  THIS CLASS IS DEPRECATED. See
67  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
68  for general migration instructions.
69
70  Defines basic interfaces of Monitors.
71  Monitors can either be run on all workers or, more commonly, restricted
72  to run exclusively on the elected chief worker.
73  """
74
75  @deprecation.deprecated(
76      "2016-12-05",
77      "Monitors are deprecated. Please use tf.train.SessionRunHook.")
78  def __init__(self):
79    self._begun = False
80    self._current_epoch = None
81    self._current_step = None
82    self._max_steps = None
83    self._estimator = None
84
85  @property
86  def run_on_all_workers(self):
87    return False
88
89  def set_estimator(self, estimator):
90    """A setter called automatically by the target estimator.
91
92    If the estimator is locked, this method does nothing.
93
94    Args:
95      estimator: the estimator that this monitor monitors.
96
97    Raises:
98      ValueError: if the estimator is None.
99    """
100    if estimator is None:
101      raise ValueError("Missing estimator.")
102    # TODO(mdan): This should fail if called twice with the same estimator.
103    self._estimator = estimator
104
105  def begin(self, max_steps=None):
106    """Called at the beginning of training.
107
108    When called, the default graph is the one we are executing.
109
110    Args:
111      max_steps: `int`, the maximum global step this training will run until.
112
113    Raises:
114      ValueError: if we've already begun a run.
115    """
116    if self._begun:
117      raise ValueError("begin called twice without end.")
118    self._max_steps = max_steps
119    self._begun = True
120
121  def end(self, session=None):
122    """Callback at the end of training/evaluation.
123
124    Args:
125      session: A `tf.Session` object that can be used to run ops.
126
127    Raises:
128      ValueError: if we've not begun a run.
129    """
130    _ = session
131    if not self._begun:
132      raise ValueError("end called without begin.")
133    self._max_steps = None
134    self._begun = False
135
136  def epoch_begin(self, epoch):
137    """Begin epoch.
138
139    Args:
140      epoch: `int`, the epoch number.
141
142    Raises:
143      ValueError: if we've already begun an epoch, or `epoch` < 0.
144    """
145    if self._current_epoch is not None:
146      raise ValueError("epoch_begin called twice without epoch_end.")
147    if epoch < 0:
148      raise ValueError("Invalid epoch %s." % epoch)
149    self._current_epoch = epoch
150
151  def epoch_end(self, epoch):
152    """End epoch.
153
154    Args:
155      epoch: `int`, the epoch number.
156
157    Raises:
158      ValueError: if we've not begun an epoch, or `epoch` number does not match.
159    """
160    if self._current_epoch != epoch:
161      raise ValueError("epoch_end expected %s but got %s.", self._current_epoch,
162                       epoch)
163    self._current_epoch = None
164
165  def step_begin(self, step):
166    """Callback before training step begins.
167
168    You may use this callback to request evaluation of additional tensors
169    in the graph.
170
171    Args:
172      step: `int`, the current value of the global step.
173
174    Returns:
175      List of `Tensor` objects or string tensor names to be run.
176
177    Raises:
178      ValueError: if we've already begun a step, or `step` < 0, or
179          `step` > `max_steps`.
180    """
181    if (step < 0) or ((self._max_steps is not None) and
182                      (step > self._max_steps)):
183      raise ValueError("Invalid step %s." % step)
184    self._current_step = step
185    return []
186
187  def step_end(self, step, output):  # pylint: disable=unused-argument
188    """Callback after training step finished.
189
190    This callback provides access to the tensors/ops evaluated at this step,
191    including the additional tensors for which evaluation was requested in
192    `step_begin`.
193
194    In addition, the callback has the opportunity to stop training by returning
195    `True`. This is useful for early stopping, for example.
196
197    Note that this method is not called if the call to `Session.run()` that
198    followed the last call to `step_begin()` failed.
199
200    Args:
201      step: `int`, the current value of the global step.
202      output: `dict` mapping `string` values representing tensor names to
203        the value resulted from running these tensors. Values may be either
204        scalars, for scalar tensors, or Numpy `array`, for non-scalar tensors.
205
206    Returns:
207      `bool`. True if training should stop.
208
209    Raises:
210      ValueError: if we've not begun a step, or `step` number does not match.
211    """
212    if self._current_step != step:
213      raise ValueError("step_end expected %s but got %s.", self._current_step,
214                       step)
215    self._current_step = None
216    return False
217
218  def post_step(self, step, session):  # pylint: disable=unused-argument
219    """Callback after the step is finished.
220
221    Called after step_end and receives session to perform extra session.run
222    calls. If failure occurred in the process, will be called as well.
223
224    Args:
225      step: `int`, global step of the model.
226      session: `Session` object.
227    """
228    _ = step, session
229
230
231def _extract_output(outputs, request):
232  if request in outputs:
233    return outputs[request]
234  return outputs[request.name]
235
236
237class EveryN(BaseMonitor):
238  """Base class for monitors that execute callbacks every N steps.
239
240  THIS CLASS IS DEPRECATED. See
241  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
242  for general migration instructions.
243
244  This class adds three new callbacks:
245    - every_n_step_begin
246    - every_n_step_end
247    - every_n_post_step
248
249  The callbacks are executed every n steps, or optionally every step for the
250  first m steps, where m and n can both be user-specified.
251
252  When extending this class, note that if you wish to use any of the
253  `BaseMonitor` callbacks, you must call their respective super implementation:
254
255    def step_begin(self, step):
256      super(ExampleMonitor, self).step_begin(step)
257      return []
258
259  Failing to call the super implementation will cause unpredictable behavior.
260
261  The `every_n_post_step()` callback is also called after the last step if it
262  was not already called through the regular conditions.  Note that
263  `every_n_step_begin()` and `every_n_step_end()` do not receive that special
264  treatment.
265
266  """
267
268  # TODO(ipolosukhin): Add also every n seconds.
269
270  def __init__(self, every_n_steps=100, first_n_steps=1):
271    """Initializes an `EveryN` monitor.
272
273    Args:
274      every_n_steps: `int`, the number of steps to allow between callbacks.
275      first_n_steps: `int`, specifying the number of initial steps during
276        which the callbacks will always be executed, regardless of the value
277        of `every_n_steps`. Note that this value is relative to the global step
278    """
279    super(EveryN, self).__init__()
280    self._every_n_steps = every_n_steps
281    self._first_n_steps = first_n_steps
282    # Last step in the model.
283    self._last_successful_step = None
284    # Last step at which we called one of the every_n methods
285    self._last_active_step = 0
286    self._every_n_step_begin_called = False
287
288  def every_n_step_begin(self, step):  # pylint: disable=unused-argument
289    """Callback before every n'th step begins.
290
291    Args:
292      step: `int`, the current value of the global step.
293
294    Returns:
295      A `list` of tensors that will be evaluated at this step.
296    """
297    return []
298
299  def every_n_step_end(self, step, outputs):  # pylint: disable=unused-argument
300    """Callback after every n'th step finished.
301
302    This callback provides access to the tensors/ops evaluated at this step,
303    including the additional tensors for which evaluation was requested in
304    `step_begin`.
305
306    In addition, the callback has the opportunity to stop training by returning
307    `True`. This is useful for early stopping, for example.
308
309    Args:
310      step: `int`, the current value of the global step.
311      outputs: `dict` mapping `string` values representing tensor names to
312        the value resulted from running these tensors. Values may be either
313        scalars, for scalar tensors, or Numpy `array`, for non-scalar tensors.
314
315    Returns:
316      `bool`. True if training should stop.
317    """
318    return False
319
320  def every_n_post_step(self, step, session):
321    """Callback after a step is finished or `end()` is called.
322
323    Args:
324      step: `int`, the current value of the global step.
325      session: `Session` object.
326    """
327    pass
328
329  def step_begin(self, step):
330    """Overrides `BaseMonitor.step_begin`.
331
332    When overriding this method, you must call the super implementation.
333
334    Args:
335      step: `int`, the current value of the global step.
336    Returns:
337      A `list`, the result of every_n_step_begin, if that was called this step,
338      or an empty list otherwise.
339
340    Raises:
341      ValueError: if called more than once during a step.
342    """
343    super(EveryN, self).step_begin(step)
344    if (step <= self._first_n_steps or
345        step >= (self._every_n_steps + self._last_active_step) or
346        step == self._max_steps):  # Note: max_steps can be None here.
347      self._every_n_step_begin_called = True
348      return self.every_n_step_begin(step)
349    self._every_n_step_begin_called = False
350    return []
351
352  def step_end(self, step, output):
353    """Overrides `BaseMonitor.step_end`.
354
355    When overriding this method, you must call the super implementation.
356
357    Args:
358      step: `int`, the current value of the global step.
359      output: `dict` mapping `string` values representing tensor names to
360        the value resulted from running these tensors. Values may be either
361        scalars, for scalar tensors, or Numpy `array`, for non-scalar tensors.
362    Returns:
363      `bool`, the result of every_n_step_end, if that was called this step,
364      or `False` otherwise.
365    """
366    super(EveryN, self).step_end(step, output)
367    if self._every_n_step_begin_called:
368      return self.every_n_step_end(step, output)
369    return False
370
371  def post_step(self, step, session):
372    super(EveryN, self).post_step(step, session)
373    if self._every_n_step_begin_called:
374      self.every_n_post_step(step, session)
375      self._last_active_step = step
376    self._last_successful_step = step
377
378  def end(self, session=None):
379    super(EveryN, self).end(session=session)
380    if self._last_successful_step != self._last_active_step:
381      self.every_n_post_step(self._last_successful_step, session)
382
383
384class StopAtStep(BaseMonitor):
385  """Monitor to request stop at a specified step."""
386
387  def __init__(self, num_steps=None, last_step=None):
388    """Create a StopAtStep monitor.
389
390    This monitor requests stop after either a number of steps have been
391    executed or a last step has been reached.  Only of the two options can be
392    specified.
393
394    if `num_steps` is specified, it indicates the number of steps to execute
395    after `begin()` is called.  If instead `last_step` is specified, it
396    indicates the last step we want to execute, as passed to the `step_begin()`
397    call.
398
399    Args:
400      num_steps: Number of steps to execute.
401      last_step: Step after which to stop.
402
403    Raises:
404      ValueError: If one of the arguments is invalid.
405    """
406    super(StopAtStep, self).__init__()
407    if num_steps is None and last_step is None:
408      raise ValueError("One of num_steps or last_step must be specified.")
409    if num_steps is not None and last_step is not None:
410      raise ValueError("Only one of num_steps or last_step can be specified.")
411    self._num_steps = num_steps
412    self._last_step = last_step
413
414  @property
415  def run_on_all_workers(self):
416    return True
417
418  def step_begin(self, step):
419    super(StopAtStep, self).step_begin(step)
420    if self._last_step is None:
421      self._last_step = step + self._num_steps - 1
422    return []
423
424  def step_end(self, step, output):
425    super(StopAtStep, self).step_end(step, output)
426    return step >= self._last_step
427
428
429# TODO(ptucker): Rename to LoggingTensor since it's not writing to stdout.
430class PrintTensor(EveryN):
431  """Prints given tensors every N steps.
432
433  THIS CLASS IS DEPRECATED. See
434  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
435  for general migration instructions.
436
437  This is an `EveryN` monitor and has consistent semantic for `every_n`
438  and `first_n`.
439
440  The tensors will be printed to the log, with `INFO` severity.
441  """
442
443  def __init__(self, tensor_names, every_n=100, first_n=1):
444    """Initializes a PrintTensor monitor.
445
446    Args:
447      tensor_names: `dict` of tag to tensor names or
448          `iterable` of tensor names (strings).
449      every_n: `int`, print every N steps. See `PrintN.`
450      first_n: `int`, also print the first N steps. See `PrintN.`
451    """
452    super(PrintTensor, self).__init__(every_n, first_n)
453    if not isinstance(tensor_names, dict):
454      tensor_names = {item: item for item in tensor_names}
455    self._tensor_names = tensor_names
456
457  def every_n_step_begin(self, step):
458    super(PrintTensor, self).every_n_step_begin(step)
459    return list(self._tensor_names.values())
460
461  def every_n_step_end(self, step, outputs):
462    super(PrintTensor, self).every_n_step_end(step, outputs)
463    stats = []
464    for tag, tensor_name in six.iteritems(self._tensor_names):
465      if tensor_name in outputs:
466        stats.append("%s = %s" % (tag,
467                                  str(_extract_output(outputs, tensor_name))))
468    logging.info("Step %d: %s", step, ", ".join(stats))
469
470
471class LoggingTrainable(EveryN):
472  """Writes trainable variable values into log every N steps.
473
474  THIS CLASS IS DEPRECATED. See
475  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
476  for general migration instructions.
477
478  Write the tensors in trainable variables `every_n` steps,
479  starting with the `first_n`th step.
480  """
481
482  def __init__(self, scope=None, every_n=100, first_n=1):
483    """Initializes LoggingTrainable monitor.
484
485    Args:
486      scope: An optional string to match variable names using re.match.
487      every_n: Print every N steps.
488      first_n: Print first N steps.
489    """
490    super(LoggingTrainable, self).__init__(every_n, first_n)
491    self._scope = scope
492
493  def every_n_step_begin(self, step):
494    super(LoggingTrainable, self).every_n_step_begin(step)
495    # Get a list of trainable variables at the beginning of every N steps.
496    # We cannot get this in __init__ because train_op has not been generated.
497    trainables = ops.get_collection(
498        ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope)
499    self._names = {}
500    for var in trainables:
501      self._names[var.name] = var.value().name
502    return list(self._names.values())
503
504  def every_n_step_end(self, step, outputs):
505    super(LoggingTrainable, self).every_n_step_end(step, outputs)
506    stats = []
507    for tag, tensor_name in six.iteritems(self._names):
508      if tensor_name in outputs:
509        stats.append("%s = %s" % (tag,
510                                  str(_extract_output(outputs, tensor_name))))
511    logging.info("Logging Trainable: Step %d: %s", step, ", ".join(stats))
512
513
514class SummarySaver(EveryN):
515  """Saves summaries every N steps.
516
517  THIS CLASS IS DEPRECATED. See
518  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
519  for general migration instructions.
520  """
521
522  def __init__(self,
523               summary_op,
524               save_steps=100,
525               output_dir=None,
526               summary_writer=None,
527               scaffold=None):
528    """Initializes a `SummarySaver` monitor.
529
530    Args:
531      summary_op: `Tensor` of type `string`. A serialized `Summary` protocol
532          buffer, as output by TF summary methods like `summary.scalar` or
533          `summary.merge_all`.
534      save_steps: `int`, save summaries every N steps. See `EveryN`.
535      output_dir: `string`, the directory to save the summaries to. Only used
536          if no `summary_writer` is supplied.
537      summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
538          one will be created accordingly.
539      scaffold: `Scaffold` to get summary_op if it's not provided.
540    """
541    # TODO(ipolosukhin): Implement every N seconds.
542    super(SummarySaver, self).__init__(every_n_steps=save_steps)
543    self._summary_op = summary_op
544    self._summary_writer = summary_writer
545    if summary_writer is None and output_dir:
546      self._summary_writer = core_summary.FileWriter(output_dir)
547    self._scaffold = scaffold
548    # TODO(mdan): Throw an error if output_dir and summary_writer are None.
549
550  def set_estimator(self, estimator):
551    super(SummarySaver, self).set_estimator(estimator)
552    # TODO(mdan): This line looks redundant.
553    if self._summary_writer is None:
554      self._summary_writer = core_summary.FileWriter(estimator.model_dir)
555
556  def every_n_step_begin(self, step):
557    super(SummarySaver, self).every_n_step_begin(step)
558    if self._summary_op is None and self._scaffold is not None:
559      self._summary_op = self._scaffold.summary_op
560    if self._summary_op is not None:
561      return [self._summary_op]
562    return []
563
564  def every_n_step_end(self, step, outputs):
565    super(SummarySaver, self).every_n_step_end(step, outputs)
566    if self._summary_op is not None:
567      summary_strs = _extract_output(outputs, self._summary_op)
568      if self._summary_writer:
569        self._summary_writer.add_summary(summary_strs, step)
570    return False
571
572  def end(self, session=None):
573    super(SummarySaver, self).end(session=session)
574    if self._summary_writer:
575      self._summary_writer.flush()
576
577
578class ValidationMonitor(EveryN):
579  """Runs evaluation of a given estimator, at most every N steps.
580
581  THIS CLASS IS DEPRECATED. See
582  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
583  for general migration instructions.
584
585  Note that the evaluation is done based on the saved checkpoint, which will
586  usually be older than the current step.
587
588  Can do early stopping on validation metrics if `early_stopping_rounds` is
589  provided.
590  """
591
592  def __init__(self,
593               x=None,
594               y=None,
595               input_fn=None,
596               batch_size=None,
597               eval_steps=None,
598               every_n_steps=100,
599               metrics=None,
600               hooks=None,
601               early_stopping_rounds=None,
602               early_stopping_metric="loss",
603               early_stopping_metric_minimize=True,
604               name=None,
605               check_interval_secs=5):
606    """Initializes a ValidationMonitor.
607
608    Args:
609      x: See `BaseEstimator.evaluate`.
610      y: See `BaseEstimator.evaluate`.
611      input_fn: See `BaseEstimator.evaluate`.
612      batch_size: See `BaseEstimator.evaluate`.
613      eval_steps: See `BaseEstimator.evaluate`.
614      every_n_steps: Check for new checkpoints to evaluate every N steps. If a
615          new checkpoint is found, it is evaluated. See `EveryN`.
616      metrics: See `BaseEstimator.evaluate`.
617      hooks: A list of `SessionRunHook` hooks to pass to the
618        `Estimator`'s `evaluate` function.
619      early_stopping_rounds: `int`. If the metric indicated by
620          `early_stopping_metric` does not change according to
621          `early_stopping_metric_minimize` for this many steps, then training
622          will be stopped.
623      early_stopping_metric: `string`, name of the metric to check for early
624          stopping.
625      early_stopping_metric_minimize: `bool`, True if `early_stopping_metric` is
626          expected to decrease (thus early stopping occurs when this metric
627          stops decreasing), False if `early_stopping_metric` is expected to
628          increase. Typically, `early_stopping_metric_minimize` is True for
629          loss metrics like mean squared error, and False for performance
630          metrics like accuracy.
631      name: See `BaseEstimator.evaluate`.
632      check_interval_secs: Only check for new checkpoint if at least
633          `check_interval_secs` have passed. Ignore if None. Default is 5 secs.
634
635
636    Raises:
637      ValueError: If both x and input_fn are provided.
638    """
639    super(ValidationMonitor, self).__init__(
640        every_n_steps=every_n_steps, first_n_steps=-1)
641    # TODO(mdan): Checks like this are already done by evaluate.
642    if x is None and input_fn is None:
643      raise ValueError("Either x or input_fn should be provided.")
644    self.x = x
645    self.y = y
646    self.input_fn = input_fn
647    self.batch_size = batch_size
648    self.eval_steps = eval_steps
649    self.metrics = metrics
650    self.hooks = hooks
651    self.early_stopping_rounds = early_stopping_rounds
652    self.early_stopping_metric = early_stopping_metric
653    self.early_stopping_metric_minimize = early_stopping_metric_minimize
654    self.name = name
655    self._best_value_step = None
656    self._best_value = None
657    self._best_metrics = None
658    self._early_stopped = False
659    self._latest_path = None
660    self._latest_path_step = None
661    self._last_checkpoint_check_time = None
662    self._check_interval_secs = check_interval_secs
663
664  @property
665  def early_stopped(self):
666    """Returns True if this monitor caused an early stop."""
667    return self._early_stopped
668
669  @property
670  def best_step(self):
671    """Returns the step at which the best early stopping metric was found."""
672    return self._best_value_step
673
674  @property
675  def best_value(self):
676    """Returns the best early stopping metric value found so far."""
677    return self._best_value
678
679  @property
680  def best_metrics(self):
681    """Returns all eval metrics computed with the best early stopping metric.
682
683    For instance, if the metrics computed in two successive evals are
684    1. {'loss':40, 'auc':0.5}
685    2. {'loss':50, 'auc':0.6}
686    this function would return the first dict {'loss':40, 'auc':0.5} after both
687    first and second eval (if `early_stopping_metric` is 'loss' and
688    `early_stopping_metric_minimize` is True).
689
690    Returns:
691      The output dict of estimator.evaluate which contains the best value of
692      the early stopping metric seen so far.
693    """
694    return self._best_metrics
695
696  def _evaluate_estimator(self):
697    if isinstance(self._estimator, core_estimator.Estimator):
698      if any((x is not None
699              for x in [self.x, self.y, self.batch_size, self.metrics])):
700        raise ValueError(
701            "tf.estimator.Estimator does not support following "
702            "arguments: x, y, batch_size, metrics. Should set as `None` "
703            "in ValidationMonitor")
704      return self._estimator.evaluate(
705          input_fn=self.input_fn,
706          steps=self.eval_steps,
707          hooks=self.hooks,
708          name=self.name)
709    else:
710      return self._estimator.evaluate(
711          x=self.x,
712          y=self.y,
713          input_fn=self.input_fn,
714          batch_size=self.batch_size,
715          steps=self.eval_steps,
716          metrics=self.metrics,
717          hooks=self.hooks,
718          name=self.name)
719
720  def every_n_step_end(self, step, outputs):
721    super(ValidationMonitor, self).every_n_step_end(step, outputs)
722    # TODO(mdan): The use of step below is probably misleading.
723    # The code should probably use the step from the checkpoint, because
724    # that's what is being evaluated.
725    if self._estimator is None:
726      raise ValueError("Missing call to set_estimator.")
727    current_time = time.time()
728    if (self._check_interval_secs is not None and
729        self._last_checkpoint_check_time is not None and
730        current_time - self._last_checkpoint_check_time <=
731        self._check_interval_secs):
732      logging.debug(
733          "Skipping evaluation since less than %d seconds have passed since "
734          "last check for a new checkpoint.", self._check_interval_secs)
735      return False
736    self._last_checkpoint_check_time = current_time
737    # Check that we are not running evaluation on the same checkpoint.
738    latest_path = checkpoint_management.latest_checkpoint(
739        self._estimator.model_dir)
740    if latest_path is None:
741      logging.debug("Skipping evaluation since model has not been saved yet "
742                    "at step %d.", step)
743      return False
744    if latest_path is not None and latest_path == self._latest_path:
745      logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
746                    "as for step %d.", latest_path, step,
747                    self._latest_path_step)
748      return False
749    self._latest_path = latest_path
750    self._latest_path_step = step
751
752    # Run evaluation and log it.
753    validation_outputs = self._evaluate_estimator()
754    stats = []
755    for name in validation_outputs:
756      stats.append("%s = %s" % (name, str(validation_outputs[name])))
757    logging.info("Validation (step %d): %s", step, ", ".join(stats))
758
759    # Early stopping logic.
760    if self.early_stopping_rounds is not None:
761      if self.early_stopping_metric not in validation_outputs:
762        raise ValueError("Metric %s missing from outputs %s." %
763                         (self.early_stopping_metric,
764                          set(validation_outputs.keys())))
765      current_value = validation_outputs[self.early_stopping_metric]
766      if (self._best_value is None or (self.early_stopping_metric_minimize and
767                                       (current_value < self._best_value)) or
768          (not self.early_stopping_metric_minimize and
769           (current_value > self._best_value))):
770        self._best_value = current_value
771        self._best_metrics = copy.deepcopy(validation_outputs)
772        self._best_value_step = step
773      stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
774      if stop_now:
775        logging.info("Stopping. Best step: {} with {} = {}.".format(
776            self._best_value_step, self.early_stopping_metric,
777            self._best_value))
778        self._early_stopped = True
779        return True
780    return False
781
782
783# TODO(ptucker): This really reads any tensor, not just vars, and requires the
784# ':0' suffix on var_name.
785class CaptureVariable(EveryN):
786  """Captures a variable's values into a collection.
787
788  THIS CLASS IS DEPRECATED. See
789  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
790  for general migration instructions.
791
792  This monitor is useful for unit testing. You should exercise caution when
793  using this monitor in production, since it never discards values.
794
795  This is an `EveryN` monitor and has consistent semantic for `every_n`
796  and `first_n`.
797  """
798
799  def __init__(self, var_name, every_n=100, first_n=1):
800    """Initializes a CaptureVariable monitor.
801
802    Args:
803      var_name: `string`. The variable name, including suffix (typically ":0").
804      every_n: `int`, print every N steps. See `PrintN.`
805      first_n: `int`, also print the first N steps. See `PrintN.`
806    """
807    super(CaptureVariable, self).__init__(every_n, first_n)
808    self._var_name = var_name
809    self._var_values = {}
810
811  @property
812  def values(self):
813    """Returns the values captured so far.
814
815    Returns:
816      `dict` mapping `int` step numbers to that values of the variable at the
817          respective step.
818    """
819    return self._var_values
820
821  def every_n_step_begin(self, step):
822    super(CaptureVariable, self).every_n_step_begin(step)
823    return [self._var_name]
824
825  def every_n_step_end(self, step, outputs):
826    super(CaptureVariable, self).every_n_step_end(step, outputs)
827    self._var_values[step] = _extract_output(outputs, self._var_name)
828
829
830@deprecation.deprecated(None, "Use tf.train.MonitoredTrainingSession.")
831def get_default_monitors(loss_op=None,
832                         summary_op=None,
833                         save_summary_steps=100,
834                         output_dir=None,
835                         summary_writer=None):
836  """Returns a default set of typically-used monitors.
837
838  Args:
839    loss_op: `Tensor`, the loss tensor. This will be printed using `PrintTensor`
840        at the default interval.
841    summary_op: See `SummarySaver`.
842    save_summary_steps: See `SummarySaver`.
843    output_dir:  See `SummarySaver`.
844    summary_writer:  See `SummarySaver`.
845  Returns:
846    `list` of monitors.
847  """
848
849  monitors = []
850  if loss_op is not None:
851    monitors.append(PrintTensor(tensor_names={"loss": loss_op.name}))
852  if summary_op is not None:
853    monitors.append(
854        SummarySaver(
855            summary_op,
856            save_steps=save_summary_steps,
857            output_dir=output_dir,
858            summary_writer=summary_writer))
859  return monitors
860
861
862class GraphDump(BaseMonitor):
863  """Dumps almost all tensors in the graph at every step.
864
865  THIS CLASS IS DEPRECATED. See
866  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
867  for general migration instructions.
868
869  Note, this is very expensive, prefer `PrintTensor` in production.
870  """
871
872  IGNORE_OPS = [
873      "Const", "Assign", "Identity", "Placeholder", "RandomUniform", "Cast",
874      "RestoreSlice"
875  ]
876
877  def __init__(self, ignore_ops=None):
878    """Initializes GraphDump monitor.
879
880    Args:
881      ignore_ops: `list` of `string`. Names of ops to ignore.
882          If None, `GraphDump.IGNORE_OPS` is used.
883    """
884    super(GraphDump, self).__init__()
885    self._ignore_ops = ignore_ops or GraphDump.IGNORE_OPS
886    self._data = {}
887
888  def begin(self, max_steps=None):
889    super(GraphDump, self).begin(max_steps=max_steps)
890    self._tensors = []
891    graph = ops.get_default_graph()
892    graph_def = graph.as_graph_def()
893    for node in graph_def.node:
894      if node.op in self._ignore_ops:
895        continue
896      logging.info("op=%s name=%s.", node.op, node.name)
897      try:
898        self._tensors.append(graph.get_tensor_by_name(node.name + ":0"))
899      except KeyError:
900        pass
901
902  def step_begin(self, step):
903    super(GraphDump, self).step_begin(step)
904    return self._tensors
905
906  def step_end(self, step, output):
907    super(GraphDump, self).step_end(step, output)
908    self._data[step] = output
909
910  @property
911  def data(self):
912    return self._data
913
914  # TODO(ptucker): Handle keys that are in one but not the other.
915  def compare(self, other_dump, step, atol=1e-06):
916    """Compares two `GraphDump` monitors and returns differences.
917
918    Args:
919      other_dump: Another `GraphDump` monitor.
920      step: `int`, step to compare on.
921      atol: `float`, absolute tolerance in comparison of floating arrays.
922
923    Returns:
924      Returns tuple:
925        matched: `list` of keys that matched.
926        non_matched: `dict` of keys to tuple of 2 mismatched values.
927
928    Raises:
929      ValueError: if a key in `data` is missing from `other_dump` at `step`.
930    """
931    non_matched = {}
932    matched = []
933    this_output = self.data[step] if step in self.data else {}
934    other_output = other_dump.data[step] if step in other_dump.data else {}
935    for key in this_output:
936      if not isinstance(key, six.string_types):
937        continue
938      if key not in other_output:
939        raise ValueError("%s missing at step %s.", (key, step))
940      value1 = _extract_output(this_output, key)
941      value2 = _extract_output(other_output, key)
942      if isinstance(value1, str):
943        continue
944      if isinstance(value1, np.ndarray):
945        if not np.allclose(value1, value2, atol=atol):
946          non_matched[key] = value1 - value2
947        else:
948          matched.append(key)
949      else:
950        if value1 != value2:
951          non_matched[key] = (value1, value2)
952        else:
953          matched.append(key)
954    return matched, non_matched
955
956
957class ExportMonitor(EveryN):
958  """Monitor that exports Estimator every N steps.
959
960  THIS CLASS IS DEPRECATED. See
961  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
962  for general migration instructions.
963  """
964
965  @deprecation.deprecated("2017-03-25",
966                          "ExportMonitor is deprecated. Please pass an "
967                          "ExportStrategy to Experiment instead.")
968  def __init__(self,
969               every_n_steps,
970               export_dir,
971               input_fn=None,
972               input_feature_key=None,
973               exports_to_keep=5,
974               signature_fn=None,
975               default_batch_size=1):
976    """Initializes ExportMonitor.
977
978    Args:
979      every_n_steps: Run monitor every N steps.
980      export_dir: str, folder to export.
981      input_fn: A function that takes no argument and returns a tuple of
982        (features, labels), where features is a dict of string key to `Tensor`
983        and labels is a `Tensor` that's currently not used (and so can be
984        `None`).
985      input_feature_key: String key into the features dict returned by
986        `input_fn` that corresponds to the raw `Example` strings `Tensor` that
987        the exported model will take as input. Should be `None` if and only if
988        you're passing in a `signature_fn` that does not use the first arg
989        (`Tensor` of `Example` strings).
990      exports_to_keep: int, number of exports to keep.
991      signature_fn: Function that returns a default signature and a named
992        signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s
993        for features and `dict` of `Tensor`s for predictions.
994      default_batch_size: Default batch size of the `Example` placeholder.
995
996    Raises:
997      ValueError: If `input_fn` and `input_feature_key` are not both defined or
998        are not both `None`.
999    """
1000    super(ExportMonitor, self).__init__(every_n_steps=every_n_steps)
1001    self._export_dir = export_dir
1002    self._input_fn = input_fn
1003    self._input_feature_key = input_feature_key
1004    self._use_deprecated_input_fn = input_fn is None
1005    self._exports_to_keep = exports_to_keep
1006    self._signature_fn = signature_fn
1007    self._default_batch_size = default_batch_size
1008    self._last_export_dir = None
1009
1010  @property
1011  def export_dir(self):
1012    return self._export_dir
1013
1014  @property
1015  def exports_to_keep(self):
1016    return self._exports_to_keep
1017
1018  @property
1019  def signature_fn(self):
1020    return self._signature_fn
1021
1022  @property
1023  def last_export_dir(self):
1024    """Returns the directory containing the last completed export.
1025
1026    Returns:
1027      The string path to the exported directory. NB: this functionality was
1028      added on 2016/09/25; clients that depend on the return value may need
1029      to handle the case where this function returns None because the
1030      estimator being fitted does not yet return a value during export.
1031    """
1032    return self._last_export_dir
1033
1034  def every_n_step_end(self, step, outputs):
1035    super(ExportMonitor, self).every_n_step_end(step, outputs)
1036    try:
1037      if isinstance(self._estimator, core_estimator.Estimator):
1038        raise ValueError(
1039            "ExportMonitor does not support `tf.estimator.Estimator. `. "
1040            "Please pass an ExportStrategy to Experiment instead.")
1041      self._last_export_dir = self._estimator.export(
1042          self.export_dir,
1043          exports_to_keep=self.exports_to_keep,
1044          signature_fn=self.signature_fn,
1045          input_fn=self._input_fn,
1046          default_batch_size=self._default_batch_size,
1047          input_feature_key=self._input_feature_key,
1048          use_deprecated_input_fn=self._use_deprecated_input_fn)
1049    except RuntimeError:
1050      # Currently we are not syncronized with saving checkpoints, which leads to
1051      # runtime errors when we are calling export on the same global step.
1052      # Exports depend on saved checkpoints for constructing the graph and
1053      # getting the global step from the graph instance saved in the checkpoint.
1054      # If the checkpoint is stale with respect to current step, the global step
1055      # is taken to be the last saved checkpoint's global step and exporter
1056      # doesn't export the same checkpoint again with the following error.
1057      logging.info("Skipping exporting because the existing checkpoint has "
1058                   "already been exported. "
1059                   "Consider exporting less frequently.")
1060
1061  def end(self, session=None):
1062    super(ExportMonitor, self).end(session=session)
1063    latest_path = checkpoint_management.latest_checkpoint(
1064        self._estimator.model_dir)
1065    if latest_path is None:
1066      logging.info("Skipping export at the end since model has not been saved "
1067                   "yet.")
1068      return
1069    if isinstance(self._estimator, core_estimator.Estimator):
1070      raise ValueError(
1071          "ExportMonitor does not support `tf.estimator.Estimator. `. "
1072          "Please pass an ExportStrategy to Experiment instead.")
1073    try:
1074      self._last_export_dir = self._estimator.export(
1075          self.export_dir,
1076          exports_to_keep=self.exports_to_keep,
1077          signature_fn=self.signature_fn,
1078          input_fn=self._input_fn,
1079          default_batch_size=self._default_batch_size,
1080          input_feature_key=self._input_feature_key,
1081          use_deprecated_input_fn=self._use_deprecated_input_fn)
1082    except RuntimeError:
1083      logging.info("Skipping exporting for the same step.")
1084
1085
1086class CheckpointSaver(BaseMonitor):
1087  """Saves checkpoints every N steps or N seconds.
1088
1089  THIS CLASS IS DEPRECATED. See
1090  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
1091  for general migration instructions.
1092  """
1093
1094  def __init__(self,
1095               checkpoint_dir,
1096               save_secs=None,
1097               save_steps=None,
1098               saver=None,
1099               checkpoint_basename="model.ckpt",
1100               scaffold=None):
1101    """Initialize CheckpointSaver monitor.
1102
1103    Args:
1104      checkpoint_dir: `str`, base directory for the checkpoint files.
1105      save_secs: `int`, save every N secs.
1106      save_steps: `int`, save every N steps.
1107      saver: `Saver` object, used for saving.
1108      checkpoint_basename: `str`, base name for the checkpoint files.
1109      scaffold: `Scaffold`, use to get saver object.
1110
1111    Raises:
1112      ValueError: If both `save_steps` and `save_secs` are not `None`.
1113      ValueError: If both `save_steps` and `save_secs` are `None`.
1114    """
1115    logging.info("Create CheckpointSaver.")
1116    super(CheckpointSaver, self).__init__()
1117    self._saver = saver
1118    self._summary_writer = core_summary.FileWriterCache.get(checkpoint_dir)
1119    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
1120    self._scaffold = scaffold
1121    self._save_secs = save_secs
1122    self._save_steps = save_steps
1123    self._last_saved_time = None
1124    self._last_begin_step = None
1125    self._last_saved_step = None
1126
1127    if save_steps is None and save_secs is None:
1128      raise ValueError("Either save_steps or save_secs should be provided")
1129    if (save_steps is not None) and (save_secs is not None):
1130      raise ValueError("Can not provide both save_steps and save_secs.")
1131
1132  def begin(self, max_steps=None):
1133    super(CheckpointSaver, self).begin(max_steps)
1134    self._last_saved_time = None
1135    self._last_begin_step = None
1136    self._last_saved_step = None
1137
1138  def step_begin(self, step):
1139    super(CheckpointSaver, self).step_begin(step)
1140    self._last_begin_step = step
1141
1142  def post_step(self, step, session):
1143    super(CheckpointSaver, self).post_step(step, session)
1144    if self._last_saved_time is None:
1145      self._save(step, session)
1146
1147    if self._save_steps is not None:
1148      if step >= self._last_saved_step + self._save_steps:
1149        self._save(step, session)
1150
1151    if self._save_secs is not None:
1152      if time.time() >= self._last_saved_time + self._save_secs:
1153        self._save(step, session)
1154
1155  def end(self, session=None):
1156    super(CheckpointSaver, self).end(session)
1157    self._save(self._last_begin_step, session)
1158
1159  def _save(self, step, session):
1160    """Saves the latest checkpoint."""
1161    if step == self._last_saved_step:
1162      return
1163    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
1164    self._last_saved_time = time.time()
1165    self._last_saved_step = step
1166    if self._saver is None:
1167      self._scaffold.saver.save(session, self._save_path, global_step=step)
1168    else:
1169      self._saver.save(session, self._save_path, global_step=step)
1170    self._summary_writer.add_session_log(
1171        SessionLog(
1172            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
1173        step)
1174
1175
1176class StepCounter(EveryN):
1177  """Steps per second monitor.
1178
1179  THIS CLASS IS DEPRECATED. See
1180  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
1181  for general migration instructions.
1182  """
1183
1184  def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None):
1185    super(StepCounter, self).__init__(every_n_steps=every_n_steps)
1186    self._summary_tag = "global_step/sec"
1187    self._last_reported_step = None
1188    self._last_reported_time = None
1189    self._summary_writer = summary_writer
1190    if summary_writer is None and output_dir:
1191      self._summary_writer = core_summary.FileWriterCache.get(output_dir)
1192
1193  def set_estimator(self, estimator):
1194    super(StepCounter, self).set_estimator(estimator)
1195    if self._summary_writer is None:
1196      self._summary_writer = core_summary.FileWriterCache.get(
1197          estimator.model_dir)
1198
1199  def every_n_step_end(self, current_step, outputs):
1200    current_time = time.time()
1201    if self._last_reported_time is not None and self._summary_writer:
1202      added_steps = current_step - self._last_reported_step
1203      elapsed_time = current_time - self._last_reported_time
1204      steps_per_sec = added_steps / elapsed_time
1205      summary = Summary(value=[
1206          Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
1207      ])
1208      self._summary_writer.add_summary(summary, current_step)
1209    self._last_reported_step = current_step
1210    self._last_reported_time = current_time
1211
1212
1213class NanLossDuringTrainingError(RuntimeError):
1214
1215  def __str__(self):
1216    return "NaN loss during training."
1217
1218
1219class NanLoss(EveryN):
1220  """NaN Loss monitor.
1221
1222  THIS CLASS IS DEPRECATED. See
1223  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
1224  for general migration instructions.
1225
1226  Monitors loss and stops training if loss is NaN.
1227  Can either fail with exception or just stop training.
1228  """
1229
1230  def __init__(self, loss_tensor, every_n_steps=100, fail_on_nan_loss=True):
1231    """Initializes NanLoss monitor.
1232
1233    Args:
1234      loss_tensor: `Tensor`, the loss tensor.
1235      every_n_steps: `int`, run check every this many steps.
1236      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
1237    """
1238    super(NanLoss, self).__init__(every_n_steps=every_n_steps)
1239    self._loss_tensor = loss_tensor
1240    self._fail_on_nan_loss = fail_on_nan_loss
1241
1242  def every_n_step_begin(self, step):
1243    super(NanLoss, self).every_n_step_begin(step)
1244    return [self._loss_tensor]
1245
1246  def every_n_step_end(self, step, outputs):
1247    super(NanLoss, self).every_n_step_end(step, outputs)
1248    if np.isnan(_extract_output(outputs, self._loss_tensor)):
1249      failure_message = "Model diverged with loss = NaN."
1250      if self._fail_on_nan_loss:
1251        logging.error(failure_message)
1252        raise NanLossDuringTrainingError
1253      else:
1254        logging.warning(failure_message)
1255        # We don't raise an error but we return "should stop" so we stop, but
1256        # without an exception.
1257        return True
1258
1259
1260class RunHookAdapterForMonitors(session_run_hook.SessionRunHook):
1261  """Wraps monitors into a SessionRunHook."""
1262
1263  def __init__(self, monitors):
1264    self._monitors = monitors
1265
1266  def begin(self):
1267    self._last_step = None
1268    self._global_step_tensor = training_util.get_global_step()
1269    for m in self._monitors:
1270      m.begin(max_steps=None)
1271
1272  def before_run(self, run_context):
1273    if self._last_step is None:
1274      self._last_step = run_context.session.run(self._global_step_tensor) + 1
1275
1276    request = {self._global_step_tensor: self._global_step_tensor}
1277    monitor_fetches = []
1278    for m in self._monitors:
1279      monitor_requests = m.step_begin(self._last_step)
1280      if monitor_requests:
1281        if not isinstance(monitor_requests, list):
1282          raise ValueError("Monitor.step_begin should return a list.")
1283        monitor_fetches.extend(monitor_requests)
1284    if monitor_fetches:
1285      request["monitors"] = dict(
1286          zip(monitor_fetches, [_as_graph_element(f) for f in monitor_fetches]))
1287
1288    return session_run_hook.SessionRunArgs(request)
1289
1290  def after_run(self, run_context, run_values):
1291    result = run_values.results[
1292        "monitors"] if "monitors" in run_values.results else {}
1293    for m in self._monitors:
1294      induce_stop = m.step_end(self._last_step, result)
1295      if induce_stop:
1296        run_context.request_stop()
1297
1298    for m in self._monitors:
1299      m.post_step(self._last_step, run_context.session)
1300
1301    self._last_step = run_values.results[self._global_step_tensor] + 1
1302
1303  def end(self, session):
1304    self._last_step = None
1305    for m in self._monitors:
1306      if "session" in tf_inspect.getargspec(m.end).args:
1307        m.end(session=session)
1308      else:
1309        m.end()
1310
1311
1312def replace_monitors_with_hooks(monitors_or_hooks, estimator):
1313  """Wraps monitors with a hook.
1314
1315  `Monitor` is deprecated in favor of `SessionRunHook`. If you're using a
1316  monitor, you can wrap it with a hook using function. It is recommended to
1317  implement hook version of your monitor.
1318
1319  Args:
1320    monitors_or_hooks: A `list` may contain both monitors and hooks.
1321    estimator: An `Estimator` that monitor will be used with.
1322
1323  Returns:
1324    Returns a list of hooks. If there is any monitor in the given list, it is
1325    replaced by a hook.
1326  """
1327  monitors_or_hooks = monitors_or_hooks or []
1328  hooks = [
1329      m for m in monitors_or_hooks
1330      if isinstance(m, session_run_hook.SessionRunHook)
1331  ]
1332
1333  deprecated_monitors = [
1334      m for m in monitors_or_hooks
1335      if not isinstance(m, session_run_hook.SessionRunHook)
1336  ]
1337
1338  if not estimator.config.is_chief:
1339    # Prune list of monitor to the ones runnable on all workers.
1340    deprecated_monitors = [
1341        m for m in deprecated_monitors if m.run_on_all_workers
1342    ]
1343
1344  # Setup monitors.
1345  for monitor in deprecated_monitors:
1346    monitor.set_estimator(estimator)
1347
1348  if deprecated_monitors:
1349    hooks.append(RunHookAdapterForMonitors(deprecated_monitors))
1350
1351  return hooks
1352
1353
1354def _as_graph_element(obj):
1355  """Retrieves Graph element."""
1356  graph = ops.get_default_graph()
1357  if not isinstance(obj, six.string_types):
1358    if not hasattr(obj, "graph") or obj.graph != graph:
1359      raise ValueError("Passed %s should have graph attribute that is equal "
1360                       "to current graph %s." % (obj, graph))
1361    return obj
1362  if ":" in obj:
1363    element = graph.as_graph_element(obj)
1364  else:
1365    element = graph.as_graph_element(obj + ":0")
1366    # Check that there is no :1 (e.g. it's single output).
1367    try:
1368      graph.as_graph_element(obj + ":1")
1369    except (KeyError, ValueError):
1370      pass
1371    else:
1372      raise ValueError("Name %s is ambiguous, "
1373                       "as this `Operation` has multiple outputs "
1374                       "(at least 2)." % obj)
1375  return element
1376