• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Some common SessionRunHook classes.
16
17Note that the symbols that are exported to v1 tf.train namespace are also
18exported to v2 in tf.estimator namespace. See
19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import os
27import time
28
29import numpy as np
30import six
31
32from tensorflow.core.framework.summary_pb2 import Summary
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.core.util.event_pb2 import SessionLog
35from tensorflow.python.client import timeline
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import meta_graph
39from tensorflow.python.framework import ops
40from tensorflow.python.ops import init_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.platform import gfile
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.training import session_run_hook
45from tensorflow.python.training import training_util
46from tensorflow.python.training.session_run_hook import SessionRunArgs
47from tensorflow.python.training.summary_io import SummaryWriterCache
48from tensorflow.python.util.tf_export import tf_export
49
50_HOOKS = "hooks"
51_STEPS_PER_RUN_VAR = "steps_per_run"
52
53
54class _HookTimer(object):
55  """Base timer for determining when Hooks should trigger.
56
57  Should not be instantiated directly.
58  """
59
60  def __init__(self):
61    pass
62
63  def reset(self):
64    """Resets the timer."""
65    pass
66
67  def should_trigger_for_step(self, step):
68    """Return true if the timer should trigger for the specified step."""
69    raise NotImplementedError
70
71  def update_last_triggered_step(self, step):
72    """Update the last triggered time and step number.
73
74    Args:
75      step: The current step.
76
77    Returns:
78      A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number
79      of seconds between the current trigger and the last one (a float), and
80      `elapsed_steps` is the number of steps between the current trigger and
81      the last one. Both values will be set to `None` on the first trigger.
82    """
83    raise NotImplementedError
84
85  def last_triggered_step(self):
86    """Returns the last triggered time step or None if never triggered."""
87    raise NotImplementedError
88
89
90@tf_export(v1=["train.SecondOrStepTimer"])
91class SecondOrStepTimer(_HookTimer):
92  """Timer that triggers at most once every N seconds or once every N steps.
93
94  This symbol is also exported to v2 in tf.estimator namespace. See
95  https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py
96  """
97
98  def __init__(self, every_secs=None, every_steps=None):
99    self.reset()
100    self._every_secs = every_secs
101    self._every_steps = every_steps
102
103    if self._every_secs is None and self._every_steps is None:
104      raise ValueError("Either every_secs or every_steps should be provided.")
105    if (self._every_secs is not None) and (self._every_steps is not None):
106      raise ValueError("Can not provide both every_secs and every_steps.")
107
108    super(SecondOrStepTimer, self).__init__()
109
110  def reset(self):
111    self._last_triggered_step = None
112    self._last_triggered_time = None
113
114  def should_trigger_for_step(self, step):
115    """Return true if the timer should trigger for the specified step.
116
117    Args:
118      step: Training step to trigger on.
119
120    Returns:
121      True if the difference between the current time and the time of the last
122      trigger exceeds `every_secs`, or if the difference between the current
123      step and the last triggered step exceeds `every_steps`. False otherwise.
124    """
125    if self._last_triggered_step is None:
126      return True
127
128    if self._last_triggered_step == step:
129      return False
130
131    if self._every_secs is not None:
132      if time.time() >= self._last_triggered_time + self._every_secs:
133        return True
134
135    if self._every_steps is not None:
136      if step >= self._last_triggered_step + self._every_steps:
137        return True
138
139    return False
140
141  def update_last_triggered_step(self, step):
142    current_time = time.time()
143    if self._last_triggered_time is None:
144      elapsed_secs = None
145      elapsed_steps = None
146    else:
147      elapsed_secs = current_time - self._last_triggered_time
148      elapsed_steps = step - self._last_triggered_step
149
150    self._last_triggered_time = current_time
151    self._last_triggered_step = step
152    return (elapsed_secs, elapsed_steps)
153
154  def last_triggered_step(self):
155    return self._last_triggered_step
156
157
158class NeverTriggerTimer(_HookTimer):
159  """Timer that never triggers."""
160
161  def should_trigger_for_step(self, step):
162    _ = step
163    return False
164
165  def update_last_triggered_step(self, step):
166    _ = step
167    return (None, None)
168
169  def last_triggered_step(self):
170    return None
171
172
173@tf_export(v1=["train.LoggingTensorHook"])
174class LoggingTensorHook(session_run_hook.SessionRunHook):
175  """Prints the given tensors every N local steps, every N seconds, or at end.
176
177  The tensors will be printed to the log, with `INFO` severity. If you are not
178  seeing the logs, you might want to add the following line after your imports:
179
180  ```python
181    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
182  ```
183
184  Note that if `at_end` is True, `tensors` should not include any tensor
185  whose evaluation produces a side effect such as consuming additional inputs.
186
187  @compatibility(TF2)
188  Please check this [notebook][notebook] on how to migrate the API to TF2.
189
190  [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb
191
192  @end_compatibility
193
194  """
195
196  def __init__(self,
197               tensors,
198               every_n_iter=None,
199               every_n_secs=None,
200               at_end=False,
201               formatter=None):
202    """Initializes a `LoggingTensorHook`.
203
204    Args:
205      tensors: `dict` that maps string-valued tags to tensors/tensor names, or
206        `iterable` of tensors/tensor names.
207      every_n_iter: `int`, print the values of `tensors` once every N local
208        steps taken on the current worker.
209      every_n_secs: `int` or `float`, print the values of `tensors` once every N
210        seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
211        provided.
212      at_end: `bool` specifying whether to print the values of `tensors` at the
213        end of the run.
214      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
215        If `None` uses default printing all tensors.
216
217    Raises:
218      ValueError: if `every_n_iter` is non-positive.
219    """
220    only_log_at_end = (
221        at_end and (every_n_iter is None) and (every_n_secs is None))
222    if (not only_log_at_end and
223        (every_n_iter is None) == (every_n_secs is None)):
224      raise ValueError(
225          "either at_end and/or exactly one of every_n_iter and every_n_secs "
226          "must be provided.")
227    if every_n_iter is not None and every_n_iter <= 0:
228      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
229    if not isinstance(tensors, dict):
230      self._tag_order = tensors
231      tensors = {item: item for item in tensors}
232    else:
233      self._tag_order = sorted(tensors.keys())
234    self._tensors = tensors
235    self._formatter = formatter
236    self._timer = (
237        NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(
238            every_secs=every_n_secs, every_steps=every_n_iter))
239    self._log_at_end = at_end
240
241  def begin(self):
242    self._timer.reset()
243    self._iter_count = 0
244    # Convert names to tensors if given
245    self._current_tensors = {
246        tag: _as_graph_element(tensor)
247        for (tag, tensor) in self._tensors.items()
248    }
249
250  def before_run(self, run_context):  # pylint: disable=unused-argument
251    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
252    if self._should_trigger:
253      return SessionRunArgs(self._current_tensors)
254    else:
255      return None
256
257  def _log_tensors(self, tensor_values):
258    original = np.get_printoptions()
259    np.set_printoptions(suppress=True)
260    elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
261    if self._formatter:
262      logging.info(self._formatter(tensor_values))
263    else:
264      stats = []
265      for tag in self._tag_order:
266        stats.append("%s = %s" % (tag, tensor_values[tag]))
267      if elapsed_secs is not None:
268        logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
269      else:
270        logging.info("%s", ", ".join(stats))
271    np.set_printoptions(**original)
272
273  def after_run(self, run_context, run_values):
274    _ = run_context
275    if self._should_trigger:
276      self._log_tensors(run_values.results)
277
278    self._iter_count += 1
279
280  def end(self, session):
281    if self._log_at_end:
282      values = session.run(self._current_tensors)
283      self._log_tensors(values)
284
285
286def get_or_create_steps_per_run_variable():
287  """Gets or creates the steps_per_run variable.
288
289  In Estimator, the user provided computation, the model_fn, is wrapped
290  inside a tf.while_loop for peak performance. The iterations of the loop are
291  specified by this variable, which adjusts its value on the CPU after each
292  device program execution and before the next execution.
293
294  The purpose of using a variable, rather than a constant, is to allow
295  Estimator adapt the device training iterations according to the final steps
296  specified by users. For example, if the user sets the steps_per_run as
297  4 and steps as 10 in Estimator.train(), the steps_per_run
298  variable will have the following value before each training run.
299
300      - 1-st execution: steps_per_run = 4
301      - 2-nd execution: steps_per_run = 4
302      - 3-rd execution: steps_per_run = 2
303
304  As model_fn increases the global step once per train_op invocation, the global
305  step is 10 after all executions, matching the steps=10 inputs passed in by
306  users.
307
308  Returns:
309    A TF non-trainable resource variable.
310
311  Raises:
312    RuntimeError: If multi steps_per_run variables were found.
313  """
314  graph = ops.get_default_graph()
315  collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR)
316  steps_per_run_vars = graph.get_collection(collection_name)
317  if len(steps_per_run_vars) == 1:
318    return steps_per_run_vars[0]
319  elif len(steps_per_run_vars) > 1:
320    raise RuntimeError("Multiple steps_per_run_var in collection.")
321
322  with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE):
323    return variable_scope.get_variable(
324        _STEPS_PER_RUN_VAR,
325        initializer=init_ops.ones_initializer(),
326        shape=[],
327        dtype=dtypes.int32,
328        trainable=False,
329        collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
330        use_resource=True)
331
332
333class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
334  """Hook that requests stop at a specified step."""
335
336  def __init__(self, num_steps=None, last_step=None, steps_per_run=1):
337    """Initializes a `MultiStepStopAtStepHook`.
338
339    This hook requests stop after either a number of steps have been
340    executed or a last step has been reached. Only one of the two options can be
341    specified.
342
343    if `num_steps` is specified, it indicates the number of steps to execute
344    after `begin()` is called. If instead `last_step` is specified, it
345    indicates the last step we want to execute, as passed to the `after_run()`
346    call.
347
348    In Estimator, the user provided computation, the model_fn, is wrapped
349    inside a tf.while_loop for peak performance. The steps_per_run variable
350    determines the number of iterations of the loop before returning to the CPU.
351
352    Args:
353      num_steps: Number of steps to execute.
354      last_step: Step after which to stop.
355      steps_per_run: Number of steps executed per run call.
356
357    Raises:
358      ValueError: If one of the arguments is invalid.
359    """
360    if num_steps is None and last_step is None:
361      raise ValueError("One of num_steps or last_step must be specified.")
362    if num_steps is not None and last_step is not None:
363      raise ValueError("Only one of num_steps or last_step can be specified.")
364    if steps_per_run is None or steps_per_run < 1:
365      raise ValueError("steps_per_run should be greater than 0")
366    self._num_steps = num_steps
367    self._last_step = last_step
368    self._steps_per_run_initial_value = steps_per_run
369
370  def begin(self):
371    self._global_step_tensor = training_util.get_global_step()
372    if self._global_step_tensor is None:
373      raise RuntimeError("Global step should be created to use StopAtStepHook.")
374    self._steps_per_run_variable = get_or_create_steps_per_run_variable()
375
376  def _update_steps_per_run_variable(self, global_step, session):
377    steps = min(self._last_step - global_step,
378                self._steps_per_run_initial_value)
379    self._steps_per_run_variable.load(steps, session=session)
380
381  def after_create_session(self, session, coord):
382    global_step = session.run(self._global_step_tensor)
383    if self._last_step is None:
384      self._last_step = global_step + self._num_steps
385    self._update_steps_per_run_variable(global_step, session)
386
387  def after_run(self, run_context, run_values):
388    # Global step cannot be retrieved via SessionRunArgs and before_run due to
389    # race condition in hook execution.
390    global_step = run_context.session.run(self._global_step_tensor)
391    if global_step >= self._last_step:
392      run_context.request_stop()
393    else:
394      self._update_steps_per_run_variable(global_step, run_context.session)
395
396
397@tf_export(v1=["train.StopAtStepHook"])
398class StopAtStepHook(session_run_hook.SessionRunHook):
399  """Hook that requests stop at a specified step.
400
401  @compatibility(TF2)
402  Please check this [notebook][notebook] on how to migrate the API to TF2.
403
404  [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb
405
406  @end_compatibility
407  """
408
409  def __init__(self, num_steps=None, last_step=None):
410    """Initializes a `StopAtStepHook`.
411
412    This hook requests stop after either a number of steps have been
413    executed or a last step has been reached. Only one of the two options can be
414    specified.
415
416    if `num_steps` is specified, it indicates the number of steps to execute
417    after `begin()` is called. If instead `last_step` is specified, it
418    indicates the last step we want to execute, as passed to the `after_run()`
419    call.
420
421    Args:
422      num_steps: Number of steps to execute.
423      last_step: Step after which to stop.
424
425    Raises:
426      ValueError: If one of the arguments is invalid.
427    """
428    if num_steps is None and last_step is None:
429      raise ValueError("One of num_steps or last_step must be specified.")
430    if num_steps is not None and last_step is not None:
431      raise ValueError("Only one of num_steps or last_step can be specified.")
432    self._num_steps = num_steps
433    self._last_step = last_step
434
435  def begin(self):
436    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
437    if self._global_step_tensor is None:
438      raise RuntimeError("Global step should be created to use StopAtStepHook.")
439
440  def after_create_session(self, session, coord):
441    if self._last_step is None:
442      global_step = session.run(self._global_step_tensor)
443      self._last_step = global_step + self._num_steps
444
445  def before_run(self, run_context):  # pylint: disable=unused-argument
446    return SessionRunArgs(self._global_step_tensor)
447
448  def after_run(self, run_context, run_values):
449    global_step = run_values.results + 1
450    if global_step >= self._last_step:
451      # Check latest global step to ensure that the targeted last step is
452      # reached. global_step read tensor is the value of global step
453      # before running the operation. We're not sure whether current session.run
454      # incremented the global_step or not. Here we're checking it.
455
456      step = run_context.session.run(self._global_step_tensor)
457      if step >= self._last_step:
458        run_context.request_stop()
459
460
461@tf_export(v1=["train.CheckpointSaverListener"])
462class CheckpointSaverListener(object):
463  """Interface for listeners that take action before or after checkpoint save.
464
465  `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
466  triggered, and provides callbacks at the following points:
467   - before using the session
468   - before each call to `Saver.save()`
469   - after each call to `Saver.save()`
470   - at the end of session
471
472  To use a listener, implement a class and pass the listener to a
473  `CheckpointSaverHook`, as in this example:
474
475  ```python
476  class ExampleCheckpointSaverListener(CheckpointSaverListener):
477    def begin(self):
478      # You can add ops to the graph here.
479      print('Starting the session.')
480      self.your_tensor = ...
481
482    def before_save(self, session, global_step_value):
483      print('About to write a checkpoint')
484
485    def after_save(self, session, global_step_value):
486      print('Done writing checkpoint.')
487      if decided_to_stop_training():
488        return True
489
490    def end(self, session, global_step_value):
491      print('Done with the session.')
492
493  ...
494  listener = ExampleCheckpointSaverListener()
495  saver_hook = tf.estimator.CheckpointSaverHook(
496      checkpoint_dir, listeners=[listener])
497  with
498  tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
499    ...
500  ```
501
502  A `CheckpointSaverListener` may simply take some action after every
503  checkpoint save. It is also possible for the listener to use its own schedule
504  to act less frequently, e.g. based on global_step_value. In this case,
505  implementors should implement the `end()` method to handle actions related to
506  the last checkpoint save. But the listener should not act twice if
507  `after_save()` already handled this last checkpoint save.
508
509  A `CheckpointSaverListener` can request training to be stopped, by returning
510  True in `after_save`. Please note that, in replicated distributed training
511  setting, only `chief` should use this behavior. Otherwise each worker will do
512  their own evaluation, which may be wasteful of resources.
513  """
514
515  def begin(self):
516    pass
517
518  def before_save(self, session, global_step_value):
519    pass
520
521  def after_save(self, session, global_step_value):
522    pass
523
524  def end(self, session, global_step_value):
525    pass
526
527
528@tf_export(v1=["train.CheckpointSaverHook"])
529class CheckpointSaverHook(session_run_hook.SessionRunHook):
530  """Saves checkpoints every N steps or seconds."""
531
532  def __init__(self,
533               checkpoint_dir,
534               save_secs=None,
535               save_steps=None,
536               saver=None,
537               checkpoint_basename="model.ckpt",
538               scaffold=None,
539               listeners=None,
540               save_graph_def=True):
541    """Initializes a `CheckpointSaverHook`.
542
543    Args:
544      checkpoint_dir: `str`, base directory for the checkpoint files.
545      save_secs: `int`, save every N secs.
546      save_steps: `int`, save every N steps.
547      saver: `Saver` object, used for saving.
548      checkpoint_basename: `str`, base name for the checkpoint files.
549      scaffold: `Scaffold`, use to get saver object.
550      listeners: List of `CheckpointSaverListener` subclass instances. Used for
551        callbacks that run immediately before or after this hook saves the
552        checkpoint.
553      save_graph_def: Whether to save the GraphDef and MetaGraphDef to
554        `checkpoint_dir`. The GraphDef is saved after the session is created as
555        `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
556        `model.ckpt-*.meta`.
557
558    Raises:
559      ValueError: One of `save_steps` or `save_secs` should be set.
560      ValueError: At most one of `saver` or `scaffold` should be set.
561    """
562    logging.info("Create CheckpointSaverHook.")
563    if saver is not None and scaffold is not None:
564      raise ValueError("You cannot provide both saver and scaffold.")
565    self._saver = saver
566    self._checkpoint_dir = checkpoint_dir
567    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
568    self._scaffold = scaffold
569    self._timer = SecondOrStepTimer(
570        every_secs=save_secs, every_steps=save_steps)
571    self._listeners = listeners or []
572    self._steps_per_run = 1
573    self._save_graph_def = save_graph_def
574
575  def _set_steps_per_run(self, steps_per_run):
576    self._steps_per_run = steps_per_run
577
578  def begin(self):
579    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
580    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
581    if self._global_step_tensor is None:
582      raise RuntimeError(
583          "Global step should be created to use CheckpointSaverHook.")
584    for l in self._listeners:
585      l.begin()
586
587  def after_create_session(self, session, coord):
588    global_step = session.run(self._global_step_tensor)
589    if self._save_graph_def:
590      # We do write graph and saver_def at the first call of before_run.
591      # We cannot do this in begin, since we let other hooks to change graph and
592      # add variables in begin. Graph is finalized after all begin calls.
593      training_util.write_graph(
594          ops.get_default_graph().as_graph_def(add_shapes=True),
595          self._checkpoint_dir, "graph.pbtxt")
596    saver_def = self._get_saver().saver_def if self._get_saver() else None
597    graph = ops.get_default_graph()
598    meta_graph_def = meta_graph.create_meta_graph_def(
599        graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
600    self._summary_writer.add_graph(graph)
601    self._summary_writer.add_meta_graph(meta_graph_def)
602    # The checkpoint saved here is the state at step "global_step".
603    self._save(session, global_step)
604    self._timer.update_last_triggered_step(global_step)
605
606  def before_run(self, run_context):  # pylint: disable=unused-argument
607    return SessionRunArgs(self._global_step_tensor)
608
609  def after_run(self, run_context, run_values):
610    stale_global_step = run_values.results
611    if self._timer.should_trigger_for_step(stale_global_step +
612                                           self._steps_per_run):
613      # get the real value after train op.
614      global_step = run_context.session.run(self._global_step_tensor)
615      if self._timer.should_trigger_for_step(global_step):
616        self._timer.update_last_triggered_step(global_step)
617        if self._save(run_context.session, global_step):
618          run_context.request_stop()
619
620  def end(self, session):
621    last_step = session.run(self._global_step_tensor)
622    if last_step != self._timer.last_triggered_step():
623      self._save(session, last_step)
624    for l in self._listeners:
625      l.end(session, last_step)
626
627  def _save(self, session, step):
628    """Saves the latest checkpoint, returns should_stop."""
629    logging.info("Calling checkpoint listeners before saving checkpoint %d...",
630                 step)
631    for l in self._listeners:
632      l.before_save(session, step)
633
634    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
635    self._get_saver().save(session, self._save_path, global_step=step,
636                           write_meta_graph=self._save_graph_def)
637    self._summary_writer.add_session_log(
638        SessionLog(
639            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
640        step)
641    logging.info("Calling checkpoint listeners after saving checkpoint %d...",
642                 step)
643    should_stop = False
644    for l in self._listeners:
645      if l.after_save(session, step):
646        logging.info(
647            "A CheckpointSaverListener requested that training be stopped. "
648            "listener: {}".format(l))
649        should_stop = True
650    return should_stop
651
652  def _get_saver(self):
653    if self._saver is not None:
654      return self._saver
655    elif self._scaffold is not None:
656      return self._scaffold.saver
657
658    # Get saver from the SAVERS collection if present.
659    collection_key = ops.GraphKeys.SAVERS
660    savers = ops.get_collection(collection_key)
661    if not savers:
662      raise RuntimeError(
663          "No items in collection {}. Please add a saver to the collection "
664          "or provide a saver or scaffold.".format(collection_key))
665    elif len(savers) > 1:
666      raise RuntimeError(
667          "More than one item in collection {}. "
668          "Please indicate which one to use by passing it to the constructor."
669          .format(collection_key))
670
671    self._saver = savers[0]
672    return savers[0]
673
674
675@tf_export(v1=["train.StepCounterHook"])
676class StepCounterHook(session_run_hook.SessionRunHook):
677  """Hook that counts steps per second."""
678
679  def __init__(self,
680               every_n_steps=100,
681               every_n_secs=None,
682               output_dir=None,
683               summary_writer=None):
684
685    if (every_n_steps is None) == (every_n_secs is None):
686      raise ValueError(
687          "exactly one of every_n_steps and every_n_secs should be provided.")
688    self._timer = SecondOrStepTimer(
689        every_steps=every_n_steps, every_secs=every_n_secs)
690
691    self._summary_writer = summary_writer
692    self._output_dir = output_dir
693    self._last_global_step = None
694    self._steps_per_run = 1
695
696  def _set_steps_per_run(self, steps_per_run):
697    self._steps_per_run = steps_per_run
698
699  def begin(self):
700    if self._summary_writer is None and self._output_dir:
701      self._summary_writer = SummaryWriterCache.get(self._output_dir)
702    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
703    if self._global_step_tensor is None:
704      raise RuntimeError(
705          "Global step should be created to use StepCounterHook.")
706    self._summary_tag = training_util.get_global_step().op.name + "/sec"
707
708  def before_run(self, run_context):  # pylint: disable=unused-argument
709    return SessionRunArgs(self._global_step_tensor)
710
711  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
712    steps_per_sec = elapsed_steps / elapsed_time
713    if self._summary_writer is not None:
714      summary = Summary(value=[
715          Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec)
716      ])
717      self._summary_writer.add_summary(summary, global_step)
718    logging.info("%s: %g", self._summary_tag, steps_per_sec)
719
720  def after_run(self, run_context, run_values):
721    _ = run_context
722
723    stale_global_step = run_values.results
724    if self._timer.should_trigger_for_step(stale_global_step +
725                                           self._steps_per_run):
726      # get the real value after train op.
727      global_step = run_context.session.run(self._global_step_tensor)
728      if self._timer.should_trigger_for_step(global_step):
729        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
730            global_step)
731        if elapsed_time is not None:
732          self._log_and_record(elapsed_steps, elapsed_time, global_step)
733
734    # Check whether the global step has been increased. Here, we do not use the
735    # timer.last_triggered_step as the timer might record a different global
736    # step value such that the comparison could be unreliable. For simplicity,
737    # we just compare the stale_global_step with previously recorded version.
738    if stale_global_step == self._last_global_step:
739      # Here, we give a warning in the first 5 times if we have observed that
740      # the global step has not been increased. For some Optimizers, the global
741      # step is not increased each time by design. For example,
742      # SyncReplicaOptimizer doesn't increase the global step in worker's main
743      # train step.
744      logging.log_first_n(
745          logging.WARN,
746          "It seems that global step (tf.train.get_global_step) has not "
747          "been increased. Current value (could be stable): %s vs previous "
748          "value: %s. You could increase the global step by passing "
749          "tf.train.get_global_step() to Optimizer.apply_gradients or "
750          "Optimizer.minimize.", 5, stale_global_step, self._last_global_step)
751
752    self._last_global_step = stale_global_step
753
754
755@tf_export(v1=["train.NanLossDuringTrainingError"])
756class NanLossDuringTrainingError(RuntimeError):
757
758  def __str__(self):
759    return "NaN loss during training."
760
761
762@tf_export(v1=["train.NanTensorHook"])
763class NanTensorHook(session_run_hook.SessionRunHook):
764  """Monitors the loss tensor and stops training if loss is NaN.
765
766  Can either fail with exception or just stop training.
767  """
768
769  def __init__(self, loss_tensor, fail_on_nan_loss=True):
770    """Initializes a `NanTensorHook`.
771
772    Args:
773      loss_tensor: `Tensor`, the loss tensor.
774      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
775    """
776    self._loss_tensor = loss_tensor
777    self._fail_on_nan_loss = fail_on_nan_loss
778
779  def before_run(self, run_context):  # pylint: disable=unused-argument
780    return SessionRunArgs(self._loss_tensor)
781
782  def after_run(self, run_context, run_values):
783    if np.isnan(run_values.results):
784      failure_message = "Model diverged with loss = NaN."
785      if self._fail_on_nan_loss:
786        logging.error(failure_message)
787        raise NanLossDuringTrainingError
788      else:
789        logging.warning(failure_message)
790        # We don't raise an error but we request stop without an exception.
791        run_context.request_stop()
792
793
794@tf_export(v1=["train.SummarySaverHook"])
795class SummarySaverHook(session_run_hook.SessionRunHook):
796  """Saves summaries every N steps."""
797
798  def __init__(self,
799               save_steps=None,
800               save_secs=None,
801               output_dir=None,
802               summary_writer=None,
803               scaffold=None,
804               summary_op=None):
805    """Initializes a `SummarySaverHook`.
806
807    Args:
808      save_steps: `int`, save summaries every N steps. Exactly one of
809        `save_secs` and `save_steps` should be set.
810      save_secs: `int`, save summaries every N seconds.
811      output_dir: `string`, the directory to save the summaries to. Only used if
812        no `summary_writer` is supplied.
813      summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
814        one will be created accordingly.
815      scaffold: `Scaffold` to get summary_op if it's not provided.
816      summary_op: `Tensor` of type `string` containing the serialized `Summary`
817        protocol buffer or a list of `Tensor`. They are most likely an output by
818        TF summary methods like `tf.compat.v1.summary.scalar` or
819        `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if
820        more than one, they must be passed in as a list.
821
822    Raises:
823      ValueError: Exactly one of scaffold or summary_op should be set.
824    """
825    if ((scaffold is None and summary_op is None) or
826        (scaffold is not None and summary_op is not None)):
827      raise ValueError(
828          "Exactly one of scaffold or summary_op must be provided.")
829    self._summary_op = summary_op
830    self._summary_writer = summary_writer
831    self._output_dir = output_dir
832    self._scaffold = scaffold
833    self._timer = SecondOrStepTimer(
834        every_secs=save_secs, every_steps=save_steps)
835    # TODO(mdan): Throw an error if output_dir and summary_writer are None.
836
837  def begin(self):
838    if self._summary_writer is None and self._output_dir:
839      self._summary_writer = SummaryWriterCache.get(self._output_dir)
840    self._next_step = None
841    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
842    if self._global_step_tensor is None:
843      raise RuntimeError(
844          "Global step should be created to use SummarySaverHook.")
845
846  def before_run(self, run_context):  # pylint: disable=unused-argument
847    self._request_summary = (
848        self._next_step is None or
849        self._timer.should_trigger_for_step(self._next_step))
850    requests = {"global_step": self._global_step_tensor}
851    if self._request_summary:
852      if self._get_summary_op() is not None:
853        requests["summary"] = self._get_summary_op()
854
855    return SessionRunArgs(requests)
856
857  def after_run(self, run_context, run_values):
858    _ = run_context
859    if not self._summary_writer:
860      return
861
862    stale_global_step = run_values.results["global_step"]
863    global_step = stale_global_step + 1
864    if self._next_step is None or self._request_summary:
865      global_step = run_context.session.run(self._global_step_tensor)
866
867    if self._next_step is None:
868      self._summary_writer.add_session_log(
869          SessionLog(status=SessionLog.START), global_step)
870
871    if self._request_summary:
872      self._timer.update_last_triggered_step(global_step)
873      if "summary" in run_values.results:
874        for summary in run_values.results["summary"]:
875          self._summary_writer.add_summary(summary, global_step)
876
877    self._next_step = global_step + 1
878
879  def end(self, session=None):
880    if self._summary_writer:
881      self._summary_writer.flush()
882
883  def _get_summary_op(self):
884    """Fetches the summary op either from self._summary_op or self._scaffold.
885
886    Returns:
887      Returns a list of summary `Tensor`.
888    """
889    summary_op = None
890    if self._summary_op is not None:
891      summary_op = self._summary_op
892    elif self._scaffold.summary_op is not None:
893      summary_op = self._scaffold.summary_op
894
895    if summary_op is None:
896      return None
897
898    if not isinstance(summary_op, list):
899      return [summary_op]
900    return summary_op
901
902
903@tf_export(v1=["train.GlobalStepWaiterHook"])
904class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
905  """Delays execution until global step reaches `wait_until_step`.
906
907  This hook delays execution until global step reaches to `wait_until_step`. It
908  is used to gradually start workers in distributed settings. One example usage
909  would be setting `wait_until_step=int(K*log(task_id+1))` assuming that
910  task_id=0 is the chief.
911  """
912
913  def __init__(self, wait_until_step):
914    """Initializes a `GlobalStepWaiterHook`.
915
916    Args:
917      wait_until_step: an `int` shows until which global step should we wait.
918    """
919    self._wait_until_step = wait_until_step
920
921  def begin(self):
922    self._worker_is_started = False
923    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
924    if self._global_step_tensor is None:
925      raise RuntimeError(
926          "Global step should be created to use _GlobalStepWaiterHook.")
927
928  def before_run(self, run_context):
929    if self._worker_is_started:
930      return None
931
932    if self._wait_until_step <= 0:
933      self._worker_is_started = True
934      return None
935
936    logging.info("Waiting for global step %d before starting training.",
937                 self._wait_until_step)
938    last_logged_step = 0
939    while True:
940      current_step = run_context.session.run(self._global_step_tensor)
941      if current_step >= self._wait_until_step:
942        self._worker_is_started = True
943        return None
944      if current_step - last_logged_step > 1000:
945        logging.info(
946            "Waiting for global step %d before starting training. "
947            "Current step is %d.", self._wait_until_step, current_step)
948        last_logged_step = current_step
949      time.sleep(0.5)
950
951
952@tf_export(v1=["train.FinalOpsHook"])
953class FinalOpsHook(session_run_hook.SessionRunHook):
954  """A hook which evaluates `Tensors` at the end of a session."""
955
956  def __init__(self, final_ops, final_ops_feed_dict=None):
957    """Initializes `FinalOpHook` with ops to run at the end of the session.
958
959    Args:
960      final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names
961        to `Tensors`.
962      final_ops_feed_dict: A feed dictionary to use when running
963        `final_ops_dict`.
964    """
965    self._final_ops = final_ops
966    self._final_ops_feed_dict = final_ops_feed_dict
967    self._final_ops_values = None
968
969  @property
970  def final_ops_values(self):
971    return self._final_ops_values
972
973  def end(self, session):
974    if self._final_ops is not None:
975      try:
976        self._final_ops_values = session.run(
977            self._final_ops, feed_dict=self._final_ops_feed_dict)
978      except (errors.OutOfRangeError, StopIteration) as e:
979        logging.warning(
980            "An OutOfRangeError or StopIteration exception is raised by the "
981            "code in FinalOpsHook. This typically means the Ops running by the "
982            "FinalOpsHook have a dependency back to some input source, which "
983            "should not happen. For example, for metrics in "
984            "tf.estimator.Estimator, all metrics functions return two Ops: "
985            "`value_op` and  `update_op`. Estimator.evaluate calls the "
986            "`update_op` for each batch of the data in input source and, once "
987            "it is exhausted, it call the `value_op` to get the metric values. "
988            "The `value_op` here should have dependency back to variables "
989            "reading only, rather than reading another batch from input. "
990            "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers "
991            "another data reading, which ends OutOfRangeError/StopIteration. "
992            "Please fix that.")
993        raise e
994
995
996@tf_export(v1=["train.FeedFnHook"])
997class FeedFnHook(session_run_hook.SessionRunHook):
998  """Runs `feed_fn` and sets the `feed_dict` accordingly."""
999
1000  def __init__(self, feed_fn):
1001    """Initializes a `FeedFnHook`.
1002
1003    Args:
1004      feed_fn: function that takes no arguments and returns `dict` of `Tensor`
1005        to feed.
1006    """
1007    self.feed_fn = feed_fn
1008
1009  def before_run(self, run_context):  # pylint: disable=unused-argument
1010    return session_run_hook.SessionRunArgs(
1011        fetches=None, feed_dict=self.feed_fn())
1012
1013
1014@tf_export(v1=["train.ProfilerHook"])
1015class ProfilerHook(session_run_hook.SessionRunHook):
1016  """Captures CPU/GPU profiling information every N steps or seconds.
1017
1018  This produces files called "timeline-<step>.json", which are in Chrome
1019  Trace format.
1020
1021  For more information see:
1022  https://github.com/catapult-project/catapult/blob/master/tracing/README.md
1023  """
1024
1025  def __init__(self,
1026               save_steps=None,
1027               save_secs=None,
1028               output_dir="",
1029               show_dataflow=True,
1030               show_memory=False):
1031    """Initializes a hook that takes periodic profiling snapshots.
1032
1033    `options.run_metadata` argument of `tf.Session.Run` is used to collect
1034    metadata about execution. This hook sets the metadata and dumps it in Chrome
1035    Trace format.
1036
1037
1038    Args:
1039      save_steps: `int`, save profile traces every N steps. Exactly one of
1040        `save_secs` and `save_steps` should be set.
1041      save_secs: `int` or `float`, save profile traces every N seconds.
1042      output_dir: `string`, the directory to save the profile traces to.
1043        Defaults to the current directory.
1044      show_dataflow: `bool`, if True, add flow events to the trace connecting
1045        producers and consumers of tensors.
1046      show_memory: `bool`, if True, add object snapshot events to the trace
1047        showing the sizes and lifetimes of tensors.
1048    """
1049    self._output_file = os.path.join(output_dir, "timeline-{}.json")
1050    self._file_writer = SummaryWriterCache.get(output_dir)
1051    self._show_dataflow = show_dataflow
1052    self._show_memory = show_memory
1053    self._timer = SecondOrStepTimer(
1054        every_secs=save_secs, every_steps=save_steps)
1055
1056  def begin(self):
1057    self._next_step = None
1058    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
1059    if self._global_step_tensor is None:
1060      raise RuntimeError("Global step should be created to use ProfilerHook.")
1061
1062  def before_run(self, run_context):
1063    self._request_summary = (
1064        self._next_step is not None and
1065        self._timer.should_trigger_for_step(self._next_step))
1066    requests = {"global_step": self._global_step_tensor}
1067    opts = (
1068        config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1069        if self._request_summary else None)
1070
1071    return SessionRunArgs(requests, options=opts)
1072
1073  def after_run(self, run_context, run_values):
1074    stale_global_step = run_values.results["global_step"]
1075    if self._next_step is None:
1076      # Update the timer so that it does not activate until N steps or seconds
1077      # have passed.
1078      self._timer.update_last_triggered_step(stale_global_step)
1079    global_step = stale_global_step + 1
1080    if self._request_summary:
1081      global_step = run_context.session.run(self._global_step_tensor)
1082      self._timer.update_last_triggered_step(global_step)
1083      self._save(global_step, self._output_file.format(global_step),
1084                 run_values.run_metadata.step_stats)
1085      self._file_writer.add_run_metadata(run_values.run_metadata,
1086                                         "step_%d" % global_step)
1087
1088    self._next_step = global_step + 1
1089
1090  def _save(self, step, save_path, step_stats):
1091    logging.info("Saving timeline for %d into '%s'.", step, save_path)
1092    with gfile.Open(save_path, "w") as f:
1093      trace = timeline.Timeline(step_stats)
1094      f.write(
1095          trace.generate_chrome_trace_format(
1096              show_dataflow=self._show_dataflow, show_memory=self._show_memory))
1097
1098
1099def _as_graph_element(obj):
1100  """Retrieves Graph element."""
1101  graph = ops.get_default_graph()
1102  if not isinstance(obj, six.string_types):
1103    if not hasattr(obj, "graph") or obj.graph != graph:
1104      raise ValueError("Passed %s should have graph attribute that is equal "
1105                       "to current graph %s." % (obj, graph))
1106    return obj
1107  if ":" in obj:
1108    element = graph.as_graph_element(obj)
1109  else:
1110    element = graph.as_graph_element(obj + ":0")
1111    # Check that there is no :1 (e.g. it's single output).
1112    try:
1113      graph.as_graph_element(obj + ":1")
1114    except (KeyError, ValueError):
1115      pass
1116    else:
1117      raise ValueError("Name %s is ambiguous, "
1118                       "as this `Operation` has multiple outputs "
1119                       "(at least 2)." % obj)
1120  return element
1121