1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Some common SessionRunHook classes. 16 17Note that the symbols that are exported to v1 tf.train namespace are also 18exported to v2 in tf.estimator namespace. See 19https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import os 27import time 28 29import numpy as np 30import six 31 32from tensorflow.core.framework.summary_pb2 import Summary 33from tensorflow.core.protobuf import config_pb2 34from tensorflow.core.util.event_pb2 import SessionLog 35from tensorflow.python.client import timeline 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import meta_graph 39from tensorflow.python.framework import ops 40from tensorflow.python.ops import init_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.platform import gfile 43from tensorflow.python.platform import tf_logging as logging 44from tensorflow.python.training import session_run_hook 45from tensorflow.python.training import training_util 46from tensorflow.python.training.session_run_hook import SessionRunArgs 47from tensorflow.python.training.summary_io import SummaryWriterCache 48from tensorflow.python.util.tf_export import tf_export 49 50_HOOKS = "hooks" 51_STEPS_PER_RUN_VAR = "steps_per_run" 52 53 54class _HookTimer(object): 55 """Base timer for determining when Hooks should trigger. 56 57 Should not be instantiated directly. 58 """ 59 60 def __init__(self): 61 pass 62 63 def reset(self): 64 """Resets the timer.""" 65 pass 66 67 def should_trigger_for_step(self, step): 68 """Return true if the timer should trigger for the specified step.""" 69 raise NotImplementedError 70 71 def update_last_triggered_step(self, step): 72 """Update the last triggered time and step number. 73 74 Args: 75 step: The current step. 76 77 Returns: 78 A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number 79 of seconds between the current trigger and the last one (a float), and 80 `elapsed_steps` is the number of steps between the current trigger and 81 the last one. Both values will be set to `None` on the first trigger. 82 """ 83 raise NotImplementedError 84 85 def last_triggered_step(self): 86 """Returns the last triggered time step or None if never triggered.""" 87 raise NotImplementedError 88 89 90@tf_export(v1=["train.SecondOrStepTimer"]) 91class SecondOrStepTimer(_HookTimer): 92 """Timer that triggers at most once every N seconds or once every N steps. 93 94 This symbol is also exported to v2 in tf.estimator namespace. See 95 https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py 96 """ 97 98 def __init__(self, every_secs=None, every_steps=None): 99 self.reset() 100 self._every_secs = every_secs 101 self._every_steps = every_steps 102 103 if self._every_secs is None and self._every_steps is None: 104 raise ValueError("Either every_secs or every_steps should be provided.") 105 if (self._every_secs is not None) and (self._every_steps is not None): 106 raise ValueError("Can not provide both every_secs and every_steps.") 107 108 super(SecondOrStepTimer, self).__init__() 109 110 def reset(self): 111 self._last_triggered_step = None 112 self._last_triggered_time = None 113 114 def should_trigger_for_step(self, step): 115 """Return true if the timer should trigger for the specified step. 116 117 Args: 118 step: Training step to trigger on. 119 120 Returns: 121 True if the difference between the current time and the time of the last 122 trigger exceeds `every_secs`, or if the difference between the current 123 step and the last triggered step exceeds `every_steps`. False otherwise. 124 """ 125 if self._last_triggered_step is None: 126 return True 127 128 if self._last_triggered_step == step: 129 return False 130 131 if self._every_secs is not None: 132 if time.time() >= self._last_triggered_time + self._every_secs: 133 return True 134 135 if self._every_steps is not None: 136 if step >= self._last_triggered_step + self._every_steps: 137 return True 138 139 return False 140 141 def update_last_triggered_step(self, step): 142 current_time = time.time() 143 if self._last_triggered_time is None: 144 elapsed_secs = None 145 elapsed_steps = None 146 else: 147 elapsed_secs = current_time - self._last_triggered_time 148 elapsed_steps = step - self._last_triggered_step 149 150 self._last_triggered_time = current_time 151 self._last_triggered_step = step 152 return (elapsed_secs, elapsed_steps) 153 154 def last_triggered_step(self): 155 return self._last_triggered_step 156 157 158class NeverTriggerTimer(_HookTimer): 159 """Timer that never triggers.""" 160 161 def should_trigger_for_step(self, step): 162 _ = step 163 return False 164 165 def update_last_triggered_step(self, step): 166 _ = step 167 return (None, None) 168 169 def last_triggered_step(self): 170 return None 171 172 173@tf_export(v1=["train.LoggingTensorHook"]) 174class LoggingTensorHook(session_run_hook.SessionRunHook): 175 """Prints the given tensors every N local steps, every N seconds, or at end. 176 177 The tensors will be printed to the log, with `INFO` severity. If you are not 178 seeing the logs, you might want to add the following line after your imports: 179 180 ```python 181 tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 182 ``` 183 184 Note that if `at_end` is True, `tensors` should not include any tensor 185 whose evaluation produces a side effect such as consuming additional inputs. 186 187 @compatibility(TF2) 188 Please check this [notebook][notebook] on how to migrate the API to TF2. 189 190 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb 191 192 @end_compatibility 193 194 """ 195 196 def __init__(self, 197 tensors, 198 every_n_iter=None, 199 every_n_secs=None, 200 at_end=False, 201 formatter=None): 202 """Initializes a `LoggingTensorHook`. 203 204 Args: 205 tensors: `dict` that maps string-valued tags to tensors/tensor names, or 206 `iterable` of tensors/tensor names. 207 every_n_iter: `int`, print the values of `tensors` once every N local 208 steps taken on the current worker. 209 every_n_secs: `int` or `float`, print the values of `tensors` once every N 210 seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 211 provided. 212 at_end: `bool` specifying whether to print the values of `tensors` at the 213 end of the run. 214 formatter: function, takes dict of `tag`->`Tensor` and returns a string. 215 If `None` uses default printing all tensors. 216 217 Raises: 218 ValueError: if `every_n_iter` is non-positive. 219 """ 220 only_log_at_end = ( 221 at_end and (every_n_iter is None) and (every_n_secs is None)) 222 if (not only_log_at_end and 223 (every_n_iter is None) == (every_n_secs is None)): 224 raise ValueError( 225 "either at_end and/or exactly one of every_n_iter and every_n_secs " 226 "must be provided.") 227 if every_n_iter is not None and every_n_iter <= 0: 228 raise ValueError("invalid every_n_iter=%s." % every_n_iter) 229 if not isinstance(tensors, dict): 230 self._tag_order = tensors 231 tensors = {item: item for item in tensors} 232 else: 233 self._tag_order = sorted(tensors.keys()) 234 self._tensors = tensors 235 self._formatter = formatter 236 self._timer = ( 237 NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer( 238 every_secs=every_n_secs, every_steps=every_n_iter)) 239 self._log_at_end = at_end 240 241 def begin(self): 242 self._timer.reset() 243 self._iter_count = 0 244 # Convert names to tensors if given 245 self._current_tensors = { 246 tag: _as_graph_element(tensor) 247 for (tag, tensor) in self._tensors.items() 248 } 249 250 def before_run(self, run_context): # pylint: disable=unused-argument 251 self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) 252 if self._should_trigger: 253 return SessionRunArgs(self._current_tensors) 254 else: 255 return None 256 257 def _log_tensors(self, tensor_values): 258 original = np.get_printoptions() 259 np.set_printoptions(suppress=True) 260 elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) 261 if self._formatter: 262 logging.info(self._formatter(tensor_values)) 263 else: 264 stats = [] 265 for tag in self._tag_order: 266 stats.append("%s = %s" % (tag, tensor_values[tag])) 267 if elapsed_secs is not None: 268 logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) 269 else: 270 logging.info("%s", ", ".join(stats)) 271 np.set_printoptions(**original) 272 273 def after_run(self, run_context, run_values): 274 _ = run_context 275 if self._should_trigger: 276 self._log_tensors(run_values.results) 277 278 self._iter_count += 1 279 280 def end(self, session): 281 if self._log_at_end: 282 values = session.run(self._current_tensors) 283 self._log_tensors(values) 284 285 286def get_or_create_steps_per_run_variable(): 287 """Gets or creates the steps_per_run variable. 288 289 In Estimator, the user provided computation, the model_fn, is wrapped 290 inside a tf.while_loop for peak performance. The iterations of the loop are 291 specified by this variable, which adjusts its value on the CPU after each 292 device program execution and before the next execution. 293 294 The purpose of using a variable, rather than a constant, is to allow 295 Estimator adapt the device training iterations according to the final steps 296 specified by users. For example, if the user sets the steps_per_run as 297 4 and steps as 10 in Estimator.train(), the steps_per_run 298 variable will have the following value before each training run. 299 300 - 1-st execution: steps_per_run = 4 301 - 2-nd execution: steps_per_run = 4 302 - 3-rd execution: steps_per_run = 2 303 304 As model_fn increases the global step once per train_op invocation, the global 305 step is 10 after all executions, matching the steps=10 inputs passed in by 306 users. 307 308 Returns: 309 A TF non-trainable resource variable. 310 311 Raises: 312 RuntimeError: If multi steps_per_run variables were found. 313 """ 314 graph = ops.get_default_graph() 315 collection_name = "{}_{}".format(_HOOKS, _STEPS_PER_RUN_VAR) 316 steps_per_run_vars = graph.get_collection(collection_name) 317 if len(steps_per_run_vars) == 1: 318 return steps_per_run_vars[0] 319 elif len(steps_per_run_vars) > 1: 320 raise RuntimeError("Multiple steps_per_run_var in collection.") 321 322 with variable_scope.variable_scope(_HOOKS, reuse=variable_scope.AUTO_REUSE): 323 return variable_scope.get_variable( 324 _STEPS_PER_RUN_VAR, 325 initializer=init_ops.ones_initializer(), 326 shape=[], 327 dtype=dtypes.int32, 328 trainable=False, 329 collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], 330 use_resource=True) 331 332 333class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook): 334 """Hook that requests stop at a specified step.""" 335 336 def __init__(self, num_steps=None, last_step=None, steps_per_run=1): 337 """Initializes a `MultiStepStopAtStepHook`. 338 339 This hook requests stop after either a number of steps have been 340 executed or a last step has been reached. Only one of the two options can be 341 specified. 342 343 if `num_steps` is specified, it indicates the number of steps to execute 344 after `begin()` is called. If instead `last_step` is specified, it 345 indicates the last step we want to execute, as passed to the `after_run()` 346 call. 347 348 In Estimator, the user provided computation, the model_fn, is wrapped 349 inside a tf.while_loop for peak performance. The steps_per_run variable 350 determines the number of iterations of the loop before returning to the CPU. 351 352 Args: 353 num_steps: Number of steps to execute. 354 last_step: Step after which to stop. 355 steps_per_run: Number of steps executed per run call. 356 357 Raises: 358 ValueError: If one of the arguments is invalid. 359 """ 360 if num_steps is None and last_step is None: 361 raise ValueError("One of num_steps or last_step must be specified.") 362 if num_steps is not None and last_step is not None: 363 raise ValueError("Only one of num_steps or last_step can be specified.") 364 if steps_per_run is None or steps_per_run < 1: 365 raise ValueError("steps_per_run should be greater than 0") 366 self._num_steps = num_steps 367 self._last_step = last_step 368 self._steps_per_run_initial_value = steps_per_run 369 370 def begin(self): 371 self._global_step_tensor = training_util.get_global_step() 372 if self._global_step_tensor is None: 373 raise RuntimeError("Global step should be created to use StopAtStepHook.") 374 self._steps_per_run_variable = get_or_create_steps_per_run_variable() 375 376 def _update_steps_per_run_variable(self, global_step, session): 377 steps = min(self._last_step - global_step, 378 self._steps_per_run_initial_value) 379 self._steps_per_run_variable.load(steps, session=session) 380 381 def after_create_session(self, session, coord): 382 global_step = session.run(self._global_step_tensor) 383 if self._last_step is None: 384 self._last_step = global_step + self._num_steps 385 self._update_steps_per_run_variable(global_step, session) 386 387 def after_run(self, run_context, run_values): 388 # Global step cannot be retrieved via SessionRunArgs and before_run due to 389 # race condition in hook execution. 390 global_step = run_context.session.run(self._global_step_tensor) 391 if global_step >= self._last_step: 392 run_context.request_stop() 393 else: 394 self._update_steps_per_run_variable(global_step, run_context.session) 395 396 397@tf_export(v1=["train.StopAtStepHook"]) 398class StopAtStepHook(session_run_hook.SessionRunHook): 399 """Hook that requests stop at a specified step. 400 401 @compatibility(TF2) 402 Please check this [notebook][notebook] on how to migrate the API to TF2. 403 404 [notebook]:https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/logging_stop_hook.ipynb 405 406 @end_compatibility 407 """ 408 409 def __init__(self, num_steps=None, last_step=None): 410 """Initializes a `StopAtStepHook`. 411 412 This hook requests stop after either a number of steps have been 413 executed or a last step has been reached. Only one of the two options can be 414 specified. 415 416 if `num_steps` is specified, it indicates the number of steps to execute 417 after `begin()` is called. If instead `last_step` is specified, it 418 indicates the last step we want to execute, as passed to the `after_run()` 419 call. 420 421 Args: 422 num_steps: Number of steps to execute. 423 last_step: Step after which to stop. 424 425 Raises: 426 ValueError: If one of the arguments is invalid. 427 """ 428 if num_steps is None and last_step is None: 429 raise ValueError("One of num_steps or last_step must be specified.") 430 if num_steps is not None and last_step is not None: 431 raise ValueError("Only one of num_steps or last_step can be specified.") 432 self._num_steps = num_steps 433 self._last_step = last_step 434 435 def begin(self): 436 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 437 if self._global_step_tensor is None: 438 raise RuntimeError("Global step should be created to use StopAtStepHook.") 439 440 def after_create_session(self, session, coord): 441 if self._last_step is None: 442 global_step = session.run(self._global_step_tensor) 443 self._last_step = global_step + self._num_steps 444 445 def before_run(self, run_context): # pylint: disable=unused-argument 446 return SessionRunArgs(self._global_step_tensor) 447 448 def after_run(self, run_context, run_values): 449 global_step = run_values.results + 1 450 if global_step >= self._last_step: 451 # Check latest global step to ensure that the targeted last step is 452 # reached. global_step read tensor is the value of global step 453 # before running the operation. We're not sure whether current session.run 454 # incremented the global_step or not. Here we're checking it. 455 456 step = run_context.session.run(self._global_step_tensor) 457 if step >= self._last_step: 458 run_context.request_stop() 459 460 461@tf_export(v1=["train.CheckpointSaverListener"]) 462class CheckpointSaverListener(object): 463 """Interface for listeners that take action before or after checkpoint save. 464 465 `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is 466 triggered, and provides callbacks at the following points: 467 - before using the session 468 - before each call to `Saver.save()` 469 - after each call to `Saver.save()` 470 - at the end of session 471 472 To use a listener, implement a class and pass the listener to a 473 `CheckpointSaverHook`, as in this example: 474 475 ```python 476 class ExampleCheckpointSaverListener(CheckpointSaverListener): 477 def begin(self): 478 # You can add ops to the graph here. 479 print('Starting the session.') 480 self.your_tensor = ... 481 482 def before_save(self, session, global_step_value): 483 print('About to write a checkpoint') 484 485 def after_save(self, session, global_step_value): 486 print('Done writing checkpoint.') 487 if decided_to_stop_training(): 488 return True 489 490 def end(self, session, global_step_value): 491 print('Done with the session.') 492 493 ... 494 listener = ExampleCheckpointSaverListener() 495 saver_hook = tf.estimator.CheckpointSaverHook( 496 checkpoint_dir, listeners=[listener]) 497 with 498 tf.compat.v1.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): 499 ... 500 ``` 501 502 A `CheckpointSaverListener` may simply take some action after every 503 checkpoint save. It is also possible for the listener to use its own schedule 504 to act less frequently, e.g. based on global_step_value. In this case, 505 implementors should implement the `end()` method to handle actions related to 506 the last checkpoint save. But the listener should not act twice if 507 `after_save()` already handled this last checkpoint save. 508 509 A `CheckpointSaverListener` can request training to be stopped, by returning 510 True in `after_save`. Please note that, in replicated distributed training 511 setting, only `chief` should use this behavior. Otherwise each worker will do 512 their own evaluation, which may be wasteful of resources. 513 """ 514 515 def begin(self): 516 pass 517 518 def before_save(self, session, global_step_value): 519 pass 520 521 def after_save(self, session, global_step_value): 522 pass 523 524 def end(self, session, global_step_value): 525 pass 526 527 528@tf_export(v1=["train.CheckpointSaverHook"]) 529class CheckpointSaverHook(session_run_hook.SessionRunHook): 530 """Saves checkpoints every N steps or seconds.""" 531 532 def __init__(self, 533 checkpoint_dir, 534 save_secs=None, 535 save_steps=None, 536 saver=None, 537 checkpoint_basename="model.ckpt", 538 scaffold=None, 539 listeners=None, 540 save_graph_def=True): 541 """Initializes a `CheckpointSaverHook`. 542 543 Args: 544 checkpoint_dir: `str`, base directory for the checkpoint files. 545 save_secs: `int`, save every N secs. 546 save_steps: `int`, save every N steps. 547 saver: `Saver` object, used for saving. 548 checkpoint_basename: `str`, base name for the checkpoint files. 549 scaffold: `Scaffold`, use to get saver object. 550 listeners: List of `CheckpointSaverListener` subclass instances. Used for 551 callbacks that run immediately before or after this hook saves the 552 checkpoint. 553 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 554 `checkpoint_dir`. The GraphDef is saved after the session is created as 555 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 556 `model.ckpt-*.meta`. 557 558 Raises: 559 ValueError: One of `save_steps` or `save_secs` should be set. 560 ValueError: At most one of `saver` or `scaffold` should be set. 561 """ 562 logging.info("Create CheckpointSaverHook.") 563 if saver is not None and scaffold is not None: 564 raise ValueError("You cannot provide both saver and scaffold.") 565 self._saver = saver 566 self._checkpoint_dir = checkpoint_dir 567 self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) 568 self._scaffold = scaffold 569 self._timer = SecondOrStepTimer( 570 every_secs=save_secs, every_steps=save_steps) 571 self._listeners = listeners or [] 572 self._steps_per_run = 1 573 self._save_graph_def = save_graph_def 574 575 def _set_steps_per_run(self, steps_per_run): 576 self._steps_per_run = steps_per_run 577 578 def begin(self): 579 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 580 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 581 if self._global_step_tensor is None: 582 raise RuntimeError( 583 "Global step should be created to use CheckpointSaverHook.") 584 for l in self._listeners: 585 l.begin() 586 587 def after_create_session(self, session, coord): 588 global_step = session.run(self._global_step_tensor) 589 if self._save_graph_def: 590 # We do write graph and saver_def at the first call of before_run. 591 # We cannot do this in begin, since we let other hooks to change graph and 592 # add variables in begin. Graph is finalized after all begin calls. 593 training_util.write_graph( 594 ops.get_default_graph().as_graph_def(add_shapes=True), 595 self._checkpoint_dir, "graph.pbtxt") 596 saver_def = self._get_saver().saver_def if self._get_saver() else None 597 graph = ops.get_default_graph() 598 meta_graph_def = meta_graph.create_meta_graph_def( 599 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) 600 self._summary_writer.add_graph(graph) 601 self._summary_writer.add_meta_graph(meta_graph_def) 602 # The checkpoint saved here is the state at step "global_step". 603 self._save(session, global_step) 604 self._timer.update_last_triggered_step(global_step) 605 606 def before_run(self, run_context): # pylint: disable=unused-argument 607 return SessionRunArgs(self._global_step_tensor) 608 609 def after_run(self, run_context, run_values): 610 stale_global_step = run_values.results 611 if self._timer.should_trigger_for_step(stale_global_step + 612 self._steps_per_run): 613 # get the real value after train op. 614 global_step = run_context.session.run(self._global_step_tensor) 615 if self._timer.should_trigger_for_step(global_step): 616 self._timer.update_last_triggered_step(global_step) 617 if self._save(run_context.session, global_step): 618 run_context.request_stop() 619 620 def end(self, session): 621 last_step = session.run(self._global_step_tensor) 622 if last_step != self._timer.last_triggered_step(): 623 self._save(session, last_step) 624 for l in self._listeners: 625 l.end(session, last_step) 626 627 def _save(self, session, step): 628 """Saves the latest checkpoint, returns should_stop.""" 629 logging.info("Calling checkpoint listeners before saving checkpoint %d...", 630 step) 631 for l in self._listeners: 632 l.before_save(session, step) 633 634 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 635 self._get_saver().save(session, self._save_path, global_step=step, 636 write_meta_graph=self._save_graph_def) 637 self._summary_writer.add_session_log( 638 SessionLog( 639 status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), 640 step) 641 logging.info("Calling checkpoint listeners after saving checkpoint %d...", 642 step) 643 should_stop = False 644 for l in self._listeners: 645 if l.after_save(session, step): 646 logging.info( 647 "A CheckpointSaverListener requested that training be stopped. " 648 "listener: {}".format(l)) 649 should_stop = True 650 return should_stop 651 652 def _get_saver(self): 653 if self._saver is not None: 654 return self._saver 655 elif self._scaffold is not None: 656 return self._scaffold.saver 657 658 # Get saver from the SAVERS collection if present. 659 collection_key = ops.GraphKeys.SAVERS 660 savers = ops.get_collection(collection_key) 661 if not savers: 662 raise RuntimeError( 663 "No items in collection {}. Please add a saver to the collection " 664 "or provide a saver or scaffold.".format(collection_key)) 665 elif len(savers) > 1: 666 raise RuntimeError( 667 "More than one item in collection {}. " 668 "Please indicate which one to use by passing it to the constructor." 669 .format(collection_key)) 670 671 self._saver = savers[0] 672 return savers[0] 673 674 675@tf_export(v1=["train.StepCounterHook"]) 676class StepCounterHook(session_run_hook.SessionRunHook): 677 """Hook that counts steps per second.""" 678 679 def __init__(self, 680 every_n_steps=100, 681 every_n_secs=None, 682 output_dir=None, 683 summary_writer=None): 684 685 if (every_n_steps is None) == (every_n_secs is None): 686 raise ValueError( 687 "exactly one of every_n_steps and every_n_secs should be provided.") 688 self._timer = SecondOrStepTimer( 689 every_steps=every_n_steps, every_secs=every_n_secs) 690 691 self._summary_writer = summary_writer 692 self._output_dir = output_dir 693 self._last_global_step = None 694 self._steps_per_run = 1 695 696 def _set_steps_per_run(self, steps_per_run): 697 self._steps_per_run = steps_per_run 698 699 def begin(self): 700 if self._summary_writer is None and self._output_dir: 701 self._summary_writer = SummaryWriterCache.get(self._output_dir) 702 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 703 if self._global_step_tensor is None: 704 raise RuntimeError( 705 "Global step should be created to use StepCounterHook.") 706 self._summary_tag = training_util.get_global_step().op.name + "/sec" 707 708 def before_run(self, run_context): # pylint: disable=unused-argument 709 return SessionRunArgs(self._global_step_tensor) 710 711 def _log_and_record(self, elapsed_steps, elapsed_time, global_step): 712 steps_per_sec = elapsed_steps / elapsed_time 713 if self._summary_writer is not None: 714 summary = Summary(value=[ 715 Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) 716 ]) 717 self._summary_writer.add_summary(summary, global_step) 718 logging.info("%s: %g", self._summary_tag, steps_per_sec) 719 720 def after_run(self, run_context, run_values): 721 _ = run_context 722 723 stale_global_step = run_values.results 724 if self._timer.should_trigger_for_step(stale_global_step + 725 self._steps_per_run): 726 # get the real value after train op. 727 global_step = run_context.session.run(self._global_step_tensor) 728 if self._timer.should_trigger_for_step(global_step): 729 elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 730 global_step) 731 if elapsed_time is not None: 732 self._log_and_record(elapsed_steps, elapsed_time, global_step) 733 734 # Check whether the global step has been increased. Here, we do not use the 735 # timer.last_triggered_step as the timer might record a different global 736 # step value such that the comparison could be unreliable. For simplicity, 737 # we just compare the stale_global_step with previously recorded version. 738 if stale_global_step == self._last_global_step: 739 # Here, we give a warning in the first 5 times if we have observed that 740 # the global step has not been increased. For some Optimizers, the global 741 # step is not increased each time by design. For example, 742 # SyncReplicaOptimizer doesn't increase the global step in worker's main 743 # train step. 744 logging.log_first_n( 745 logging.WARN, 746 "It seems that global step (tf.train.get_global_step) has not " 747 "been increased. Current value (could be stable): %s vs previous " 748 "value: %s. You could increase the global step by passing " 749 "tf.train.get_global_step() to Optimizer.apply_gradients or " 750 "Optimizer.minimize.", 5, stale_global_step, self._last_global_step) 751 752 self._last_global_step = stale_global_step 753 754 755@tf_export(v1=["train.NanLossDuringTrainingError"]) 756class NanLossDuringTrainingError(RuntimeError): 757 758 def __str__(self): 759 return "NaN loss during training." 760 761 762@tf_export(v1=["train.NanTensorHook"]) 763class NanTensorHook(session_run_hook.SessionRunHook): 764 """Monitors the loss tensor and stops training if loss is NaN. 765 766 Can either fail with exception or just stop training. 767 """ 768 769 def __init__(self, loss_tensor, fail_on_nan_loss=True): 770 """Initializes a `NanTensorHook`. 771 772 Args: 773 loss_tensor: `Tensor`, the loss tensor. 774 fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. 775 """ 776 self._loss_tensor = loss_tensor 777 self._fail_on_nan_loss = fail_on_nan_loss 778 779 def before_run(self, run_context): # pylint: disable=unused-argument 780 return SessionRunArgs(self._loss_tensor) 781 782 def after_run(self, run_context, run_values): 783 if np.isnan(run_values.results): 784 failure_message = "Model diverged with loss = NaN." 785 if self._fail_on_nan_loss: 786 logging.error(failure_message) 787 raise NanLossDuringTrainingError 788 else: 789 logging.warning(failure_message) 790 # We don't raise an error but we request stop without an exception. 791 run_context.request_stop() 792 793 794@tf_export(v1=["train.SummarySaverHook"]) 795class SummarySaverHook(session_run_hook.SessionRunHook): 796 """Saves summaries every N steps.""" 797 798 def __init__(self, 799 save_steps=None, 800 save_secs=None, 801 output_dir=None, 802 summary_writer=None, 803 scaffold=None, 804 summary_op=None): 805 """Initializes a `SummarySaverHook`. 806 807 Args: 808 save_steps: `int`, save summaries every N steps. Exactly one of 809 `save_secs` and `save_steps` should be set. 810 save_secs: `int`, save summaries every N seconds. 811 output_dir: `string`, the directory to save the summaries to. Only used if 812 no `summary_writer` is supplied. 813 summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, 814 one will be created accordingly. 815 scaffold: `Scaffold` to get summary_op if it's not provided. 816 summary_op: `Tensor` of type `string` containing the serialized `Summary` 817 protocol buffer or a list of `Tensor`. They are most likely an output by 818 TF summary methods like `tf.compat.v1.summary.scalar` or 819 `tf.compat.v1.summary.merge_all`. It can be passed in as one tensor; if 820 more than one, they must be passed in as a list. 821 822 Raises: 823 ValueError: Exactly one of scaffold or summary_op should be set. 824 """ 825 if ((scaffold is None and summary_op is None) or 826 (scaffold is not None and summary_op is not None)): 827 raise ValueError( 828 "Exactly one of scaffold or summary_op must be provided.") 829 self._summary_op = summary_op 830 self._summary_writer = summary_writer 831 self._output_dir = output_dir 832 self._scaffold = scaffold 833 self._timer = SecondOrStepTimer( 834 every_secs=save_secs, every_steps=save_steps) 835 # TODO(mdan): Throw an error if output_dir and summary_writer are None. 836 837 def begin(self): 838 if self._summary_writer is None and self._output_dir: 839 self._summary_writer = SummaryWriterCache.get(self._output_dir) 840 self._next_step = None 841 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 842 if self._global_step_tensor is None: 843 raise RuntimeError( 844 "Global step should be created to use SummarySaverHook.") 845 846 def before_run(self, run_context): # pylint: disable=unused-argument 847 self._request_summary = ( 848 self._next_step is None or 849 self._timer.should_trigger_for_step(self._next_step)) 850 requests = {"global_step": self._global_step_tensor} 851 if self._request_summary: 852 if self._get_summary_op() is not None: 853 requests["summary"] = self._get_summary_op() 854 855 return SessionRunArgs(requests) 856 857 def after_run(self, run_context, run_values): 858 _ = run_context 859 if not self._summary_writer: 860 return 861 862 stale_global_step = run_values.results["global_step"] 863 global_step = stale_global_step + 1 864 if self._next_step is None or self._request_summary: 865 global_step = run_context.session.run(self._global_step_tensor) 866 867 if self._next_step is None: 868 self._summary_writer.add_session_log( 869 SessionLog(status=SessionLog.START), global_step) 870 871 if self._request_summary: 872 self._timer.update_last_triggered_step(global_step) 873 if "summary" in run_values.results: 874 for summary in run_values.results["summary"]: 875 self._summary_writer.add_summary(summary, global_step) 876 877 self._next_step = global_step + 1 878 879 def end(self, session=None): 880 if self._summary_writer: 881 self._summary_writer.flush() 882 883 def _get_summary_op(self): 884 """Fetches the summary op either from self._summary_op or self._scaffold. 885 886 Returns: 887 Returns a list of summary `Tensor`. 888 """ 889 summary_op = None 890 if self._summary_op is not None: 891 summary_op = self._summary_op 892 elif self._scaffold.summary_op is not None: 893 summary_op = self._scaffold.summary_op 894 895 if summary_op is None: 896 return None 897 898 if not isinstance(summary_op, list): 899 return [summary_op] 900 return summary_op 901 902 903@tf_export(v1=["train.GlobalStepWaiterHook"]) 904class GlobalStepWaiterHook(session_run_hook.SessionRunHook): 905 """Delays execution until global step reaches `wait_until_step`. 906 907 This hook delays execution until global step reaches to `wait_until_step`. It 908 is used to gradually start workers in distributed settings. One example usage 909 would be setting `wait_until_step=int(K*log(task_id+1))` assuming that 910 task_id=0 is the chief. 911 """ 912 913 def __init__(self, wait_until_step): 914 """Initializes a `GlobalStepWaiterHook`. 915 916 Args: 917 wait_until_step: an `int` shows until which global step should we wait. 918 """ 919 self._wait_until_step = wait_until_step 920 921 def begin(self): 922 self._worker_is_started = False 923 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 924 if self._global_step_tensor is None: 925 raise RuntimeError( 926 "Global step should be created to use _GlobalStepWaiterHook.") 927 928 def before_run(self, run_context): 929 if self._worker_is_started: 930 return None 931 932 if self._wait_until_step <= 0: 933 self._worker_is_started = True 934 return None 935 936 logging.info("Waiting for global step %d before starting training.", 937 self._wait_until_step) 938 last_logged_step = 0 939 while True: 940 current_step = run_context.session.run(self._global_step_tensor) 941 if current_step >= self._wait_until_step: 942 self._worker_is_started = True 943 return None 944 if current_step - last_logged_step > 1000: 945 logging.info( 946 "Waiting for global step %d before starting training. " 947 "Current step is %d.", self._wait_until_step, current_step) 948 last_logged_step = current_step 949 time.sleep(0.5) 950 951 952@tf_export(v1=["train.FinalOpsHook"]) 953class FinalOpsHook(session_run_hook.SessionRunHook): 954 """A hook which evaluates `Tensors` at the end of a session.""" 955 956 def __init__(self, final_ops, final_ops_feed_dict=None): 957 """Initializes `FinalOpHook` with ops to run at the end of the session. 958 959 Args: 960 final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names 961 to `Tensors`. 962 final_ops_feed_dict: A feed dictionary to use when running 963 `final_ops_dict`. 964 """ 965 self._final_ops = final_ops 966 self._final_ops_feed_dict = final_ops_feed_dict 967 self._final_ops_values = None 968 969 @property 970 def final_ops_values(self): 971 return self._final_ops_values 972 973 def end(self, session): 974 if self._final_ops is not None: 975 try: 976 self._final_ops_values = session.run( 977 self._final_ops, feed_dict=self._final_ops_feed_dict) 978 except (errors.OutOfRangeError, StopIteration) as e: 979 logging.warning( 980 "An OutOfRangeError or StopIteration exception is raised by the " 981 "code in FinalOpsHook. This typically means the Ops running by the " 982 "FinalOpsHook have a dependency back to some input source, which " 983 "should not happen. For example, for metrics in " 984 "tf.estimator.Estimator, all metrics functions return two Ops: " 985 "`value_op` and `update_op`. Estimator.evaluate calls the " 986 "`update_op` for each batch of the data in input source and, once " 987 "it is exhausted, it call the `value_op` to get the metric values. " 988 "The `value_op` here should have dependency back to variables " 989 "reading only, rather than reading another batch from input. " 990 "Otherwise, the `value_op`, executed by `FinalOpsHook`, triggers " 991 "another data reading, which ends OutOfRangeError/StopIteration. " 992 "Please fix that.") 993 raise e 994 995 996@tf_export(v1=["train.FeedFnHook"]) 997class FeedFnHook(session_run_hook.SessionRunHook): 998 """Runs `feed_fn` and sets the `feed_dict` accordingly.""" 999 1000 def __init__(self, feed_fn): 1001 """Initializes a `FeedFnHook`. 1002 1003 Args: 1004 feed_fn: function that takes no arguments and returns `dict` of `Tensor` 1005 to feed. 1006 """ 1007 self.feed_fn = feed_fn 1008 1009 def before_run(self, run_context): # pylint: disable=unused-argument 1010 return session_run_hook.SessionRunArgs( 1011 fetches=None, feed_dict=self.feed_fn()) 1012 1013 1014@tf_export(v1=["train.ProfilerHook"]) 1015class ProfilerHook(session_run_hook.SessionRunHook): 1016 """Captures CPU/GPU profiling information every N steps or seconds. 1017 1018 This produces files called "timeline-<step>.json", which are in Chrome 1019 Trace format. 1020 1021 For more information see: 1022 https://github.com/catapult-project/catapult/blob/master/tracing/README.md 1023 """ 1024 1025 def __init__(self, 1026 save_steps=None, 1027 save_secs=None, 1028 output_dir="", 1029 show_dataflow=True, 1030 show_memory=False): 1031 """Initializes a hook that takes periodic profiling snapshots. 1032 1033 `options.run_metadata` argument of `tf.Session.Run` is used to collect 1034 metadata about execution. This hook sets the metadata and dumps it in Chrome 1035 Trace format. 1036 1037 1038 Args: 1039 save_steps: `int`, save profile traces every N steps. Exactly one of 1040 `save_secs` and `save_steps` should be set. 1041 save_secs: `int` or `float`, save profile traces every N seconds. 1042 output_dir: `string`, the directory to save the profile traces to. 1043 Defaults to the current directory. 1044 show_dataflow: `bool`, if True, add flow events to the trace connecting 1045 producers and consumers of tensors. 1046 show_memory: `bool`, if True, add object snapshot events to the trace 1047 showing the sizes and lifetimes of tensors. 1048 """ 1049 self._output_file = os.path.join(output_dir, "timeline-{}.json") 1050 self._file_writer = SummaryWriterCache.get(output_dir) 1051 self._show_dataflow = show_dataflow 1052 self._show_memory = show_memory 1053 self._timer = SecondOrStepTimer( 1054 every_secs=save_secs, every_steps=save_steps) 1055 1056 def begin(self): 1057 self._next_step = None 1058 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 1059 if self._global_step_tensor is None: 1060 raise RuntimeError("Global step should be created to use ProfilerHook.") 1061 1062 def before_run(self, run_context): 1063 self._request_summary = ( 1064 self._next_step is not None and 1065 self._timer.should_trigger_for_step(self._next_step)) 1066 requests = {"global_step": self._global_step_tensor} 1067 opts = ( 1068 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1069 if self._request_summary else None) 1070 1071 return SessionRunArgs(requests, options=opts) 1072 1073 def after_run(self, run_context, run_values): 1074 stale_global_step = run_values.results["global_step"] 1075 if self._next_step is None: 1076 # Update the timer so that it does not activate until N steps or seconds 1077 # have passed. 1078 self._timer.update_last_triggered_step(stale_global_step) 1079 global_step = stale_global_step + 1 1080 if self._request_summary: 1081 global_step = run_context.session.run(self._global_step_tensor) 1082 self._timer.update_last_triggered_step(global_step) 1083 self._save(global_step, self._output_file.format(global_step), 1084 run_values.run_metadata.step_stats) 1085 self._file_writer.add_run_metadata(run_values.run_metadata, 1086 "step_%d" % global_step) 1087 1088 self._next_step = global_step + 1 1089 1090 def _save(self, step, save_path, step_stats): 1091 logging.info("Saving timeline for %d into '%s'.", step, save_path) 1092 with gfile.Open(save_path, "w") as f: 1093 trace = timeline.Timeline(step_stats) 1094 f.write( 1095 trace.generate_chrome_trace_format( 1096 show_dataflow=self._show_dataflow, show_memory=self._show_memory)) 1097 1098 1099def _as_graph_element(obj): 1100 """Retrieves Graph element.""" 1101 graph = ops.get_default_graph() 1102 if not isinstance(obj, six.string_types): 1103 if not hasattr(obj, "graph") or obj.graph != graph: 1104 raise ValueError("Passed %s should have graph attribute that is equal " 1105 "to current graph %s." % (obj, graph)) 1106 return obj 1107 if ":" in obj: 1108 element = graph.as_graph_element(obj) 1109 else: 1110 element = graph.as_graph_element(obj + ":0") 1111 # Check that there is no :1 (e.g. it's single output). 1112 try: 1113 graph.as_graph_element(obj + ":1") 1114 except (KeyError, ValueError): 1115 pass 1116 else: 1117 raise ValueError("Name %s is ambiguous, " 1118 "as this `Operation` has multiple outputs " 1119 "(at least 2)." % obj) 1120 return element 1121