1# pylint: disable=g-bad-file-header 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""A wrapper of Session API which runs hooks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import abc 23import os 24import sys 25 26import six 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.distribute import distribute_coordinator_context 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import lookup_ops 35from tensorflow.python.ops import resources 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.summary import summary 39from tensorflow.python.training import basic_session_run_hooks 40from tensorflow.python.training import coordinator 41from tensorflow.python.training import queue_runner 42from tensorflow.python.training import saver as training_saver 43from tensorflow.python.training import session_manager as sm 44from tensorflow.python.training import session_run_hook 45from tensorflow.python.training.tracking import graph_view 46from tensorflow.python.training.tracking import util as trackable_util 47from tensorflow.python.util import function_utils 48from tensorflow.python.util.tf_export import tf_export 49 50# The list of exceptions that we should recover from. Exceptions not in this 51# list may terminate the job. 52_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) 53 54# Value that indicates no value was provided. 55USE_DEFAULT = object() 56 57 58@tf_export(v1=['train.Scaffold']) 59class Scaffold(object): 60 """Structure to create or gather pieces commonly needed to train a model. 61 62 When you build a model for training you usually need ops to initialize 63 variables, a `Saver` to checkpoint them, an op to collect summaries for 64 the visualizer, and so on. 65 66 Various libraries built on top of the core TensorFlow library take care of 67 creating some or all of these pieces and storing them in well known 68 collections in the graph. The `Scaffold` class helps pick these pieces from 69 the graph collections, creating and adding them to the collections if needed. 70 71 If you call the scaffold constructor without any arguments, it will pick 72 pieces from the collections, creating default ones if needed when 73 `scaffold.finalize()` is called. You can pass arguments to the constructor to 74 provide your own pieces. Pieces that you pass to the constructor are not 75 added to the graph collections. 76 77 The following pieces are directly accessible as attributes of the `Scaffold` 78 object: 79 80 * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the 81 variables. 82 Picked from and stored into the `SAVERS` collection in the graph by default. 83 * `init_op`: An op to run to initialize the variables. Picked from and 84 stored into the `INIT_OP` collection in the graph by default. 85 * `ready_op`: An op to verify that the variables are initialized. Picked 86 from and stored into the `READY_OP` collection in the graph by default. 87 * `ready_for_local_init_op`: An op to verify that global state has been 88 initialized and it is alright to run `local_init_op`. Picked from and 89 stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by 90 default. This is needed when the initialization of local variables depends 91 on the values of global variables. 92 * `local_init_op`: An op to initialize the local variables. Picked 93 from and stored into the `LOCAL_INIT_OP` collection in the graph by default. 94 * `summary_op`: An op to run and merge the summaries in the graph. Picked 95 from and stored into the `SUMMARY_OP` collection in the graph by default. 96 97 You can also pass the following additional pieces to the constructor: 98 99 * `init_feed_dict`: A session feed dictionary that should be used when 100 running the init op. 101 * `init_fn`: A callable to run after the init op to perform additional 102 initializations. The callable will be called as 103 `init_fn(scaffold, session)`. 104 105 """ 106 107 def __init__(self, 108 init_op=None, 109 init_feed_dict=None, 110 init_fn=None, 111 ready_op=None, 112 ready_for_local_init_op=None, 113 local_init_op=None, 114 summary_op=None, 115 saver=None, 116 copy_from_scaffold=None, 117 local_init_feed_dict=None): 118 """Create a scaffold. 119 120 Args: 121 init_op: Optional op for initializing variables. 122 init_feed_dict: Optional session feed dictionary to use when running the 123 init_op. 124 init_fn: Optional function to use to initialize the model after running 125 the init_op. Will be called as `init_fn(scaffold, session)`. 126 ready_op: Optional op to verify that the variables are initialized. Must 127 return an empty 1D string tensor when the variables are initialized, or 128 a non-empty 1D string tensor listing the names of the non-initialized 129 variables. 130 ready_for_local_init_op: Optional op to verify that the global variables 131 are initialized and `local_init_op` can be run. Must return an empty 1D 132 string tensor when the global variables are initialized, or a non-empty 133 1D string tensor listing the names of the non-initialized global 134 variables. 135 local_init_op: Optional op to initialize local variables. 136 summary_op: Optional op to gather all summaries. Must return a scalar 137 string tensor containing a serialized `Summary` proto. 138 saver: Optional `tf.compat.v1.train.Saver` object to use to save and 139 restore variables. May also be a `tf.train.Checkpoint` object, in which 140 case object-based checkpoints are saved. This will also load some 141 object-based checkpoints saved from elsewhere, but that loading may be 142 fragile since it uses fixed keys rather than performing a full 143 graph-based match. For example if a variable has two paths from the 144 `Checkpoint` object because two `Model` objects share the `Layer` object 145 that owns it, removing one `Model` may change the keys and break 146 checkpoint loading through this API, whereas a graph-based match would 147 match the variable through the other `Model`. 148 copy_from_scaffold: Optional scaffold object to copy fields from. Its 149 fields will be overwritten by the provided fields in this function. 150 local_init_feed_dict: Optional session feed dictionary to use when running 151 the local_init_op. 152 """ 153 if copy_from_scaffold is not None: 154 if not isinstance(copy_from_scaffold, Scaffold): 155 raise TypeError('copy_from_scaffold is not a Scaffold instance.') 156 # We need _coalesce since Tensor is not converted to bool automatically, 157 # so the common idiom of (a or b) does not work. 158 coalesce = lambda a, b: a if a is not None else b 159 init_op = coalesce(init_op, copy_from_scaffold.init_op) 160 init_feed_dict = coalesce(init_feed_dict, 161 copy_from_scaffold.init_feed_dict) 162 # Use the original init_fn provided by the user to init the new Scaffold. 163 init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access 164 ready_op = coalesce(ready_op, copy_from_scaffold.ready_op) 165 ready_for_local_init_op = coalesce( 166 ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) 167 local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) 168 local_init_feed_dict = coalesce(local_init_feed_dict, 169 copy_from_scaffold.local_init_feed_dict) 170 summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) 171 saver = coalesce(saver, copy_from_scaffold.saver) 172 173 # NOTE(touts): modifying the init function to be passed the scaffold is a 174 # hack to make it easy to find the saver. Is there a better way? 175 self._user_init_fn = init_fn 176 if init_fn: 177 self._init_fn = lambda sess: init_fn(self, sess) 178 else: 179 self._init_fn = None 180 181 self._init_op = init_op 182 self._init_feed_dict = init_feed_dict 183 self._ready_op = ready_op 184 self._ready_for_local_init_op = ready_for_local_init_op 185 self._local_init_op = local_init_op 186 self._local_init_feed_dict = local_init_feed_dict 187 self._summary_op = summary_op 188 self._saver = saver 189 190 def finalize(self): 191 """Creates operations if needed and finalizes the graph.""" 192 if self._init_op is None: 193 194 def default_init_op(): 195 return control_flow_ops.group( 196 variables.global_variables_initializer(), 197 resources.initialize_resources(resources.shared_resources())) 198 199 self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP, 200 default_init_op) 201 if self._ready_op is None: 202 203 def default_ready_op(): 204 return array_ops.concat([ 205 variables.report_uninitialized_variables(), 206 resources.report_uninitialized_resources() 207 ], 0) 208 209 self._ready_op = Scaffold.get_or_default('ready_op', 210 ops.GraphKeys.READY_OP, 211 default_ready_op) 212 if self._ready_for_local_init_op is None: 213 214 def default_ready_for_local_init_op(): 215 return array_ops.concat([ 216 variables.report_uninitialized_variables( 217 variables.global_variables()), 218 resources.report_uninitialized_resources( 219 resources.shared_resources()) 220 ], 0) 221 222 self._ready_for_local_init_op = Scaffold.get_or_default( 223 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, 224 default_ready_for_local_init_op) 225 if self._local_init_op is None: 226 self._local_init_op = Scaffold.get_or_default( 227 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, 228 Scaffold.default_local_init_op) 229 if self._summary_op is None: 230 self._summary_op = Scaffold.get_or_default('summary_op', 231 ops.GraphKeys.SUMMARY_OP, 232 summary.merge_all) 233 # pylint: disable=g-long-lambda 234 if self._saver is None: 235 self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access 236 # pylint: enable=g-long-lambda 237 if isinstance(self._saver, trackable_util.Checkpoint): 238 self._saver = training_saver.Saver( 239 var_list=graph_view.ObjectGraphView( 240 self._saver).frozen_saveable_objects(), 241 sharded=True) 242 else: 243 self._saver.build() 244 245 ops.get_default_graph().finalize() 246 logging.info('Graph was finalized.') 247 return self 248 249 @property 250 def init_fn(self): 251 return self._init_fn 252 253 @property 254 def init_op(self): 255 return self._init_op 256 257 @property 258 def ready_op(self): 259 return self._ready_op 260 261 @property 262 def ready_for_local_init_op(self): 263 return self._ready_for_local_init_op 264 265 @property 266 def local_init_op(self): 267 return self._local_init_op 268 269 @property 270 def local_init_feed_dict(self): 271 return self._local_init_feed_dict 272 273 @property 274 def summary_op(self): 275 return self._summary_op 276 277 @property 278 def saver(self): 279 return self._saver 280 281 @property 282 def init_feed_dict(self): 283 return self._init_feed_dict 284 285 @staticmethod 286 def get_or_default(arg_name, collection_key, default_constructor): 287 """Get from cache or create a default operation.""" 288 elements = ops.get_collection(collection_key) 289 if elements: 290 if len(elements) > 1: 291 raise RuntimeError( 292 'More than one item in the collection "%s". ' 293 'Please indicate which one to use by passing it to ' 294 'the tf.Scaffold constructor as: ' 295 'tf.Scaffold(%s=item to use)', collection_key, arg_name) 296 return elements[0] 297 op = default_constructor() 298 if op is not None: 299 ops.add_to_collection(collection_key, op) 300 return op 301 302 @staticmethod 303 def default_local_init_op(): 304 """Returns an op that groups the default local init ops. 305 306 This op is used during session initialization when a Scaffold is 307 initialized without specifying the local_init_op arg. It includes 308 `tf.compat.v1.local_variables_initializer`, 309 `tf.compat.v1.tables_initializer`, and also 310 initializes local session resources. 311 312 Returns: 313 The default Scaffold local init op. 314 """ 315 return control_flow_ops.group( 316 variables.local_variables_initializer(), 317 lookup_ops.tables_initializer(), 318 resources.initialize_resources(resources.local_resources())) 319 320 321def _create_monitored_session_with_worker_context( 322 worker_context, # pylint: disable=missing-docstring 323 scaffold, 324 checkpoint_dir=None, 325 hooks=None, 326 chief_only_hooks=None, 327 save_checkpoint_secs=None, 328 save_summaries_steps=None, 329 save_summaries_secs=None, 330 config=None, 331 stop_grace_period_secs=120, 332 log_step_count_steps=100, 333 max_wait_secs=7200, 334 save_checkpoint_steps=None, 335 summary_dir=None, 336 save_graph_def=True): 337 all_hooks = [] 338 if hooks: 339 all_hooks.extend(hooks) 340 if chief_only_hooks and worker_context.is_chief: 341 all_hooks.extend(chief_only_hooks) 342 343 # We need to call save or summary ops on all workers since these ops may 344 # contain collective ops, only running save ops on some workers would make 345 # collective ops hang. Therefore on those workers that don't need to actually 346 # write checkpoints or summaries, we let them write to a temp directory. 347 # pylint: disable=protected-access 348 if type( 349 worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy', 350 'CollectiveAllReduceStrategyV1', 351 'MultiWorkerMirroredStrategy'): 352 if worker_context.task_type: 353 tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id) 354 else: 355 tmpdir = 'tmp' 356 357 if save_checkpoint_secs: 358 logging.warning('Collective ops may deadlock with ' 359 '`save_checkpoints_secs` please use ' 360 '`save_checkpoint_steps` instead. Clearing ' 361 '`save_checkpoint_secs` and setting ' 362 '`save_checkpoint_steps` to 1000 now.') 363 save_checkpoint_secs = None 364 save_checkpoint_steps = 1000 365 if save_summaries_secs: 366 logging.warning('Collective ops may run out of sync with' 367 '`save_summaries_secs`, please use ' 368 '`save_summaries_steps` instead.') 369 else: 370 tmpdir = None 371 372 summary_dir = summary_dir or checkpoint_dir 373 if summary_dir and log_step_count_steps and log_step_count_steps > 0: 374 if worker_context.should_save_summary: 375 all_hooks.append( 376 basic_session_run_hooks.StepCounterHook( 377 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 378 elif tmpdir: 379 all_hooks.append( 380 basic_session_run_hooks.StepCounterHook( 381 output_dir=os.path.join(summary_dir, tmpdir), 382 every_n_steps=log_step_count_steps)) 383 384 if (((save_summaries_steps and save_summaries_steps > 0) or 385 (save_summaries_secs and save_summaries_secs > 0)) and summary_dir): 386 if worker_context.should_save_summary: 387 all_hooks.append( 388 basic_session_run_hooks.SummarySaverHook( 389 scaffold=scaffold, 390 save_steps=save_summaries_steps, 391 save_secs=save_summaries_secs, 392 output_dir=summary_dir)) 393 elif tmpdir: 394 all_hooks.append( 395 basic_session_run_hooks.SummarySaverHook( 396 scaffold=scaffold, 397 save_steps=save_summaries_steps, 398 save_secs=save_summaries_secs, 399 output_dir=os.path.join(summary_dir, tmpdir))) 400 401 if (((save_checkpoint_secs and save_checkpoint_secs > 0) or 402 (save_checkpoint_steps and save_checkpoint_steps > 0)) and 403 checkpoint_dir): 404 if worker_context.should_checkpoint: 405 all_hooks.append( 406 basic_session_run_hooks.CheckpointSaverHook( 407 checkpoint_dir, 408 save_steps=save_checkpoint_steps, 409 save_secs=save_checkpoint_secs, 410 scaffold=scaffold, 411 save_graph_def=save_graph_def)) 412 elif tmpdir: 413 all_hooks.append( 414 basic_session_run_hooks.CheckpointSaverHook( 415 os.path.join(checkpoint_dir, tmpdir), 416 save_steps=save_checkpoint_steps, 417 save_secs=save_checkpoint_secs, 418 scaffold=scaffold, 419 save_graph_def=save_graph_def)) 420 421 logging.info('all_hooks %r', all_hooks) 422 session_creator = worker_context.session_creator( 423 scaffold, 424 config=config, 425 checkpoint_dir=checkpoint_dir, 426 max_wait_secs=max_wait_secs) 427 return MonitoredSession( 428 session_creator=session_creator, 429 hooks=all_hooks, 430 stop_grace_period_secs=stop_grace_period_secs) 431 432 433@tf_export(v1=['train.MonitoredTrainingSession']) 434def MonitoredTrainingSession( 435 master='', # pylint: disable=invalid-name 436 is_chief=True, 437 checkpoint_dir=None, 438 scaffold=None, 439 hooks=None, 440 chief_only_hooks=None, 441 save_checkpoint_secs=USE_DEFAULT, 442 save_summaries_steps=USE_DEFAULT, 443 save_summaries_secs=USE_DEFAULT, 444 config=None, 445 stop_grace_period_secs=120, 446 log_step_count_steps=100, 447 max_wait_secs=7200, 448 save_checkpoint_steps=USE_DEFAULT, 449 summary_dir=None, 450 save_graph_def=True): 451 """Creates a `MonitoredSession` for training. 452 453 For a chief, this utility sets proper session initializer/restorer. It also 454 creates hooks related to checkpoint and summary saving. For workers, this 455 utility sets proper session creator which waits for the chief to 456 initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for 457 more 458 information. 459 460 461 Args: 462 master: `String` the TensorFlow master to use. 463 is_chief: If `True`, it will take care of initialization and recovery the 464 underlying TensorFlow session. If `False`, it will wait on a chief to 465 initialize or recover the TensorFlow session. 466 checkpoint_dir: A string. Optional path to a directory where to restore 467 variables. 468 scaffold: A `Scaffold` used for gathering or building supportive ops. If not 469 specified, a default one is created. It's used to finalize the graph. 470 hooks: Optional list of `SessionRunHook` objects. 471 chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if 472 `is_chief==True`, ignore otherwise. 473 save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved 474 using a default checkpoint saver. If both `save_checkpoint_steps` and 475 `save_checkpoint_secs` are set to `None`, then the default checkpoint 476 saver isn't used. If both are provided, then only `save_checkpoint_secs` 477 is used. Default 600. 478 save_summaries_steps: The frequency, in number of global steps, that the 479 summaries are written to disk using a default summary saver. If both 480 `save_summaries_steps` and `save_summaries_secs` are set to `None`, then 481 the default summary saver isn't used. Default 100. 482 save_summaries_secs: The frequency, in secs, that the summaries are written 483 to disk using a default summary saver. If both `save_summaries_steps` and 484 `save_summaries_secs` are set to `None`, then the default summary saver 485 isn't used. Default not enabled. 486 config: an instance of `tf.compat.v1.ConfigProto` proto used to configure 487 the session. It's the `config` argument of constructor of 488 `tf.compat.v1.Session`. 489 stop_grace_period_secs: Number of seconds given to threads to stop after 490 `close()` has been called. 491 log_step_count_steps: The frequency, in number of global steps, that the 492 global step/sec is logged. 493 max_wait_secs: Maximum time workers should wait for the session to become 494 available. This should be kept relatively short to help detect incorrect 495 code, but sometimes may need to be increased if the chief takes a while to 496 start up. 497 save_checkpoint_steps: The frequency, in number of global steps, that a 498 checkpoint is saved using a default checkpoint saver. If both 499 `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then 500 the default checkpoint saver isn't used. If both are provided, then only 501 `save_checkpoint_secs` is used. Default not enabled. 502 summary_dir: A string. Optional path to a directory where to save 503 summaries. If None, checkpoint_dir is used instead. 504 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 505 `checkpoint_dir`. The GraphDef is saved after the session is created as 506 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 507 `model.ckpt-*.meta`. 508 509 Returns: 510 A `MonitoredSession` object. 511 """ 512 if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: 513 save_summaries_steps = 100 514 save_summaries_secs = None 515 elif save_summaries_secs == USE_DEFAULT: 516 save_summaries_secs = None 517 elif save_summaries_steps == USE_DEFAULT: 518 save_summaries_steps = None 519 520 if (save_checkpoint_steps == USE_DEFAULT and 521 save_checkpoint_secs == USE_DEFAULT): 522 save_checkpoint_steps = None 523 save_checkpoint_secs = 600 524 elif save_checkpoint_secs == USE_DEFAULT: 525 save_checkpoint_secs = None 526 elif save_checkpoint_steps == USE_DEFAULT: 527 save_checkpoint_steps = None 528 529 scaffold = scaffold or Scaffold() 530 worker_context = distribute_coordinator_context.get_current_worker_context() 531 532 if worker_context: 533 return _create_monitored_session_with_worker_context( 534 worker_context, 535 scaffold, 536 checkpoint_dir=checkpoint_dir, 537 hooks=hooks, 538 chief_only_hooks=chief_only_hooks, 539 save_checkpoint_secs=save_checkpoint_secs, 540 save_summaries_steps=save_summaries_steps, 541 save_summaries_secs=save_summaries_secs, 542 config=config, 543 stop_grace_period_secs=stop_grace_period_secs, 544 log_step_count_steps=log_step_count_steps, 545 max_wait_secs=max_wait_secs, 546 save_checkpoint_steps=save_checkpoint_steps, 547 summary_dir=summary_dir, 548 save_graph_def=save_graph_def) 549 550 if not is_chief: 551 session_creator = WorkerSessionCreator( 552 scaffold=scaffold, 553 master=master, 554 config=config, 555 max_wait_secs=max_wait_secs) 556 return MonitoredSession( 557 session_creator=session_creator, 558 hooks=hooks or [], 559 stop_grace_period_secs=stop_grace_period_secs) 560 561 all_hooks = [] 562 if chief_only_hooks: 563 all_hooks.extend(chief_only_hooks) 564 session_creator = ChiefSessionCreator( 565 scaffold=scaffold, 566 checkpoint_dir=checkpoint_dir, 567 master=master, 568 config=config) 569 570 summary_dir = summary_dir or checkpoint_dir 571 if summary_dir: 572 if log_step_count_steps and log_step_count_steps > 0: 573 all_hooks.append( 574 basic_session_run_hooks.StepCounterHook( 575 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 576 577 if (save_summaries_steps and 578 save_summaries_steps > 0) or (save_summaries_secs and 579 save_summaries_secs > 0): 580 all_hooks.append( 581 basic_session_run_hooks.SummarySaverHook( 582 scaffold=scaffold, 583 save_steps=save_summaries_steps, 584 save_secs=save_summaries_secs, 585 output_dir=summary_dir)) 586 587 if checkpoint_dir: 588 if (save_checkpoint_secs and 589 save_checkpoint_secs > 0) or (save_checkpoint_steps and 590 save_checkpoint_steps > 0): 591 all_hooks.append( 592 basic_session_run_hooks.CheckpointSaverHook( 593 checkpoint_dir, 594 save_steps=save_checkpoint_steps, 595 save_secs=save_checkpoint_secs, 596 scaffold=scaffold, 597 save_graph_def=save_graph_def)) 598 599 if hooks: 600 all_hooks.extend(hooks) 601 return MonitoredSession( 602 session_creator=session_creator, 603 hooks=all_hooks, 604 stop_grace_period_secs=stop_grace_period_secs) 605 606 607@tf_export(v1=['train.SessionCreator']) 608@six.add_metaclass(abc.ABCMeta) 609class SessionCreator(object): 610 """A factory for tf.Session.""" 611 612 @abc.abstractmethod 613 def create_session(self): 614 raise NotImplementedError( 615 'create_session is not implemented for {}.'.format(self)) 616 617 618@tf_export(v1=['train.ChiefSessionCreator']) 619class ChiefSessionCreator(SessionCreator): 620 """Creates a tf.compat.v1.Session for a chief.""" 621 622 def __init__(self, 623 scaffold=None, 624 master='', 625 config=None, 626 checkpoint_dir=None, 627 checkpoint_filename_with_path=None): 628 """Initializes a chief session creator. 629 630 Args: 631 scaffold: A `Scaffold` used for gathering or building supportive ops. If 632 not specified a default one is created. It's used to finalize the graph. 633 master: `String` representation of the TensorFlow master to use. 634 config: `ConfigProto` proto used to configure the session. 635 checkpoint_dir: A string. Optional path to a directory where to restore 636 variables. 637 checkpoint_filename_with_path: Full file name path to the checkpoint file. 638 """ 639 self._checkpoint_dir = checkpoint_dir 640 self._checkpoint_filename_with_path = checkpoint_filename_with_path 641 self._scaffold = scaffold or Scaffold() 642 self._session_manager = None 643 self._master = master 644 self._config = config 645 646 def _get_session_manager(self): 647 """Gets or creates a SessionManager.""" 648 if self._session_manager: 649 return self._session_manager 650 651 self._session_manager = sm.SessionManager( 652 local_init_op=self._scaffold.local_init_op, 653 local_init_feed_dict=self._scaffold.local_init_feed_dict, 654 ready_op=self._scaffold.ready_op, 655 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 656 graph=ops.get_default_graph()) 657 return self._session_manager 658 659 def create_session(self): 660 self._scaffold.finalize() 661 return self._get_session_manager().prepare_session( 662 self._master, 663 saver=self._scaffold.saver, 664 checkpoint_dir=self._checkpoint_dir, 665 checkpoint_filename_with_path=self._checkpoint_filename_with_path, 666 config=self._config, 667 init_op=self._scaffold.init_op, 668 init_feed_dict=self._scaffold.init_feed_dict, 669 init_fn=self._scaffold.init_fn) 670 671 672@tf_export(v1=['train.WorkerSessionCreator']) 673class WorkerSessionCreator(SessionCreator): 674 """Creates a tf.compat.v1.Session for a worker.""" 675 676 def __init__(self, 677 scaffold=None, 678 master='', 679 config=None, 680 max_wait_secs=30 * 60): 681 """Initializes a worker session creator. 682 683 Args: 684 scaffold: A `Scaffold` used for gathering or building supportive ops. If 685 not specified a default one is created. It's used to finalize the graph. 686 master: `String` representation of the TensorFlow master to use. 687 config: `ConfigProto` proto used to configure the session. 688 max_wait_secs: Maximum time to wait for the session to become available. 689 """ 690 self._scaffold = scaffold or Scaffold() 691 self._session_manager = None 692 self._master = master 693 self._config = config 694 self._max_wait_secs = max_wait_secs 695 696 def _get_session_manager(self): 697 """Gets or creates a SessionManager.""" 698 if self._session_manager: 699 return self._session_manager 700 701 self._session_manager = sm.SessionManager( 702 local_init_op=self._scaffold.local_init_op, 703 local_init_feed_dict=self._scaffold.local_init_feed_dict, 704 ready_op=self._scaffold.ready_op, 705 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 706 graph=ops.get_default_graph()) 707 return self._session_manager 708 709 def create_session(self): 710 self._scaffold.finalize() 711 return self._get_session_manager().wait_for_session( 712 self._master, config=self._config, max_wait_secs=self._max_wait_secs) 713 714 715class _MonitoredSession(object): 716 """See `MonitoredSession` or `SingularMonitoredSession`.""" 717 718 def __init__(self, 719 session_creator, 720 hooks, 721 should_recover, 722 stop_grace_period_secs=120): 723 """Sets up a Monitored or Hooked Session. 724 725 Args: 726 session_creator: A factory object to create session. Typically a 727 `ChiefSessionCreator` or a `WorkerSessionCreator`. 728 hooks: An iterable of `SessionRunHook' objects. 729 should_recover: A bool. Indicates whether to recover from `AbortedError` 730 and `UnavailableError` or not. 731 stop_grace_period_secs: Number of seconds given to threads to stop after 732 `close()` has been called. 733 """ 734 self._graph_was_finalized = ops.get_default_graph().finalized 735 self._hooks = hooks or [] 736 for h in self._hooks: 737 h.begin() 738 739 worker_context = distribute_coordinator_context.get_current_worker_context() 740 if not session_creator and worker_context: 741 session_creator = worker_context.session_creator() 742 743 # Create the session. 744 self._coordinated_creator = self._CoordinatedSessionCreator( 745 session_creator=session_creator or ChiefSessionCreator(), 746 hooks=self._hooks, 747 stop_grace_period_secs=stop_grace_period_secs) 748 if should_recover: 749 self._sess = _RecoverableSession(self._coordinated_creator) 750 else: 751 self._sess = self._coordinated_creator.create_session() 752 753 @property 754 def graph(self): 755 """The graph that was launched in this session.""" 756 if self._tf_sess() is None: 757 return None 758 return self._tf_sess().graph 759 760 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 761 """Run ops in the monitored session. 762 763 This method is completely compatible with the `tf.Session.run()` method. 764 765 Args: 766 fetches: Same as `tf.Session.run()`. 767 feed_dict: Same as `tf.Session.run()`. 768 options: Same as `tf.Session.run()`. 769 run_metadata: Same as `tf.Session.run()`. 770 771 Returns: 772 Same as `tf.Session.run()`. 773 """ 774 return self._sess.run( 775 fetches, 776 feed_dict=feed_dict, 777 options=options, 778 run_metadata=run_metadata) 779 780 def run_step_fn(self, step_fn): 781 """Run ops using a step function. 782 783 Args: 784 step_fn: A function or a method with a single argument of type 785 `StepContext`. The function may use methods of the argument to perform 786 computations with access to a raw session. The returned value of the 787 `step_fn` will be returned from `run_step_fn`, unless a stop is 788 requested. In that case, the next `should_stop` call will return True. 789 Example usage: 790 ```python 791 with tf.Graph().as_default(): 792 c = tf.compat.v1.placeholder(dtypes.float32) 793 v = tf.add(c, 4.0) 794 w = tf.add(c, 0.5) 795 def step_fn(step_context): 796 a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) 797 if a <= 4.5: 798 step_context.request_stop() 799 return step_context.run_with_hooks(fetches=w, 800 feed_dict={c: 0.1}) 801 802 with tf.MonitoredSession() as session: 803 while not session.should_stop(): 804 a = session.run_step_fn(step_fn) 805 ``` 806 Hooks interact with the `run_with_hooks()` call inside the 807 `step_fn` as they do with a `MonitoredSession.run` call. 808 809 Returns: 810 Returns the returned value of `step_fn`. 811 812 Raises: 813 StopIteration: if `step_fn` has called `request_stop()`. It may be 814 caught by `with tf.MonitoredSession()` to close the session. 815 ValueError: if `step_fn` doesn't have a single argument called 816 `step_context`. It may also optionally have `self` for cases when it 817 belongs to an object. 818 """ 819 step_fn_arguments = function_utils.fn_args(step_fn) 820 if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 821 'self', 822 'step_context', 823 ): 824 raise ValueError( 825 '`step_fn` may either have one `step_context` argument, or' 826 ' `self` and `step_context` arguments if it\'s an instance' 827 ' method. Got {} instead.'.format(step_fn_arguments)) 828 829 # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. 830 # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be 831 # `_CoordinatedSession.run` downstream in either case. This allows 832 # `_PREEMPTION_ERRORS` to propage from within `step_fn` to 833 # `_RecoverableSession.run_step_fn`. 834 return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None) 835 836 class StepContext(object): 837 """Control flow instrument for the `step_fn` from `run_step_fn()`. 838 839 Users of `step_fn` may perform `run()` calls without running hooks 840 by accessing the `session`. A `run()` call with hooks may be performed 841 using `run_with_hooks()`. Computation flow can be interrupted using 842 `request_stop()`. 843 """ 844 845 def __init__(self, session, run_with_hooks_fn): 846 """Initializes the `step_context` argument for a `step_fn` invocation. 847 848 Args: 849 session: An instance of `tf.compat.v1.Session`. 850 run_with_hooks_fn: A function for running fetches and hooks. 851 """ 852 self._session = session 853 self._run_with_hooks_fn = run_with_hooks_fn 854 855 @property 856 def session(self): 857 return self._session 858 859 def run_with_hooks(self, *args, **kwargs): 860 """Same as `MonitoredSession.run`. Accepts the same arguments.""" 861 return self._run_with_hooks_fn(*args, **kwargs) 862 863 def request_stop(self): 864 """Exit the training loop by causing `should_stop()` to return `True`. 865 866 Causes `step_fn` to exit by raising an exception. 867 868 Raises: 869 StopIteration 870 """ 871 raise StopIteration('step_fn has requested the iterations to stop.') 872 873 def should_stop(self): 874 return self._sess is None or self._sess.should_stop() 875 876 def close(self): 877 self._close_internal() 878 879 def __enter__(self): 880 return self 881 882 def __exit__(self, exception_type, exception_value, traceback): 883 if exception_type in [errors.OutOfRangeError, StopIteration]: 884 exception_type = None 885 self._close_internal(exception_type) 886 # __exit__ should return True to suppress an exception. 887 return exception_type is None 888 889 class _CoordinatedSessionCreator(SessionCreator): 890 """Factory for _CoordinatedSession.""" 891 892 def __init__(self, session_creator, hooks, stop_grace_period_secs): 893 self._session_creator = session_creator 894 self._hooks = hooks 895 self.coord = None 896 self.tf_sess = None 897 self._stop_grace_period_secs = stop_grace_period_secs 898 899 def create_session(self): 900 """Creates a coordinated session.""" 901 # Keep the tf_sess for unit testing. 902 self.tf_sess = self._session_creator.create_session() 903 # We don't want coordinator to suppress any exception. 904 self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) 905 if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 906 queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord) 907 # Inform the hooks that a new session has been created. 908 for hook in self._hooks: 909 hook.after_create_session(self.tf_sess, self.coord) 910 return _CoordinatedSession( 911 _HookedSession(self.tf_sess, self._hooks), self.coord, 912 self._stop_grace_period_secs) 913 914 def _close_internal(self, exception_type=None): 915 try: 916 if not exception_type: 917 for h in self._hooks: 918 h.end(self._coordinated_creator.tf_sess) 919 finally: 920 try: 921 if self._sess is None: 922 raise RuntimeError('Session is already closed.') 923 self._sess.close() 924 finally: 925 self._sess = None 926 self._coordinated_creator.tf_sess = None 927 self._coordinated_creator.coord = None 928 if not self._graph_was_finalized: 929 ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access 930 931 def _is_closed(self): 932 """Return True if the monitored session is closed. 933 934 For tests only. 935 936 Returns: 937 A boolean. 938 """ 939 return self._coordinated_creator.tf_sess is None 940 941 def _tf_sess(self): 942 """Return underlying tf.compat.v1.Session object. 943 944 Warning: accessing the returned object in user code is likely to cause races 945 or "flaky tests". 946 947 Returns: 948 A tf.compat.v1.Session object. 949 """ 950 return self._coordinated_creator.tf_sess 951 952 953@tf_export(v1=['train.MonitoredSession']) 954class MonitoredSession(_MonitoredSession): 955 """Session-like object that handles initialization, recovery and hooks. 956 957 Example usage: 958 959 ```python 960 saver_hook = CheckpointSaverHook(...) 961 summary_hook = SummarySaverHook(...) 962 with MonitoredSession(session_creator=ChiefSessionCreator(...), 963 hooks=[saver_hook, summary_hook]) as sess: 964 while not sess.should_stop(): 965 sess.run(train_op) 966 ``` 967 968 Initialization: At creation time the monitored session does following things 969 in given order: 970 971 * calls `hook.begin()` for each given hook 972 * finalizes the graph via `scaffold.finalize()` 973 * create session 974 * initializes the model via initialization ops provided by `Scaffold` 975 * restores variables if a checkpoint exists 976 * launches queue runners 977 * calls `hook.after_create_session()` 978 979 Run: When `run()` is called, the monitored session does following things: 980 981 * calls `hook.before_run()` 982 * calls TensorFlow `session.run()` with merged fetches and feed_dict 983 * calls `hook.after_run()` 984 * returns result of `session.run()` asked by user 985 * if `AbortedError` or `UnavailableError` occurs, it recovers or 986 reinitializes the session before executing the run() call again 987 988 989 Exit: At the `close()`, the monitored session does following things in order: 990 991 * calls `hook.end()` 992 * closes the queue runners and the session 993 * suppresses `OutOfRange` error which indicates that all inputs have been 994 processed if the monitored_session is used as a context 995 996 How to set `tf.compat.v1.Session` arguments: 997 998 * In most cases you can set session arguments as follows: 999 1000 ```python 1001 MonitoredSession( 1002 session_creator=ChiefSessionCreator(master=..., config=...)) 1003 ``` 1004 1005 * In distributed setting for a non-chief worker, you can use following: 1006 1007 ```python 1008 MonitoredSession( 1009 session_creator=WorkerSessionCreator(master=..., config=...)) 1010 ``` 1011 1012 See `MonitoredTrainingSession` for an example usage based on chief or worker. 1013 1014 Note: This is not a `tf.compat.v1.Session`. For example, it cannot do 1015 following: 1016 1017 * it cannot be set as default session. 1018 * it cannot be sent to saver.save. 1019 * it cannot be sent to tf.train.start_queue_runners. 1020 1021 Args: 1022 session_creator: A factory object to create session. Typically a 1023 `ChiefSessionCreator` which is the default one. 1024 hooks: An iterable of `SessionRunHook' objects. 1025 1026 Returns: 1027 A MonitoredSession object. 1028 """ 1029 1030 def __init__(self, 1031 session_creator=None, 1032 hooks=None, 1033 stop_grace_period_secs=120): 1034 super(MonitoredSession, self).__init__( 1035 session_creator, 1036 hooks, 1037 should_recover=True, 1038 stop_grace_period_secs=stop_grace_period_secs) 1039 1040 1041@tf_export(v1=['train.SingularMonitoredSession']) 1042class SingularMonitoredSession(_MonitoredSession): 1043 """Session-like object that handles initialization, restoring, and hooks. 1044 1045 Please note that this utility is not recommended for distributed settings. 1046 For distributed settings, please use `tf.compat.v1.train.MonitoredSession`. 1047 The 1048 differences between `MonitoredSession` and `SingularMonitoredSession` are: 1049 1050 * `MonitoredSession` handles `AbortedError` and `UnavailableError` for 1051 distributed settings, but `SingularMonitoredSession` does not. 1052 * `MonitoredSession` can be created in `chief` or `worker` modes. 1053 `SingularMonitoredSession` is always created as `chief`. 1054 * You can access the raw `tf.compat.v1.Session` object used by 1055 `SingularMonitoredSession`, whereas in MonitoredSession the raw session is 1056 private. This can be used: 1057 - To `run` without hooks. 1058 - To save and restore. 1059 * All other functionality is identical. 1060 1061 Example usage: 1062 ```python 1063 saver_hook = CheckpointSaverHook(...) 1064 summary_hook = SummarySaverHook(...) 1065 with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess: 1066 while not sess.should_stop(): 1067 sess.run(train_op) 1068 ``` 1069 1070 Initialization: At creation time the hooked session does following things 1071 in given order: 1072 1073 * calls `hook.begin()` for each given hook 1074 * finalizes the graph via `scaffold.finalize()` 1075 * create session 1076 * initializes the model via initialization ops provided by `Scaffold` 1077 * restores variables if a checkpoint exists 1078 * launches queue runners 1079 1080 Run: When `run()` is called, the hooked session does following things: 1081 1082 * calls `hook.before_run()` 1083 * calls TensorFlow `session.run()` with merged fetches and feed_dict 1084 * calls `hook.after_run()` 1085 * returns result of `session.run()` asked by user 1086 1087 Exit: At the `close()`, the hooked session does following things in order: 1088 1089 * calls `hook.end()` 1090 * closes the queue runners and the session 1091 * suppresses `OutOfRange` error which indicates that all inputs have been 1092 processed if the `SingularMonitoredSession` is used as a context. 1093 """ 1094 1095 def __init__(self, 1096 hooks=None, 1097 scaffold=None, 1098 master='', 1099 config=None, 1100 checkpoint_dir=None, 1101 stop_grace_period_secs=120, 1102 checkpoint_filename_with_path=None): 1103 """Creates a SingularMonitoredSession. 1104 1105 Args: 1106 hooks: An iterable of `SessionRunHook' objects. 1107 scaffold: A `Scaffold` used for gathering or building supportive ops. If 1108 not specified a default one is created. It's used to finalize the graph. 1109 master: `String` representation of the TensorFlow master to use. 1110 config: `ConfigProto` proto used to configure the session. 1111 checkpoint_dir: A string. Optional path to a directory where to restore 1112 variables. 1113 stop_grace_period_secs: Number of seconds given to threads to stop after 1114 `close()` has been called. 1115 checkpoint_filename_with_path: A string. Optional path to a checkpoint 1116 file from which to restore variables. 1117 """ 1118 session_creator = ChiefSessionCreator( 1119 scaffold=scaffold, 1120 master=master, 1121 config=config, 1122 checkpoint_dir=checkpoint_dir, 1123 checkpoint_filename_with_path=checkpoint_filename_with_path) 1124 super(SingularMonitoredSession, self).__init__( 1125 session_creator, 1126 hooks, 1127 should_recover=False, 1128 stop_grace_period_secs=stop_grace_period_secs) 1129 1130 def raw_session(self): 1131 """Returns underlying `TensorFlow.Session` object.""" 1132 return self._tf_sess() 1133 1134 1135class _WrappedSession(object): 1136 """Wrapper around a `tf.compat.v1.Session`. 1137 1138 This wrapper is used as a base class for various session wrappers 1139 that provide additional functionality such as monitoring, coordination, 1140 and recovery. 1141 1142 In addition to the methods exported by `SessionInterface` the wrapper 1143 provides a method to check for stop and never raises exceptions from 1144 calls to `close()`. 1145 """ 1146 1147 def __init__(self, sess): 1148 """Creates a `_WrappedSession`. 1149 1150 Args: 1151 sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped 1152 session. 1153 """ 1154 self._sess = sess 1155 self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession) 1156 1157 @property 1158 def graph(self): 1159 return self._sess.graph 1160 1161 @property 1162 def sess_str(self): 1163 return self._sess.sess_str 1164 1165 def should_stop(self): 1166 """Return true if this session should not be used anymore. 1167 1168 Always return True if the session was closed. 1169 1170 Returns: 1171 True if the session should stop, False otherwise. 1172 """ 1173 if self._check_stop(): 1174 return True 1175 if self._sess: 1176 return self._wrapped_is_stoppable and self._sess.should_stop() 1177 return True 1178 1179 def _check_stop(self): 1180 """Hook for subclasses to provide their own stop condition. 1181 1182 Returns: 1183 True if the session should stop, False otherwise. 1184 """ 1185 return False 1186 1187 def close(self): 1188 if self._sess: 1189 try: 1190 self._sess.close() 1191 except _PREEMPTION_ERRORS as e: 1192 logging.error( 1193 'An error occurred when attempting to close the ' 1194 'session. This may be due to a preemption in a ' 1195 'connected worker or parameter server. Error: %s', e) 1196 finally: 1197 self._sess = None 1198 1199 def run(self, *args, **kwargs): 1200 return self._sess.run(*args, **kwargs) 1201 1202 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1203 # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`. 1204 # It is `None` when called from `_CoordinatedSession`. In that case 1205 # `self.run` is `_CoordinatedSession.run`. 1206 run_with_hooks = run_with_hooks or self.run 1207 return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks)) 1208 1209 1210class _RecoverableSession(_WrappedSession): 1211 """A wrapped session that recreates a session upon certain kinds of errors. 1212 1213 The constructor is passed a SessionCreator object, not a session. 1214 1215 Calls to `run()` are delegated to the wrapped session. If a call raises the 1216 exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the 1217 wrapped session is closed, and a new one is created by calling the factory 1218 again. 1219 """ 1220 1221 def __init__(self, sess_creator): 1222 """Create a new `_RecoverableSession`. 1223 1224 The value returned by calling `sess_creator.create_session()` will be the 1225 session wrapped by this recoverable session. 1226 1227 Args: 1228 sess_creator: A 'SessionCreator' to be wrapped by recoverable. 1229 """ 1230 self._sess_creator = sess_creator 1231 _WrappedSession.__init__(self, self._create_session()) 1232 1233 def _create_session(self): 1234 while True: 1235 try: 1236 return self._sess_creator.create_session() 1237 except _PREEMPTION_ERRORS as e: 1238 logging.info( 1239 'An error was raised while a session was being created. ' 1240 'This may be due to a preemption of a connected worker ' 1241 'or parameter server. A new session will be created. ' 1242 'This error may also occur due to a gRPC failure caused ' 1243 'by high memory or network bandwidth usage in the ' 1244 'parameter servers. If this error occurs repeatedly, try ' 1245 'increasing the number of parameter servers assigned to ' 1246 'the job. Error: %s', e) 1247 1248 def _check_stop(self): 1249 try: 1250 if self._sess: 1251 return self._sess._check_stop() # pylint: disable=protected-access 1252 else: 1253 return True 1254 except _PREEMPTION_ERRORS as e: 1255 logging.info( 1256 'An error was raised while considering whether the ' 1257 'session is complete. This may be due to a preemption in ' 1258 'a connected worker or parameter server. The current ' 1259 'session will be closed and a new session will be ' 1260 'created. This error may also occur due to a gRPC failure ' 1261 'caused by high memory or network bandwidth usage in the ' 1262 'parameter servers. If this error occurs repeatedly, try ' 1263 'increasing the number of parameter servers assigned to ' 1264 'the job. Error: %s', e) 1265 self.close() 1266 self._sess = self._create_session() 1267 # Since we have just recreated the session, the overall computation should 1268 # not stop: 1269 return False 1270 except Exception: # pylint: disable=broad-except 1271 # `should_stop` should return True instead of raising an exception. 1272 return True 1273 1274 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1275 while True: 1276 try: 1277 if not self._sess: 1278 self._sess = self._create_session() 1279 return self._sess.run( 1280 fetches, 1281 feed_dict=feed_dict, 1282 options=options, 1283 run_metadata=run_metadata) 1284 except _PREEMPTION_ERRORS as e: 1285 logging.info( 1286 'An error was raised. This may be due to a preemption in ' 1287 'a connected worker or parameter server. The current ' 1288 'session will be closed and a new session will be ' 1289 'created. This error may also occur due to a gRPC failure ' 1290 'caused by high memory or network bandwidth usage in the ' 1291 'parameter servers. If this error occurs repeatedly, try ' 1292 'increasing the number of parameter servers assigned to ' 1293 'the job. Error: %s', e) 1294 self.close() 1295 self._sess = None 1296 1297 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1298 while True: 1299 try: 1300 if not self._sess: 1301 self._sess = self._create_session() 1302 1303 run_with_hooks = self._sess.run 1304 return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks) 1305 except _PREEMPTION_ERRORS as e: 1306 logging.info( 1307 'An error was raised. This may be due to a preemption in ' 1308 'a connected worker or parameter server. The current ' 1309 'session will be closed and a new session will be ' 1310 'created. This error may also occur due to a gRPC failure ' 1311 'caused by high memory or network bandwidth usage in the ' 1312 'parameter servers. If this error occurs repeatedly, try ' 1313 'increasing the number of parameter servers assigned to ' 1314 'the job. Error: %s', e) 1315 self.close() 1316 self._sess = None 1317 1318 1319class _CoordinatedSession(_WrappedSession): 1320 """A wrapped session that works with a `tf.Coordinator`. 1321 1322 Calls to `run()` are delegated to the wrapped session. If a call 1323 raises an exception, the exception is reported to the coordinator. 1324 1325 In addition, after each call to `run()` this session ask the coordinator if 1326 the session should stop. In that case it will join all the threads 1327 registered with the coordinator before returning. 1328 1329 If the coordinator was requested to stop with an exception, that exception 1330 will be re-raised from the call to `run()`. 1331 """ 1332 1333 def __init__(self, sess, coord, stop_grace_period_secs=120): 1334 """Create a new `_CoordinatedSession`. 1335 1336 Args: 1337 sess: A `tf.compat.v1.Session` object. The wrapped session. 1338 coord: A `tf.train.Coordinator` object. 1339 stop_grace_period_secs: Number of seconds given to threads to stop after 1340 `close()` has been called. 1341 """ 1342 _WrappedSession.__init__(self, sess) 1343 self._coord = coord 1344 self._stop_grace_period_secs = stop_grace_period_secs 1345 1346 def _check_stop(self): 1347 # If the coordinator was asked to stop due to an exception, then it needs 1348 # to be propagated to this stack. 1349 self._coord.raise_requested_exception() 1350 # At this point, no exceptions are recorded in the coordinator. 1351 return self._coord.should_stop() 1352 1353 def close(self): 1354 self._coord.request_stop() 1355 try: 1356 self._coord.join( 1357 stop_grace_period_secs=self._stop_grace_period_secs, 1358 ignore_live_threads=True) 1359 finally: 1360 try: 1361 _WrappedSession.close(self) 1362 except Exception: # pylint: disable=broad-except 1363 # We intentionally suppress exceptions from the close() here since 1364 # useful exceptions are already reported by join(). 1365 pass 1366 1367 def run(self, *args, **kwargs): 1368 try: 1369 return self._sess.run(*args, **kwargs) 1370 except _PREEMPTION_ERRORS: 1371 raise 1372 except Exception: # pylint: disable=broad-except 1373 # A non-preemption error could have been caused by a preemption error 1374 # in the coordinator. If this is the case, raise that exception instead, 1375 # since it's the root cause. Otherwise, stick to the `original_exc_info`. 1376 original_exc_info = sys.exc_info() 1377 try: 1378 self._coord.raise_requested_exception() 1379 except _PREEMPTION_ERRORS: 1380 raise 1381 except Exception: # pylint: disable=broad-except 1382 raise six.reraise(*original_exc_info) 1383 else: 1384 raise six.reraise(*original_exc_info) 1385 1386 1387class _HookedSession(_WrappedSession): 1388 """A _WrappedSession that calls hooks during calls to run(). 1389 1390 The list of hooks to call is passed in the constructor. Before each call 1391 to `run()` the session calls the `before_run()` method of the hooks, which 1392 can return additional ops or tensors to run. These are added to the arguments 1393 of the call to `run()`. 1394 1395 When the `run()` call finishes, the session calls the `after_run()` methods of 1396 the hooks, passing the values returned by the `run()` call corresponding to 1397 the ops and tensors that each hook requested. 1398 1399 If any call to the hooks, requests stop via run_context the session will be 1400 marked as needing to stop and its `should_stop()` method will now return 1401 `True`. 1402 """ 1403 1404 def __init__(self, sess, hooks): 1405 """Initializes a _HookedSession object. 1406 1407 Args: 1408 sess: A `tf.compat.v1.Session` or a `_WrappedSession` object. 1409 hooks: An iterable of `SessionRunHook' objects. 1410 """ 1411 1412 _WrappedSession.__init__(self, sess) 1413 self._hooks = hooks 1414 self._should_stop = False 1415 1416 def _check_stop(self): 1417 """See base class.""" 1418 return self._should_stop 1419 1420 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1421 """See base class.""" 1422 if self.should_stop(): 1423 raise RuntimeError('Run called even after should_stop requested.') 1424 1425 actual_fetches = {'caller': fetches} 1426 1427 run_context = session_run_hook.SessionRunContext( 1428 original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), 1429 session=self._sess) 1430 1431 options = options or config_pb2.RunOptions() 1432 feed_dict = self._call_hook_before_run(run_context, actual_fetches, 1433 feed_dict, options) 1434 1435 # Do session run. 1436 run_metadata = run_metadata or config_pb2.RunMetadata() 1437 outputs = _WrappedSession.run( 1438 self, 1439 fetches=actual_fetches, 1440 feed_dict=feed_dict, 1441 options=options, 1442 run_metadata=run_metadata) 1443 1444 for hook in self._hooks: 1445 hook.after_run( 1446 run_context, 1447 session_run_hook.SessionRunValues( 1448 results=outputs[hook] if hook in outputs else None, 1449 options=options, 1450 run_metadata=run_metadata)) 1451 self._should_stop = self._should_stop or run_context.stop_requested 1452 1453 return outputs['caller'] 1454 1455 def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict, 1456 options): 1457 """Calls hooks.before_run and handles requests from hooks.""" 1458 hook_feeds = {} 1459 for hook in self._hooks: 1460 request = hook.before_run(run_context) 1461 if request is not None: 1462 if request.fetches is not None: 1463 fetch_dict[hook] = request.fetches 1464 if request.feed_dict: 1465 self._raise_if_feeds_intersects(hook_feeds, request.feed_dict, 1466 'Same tensor is fed by two hooks.') 1467 hook_feeds.update(request.feed_dict) 1468 if request.options: 1469 self._merge_run_options(options, request.options) 1470 1471 if not hook_feeds: 1472 return user_feed_dict 1473 1474 if not user_feed_dict: 1475 return hook_feeds 1476 1477 self._raise_if_feeds_intersects( 1478 user_feed_dict, hook_feeds, 1479 'Same tensor is fed by a SessionRunHook and user.') 1480 hook_feeds.update(user_feed_dict) 1481 return hook_feeds 1482 1483 def _raise_if_feeds_intersects(self, feeds1, feeds2, message): 1484 intersection = set(feeds1.keys()) & set(feeds2.keys()) 1485 if intersection: 1486 raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection))) 1487 1488 def _merge_run_options(self, options, incoming_options): 1489 """Merge two instances of RunOptions into the first one. 1490 1491 During the merger, the numerical fields including trace_level, 1492 timeout_in_ms, inter_op_thread_pool are set to the larger one of the two. 1493 The boolean value is set to the logical OR of the two. 1494 debug_tensor_watch_opts of the original options is extended with that from 1495 the incoming one. 1496 1497 Args: 1498 options: The options to merge into. 1499 incoming_options: The options to be merged into the first argument. 1500 """ 1501 options.trace_level = max(options.trace_level, incoming_options.trace_level) 1502 options.timeout_in_ms = max(options.timeout_in_ms, 1503 incoming_options.timeout_in_ms) 1504 options.inter_op_thread_pool = max(options.inter_op_thread_pool, 1505 incoming_options.inter_op_thread_pool) 1506 options.output_partition_graphs = max( 1507 options.output_partition_graphs, 1508 incoming_options.output_partition_graphs) 1509 options.debug_options.debug_tensor_watch_opts.extend( 1510 incoming_options.debug_options.debug_tensor_watch_opts) 1511 options.debug_options.reset_disk_byte_usage = ( 1512 options.debug_options.reset_disk_byte_usage or 1513 incoming_options.debug_options.reset_disk_byte_usage) 1514 options.report_tensor_allocations_upon_oom = ( 1515 options.report_tensor_allocations_upon_oom or 1516 incoming_options.report_tensor_allocations_upon_oom) 1517