1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Some common SessionRunHook classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import time 23 24import numpy as np 25import six 26 27from tensorflow.core.framework.summary_pb2 import Summary 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.util.event_pb2 import SessionLog 30from tensorflow.python.client import timeline 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import meta_graph 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import init_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.platform import gfile 38from tensorflow.python.platform import tf_logging as logging 39from tensorflow.python.training import session_run_hook 40from tensorflow.python.training import training_util 41from tensorflow.python.training.session_run_hook import SessionRunArgs 42from tensorflow.python.training.summary_io import SummaryWriterCache 43from tensorflow.python.util.tf_export import tf_export 44 45 46_HOOKS = "hooks" 47_STEPS_PER_RUN_VAR = "steps_per_run" 48 49 50class _HookTimer(object): 51 """Base timer for determining when Hooks should trigger. 52 53 Should not be instantiated directly. 54 """ 55 56 def __init__(self): 57 pass 58 59 def reset(self): 60 """Resets the timer.""" 61 pass 62 63 def should_trigger_for_step(self, step): 64 """Return true if the timer should trigger for the specified step.""" 65 raise NotImplementedError 66 67 def update_last_triggered_step(self, step): 68 """Update the last triggered time and step number. 69 70 Args: 71 step: The current step. 72 73 Returns: 74 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number 75 of seconds between the current trigger and the last one (a float), and 76 `elapsed_steps` is the number of steps between the current trigger and 77 the last one. Both values will be set to `None` on the first trigger. 78 """ 79 raise NotImplementedError 80 81 def last_triggered_step(self): 82 """Returns the last triggered time step or None if never triggered.""" 83 raise NotImplementedError 84 85 86@tf_export(v1=["train.SecondOrStepTimer"]) 87class SecondOrStepTimer(_HookTimer): 88 """Timer that triggers at most once every N seconds or once every N steps. 89 """ 90 91 def __init__(self, every_secs=None, every_steps=None): 92 self.reset() 93 self._every_secs = every_secs 94 self._every_steps = every_steps 95 96 if self._every_secs is None and self._every_steps is None: 97 raise ValueError("Either every_secs or every_steps should be provided.") 98 if (self._every_secs is not None) and (self._every_steps is not None): 99 raise ValueError("Can not provide both every_secs and every_steps.") 100 101 super(SecondOrStepTimer, self).__init__() 102 103 def reset(self): 104 self._last_triggered_step = None 105 self._last_triggered_time = None 106 107 def should_trigger_for_step(self, step): 108 """Return true if the timer should trigger for the specified step. 109 110 Args: 111 step: Training step to trigger on. 112 113 Returns: 114 True if the difference between the current time and the time of the last 115 trigger exceeds `every_secs`, or if the difference between the current 116 step and the last triggered step exceeds `every_steps`. False otherwise. 117 """ 118 if self._last_triggered_step is None: 119 return True 120 121 if self._last_triggered_step == step: 122 return False 123 124 if self._every_secs is not None: 125 if time.time() >= self._last_triggered_time + self._every_secs: 126 return True 127 128 if self._every_steps is not None: 129 if step >= self._last_triggered_step + self._every_steps: 130 return True 131 132 return False 133 134 def update_last_triggered_step(self, step): 135 current_time = time.time() 136 if self._last_triggered_time is None: 137 elapsed_secs = None 138 elapsed_steps = None 139 else: 140 elapsed_secs = current_time - self._last_triggered_time 141 elapsed_steps = step - self._last_triggered_step 142 143 self._last_triggered_time = current_time 144 self._last_triggered_step = step 145 return (elapsed_secs, elapsed_steps) 146 147 def last_triggered_step(self): 148 return self._last_triggered_step 149 150 151class NeverTriggerTimer(_HookTimer): 152 """Timer that never triggers.""" 153 154 def should_trigger_for_step(self, step): 155 _ = step 156 return False 157 158 def update_last_triggered_step(self, step): 159 _ = step 160 return (None, None) 161 162 def last_triggered_step(self): 163 return None 164 165 166@tf_export(v1=["train.LoggingTensorHook"]) 167class LoggingTensorHook(session_run_hook.SessionRunHook): 168 """Prints the given tensors every N local steps, every N seconds, or at end. 169 170 The tensors will be printed to the log, with `INFO` severity. If you are not 171 seeing the logs, you might want to add the following line after your imports: 172 173 ```python 174 tf.logging.set_verbosity(tf.logging.INFO) 175 ``` 176 177 Note that if `at_end` is True, `tensors` should not include any tensor 178 whose evaluation produces a side effect such as consuming additional inputs. 179 """ 180 181 def __init__(self, tensors, every_n_iter=None, every_n_secs=None, 182 at_end=False, formatter=None): 183 """Initializes a `LoggingTensorHook`. 184 185 Args: 186 tensors: `dict` that maps string-valued tags to tensors/tensor names, 187 or `iterable` of tensors/tensor names. 188 every_n_iter: `int`, print the values of `tensors` once every N local 189 steps taken on the current worker. 190 every_n_secs: `int` or `float`, print the values of `tensors` once every N 191 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 192 provided. 193 at_end: `bool` specifying whether to print the values of `tensors` at the 194 end of the run. 195 formatter: function, takes dict of `tag`->`Tensor` and returns a string. 196 If `None` uses default printing all tensors. 197 198 Raises: 199 ValueError: if `every_n_iter` is non-positive. 200 """ 201 only_log_at_end = ( 202 at_end and (every_n_iter is None) and (every_n_secs is None)) 203 if (not only_log_at_end and 204 (every_n_iter is None) == (every_n_secs is None)): 205 raise ValueError( 206 "either at_end and/or exactly one of every_n_iter and every_n_secs " 207 "must be provided.") 208 if every_n_iter is not None and every_n_iter <= 0: 209 raise ValueError("invalid every_n_iter=%s." % every_n_iter) 210 if not isinstance(tensors, dict): 211 self._tag_order = tensors 212 tensors = {item: item for item in tensors} 213 else: 214 self._tag_order = sorted(tensors.keys()) 215 self._tensors = tensors 216 self._formatter = formatter 217 self._timer = ( 218 NeverTriggerTimer() if only_log_at_end else 219 SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter)) 220 self._log_at_end = at_end 221 222 def begin(self): 223 self._timer.reset() 224 self._iter_count = 0 225 # Convert names to tensors if given 226 self._current_tensors = {tag: _as_graph_element(tensor) 227 for (tag, tensor) in self._tensors.items()} 228 229 def before_run(self, run_context): # pylint: disable=unused-argument 230 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) 231 if self._should_trigger: 232 return SessionRunArgs(self._current_tensors) 233 else: 234 return None 235 236 def _log_tensors(self, tensor_values): 237 original = np.get_printoptions() 238 np.set_printoptions(suppress=True) 239 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) 240 if self._formatter: 241 logging.info(self._formatter(tensor_values)) 242 else: 243 stats = [] 244 for tag in self._tag_order: 245 stats.append("%s = %s" % (tag, tensor_values[tag])) 246 if elapsed_secs is not None: 247 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) 248 else: 249 logging.info("%s", ", ".join(stats)) 250 np.set_printoptions(**original) 251 252 def after_run(self, run_context, run_values): 253 _ = run_context 254 if self._should_trigger: 255 self._log_tensors(run_values.results) 256 257 self._iter_count += 1 258 259 def end(self, session): 260 if self._log_at_end: 261 values = session.run(self._current_tensors) 262 self._log_tensors(values) 263 264 265def get_or_create_steps_per_run_variable(): 266 """Gets or creates the steps_per_run variable. 267 268 In Estimator, the user provided computation, the model_fn, is wrapped 269 inside a tf.while_loop for peak performance. The iterations of the loop are 270 specified by this variable, which adjusts its value on the CPU after each 271 device program execution and before the next execution. 272 273 The purpose of using a variable, rather than a constant, is to allow 274 Estimator adapt the device training iterations according to the final steps 275 specified by users. For example, if the user sets the steps_per_run as 276 4 and steps as 10 in Estimator.train(), the steps_per_run 277 variable will have the following value before each training run. 278 279 - 1-st execution: steps_per_run = 4 280 - 2-nd execution: steps_per_run = 4 281 - 3-rd execution: steps_per_run = 2 282 283 As model_fn increases the global step once per train_op invocation, the global 284 step is 10 after all executions, matching the steps=10 inputs passed in by 285 users. 286 287 Returns: 288 A TF non-trainable resource variable. 289 290 Raises: 291 RuntimeError: If multi steps_per_run variables were found. 292 """ 293 graph = ops.get_default_graph() 294 collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR) 295 steps_per_run_vars = graph.get_collection(collection_name) 296 if len(steps_per_run_vars) == 1: 297 return steps_per_run_vars[0] 298 elif len(steps_per_run_vars) > 1: 299 raise RuntimeError("Multiple steps_per_run_var in collection.") 300 301 with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE): 302 return variable_scope.get_variable( 303 _STEPS_PER_RUN_VAR, 304 initializer=init_ops.ones_initializer(), 305 shape=[], 306 dtype=dtypes.int32, 307 trainable=False, 308 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], 309 use_resource=True) 310 311 312class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): 313 """Hook that requests stop at a specified step.""" 314 315 def __init__(self, num_steps=None, last_step=None, steps_per_run=1): 316 """Initializes a `MultiStepStopAtStepHook`. 317 318 This hook requests stop after either a number of steps have been 319 executed or a last step has been reached. Only one of the two options can be 320 specified. 321 322 if `num_steps` is specified, it indicates the number of steps to execute 323 after `begin()` is called. If instead `last_step` is specified, it 324 indicates the last step we want to execute, as passed to the `after_run()` 325 call. 326 327 In Estimator, the user provided computation, the model_fn, is wrapped 328 inside a tf.while_loop for peak performance. The steps_per_run variable 329 determines the number of iterations of the loop before returning to the CPU. 330 331 Args: 332 num_steps: Number of steps to execute. 333 last_step: Step after which to stop. 334 steps_per_run: Number of steps executed per run call. 335 336 Raises: 337 ValueError: If one of the arguments is invalid. 338 """ 339 if num_steps is None and last_step is None: 340 raise ValueError("One of num_steps or last_step must be specified.") 341 if num_steps is not None and last_step is not None: 342 raise ValueError("Only one of num_steps or last_step can be specified.") 343 if steps_per_run is None or steps_per_run < 1: 344 raise ValueError("steps_per_run should be greater than 0") 345 self._num_steps = num_steps 346 self._last_step = last_step 347 self._steps_per_run_initial_value = steps_per_run 348 349 def begin(self): 350 self._global_step_tensor = training_util.get_global_step() 351 if self._global_step_tensor is None: 352 raise RuntimeError("Global step should be created to use StopAtStepHook.") 353 self._steps_per_run_variable = get_or_create_steps_per_run_variable() 354 355 def _update_steps_per_run_variable(self, global_step, session): 356 steps = min(self._last_step - global_step, 357 self._steps_per_run_initial_value) 358 self._steps_per_run_variable.load(steps, session=session) 359 360 def after_create_session(self, session, coord): 361 global_step = session.run(self._global_step_tensor) 362 if self._last_step is None: 363 self._last_step = global_step + self._num_steps 364 self._update_steps_per_run_variable(global_step, session) 365 366 def after_run(self, run_context, run_values): 367 # Global step cannot be retrieved via SessionRunArgs and before_run due to 368 # race condition in hook execution. 369 global_step = run_context.session.run(self._global_step_tensor) 370 if global_step >= self._last_step: 371 run_context.request_stop() 372 else: 373 self._update_steps_per_run_variable(global_step, run_context.session) 374 375 376@tf_export(v1=["train.StopAtStepHook"]) 377class StopAtStepHook(session_run_hook.SessionRunHook): 378 """Hook that requests stop at a specified step.""" 379 380 def __init__(self, num_steps=None, last_step=None): 381 """Initializes a `StopAtStepHook`. 382 383 This hook requests stop after either a number of steps have been 384 executed or a last step has been reached. Only one of the two options can be 385 specified. 386 387 if `num_steps` is specified, it indicates the number of steps to execute 388 after `begin()` is called. If instead `last_step` is specified, it 389 indicates the last step we want to execute, as passed to the `after_run()` 390 call. 391 392 Args: 393 num_steps: Number of steps to execute. 394 last_step: Step after which to stop. 395 396 Raises: 397 ValueError: If one of the arguments is invalid. 398 """ 399 if num_steps is None and last_step is None: 400 raise ValueError("One of num_steps or last_step must be specified.") 401 if num_steps is not None and last_step is not None: 402 raise ValueError("Only one of num_steps or last_step can be specified.") 403 self._num_steps = num_steps 404 self._last_step = last_step 405 406 def begin(self): 407 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 408 if self._global_step_tensor is None: 409 raise RuntimeError("Global step should be created to use StopAtStepHook.") 410 411 def after_create_session(self, session, coord): 412 if self._last_step is None: 413 global_step = session.run(self._global_step_tensor) 414 self._last_step = global_step + self._num_steps 415 416 def before_run(self, run_context): # pylint: disable=unused-argument 417 return SessionRunArgs(self._global_step_tensor) 418 419 def after_run(self, run_context, run_values): 420 global_step = run_values.results + 1 421 if global_step >= self._last_step: 422 # Check latest global step to ensure that the targeted last step is 423 # reached. global_step read tensor is the value of global step 424 # before running the operation. We're not sure whether current session.run 425 # incremented the global_step or not. Here we're checking it. 426 427 step = run_context.session.run(self._global_step_tensor) 428 if step >= self._last_step: 429 run_context.request_stop() 430 431 432@tf_export(v1=["train.CheckpointSaverListener"]) 433class CheckpointSaverListener(object): 434 """Interface for listeners that take action before or after checkpoint save. 435 436 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is 437 triggered, and provides callbacks at the following points: 438 - before using the session 439 - before each call to `Saver.save()` 440 - after each call to `Saver.save()` 441 - at the end of session 442 443 To use a listener, implement a class and pass the listener to a 444 `CheckpointSaverHook`, as in this example: 445 446 ```python 447 class ExampleCheckpointSaverListener(CheckpointSaverListener): 448 def begin(self): 449 # You can add ops to the graph here. 450 print('Starting the session.') 451 self.your_tensor = ... 452 453 def before_save(self, session, global_step_value): 454 print('About to write a checkpoint') 455 456 def after_save(self, session, global_step_value): 457 print('Done writing checkpoint.') 458 if decided_to_stop_training(): 459 return True 460 461 def end(self, session, global_step_value): 462 print('Done with the session.') 463 464 ... 465 listener = ExampleCheckpointSaverListener() 466 saver_hook = tf.train.CheckpointSaverHook( 467 checkpoint_dir, listeners=[listener]) 468 with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): 469 ... 470 ``` 471 472 A `CheckpointSaverListener` may simply take some action after every 473 checkpoint save. It is also possible for the listener to use its own schedule 474 to act less frequently, e.g. based on global_step_value. In this case, 475 implementors should implement the `end()` method to handle actions related to 476 the last checkpoint save. But the listener should not act twice if 477 `after_save()` already handled this last checkpoint save. 478 479 A `CheckpointSaverListener` can request training to be stopped, by returning 480 True in `after_save`. Please note that, in replicated distributed training 481 setting, only `chief` should use this behavior. Otherwise each worker will do 482 their own evaluation, which may be wasteful of resources. 483 """ 484 485 def begin(self): 486 pass 487 488 def before_save(self, session, global_step_value): 489 pass 490 491 def after_save(self, session, global_step_value): 492 pass 493 494 def end(self, session, global_step_value): 495 pass 496 497 498@tf_export(v1=["train.CheckpointSaverHook"]) 499class CheckpointSaverHook(session_run_hook.SessionRunHook): 500 """Saves checkpoints every N steps or seconds.""" 501 502 def __init__(self, 503 checkpoint_dir, 504 save_secs=None, 505 save_steps=None, 506 saver=None, 507 checkpoint_basename="model.ckpt", 508 scaffold=None, 509 listeners=None): 510 """Initializes a `CheckpointSaverHook`. 511 512 Args: 513 checkpoint_dir: `str`, base directory for the checkpoint files. 514 save_secs: `int`, save every N secs. 515 save_steps: `int`, save every N steps. 516 saver: `Saver` object, used for saving. 517 checkpoint_basename: `str`, base name for the checkpoint files. 518 scaffold: `Scaffold`, use to get saver object. 519 listeners: List of `CheckpointSaverListener` subclass instances. 520 Used for callbacks that run immediately before or after this hook saves 521 the checkpoint. 522 523 Raises: 524 ValueError: One of `save_steps` or `save_secs` should be set. 525 ValueError: At most one of `saver` or `scaffold` should be set. 526 """ 527 logging.info("Create CheckpointSaverHook.") 528 if saver is not None and scaffold is not None: 529 raise ValueError("You cannot provide both saver and scaffold.") 530 self._saver = saver 531 self._checkpoint_dir = checkpoint_dir 532 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 533 self._scaffold = scaffold 534 self._timer = SecondOrStepTimer(every_secs=save_secs, 535 every_steps=save_steps) 536 self._listeners = listeners or [] 537 self._steps_per_run = 1 538 539 def _set_steps_per_run(self, steps_per_run): 540 self._steps_per_run = steps_per_run 541 542 def begin(self): 543 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 544 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 545 if self._global_step_tensor is None: 546 raise RuntimeError( 547 "Global step should be created to use CheckpointSaverHook.") 548 for l in self._listeners: 549 l.begin() 550 551 def after_create_session(self, session, coord): 552 global_step = session.run(self._global_step_tensor) 553 # We do write graph and saver_def at the first call of before_run. 554 # We cannot do this in begin, since we let other hooks to change graph and 555 # add variables in begin. Graph is finalized after all begin calls. 556 training_util.write_graph( 557 ops.get_default_graph().as_graph_def(add_shapes=True), 558 self._checkpoint_dir, 559 "graph.pbtxt") 560 saver_def = self._get_saver().saver_def if self._get_saver() else None 561 graph = ops.get_default_graph() 562 meta_graph_def = meta_graph.create_meta_graph_def( 563 graph_def=graph.as_graph_def(add_shapes=True), 564 saver_def=saver_def) 565 self._summary_writer.add_graph(graph) 566 self._summary_writer.add_meta_graph(meta_graph_def) 567 # The checkpoint saved here is the state at step "global_step". 568 self._save(session, global_step) 569 self._timer.update_last_triggered_step(global_step) 570 571 def before_run(self, run_context): # pylint: disable=unused-argument 572 return SessionRunArgs(self._global_step_tensor) 573 574 def after_run(self, run_context, run_values): 575 stale_global_step = run_values.results 576 if self._timer.should_trigger_for_step( 577 stale_global_step + self._steps_per_run): 578 # get the real value after train op. 579 global_step = run_context.session.run(self._global_step_tensor) 580 if self._timer.should_trigger_for_step(global_step): 581 self._timer.update_last_triggered_step(global_step) 582 if self._save(run_context.session, global_step): 583 run_context.request_stop() 584 585 def end(self, session): 586 last_step = session.run(self._global_step_tensor) 587 if last_step != self._timer.last_triggered_step(): 588 self._save(session, last_step) 589 for l in self._listeners: 590 l.end(session, last_step) 591 592 def _save(self, session, step): 593 """Saves the latest checkpoint, returns should_stop.""" 594 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 595 596 for l in self._listeners: 597 l.before_save(session, step) 598 599 self._get_saver().save(session, self._save_path, global_step=step) 600 self._summary_writer.add_session_log( 601 SessionLog( 602 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 603 step) 604 605 should_stop = False 606 for l in self._listeners: 607 if l.after_save(session, step): 608 logging.info( 609 "A CheckpointSaverListener requested that training be stopped. " 610 "listener: {}".format(l)) 611 should_stop = True 612 return should_stop 613 614 def _get_saver(self): 615 if self._saver is not None: 616 return self._saver 617 elif self._scaffold is not None: 618 return self._scaffold.saver 619 620 # Get saver from the SAVERS collection if present. 621 collection_key = ops.GraphKeys.SAVERS 622 savers = ops.get_collection(collection_key) 623 if not savers: 624 raise RuntimeError( 625 "No items in collection {}. Please add a saver to the collection " 626 "or provide a saver or scaffold.".format(collection_key)) 627 elif len(savers) > 1: 628 raise RuntimeError( 629 "More than one item in collection {}. " 630 "Please indicate which one to use by passing it to the constructor.". 631 format(collection_key)) 632 633 self._saver = savers[0] 634 return savers[0] 635 636 637@tf_export(v1=["train.StepCounterHook"]) 638class StepCounterHook(session_run_hook.SessionRunHook): 639 """Hook that counts steps per second.""" 640 641 def __init__(self, 642 every_n_steps=100, 643 every_n_secs=None, 644 output_dir=None, 645 summary_writer=None): 646 647 if (every_n_steps is None) == (every_n_secs is None): 648 raise ValueError( 649 "exactly one of every_n_steps and every_n_secs should be provided.") 650 self._timer = SecondOrStepTimer(every_steps=every_n_steps, 651 every_secs=every_n_secs) 652 653 self._summary_writer = summary_writer 654 self._output_dir = output_dir 655 self._last_global_step = None 656 self._global_step_check_count = 0 657 self._steps_per_run = 1 658 659 def _set_steps_per_run(self, steps_per_run): 660 self._steps_per_run = steps_per_run 661 662 def begin(self): 663 if self._summary_writer is None and self._output_dir: 664 self._summary_writer = SummaryWriterCache.get(self._output_dir) 665 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 666 if self._global_step_tensor is None: 667 raise RuntimeError( 668 "Global step should be created to use StepCounterHook.") 669 self._summary_tag = training_util.get_global_step().op.name + "/sec" 670 671 def before_run(self, run_context): # pylint: disable=unused-argument 672 return SessionRunArgs(self._global_step_tensor) 673 674 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 675 steps_per_sec = elapsed_steps / elapsed_time 676 if self._summary_writer is not None: 677 summary = Summary(value=[Summary.Value( 678 tag=self._summary_tag, simple_value=steps_per_sec)]) 679 self._summary_writer.add_summary(summary, global_step) 680 logging.info("%s: %g", self._summary_tag, steps_per_sec) 681 682 def after_run(self, run_context, run_values): 683 _ = run_context 684 685 stale_global_step = run_values.results 686 if self._timer.should_trigger_for_step( 687 stale_global_step + self._steps_per_run): 688 # get the real value after train op. 689 global_step = run_context.session.run(self._global_step_tensor) 690 if self._timer.should_trigger_for_step(global_step): 691 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 692 global_step) 693 if elapsed_time is not None: 694 self._log_and_record(elapsed_steps, elapsed_time, global_step) 695 696 # Check whether the global step has been increased. Here, we do not use the 697 # timer.last_triggered_step as the timer might record a different global 698 # step value such that the comparison could be unreliable. For simplicity, 699 # we just compare the stale_global_step with previously recorded version. 700 if stale_global_step == self._last_global_step: 701 # Here, we use a counter to count how many times we have observed that the 702 # global step has not been increased. For some Optimizers, the global step 703 # is not increased each time by design. For example, SyncReplicaOptimizer 704 # doesn't increase the global step in worker's main train step. 705 self._global_step_check_count += 1 706 if self._global_step_check_count % 20 == 0: 707 self._global_step_check_count = 0 708 logging.warning( 709 "It seems that global step (tf.train.get_global_step) has not " 710 "been increased. Current value (could be stable): %s vs previous " 711 "value: %s. You could increase the global step by passing " 712 "tf.train.get_global_step() to Optimizer.apply_gradients or " 713 "Optimizer.minimize.", stale_global_step, self._last_global_step) 714 else: 715 # Whenever we observe the increment, reset the counter. 716 self._global_step_check_count = 0 717 718 self._last_global_step = stale_global_step 719 720 721@tf_export(v1=["train.NanLossDuringTrainingError"]) 722class NanLossDuringTrainingError(RuntimeError): 723 724 def __str__(self): 725 return "NaN loss during training." 726 727 728@tf_export(v1=["train.NanTensorHook"]) 729class NanTensorHook(session_run_hook.SessionRunHook): 730 """Monitors the loss tensor and stops training if loss is NaN. 731 732 Can either fail with exception or just stop training. 733 """ 734 735 def __init__(self, loss_tensor, fail_on_nan_loss=True): 736 """Initializes a `NanTensorHook`. 737 738 Args: 739 loss_tensor: `Tensor`, the loss tensor. 740 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 741 """ 742 self._loss_tensor = loss_tensor 743 self._fail_on_nan_loss = fail_on_nan_loss 744 745 def before_run(self, run_context): # pylint: disable=unused-argument 746 return SessionRunArgs(self._loss_tensor) 747 748 def after_run(self, run_context, run_values): 749 if np.isnan(run_values.results): 750 failure_message = "Model diverged with loss = NaN." 751 if self._fail_on_nan_loss: 752 logging.error(failure_message) 753 raise NanLossDuringTrainingError 754 else: 755 logging.warning(failure_message) 756 # We don't raise an error but we request stop without an exception. 757 run_context.request_stop() 758 759 760@tf_export(v1=["train.SummarySaverHook"]) 761class SummarySaverHook(session_run_hook.SessionRunHook): 762 """Saves summaries every N steps.""" 763 764 def __init__(self, 765 save_steps=None, 766 save_secs=None, 767 output_dir=None, 768 summary_writer=None, 769 scaffold=None, 770 summary_op=None): 771 """Initializes a `SummarySaverHook`. 772 773 Args: 774 save_steps: `int`, save summaries every N steps. Exactly one of 775 `save_secs` and `save_steps` should be set. 776 save_secs: `int`, save summaries every N seconds. 777 output_dir: `string`, the directory to save the summaries to. Only used 778 if no `summary_writer` is supplied. 779 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 780 one will be created accordingly. 781 scaffold: `Scaffold` to get summary_op if it's not provided. 782 summary_op: `Tensor` of type `string` containing the serialized `Summary` 783 protocol buffer or a list of `Tensor`. They are most likely an output 784 by TF summary methods like `tf.summary.scalar` or 785 `tf.summary.merge_all`. It can be passed in as one tensor; if more 786 than one, they must be passed in as a list. 787 788 Raises: 789 ValueError: Exactly one of scaffold or summary_op should be set. 790 """ 791 if ((scaffold is None and summary_op is None) or 792 (scaffold is not None and summary_op is not None)): 793 raise ValueError( 794 "Exactly one of scaffold or summary_op must be provided.") 795 self._summary_op = summary_op 796 self._summary_writer = summary_writer 797 self._output_dir = output_dir 798 self._scaffold = scaffold 799 self._timer = SecondOrStepTimer(every_secs=save_secs, 800 every_steps=save_steps) 801 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 802 803 def begin(self): 804 if self._summary_writer is None and self._output_dir: 805 self._summary_writer = SummaryWriterCache.get(self._output_dir) 806 self._next_step = None 807 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 808 if self._global_step_tensor is None: 809 raise RuntimeError( 810 "Global step should be created to use SummarySaverHook.") 811 812 def before_run(self, run_context): # pylint: disable=unused-argument 813 self._request_summary = ( 814 self._next_step is None or 815 self._timer.should_trigger_for_step(self._next_step)) 816 requests = {"global_step": self._global_step_tensor} 817 if self._request_summary: 818 if self._get_summary_op() is not None: 819 requests["summary"] = self._get_summary_op() 820 821 return SessionRunArgs(requests) 822 823 def after_run(self, run_context, run_values): 824 _ = run_context 825 if not self._summary_writer: 826 return 827 828 stale_global_step = run_values.results["global_step"] 829 global_step = stale_global_step + 1 830 if self._next_step is None or self._request_summary: 831 global_step = run_context.session.run(self._global_step_tensor) 832 833 if self._next_step is None: 834 self._summary_writer.add_session_log( 835 SessionLog(status=SessionLog.START), global_step) 836 837 if self._request_summary: 838 self._timer.update_last_triggered_step(global_step) 839 if "summary" in run_values.results: 840 for summary in run_values.results["summary"]: 841 self._summary_writer.add_summary(summary, global_step) 842 843 self._next_step = global_step + 1 844 845 def end(self, session=None): 846 if self._summary_writer: 847 self._summary_writer.flush() 848 849 def _get_summary_op(self): 850 """Fetches the summary op either from self._summary_op or self._scaffold. 851 852 Returns: 853 Returns a list of summary `Tensor`. 854 """ 855 summary_op = None 856 if self._summary_op is not None: 857 summary_op = self._summary_op 858 elif self._scaffold.summary_op is not None: 859 summary_op = self._scaffold.summary_op 860 861 if summary_op is None: 862 return None 863 864 if not isinstance(summary_op, list): 865 return [summary_op] 866 return summary_op 867 868 869@tf_export(v1=["train.GlobalStepWaiterHook"]) 870class GlobalStepWaiterHook(session_run_hook.SessionRunHook): 871 """Delays execution until global step reaches `wait_until_step`. 872 873 This hook delays execution until global step reaches to `wait_until_step`. It 874 is used to gradually start workers in distributed settings. One example usage 875 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that 876 task_id=0 is the chief. 877 """ 878 879 def __init__(self, wait_until_step): 880 """Initializes a `GlobalStepWaiterHook`. 881 882 Args: 883 wait_until_step: an `int` shows until which global step should we wait. 884 """ 885 self._wait_until_step = wait_until_step 886 887 def begin(self): 888 self._worker_is_started = False 889 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 890 if self._global_step_tensor is None: 891 raise RuntimeError( 892 "Global step should be created to use _GlobalStepWaiterHook.") 893 894 def before_run(self, run_context): 895 if self._worker_is_started: 896 return None 897 898 if self._wait_until_step <= 0: 899 self._worker_is_started = True 900 return None 901 902 logging.info("Waiting for global step %d before starting training.", 903 self._wait_until_step) 904 last_logged_step = 0 905 while True: 906 current_step = run_context.session.run(self._global_step_tensor) 907 if current_step >= self._wait_until_step: 908 self._worker_is_started = True 909 return None 910 if current_step - last_logged_step > 1000: 911 logging.info("Waiting for global step %d before starting training. " 912 "Current step is %d.", self._wait_until_step, current_step) 913 last_logged_step = current_step 914 time.sleep(0.5) 915 916 917@tf_export(v1=["train.FinalOpsHook"]) 918class FinalOpsHook(session_run_hook.SessionRunHook): 919 """A hook which evaluates `Tensors` at the end of a session.""" 920 921 def __init__(self, final_ops, final_ops_feed_dict=None): 922 """Initializes `FinalOpHook` with ops to run at the end of the session. 923 924 Args: 925 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of 926 names to `Tensors`. 927 final_ops_feed_dict: A feed dictionary to use when running 928 `final_ops_dict`. 929 """ 930 self._final_ops = final_ops 931 self._final_ops_feed_dict = final_ops_feed_dict 932 self._final_ops_values = None 933 934 @property 935 def final_ops_values(self): 936 return self._final_ops_values 937 938 def end(self, session): 939 if self._final_ops is not None: 940 try: 941 self._final_ops_values = session.run( 942 self._final_ops, feed_dict=self._final_ops_feed_dict) 943 except (errors.OutOfRangeError, StopIteration) as e: 944 logging.warning( 945 "An OutOfRangeError or StopIteration exception is raised by the " 946 "code in FinalOpsHook. This typically means the Ops running by the " 947 "FinalOpsHook have a dependency back to some input source, which " 948 "should not happen. For example, for metrics in " 949 "tf.estimator.Estimator, all metrics functions return two Ops: " 950 "`value_op` and `update_op`. Estimator.evaluate calls the " 951 "`update_op` for each batch of the data in input source and, once " 952 "it is exhausted, it call the `value_op` to get the metric values. " 953 "The `value_op` here should have dependency back to variables " 954 "reading only, rather than reading another batch from input. " 955 "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " 956 "another data reading, which ends OutOfRangeError/StopIteration. " 957 "Please fix that.") 958 raise e 959 960 961@tf_export(v1=["train.FeedFnHook"]) 962class FeedFnHook(session_run_hook.SessionRunHook): 963 """Runs `feed_fn` and sets the `feed_dict` accordingly.""" 964 965 def __init__(self, feed_fn): 966 """Initializes a `FeedFnHook`. 967 968 Args: 969 feed_fn: function that takes no arguments and returns `dict` of `Tensor` 970 to feed. 971 """ 972 self.feed_fn = feed_fn 973 974 def before_run(self, run_context): # pylint: disable=unused-argument 975 return session_run_hook.SessionRunArgs( 976 fetches=None, feed_dict=self.feed_fn()) 977 978 979@tf_export(v1=["train.ProfilerHook"]) 980class ProfilerHook(session_run_hook.SessionRunHook): 981 """Captures CPU/GPU profiling information every N steps or seconds. 982 983 This produces files called "timeline-<step>.json", which are in Chrome 984 Trace format. 985 986 For more information see: 987 https://github.com/catapult-project/catapult/blob/master/tracing/README.md 988 """ 989 990 def __init__(self, 991 save_steps=None, 992 save_secs=None, 993 output_dir="", 994 show_dataflow=True, 995 show_memory=False): 996 """Initializes a hook that takes periodic profiling snapshots. 997 998 `options.run_metadata` argument of `tf.Session.Run` is used to collect 999 metadata about execution. This hook sets the metadata and dumps it in Chrome 1000 Trace format. 1001 1002 1003 Args: 1004 save_steps: `int`, save profile traces every N steps. Exactly one of 1005 `save_secs` and `save_steps` should be set. 1006 save_secs: `int` or `float`, save profile traces every N seconds. 1007 output_dir: `string`, the directory to save the profile traces to. 1008 Defaults to the current directory. 1009 show_dataflow: `bool`, if True, add flow events to the trace connecting 1010 producers and consumers of tensors. 1011 show_memory: `bool`, if True, add object snapshot events to the trace 1012 showing the sizes and lifetimes of tensors. 1013 """ 1014 self._output_file = os.path.join(output_dir, "timeline-{}.json") 1015 self._file_writer = SummaryWriterCache.get(output_dir) 1016 self._show_dataflow = show_dataflow 1017 self._show_memory = show_memory 1018 self._timer = SecondOrStepTimer( 1019 every_secs=save_secs, every_steps=save_steps) 1020 1021 def begin(self): 1022 self._next_step = None 1023 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 1024 if self._global_step_tensor is None: 1025 raise RuntimeError("Global step should be created to use ProfilerHook.") 1026 1027 def before_run(self, run_context): 1028 self._request_summary = ( 1029 self._next_step is not None and 1030 self._timer.should_trigger_for_step(self._next_step)) 1031 requests = {"global_step": self._global_step_tensor} 1032 opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1033 if self._request_summary else None) 1034 1035 return SessionRunArgs(requests, options=opts) 1036 1037 def after_run(self, run_context, run_values): 1038 stale_global_step = run_values.results["global_step"] 1039 if self._next_step is None: 1040 # Update the timer so that it does not activate until N steps or seconds 1041 # have passed. 1042 self._timer.update_last_triggered_step(stale_global_step) 1043 global_step = stale_global_step + 1 1044 if self._request_summary: 1045 global_step = run_context.session.run(self._global_step_tensor) 1046 self._timer.update_last_triggered_step(global_step) 1047 self._save(global_step, 1048 self._output_file.format(global_step), 1049 run_values.run_metadata.step_stats) 1050 self._file_writer.add_run_metadata(run_values.run_metadata, 1051 "step_%d" % global_step) 1052 1053 self._next_step = global_step + 1 1054 1055 def _save(self, step, save_path, step_stats): 1056 logging.info("Saving timeline for %d into '%s'.", step, save_path) 1057 with gfile.Open(save_path, "w") as f: 1058 trace = timeline.Timeline(step_stats) 1059 f.write( 1060 trace.generate_chrome_trace_format( 1061 show_dataflow=self._show_dataflow, show_memory=self._show_memory)) 1062 1063 1064def _as_graph_element(obj): 1065 """Retrieves Graph element.""" 1066 graph = ops.get_default_graph() 1067 if not isinstance(obj, six.string_types): 1068 if not hasattr(obj, "graph") or obj.graph != graph: 1069 raise ValueError("Passed %s should have graph attribute that is equal " 1070 "to current graph %s." % (obj, graph)) 1071 return obj 1072 if ":" in obj: 1073 element = graph.as_graph_element(obj) 1074 else: 1075 element = graph.as_graph_element(obj + ":0") 1076 # Check that there is no :1 (e.g. it's single output). 1077 try: 1078 graph.as_graph_element(obj + ":1") 1079 except (KeyError, ValueError): 1080 pass 1081 else: 1082 raise ValueError("Name %s is ambiguous, " 1083 "as this `Operation` has multiple outputs " 1084 "(at least 2)." % obj) 1085 return element 1086