1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Some common SessionRunHook classes. 16 17@@LoggingTensorHook 18@@StopAtStepHook 19@@CheckpointSaverHook 20@@StepCounterHook 21@@NanLossDuringTrainingError 22@@NanTensorHook 23@@SummarySaverHook 24@@GlobalStepWaiterHook 25@@ProfilerHook 26""" 27 28from __future__ import absolute_import 29from __future__ import division 30from __future__ import print_function 31 32import os 33import time 34 35import numpy as np 36import six 37 38from tensorflow.core.framework.summary_pb2 import Summary 39from tensorflow.core.protobuf import config_pb2 40from tensorflow.core.util.event_pb2 import SessionLog 41from tensorflow.python.client import timeline 42from tensorflow.python.framework import meta_graph 43from tensorflow.python.framework import ops 44from tensorflow.python.platform import gfile 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.training import session_run_hook 47from tensorflow.python.training import training_util 48from tensorflow.python.training.session_run_hook import SessionRunArgs 49from tensorflow.python.training.summary_io import SummaryWriterCache 50from tensorflow.python.util.tf_export import tf_export 51 52 53class _HookTimer(object): 54 """Base timer for determining when Hooks should trigger. 55 56 Should not be instantiated directly. 57 """ 58 59 def __init__(self): 60 pass 61 62 def reset(self): 63 """Resets the timer.""" 64 pass 65 66 def should_trigger_for_step(self, step): 67 """Return true if the timer should trigger for the specified step.""" 68 raise NotImplementedError 69 70 def update_last_triggered_step(self, step): 71 """Update the last triggered time and step number. 72 73 Args: 74 step: The current step. 75 76 Returns: 77 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number 78 of seconds between the current trigger and the last one (a float), and 79 `elapsed_steps` is the number of steps between the current trigger and 80 the last one. Both values will be set to `None` on the first trigger. 81 """ 82 raise NotImplementedError 83 84 def last_triggered_step(self): 85 """Returns the last triggered time step or None if never triggered.""" 86 raise NotImplementedError 87 88 89@tf_export("train.SecondOrStepTimer") 90class SecondOrStepTimer(_HookTimer): 91 """Timer that triggers at most once every N seconds or once every N steps. 92 """ 93 94 def __init__(self, every_secs=None, every_steps=None): 95 self.reset() 96 self._every_secs = every_secs 97 self._every_steps = every_steps 98 99 if self._every_secs is None and self._every_steps is None: 100 raise ValueError("Either every_secs or every_steps should be provided.") 101 if (self._every_secs is not None) and (self._every_steps is not None): 102 raise ValueError("Can not provide both every_secs and every_steps.") 103 104 super(SecondOrStepTimer, self).__init__() 105 106 def reset(self): 107 self._last_triggered_step = None 108 self._last_triggered_time = None 109 110 def should_trigger_for_step(self, step): 111 """Return true if the timer should trigger for the specified step. 112 113 Args: 114 step: Training step to trigger on. 115 116 Returns: 117 True if the difference between the current time and the time of the last 118 trigger exceeds `every_secs`, or if the difference between the current 119 step and the last triggered step exceeds `every_steps`. False otherwise. 120 """ 121 if self._last_triggered_step is None: 122 return True 123 124 if self._last_triggered_step == step: 125 return False 126 127 if self._every_secs is not None: 128 if time.time() >= self._last_triggered_time + self._every_secs: 129 return True 130 131 if self._every_steps is not None: 132 if step >= self._last_triggered_step + self._every_steps: 133 return True 134 135 return False 136 137 def update_last_triggered_step(self, step): 138 current_time = time.time() 139 if self._last_triggered_time is None: 140 elapsed_secs = None 141 elapsed_steps = None 142 else: 143 elapsed_secs = current_time - self._last_triggered_time 144 elapsed_steps = step - self._last_triggered_step 145 146 self._last_triggered_time = current_time 147 self._last_triggered_step = step 148 return (elapsed_secs, elapsed_steps) 149 150 def last_triggered_step(self): 151 return self._last_triggered_step 152 153 154class NeverTriggerTimer(_HookTimer): 155 """Timer that never triggers.""" 156 157 def should_trigger_for_step(self, step): 158 _ = step 159 return False 160 161 def update_last_triggered_step(self, step): 162 _ = step 163 return (None, None) 164 165 def last_triggered_step(self): 166 return None 167 168 169@tf_export("train.LoggingTensorHook") 170class LoggingTensorHook(session_run_hook.SessionRunHook): 171 """Prints the given tensors every N local steps, every N seconds, or at end. 172 173 The tensors will be printed to the log, with `INFO` severity. If you are not 174 seeing the logs, you might want to add the following line after your imports: 175 176 ```python 177 tf.logging.set_verbosity(tf.logging.INFO) 178 ``` 179 180 Note that if `at_end` is True, `tensors` should not include any tensor 181 whose evaluation produces a side effect such as consuming additional inputs. 182 """ 183 184 def __init__(self, tensors, every_n_iter=None, every_n_secs=None, 185 at_end=False, formatter=None): 186 """Initializes a `LoggingTensorHook`. 187 188 Args: 189 tensors: `dict` that maps string-valued tags to tensors/tensor names, 190 or `iterable` of tensors/tensor names. 191 every_n_iter: `int`, print the values of `tensors` once every N local 192 steps taken on the current worker. 193 every_n_secs: `int` or `float`, print the values of `tensors` once every N 194 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 195 provided. 196 at_end: `bool` specifying whether to print the values of `tensors` at the 197 end of the run. 198 formatter: function, takes dict of `tag`->`Tensor` and returns a string. 199 If `None` uses default printing all tensors. 200 201 Raises: 202 ValueError: if `every_n_iter` is non-positive. 203 """ 204 only_log_at_end = ( 205 at_end and (every_n_iter is None) and (every_n_secs is None)) 206 if (not only_log_at_end and 207 (every_n_iter is None) == (every_n_secs is None)): 208 raise ValueError( 209 "either at_end and/or exactly one of every_n_iter and every_n_secs " 210 "must be provided.") 211 if every_n_iter is not None and every_n_iter <= 0: 212 raise ValueError("invalid every_n_iter=%s." % every_n_iter) 213 if not isinstance(tensors, dict): 214 self._tag_order = tensors 215 tensors = {item: item for item in tensors} 216 else: 217 self._tag_order = tensors.keys() 218 self._tensors = tensors 219 self._formatter = formatter 220 self._timer = ( 221 NeverTriggerTimer() if only_log_at_end else 222 SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter)) 223 self._log_at_end = at_end 224 225 def begin(self): 226 self._timer.reset() 227 self._iter_count = 0 228 # Convert names to tensors if given 229 self._current_tensors = {tag: _as_graph_element(tensor) 230 for (tag, tensor) in self._tensors.items()} 231 232 def before_run(self, run_context): # pylint: disable=unused-argument 233 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) 234 if self._should_trigger: 235 return SessionRunArgs(self._current_tensors) 236 else: 237 return None 238 239 def _log_tensors(self, tensor_values): 240 original = np.get_printoptions() 241 np.set_printoptions(suppress=True) 242 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) 243 if self._formatter: 244 logging.info(self._formatter(tensor_values)) 245 else: 246 stats = [] 247 for tag in self._tag_order: 248 stats.append("%s = %s" % (tag, tensor_values[tag])) 249 if elapsed_secs is not None: 250 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) 251 else: 252 logging.info("%s", ", ".join(stats)) 253 np.set_printoptions(**original) 254 255 def after_run(self, run_context, run_values): 256 _ = run_context 257 if self._should_trigger: 258 self._log_tensors(run_values.results) 259 260 self._iter_count += 1 261 262 def end(self, session): 263 if self._log_at_end: 264 values = session.run(self._current_tensors) 265 self._log_tensors(values) 266 267 268@tf_export("train.StopAtStepHook") 269class StopAtStepHook(session_run_hook.SessionRunHook): 270 """Hook that requests stop at a specified step.""" 271 272 def __init__(self, num_steps=None, last_step=None): 273 """Initializes a `StopAtStepHook`. 274 275 This hook requests stop after either a number of steps have been 276 executed or a last step has been reached. Only one of the two options can be 277 specified. 278 279 if `num_steps` is specified, it indicates the number of steps to execute 280 after `begin()` is called. If instead `last_step` is specified, it 281 indicates the last step we want to execute, as passed to the `after_run()` 282 call. 283 284 Args: 285 num_steps: Number of steps to execute. 286 last_step: Step after which to stop. 287 288 Raises: 289 ValueError: If one of the arguments is invalid. 290 """ 291 if num_steps is None and last_step is None: 292 raise ValueError("One of num_steps or last_step must be specified.") 293 if num_steps is not None and last_step is not None: 294 raise ValueError("Only one of num_steps or last_step can be specified.") 295 self._num_steps = num_steps 296 self._last_step = last_step 297 298 def begin(self): 299 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 300 if self._global_step_tensor is None: 301 raise RuntimeError("Global step should be created to use StopAtStepHook.") 302 303 def after_create_session(self, session, coord): 304 if self._last_step is None: 305 global_step = session.run(self._global_step_tensor) 306 self._last_step = global_step + self._num_steps 307 308 def before_run(self, run_context): # pylint: disable=unused-argument 309 return SessionRunArgs(self._global_step_tensor) 310 311 def after_run(self, run_context, run_values): 312 global_step = run_values.results + 1 313 if global_step >= self._last_step: 314 # Check latest global step to ensure that the targeted last step is 315 # reached. global_step read tensor is the value of global step 316 # before running the operation. We're not sure whether current session.run 317 # incremented the global_step or not. Here we're checking it. 318 319 step = run_context.session.run(self._global_step_tensor) 320 if step >= self._last_step: 321 run_context.request_stop() 322 323 324@tf_export("train.CheckpointSaverListener") 325class CheckpointSaverListener(object): 326 """Interface for listeners that take action before or after checkpoint save. 327 328 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is 329 triggered, and provides callbacks at the following points: 330 - before using the session 331 - before each call to `Saver.save()` 332 - after each call to `Saver.save()` 333 - at the end of session 334 335 To use a listener, implement a class and pass the listener to a 336 `CheckpointSaverHook`, as in this example: 337 338 ```python 339 class ExampleCheckpointSaverListener(CheckpointSaverListener): 340 def begin(self): 341 # You can add ops to the graph here. 342 print('Starting the session.') 343 self.your_tensor = ... 344 345 def before_save(self, session, global_step_value): 346 print('About to write a checkpoint') 347 348 def after_save(self, session, global_step_value): 349 print('Done writing checkpoint.') 350 351 def end(self, session, global_step_value): 352 print('Done with the session.') 353 354 ... 355 listener = ExampleCheckpointSaverListener() 356 saver_hook = tf.train.CheckpointSaverHook( 357 checkpoint_dir, listeners=[listener]) 358 with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): 359 ... 360 ``` 361 362 A `CheckpointSaverListener` may simply take some action after every 363 checkpoint save. It is also possible for the listener to use its own schedule 364 to act less frequently, e.g. based on global_step_value. In this case, 365 implementors should implement the `end()` method to handle actions related to 366 the last checkpoint save. But the listener should not act twice if 367 `after_save()` already handled this last checkpoint save. 368 """ 369 370 def begin(self): 371 pass 372 373 def before_save(self, session, global_step_value): 374 pass 375 376 def after_save(self, session, global_step_value): 377 pass 378 379 def end(self, session, global_step_value): 380 pass 381 382 383@tf_export("train.CheckpointSaverHook") 384class CheckpointSaverHook(session_run_hook.SessionRunHook): 385 """Saves checkpoints every N steps or seconds.""" 386 387 def __init__(self, 388 checkpoint_dir, 389 save_secs=None, 390 save_steps=None, 391 saver=None, 392 checkpoint_basename="model.ckpt", 393 scaffold=None, 394 listeners=None): 395 """Initializes a `CheckpointSaverHook`. 396 397 Args: 398 checkpoint_dir: `str`, base directory for the checkpoint files. 399 save_secs: `int`, save every N secs. 400 save_steps: `int`, save every N steps. 401 saver: `Saver` object, used for saving. 402 checkpoint_basename: `str`, base name for the checkpoint files. 403 scaffold: `Scaffold`, use to get saver object. 404 listeners: List of `CheckpointSaverListener` subclass instances. 405 Used for callbacks that run immediately before or after this hook saves 406 the checkpoint. 407 408 Raises: 409 ValueError: One of `save_steps` or `save_secs` should be set. 410 ValueError: At most one of saver or scaffold should be set. 411 """ 412 logging.info("Create CheckpointSaverHook.") 413 if saver is not None and scaffold is not None: 414 raise ValueError("You cannot provide both saver and scaffold.") 415 self._saver = saver 416 self._checkpoint_dir = checkpoint_dir 417 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 418 self._scaffold = scaffold 419 self._timer = SecondOrStepTimer(every_secs=save_secs, 420 every_steps=save_steps) 421 self._listeners = listeners or [] 422 423 def begin(self): 424 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 425 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 426 if self._global_step_tensor is None: 427 raise RuntimeError( 428 "Global step should be created to use CheckpointSaverHook.") 429 for l in self._listeners: 430 l.begin() 431 432 def before_run(self, run_context): # pylint: disable=unused-argument 433 if self._timer.last_triggered_step() is None: 434 # We do write graph and saver_def at the first call of before_run. 435 # We cannot do this in begin, since we let other hooks to change graph and 436 # add variables in begin. Graph is finalized after all begin calls. 437 training_util.write_graph( 438 ops.get_default_graph().as_graph_def(add_shapes=True), 439 self._checkpoint_dir, 440 "graph.pbtxt") 441 saver_def = self._get_saver().saver_def if self._get_saver() else None 442 graph = ops.get_default_graph() 443 meta_graph_def = meta_graph.create_meta_graph_def( 444 graph_def=graph.as_graph_def(add_shapes=True), 445 saver_def=saver_def) 446 self._summary_writer.add_graph(graph) 447 self._summary_writer.add_meta_graph(meta_graph_def) 448 449 return SessionRunArgs(self._global_step_tensor) 450 451 def after_run(self, run_context, run_values): 452 stale_global_step = run_values.results 453 if self._timer.should_trigger_for_step(stale_global_step+1): 454 # get the real value after train op. 455 global_step = run_context.session.run(self._global_step_tensor) 456 if self._timer.should_trigger_for_step(global_step): 457 self._timer.update_last_triggered_step(global_step) 458 self._save(run_context.session, global_step) 459 460 def end(self, session): 461 last_step = session.run(self._global_step_tensor) 462 if last_step != self._timer.last_triggered_step(): 463 self._save(session, last_step) 464 for l in self._listeners: 465 l.end(session, last_step) 466 467 def _save(self, session, step): 468 """Saves the latest checkpoint.""" 469 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 470 471 for l in self._listeners: 472 l.before_save(session, step) 473 474 self._get_saver().save(session, self._save_path, global_step=step) 475 self._summary_writer.add_session_log( 476 SessionLog( 477 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 478 step) 479 480 for l in self._listeners: 481 l.after_save(session, step) 482 483 def _get_saver(self): 484 if self._saver is not None: 485 return self._saver 486 elif self._scaffold is not None: 487 return self._scaffold.saver 488 489 # Get saver from the SAVERS collection if present. 490 collection_key = ops.GraphKeys.SAVERS 491 savers = ops.get_collection(collection_key) 492 if not savers: 493 raise RuntimeError( 494 "No items in collection {}. Please add a saver to the collection " 495 "or provide a saver or scaffold.".format(collection_key)) 496 elif len(savers) > 1: 497 raise RuntimeError( 498 "More than one item in collection {}. " 499 "Please indicate which one to use by passing it to the constructor.". 500 format(collection_key)) 501 502 self._saver = savers[0] 503 return savers[0] 504 505 506@tf_export("train.StepCounterHook") 507class StepCounterHook(session_run_hook.SessionRunHook): 508 """Hook that counts steps per second.""" 509 510 def __init__(self, 511 every_n_steps=100, 512 every_n_secs=None, 513 output_dir=None, 514 summary_writer=None): 515 516 if (every_n_steps is None) == (every_n_secs is None): 517 raise ValueError( 518 "exactly one of every_n_steps and every_n_secs should be provided.") 519 self._timer = SecondOrStepTimer(every_steps=every_n_steps, 520 every_secs=every_n_secs) 521 522 self._summary_writer = summary_writer 523 self._output_dir = output_dir 524 self._last_global_step = None 525 self._global_step_check_count = 0 526 527 def begin(self): 528 if self._summary_writer is None and self._output_dir: 529 self._summary_writer = SummaryWriterCache.get(self._output_dir) 530 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 531 if self._global_step_tensor is None: 532 raise RuntimeError( 533 "Global step should be created to use StepCounterHook.") 534 self._summary_tag = training_util.get_global_step().op.name + "/sec" 535 536 def before_run(self, run_context): # pylint: disable=unused-argument 537 return SessionRunArgs(self._global_step_tensor) 538 539 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 540 steps_per_sec = elapsed_steps / elapsed_time 541 if self._summary_writer is not None: 542 summary = Summary(value=[Summary.Value( 543 tag=self._summary_tag, simple_value=steps_per_sec)]) 544 self._summary_writer.add_summary(summary, global_step) 545 logging.info("%s: %g", self._summary_tag, steps_per_sec) 546 547 def after_run(self, run_context, run_values): 548 _ = run_context 549 550 stale_global_step = run_values.results 551 if self._timer.should_trigger_for_step(stale_global_step+1): 552 # get the real value after train op. 553 global_step = run_context.session.run(self._global_step_tensor) 554 if self._timer.should_trigger_for_step(global_step): 555 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 556 global_step) 557 if elapsed_time is not None: 558 self._log_and_record(elapsed_steps, elapsed_time, global_step) 559 560 # Check whether the global step has been increased. Here, we do not use the 561 # timer.last_triggered_step as the timer might record a different global 562 # step value such that the comparison could be unreliable. For simplicity, 563 # we just compare the stale_global_step with previously recorded version. 564 if stale_global_step == self._last_global_step: 565 # Here, we use a counter to count how many times we have observed that the 566 # global step has not been increased. For some Optimizers, the global step 567 # is not increased each time by design. For example, SyncReplicaOptimizer 568 # doesn't increase the global step in worker's main train step. 569 self._global_step_check_count += 1 570 if self._global_step_check_count % 20 == 0: 571 self._global_step_check_count = 0 572 logging.warning( 573 "It seems that global step (tf.train.get_global_step) has not " 574 "been increased. Current value (could be stable): %s vs previous " 575 "value: %s. You could increase the global step by passing " 576 "tf.train.get_global_step() to Optimizer.apply_gradients or " 577 "Optimizer.minimize.", stale_global_step, self._last_global_step) 578 else: 579 # Whenever we observe the increment, reset the counter. 580 self._global_step_check_count = 0 581 582 self._last_global_step = stale_global_step 583 584 585@tf_export("train.NanLossDuringTrainingError") 586class NanLossDuringTrainingError(RuntimeError): 587 588 def __str__(self): 589 return "NaN loss during training." 590 591 592@tf_export("train.NanTensorHook") 593class NanTensorHook(session_run_hook.SessionRunHook): 594 """Monitors the loss tensor and stops training if loss is NaN. 595 596 Can either fail with exception or just stop training. 597 """ 598 599 def __init__(self, loss_tensor, fail_on_nan_loss=True): 600 """Initializes a `NanTensorHook`. 601 602 Args: 603 loss_tensor: `Tensor`, the loss tensor. 604 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 605 """ 606 self._loss_tensor = loss_tensor 607 self._fail_on_nan_loss = fail_on_nan_loss 608 609 def before_run(self, run_context): # pylint: disable=unused-argument 610 return SessionRunArgs(self._loss_tensor) 611 612 def after_run(self, run_context, run_values): 613 if np.isnan(run_values.results): 614 failure_message = "Model diverged with loss = NaN." 615 if self._fail_on_nan_loss: 616 logging.error(failure_message) 617 raise NanLossDuringTrainingError 618 else: 619 logging.warning(failure_message) 620 # We don't raise an error but we request stop without an exception. 621 run_context.request_stop() 622 623 624@tf_export("train.SummarySaverHook") 625class SummarySaverHook(session_run_hook.SessionRunHook): 626 """Saves summaries every N steps.""" 627 628 def __init__(self, 629 save_steps=None, 630 save_secs=None, 631 output_dir=None, 632 summary_writer=None, 633 scaffold=None, 634 summary_op=None): 635 """Initializes a `SummarySaverHook`. 636 637 Args: 638 save_steps: `int`, save summaries every N steps. Exactly one of 639 `save_secs` and `save_steps` should be set. 640 save_secs: `int`, save summaries every N seconds. 641 output_dir: `string`, the directory to save the summaries to. Only used 642 if no `summary_writer` is supplied. 643 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 644 one will be created accordingly. 645 scaffold: `Scaffold` to get summary_op if it's not provided. 646 summary_op: `Tensor` of type `string` containing the serialized `Summary` 647 protocol buffer or a list of `Tensor`. They are most likely an output 648 by TF summary methods like `tf.summary.scalar` or 649 `tf.summary.merge_all`. It can be passed in as one tensor; if more 650 than one, they must be passed in as a list. 651 652 Raises: 653 ValueError: Exactly one of scaffold or summary_op should be set. 654 """ 655 if ((scaffold is None and summary_op is None) or 656 (scaffold is not None and summary_op is not None)): 657 raise ValueError( 658 "Exactly one of scaffold or summary_op must be provided.") 659 self._summary_op = summary_op 660 self._summary_writer = summary_writer 661 self._output_dir = output_dir 662 self._scaffold = scaffold 663 self._timer = SecondOrStepTimer(every_secs=save_secs, 664 every_steps=save_steps) 665 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 666 667 def begin(self): 668 if self._summary_writer is None and self._output_dir: 669 self._summary_writer = SummaryWriterCache.get(self._output_dir) 670 self._next_step = None 671 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 672 if self._global_step_tensor is None: 673 raise RuntimeError( 674 "Global step should be created to use SummarySaverHook.") 675 676 def before_run(self, run_context): # pylint: disable=unused-argument 677 self._request_summary = ( 678 self._next_step is None or 679 self._timer.should_trigger_for_step(self._next_step)) 680 requests = {"global_step": self._global_step_tensor} 681 if self._request_summary: 682 if self._get_summary_op() is not None: 683 requests["summary"] = self._get_summary_op() 684 685 return SessionRunArgs(requests) 686 687 def after_run(self, run_context, run_values): 688 _ = run_context 689 if not self._summary_writer: 690 return 691 692 stale_global_step = run_values.results["global_step"] 693 global_step = stale_global_step + 1 694 if self._next_step is None or self._request_summary: 695 global_step = run_context.session.run(self._global_step_tensor) 696 697 if self._next_step is None: 698 self._summary_writer.add_session_log( 699 SessionLog(status=SessionLog.START), global_step) 700 701 if self._request_summary: 702 self._timer.update_last_triggered_step(global_step) 703 if "summary" in run_values.results: 704 for summary in run_values.results["summary"]: 705 self._summary_writer.add_summary(summary, global_step) 706 707 self._next_step = global_step + 1 708 709 def end(self, session=None): 710 if self._summary_writer: 711 self._summary_writer.flush() 712 713 def _get_summary_op(self): 714 """Fetches the summary op either from self._summary_op or self._scaffold. 715 716 Returns: 717 Returns a list of summary `Tensor`. 718 """ 719 summary_op = None 720 if self._summary_op is not None: 721 summary_op = self._summary_op 722 elif self._scaffold.summary_op is not None: 723 summary_op = self._scaffold.summary_op 724 725 if summary_op is None: 726 return None 727 728 if not isinstance(summary_op, list): 729 return [summary_op] 730 return summary_op 731 732 733@tf_export("train.GlobalStepWaiterHook") 734class GlobalStepWaiterHook(session_run_hook.SessionRunHook): 735 """Delays execution until global step reaches `wait_until_step`. 736 737 This hook delays execution until global step reaches to `wait_until_step`. It 738 is used to gradually start workers in distributed settings. One example usage 739 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that 740 task_id=0 is the chief. 741 """ 742 743 def __init__(self, wait_until_step): 744 """Initializes a `GlobalStepWaiterHook`. 745 746 Args: 747 wait_until_step: an `int` shows until which global step should we wait. 748 """ 749 self._wait_until_step = wait_until_step 750 751 def begin(self): 752 self._worker_is_started = False 753 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 754 if self._global_step_tensor is None: 755 raise RuntimeError( 756 "Global step should be created to use _GlobalStepWaiterHook.") 757 758 def before_run(self, run_context): 759 if self._worker_is_started: 760 return None 761 762 if self._wait_until_step <= 0: 763 self._worker_is_started = True 764 return None 765 766 logging.info("Waiting for global step %d before starting training.", 767 self._wait_until_step) 768 last_logged_step = 0 769 while True: 770 current_step = run_context.session.run(self._global_step_tensor) 771 if current_step >= self._wait_until_step: 772 self._worker_is_started = True 773 return None 774 if current_step - last_logged_step > 1000: 775 logging.info("Waiting for global step %d before starting training. " 776 "Current step is %d.", self._wait_until_step, current_step) 777 last_logged_step = current_step 778 time.sleep(0.5) 779 780 781@tf_export("train.FinalOpsHook") 782class FinalOpsHook(session_run_hook.SessionRunHook): 783 """A hook which evaluates `Tensors` at the end of a session.""" 784 785 def __init__(self, final_ops, final_ops_feed_dict=None): 786 """Initializes `FinalOpHook` with ops to run at the end of the session. 787 788 Args: 789 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of 790 names to `Tensors`. 791 final_ops_feed_dict: A feed dictionary to use when running 792 `final_ops_dict`. 793 """ 794 self._final_ops = final_ops 795 self._final_ops_feed_dict = final_ops_feed_dict 796 self._final_ops_values = None 797 798 @property 799 def final_ops_values(self): 800 return self._final_ops_values 801 802 def end(self, session): 803 if self._final_ops is not None: 804 self._final_ops_values = session.run(self._final_ops, 805 feed_dict=self._final_ops_feed_dict) 806 807 808@tf_export("train.FeedFnHook") 809class FeedFnHook(session_run_hook.SessionRunHook): 810 """Runs `feed_fn` and sets the `feed_dict` accordingly.""" 811 812 def __init__(self, feed_fn): 813 """Initializes a `FeedFnHook`. 814 815 Args: 816 feed_fn: function that takes no arguments and returns `dict` of `Tensor` 817 to feed. 818 """ 819 self.feed_fn = feed_fn 820 821 def before_run(self, run_context): # pylint: disable=unused-argument 822 return session_run_hook.SessionRunArgs( 823 fetches=None, feed_dict=self.feed_fn()) 824 825 826@tf_export("train.ProfilerHook") 827class ProfilerHook(session_run_hook.SessionRunHook): 828 """Captures CPU/GPU profiling information every N steps or seconds. 829 830 This produces files called "timeline-<step>.json", which are in Chrome 831 Trace format. 832 833 For more information see: 834 https://github.com/catapult-project/catapult/blob/master/tracing/README.md 835 """ 836 837 def __init__(self, 838 save_steps=None, 839 save_secs=None, 840 output_dir="", 841 show_dataflow=True, 842 show_memory=False): 843 """Initializes a hook that takes periodic profiling snapshots. 844 845 `options.run_metadata` argument of `tf.Session.Run` is used to collect 846 metadata about execution. This hook sets the metadata and dumps it in Chrome 847 Trace format. 848 849 850 Args: 851 save_steps: `int`, save profile traces every N steps. Exactly one of 852 `save_secs` and `save_steps` should be set. 853 save_secs: `int` or `float`, save profile traces every N seconds. 854 output_dir: `string`, the directory to save the profile traces to. 855 Defaults to the current directory. 856 show_dataflow: `bool`, if True, add flow events to the trace connecting 857 producers and consumers of tensors. 858 show_memory: `bool`, if True, add object snapshot events to the trace 859 showing the sizes and lifetimes of tensors. 860 """ 861 self._output_file = os.path.join(output_dir, "timeline-{}.json") 862 self._show_dataflow = show_dataflow 863 self._show_memory = show_memory 864 self._timer = SecondOrStepTimer( 865 every_secs=save_secs, every_steps=save_steps) 866 867 def begin(self): 868 self._next_step = None 869 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 870 if self._global_step_tensor is None: 871 raise RuntimeError("Global step should be created to use ProfilerHook.") 872 873 def before_run(self, run_context): 874 self._request_summary = ( 875 self._next_step is None or 876 self._timer.should_trigger_for_step(self._next_step)) 877 requests = {"global_step": self._global_step_tensor} 878 opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 879 if self._request_summary else None) 880 881 return SessionRunArgs(requests, options=opts) 882 883 def after_run(self, run_context, run_values): 884 stale_global_step = run_values.results["global_step"] 885 global_step = stale_global_step + 1 886 if self._request_summary: 887 global_step = run_context.session.run(self._global_step_tensor) 888 self._timer.update_last_triggered_step(global_step) 889 self._save(global_step, 890 self._output_file.format(global_step), 891 run_values.run_metadata.step_stats) 892 893 self._next_step = global_step + 1 894 895 def _save(self, step, save_path, step_stats): 896 logging.info("Saving timeline for %d into '%s'.", step, save_path) 897 with gfile.Open(save_path, "w") as f: 898 trace = timeline.Timeline(step_stats) 899 f.write( 900 trace.generate_chrome_trace_format( 901 show_dataflow=self._show_dataflow, show_memory=self._show_memory)) 902 903 904def _as_graph_element(obj): 905 """Retrieves Graph element.""" 906 graph = ops.get_default_graph() 907 if not isinstance(obj, six.string_types): 908 if not hasattr(obj, "graph") or obj.graph != graph: 909 raise ValueError("Passed %s should have graph attribute that is equal " 910 "to current graph %s." % (obj, graph)) 911 return obj 912 if ":" in obj: 913 element = graph.as_graph_element(obj) 914 else: 915 element = graph.as_graph_element(obj + ":0") 916 # Check that there is no :1 (e.g. it's single output). 917 try: 918 graph.as_graph_element(obj + ":1") 919 except (KeyError, ValueError): 920 pass 921 else: 922 raise ValueError("Name %s is ambiguous, " 923 "as this `Operation` has multiple outputs " 924 "(at least 2)." % obj) 925 return element 926