1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Some common SessionRunHook classes. 16 17Note that the symbols that are exported to v1 tf.train namespace are also 18exported to v2 in tf.estimator namespace. See 19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27import time 28 29import numpy as np 30import six 31 32from tensorflow.core.framework.summary_pb2 import Summary 33from tensorflow.core.protobuf import config_pb2 34from tensorflow.core.util.event_pb2 import SessionLog 35from tensorflow.python.client import timeline 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import meta_graph 39from tensorflow.python.framework import ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.platform import gfile 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.training import session_run_hook 45from tensorflow.python.training import training_util 46from tensorflow.python.training.session_run_hook import SessionRunArgs 47from tensorflow.python.training.summary_io import SummaryWriterCache 48from tensorflow.python.util.tf_export import tf_export 49 50_HOOKS = "hooks" 51_STEPS_PER_RUN_VAR = "steps_per_run" 52 53 54class _HookTimer(object): 55 """Base timer for determining when Hooks should trigger. 56 57 Should not be instantiated directly. 58 """ 59 60 def __init__(self): 61 pass 62 63 def reset(self): 64 """Resets the timer.""" 65 pass 66 67 def should_trigger_for_step(self, step): 68 """Return true if the timer should trigger for the specified step.""" 69 raise NotImplementedError 70 71 def update_last_triggered_step(self, step): 72 """Update the last triggered time and step number. 73 74 Args: 75 step: The current step. 76 77 Returns: 78 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number 79 of seconds between the current trigger and the last one (a float), and 80 `elapsed_steps` is the number of steps between the current trigger and 81 the last one. Both values will be set to `None` on the first trigger. 82 """ 83 raise NotImplementedError 84 85 def last_triggered_step(self): 86 """Returns the last triggered time step or None if never triggered.""" 87 raise NotImplementedError 88 89 90@tf_export(v1=["train.SecondOrStepTimer"]) 91class SecondOrStepTimer(_HookTimer): 92 """Timer that triggers at most once every N seconds or once every N steps. 93 94 This symbol is also exported to v2 in tf.estimator namespace. See 95 https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 96 """ 97 98 def __init__(self, every_secs=None, every_steps=None): 99 self.reset() 100 self._every_secs = every_secs 101 self._every_steps = every_steps 102 103 if self._every_secs is None and self._every_steps is None: 104 raise ValueError("Either every_secs or every_steps should be provided.") 105 if (self._every_secs is not None) and (self._every_steps is not None): 106 raise ValueError("Can not provide both every_secs and every_steps.") 107 108 super(SecondOrStepTimer, self).__init__() 109 110 def reset(self): 111 self._last_triggered_step = None 112 self._last_triggered_time = None 113 114 def should_trigger_for_step(self, step): 115 """Return true if the timer should trigger for the specified step. 116 117 Args: 118 step: Training step to trigger on. 119 120 Returns: 121 True if the difference between the current time and the time of the last 122 trigger exceeds `every_secs`, or if the difference between the current 123 step and the last triggered step exceeds `every_steps`. False otherwise. 124 """ 125 if self._last_triggered_step is None: 126 return True 127 128 if self._last_triggered_step == step: 129 return False 130 131 if self._every_secs is not None: 132 if time.time() >= self._last_triggered_time + self._every_secs: 133 return True 134 135 if self._every_steps is not None: 136 if step >= self._last_triggered_step + self._every_steps: 137 return True 138 139 return False 140 141 def update_last_triggered_step(self, step): 142 current_time = time.time() 143 if self._last_triggered_time is None: 144 elapsed_secs = None 145 elapsed_steps = None 146 else: 147 elapsed_secs = current_time - self._last_triggered_time 148 elapsed_steps = step - self._last_triggered_step 149 150 self._last_triggered_time = current_time 151 self._last_triggered_step = step 152 return (elapsed_secs, elapsed_steps) 153 154 def last_triggered_step(self): 155 return self._last_triggered_step 156 157 158class NeverTriggerTimer(_HookTimer): 159 """Timer that never triggers.""" 160 161 def should_trigger_for_step(self, step): 162 _ = step 163 return False 164 165 def update_last_triggered_step(self, step): 166 _ = step 167 return (None, None) 168 169 def last_triggered_step(self): 170 return None 171 172 173@tf_export(v1=["train.LoggingTensorHook"]) 174class LoggingTensorHook(session_run_hook.SessionRunHook): 175 """Prints the given tensors every N local steps, every N seconds, or at end. 176 177 The tensors will be printed to the log, with `INFO` severity. If you are not 178 seeing the logs, you might want to add the following line after your imports: 179 180 ```python 181 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 182 ``` 183 184 Note that if `at_end` is True, `tensors` should not include any tensor 185 whose evaluation produces a side effect such as consuming additional inputs. 186 """ 187 188 def __init__(self, 189 tensors, 190 every_n_iter=None, 191 every_n_secs=None, 192 at_end=False, 193 formatter=None): 194 """Initializes a `LoggingTensorHook`. 195 196 Args: 197 tensors: `dict` that maps string-valued tags to tensors/tensor names, or 198 `iterable` of tensors/tensor names. 199 every_n_iter: `int`, print the values of `tensors` once every N local 200 steps taken on the current worker. 201 every_n_secs: `int` or `float`, print the values of `tensors` once every N 202 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 203 provided. 204 at_end: `bool` specifying whether to print the values of `tensors` at the 205 end of the run. 206 formatter: function, takes dict of `tag`->`Tensor` and returns a string. 207 If `None` uses default printing all tensors. 208 209 Raises: 210 ValueError: if `every_n_iter` is non-positive. 211 """ 212 only_log_at_end = ( 213 at_end and (every_n_iter is None) and (every_n_secs is None)) 214 if (not only_log_at_end and 215 (every_n_iter is None) == (every_n_secs is None)): 216 raise ValueError( 217 "either at_end and/or exactly one of every_n_iter and every_n_secs " 218 "must be provided.") 219 if every_n_iter is not None and every_n_iter <= 0: 220 raise ValueError("invalid every_n_iter=%s." % every_n_iter) 221 if not isinstance(tensors, dict): 222 self._tag_order = tensors 223 tensors = {item: item for item in tensors} 224 else: 225 self._tag_order = sorted(tensors.keys()) 226 self._tensors = tensors 227 self._formatter = formatter 228 self._timer = ( 229 NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer( 230 every_secs=every_n_secs, every_steps=every_n_iter)) 231 self._log_at_end = at_end 232 233 def begin(self): 234 self._timer.reset() 235 self._iter_count = 0 236 # Convert names to tensors if given 237 self._current_tensors = { 238 tag: _as_graph_element(tensor) 239 for (tag, tensor) in self._tensors.items() 240 } 241 242 def before_run(self, run_context): # pylint: disable=unused-argument 243 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) 244 if self._should_trigger: 245 return SessionRunArgs(self._current_tensors) 246 else: 247 return None 248 249 def _log_tensors(self, tensor_values): 250 original = np.get_printoptions() 251 np.set_printoptions(suppress=True) 252 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) 253 if self._formatter: 254 logging.info(self._formatter(tensor_values)) 255 else: 256 stats = [] 257 for tag in self._tag_order: 258 stats.append("%s = %s" % (tag, tensor_values[tag])) 259 if elapsed_secs is not None: 260 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) 261 else: 262 logging.info("%s", ", ".join(stats)) 263 np.set_printoptions(**original) 264 265 def after_run(self, run_context, run_values): 266 _ = run_context 267 if self._should_trigger: 268 self._log_tensors(run_values.results) 269 270 self._iter_count += 1 271 272 def end(self, session): 273 if self._log_at_end: 274 values = session.run(self._current_tensors) 275 self._log_tensors(values) 276 277 278def get_or_create_steps_per_run_variable(): 279 """Gets or creates the steps_per_run variable. 280 281 In Estimator, the user provided computation, the model_fn, is wrapped 282 inside a tf.while_loop for peak performance. The iterations of the loop are 283 specified by this variable, which adjusts its value on the CPU after each 284 device program execution and before the next execution. 285 286 The purpose of using a variable, rather than a constant, is to allow 287 Estimator adapt the device training iterations according to the final steps 288 specified by users. For example, if the user sets the steps_per_run as 289 4 and steps as 10 in Estimator.train(), the steps_per_run 290 variable will have the following value before each training run. 291 292 - 1-st execution: steps_per_run = 4 293 - 2-nd execution: steps_per_run = 4 294 - 3-rd execution: steps_per_run = 2 295 296 As model_fn increases the global step once per train_op invocation, the global 297 step is 10 after all executions, matching the steps=10 inputs passed in by 298 users. 299 300 Returns: 301 A TF non-trainable resource variable. 302 303 Raises: 304 RuntimeError: If multi steps_per_run variables were found. 305 """ 306 graph = ops.get_default_graph() 307 collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR) 308 steps_per_run_vars = graph.get_collection(collection_name) 309 if len(steps_per_run_vars) == 1: 310 return steps_per_run_vars[0] 311 elif len(steps_per_run_vars) > 1: 312 raise RuntimeError("Multiple steps_per_run_var in collection.") 313 314 with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE): 315 return variable_scope.get_variable( 316 _STEPS_PER_RUN_VAR, 317 initializer=init_ops.ones_initializer(), 318 shape=[], 319 dtype=dtypes.int32, 320 trainable=False, 321 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], 322 use_resource=True) 323 324 325class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): 326 """Hook that requests stop at a specified step.""" 327 328 def __init__(self, num_steps=None, last_step=None, steps_per_run=1): 329 """Initializes a `MultiStepStopAtStepHook`. 330 331 This hook requests stop after either a number of steps have been 332 executed or a last step has been reached. Only one of the two options can be 333 specified. 334 335 if `num_steps` is specified, it indicates the number of steps to execute 336 after `begin()` is called. If instead `last_step` is specified, it 337 indicates the last step we want to execute, as passed to the `after_run()` 338 call. 339 340 In Estimator, the user provided computation, the model_fn, is wrapped 341 inside a tf.while_loop for peak performance. The steps_per_run variable 342 determines the number of iterations of the loop before returning to the CPU. 343 344 Args: 345 num_steps: Number of steps to execute. 346 last_step: Step after which to stop. 347 steps_per_run: Number of steps executed per run call. 348 349 Raises: 350 ValueError: If one of the arguments is invalid. 351 """ 352 if num_steps is None and last_step is None: 353 raise ValueError("One of num_steps or last_step must be specified.") 354 if num_steps is not None and last_step is not None: 355 raise ValueError("Only one of num_steps or last_step can be specified.") 356 if steps_per_run is None or steps_per_run < 1: 357 raise ValueError("steps_per_run should be greater than 0") 358 self._num_steps = num_steps 359 self._last_step = last_step 360 self._steps_per_run_initial_value = steps_per_run 361 362 def begin(self): 363 self._global_step_tensor = training_util.get_global_step() 364 if self._global_step_tensor is None: 365 raise RuntimeError("Global step should be created to use StopAtStepHook.") 366 self._steps_per_run_variable = get_or_create_steps_per_run_variable() 367 368 def _update_steps_per_run_variable(self, global_step, session): 369 steps = min(self._last_step - global_step, 370 self._steps_per_run_initial_value) 371 self._steps_per_run_variable.load(steps, session=session) 372 373 def after_create_session(self, session, coord): 374 global_step = session.run(self._global_step_tensor) 375 if self._last_step is None: 376 self._last_step = global_step + self._num_steps 377 self._update_steps_per_run_variable(global_step, session) 378 379 def after_run(self, run_context, run_values): 380 # Global step cannot be retrieved via SessionRunArgs and before_run due to 381 # race condition in hook execution. 382 global_step = run_context.session.run(self._global_step_tensor) 383 if global_step >= self._last_step: 384 run_context.request_stop() 385 else: 386 self._update_steps_per_run_variable(global_step, run_context.session) 387 388 389@tf_export(v1=["train.StopAtStepHook"]) 390class StopAtStepHook(session_run_hook.SessionRunHook): 391 """Hook that requests stop at a specified step.""" 392 393 def __init__(self, num_steps=None, last_step=None): 394 """Initializes a `StopAtStepHook`. 395 396 This hook requests stop after either a number of steps have been 397 executed or a last step has been reached. Only one of the two options can be 398 specified. 399 400 if `num_steps` is specified, it indicates the number of steps to execute 401 after `begin()` is called. If instead `last_step` is specified, it 402 indicates the last step we want to execute, as passed to the `after_run()` 403 call. 404 405 Args: 406 num_steps: Number of steps to execute. 407 last_step: Step after which to stop. 408 409 Raises: 410 ValueError: If one of the arguments is invalid. 411 """ 412 if num_steps is None and last_step is None: 413 raise ValueError("One of num_steps or last_step must be specified.") 414 if num_steps is not None and last_step is not None: 415 raise ValueError("Only one of num_steps or last_step can be specified.") 416 self._num_steps = num_steps 417 self._last_step = last_step 418 419 def begin(self): 420 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 421 if self._global_step_tensor is None: 422 raise RuntimeError("Global step should be created to use StopAtStepHook.") 423 424 def after_create_session(self, session, coord): 425 if self._last_step is None: 426 global_step = session.run(self._global_step_tensor) 427 self._last_step = global_step + self._num_steps 428 429 def before_run(self, run_context): # pylint: disable=unused-argument 430 return SessionRunArgs(self._global_step_tensor) 431 432 def after_run(self, run_context, run_values): 433 global_step = run_values.results + 1 434 if global_step >= self._last_step: 435 # Check latest global step to ensure that the targeted last step is 436 # reached. global_step read tensor is the value of global step 437 # before running the operation. We're not sure whether current session.run 438 # incremented the global_step or not. Here we're checking it. 439 440 step = run_context.session.run(self._global_step_tensor) 441 if step >= self._last_step: 442 run_context.request_stop() 443 444 445@tf_export(v1=["train.CheckpointSaverListener"]) 446class CheckpointSaverListener(object): 447 """Interface for listeners that take action before or after checkpoint save. 448 449 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is 450 triggered, and provides callbacks at the following points: 451 - before using the session 452 - before each call to `Saver.save()` 453 - after each call to `Saver.save()` 454 - at the end of session 455 456 To use a listener, implement a class and pass the listener to a 457 `CheckpointSaverHook`, as in this example: 458 459 ```python 460 class ExampleCheckpointSaverListener(CheckpointSaverListener): 461 def begin(self): 462 # You can add ops to the graph here. 463 print('Starting the session.') 464 self.your_tensor = ... 465 466 def before_save(self, session, global_step_value): 467 print('About to write a checkpoint') 468 469 def after_save(self, session, global_step_value): 470 print('Done writing checkpoint.') 471 if decided_to_stop_training(): 472 return True 473 474 def end(self, session, global_step_value): 475 print('Done with the session.') 476 477 ... 478 listener = ExampleCheckpointSaverListener() 479 saver_hook = tf.estimator.CheckpointSaverHook( 480 checkpoint_dir, listeners=[listener]) 481 with 482 tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): 483 ... 484 ``` 485 486 A `CheckpointSaverListener` may simply take some action after every 487 checkpoint save. It is also possible for the listener to use its own schedule 488 to act less frequently, e.g. based on global_step_value. In this case, 489 implementors should implement the `end()` method to handle actions related to 490 the last checkpoint save. But the listener should not act twice if 491 `after_save()` already handled this last checkpoint save. 492 493 A `CheckpointSaverListener` can request training to be stopped, by returning 494 True in `after_save`. Please note that, in replicated distributed training 495 setting, only `chief` should use this behavior. Otherwise each worker will do 496 their own evaluation, which may be wasteful of resources. 497 """ 498 499 def begin(self): 500 pass 501 502 def before_save(self, session, global_step_value): 503 pass 504 505 def after_save(self, session, global_step_value): 506 pass 507 508 def end(self, session, global_step_value): 509 pass 510 511 512@tf_export(v1=["train.CheckpointSaverHook"]) 513class CheckpointSaverHook(session_run_hook.SessionRunHook): 514 """Saves checkpoints every N steps or seconds.""" 515 516 def __init__(self, 517 checkpoint_dir, 518 save_secs=None, 519 save_steps=None, 520 saver=None, 521 checkpoint_basename="model.ckpt", 522 scaffold=None, 523 listeners=None, 524 save_graph_def=True): 525 """Initializes a `CheckpointSaverHook`. 526 527 Args: 528 checkpoint_dir: `str`, base directory for the checkpoint files. 529 save_secs: `int`, save every N secs. 530 save_steps: `int`, save every N steps. 531 saver: `Saver` object, used for saving. 532 checkpoint_basename: `str`, base name for the checkpoint files. 533 scaffold: `Scaffold`, use to get saver object. 534 listeners: List of `CheckpointSaverListener` subclass instances. Used for 535 callbacks that run immediately before or after this hook saves the 536 checkpoint. 537 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 538 `checkpoint_dir`. The GraphDef is saved after the session is created as 539 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 540 `model.ckpt-*.meta`. 541 542 Raises: 543 ValueError: One of `save_steps` or `save_secs` should be set. 544 ValueError: At most one of `saver` or `scaffold` should be set. 545 """ 546 logging.info("Create CheckpointSaverHook.") 547 if saver is not None and scaffold is not None: 548 raise ValueError("You cannot provide both saver and scaffold.") 549 self._saver = saver 550 self._checkpoint_dir = checkpoint_dir 551 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 552 self._scaffold = scaffold 553 self._timer = SecondOrStepTimer( 554 every_secs=save_secs, every_steps=save_steps) 555 self._listeners = listeners or [] 556 self._steps_per_run = 1 557 self._save_graph_def = save_graph_def 558 559 def _set_steps_per_run(self, steps_per_run): 560 self._steps_per_run = steps_per_run 561 562 def begin(self): 563 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 564 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 565 if self._global_step_tensor is None: 566 raise RuntimeError( 567 "Global step should be created to use CheckpointSaverHook.") 568 for l in self._listeners: 569 l.begin() 570 571 def after_create_session(self, session, coord): 572 global_step = session.run(self._global_step_tensor) 573 if self._save_graph_def: 574 # We do write graph and saver_def at the first call of before_run. 575 # We cannot do this in begin, since we let other hooks to change graph and 576 # add variables in begin. Graph is finalized after all begin calls. 577 training_util.write_graph( 578 ops.get_default_graph().as_graph_def(add_shapes=True), 579 self._checkpoint_dir, "graph.pbtxt") 580 saver_def = self._get_saver().saver_def if self._get_saver() else None 581 graph = ops.get_default_graph() 582 meta_graph_def = meta_graph.create_meta_graph_def( 583 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) 584 self._summary_writer.add_graph(graph) 585 self._summary_writer.add_meta_graph(meta_graph_def) 586 # The checkpoint saved here is the state at step "global_step". 587 self._save(session, global_step) 588 self._timer.update_last_triggered_step(global_step) 589 590 def before_run(self, run_context): # pylint: disable=unused-argument 591 return SessionRunArgs(self._global_step_tensor) 592 593 def after_run(self, run_context, run_values): 594 stale_global_step = run_values.results 595 if self._timer.should_trigger_for_step(stale_global_step + 596 self._steps_per_run): 597 # get the real value after train op. 598 global_step = run_context.session.run(self._global_step_tensor) 599 if self._timer.should_trigger_for_step(global_step): 600 self._timer.update_last_triggered_step(global_step) 601 if self._save(run_context.session, global_step): 602 run_context.request_stop() 603 604 def end(self, session): 605 last_step = session.run(self._global_step_tensor) 606 if last_step != self._timer.last_triggered_step(): 607 self._save(session, last_step) 608 for l in self._listeners: 609 l.end(session, last_step) 610 611 def _save(self, session, step): 612 """Saves the latest checkpoint, returns should_stop.""" 613 logging.info("Calling checkpoint listeners before saving checkpoint %d...", 614 step) 615 for l in self._listeners: 616 l.before_save(session, step) 617 618 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 619 self._get_saver().save(session, self._save_path, global_step=step, 620 write_meta_graph=self._save_graph_def) 621 self._summary_writer.add_session_log( 622 SessionLog( 623 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 624 step) 625 logging.info("Calling checkpoint listeners after saving checkpoint %d...", 626 step) 627 should_stop = False 628 for l in self._listeners: 629 if l.after_save(session, step): 630 logging.info( 631 "A CheckpointSaverListener requested that training be stopped. " 632 "listener: {}".format(l)) 633 should_stop = True 634 return should_stop 635 636 def _get_saver(self): 637 if self._saver is not None: 638 return self._saver 639 elif self._scaffold is not None: 640 return self._scaffold.saver 641 642 # Get saver from the SAVERS collection if present. 643 collection_key = ops.GraphKeys.SAVERS 644 savers = ops.get_collection(collection_key) 645 if not savers: 646 raise RuntimeError( 647 "No items in collection {}. Please add a saver to the collection " 648 "or provide a saver or scaffold.".format(collection_key)) 649 elif len(savers) > 1: 650 raise RuntimeError( 651 "More than one item in collection {}. " 652 "Please indicate which one to use by passing it to the constructor." 653 .format(collection_key)) 654 655 self._saver = savers[0] 656 return savers[0] 657 658 659@tf_export(v1=["train.StepCounterHook"]) 660class StepCounterHook(session_run_hook.SessionRunHook): 661 """Hook that counts steps per second.""" 662 663 def __init__(self, 664 every_n_steps=100, 665 every_n_secs=None, 666 output_dir=None, 667 summary_writer=None): 668 669 if (every_n_steps is None) == (every_n_secs is None): 670 raise ValueError( 671 "exactly one of every_n_steps and every_n_secs should be provided.") 672 self._timer = SecondOrStepTimer( 673 every_steps=every_n_steps, every_secs=every_n_secs) 674 675 self._summary_writer = summary_writer 676 self._output_dir = output_dir 677 self._last_global_step = None 678 self._steps_per_run = 1 679 680 def _set_steps_per_run(self, steps_per_run): 681 self._steps_per_run = steps_per_run 682 683 def begin(self): 684 if self._summary_writer is None and self._output_dir: 685 self._summary_writer = SummaryWriterCache.get(self._output_dir) 686 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 687 if self._global_step_tensor is None: 688 raise RuntimeError( 689 "Global step should be created to use StepCounterHook.") 690 self._summary_tag = training_util.get_global_step().op.name + "/sec" 691 692 def before_run(self, run_context): # pylint: disable=unused-argument 693 return SessionRunArgs(self._global_step_tensor) 694 695 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 696 steps_per_sec = elapsed_steps / elapsed_time 697 if self._summary_writer is not None: 698 summary = Summary(value=[ 699 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 700 ]) 701 self._summary_writer.add_summary(summary, global_step) 702 logging.info("%s: %g", self._summary_tag, steps_per_sec) 703 704 def after_run(self, run_context, run_values): 705 _ = run_context 706 707 stale_global_step = run_values.results 708 if self._timer.should_trigger_for_step(stale_global_step + 709 self._steps_per_run): 710 # get the real value after train op. 711 global_step = run_context.session.run(self._global_step_tensor) 712 if self._timer.should_trigger_for_step(global_step): 713 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 714 global_step) 715 if elapsed_time is not None: 716 self._log_and_record(elapsed_steps, elapsed_time, global_step) 717 718 # Check whether the global step has been increased. Here, we do not use the 719 # timer.last_triggered_step as the timer might record a different global 720 # step value such that the comparison could be unreliable. For simplicity, 721 # we just compare the stale_global_step with previously recorded version. 722 if stale_global_step == self._last_global_step: 723 # Here, we give a warning in the first 5 times if we have observed that 724 # the global step has not been increased. For some Optimizers, the global 725 # step is not increased each time by design. For example, 726 # SyncReplicaOptimizer doesn't increase the global step in worker's main 727 # train step. 728 logging.log_first_n( 729 logging.WARN, 730 "It seems that global step (tf.train.get_global_step) has not " 731 "been increased. Current value (could be stable): %s vs previous " 732 "value: %s. You could increase the global step by passing " 733 "tf.train.get_global_step() to Optimizer.apply_gradients or " 734 "Optimizer.minimize.", 5, stale_global_step, self._last_global_step) 735 736 self._last_global_step = stale_global_step 737 738 739@tf_export(v1=["train.NanLossDuringTrainingError"]) 740class NanLossDuringTrainingError(RuntimeError): 741 742 def __str__(self): 743 return "NaN loss during training." 744 745 746@tf_export(v1=["train.NanTensorHook"]) 747class NanTensorHook(session_run_hook.SessionRunHook): 748 """Monitors the loss tensor and stops training if loss is NaN. 749 750 Can either fail with exception or just stop training. 751 """ 752 753 def __init__(self, loss_tensor, fail_on_nan_loss=True): 754 """Initializes a `NanTensorHook`. 755 756 Args: 757 loss_tensor: `Tensor`, the loss tensor. 758 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 759 """ 760 self._loss_tensor = loss_tensor 761 self._fail_on_nan_loss = fail_on_nan_loss 762 763 def before_run(self, run_context): # pylint: disable=unused-argument 764 return SessionRunArgs(self._loss_tensor) 765 766 def after_run(self, run_context, run_values): 767 if np.isnan(run_values.results): 768 failure_message = "Model diverged with loss = NaN." 769 if self._fail_on_nan_loss: 770 logging.error(failure_message) 771 raise NanLossDuringTrainingError 772 else: 773 logging.warning(failure_message) 774 # We don't raise an error but we request stop without an exception. 775 run_context.request_stop() 776 777 778@tf_export(v1=["train.SummarySaverHook"]) 779class SummarySaverHook(session_run_hook.SessionRunHook): 780 """Saves summaries every N steps.""" 781 782 def __init__(self, 783 save_steps=None, 784 save_secs=None, 785 output_dir=None, 786 summary_writer=None, 787 scaffold=None, 788 summary_op=None): 789 """Initializes a `SummarySaverHook`. 790 791 Args: 792 save_steps: `int`, save summaries every N steps. Exactly one of 793 `save_secs` and `save_steps` should be set. 794 save_secs: `int`, save summaries every N seconds. 795 output_dir: `string`, the directory to save the summaries to. Only used if 796 no `summary_writer` is supplied. 797 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 798 one will be created accordingly. 799 scaffold: `Scaffold` to get summary_op if it's not provided. 800 summary_op: `Tensor` of type `string` containing the serialized `Summary` 801 protocol buffer or a list of `Tensor`. They are most likely an output by 802 TF summary methods like `tf.compat.v1.summary.scalar` or 803 `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if 804 more than one, they must be passed in as a list. 805 806 Raises: 807 ValueError: Exactly one of scaffold or summary_op should be set. 808 """ 809 if ((scaffold is None and summary_op is None) or 810 (scaffold is not None and summary_op is not None)): 811 raise ValueError( 812 "Exactly one of scaffold or summary_op must be provided.") 813 self._summary_op = summary_op 814 self._summary_writer = summary_writer 815 self._output_dir = output_dir 816 self._scaffold = scaffold 817 self._timer = SecondOrStepTimer( 818 every_secs=save_secs, every_steps=save_steps) 819 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 820 821 def begin(self): 822 if self._summary_writer is None and self._output_dir: 823 self._summary_writer = SummaryWriterCache.get(self._output_dir) 824 self._next_step = None 825 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 826 if self._global_step_tensor is None: 827 raise RuntimeError( 828 "Global step should be created to use SummarySaverHook.") 829 830 def before_run(self, run_context): # pylint: disable=unused-argument 831 self._request_summary = ( 832 self._next_step is None or 833 self._timer.should_trigger_for_step(self._next_step)) 834 requests = {"global_step": self._global_step_tensor} 835 if self._request_summary: 836 if self._get_summary_op() is not None: 837 requests["summary"] = self._get_summary_op() 838 839 return SessionRunArgs(requests) 840 841 def after_run(self, run_context, run_values): 842 _ = run_context 843 if not self._summary_writer: 844 return 845 846 stale_global_step = run_values.results["global_step"] 847 global_step = stale_global_step + 1 848 if self._next_step is None or self._request_summary: 849 global_step = run_context.session.run(self._global_step_tensor) 850 851 if self._next_step is None: 852 self._summary_writer.add_session_log( 853 SessionLog(status=SessionLog.START), global_step) 854 855 if self._request_summary: 856 self._timer.update_last_triggered_step(global_step) 857 if "summary" in run_values.results: 858 for summary in run_values.results["summary"]: 859 self._summary_writer.add_summary(summary, global_step) 860 861 self._next_step = global_step + 1 862 863 def end(self, session=None): 864 if self._summary_writer: 865 self._summary_writer.flush() 866 867 def _get_summary_op(self): 868 """Fetches the summary op either from self._summary_op or self._scaffold. 869 870 Returns: 871 Returns a list of summary `Tensor`. 872 """ 873 summary_op = None 874 if self._summary_op is not None: 875 summary_op = self._summary_op 876 elif self._scaffold.summary_op is not None: 877 summary_op = self._scaffold.summary_op 878 879 if summary_op is None: 880 return None 881 882 if not isinstance(summary_op, list): 883 return [summary_op] 884 return summary_op 885 886 887@tf_export(v1=["train.GlobalStepWaiterHook"]) 888class GlobalStepWaiterHook(session_run_hook.SessionRunHook): 889 """Delays execution until global step reaches `wait_until_step`. 890 891 This hook delays execution until global step reaches to `wait_until_step`. It 892 is used to gradually start workers in distributed settings. One example usage 893 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that 894 task_id=0 is the chief. 895 """ 896 897 def __init__(self, wait_until_step): 898 """Initializes a `GlobalStepWaiterHook`. 899 900 Args: 901 wait_until_step: an `int` shows until which global step should we wait. 902 """ 903 self._wait_until_step = wait_until_step 904 905 def begin(self): 906 self._worker_is_started = False 907 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 908 if self._global_step_tensor is None: 909 raise RuntimeError( 910 "Global step should be created to use _GlobalStepWaiterHook.") 911 912 def before_run(self, run_context): 913 if self._worker_is_started: 914 return None 915 916 if self._wait_until_step <= 0: 917 self._worker_is_started = True 918 return None 919 920 logging.info("Waiting for global step %d before starting training.", 921 self._wait_until_step) 922 last_logged_step = 0 923 while True: 924 current_step = run_context.session.run(self._global_step_tensor) 925 if current_step >= self._wait_until_step: 926 self._worker_is_started = True 927 return None 928 if current_step - last_logged_step > 1000: 929 logging.info( 930 "Waiting for global step %d before starting training. " 931 "Current step is %d.", self._wait_until_step, current_step) 932 last_logged_step = current_step 933 time.sleep(0.5) 934 935 936@tf_export(v1=["train.FinalOpsHook"]) 937class FinalOpsHook(session_run_hook.SessionRunHook): 938 """A hook which evaluates `Tensors` at the end of a session.""" 939 940 def __init__(self, final_ops, final_ops_feed_dict=None): 941 """Initializes `FinalOpHook` with ops to run at the end of the session. 942 943 Args: 944 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 945 to `Tensors`. 946 final_ops_feed_dict: A feed dictionary to use when running 947 `final_ops_dict`. 948 """ 949 self._final_ops = final_ops 950 self._final_ops_feed_dict = final_ops_feed_dict 951 self._final_ops_values = None 952 953 @property 954 def final_ops_values(self): 955 return self._final_ops_values 956 957 def end(self, session): 958 if self._final_ops is not None: 959 try: 960 self._final_ops_values = session.run( 961 self._final_ops, feed_dict=self._final_ops_feed_dict) 962 except (errors.OutOfRangeError, StopIteration) as e: 963 logging.warning( 964 "An OutOfRangeError or StopIteration exception is raised by the " 965 "code in FinalOpsHook. This typically means the Ops running by the " 966 "FinalOpsHook have a dependency back to some input source, which " 967 "should not happen. For example, for metrics in " 968 "tf.estimator.Estimator, all metrics functions return two Ops: " 969 "`value_op` and `update_op`. Estimator.evaluate calls the " 970 "`update_op` for each batch of the data in input source and, once " 971 "it is exhausted, it call the `value_op` to get the metric values. " 972 "The `value_op` here should have dependency back to variables " 973 "reading only, rather than reading another batch from input. " 974 "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " 975 "another data reading, which ends OutOfRangeError/StopIteration. " 976 "Please fix that.") 977 raise e 978 979 980@tf_export(v1=["train.FeedFnHook"]) 981class FeedFnHook(session_run_hook.SessionRunHook): 982 """Runs `feed_fn` and sets the `feed_dict` accordingly.""" 983 984 def __init__(self, feed_fn): 985 """Initializes a `FeedFnHook`. 986 987 Args: 988 feed_fn: function that takes no arguments and returns `dict` of `Tensor` 989 to feed. 990 """ 991 self.feed_fn = feed_fn 992 993 def before_run(self, run_context): # pylint: disable=unused-argument 994 return session_run_hook.SessionRunArgs( 995 fetches=None, feed_dict=self.feed_fn()) 996 997 998@tf_export(v1=["train.ProfilerHook"]) 999class ProfilerHook(session_run_hook.SessionRunHook): 1000 """Captures CPU/GPU profiling information every N steps or seconds. 1001 1002 This produces files called "timeline-<step>.json", which are in Chrome 1003 Trace format. 1004 1005 For more information see: 1006 https://github.com/catapult-project/catapult/blob/master/tracing/README.md 1007 """ 1008 1009 def __init__(self, 1010 save_steps=None, 1011 save_secs=None, 1012 output_dir="", 1013 show_dataflow=True, 1014 show_memory=False): 1015 """Initializes a hook that takes periodic profiling snapshots. 1016 1017 `options.run_metadata` argument of `tf.Session.Run` is used to collect 1018 metadata about execution. This hook sets the metadata and dumps it in Chrome 1019 Trace format. 1020 1021 1022 Args: 1023 save_steps: `int`, save profile traces every N steps. Exactly one of 1024 `save_secs` and `save_steps` should be set. 1025 save_secs: `int` or `float`, save profile traces every N seconds. 1026 output_dir: `string`, the directory to save the profile traces to. 1027 Defaults to the current directory. 1028 show_dataflow: `bool`, if True, add flow events to the trace connecting 1029 producers and consumers of tensors. 1030 show_memory: `bool`, if True, add object snapshot events to the trace 1031 showing the sizes and lifetimes of tensors. 1032 """ 1033 self._output_file = os.path.join(output_dir, "timeline-{}.json") 1034 self._file_writer = SummaryWriterCache.get(output_dir) 1035 self._show_dataflow = show_dataflow 1036 self._show_memory = show_memory 1037 self._timer = SecondOrStepTimer( 1038 every_secs=save_secs, every_steps=save_steps) 1039 1040 def begin(self): 1041 self._next_step = None 1042 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 1043 if self._global_step_tensor is None: 1044 raise RuntimeError("Global step should be created to use ProfilerHook.") 1045 1046 def before_run(self, run_context): 1047 self._request_summary = ( 1048 self._next_step is not None and 1049 self._timer.should_trigger_for_step(self._next_step)) 1050 requests = {"global_step": self._global_step_tensor} 1051 opts = ( 1052 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1053 if self._request_summary else None) 1054 1055 return SessionRunArgs(requests, options=opts) 1056 1057 def after_run(self, run_context, run_values): 1058 stale_global_step = run_values.results["global_step"] 1059 if self._next_step is None: 1060 # Update the timer so that it does not activate until N steps or seconds 1061 # have passed. 1062 self._timer.update_last_triggered_step(stale_global_step) 1063 global_step = stale_global_step + 1 1064 if self._request_summary: 1065 global_step = run_context.session.run(self._global_step_tensor) 1066 self._timer.update_last_triggered_step(global_step) 1067 self._save(global_step, self._output_file.format(global_step), 1068 run_values.run_metadata.step_stats) 1069 self._file_writer.add_run_metadata(run_values.run_metadata, 1070 "step_%d" % global_step) 1071 1072 self._next_step = global_step + 1 1073 1074 def _save(self, step, save_path, step_stats): 1075 logging.info("Saving timeline for %d into '%s'.", step, save_path) 1076 with gfile.Open(save_path, "w") as f: 1077 trace = timeline.Timeline(step_stats) 1078 f.write( 1079 trace.generate_chrome_trace_format( 1080 show_dataflow=self._show_dataflow, show_memory=self._show_memory)) 1081 1082 1083def _as_graph_element(obj): 1084 """Retrieves Graph element.""" 1085 graph = ops.get_default_graph() 1086 if not isinstance(obj, six.string_types): 1087 if not hasattr(obj, "graph") or obj.graph != graph: 1088 raise ValueError("Passed %s should have graph attribute that is equal " 1089 "to current graph %s." % (obj, graph)) 1090 return obj 1091 if ":" in obj: 1092 element = graph.as_graph_element(obj) 1093 else: 1094 element = graph.as_graph_element(obj + ":0") 1095 # Check that there is no :1 (e.g. it's single output). 1096 try: 1097 graph.as_graph_element(obj + ":1") 1098 except (KeyError, ValueError): 1099 pass 1100 else: 1101 raise ValueError("Name %s is ambiguous, " 1102 "as this `Operation` has multiple outputs " 1103 "(at least 2)." % obj) 1104 return element 1105