1# pylint: disable=g-bad-file-header 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""A wrapper of Session API which runs hooks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import abc 23import os 24import sys 25 26import six 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.distribute import distribute_coordinator_context 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import lookup_ops 35from tensorflow.python.ops import resources 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.summary import summary 39from tensorflow.python.training import basic_session_run_hooks 40from tensorflow.python.training import coordinator 41from tensorflow.python.training import queue_runner 42from tensorflow.python.training import saver as training_saver 43from tensorflow.python.training import session_manager as sm 44from tensorflow.python.training import session_run_hook 45from tensorflow.python.training.tracking import graph_view 46from tensorflow.python.training.tracking import util as trackable_util 47from tensorflow.python.util import function_utils 48from tensorflow.python.util.tf_export import tf_export 49 50# The list of exceptions that we should recover from. Exceptions not in this 51# list may terminate the job. 52_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) 53 54# Value that indicates no value was provided. 55USE_DEFAULT = object() 56 57 58@tf_export(v1=['train.Scaffold']) 59class Scaffold(object): 60 """Structure to create or gather pieces commonly needed to train a model. 61 62 When you build a model for training you usually need ops to initialize 63 variables, a `Saver` to checkpoint them, an op to collect summaries for 64 the visualizer, and so on. 65 66 Various libraries built on top of the core TensorFlow library take care of 67 creating some or all of these pieces and storing them in well known 68 collections in the graph. The `Scaffold` class helps pick these pieces from 69 the graph collections, creating and adding them to the collections if needed. 70 71 If you call the scaffold constructor without any arguments, it will pick 72 pieces from the collections, creating default ones if needed when 73 `scaffold.finalize()` is called. You can pass arguments to the constructor to 74 provide your own pieces. Pieces that you pass to the constructor are not 75 added to the graph collections. 76 77 The following pieces are directly accessible as attributes of the `Scaffold` 78 object: 79 80 * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the 81 variables. 82 Picked from and stored into the `SAVERS` collection in the graph by default. 83 * `init_op`: An op to run to initialize the variables. Picked from and 84 stored into the `INIT_OP` collection in the graph by default. 85 * `ready_op`: An op to verify that the variables are initialized. Picked 86 from and stored into the `READY_OP` collection in the graph by default. 87 * `ready_for_local_init_op`: An op to verify that global state has been 88 initialized and it is alright to run `local_init_op`. Picked from and 89 stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by 90 default. This is needed when the initialization of local variables depends 91 on the values of global variables. 92 * `local_init_op`: An op to initialize the local variables. Picked 93 from and stored into the `LOCAL_INIT_OP` collection in the graph by default. 94 * `summary_op`: An op to run and merge the summaries in the graph. Picked 95 from and stored into the `SUMMARY_OP` collection in the graph by default. 96 97 You can also pass the following additional pieces to the constructor: 98 99 * `init_feed_dict`: A session feed dictionary that should be used when 100 running the init op. 101 * `init_fn`: A callable to run after the init op to perform additional 102 initializations. The callable will be called as 103 `init_fn(scaffold, session)`. 104 105 """ 106 107 def __init__(self, 108 init_op=None, 109 init_feed_dict=None, 110 init_fn=None, 111 ready_op=None, 112 ready_for_local_init_op=None, 113 local_init_op=None, 114 summary_op=None, 115 saver=None, 116 copy_from_scaffold=None, 117 local_init_feed_dict=None): 118 """Create a scaffold. 119 120 Args: 121 init_op: Optional op for initializing variables. 122 init_feed_dict: Optional session feed dictionary to use when running the 123 init_op. 124 init_fn: Optional function to use to initialize the model after running 125 the init_op. Will be called as `init_fn(scaffold, session)`. 126 ready_op: Optional op to verify that the variables are initialized. Must 127 return an empty 1D string tensor when the variables are initialized, or 128 a non-empty 1D string tensor listing the names of the non-initialized 129 variables. 130 ready_for_local_init_op: Optional op to verify that the global variables 131 are initialized and `local_init_op` can be run. Must return an empty 1D 132 string tensor when the global variables are initialized, or a non-empty 133 1D string tensor listing the names of the non-initialized global 134 variables. 135 local_init_op: Optional op to initialize local variables. 136 summary_op: Optional op to gather all summaries. Must return a scalar 137 string tensor containing a serialized `Summary` proto. 138 saver: Optional `tf.compat.v1.train.Saver` object to use to save and 139 restore variables. May also be a `tf.train.Checkpoint` object, in which 140 case object-based checkpoints are saved. This will also load some 141 object-based checkpoints saved from elsewhere, but that loading may be 142 fragile since it uses fixed keys rather than performing a full 143 graph-based match. For example if a variable has two paths from the 144 `Checkpoint` object because two `Model` objects share the `Layer` object 145 that owns it, removing one `Model` may change the keys and break 146 checkpoint loading through this API, whereas a graph-based match would 147 match the variable through the other `Model`. 148 copy_from_scaffold: Optional scaffold object to copy fields from. Its 149 fields will be overwritten by the provided fields in this function. 150 local_init_feed_dict: Optional session feed dictionary to use when running 151 the local_init_op. 152 """ 153 if copy_from_scaffold is not None: 154 if not isinstance(copy_from_scaffold, Scaffold): 155 raise TypeError('copy_from_scaffold is not a Scaffold instance.') 156 # We need _coalesce since Tensor is not converted to bool automatically, 157 # so the common idiom of (a or b) does not work. 158 coalesce = lambda a, b: a if a is not None else b 159 init_op = coalesce(init_op, copy_from_scaffold.init_op) 160 init_feed_dict = coalesce(init_feed_dict, 161 copy_from_scaffold.init_feed_dict) 162 # Use the original init_fn provided by the user to init the new Scaffold. 163 init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access 164 ready_op = coalesce(ready_op, copy_from_scaffold.ready_op) 165 ready_for_local_init_op = coalesce( 166 ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) 167 local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) 168 local_init_feed_dict = coalesce(local_init_feed_dict, 169 copy_from_scaffold.local_init_feed_dict) 170 summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) 171 saver = coalesce(saver, copy_from_scaffold.saver) 172 173 # NOTE(touts): modifying the init function to be passed the scaffold is a 174 # hack to make it easy to find the saver. Is there a better way? 175 self._user_init_fn = init_fn 176 if init_fn: 177 self._init_fn = lambda sess: init_fn(self, sess) 178 else: 179 self._init_fn = None 180 181 self._init_op = init_op 182 self._init_feed_dict = init_feed_dict 183 self._ready_op = ready_op 184 self._ready_for_local_init_op = ready_for_local_init_op 185 self._local_init_op = local_init_op 186 self._local_init_feed_dict = local_init_feed_dict 187 self._summary_op = summary_op 188 self._saver = saver 189 190 def finalize(self): 191 """Creates operations if needed and finalizes the graph.""" 192 if self._init_op is None: 193 194 def default_init_op(): 195 return control_flow_ops.group( 196 variables.global_variables_initializer(), 197 resources.initialize_resources(resources.shared_resources()), 198 ops.get_collection('saved_model_initializers')) 199 200 self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP, 201 default_init_op) 202 if self._ready_op is None: 203 204 def default_ready_op(): 205 return array_ops.concat([ 206 variables.report_uninitialized_variables(), 207 resources.report_uninitialized_resources() 208 ], 0) 209 210 self._ready_op = Scaffold.get_or_default('ready_op', 211 ops.GraphKeys.READY_OP, 212 default_ready_op) 213 if self._ready_for_local_init_op is None: 214 215 def default_ready_for_local_init_op(): 216 return array_ops.concat([ 217 variables.report_uninitialized_variables( 218 variables.global_variables()), 219 resources.report_uninitialized_resources( 220 resources.shared_resources()) 221 ], 0) 222 223 self._ready_for_local_init_op = Scaffold.get_or_default( 224 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, 225 default_ready_for_local_init_op) 226 if self._local_init_op is None: 227 self._local_init_op = Scaffold.get_or_default( 228 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, 229 Scaffold.default_local_init_op) 230 if self._summary_op is None: 231 self._summary_op = Scaffold.get_or_default('summary_op', 232 ops.GraphKeys.SUMMARY_OP, 233 summary.merge_all) 234 # pylint: disable=g-long-lambda 235 if self._saver is None: 236 self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access 237 # pylint: enable=g-long-lambda 238 if isinstance(self._saver, trackable_util.Checkpoint): 239 self._saver = training_saver.Saver( 240 var_list=graph_view.ObjectGraphView( 241 self._saver).frozen_saveable_objects(), 242 sharded=True) 243 else: 244 self._saver.build() 245 246 ops.get_default_graph().finalize() 247 logging.info('Graph was finalized.') 248 return self 249 250 @property 251 def init_fn(self): 252 return self._init_fn 253 254 @property 255 def init_op(self): 256 return self._init_op 257 258 @property 259 def ready_op(self): 260 return self._ready_op 261 262 @property 263 def ready_for_local_init_op(self): 264 return self._ready_for_local_init_op 265 266 @property 267 def local_init_op(self): 268 return self._local_init_op 269 270 @property 271 def local_init_feed_dict(self): 272 return self._local_init_feed_dict 273 274 @property 275 def summary_op(self): 276 return self._summary_op 277 278 @property 279 def saver(self): 280 return self._saver 281 282 @property 283 def init_feed_dict(self): 284 return self._init_feed_dict 285 286 @staticmethod 287 def get_or_default(arg_name, collection_key, default_constructor): 288 """Get from cache or create a default operation.""" 289 elements = ops.get_collection(collection_key) 290 if elements: 291 if len(elements) > 1: 292 raise RuntimeError( 293 'More than one item in the collection "%s". ' 294 'Please indicate which one to use by passing it to ' 295 'the tf.Scaffold constructor as: ' 296 'tf.Scaffold(%s=item to use)', collection_key, arg_name) 297 return elements[0] 298 op = default_constructor() 299 if op is not None: 300 ops.add_to_collection(collection_key, op) 301 return op 302 303 @staticmethod 304 def default_local_init_op(): 305 """Returns an op that groups the default local init ops. 306 307 This op is used during session initialization when a Scaffold is 308 initialized without specifying the local_init_op arg. It includes 309 `tf.compat.v1.local_variables_initializer`, 310 `tf.compat.v1.tables_initializer`, and also 311 initializes local session resources. 312 313 Returns: 314 The default Scaffold local init op. 315 """ 316 return control_flow_ops.group( 317 variables.local_variables_initializer(), 318 lookup_ops.tables_initializer(), 319 resources.initialize_resources(resources.local_resources())) 320 321 322def _create_monitored_session_with_worker_context( 323 worker_context, # pylint: disable=missing-docstring 324 scaffold, 325 checkpoint_dir=None, 326 hooks=None, 327 chief_only_hooks=None, 328 save_checkpoint_secs=None, 329 save_summaries_steps=None, 330 save_summaries_secs=None, 331 config=None, 332 stop_grace_period_secs=120, 333 log_step_count_steps=100, 334 max_wait_secs=7200, 335 save_checkpoint_steps=None, 336 summary_dir=None, 337 save_graph_def=True): 338 all_hooks = [] 339 if hooks: 340 all_hooks.extend(hooks) 341 if chief_only_hooks and worker_context.is_chief: 342 all_hooks.extend(chief_only_hooks) 343 344 # We need to call save or summary ops on all workers since these ops may 345 # contain collective ops, only running save ops on some workers would make 346 # collective ops hang. Therefore on those workers that don't need to actually 347 # write checkpoints or summaries, we let them write to a temp directory. 348 # pylint: disable=protected-access 349 if type( 350 worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy', 351 'CollectiveAllReduceStrategyV1', 352 'MultiWorkerMirroredStrategy'): 353 if worker_context.task_type: 354 tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id) 355 else: 356 tmpdir = 'tmp' 357 358 if save_checkpoint_secs: 359 logging.warning('Collective ops may deadlock with ' 360 '`save_checkpoints_secs` please use ' 361 '`save_checkpoint_steps` instead. Clearing ' 362 '`save_checkpoint_secs` and setting ' 363 '`save_checkpoint_steps` to 1000 now.') 364 save_checkpoint_secs = None 365 save_checkpoint_steps = 1000 366 if save_summaries_secs: 367 logging.warning('Collective ops may run out of sync with' 368 '`save_summaries_secs`, please use ' 369 '`save_summaries_steps` instead.') 370 else: 371 tmpdir = None 372 373 summary_dir = summary_dir or checkpoint_dir 374 if summary_dir and log_step_count_steps and log_step_count_steps > 0: 375 if worker_context.should_save_summary: 376 all_hooks.append( 377 basic_session_run_hooks.StepCounterHook( 378 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 379 elif tmpdir: 380 all_hooks.append( 381 basic_session_run_hooks.StepCounterHook( 382 output_dir=os.path.join(summary_dir, tmpdir), 383 every_n_steps=log_step_count_steps)) 384 385 if (((save_summaries_steps and save_summaries_steps > 0) or 386 (save_summaries_secs and save_summaries_secs > 0)) and summary_dir): 387 if worker_context.should_save_summary: 388 all_hooks.append( 389 basic_session_run_hooks.SummarySaverHook( 390 scaffold=scaffold, 391 save_steps=save_summaries_steps, 392 save_secs=save_summaries_secs, 393 output_dir=summary_dir)) 394 elif tmpdir: 395 all_hooks.append( 396 basic_session_run_hooks.SummarySaverHook( 397 scaffold=scaffold, 398 save_steps=save_summaries_steps, 399 save_secs=save_summaries_secs, 400 output_dir=os.path.join(summary_dir, tmpdir))) 401 402 if (((save_checkpoint_secs and save_checkpoint_secs > 0) or 403 (save_checkpoint_steps and save_checkpoint_steps > 0)) and 404 checkpoint_dir): 405 if worker_context.should_checkpoint: 406 all_hooks.append( 407 basic_session_run_hooks.CheckpointSaverHook( 408 checkpoint_dir, 409 save_steps=save_checkpoint_steps, 410 save_secs=save_checkpoint_secs, 411 scaffold=scaffold, 412 save_graph_def=save_graph_def)) 413 elif tmpdir: 414 all_hooks.append( 415 basic_session_run_hooks.CheckpointSaverHook( 416 os.path.join(checkpoint_dir, tmpdir), 417 save_steps=save_checkpoint_steps, 418 save_secs=save_checkpoint_secs, 419 scaffold=scaffold, 420 save_graph_def=save_graph_def)) 421 422 logging.info('all_hooks %r', all_hooks) 423 session_creator = worker_context.session_creator( 424 scaffold, 425 config=config, 426 checkpoint_dir=checkpoint_dir, 427 max_wait_secs=max_wait_secs) 428 return MonitoredSession( 429 session_creator=session_creator, 430 hooks=all_hooks, 431 stop_grace_period_secs=stop_grace_period_secs) 432 433 434@tf_export(v1=['train.MonitoredTrainingSession']) 435def MonitoredTrainingSession( 436 master='', # pylint: disable=invalid-name 437 is_chief=True, 438 checkpoint_dir=None, 439 scaffold=None, 440 hooks=None, 441 chief_only_hooks=None, 442 save_checkpoint_secs=USE_DEFAULT, 443 save_summaries_steps=USE_DEFAULT, 444 save_summaries_secs=USE_DEFAULT, 445 config=None, 446 stop_grace_period_secs=120, 447 log_step_count_steps=100, 448 max_wait_secs=7200, 449 save_checkpoint_steps=USE_DEFAULT, 450 summary_dir=None, 451 save_graph_def=True): 452 """Creates a `MonitoredSession` for training. 453 454 For a chief, this utility sets proper session initializer/restorer. It also 455 creates hooks related to checkpoint and summary saving. For workers, this 456 utility sets proper session creator which waits for the chief to 457 initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for 458 more 459 information. 460 461 462 Args: 463 master: `String` the TensorFlow master to use. 464 is_chief: If `True`, it will take care of initialization and recovery the 465 underlying TensorFlow session. If `False`, it will wait on a chief to 466 initialize or recover the TensorFlow session. 467 checkpoint_dir: A string. Optional path to a directory where to restore 468 variables. 469 scaffold: A `Scaffold` used for gathering or building supportive ops. If not 470 specified, a default one is created. It's used to finalize the graph. 471 hooks: Optional list of `SessionRunHook` objects. 472 chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if 473 `is_chief==True`, ignore otherwise. 474 save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved 475 using a default checkpoint saver. If both `save_checkpoint_steps` and 476 `save_checkpoint_secs` are set to `None`, then the default checkpoint 477 saver isn't used. If both are provided, then only `save_checkpoint_secs` 478 is used. Default 600. 479 save_summaries_steps: The frequency, in number of global steps, that the 480 summaries are written to disk using a default summary saver. If both 481 `save_summaries_steps` and `save_summaries_secs` are set to `None`, then 482 the default summary saver isn't used. Default 100. 483 save_summaries_secs: The frequency, in secs, that the summaries are written 484 to disk using a default summary saver. If both `save_summaries_steps` and 485 `save_summaries_secs` are set to `None`, then the default summary saver 486 isn't used. Default not enabled. 487 config: an instance of `tf.compat.v1.ConfigProto` proto used to configure 488 the session. It's the `config` argument of constructor of 489 `tf.compat.v1.Session`. 490 stop_grace_period_secs: Number of seconds given to threads to stop after 491 `close()` has been called. 492 log_step_count_steps: The frequency, in number of global steps, that the 493 global step/sec is logged. 494 max_wait_secs: Maximum time workers should wait for the session to become 495 available. This should be kept relatively short to help detect incorrect 496 code, but sometimes may need to be increased if the chief takes a while to 497 start up. 498 save_checkpoint_steps: The frequency, in number of global steps, that a 499 checkpoint is saved using a default checkpoint saver. If both 500 `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then 501 the default checkpoint saver isn't used. If both are provided, then only 502 `save_checkpoint_secs` is used. Default not enabled. 503 summary_dir: A string. Optional path to a directory where to save 504 summaries. If None, checkpoint_dir is used instead. 505 save_graph_def: Whether to save the GraphDef and MetaGraphDef to 506 `checkpoint_dir`. The GraphDef is saved after the session is created as 507 `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as 508 `model.ckpt-*.meta`. 509 510 Returns: 511 A `MonitoredSession` object. 512 """ 513 if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: 514 save_summaries_steps = 100 515 save_summaries_secs = None 516 elif save_summaries_secs == USE_DEFAULT: 517 save_summaries_secs = None 518 elif save_summaries_steps == USE_DEFAULT: 519 save_summaries_steps = None 520 521 if (save_checkpoint_steps == USE_DEFAULT and 522 save_checkpoint_secs == USE_DEFAULT): 523 save_checkpoint_steps = None 524 save_checkpoint_secs = 600 525 elif save_checkpoint_secs == USE_DEFAULT: 526 save_checkpoint_secs = None 527 elif save_checkpoint_steps == USE_DEFAULT: 528 save_checkpoint_steps = None 529 530 scaffold = scaffold or Scaffold() 531 worker_context = distribute_coordinator_context.get_current_worker_context() 532 533 if worker_context: 534 return _create_monitored_session_with_worker_context( 535 worker_context, 536 scaffold, 537 checkpoint_dir=checkpoint_dir, 538 hooks=hooks, 539 chief_only_hooks=chief_only_hooks, 540 save_checkpoint_secs=save_checkpoint_secs, 541 save_summaries_steps=save_summaries_steps, 542 save_summaries_secs=save_summaries_secs, 543 config=config, 544 stop_grace_period_secs=stop_grace_period_secs, 545 log_step_count_steps=log_step_count_steps, 546 max_wait_secs=max_wait_secs, 547 save_checkpoint_steps=save_checkpoint_steps, 548 summary_dir=summary_dir, 549 save_graph_def=save_graph_def) 550 551 if not is_chief: 552 session_creator = WorkerSessionCreator( 553 scaffold=scaffold, 554 master=master, 555 config=config, 556 max_wait_secs=max_wait_secs) 557 return MonitoredSession( 558 session_creator=session_creator, 559 hooks=hooks or [], 560 stop_grace_period_secs=stop_grace_period_secs) 561 562 all_hooks = [] 563 if chief_only_hooks: 564 all_hooks.extend(chief_only_hooks) 565 session_creator = ChiefSessionCreator( 566 scaffold=scaffold, 567 checkpoint_dir=checkpoint_dir, 568 master=master, 569 config=config) 570 571 summary_dir = summary_dir or checkpoint_dir 572 if summary_dir: 573 if log_step_count_steps and log_step_count_steps > 0: 574 all_hooks.append( 575 basic_session_run_hooks.StepCounterHook( 576 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 577 578 if (save_summaries_steps and 579 save_summaries_steps > 0) or (save_summaries_secs and 580 save_summaries_secs > 0): 581 all_hooks.append( 582 basic_session_run_hooks.SummarySaverHook( 583 scaffold=scaffold, 584 save_steps=save_summaries_steps, 585 save_secs=save_summaries_secs, 586 output_dir=summary_dir)) 587 588 if checkpoint_dir: 589 if (save_checkpoint_secs and 590 save_checkpoint_secs > 0) or (save_checkpoint_steps and 591 save_checkpoint_steps > 0): 592 all_hooks.append( 593 basic_session_run_hooks.CheckpointSaverHook( 594 checkpoint_dir, 595 save_steps=save_checkpoint_steps, 596 save_secs=save_checkpoint_secs, 597 scaffold=scaffold, 598 save_graph_def=save_graph_def)) 599 600 if hooks: 601 all_hooks.extend(hooks) 602 return MonitoredSession( 603 session_creator=session_creator, 604 hooks=all_hooks, 605 stop_grace_period_secs=stop_grace_period_secs) 606 607 608@tf_export(v1=['train.SessionCreator']) 609@six.add_metaclass(abc.ABCMeta) 610class SessionCreator(object): 611 """A factory for tf.Session.""" 612 613 @abc.abstractmethod 614 def create_session(self): 615 raise NotImplementedError( 616 'create_session is not implemented for {}.'.format(self)) 617 618 619@tf_export(v1=['train.ChiefSessionCreator']) 620class ChiefSessionCreator(SessionCreator): 621 """Creates a tf.compat.v1.Session for a chief.""" 622 623 def __init__(self, 624 scaffold=None, 625 master='', 626 config=None, 627 checkpoint_dir=None, 628 checkpoint_filename_with_path=None): 629 """Initializes a chief session creator. 630 631 Args: 632 scaffold: A `Scaffold` used for gathering or building supportive ops. If 633 not specified a default one is created. It's used to finalize the graph. 634 master: `String` representation of the TensorFlow master to use. 635 config: `ConfigProto` proto used to configure the session. 636 checkpoint_dir: A string. Optional path to a directory where to restore 637 variables. 638 checkpoint_filename_with_path: Full file name path to the checkpoint file. 639 """ 640 self._checkpoint_dir = checkpoint_dir 641 self._checkpoint_filename_with_path = checkpoint_filename_with_path 642 self._scaffold = scaffold or Scaffold() 643 self._session_manager = None 644 self._master = master 645 self._config = config 646 647 def _get_session_manager(self): 648 """Gets or creates a SessionManager.""" 649 if self._session_manager: 650 return self._session_manager 651 652 self._session_manager = sm.SessionManager( 653 local_init_op=self._scaffold.local_init_op, 654 local_init_feed_dict=self._scaffold.local_init_feed_dict, 655 ready_op=self._scaffold.ready_op, 656 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 657 graph=ops.get_default_graph()) 658 return self._session_manager 659 660 def create_session(self): 661 self._scaffold.finalize() 662 return self._get_session_manager().prepare_session( 663 self._master, 664 saver=self._scaffold.saver, 665 checkpoint_dir=self._checkpoint_dir, 666 checkpoint_filename_with_path=self._checkpoint_filename_with_path, 667 config=self._config, 668 init_op=self._scaffold.init_op, 669 init_feed_dict=self._scaffold.init_feed_dict, 670 init_fn=self._scaffold.init_fn) 671 672 673@tf_export(v1=['train.WorkerSessionCreator']) 674class WorkerSessionCreator(SessionCreator): 675 """Creates a tf.compat.v1.Session for a worker.""" 676 677 def __init__(self, 678 scaffold=None, 679 master='', 680 config=None, 681 max_wait_secs=30 * 60): 682 """Initializes a worker session creator. 683 684 Args: 685 scaffold: A `Scaffold` used for gathering or building supportive ops. If 686 not specified a default one is created. It's used to finalize the graph. 687 master: `String` representation of the TensorFlow master to use. 688 config: `ConfigProto` proto used to configure the session. 689 max_wait_secs: Maximum time to wait for the session to become available. 690 """ 691 self._scaffold = scaffold or Scaffold() 692 self._session_manager = None 693 self._master = master 694 self._config = config 695 self._max_wait_secs = max_wait_secs 696 697 def _get_session_manager(self): 698 """Gets or creates a SessionManager.""" 699 if self._session_manager: 700 return self._session_manager 701 702 self._session_manager = sm.SessionManager( 703 local_init_op=self._scaffold.local_init_op, 704 local_init_feed_dict=self._scaffold.local_init_feed_dict, 705 ready_op=self._scaffold.ready_op, 706 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 707 graph=ops.get_default_graph()) 708 return self._session_manager 709 710 def create_session(self): 711 self._scaffold.finalize() 712 return self._get_session_manager().wait_for_session( 713 self._master, config=self._config, max_wait_secs=self._max_wait_secs) 714 715 716class _MonitoredSession(object): 717 """See `MonitoredSession` or `SingularMonitoredSession`.""" 718 719 def __init__(self, 720 session_creator, 721 hooks, 722 should_recover, 723 stop_grace_period_secs=120): 724 """Sets up a Monitored or Hooked Session. 725 726 Args: 727 session_creator: A factory object to create session. Typically a 728 `ChiefSessionCreator` or a `WorkerSessionCreator`. 729 hooks: An iterable of `SessionRunHook' objects. 730 should_recover: A bool. Indicates whether to recover from `AbortedError` 731 and `UnavailableError` or not. 732 stop_grace_period_secs: Number of seconds given to threads to stop after 733 `close()` has been called. 734 """ 735 self._graph_was_finalized = ops.get_default_graph().finalized 736 self._hooks = hooks or [] 737 for h in self._hooks: 738 h.begin() 739 740 worker_context = distribute_coordinator_context.get_current_worker_context() 741 if not session_creator and worker_context: 742 session_creator = worker_context.session_creator() 743 744 # Create the session. 745 self._coordinated_creator = self._CoordinatedSessionCreator( 746 session_creator=session_creator or ChiefSessionCreator(), 747 hooks=self._hooks, 748 stop_grace_period_secs=stop_grace_period_secs) 749 if should_recover: 750 self._sess = _RecoverableSession(self._coordinated_creator) 751 else: 752 self._sess = self._coordinated_creator.create_session() 753 754 @property 755 def graph(self): 756 """The graph that was launched in this session.""" 757 if self._tf_sess() is None: 758 return None 759 return self._tf_sess().graph 760 761 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 762 """Run ops in the monitored session. 763 764 This method is completely compatible with the `tf.Session.run()` method. 765 766 Args: 767 fetches: Same as `tf.Session.run()`. 768 feed_dict: Same as `tf.Session.run()`. 769 options: Same as `tf.Session.run()`. 770 run_metadata: Same as `tf.Session.run()`. 771 772 Returns: 773 Same as `tf.Session.run()`. 774 """ 775 return self._sess.run( 776 fetches, 777 feed_dict=feed_dict, 778 options=options, 779 run_metadata=run_metadata) 780 781 def run_step_fn(self, step_fn): 782 """Run ops using a step function. 783 784 Args: 785 step_fn: A function or a method with a single argument of type 786 `StepContext`. The function may use methods of the argument to perform 787 computations with access to a raw session. The returned value of the 788 `step_fn` will be returned from `run_step_fn`, unless a stop is 789 requested. In that case, the next `should_stop` call will return True. 790 Example usage: 791 ```python 792 with tf.Graph().as_default(): 793 c = tf.compat.v1.placeholder(dtypes.float32) 794 v = tf.add(c, 4.0) 795 w = tf.add(c, 0.5) 796 def step_fn(step_context): 797 a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) 798 if a <= 4.5: 799 step_context.request_stop() 800 return step_context.run_with_hooks(fetches=w, 801 feed_dict={c: 0.1}) 802 803 with tf.MonitoredSession() as session: 804 while not session.should_stop(): 805 a = session.run_step_fn(step_fn) 806 ``` 807 Hooks interact with the `run_with_hooks()` call inside the 808 `step_fn` as they do with a `MonitoredSession.run` call. 809 810 Returns: 811 Returns the returned value of `step_fn`. 812 813 Raises: 814 StopIteration: if `step_fn` has called `request_stop()`. It may be 815 caught by `with tf.MonitoredSession()` to close the session. 816 ValueError: if `step_fn` doesn't have a single argument called 817 `step_context`. It may also optionally have `self` for cases when it 818 belongs to an object. 819 """ 820 step_fn_arguments = function_utils.fn_args(step_fn) 821 if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 822 'self', 823 'step_context', 824 ): 825 raise ValueError( 826 '`step_fn` may either have one `step_context` argument, or' 827 ' `self` and `step_context` arguments if it\'s an instance' 828 ' method. Got {} instead.'.format(step_fn_arguments)) 829 830 # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. 831 # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be 832 # `_CoordinatedSession.run` downstream in either case. This allows 833 # `_PREEMPTION_ERRORS` to propage from within `step_fn` to 834 # `_RecoverableSession.run_step_fn`. 835 return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None) 836 837 class StepContext(object): 838 """Control flow instrument for the `step_fn` from `run_step_fn()`. 839 840 Users of `step_fn` may perform `run()` calls without running hooks 841 by accessing the `session`. A `run()` call with hooks may be performed 842 using `run_with_hooks()`. Computation flow can be interrupted using 843 `request_stop()`. 844 """ 845 846 def __init__(self, session, run_with_hooks_fn): 847 """Initializes the `step_context` argument for a `step_fn` invocation. 848 849 Args: 850 session: An instance of `tf.compat.v1.Session`. 851 run_with_hooks_fn: A function for running fetches and hooks. 852 """ 853 self._session = session 854 self._run_with_hooks_fn = run_with_hooks_fn 855 856 @property 857 def session(self): 858 return self._session 859 860 def run_with_hooks(self, *args, **kwargs): 861 """Same as `MonitoredSession.run`. Accepts the same arguments.""" 862 return self._run_with_hooks_fn(*args, **kwargs) 863 864 def request_stop(self): 865 """Exit the training loop by causing `should_stop()` to return `True`. 866 867 Causes `step_fn` to exit by raising an exception. 868 869 Raises: 870 StopIteration 871 """ 872 raise StopIteration('step_fn has requested the iterations to stop.') 873 874 def should_stop(self): 875 return self._sess is None or self._sess.should_stop() 876 877 def close(self): 878 self._close_internal() 879 880 def __enter__(self): 881 return self 882 883 def __exit__(self, exception_type, exception_value, traceback): 884 if exception_type in [errors.OutOfRangeError, StopIteration]: 885 exception_type = None 886 self._close_internal(exception_type) 887 # __exit__ should return True to suppress an exception. 888 return exception_type is None 889 890 class _CoordinatedSessionCreator(SessionCreator): 891 """Factory for _CoordinatedSession.""" 892 893 def __init__(self, session_creator, hooks, stop_grace_period_secs): 894 self._session_creator = session_creator 895 self._hooks = hooks 896 self.coord = None 897 self.tf_sess = None 898 self._stop_grace_period_secs = stop_grace_period_secs 899 900 def create_session(self): 901 """Creates a coordinated session.""" 902 # Keep the tf_sess for unit testing. 903 self.tf_sess = self._session_creator.create_session() 904 # We don't want coordinator to suppress any exception. 905 self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) 906 if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 907 queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord) 908 # Inform the hooks that a new session has been created. 909 for hook in self._hooks: 910 hook.after_create_session(self.tf_sess, self.coord) 911 return _CoordinatedSession( 912 _HookedSession(self.tf_sess, self._hooks), self.coord, 913 self._stop_grace_period_secs) 914 915 def _close_internal(self, exception_type=None): 916 try: 917 if not exception_type: 918 for h in self._hooks: 919 h.end(self._coordinated_creator.tf_sess) 920 finally: 921 try: 922 if self._sess is None: 923 raise RuntimeError('Session is already closed.') 924 self._sess.close() 925 finally: 926 self._sess = None 927 self._coordinated_creator.tf_sess = None 928 self._coordinated_creator.coord = None 929 if not self._graph_was_finalized: 930 ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access 931 932 def _is_closed(self): 933 """Return True if the monitored session is closed. 934 935 For tests only. 936 937 Returns: 938 A boolean. 939 """ 940 return self._coordinated_creator.tf_sess is None 941 942 def _tf_sess(self): 943 """Return underlying tf.compat.v1.Session object. 944 945 Warning: accessing the returned object in user code is likely to cause races 946 or "flaky tests". 947 948 Returns: 949 A tf.compat.v1.Session object. 950 """ 951 return self._coordinated_creator.tf_sess 952 953 954@tf_export(v1=['train.MonitoredSession']) 955class MonitoredSession(_MonitoredSession): 956 """Session-like object that handles initialization, recovery and hooks. 957 958 Example usage: 959 960 ```python 961 saver_hook = CheckpointSaverHook(...) 962 summary_hook = SummarySaverHook(...) 963 with MonitoredSession(session_creator=ChiefSessionCreator(...), 964 hooks=[saver_hook, summary_hook]) as sess: 965 while not sess.should_stop(): 966 sess.run(train_op) 967 ``` 968 969 Initialization: At creation time the monitored session does following things 970 in given order: 971 972 * calls `hook.begin()` for each given hook 973 * finalizes the graph via `scaffold.finalize()` 974 * create session 975 * initializes the model via initialization ops provided by `Scaffold` 976 * restores variables if a checkpoint exists 977 * launches queue runners 978 * calls `hook.after_create_session()` 979 980 Run: When `run()` is called, the monitored session does following things: 981 982 * calls `hook.before_run()` 983 * calls TensorFlow `session.run()` with merged fetches and feed_dict 984 * calls `hook.after_run()` 985 * returns result of `session.run()` asked by user 986 * if `AbortedError` or `UnavailableError` occurs, it recovers or 987 reinitializes the session before executing the run() call again 988 989 990 Exit: At the `close()`, the monitored session does following things in order: 991 992 * calls `hook.end()` 993 * closes the queue runners and the session 994 * suppresses `OutOfRange` error which indicates that all inputs have been 995 processed if the monitored_session is used as a context 996 997 How to set `tf.compat.v1.Session` arguments: 998 999 * In most cases you can set session arguments as follows: 1000 1001 ```python 1002 MonitoredSession( 1003 session_creator=ChiefSessionCreator(master=..., config=...)) 1004 ``` 1005 1006 * In distributed setting for a non-chief worker, you can use following: 1007 1008 ```python 1009 MonitoredSession( 1010 session_creator=WorkerSessionCreator(master=..., config=...)) 1011 ``` 1012 1013 See `MonitoredTrainingSession` for an example usage based on chief or worker. 1014 1015 Note: This is not a `tf.compat.v1.Session`. For example, it cannot do 1016 following: 1017 1018 * it cannot be set as default session. 1019 * it cannot be sent to saver.save. 1020 * it cannot be sent to tf.train.start_queue_runners. 1021 1022 Args: 1023 session_creator: A factory object to create session. Typically a 1024 `ChiefSessionCreator` which is the default one. 1025 hooks: An iterable of `SessionRunHook' objects. 1026 1027 Returns: 1028 A MonitoredSession object. 1029 """ 1030 1031 def __init__(self, 1032 session_creator=None, 1033 hooks=None, 1034 stop_grace_period_secs=120): 1035 super(MonitoredSession, self).__init__( 1036 session_creator, 1037 hooks, 1038 should_recover=True, 1039 stop_grace_period_secs=stop_grace_period_secs) 1040 1041 1042@tf_export(v1=['train.SingularMonitoredSession']) 1043class SingularMonitoredSession(_MonitoredSession): 1044 """Session-like object that handles initialization, restoring, and hooks. 1045 1046 Please note that this utility is not recommended for distributed settings. 1047 For distributed settings, please use `tf.compat.v1.train.MonitoredSession`. 1048 The 1049 differences between `MonitoredSession` and `SingularMonitoredSession` are: 1050 1051 * `MonitoredSession` handles `AbortedError` and `UnavailableError` for 1052 distributed settings, but `SingularMonitoredSession` does not. 1053 * `MonitoredSession` can be created in `chief` or `worker` modes. 1054 `SingularMonitoredSession` is always created as `chief`. 1055 * You can access the raw `tf.compat.v1.Session` object used by 1056 `SingularMonitoredSession`, whereas in MonitoredSession the raw session is 1057 private. This can be used: 1058 - To `run` without hooks. 1059 - To save and restore. 1060 * All other functionality is identical. 1061 1062 Example usage: 1063 ```python 1064 saver_hook = CheckpointSaverHook(...) 1065 summary_hook = SummarySaverHook(...) 1066 with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess: 1067 while not sess.should_stop(): 1068 sess.run(train_op) 1069 ``` 1070 1071 Initialization: At creation time the hooked session does following things 1072 in given order: 1073 1074 * calls `hook.begin()` for each given hook 1075 * finalizes the graph via `scaffold.finalize()` 1076 * create session 1077 * initializes the model via initialization ops provided by `Scaffold` 1078 * restores variables if a checkpoint exists 1079 * launches queue runners 1080 1081 Run: When `run()` is called, the hooked session does following things: 1082 1083 * calls `hook.before_run()` 1084 * calls TensorFlow `session.run()` with merged fetches and feed_dict 1085 * calls `hook.after_run()` 1086 * returns result of `session.run()` asked by user 1087 1088 Exit: At the `close()`, the hooked session does following things in order: 1089 1090 * calls `hook.end()` 1091 * closes the queue runners and the session 1092 * suppresses `OutOfRange` error which indicates that all inputs have been 1093 processed if the `SingularMonitoredSession` is used as a context. 1094 """ 1095 1096 def __init__(self, 1097 hooks=None, 1098 scaffold=None, 1099 master='', 1100 config=None, 1101 checkpoint_dir=None, 1102 stop_grace_period_secs=120, 1103 checkpoint_filename_with_path=None): 1104 """Creates a SingularMonitoredSession. 1105 1106 Args: 1107 hooks: An iterable of `SessionRunHook' objects. 1108 scaffold: A `Scaffold` used for gathering or building supportive ops. If 1109 not specified a default one is created. It's used to finalize the graph. 1110 master: `String` representation of the TensorFlow master to use. 1111 config: `ConfigProto` proto used to configure the session. 1112 checkpoint_dir: A string. Optional path to a directory where to restore 1113 variables. 1114 stop_grace_period_secs: Number of seconds given to threads to stop after 1115 `close()` has been called. 1116 checkpoint_filename_with_path: A string. Optional path to a checkpoint 1117 file from which to restore variables. 1118 """ 1119 session_creator = ChiefSessionCreator( 1120 scaffold=scaffold, 1121 master=master, 1122 config=config, 1123 checkpoint_dir=checkpoint_dir, 1124 checkpoint_filename_with_path=checkpoint_filename_with_path) 1125 super(SingularMonitoredSession, self).__init__( 1126 session_creator, 1127 hooks, 1128 should_recover=False, 1129 stop_grace_period_secs=stop_grace_period_secs) 1130 1131 def raw_session(self): 1132 """Returns underlying `TensorFlow.Session` object.""" 1133 return self._tf_sess() 1134 1135 1136class _WrappedSession(object): 1137 """Wrapper around a `tf.compat.v1.Session`. 1138 1139 This wrapper is used as a base class for various session wrappers 1140 that provide additional functionality such as monitoring, coordination, 1141 and recovery. 1142 1143 In addition to the methods exported by `SessionInterface` the wrapper 1144 provides a method to check for stop and never raises exceptions from 1145 calls to `close()`. 1146 """ 1147 1148 def __init__(self, sess): 1149 """Creates a `_WrappedSession`. 1150 1151 Args: 1152 sess: A `tf.compat.v1.Session` or `_WrappedSession` object. The wrapped 1153 session. 1154 """ 1155 self._sess = sess 1156 self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession) 1157 1158 @property 1159 def graph(self): 1160 return self._sess.graph 1161 1162 @property 1163 def sess_str(self): 1164 return self._sess.sess_str 1165 1166 def should_stop(self): 1167 """Return true if this session should not be used anymore. 1168 1169 Always return True if the session was closed. 1170 1171 Returns: 1172 True if the session should stop, False otherwise. 1173 """ 1174 if self._check_stop(): 1175 return True 1176 if self._sess: 1177 return self._wrapped_is_stoppable and self._sess.should_stop() 1178 return True 1179 1180 def _check_stop(self): 1181 """Hook for subclasses to provide their own stop condition. 1182 1183 Returns: 1184 True if the session should stop, False otherwise. 1185 """ 1186 return False 1187 1188 def close(self): 1189 if self._sess: 1190 try: 1191 self._sess.close() 1192 except _PREEMPTION_ERRORS as e: 1193 logging.error( 1194 'An error occurred when attempting to close the ' 1195 'session. This may be due to a preemption in a ' 1196 'connected worker or parameter server. Error: %s', e) 1197 finally: 1198 self._sess = None 1199 1200 def run(self, *args, **kwargs): 1201 return self._sess.run(*args, **kwargs) 1202 1203 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1204 # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`. 1205 # It is `None` when called from `_CoordinatedSession`. In that case 1206 # `self.run` is `_CoordinatedSession.run`. 1207 run_with_hooks = run_with_hooks or self.run 1208 return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks)) 1209 1210 1211class _RecoverableSession(_WrappedSession): 1212 """A wrapped session that recreates a session upon certain kinds of errors. 1213 1214 The constructor is passed a SessionCreator object, not a session. 1215 1216 Calls to `run()` are delegated to the wrapped session. If a call raises the 1217 exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the 1218 wrapped session is closed, and a new one is created by calling the factory 1219 again. 1220 """ 1221 1222 def __init__(self, sess_creator): 1223 """Create a new `_RecoverableSession`. 1224 1225 The value returned by calling `sess_creator.create_session()` will be the 1226 session wrapped by this recoverable session. 1227 1228 Args: 1229 sess_creator: A 'SessionCreator' to be wrapped by recoverable. 1230 """ 1231 self._sess_creator = sess_creator 1232 _WrappedSession.__init__(self, self._create_session()) 1233 1234 def _create_session(self): 1235 while True: 1236 try: 1237 return self._sess_creator.create_session() 1238 except _PREEMPTION_ERRORS as e: 1239 logging.info( 1240 'An error was raised while a session was being created. ' 1241 'This may be due to a preemption of a connected worker ' 1242 'or parameter server. A new session will be created. ' 1243 'This error may also occur due to a gRPC failure caused ' 1244 'by high memory or network bandwidth usage in the ' 1245 'parameter servers. If this error occurs repeatedly, try ' 1246 'increasing the number of parameter servers assigned to ' 1247 'the job. Error: %s', e) 1248 1249 def _check_stop(self): 1250 try: 1251 if self._sess: 1252 return self._sess._check_stop() # pylint: disable=protected-access 1253 else: 1254 return True 1255 except _PREEMPTION_ERRORS as e: 1256 logging.info( 1257 'An error was raised while considering whether the ' 1258 'session is complete. This may be due to a preemption in ' 1259 'a connected worker or parameter server. The current ' 1260 'session will be closed and a new session will be ' 1261 'created. This error may also occur due to a gRPC failure ' 1262 'caused by high memory or network bandwidth usage in the ' 1263 'parameter servers. If this error occurs repeatedly, try ' 1264 'increasing the number of parameter servers assigned to ' 1265 'the job. Error: %s', e) 1266 self.close() 1267 self._sess = self._create_session() 1268 # Since we have just recreated the session, the overall computation should 1269 # not stop: 1270 return False 1271 except Exception: # pylint: disable=broad-except 1272 # `should_stop` should return True instead of raising an exception. 1273 return True 1274 1275 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1276 while True: 1277 try: 1278 if not self._sess: 1279 self._sess = self._create_session() 1280 return self._sess.run( 1281 fetches, 1282 feed_dict=feed_dict, 1283 options=options, 1284 run_metadata=run_metadata) 1285 except _PREEMPTION_ERRORS as e: 1286 logging.info( 1287 'An error was raised. This may be due to a preemption in ' 1288 'a connected worker or parameter server. The current ' 1289 'session will be closed and a new session will be ' 1290 'created. This error may also occur due to a gRPC failure ' 1291 'caused by high memory or network bandwidth usage in the ' 1292 'parameter servers. If this error occurs repeatedly, try ' 1293 'increasing the number of parameter servers assigned to ' 1294 'the job. Error: %s', e) 1295 self.close() 1296 self._sess = None 1297 1298 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1299 while True: 1300 try: 1301 if not self._sess: 1302 self._sess = self._create_session() 1303 1304 run_with_hooks = self._sess.run 1305 return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks) 1306 except _PREEMPTION_ERRORS as e: 1307 logging.info( 1308 'An error was raised. This may be due to a preemption in ' 1309 'a connected worker or parameter server. The current ' 1310 'session will be closed and a new session will be ' 1311 'created. This error may also occur due to a gRPC failure ' 1312 'caused by high memory or network bandwidth usage in the ' 1313 'parameter servers. If this error occurs repeatedly, try ' 1314 'increasing the number of parameter servers assigned to ' 1315 'the job. Error: %s', e) 1316 self.close() 1317 self._sess = None 1318 1319 1320class _CoordinatedSession(_WrappedSession): 1321 """A wrapped session that works with a `tf.Coordinator`. 1322 1323 Calls to `run()` are delegated to the wrapped session. If a call 1324 raises an exception, the exception is reported to the coordinator. 1325 1326 In addition, after each call to `run()` this session ask the coordinator if 1327 the session should stop. In that case it will join all the threads 1328 registered with the coordinator before returning. 1329 1330 If the coordinator was requested to stop with an exception, that exception 1331 will be re-raised from the call to `run()`. 1332 """ 1333 1334 def __init__(self, sess, coord, stop_grace_period_secs=120): 1335 """Create a new `_CoordinatedSession`. 1336 1337 Args: 1338 sess: A `tf.compat.v1.Session` object. The wrapped session. 1339 coord: A `tf.train.Coordinator` object. 1340 stop_grace_period_secs: Number of seconds given to threads to stop after 1341 `close()` has been called. 1342 """ 1343 _WrappedSession.__init__(self, sess) 1344 self._coord = coord 1345 self._stop_grace_period_secs = stop_grace_period_secs 1346 1347 def _check_stop(self): 1348 # If the coordinator was asked to stop due to an exception, then it needs 1349 # to be propagated to this stack. 1350 self._coord.raise_requested_exception() 1351 # At this point, no exceptions are recorded in the coordinator. 1352 return self._coord.should_stop() 1353 1354 def close(self): 1355 self._coord.request_stop() 1356 try: 1357 self._coord.join( 1358 stop_grace_period_secs=self._stop_grace_period_secs, 1359 ignore_live_threads=True) 1360 finally: 1361 try: 1362 _WrappedSession.close(self) 1363 except Exception: # pylint: disable=broad-except 1364 # We intentionally suppress exceptions from the close() here since 1365 # useful exceptions are already reported by join(). 1366 pass 1367 1368 def run(self, *args, **kwargs): 1369 try: 1370 return self._sess.run(*args, **kwargs) 1371 except _PREEMPTION_ERRORS: 1372 raise 1373 except Exception: # pylint: disable=broad-except 1374 # A non-preemption error could have been caused by a preemption error 1375 # in the coordinator. If this is the case, raise that exception instead, 1376 # since it's the root cause. Otherwise, stick to the `original_exc_info`. 1377 original_exc_info = sys.exc_info() 1378 try: 1379 self._coord.raise_requested_exception() 1380 except _PREEMPTION_ERRORS: 1381 raise 1382 except Exception: # pylint: disable=broad-except 1383 raise six.reraise(*original_exc_info) 1384 else: 1385 raise six.reraise(*original_exc_info) 1386 1387 1388class _HookedSession(_WrappedSession): 1389 """A _WrappedSession that calls hooks during calls to run(). 1390 1391 The list of hooks to call is passed in the constructor. Before each call 1392 to `run()` the session calls the `before_run()` method of the hooks, which 1393 can return additional ops or tensors to run. These are added to the arguments 1394 of the call to `run()`. 1395 1396 When the `run()` call finishes, the session calls the `after_run()` methods of 1397 the hooks, passing the values returned by the `run()` call corresponding to 1398 the ops and tensors that each hook requested. 1399 1400 If any call to the hooks, requests stop via run_context the session will be 1401 marked as needing to stop and its `should_stop()` method will now return 1402 `True`. 1403 """ 1404 1405 def __init__(self, sess, hooks): 1406 """Initializes a _HookedSession object. 1407 1408 Args: 1409 sess: A `tf.compat.v1.Session` or a `_WrappedSession` object. 1410 hooks: An iterable of `SessionRunHook' objects. 1411 """ 1412 1413 _WrappedSession.__init__(self, sess) 1414 self._hooks = hooks 1415 self._should_stop = False 1416 1417 def _check_stop(self): 1418 """See base class.""" 1419 return self._should_stop 1420 1421 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1422 """See base class.""" 1423 if self.should_stop(): 1424 raise RuntimeError('Run called even after should_stop requested.') 1425 1426 actual_fetches = {'caller': fetches} 1427 1428 run_context = session_run_hook.SessionRunContext( 1429 original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), 1430 session=self._sess) 1431 1432 options = options or config_pb2.RunOptions() 1433 feed_dict = self._call_hook_before_run(run_context, actual_fetches, 1434 feed_dict, options) 1435 1436 # Do session run. 1437 run_metadata = run_metadata or config_pb2.RunMetadata() 1438 outputs = _WrappedSession.run( 1439 self, 1440 fetches=actual_fetches, 1441 feed_dict=feed_dict, 1442 options=options, 1443 run_metadata=run_metadata) 1444 1445 for hook in self._hooks: 1446 hook.after_run( 1447 run_context, 1448 session_run_hook.SessionRunValues( 1449 results=outputs[hook] if hook in outputs else None, 1450 options=options, 1451 run_metadata=run_metadata)) 1452 self._should_stop = self._should_stop or run_context.stop_requested 1453 1454 return outputs['caller'] 1455 1456 def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict, 1457 options): 1458 """Calls hooks.before_run and handles requests from hooks.""" 1459 hook_feeds = {} 1460 for hook in self._hooks: 1461 request = hook.before_run(run_context) 1462 if request is not None: 1463 if request.fetches is not None: 1464 fetch_dict[hook] = request.fetches 1465 if request.feed_dict: 1466 self._raise_if_feeds_intersects(hook_feeds, request.feed_dict, 1467 'Same tensor is fed by two hooks.') 1468 hook_feeds.update(request.feed_dict) 1469 if request.options: 1470 self._merge_run_options(options, request.options) 1471 1472 if not hook_feeds: 1473 return user_feed_dict 1474 1475 if not user_feed_dict: 1476 return hook_feeds 1477 1478 self._raise_if_feeds_intersects( 1479 user_feed_dict, hook_feeds, 1480 'Same tensor is fed by a SessionRunHook and user.') 1481 hook_feeds.update(user_feed_dict) 1482 return hook_feeds 1483 1484 def _raise_if_feeds_intersects(self, feeds1, feeds2, message): 1485 intersection = set(feeds1.keys()) & set(feeds2.keys()) 1486 if intersection: 1487 raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection))) 1488 1489 def _merge_run_options(self, options, incoming_options): 1490 """Merge two instances of RunOptions into the first one. 1491 1492 During the merger, the numerical fields including trace_level, 1493 timeout_in_ms, inter_op_thread_pool are set to the larger one of the two. 1494 The boolean value is set to the logical OR of the two. 1495 debug_tensor_watch_opts of the original options is extended with that from 1496 the incoming one. 1497 1498 Args: 1499 options: The options to merge into. 1500 incoming_options: The options to be merged into the first argument. 1501 """ 1502 options.trace_level = max(options.trace_level, incoming_options.trace_level) 1503 options.timeout_in_ms = max(options.timeout_in_ms, 1504 incoming_options.timeout_in_ms) 1505 options.inter_op_thread_pool = max(options.inter_op_thread_pool, 1506 incoming_options.inter_op_thread_pool) 1507 options.output_partition_graphs = max( 1508 options.output_partition_graphs, 1509 incoming_options.output_partition_graphs) 1510 options.debug_options.debug_tensor_watch_opts.extend( 1511 incoming_options.debug_options.debug_tensor_watch_opts) 1512 options.debug_options.reset_disk_byte_usage = ( 1513 options.debug_options.reset_disk_byte_usage or 1514 incoming_options.debug_options.reset_disk_byte_usage) 1515 options.report_tensor_allocations_upon_oom = ( 1516 options.report_tensor_allocations_upon_oom or 1517 incoming_options.report_tensor_allocations_upon_oom) 1518