• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Some common SessionRunHook classes.
16
17@@LoggingTensorHook
18@@StopAtStepHook
19@@CheckpointSaverHook
20@@StepCounterHook
21@@NanLossDuringTrainingError
22@@NanTensorHook
23@@SummarySaverHook
24@@GlobalStepWaiterHook
25@@ProfilerHook
26"""
27
28from __future__ import absolute_import
29from __future__ import division
30from __future__ import print_function
31
32import os
33import time
34
35import numpy as np
36import six
37
38from tensorflow.core.framework.summary_pb2 import Summary
39from tensorflow.core.protobuf import config_pb2
40from tensorflow.core.util.event_pb2 import SessionLog
41from tensorflow.python.client import timeline
42from tensorflow.python.framework import meta_graph
43from tensorflow.python.framework import ops
44from tensorflow.python.platform import gfile
45from tensorflow.python.platform import tf_logging as logging
46from tensorflow.python.training import session_run_hook
47from tensorflow.python.training import training_util
48from tensorflow.python.training.session_run_hook import SessionRunArgs
49from tensorflow.python.training.summary_io import SummaryWriterCache
50from tensorflow.python.util.tf_export import tf_export
51
52
53class _HookTimer(object):
54  """Base timer for determining when Hooks should trigger.
55
56  Should not be instantiated directly.
57  """
58
59  def __init__(self):
60    pass
61
62  def reset(self):
63    """Resets the timer."""
64    pass
65
66  def should_trigger_for_step(self, step):
67    """Return true if the timer should trigger for the specified step."""
68    raise NotImplementedError
69
70  def update_last_triggered_step(self, step):
71    """Update the last triggered time and step number.
72
73    Args:
74      step: The current step.
75
76    Returns:
77      A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number
78      of seconds between the current trigger and the last one (a float), and
79      `elapsed_steps` is the number of steps between the current trigger and
80      the last one. Both values will be set to `None` on the first trigger.
81    """
82    raise NotImplementedError
83
84  def last_triggered_step(self):
85    """Returns the last triggered time step or None if never triggered."""
86    raise NotImplementedError
87
88
89@tf_export("train.SecondOrStepTimer")
90class SecondOrStepTimer(_HookTimer):
91  """Timer that triggers at most once every N seconds or once every N steps.
92  """
93
94  def __init__(self, every_secs=None, every_steps=None):
95    self.reset()
96    self._every_secs = every_secs
97    self._every_steps = every_steps
98
99    if self._every_secs is None and self._every_steps is None:
100      raise ValueError("Either every_secs or every_steps should be provided.")
101    if (self._every_secs is not None) and (self._every_steps is not None):
102      raise ValueError("Can not provide both every_secs and every_steps.")
103
104    super(SecondOrStepTimer, self).__init__()
105
106  def reset(self):
107    self._last_triggered_step = None
108    self._last_triggered_time = None
109
110  def should_trigger_for_step(self, step):
111    """Return true if the timer should trigger for the specified step.
112
113    Args:
114      step: Training step to trigger on.
115
116    Returns:
117      True if the difference between the current time and the time of the last
118      trigger exceeds `every_secs`, or if the difference between the current
119      step and the last triggered step exceeds `every_steps`. False otherwise.
120    """
121    if self._last_triggered_step is None:
122      return True
123
124    if self._last_triggered_step == step:
125      return False
126
127    if self._every_secs is not None:
128      if time.time() >= self._last_triggered_time + self._every_secs:
129        return True
130
131    if self._every_steps is not None:
132      if step >= self._last_triggered_step + self._every_steps:
133        return True
134
135    return False
136
137  def update_last_triggered_step(self, step):
138    current_time = time.time()
139    if self._last_triggered_time is None:
140      elapsed_secs = None
141      elapsed_steps = None
142    else:
143      elapsed_secs = current_time - self._last_triggered_time
144      elapsed_steps = step - self._last_triggered_step
145
146    self._last_triggered_time = current_time
147    self._last_triggered_step = step
148    return (elapsed_secs, elapsed_steps)
149
150  def last_triggered_step(self):
151    return self._last_triggered_step
152
153
154class NeverTriggerTimer(_HookTimer):
155  """Timer that never triggers."""
156
157  def should_trigger_for_step(self, step):
158    _ = step
159    return False
160
161  def update_last_triggered_step(self, step):
162    _ = step
163    return (None, None)
164
165  def last_triggered_step(self):
166    return None
167
168
169@tf_export("train.LoggingTensorHook")
170class LoggingTensorHook(session_run_hook.SessionRunHook):
171  """Prints the given tensors every N local steps, every N seconds, or at end.
172
173  The tensors will be printed to the log, with `INFO` severity. If you are not
174  seeing the logs, you might want to add the following line after your imports:
175
176  ```python
177    tf.logging.set_verbosity(tf.logging.INFO)
178  ```
179
180  Note that if `at_end` is True, `tensors` should not include any tensor
181  whose evaluation produces a side effect such as consuming additional inputs.
182  """
183
184  def __init__(self, tensors, every_n_iter=None, every_n_secs=None,
185               at_end=False, formatter=None):
186    """Initializes a `LoggingTensorHook`.
187
188    Args:
189      tensors: `dict` that maps string-valued tags to tensors/tensor names,
190          or `iterable` of tensors/tensor names.
191      every_n_iter: `int`, print the values of `tensors` once every N local
192          steps taken on the current worker.
193      every_n_secs: `int` or `float`, print the values of `tensors` once every N
194          seconds. Exactly one of `every_n_iter` and `every_n_secs` should be
195          provided.
196      at_end: `bool` specifying whether to print the values of `tensors` at the
197          end of the run.
198      formatter: function, takes dict of `tag`->`Tensor` and returns a string.
199          If `None` uses default printing all tensors.
200
201    Raises:
202      ValueError: if `every_n_iter` is non-positive.
203    """
204    only_log_at_end = (
205        at_end and (every_n_iter is None) and (every_n_secs is None))
206    if (not only_log_at_end and
207        (every_n_iter is None) == (every_n_secs is None)):
208      raise ValueError(
209          "either at_end and/or exactly one of every_n_iter and every_n_secs "
210          "must be provided.")
211    if every_n_iter is not None and every_n_iter <= 0:
212      raise ValueError("invalid every_n_iter=%s." % every_n_iter)
213    if not isinstance(tensors, dict):
214      self._tag_order = tensors
215      tensors = {item: item for item in tensors}
216    else:
217      self._tag_order = tensors.keys()
218    self._tensors = tensors
219    self._formatter = formatter
220    self._timer = (
221        NeverTriggerTimer() if only_log_at_end else
222        SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter))
223    self._log_at_end = at_end
224
225  def begin(self):
226    self._timer.reset()
227    self._iter_count = 0
228    # Convert names to tensors if given
229    self._current_tensors = {tag: _as_graph_element(tensor)
230                             for (tag, tensor) in self._tensors.items()}
231
232  def before_run(self, run_context):  # pylint: disable=unused-argument
233    self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)
234    if self._should_trigger:
235      return SessionRunArgs(self._current_tensors)
236    else:
237      return None
238
239  def _log_tensors(self, tensor_values):
240    original = np.get_printoptions()
241    np.set_printoptions(suppress=True)
242    elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count)
243    if self._formatter:
244      logging.info(self._formatter(tensor_values))
245    else:
246      stats = []
247      for tag in self._tag_order:
248        stats.append("%s = %s" % (tag, tensor_values[tag]))
249      if elapsed_secs is not None:
250        logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs)
251      else:
252        logging.info("%s", ", ".join(stats))
253    np.set_printoptions(**original)
254
255  def after_run(self, run_context, run_values):
256    _ = run_context
257    if self._should_trigger:
258      self._log_tensors(run_values.results)
259
260    self._iter_count += 1
261
262  def end(self, session):
263    if self._log_at_end:
264      values = session.run(self._current_tensors)
265      self._log_tensors(values)
266
267
268@tf_export("train.StopAtStepHook")
269class StopAtStepHook(session_run_hook.SessionRunHook):
270  """Hook that requests stop at a specified step."""
271
272  def __init__(self, num_steps=None, last_step=None):
273    """Initializes a `StopAtStepHook`.
274
275    This hook requests stop after either a number of steps have been
276    executed or a last step has been reached. Only one of the two options can be
277    specified.
278
279    if `num_steps` is specified, it indicates the number of steps to execute
280    after `begin()` is called. If instead `last_step` is specified, it
281    indicates the last step we want to execute, as passed to the `after_run()`
282    call.
283
284    Args:
285      num_steps: Number of steps to execute.
286      last_step: Step after which to stop.
287
288    Raises:
289      ValueError: If one of the arguments is invalid.
290    """
291    if num_steps is None and last_step is None:
292      raise ValueError("One of num_steps or last_step must be specified.")
293    if num_steps is not None and last_step is not None:
294      raise ValueError("Only one of num_steps or last_step can be specified.")
295    self._num_steps = num_steps
296    self._last_step = last_step
297
298  def begin(self):
299    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
300    if self._global_step_tensor is None:
301      raise RuntimeError("Global step should be created to use StopAtStepHook.")
302
303  def after_create_session(self, session, coord):
304    if self._last_step is None:
305      global_step = session.run(self._global_step_tensor)
306      self._last_step = global_step + self._num_steps
307
308  def before_run(self, run_context):  # pylint: disable=unused-argument
309    return SessionRunArgs(self._global_step_tensor)
310
311  def after_run(self, run_context, run_values):
312    global_step = run_values.results + 1
313    if global_step >= self._last_step:
314      # Check latest global step to ensure that the targeted last step is
315      # reached. global_step read tensor is the value of global step
316      # before running the operation. We're not sure whether current session.run
317      # incremented the global_step or not. Here we're checking it.
318
319      step = run_context.session.run(self._global_step_tensor)
320      if step >= self._last_step:
321        run_context.request_stop()
322
323
324@tf_export("train.CheckpointSaverListener")
325class CheckpointSaverListener(object):
326  """Interface for listeners that take action before or after checkpoint save.
327
328  `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is
329  triggered, and provides callbacks at the following points:
330   - before using the session
331   - before each call to `Saver.save()`
332   - after each call to `Saver.save()`
333   - at the end of session
334
335  To use a listener, implement a class and pass the listener to a
336  `CheckpointSaverHook`, as in this example:
337
338  ```python
339  class ExampleCheckpointSaverListener(CheckpointSaverListener):
340    def begin(self):
341      # You can add ops to the graph here.
342      print('Starting the session.')
343      self.your_tensor = ...
344
345    def before_save(self, session, global_step_value):
346      print('About to write a checkpoint')
347
348    def after_save(self, session, global_step_value):
349      print('Done writing checkpoint.')
350
351    def end(self, session, global_step_value):
352      print('Done with the session.')
353
354  ...
355  listener = ExampleCheckpointSaverListener()
356  saver_hook = tf.train.CheckpointSaverHook(
357      checkpoint_dir, listeners=[listener])
358  with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
359    ...
360  ```
361
362  A `CheckpointSaverListener` may simply take some action after every
363  checkpoint save. It is also possible for the listener to use its own schedule
364  to act less frequently, e.g. based on global_step_value. In this case,
365  implementors should implement the `end()` method to handle actions related to
366  the last checkpoint save. But the listener should not act twice if
367  `after_save()` already handled this last checkpoint save.
368  """
369
370  def begin(self):
371    pass
372
373  def before_save(self, session, global_step_value):
374    pass
375
376  def after_save(self, session, global_step_value):
377    pass
378
379  def end(self, session, global_step_value):
380    pass
381
382
383@tf_export("train.CheckpointSaverHook")
384class CheckpointSaverHook(session_run_hook.SessionRunHook):
385  """Saves checkpoints every N steps or seconds."""
386
387  def __init__(self,
388               checkpoint_dir,
389               save_secs=None,
390               save_steps=None,
391               saver=None,
392               checkpoint_basename="model.ckpt",
393               scaffold=None,
394               listeners=None):
395    """Initializes a `CheckpointSaverHook`.
396
397    Args:
398      checkpoint_dir: `str`, base directory for the checkpoint files.
399      save_secs: `int`, save every N secs.
400      save_steps: `int`, save every N steps.
401      saver: `Saver` object, used for saving.
402      checkpoint_basename: `str`, base name for the checkpoint files.
403      scaffold: `Scaffold`, use to get saver object.
404      listeners: List of `CheckpointSaverListener` subclass instances.
405        Used for callbacks that run immediately before or after this hook saves
406        the checkpoint.
407
408    Raises:
409      ValueError: One of `save_steps` or `save_secs` should be set.
410      ValueError: At most one of saver or scaffold should be set.
411    """
412    logging.info("Create CheckpointSaverHook.")
413    if saver is not None and scaffold is not None:
414      raise ValueError("You cannot provide both saver and scaffold.")
415    self._saver = saver
416    self._checkpoint_dir = checkpoint_dir
417    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
418    self._scaffold = scaffold
419    self._timer = SecondOrStepTimer(every_secs=save_secs,
420                                    every_steps=save_steps)
421    self._listeners = listeners or []
422
423  def begin(self):
424    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
425    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
426    if self._global_step_tensor is None:
427      raise RuntimeError(
428          "Global step should be created to use CheckpointSaverHook.")
429    for l in self._listeners:
430      l.begin()
431
432  def before_run(self, run_context):  # pylint: disable=unused-argument
433    if self._timer.last_triggered_step() is None:
434      # We do write graph and saver_def at the first call of before_run.
435      # We cannot do this in begin, since we let other hooks to change graph and
436      # add variables in begin. Graph is finalized after all begin calls.
437      training_util.write_graph(
438          ops.get_default_graph().as_graph_def(add_shapes=True),
439          self._checkpoint_dir,
440          "graph.pbtxt")
441      saver_def = self._get_saver().saver_def if self._get_saver() else None
442      graph = ops.get_default_graph()
443      meta_graph_def = meta_graph.create_meta_graph_def(
444          graph_def=graph.as_graph_def(add_shapes=True),
445          saver_def=saver_def)
446      self._summary_writer.add_graph(graph)
447      self._summary_writer.add_meta_graph(meta_graph_def)
448
449    return SessionRunArgs(self._global_step_tensor)
450
451  def after_run(self, run_context, run_values):
452    stale_global_step = run_values.results
453    if self._timer.should_trigger_for_step(stale_global_step+1):
454      # get the real value after train op.
455      global_step = run_context.session.run(self._global_step_tensor)
456      if self._timer.should_trigger_for_step(global_step):
457        self._timer.update_last_triggered_step(global_step)
458        self._save(run_context.session, global_step)
459
460  def end(self, session):
461    last_step = session.run(self._global_step_tensor)
462    if last_step != self._timer.last_triggered_step():
463      self._save(session, last_step)
464    for l in self._listeners:
465      l.end(session, last_step)
466
467  def _save(self, session, step):
468    """Saves the latest checkpoint."""
469    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
470
471    for l in self._listeners:
472      l.before_save(session, step)
473
474    self._get_saver().save(session, self._save_path, global_step=step)
475    self._summary_writer.add_session_log(
476        SessionLog(
477            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
478        step)
479
480    for l in self._listeners:
481      l.after_save(session, step)
482
483  def _get_saver(self):
484    if self._saver is not None:
485      return self._saver
486    elif self._scaffold is not None:
487      return self._scaffold.saver
488
489    # Get saver from the SAVERS collection if present.
490    collection_key = ops.GraphKeys.SAVERS
491    savers = ops.get_collection(collection_key)
492    if not savers:
493      raise RuntimeError(
494          "No items in collection {}. Please add a saver to the collection "
495          "or provide a saver or scaffold.".format(collection_key))
496    elif len(savers) > 1:
497      raise RuntimeError(
498          "More than one item in collection {}. "
499          "Please indicate which one to use by passing it to the constructor.".
500          format(collection_key))
501
502    self._saver = savers[0]
503    return savers[0]
504
505
506@tf_export("train.StepCounterHook")
507class StepCounterHook(session_run_hook.SessionRunHook):
508  """Hook that counts steps per second."""
509
510  def __init__(self,
511               every_n_steps=100,
512               every_n_secs=None,
513               output_dir=None,
514               summary_writer=None):
515
516    if (every_n_steps is None) == (every_n_secs is None):
517      raise ValueError(
518          "exactly one of every_n_steps and every_n_secs should be provided.")
519    self._timer = SecondOrStepTimer(every_steps=every_n_steps,
520                                    every_secs=every_n_secs)
521
522    self._summary_writer = summary_writer
523    self._output_dir = output_dir
524    self._last_global_step = None
525    self._global_step_check_count = 0
526
527  def begin(self):
528    if self._summary_writer is None and self._output_dir:
529      self._summary_writer = SummaryWriterCache.get(self._output_dir)
530    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
531    if self._global_step_tensor is None:
532      raise RuntimeError(
533          "Global step should be created to use StepCounterHook.")
534    self._summary_tag = training_util.get_global_step().op.name + "/sec"
535
536  def before_run(self, run_context):  # pylint: disable=unused-argument
537    return SessionRunArgs(self._global_step_tensor)
538
539  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
540    steps_per_sec = elapsed_steps / elapsed_time
541    if self._summary_writer is not None:
542      summary = Summary(value=[Summary.Value(
543          tag=self._summary_tag, simple_value=steps_per_sec)])
544      self._summary_writer.add_summary(summary, global_step)
545    logging.info("%s: %g", self._summary_tag, steps_per_sec)
546
547  def after_run(self, run_context, run_values):
548    _ = run_context
549
550    stale_global_step = run_values.results
551    if self._timer.should_trigger_for_step(stale_global_step+1):
552      # get the real value after train op.
553      global_step = run_context.session.run(self._global_step_tensor)
554      if self._timer.should_trigger_for_step(global_step):
555        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
556            global_step)
557        if elapsed_time is not None:
558          self._log_and_record(elapsed_steps, elapsed_time, global_step)
559
560    # Check whether the global step has been increased. Here, we do not use the
561    # timer.last_triggered_step as the timer might record a different global
562    # step value such that the comparison could be unreliable. For simplicity,
563    # we just compare the stale_global_step with previously recorded version.
564    if stale_global_step == self._last_global_step:
565      # Here, we use a counter to count how many times we have observed that the
566      # global step has not been increased. For some Optimizers, the global step
567      # is not increased each time by design. For example, SyncReplicaOptimizer
568      # doesn't increase the global step in worker's main train step.
569      self._global_step_check_count += 1
570      if self._global_step_check_count % 20 == 0:
571        self._global_step_check_count = 0
572        logging.warning(
573            "It seems that global step (tf.train.get_global_step) has not "
574            "been increased. Current value (could be stable): %s vs previous "
575            "value: %s. You could increase the global step by passing "
576            "tf.train.get_global_step() to Optimizer.apply_gradients or "
577            "Optimizer.minimize.", stale_global_step, self._last_global_step)
578    else:
579      # Whenever we observe the increment, reset the counter.
580      self._global_step_check_count = 0
581
582    self._last_global_step = stale_global_step
583
584
585@tf_export("train.NanLossDuringTrainingError")
586class NanLossDuringTrainingError(RuntimeError):
587
588  def __str__(self):
589    return "NaN loss during training."
590
591
592@tf_export("train.NanTensorHook")
593class NanTensorHook(session_run_hook.SessionRunHook):
594  """Monitors the loss tensor and stops training if loss is NaN.
595
596  Can either fail with exception or just stop training.
597  """
598
599  def __init__(self, loss_tensor, fail_on_nan_loss=True):
600    """Initializes a `NanTensorHook`.
601
602    Args:
603      loss_tensor: `Tensor`, the loss tensor.
604      fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
605    """
606    self._loss_tensor = loss_tensor
607    self._fail_on_nan_loss = fail_on_nan_loss
608
609  def before_run(self, run_context):  # pylint: disable=unused-argument
610    return SessionRunArgs(self._loss_tensor)
611
612  def after_run(self, run_context, run_values):
613    if np.isnan(run_values.results):
614      failure_message = "Model diverged with loss = NaN."
615      if self._fail_on_nan_loss:
616        logging.error(failure_message)
617        raise NanLossDuringTrainingError
618      else:
619        logging.warning(failure_message)
620        # We don't raise an error but we request stop without an exception.
621        run_context.request_stop()
622
623
624@tf_export("train.SummarySaverHook")
625class SummarySaverHook(session_run_hook.SessionRunHook):
626  """Saves summaries every N steps."""
627
628  def __init__(self,
629               save_steps=None,
630               save_secs=None,
631               output_dir=None,
632               summary_writer=None,
633               scaffold=None,
634               summary_op=None):
635    """Initializes a `SummarySaverHook`.
636
637    Args:
638      save_steps: `int`, save summaries every N steps. Exactly one of
639          `save_secs` and `save_steps` should be set.
640      save_secs: `int`, save summaries every N seconds.
641      output_dir: `string`, the directory to save the summaries to. Only used
642          if no `summary_writer` is supplied.
643      summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed,
644          one will be created accordingly.
645      scaffold: `Scaffold` to get summary_op if it's not provided.
646      summary_op: `Tensor` of type `string` containing the serialized `Summary`
647          protocol buffer or a list of `Tensor`. They are most likely an output
648          by TF summary methods like `tf.summary.scalar` or
649          `tf.summary.merge_all`. It can be passed in as one tensor; if more
650          than one, they must be passed in as a list.
651
652    Raises:
653      ValueError: Exactly one of scaffold or summary_op should be set.
654    """
655    if ((scaffold is None and summary_op is None) or
656        (scaffold is not None and summary_op is not None)):
657      raise ValueError(
658          "Exactly one of scaffold or summary_op must be provided.")
659    self._summary_op = summary_op
660    self._summary_writer = summary_writer
661    self._output_dir = output_dir
662    self._scaffold = scaffold
663    self._timer = SecondOrStepTimer(every_secs=save_secs,
664                                    every_steps=save_steps)
665    # TODO(mdan): Throw an error if output_dir and summary_writer are None.
666
667  def begin(self):
668    if self._summary_writer is None and self._output_dir:
669      self._summary_writer = SummaryWriterCache.get(self._output_dir)
670    self._next_step = None
671    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
672    if self._global_step_tensor is None:
673      raise RuntimeError(
674          "Global step should be created to use SummarySaverHook.")
675
676  def before_run(self, run_context):  # pylint: disable=unused-argument
677    self._request_summary = (
678        self._next_step is None or
679        self._timer.should_trigger_for_step(self._next_step))
680    requests = {"global_step": self._global_step_tensor}
681    if self._request_summary:
682      if self._get_summary_op() is not None:
683        requests["summary"] = self._get_summary_op()
684
685    return SessionRunArgs(requests)
686
687  def after_run(self, run_context, run_values):
688    _ = run_context
689    if not self._summary_writer:
690      return
691
692    stale_global_step = run_values.results["global_step"]
693    global_step = stale_global_step + 1
694    if self._next_step is None or self._request_summary:
695      global_step = run_context.session.run(self._global_step_tensor)
696
697    if self._next_step is None:
698      self._summary_writer.add_session_log(
699          SessionLog(status=SessionLog.START), global_step)
700
701    if self._request_summary:
702      self._timer.update_last_triggered_step(global_step)
703      if "summary" in run_values.results:
704        for summary in run_values.results["summary"]:
705          self._summary_writer.add_summary(summary, global_step)
706
707    self._next_step = global_step + 1
708
709  def end(self, session=None):
710    if self._summary_writer:
711      self._summary_writer.flush()
712
713  def _get_summary_op(self):
714    """Fetches the summary op either from self._summary_op or self._scaffold.
715
716    Returns:
717      Returns a list of summary `Tensor`.
718    """
719    summary_op = None
720    if self._summary_op is not None:
721      summary_op = self._summary_op
722    elif self._scaffold.summary_op is not None:
723      summary_op = self._scaffold.summary_op
724
725    if summary_op is None:
726      return None
727
728    if not isinstance(summary_op, list):
729      return [summary_op]
730    return summary_op
731
732
733@tf_export("train.GlobalStepWaiterHook")
734class GlobalStepWaiterHook(session_run_hook.SessionRunHook):
735  """Delays execution until global step reaches `wait_until_step`.
736
737  This hook delays execution until global step reaches to `wait_until_step`. It
738  is used to gradually start workers in distributed settings. One example usage
739  would be setting `wait_until_step=int(K*log(task_id+1))` assuming that
740  task_id=0 is the chief.
741  """
742
743  def __init__(self, wait_until_step):
744    """Initializes a `GlobalStepWaiterHook`.
745
746    Args:
747      wait_until_step: an `int` shows until which global step should we wait.
748    """
749    self._wait_until_step = wait_until_step
750
751  def begin(self):
752    self._worker_is_started = False
753    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
754    if self._global_step_tensor is None:
755      raise RuntimeError(
756          "Global step should be created to use _GlobalStepWaiterHook.")
757
758  def before_run(self, run_context):
759    if self._worker_is_started:
760      return None
761
762    if self._wait_until_step <= 0:
763      self._worker_is_started = True
764      return None
765
766    logging.info("Waiting for global step %d before starting training.",
767                 self._wait_until_step)
768    last_logged_step = 0
769    while True:
770      current_step = run_context.session.run(self._global_step_tensor)
771      if current_step >= self._wait_until_step:
772        self._worker_is_started = True
773        return None
774      if current_step - last_logged_step > 1000:
775        logging.info("Waiting for global step %d before starting training. "
776                     "Current step is %d.", self._wait_until_step, current_step)
777        last_logged_step = current_step
778      time.sleep(0.5)
779
780
781@tf_export("train.FinalOpsHook")
782class FinalOpsHook(session_run_hook.SessionRunHook):
783  """A hook which evaluates `Tensors` at the end of a session."""
784
785  def __init__(self, final_ops, final_ops_feed_dict=None):
786    """Initializes `FinalOpHook` with ops to run at the end of the session.
787
788    Args:
789      final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of
790        names to `Tensors`.
791      final_ops_feed_dict: A feed dictionary to use when running
792        `final_ops_dict`.
793    """
794    self._final_ops = final_ops
795    self._final_ops_feed_dict = final_ops_feed_dict
796    self._final_ops_values = None
797
798  @property
799  def final_ops_values(self):
800    return self._final_ops_values
801
802  def end(self, session):
803    if self._final_ops is not None:
804      self._final_ops_values = session.run(self._final_ops,
805                                           feed_dict=self._final_ops_feed_dict)
806
807
808@tf_export("train.FeedFnHook")
809class FeedFnHook(session_run_hook.SessionRunHook):
810  """Runs `feed_fn` and sets the `feed_dict` accordingly."""
811
812  def __init__(self, feed_fn):
813    """Initializes a `FeedFnHook`.
814
815    Args:
816      feed_fn: function that takes no arguments and returns `dict` of `Tensor`
817        to feed.
818    """
819    self.feed_fn = feed_fn
820
821  def before_run(self, run_context):  # pylint: disable=unused-argument
822    return session_run_hook.SessionRunArgs(
823        fetches=None, feed_dict=self.feed_fn())
824
825
826@tf_export("train.ProfilerHook")
827class ProfilerHook(session_run_hook.SessionRunHook):
828  """Captures CPU/GPU profiling information every N steps or seconds.
829
830  This produces files called "timeline-<step>.json", which are in Chrome
831  Trace format.
832
833  For more information see:
834  https://github.com/catapult-project/catapult/blob/master/tracing/README.md
835  """
836
837  def __init__(self,
838               save_steps=None,
839               save_secs=None,
840               output_dir="",
841               show_dataflow=True,
842               show_memory=False):
843    """Initializes a hook that takes periodic profiling snapshots.
844
845    `options.run_metadata` argument of `tf.Session.Run` is used to collect
846    metadata about execution. This hook sets the metadata and dumps it in Chrome
847    Trace format.
848
849
850    Args:
851      save_steps: `int`, save profile traces every N steps. Exactly one of
852          `save_secs` and `save_steps` should be set.
853      save_secs: `int` or `float`, save profile traces every N seconds.
854      output_dir: `string`, the directory to save the profile traces to.
855          Defaults to the current directory.
856      show_dataflow: `bool`, if True, add flow events to the trace connecting
857          producers and consumers of tensors.
858      show_memory: `bool`, if True, add object snapshot events to the trace
859          showing the sizes and lifetimes of tensors.
860    """
861    self._output_file = os.path.join(output_dir, "timeline-{}.json")
862    self._show_dataflow = show_dataflow
863    self._show_memory = show_memory
864    self._timer = SecondOrStepTimer(
865        every_secs=save_secs, every_steps=save_steps)
866
867  def begin(self):
868    self._next_step = None
869    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
870    if self._global_step_tensor is None:
871      raise RuntimeError("Global step should be created to use ProfilerHook.")
872
873  def before_run(self, run_context):
874    self._request_summary = (
875        self._next_step is None or
876        self._timer.should_trigger_for_step(self._next_step))
877    requests = {"global_step": self._global_step_tensor}
878    opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
879            if self._request_summary else None)
880
881    return SessionRunArgs(requests, options=opts)
882
883  def after_run(self, run_context, run_values):
884    stale_global_step = run_values.results["global_step"]
885    global_step = stale_global_step + 1
886    if self._request_summary:
887      global_step = run_context.session.run(self._global_step_tensor)
888      self._timer.update_last_triggered_step(global_step)
889      self._save(global_step,
890                 self._output_file.format(global_step),
891                 run_values.run_metadata.step_stats)
892
893    self._next_step = global_step + 1
894
895  def _save(self, step, save_path, step_stats):
896    logging.info("Saving timeline for %d into '%s'.", step, save_path)
897    with gfile.Open(save_path, "w") as f:
898      trace = timeline.Timeline(step_stats)
899      f.write(
900          trace.generate_chrome_trace_format(
901              show_dataflow=self._show_dataflow, show_memory=self._show_memory))
902
903
904def _as_graph_element(obj):
905  """Retrieves Graph element."""
906  graph = ops.get_default_graph()
907  if not isinstance(obj, six.string_types):
908    if not hasattr(obj, "graph") or obj.graph != graph:
909      raise ValueError("Passed %s should have graph attribute that is equal "
910                       "to current graph %s." % (obj, graph))
911    return obj
912  if ":" in obj:
913    element = graph.as_graph_element(obj)
914  else:
915    element = graph.as_graph_element(obj + ":0")
916    # Check that there is no :1 (e.g. it's single output).
917    try:
918      graph.as_graph_element(obj + ":1")
919    except (KeyError, ValueError):
920      pass
921    else:
922      raise ValueError("Name %s is ambiguous, "
923                       "as this `Operation` has multiple outputs "
924                       "(at least 2)." % obj)
925  return element
926