1# pylint: disable=g-bad-file-header 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""A wrapper of Session API which runs hooks.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import abc 23import os 24import sys 25 26import six 27 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.python.distribute import distribute_coordinator_context 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import lookup_ops 35from tensorflow.python.ops import resources 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.summary import summary 39from tensorflow.python.training import basic_session_run_hooks 40from tensorflow.python.training import coordinator 41from tensorflow.python.training import queue_runner 42from tensorflow.python.training import saver as training_saver 43from tensorflow.python.training import session_manager as sm 44from tensorflow.python.training import session_run_hook 45from tensorflow.python.training.tracking import graph_view 46from tensorflow.python.training.tracking import util as trackable_util 47from tensorflow.python.util import function_utils 48from tensorflow.python.util.tf_export import tf_export 49 50# The list of exceptions that we should recover from. Exceptions not in this 51# list may terminate the job. 52_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError) 53 54# Value that indicates no value was provided. 55USE_DEFAULT = object() 56 57 58@tf_export(v1=['train.Scaffold']) 59class Scaffold(object): 60 """Structure to create or gather pieces commonly needed to train a model. 61 62 When you build a model for training you usually need ops to initialize 63 variables, a `Saver` to checkpoint them, an op to collect summaries for 64 the visualizer, and so on. 65 66 Various libraries built on top of the core TensorFlow library take care of 67 creating some or all of these pieces and storing them in well known 68 collections in the graph. The `Scaffold` class helps pick these pieces from 69 the graph collections, creating and adding them to the collections if needed. 70 71 If you call the scaffold constructor without any arguments, it will pick 72 pieces from the collections, creating default ones if needed when 73 `scaffold.finalize()` is called. You can pass arguments to the constructor to 74 provide your own pieces. Pieces that you pass to the constructor are not 75 added to the graph collections. 76 77 The following pieces are directly accessible as attributes of the `Scaffold` 78 object: 79 80 * `saver`: A `tf.train.Saver` object taking care of saving the variables. 81 Picked from and stored into the `SAVERS` collection in the graph by default. 82 * `init_op`: An op to run to initialize the variables. Picked from and 83 stored into the `INIT_OP` collection in the graph by default. 84 * `ready_op`: An op to verify that the variables are initialized. Picked 85 from and stored into the `READY_OP` collection in the graph by default. 86 * `ready_for_local_init_op`: An op to verify that global state has been 87 initialized and it is alright to run `local_init_op`. Picked from and 88 stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by 89 default. This is needed when the initialization of local variables depends 90 on the values of global variables. 91 * `local_init_op`: An op to initialize the local variables. Picked 92 from and stored into the `LOCAL_INIT_OP` collection in the graph by default. 93 * `summary_op`: An op to run and merge the summaries in the graph. Picked 94 from and stored into the `SUMMARY_OP` collection in the graph by default. 95 96 You can also pass the following additional pieces to the constructor: 97 98 * `init_feed_dict`: A session feed dictionary that should be used when 99 running the init op. 100 * `init_fn`: A callable to run after the init op to perform additional 101 initializations. The callable will be called as 102 `init_fn(scaffold, session)`. 103 104 """ 105 106 def __init__(self, 107 init_op=None, 108 init_feed_dict=None, 109 init_fn=None, 110 ready_op=None, 111 ready_for_local_init_op=None, 112 local_init_op=None, 113 summary_op=None, 114 saver=None, 115 copy_from_scaffold=None): 116 """Create a scaffold. 117 118 Args: 119 init_op: Optional op for initializing variables. 120 init_feed_dict: Optional session feed dictionary to use when running the 121 init_op. 122 init_fn: Optional function to use to initialize the model after running 123 the init_op. Will be called as `init_fn(scaffold, session)`. 124 ready_op: Optional op to verify that the variables are initialized. Must 125 return an empty 1D string tensor when the variables are initialized, or 126 a non-empty 1D string tensor listing the names of the non-initialized 127 variables. 128 ready_for_local_init_op: Optional op to verify that the global variables 129 are initialized and `local_init_op` can be run. Must return an empty 1D 130 string tensor when the global variables are initialized, or a non-empty 131 1D string tensor listing the names of the non-initialized global 132 variables. 133 local_init_op: Optional op to initialize local variables. 134 summary_op: Optional op to gather all summaries. Must return a scalar 135 string tensor containing a serialized `Summary` proto. 136 saver: Optional `tf.train.Saver` object to use to save and restore 137 variables. May also be a `tf.train.Checkpoint` object, in which case 138 object-based checkpoints are saved. This will also load some 139 object-based checkpoints saved from elsewhere, but that loading may be 140 fragile since it uses fixed keys rather than performing a full 141 graph-based match. For example if a variable has two paths from the 142 `Checkpoint` object because two `Model` objects share the `Layer` object 143 that owns it, removing one `Model` may change the keys and break 144 checkpoint loading through this API, whereas a graph-based match would 145 match the variable through the other `Model`. 146 copy_from_scaffold: Optional scaffold object to copy fields from. Its 147 fields will be overwritten by the provided fields in this function. 148 """ 149 if copy_from_scaffold is not None: 150 if not isinstance(copy_from_scaffold, Scaffold): 151 raise TypeError('copy_from_scaffold is not a Scaffold instance.') 152 # We need _coalesce since Tensor is not converted to bool automatically, 153 # so the common idiom of (a or b) does not work. 154 coalesce = lambda a, b: a if a is not None else b 155 init_op = coalesce(init_op, copy_from_scaffold.init_op) 156 init_feed_dict = coalesce(init_feed_dict, 157 copy_from_scaffold.init_feed_dict) 158 # Use the original init_fn provided by the user to init the new Scaffold. 159 init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn) # pylint: disable=protected-access 160 ready_op = coalesce(ready_op, copy_from_scaffold.ready_op) 161 ready_for_local_init_op = coalesce( 162 ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op) 163 local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op) 164 summary_op = coalesce(summary_op, copy_from_scaffold.summary_op) 165 saver = coalesce(saver, copy_from_scaffold.saver) 166 167 # NOTE(touts): modifying the init function to be passed the scaffold is a 168 # hack to make it easy to find the saver. Is there a better way? 169 self._user_init_fn = init_fn 170 if init_fn: 171 self._init_fn = lambda sess: init_fn(self, sess) 172 else: 173 self._init_fn = None 174 175 self._init_op = init_op 176 self._init_feed_dict = init_feed_dict 177 self._ready_op = ready_op 178 self._ready_for_local_init_op = ready_for_local_init_op 179 self._local_init_op = local_init_op 180 self._summary_op = summary_op 181 self._saver = saver 182 183 def finalize(self): 184 """Creates operations if needed and finalizes the graph.""" 185 if self._init_op is None: 186 187 def default_init_op(): 188 return control_flow_ops.group( 189 variables.global_variables_initializer(), 190 resources.initialize_resources(resources.shared_resources())) 191 192 self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP, 193 default_init_op) 194 if self._ready_op is None: 195 196 def default_ready_op(): 197 return array_ops.concat([ 198 variables.report_uninitialized_variables(), 199 resources.report_uninitialized_resources() 200 ], 0) 201 202 self._ready_op = Scaffold.get_or_default( 203 'ready_op', ops.GraphKeys.READY_OP, default_ready_op) 204 if self._ready_for_local_init_op is None: 205 206 def default_ready_for_local_init_op(): 207 return array_ops.concat([ 208 variables.report_uninitialized_variables( 209 variables.global_variables()), 210 resources.report_uninitialized_resources( 211 resources.shared_resources()) 212 ], 0) 213 214 self._ready_for_local_init_op = Scaffold.get_or_default( 215 'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP, 216 default_ready_for_local_init_op) 217 if self._local_init_op is None: 218 self._local_init_op = Scaffold.get_or_default( 219 'local_init_op', ops.GraphKeys.LOCAL_INIT_OP, 220 Scaffold.default_local_init_op) 221 if self._summary_op is None: 222 self._summary_op = Scaffold.get_or_default( 223 'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all) 224 # pylint: disable=g-long-lambda 225 if self._saver is None: 226 self._saver = training_saver._get_saver_or_default() # pylint: disable=protected-access 227 # pylint: enable=g-long-lambda 228 if isinstance(self._saver, trackable_util.Checkpoint): 229 self._saver = training_saver.Saver( 230 var_list=graph_view.ObjectGraphView( 231 self._saver).frozen_saveable_objects(), 232 sharded=True) 233 else: 234 self._saver.build() 235 236 ops.get_default_graph().finalize() 237 logging.info('Graph was finalized.') 238 return self 239 240 @property 241 def init_fn(self): 242 return self._init_fn 243 244 @property 245 def init_op(self): 246 return self._init_op 247 248 @property 249 def ready_op(self): 250 return self._ready_op 251 252 @property 253 def ready_for_local_init_op(self): 254 return self._ready_for_local_init_op 255 256 @property 257 def local_init_op(self): 258 return self._local_init_op 259 260 @property 261 def summary_op(self): 262 return self._summary_op 263 264 @property 265 def saver(self): 266 return self._saver 267 268 @property 269 def init_feed_dict(self): 270 return self._init_feed_dict 271 272 @staticmethod 273 def get_or_default(arg_name, collection_key, default_constructor): 274 """Get from cache or create a default operation.""" 275 elements = ops.get_collection(collection_key) 276 if elements: 277 if len(elements) > 1: 278 raise RuntimeError( 279 'More than one item in the collection "%s". ' 280 'Please indicate which one to use by passing it to ' 281 'the tf.Scaffold constructor as: ' 282 'tf.Scaffold(%s=item to use)', collection_key, arg_name) 283 return elements[0] 284 op = default_constructor() 285 if op is not None: 286 ops.add_to_collection(collection_key, op) 287 return op 288 289 @staticmethod 290 def default_local_init_op(): 291 """Returns an op that groups the default local init ops. 292 293 This op is used during session initialization when a Scaffold is 294 initialized without specifying the local_init_op arg. It includes 295 `tf.local_variables_initializer`, `tf.tables_initializer`, and also 296 initializes local session resources. 297 298 Returns: 299 The default Scaffold local init op. 300 """ 301 return control_flow_ops.group( 302 variables.local_variables_initializer(), 303 lookup_ops.tables_initializer(), 304 resources.initialize_resources(resources.local_resources())) 305 306 307def _create_monitored_session_with_worker_context( 308 worker_context, # pylint: disable=missing-docstring 309 scaffold, 310 checkpoint_dir=None, 311 hooks=None, 312 chief_only_hooks=None, 313 save_checkpoint_secs=None, 314 save_summaries_steps=None, 315 save_summaries_secs=None, 316 config=None, 317 stop_grace_period_secs=120, 318 log_step_count_steps=100, 319 max_wait_secs=7200, 320 save_checkpoint_steps=None, 321 summary_dir=None): 322 all_hooks = [] 323 if hooks: 324 all_hooks.extend(hooks) 325 if chief_only_hooks and worker_context.is_chief: 326 all_hooks.extend(chief_only_hooks) 327 328 # We need to call save or summary ops on all workers since these ops may 329 # contain collective ops, only running save ops on some workers would make 330 # collective ops hang. Therefore on those workers that don't need to actually 331 # write checkpoints or summaries, we let them write to a temp directory. 332 # pylint: disable=protected-access 333 if type(worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy', 334 'MultiWorkerMirroredStrategy'): 335 if worker_context.task_type: 336 tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id) 337 else: 338 tmpdir = 'tmp' 339 340 if save_checkpoint_secs: 341 logging.warning('Collective ops may deadlock with ' 342 '`save_checkpoints_secs` please use ' 343 '`save_checkpoint_steps` instead. Clearing ' 344 '`save_checkpoint_secs` and setting ' 345 '`save_checkpoint_steps` to 1000 now.') 346 save_checkpoint_secs = None 347 save_checkpoint_steps = 1000 348 if save_summaries_secs: 349 logging.warning('Collective ops may run out of sync with' 350 '`save_summaries_secs`, please use ' 351 '`save_summaries_steps` instead.') 352 else: 353 tmpdir = None 354 355 summary_dir = summary_dir or checkpoint_dir 356 if summary_dir and log_step_count_steps and log_step_count_steps > 0: 357 if worker_context.should_save_summary: 358 all_hooks.append( 359 basic_session_run_hooks.StepCounterHook( 360 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 361 elif tmpdir: 362 all_hooks.append( 363 basic_session_run_hooks.StepCounterHook( 364 output_dir=os.path.join(summary_dir, tmpdir), 365 every_n_steps=log_step_count_steps)) 366 367 if (((save_summaries_steps and save_summaries_steps > 0) or 368 (save_summaries_secs and save_summaries_secs > 0)) and summary_dir): 369 if worker_context.should_save_summary: 370 all_hooks.append( 371 basic_session_run_hooks.SummarySaverHook( 372 scaffold=scaffold, 373 save_steps=save_summaries_steps, 374 save_secs=save_summaries_secs, 375 output_dir=summary_dir)) 376 elif tmpdir: 377 all_hooks.append( 378 basic_session_run_hooks.SummarySaverHook( 379 scaffold=scaffold, 380 save_steps=save_summaries_steps, 381 save_secs=save_summaries_secs, 382 output_dir=os.path.join(summary_dir, tmpdir))) 383 384 if (((save_checkpoint_secs and save_checkpoint_secs > 0) or 385 (save_checkpoint_steps and save_checkpoint_steps > 0)) and 386 checkpoint_dir): 387 if worker_context.should_checkpoint: 388 all_hooks.append( 389 basic_session_run_hooks.CheckpointSaverHook( 390 checkpoint_dir, 391 save_steps=save_checkpoint_steps, 392 save_secs=save_checkpoint_secs, 393 scaffold=scaffold)) 394 elif tmpdir: 395 all_hooks.append( 396 basic_session_run_hooks.CheckpointSaverHook( 397 os.path.join(checkpoint_dir, tmpdir), 398 save_steps=save_checkpoint_steps, 399 save_secs=save_checkpoint_secs, 400 scaffold=scaffold)) 401 402 logging.info('all_hooks %r', all_hooks) 403 session_creator = worker_context.session_creator( 404 scaffold, 405 config=config, 406 checkpoint_dir=checkpoint_dir, 407 max_wait_secs=max_wait_secs) 408 return MonitoredSession( 409 session_creator=session_creator, 410 hooks=all_hooks, 411 stop_grace_period_secs=stop_grace_period_secs) 412 413 414@tf_export(v1=['train.MonitoredTrainingSession']) 415def MonitoredTrainingSession( 416 master='', # pylint: disable=invalid-name 417 is_chief=True, 418 checkpoint_dir=None, 419 scaffold=None, 420 hooks=None, 421 chief_only_hooks=None, 422 save_checkpoint_secs=USE_DEFAULT, 423 save_summaries_steps=USE_DEFAULT, 424 save_summaries_secs=USE_DEFAULT, 425 config=None, 426 stop_grace_period_secs=120, 427 log_step_count_steps=100, 428 max_wait_secs=7200, 429 save_checkpoint_steps=USE_DEFAULT, 430 summary_dir=None): 431 """Creates a `MonitoredSession` for training. 432 433 For a chief, this utility sets proper session initializer/restorer. It also 434 creates hooks related to checkpoint and summary saving. For workers, this 435 utility sets proper session creator which waits for the chief to 436 initialize/restore. Please check `tf.train.MonitoredSession` for more 437 information. 438 439 440 Args: 441 master: `String` the TensorFlow master to use. 442 is_chief: If `True`, it will take care of initialization and recovery the 443 underlying TensorFlow session. If `False`, it will wait on a chief to 444 initialize or recover the TensorFlow session. 445 checkpoint_dir: A string. Optional path to a directory where to restore 446 variables. 447 scaffold: A `Scaffold` used for gathering or building supportive ops. If not 448 specified, a default one is created. It's used to finalize the graph. 449 hooks: Optional list of `SessionRunHook` objects. 450 chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if 451 `is_chief==True`, ignore otherwise. 452 save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved 453 using a default checkpoint saver. If both `save_checkpoint_steps` and 454 `save_checkpoint_secs` are set to `None`, then the default checkpoint 455 saver isn't used. If both are provided, then only `save_checkpoint_secs` 456 is used. Default 600. 457 save_summaries_steps: The frequency, in number of global steps, that the 458 summaries are written to disk using a default summary saver. If both 459 `save_summaries_steps` and `save_summaries_secs` are set to `None`, then 460 the default summary saver isn't used. Default 100. 461 save_summaries_secs: The frequency, in secs, that the summaries are written 462 to disk using a default summary saver. If both `save_summaries_steps` and 463 `save_summaries_secs` are set to `None`, then the default summary saver 464 isn't used. Default not enabled. 465 config: an instance of `tf.ConfigProto` proto used to configure the session. 466 It's the `config` argument of constructor of `tf.Session`. 467 stop_grace_period_secs: Number of seconds given to threads to stop after 468 `close()` has been called. 469 log_step_count_steps: The frequency, in number of global steps, that the 470 global step/sec is logged. 471 max_wait_secs: Maximum time workers should wait for the session to become 472 available. This should be kept relatively short to help detect incorrect 473 code, but sometimes may need to be increased if the chief takes a while to 474 start up. 475 save_checkpoint_steps: The frequency, in number of global steps, that a 476 checkpoint is saved using a default checkpoint saver. If both 477 `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then 478 the default checkpoint saver isn't used. If both are provided, then only 479 `save_checkpoint_secs` is used. Default not enabled. 480 summary_dir: A string. Optional path to a directory where to save 481 summaries. If None, checkpoint_dir is used instead. 482 483 Returns: 484 A `MonitoredSession` object. 485 """ 486 if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT: 487 save_summaries_steps = 100 488 save_summaries_secs = None 489 elif save_summaries_secs == USE_DEFAULT: 490 save_summaries_secs = None 491 elif save_summaries_steps == USE_DEFAULT: 492 save_summaries_steps = None 493 494 if (save_checkpoint_steps == USE_DEFAULT and 495 save_checkpoint_secs == USE_DEFAULT): 496 save_checkpoint_steps = None 497 save_checkpoint_secs = 600 498 elif save_checkpoint_secs == USE_DEFAULT: 499 save_checkpoint_secs = None 500 elif save_checkpoint_steps == USE_DEFAULT: 501 save_checkpoint_steps = None 502 503 scaffold = scaffold or Scaffold() 504 worker_context = distribute_coordinator_context.get_current_worker_context() 505 506 if worker_context: 507 return _create_monitored_session_with_worker_context( 508 worker_context, 509 scaffold, 510 checkpoint_dir=checkpoint_dir, 511 hooks=hooks, 512 chief_only_hooks=chief_only_hooks, 513 save_checkpoint_secs=save_checkpoint_secs, 514 save_summaries_steps=save_summaries_steps, 515 save_summaries_secs=save_summaries_secs, 516 config=config, 517 stop_grace_period_secs=stop_grace_period_secs, 518 log_step_count_steps=log_step_count_steps, 519 max_wait_secs=max_wait_secs, 520 save_checkpoint_steps=save_checkpoint_steps, 521 summary_dir=summary_dir) 522 523 if not is_chief: 524 session_creator = WorkerSessionCreator( 525 scaffold=scaffold, 526 master=master, 527 config=config, 528 max_wait_secs=max_wait_secs) 529 return MonitoredSession( 530 session_creator=session_creator, 531 hooks=hooks or [], 532 stop_grace_period_secs=stop_grace_period_secs) 533 534 all_hooks = [] 535 if chief_only_hooks: 536 all_hooks.extend(chief_only_hooks) 537 session_creator = ChiefSessionCreator( 538 scaffold=scaffold, 539 checkpoint_dir=checkpoint_dir, 540 master=master, 541 config=config) 542 543 summary_dir = summary_dir or checkpoint_dir 544 if summary_dir: 545 if log_step_count_steps and log_step_count_steps > 0: 546 all_hooks.append( 547 basic_session_run_hooks.StepCounterHook( 548 output_dir=summary_dir, every_n_steps=log_step_count_steps)) 549 550 if (save_summaries_steps and 551 save_summaries_steps > 0) or (save_summaries_secs and 552 save_summaries_secs > 0): 553 all_hooks.append( 554 basic_session_run_hooks.SummarySaverHook( 555 scaffold=scaffold, 556 save_steps=save_summaries_steps, 557 save_secs=save_summaries_secs, 558 output_dir=summary_dir)) 559 560 if checkpoint_dir: 561 if (save_checkpoint_secs and 562 save_checkpoint_secs > 0) or (save_checkpoint_steps and 563 save_checkpoint_steps > 0): 564 all_hooks.append( 565 basic_session_run_hooks.CheckpointSaverHook( 566 checkpoint_dir, 567 save_steps=save_checkpoint_steps, 568 save_secs=save_checkpoint_secs, 569 scaffold=scaffold)) 570 571 if hooks: 572 all_hooks.extend(hooks) 573 return MonitoredSession( 574 session_creator=session_creator, 575 hooks=all_hooks, 576 stop_grace_period_secs=stop_grace_period_secs) 577 578 579@tf_export(v1=['train.SessionCreator']) 580@six.add_metaclass(abc.ABCMeta) 581class SessionCreator(object): 582 """A factory for tf.Session.""" 583 584 @abc.abstractmethod 585 def create_session(self): 586 raise NotImplementedError( 587 'create_session is not implemented for {}.'.format(self)) 588 589 590@tf_export(v1=['train.ChiefSessionCreator']) 591class ChiefSessionCreator(SessionCreator): 592 """Creates a tf.Session for a chief.""" 593 594 def __init__(self, 595 scaffold=None, 596 master='', 597 config=None, 598 checkpoint_dir=None, 599 checkpoint_filename_with_path=None): 600 """Initializes a chief session creator. 601 602 Args: 603 scaffold: A `Scaffold` used for gathering or building supportive ops. If 604 not specified a default one is created. It's used to finalize the graph. 605 master: `String` representation of the TensorFlow master to use. 606 config: `ConfigProto` proto used to configure the session. 607 checkpoint_dir: A string. Optional path to a directory where to restore 608 variables. 609 checkpoint_filename_with_path: Full file name path to the checkpoint file. 610 """ 611 self._checkpoint_dir = checkpoint_dir 612 self._checkpoint_filename_with_path = checkpoint_filename_with_path 613 self._scaffold = scaffold or Scaffold() 614 self._session_manager = None 615 self._master = master 616 self._config = config 617 618 def _get_session_manager(self): 619 if self._session_manager: 620 return self._session_manager 621 622 self._session_manager = sm.SessionManager( 623 local_init_op=self._scaffold.local_init_op, 624 ready_op=self._scaffold.ready_op, 625 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 626 graph=ops.get_default_graph()) 627 return self._session_manager 628 629 def create_session(self): 630 self._scaffold.finalize() 631 return self._get_session_manager().prepare_session( 632 self._master, 633 saver=self._scaffold.saver, 634 checkpoint_dir=self._checkpoint_dir, 635 checkpoint_filename_with_path=self._checkpoint_filename_with_path, 636 config=self._config, 637 init_op=self._scaffold.init_op, 638 init_feed_dict=self._scaffold.init_feed_dict, 639 init_fn=self._scaffold.init_fn) 640 641 642@tf_export(v1=['train.WorkerSessionCreator']) 643class WorkerSessionCreator(SessionCreator): 644 """Creates a tf.Session for a worker.""" 645 646 def __init__(self, 647 scaffold=None, 648 master='', 649 config=None, 650 max_wait_secs=30 * 60): 651 """Initializes a worker session creator. 652 653 Args: 654 scaffold: A `Scaffold` used for gathering or building supportive ops. If 655 not specified a default one is created. It's used to finalize the graph. 656 master: `String` representation of the TensorFlow master to use. 657 config: `ConfigProto` proto used to configure the session. 658 max_wait_secs: Maximum time to wait for the session to become available. 659 """ 660 self._scaffold = scaffold or Scaffold() 661 self._session_manager = None 662 self._master = master 663 self._config = config 664 self._max_wait_secs = max_wait_secs 665 666 def _get_session_manager(self): 667 if self._session_manager: 668 return self._session_manager 669 670 self._session_manager = sm.SessionManager( 671 local_init_op=self._scaffold.local_init_op, 672 ready_op=self._scaffold.ready_op, 673 ready_for_local_init_op=self._scaffold.ready_for_local_init_op, 674 graph=ops.get_default_graph()) 675 return self._session_manager 676 677 def create_session(self): 678 self._scaffold.finalize() 679 return self._get_session_manager().wait_for_session( 680 self._master, config=self._config, max_wait_secs=self._max_wait_secs) 681 682 683class _MonitoredSession(object): 684 """See `MonitoredSession` or `SingularMonitoredSession`.""" 685 686 def __init__(self, 687 session_creator, 688 hooks, 689 should_recover, 690 stop_grace_period_secs=120): 691 """Sets up a Monitored or Hooked Session. 692 693 Args: 694 session_creator: A factory object to create session. Typically a 695 `ChiefSessionCreator` or a `WorkerSessionCreator`. 696 hooks: An iterable of `SessionRunHook' objects. 697 should_recover: A bool. Indicates whether to recover from `AbortedError` 698 and `UnavailableError` or not. 699 stop_grace_period_secs: Number of seconds given to threads to stop after 700 `close()` has been called. 701 """ 702 self._graph_was_finalized = ops.get_default_graph().finalized 703 self._hooks = hooks or [] 704 for h in self._hooks: 705 h.begin() 706 707 worker_context = distribute_coordinator_context.get_current_worker_context() 708 if not session_creator and worker_context: 709 session_creator = worker_context.session_creator() 710 711 # Create the session. 712 self._coordinated_creator = self._CoordinatedSessionCreator( 713 session_creator=session_creator or ChiefSessionCreator(), 714 hooks=self._hooks, 715 stop_grace_period_secs=stop_grace_period_secs) 716 if should_recover: 717 self._sess = _RecoverableSession(self._coordinated_creator) 718 else: 719 self._sess = self._coordinated_creator.create_session() 720 721 @property 722 def graph(self): 723 """The graph that was launched in this session.""" 724 if self._tf_sess() is None: 725 return None 726 return self._tf_sess().graph 727 728 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 729 """Run ops in the monitored session. 730 731 This method is completely compatible with the `tf.Session.run()` method. 732 733 Args: 734 fetches: Same as `tf.Session.run()`. 735 feed_dict: Same as `tf.Session.run()`. 736 options: Same as `tf.Session.run()`. 737 run_metadata: Same as `tf.Session.run()`. 738 739 Returns: 740 Same as `tf.Session.run()`. 741 """ 742 return self._sess.run( 743 fetches, 744 feed_dict=feed_dict, 745 options=options, 746 run_metadata=run_metadata) 747 748 def run_step_fn(self, step_fn): 749 """Run ops using a step function. 750 751 Args: 752 step_fn: A function or a method with a single argument of type 753 `StepContext`. The function may use methods of the argument to perform 754 computations with access to a raw session. The returned value of the 755 `step_fn` will be returned from `run_step_fn`, unless a stop is 756 requested. In that case, the next `should_stop` call will return True. 757 Example usage: ```python 758 with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v = 759 tf.add(c, 4.0) w = tf.add(c, 0.5) 760 def step_fn(step_context): 761 a = step_context.session.run(fetches=v, feed_dict={c: 0.5}) 762 if a <= 4.5: step_context.request_stop() 763 return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1}) 764 with tf.MonitoredSession() as session: 765 while not session.should_stop(): a = session.run_step_fn(step_fn) 766 ``` Hooks interact with the `run_with_hooks()` call inside the 767 `step_fn` as they do with a `MonitoredSession.run` call. 768 769 Returns: 770 Returns the returned value of `step_fn`. 771 772 Raises: 773 StopIteration: if `step_fn` has called `request_stop()`. It may be 774 caught by `with tf.MonitoredSession()` to close the session. 775 ValueError: if `step_fn` doesn't have a single argument called 776 `step_context`. It may also optionally have `self` for cases when it 777 belongs to an object. 778 """ 779 step_fn_arguments = function_utils.fn_args(step_fn) 780 if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 781 'self', 782 'step_context', 783 ): 784 raise ValueError( 785 '`step_fn` may either have one `step_context` argument, or' 786 ' `self` and `step_context` arguments if it\'s an instance' 787 ' method. Got {} instead.'.format(step_fn_arguments)) 788 789 # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`. 790 # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be 791 # `_CoordinatedSession.run` downstream in either case. This allows 792 # `_PREEMPTION_ERRORS` to propage from within `step_fn` to 793 # `_RecoverableSession.run_step_fn`. 794 return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None) 795 796 class StepContext(object): 797 """Control flow instrument for the `step_fn` from `run_step_fn()`. 798 799 Users of `step_fn` may perform `run()` calls without running hooks 800 by accessing the `session`. A `run()` call with hooks may be performed 801 using `run_with_hooks()`. Computation flow can be interrupted using 802 `request_stop()`. 803 """ 804 805 def __init__(self, session, run_with_hooks_fn): 806 """Initializes the `step_context` argument for a `step_fn` invocation. 807 808 Args: 809 session: An instance of `tf.Session`. 810 run_with_hooks_fn: A function for running fetches and hooks. 811 """ 812 self._session = session 813 self._run_with_hooks_fn = run_with_hooks_fn 814 815 @property 816 def session(self): 817 return self._session 818 819 def run_with_hooks(self, *args, **kwargs): 820 """Same as `MonitoredSession.run`. Accepts the same arguments.""" 821 return self._run_with_hooks_fn(*args, **kwargs) 822 823 def request_stop(self): 824 """Exit the training loop by causing `should_stop()` to return `True`. 825 826 Causes `step_fn` to exit by raising an exception. 827 828 Raises: 829 StopIteration 830 """ 831 raise StopIteration('step_fn has requested the iterations to stop.') 832 833 def should_stop(self): 834 return self._sess is None or self._sess.should_stop() 835 836 def close(self): 837 self._close_internal() 838 839 def __enter__(self): 840 return self 841 842 def __exit__(self, exception_type, exception_value, traceback): 843 if exception_type in [errors.OutOfRangeError, StopIteration]: 844 exception_type = None 845 self._close_internal(exception_type) 846 # __exit__ should return True to suppress an exception. 847 return exception_type is None 848 849 class _CoordinatedSessionCreator(SessionCreator): 850 """Factory for the _RecoverableSession.""" 851 852 def __init__(self, session_creator, hooks, stop_grace_period_secs): 853 self._session_creator = session_creator 854 self._hooks = hooks 855 self.coord = None 856 self.tf_sess = None 857 self._stop_grace_period_secs = stop_grace_period_secs 858 859 def create_session(self): 860 """Creates a coordinated session.""" 861 # Keep the tf_sess for unit testing. 862 self.tf_sess = self._session_creator.create_session() 863 # We don't want coordinator to suppress any exception. 864 self.coord = coordinator.Coordinator(clean_stop_exception_types=[]) 865 if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 866 queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord) 867 # Inform the hooks that a new session has been created. 868 for hook in self._hooks: 869 hook.after_create_session(self.tf_sess, self.coord) 870 return _CoordinatedSession( 871 _HookedSession(self.tf_sess, self._hooks), self.coord, 872 self._stop_grace_period_secs) 873 874 def _close_internal(self, exception_type=None): 875 try: 876 if not exception_type: 877 for h in self._hooks: 878 h.end(self._coordinated_creator.tf_sess) 879 finally: 880 try: 881 if self._sess is None: 882 raise RuntimeError('Session is already closed.') 883 self._sess.close() 884 finally: 885 self._sess = None 886 self._coordinated_creator.tf_sess = None 887 self._coordinated_creator.coord = None 888 if not self._graph_was_finalized: 889 ops.get_default_graph()._unsafe_unfinalize() # pylint: disable=protected-access 890 891 def _is_closed(self): 892 """Return True if the monitored session is closed. 893 894 For tests only. 895 896 Returns: 897 A boolean. 898 """ 899 return self._coordinated_creator.tf_sess is None 900 901 def _tf_sess(self): 902 """Return underlying tf.Session object. 903 904 Warning: accessing the returned object in user code is likely to cause races 905 or "flaky tests". 906 907 Returns: 908 A tf.Session object. 909 """ 910 return self._coordinated_creator.tf_sess 911 912 913@tf_export(v1=['train.MonitoredSession']) 914class MonitoredSession(_MonitoredSession): 915 """Session-like object that handles initialization, recovery and hooks. 916 917 Example usage: 918 919 ```python 920 saver_hook = CheckpointSaverHook(...) 921 summary_hook = SummarySaverHook(...) 922 with MonitoredSession(session_creator=ChiefSessionCreator(...), 923 hooks=[saver_hook, summary_hook]) as sess: 924 while not sess.should_stop(): 925 sess.run(train_op) 926 ``` 927 928 Initialization: At creation time the monitored session does following things 929 in given order: 930 931 * calls `hook.begin()` for each given hook 932 * finalizes the graph via `scaffold.finalize()` 933 * create session 934 * initializes the model via initialization ops provided by `Scaffold` 935 * restores variables if a checkpoint exists 936 * launches queue runners 937 * calls `hook.after_create_session()` 938 939 Run: When `run()` is called, the monitored session does following things: 940 941 * calls `hook.before_run()` 942 * calls TensorFlow `session.run()` with merged fetches and feed_dict 943 * calls `hook.after_run()` 944 * returns result of `session.run()` asked by user 945 * if `AbortedError` or `UnavailableError` occurs, it recovers or 946 reinitializes the session before executing the run() call again 947 948 949 Exit: At the `close()`, the monitored session does following things in order: 950 951 * calls `hook.end()` 952 * closes the queue runners and the session 953 * suppresses `OutOfRange` error which indicates that all inputs have been 954 processed if the monitored_session is used as a context 955 956 How to set `tf.Session` arguments: 957 958 * In most cases you can set session arguments as follows: 959 960 ```python 961 MonitoredSession( 962 session_creator=ChiefSessionCreator(master=..., config=...)) 963 ``` 964 965 * In distributed setting for a non-chief worker, you can use following: 966 967 ```python 968 MonitoredSession( 969 session_creator=WorkerSessionCreator(master=..., config=...)) 970 ``` 971 972 See `MonitoredTrainingSession` for an example usage based on chief or worker. 973 974 Note: This is not a `tf.Session`. For example, it cannot do following: 975 976 * it cannot be set as default session. 977 * it cannot be sent to saver.save. 978 * it cannot be sent to tf.train.start_queue_runners. 979 980 Args: 981 session_creator: A factory object to create session. Typically a 982 `ChiefSessionCreator` which is the default one. 983 hooks: An iterable of `SessionRunHook' objects. 984 985 Returns: 986 A MonitoredSession object. 987 """ 988 989 def __init__(self, 990 session_creator=None, 991 hooks=None, 992 stop_grace_period_secs=120): 993 super(MonitoredSession, self).__init__( 994 session_creator, 995 hooks, 996 should_recover=True, 997 stop_grace_period_secs=stop_grace_period_secs) 998 999 1000@tf_export(v1=['train.SingularMonitoredSession']) 1001class SingularMonitoredSession(_MonitoredSession): 1002 """Session-like object that handles initialization, restoring, and hooks. 1003 1004 Please note that this utility is not recommended for distributed settings. 1005 For distributed settings, please use `tf.train.MonitoredSession`. The 1006 differences between `MonitoredSession` and `SingularMonitoredSession` are: 1007 1008 * `MonitoredSession` handles `AbortedError` and `UnavailableError` for 1009 distributed settings, but `SingularMonitoredSession` does not. 1010 * `MonitoredSession` can be created in `chief` or `worker` modes. 1011 `SingularMonitoredSession` is always created as `chief`. 1012 * You can access the raw `tf.Session` object used by 1013 `SingularMonitoredSession`, whereas in MonitoredSession the raw session is 1014 private. This can be used: 1015 - To `run` without hooks. 1016 - To save and restore. 1017 * All other functionality is identical. 1018 1019 Example usage: 1020 ```python 1021 saver_hook = CheckpointSaverHook(...) 1022 summary_hook = SummarySaverHook(...) 1023 with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess: 1024 while not sess.should_stop(): 1025 sess.run(train_op) 1026 ``` 1027 1028 Initialization: At creation time the hooked session does following things 1029 in given order: 1030 1031 * calls `hook.begin()` for each given hook 1032 * finalizes the graph via `scaffold.finalize()` 1033 * create session 1034 * initializes the model via initialization ops provided by `Scaffold` 1035 * restores variables if a checkpoint exists 1036 * launches queue runners 1037 1038 Run: When `run()` is called, the hooked session does following things: 1039 1040 * calls `hook.before_run()` 1041 * calls TensorFlow `session.run()` with merged fetches and feed_dict 1042 * calls `hook.after_run()` 1043 * returns result of `session.run()` asked by user 1044 1045 Exit: At the `close()`, the hooked session does following things in order: 1046 1047 * calls `hook.end()` 1048 * closes the queue runners and the session 1049 * suppresses `OutOfRange` error which indicates that all inputs have been 1050 processed if the `SingularMonitoredSession` is used as a context. 1051 """ 1052 1053 def __init__(self, 1054 hooks=None, 1055 scaffold=None, 1056 master='', 1057 config=None, 1058 checkpoint_dir=None, 1059 stop_grace_period_secs=120, 1060 checkpoint_filename_with_path=None): 1061 """Creates a SingularMonitoredSession. 1062 1063 Args: 1064 hooks: An iterable of `SessionRunHook' objects. 1065 scaffold: A `Scaffold` used for gathering or building supportive ops. If 1066 not specified a default one is created. It's used to finalize the graph. 1067 master: `String` representation of the TensorFlow master to use. 1068 config: `ConfigProto` proto used to configure the session. 1069 checkpoint_dir: A string. Optional path to a directory where to restore 1070 variables. 1071 stop_grace_period_secs: Number of seconds given to threads to stop after 1072 `close()` has been called. 1073 checkpoint_filename_with_path: A string. Optional path to a checkpoint 1074 file from which to restore variables. 1075 """ 1076 session_creator = ChiefSessionCreator( 1077 scaffold=scaffold, 1078 master=master, 1079 config=config, 1080 checkpoint_dir=checkpoint_dir, 1081 checkpoint_filename_with_path=checkpoint_filename_with_path) 1082 super(SingularMonitoredSession, self).__init__( 1083 session_creator, 1084 hooks, 1085 should_recover=False, 1086 stop_grace_period_secs=stop_grace_period_secs) 1087 1088 def raw_session(self): 1089 """Returns underlying `TensorFlow.Session` object.""" 1090 return self._tf_sess() 1091 1092 1093class _WrappedSession(object): 1094 """Wrapper around a `tf.Session`. 1095 1096 This wrapper is used as a base class for various session wrappers 1097 that provide additional functionality such as monitoring, coordination, 1098 and recovery. 1099 1100 In addition to the methods exported by `SessionInterface` the wrapper 1101 provides a method to check for stop and never raises exceptions from 1102 calls to `close()`. 1103 """ 1104 1105 def __init__(self, sess): 1106 """Creates a `_WrappedSession`. 1107 1108 Args: 1109 sess: A `tf.Session` or `_WrappedSession` object. The wrapped session. 1110 """ 1111 self._sess = sess 1112 self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession) 1113 1114 @property 1115 def graph(self): 1116 return self._sess.graph 1117 1118 @property 1119 def sess_str(self): 1120 return self._sess.sess_str 1121 1122 def should_stop(self): 1123 """Return true if this session should not be used anymore. 1124 1125 Always return True if the session was closed. 1126 1127 Returns: 1128 True if the session should stop, False otherwise. 1129 """ 1130 if self._check_stop(): 1131 return True 1132 if self._sess: 1133 return self._wrapped_is_stoppable and self._sess.should_stop() 1134 return True 1135 1136 def _check_stop(self): 1137 """Hook for subclasses to provide their own stop condition. 1138 1139 Returns: 1140 True if the session should stop, False otherwise. 1141 """ 1142 return False 1143 1144 def close(self): 1145 if self._sess: 1146 try: 1147 self._sess.close() 1148 except _PREEMPTION_ERRORS as e: 1149 logging.warning( 1150 'An error occurred when attempting to close the ' 1151 'session. This may be due to a preemption in a ' 1152 'connected worker or parameter server. Error: %s', e) 1153 finally: 1154 self._sess = None 1155 1156 def run(self, *args, **kwargs): 1157 return self._sess.run(*args, **kwargs) 1158 1159 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1160 # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`. 1161 # It is `None` when called from `_CoordinatedSession`. In that case 1162 # `self.run` is `_CoordinatedSession.run`. 1163 run_with_hooks = run_with_hooks or self.run 1164 return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks)) 1165 1166 1167class _RecoverableSession(_WrappedSession): 1168 """A wrapped session that recreates a session upon certain kinds of errors. 1169 1170 The constructor is passed a SessionCreator object, not a session. 1171 1172 Calls to `run()` are delegated to the wrapped session. If a call raises the 1173 exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the 1174 wrapped session is closed, and a new one is created by calling the factory 1175 again. 1176 """ 1177 1178 def __init__(self, sess_creator): 1179 """Create a new `_RecoverableSession`. 1180 1181 The value returned by calling `sess_creator.create_session()` will be the 1182 session wrapped by this recoverable session. 1183 1184 Args: 1185 sess_creator: A 'SessionCreator' to be wrapped by recoverable. 1186 """ 1187 self._sess_creator = sess_creator 1188 _WrappedSession.__init__(self, self._create_session()) 1189 1190 def _create_session(self): 1191 while True: 1192 try: 1193 return self._sess_creator.create_session() 1194 except _PREEMPTION_ERRORS as e: 1195 logging.info( 1196 'An error was raised while a session was being created. ' 1197 'This may be due to a preemption of a connected worker ' 1198 'or parameter server. A new session will be created. ' 1199 'This error may also occur due to a gRPC failure caused ' 1200 'by high memory or network bandwidth usage in the ' 1201 'parameter servers. If this error occurs repeatedly, try ' 1202 'increasing the number of parameter servers assigned to ' 1203 'the job. Error: %s', e) 1204 1205 def _check_stop(self): 1206 try: 1207 if self._sess: 1208 return self._sess._check_stop() # pylint: disable=protected-access 1209 else: 1210 return True 1211 except _PREEMPTION_ERRORS as e: 1212 logging.info( 1213 'An error was raised while considering whether the ' 1214 'session is complete. This may be due to a preemption in ' 1215 'a connected worker or parameter server. The current ' 1216 'session will be closed and a new session will be ' 1217 'created. This error may also occur due to a gRPC failure ' 1218 'caused by high memory or network bandwidth usage in the ' 1219 'parameter servers. If this error occurs repeatedly, try ' 1220 'increasing the number of parameter servers assigned to ' 1221 'the job. Error: %s', e) 1222 self.close() 1223 self._sess = self._create_session() 1224 # Since we have just recreated the session, the overall computation should 1225 # not stop: 1226 return False 1227 except Exception: # pylint: disable=broad-except 1228 # `should_stop` should return True instead of raising an exception. 1229 return True 1230 1231 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1232 while True: 1233 try: 1234 if not self._sess: 1235 self._sess = self._create_session() 1236 return self._sess.run( 1237 fetches, 1238 feed_dict=feed_dict, 1239 options=options, 1240 run_metadata=run_metadata) 1241 except _PREEMPTION_ERRORS as e: 1242 logging.info( 1243 'An error was raised. This may be due to a preemption in ' 1244 'a connected worker or parameter server. The current ' 1245 'session will be closed and a new session will be ' 1246 'created. This error may also occur due to a gRPC failure ' 1247 'caused by high memory or network bandwidth usage in the ' 1248 'parameter servers. If this error occurs repeatedly, try ' 1249 'increasing the number of parameter servers assigned to ' 1250 'the job. Error: %s', e) 1251 self.close() 1252 self._sess = None 1253 1254 def run_step_fn(self, step_fn, raw_session, run_with_hooks): 1255 while True: 1256 try: 1257 if not self._sess: 1258 self._sess = self._create_session() 1259 1260 run_with_hooks = self._sess.run 1261 return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks) 1262 except _PREEMPTION_ERRORS as e: 1263 logging.info( 1264 'An error was raised. This may be due to a preemption in ' 1265 'a connected worker or parameter server. The current ' 1266 'session will be closed and a new session will be ' 1267 'created. This error may also occur due to a gRPC failure ' 1268 'caused by high memory or network bandwidth usage in the ' 1269 'parameter servers. If this error occurs repeatedly, try ' 1270 'increasing the number of parameter servers assigned to ' 1271 'the job. Error: %s', e) 1272 self.close() 1273 self._sess = None 1274 1275 1276class _CoordinatedSession(_WrappedSession): 1277 """A wrapped session that works with a `tf.Coordinator`. 1278 1279 Calls to `run()` are delegated to the wrapped session. If a call 1280 raises an exception, the exception is reported to the coordinator. 1281 1282 In addition, after each call to `run()` this session ask the coordinator if 1283 the session should stop. In that case it will will join all the threads 1284 registered with the coordinator before returning. 1285 1286 If the coordinator was requested to stop with an exception, that exception 1287 will be re-raised from the call to `run()`. 1288 """ 1289 1290 def __init__(self, sess, coord, stop_grace_period_secs=120): 1291 """Create a new `_CoordinatedSession`. 1292 1293 Args: 1294 sess: A `tf.Session` object. The wrapped session. 1295 coord: A `tf.train.Coordinator` object. 1296 stop_grace_period_secs: Number of seconds given to threads to stop after 1297 `close()` has been called. 1298 """ 1299 _WrappedSession.__init__(self, sess) 1300 self._coord = coord 1301 self._stop_grace_period_secs = stop_grace_period_secs 1302 1303 def _check_stop(self): 1304 # If the coordinator was asked to stop due to an exception, then it needs 1305 # to be propagated to this stack. 1306 self._coord.raise_requested_exception() 1307 # At this point, no exceptions are recorded in the coordinator. 1308 return self._coord.should_stop() 1309 1310 def close(self): 1311 self._coord.request_stop() 1312 try: 1313 self._coord.join( 1314 stop_grace_period_secs=self._stop_grace_period_secs, 1315 ignore_live_threads=True) 1316 finally: 1317 try: 1318 _WrappedSession.close(self) 1319 except Exception: # pylint: disable=broad-except 1320 # We intentionally suppress exceptions from the close() here since 1321 # useful exceptions are already reported by join(). 1322 pass 1323 1324 def run(self, *args, **kwargs): 1325 try: 1326 return self._sess.run(*args, **kwargs) 1327 except _PREEMPTION_ERRORS: 1328 raise 1329 except Exception: # pylint: disable=broad-except 1330 # A non-preemption error could have been caused by a preemption error 1331 # in the coordinator. If this is the case, raise that exception instead, 1332 # since it's the root cause. Otherwise, stick to the `original_exc_info`. 1333 original_exc_info = sys.exc_info() 1334 try: 1335 self._coord.raise_requested_exception() 1336 except _PREEMPTION_ERRORS: 1337 raise 1338 except Exception: # pylint: disable=broad-except 1339 raise six.reraise(*original_exc_info) 1340 else: 1341 raise six.reraise(*original_exc_info) 1342 1343 1344class _HookedSession(_WrappedSession): 1345 """A _WrappedSession that calls hooks during calls to run(). 1346 1347 The list of hooks to call is passed in the constructor. Before each call 1348 to `run()` the session calls the `before_run()` method of the hooks, which 1349 can return additional ops or tensors to run. These are added to the arguments 1350 of the call to `run()`. 1351 1352 When the `run()` call finishes, the session calls the `after_run()` methods of 1353 the hooks, passing the values returned by the `run()` call corresponding to 1354 the ops and tensors that each hook requested. 1355 1356 If any call to the hooks, requests stop via run_context the session will be 1357 marked as needing to stop and its `should_stop()` method will now return 1358 `True`. 1359 """ 1360 1361 def __init__(self, sess, hooks): 1362 """Initializes a _HookedSession object. 1363 1364 Args: 1365 sess: A `tf.Session` or a `_WrappedSession` object. 1366 hooks: An iterable of `SessionRunHook' objects. 1367 """ 1368 1369 _WrappedSession.__init__(self, sess) 1370 self._hooks = hooks 1371 self._should_stop = False 1372 1373 def _check_stop(self): 1374 """See base class.""" 1375 return self._should_stop 1376 1377 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 1378 """See base class.""" 1379 if self.should_stop(): 1380 raise RuntimeError('Run called even after should_stop requested.') 1381 1382 actual_fetches = {'caller': fetches} 1383 1384 run_context = session_run_hook.SessionRunContext( 1385 original_args=session_run_hook.SessionRunArgs(fetches, feed_dict), 1386 session=self._sess) 1387 1388 options = options or config_pb2.RunOptions() 1389 feed_dict = self._call_hook_before_run(run_context, actual_fetches, 1390 feed_dict, options) 1391 1392 # Do session run. 1393 run_metadata = run_metadata or config_pb2.RunMetadata() 1394 outputs = _WrappedSession.run( 1395 self, 1396 fetches=actual_fetches, 1397 feed_dict=feed_dict, 1398 options=options, 1399 run_metadata=run_metadata) 1400 1401 for hook in self._hooks: 1402 hook.after_run( 1403 run_context, 1404 session_run_hook.SessionRunValues( 1405 results=outputs[hook] if hook in outputs else None, 1406 options=options, 1407 run_metadata=run_metadata)) 1408 self._should_stop = self._should_stop or run_context.stop_requested 1409 1410 return outputs['caller'] 1411 1412 def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict, 1413 options): 1414 """Calls hooks.before_run and handles requests from hooks.""" 1415 hook_feeds = {} 1416 for hook in self._hooks: 1417 request = hook.before_run(run_context) 1418 if request is not None: 1419 if request.fetches is not None: 1420 fetch_dict[hook] = request.fetches 1421 if request.feed_dict: 1422 self._raise_if_feeds_intersects(hook_feeds, request.feed_dict, 1423 'Same tensor is fed by two hooks.') 1424 hook_feeds.update(request.feed_dict) 1425 if request.options: 1426 self._merge_run_options(options, request.options) 1427 1428 if not hook_feeds: 1429 return user_feed_dict 1430 1431 if not user_feed_dict: 1432 return hook_feeds 1433 1434 self._raise_if_feeds_intersects( 1435 user_feed_dict, hook_feeds, 1436 'Same tensor is fed by a SessionRunHook and user.') 1437 hook_feeds.update(user_feed_dict) 1438 return hook_feeds 1439 1440 def _raise_if_feeds_intersects(self, feeds1, feeds2, message): 1441 intersection = set(feeds1.keys()) & set(feeds2.keys()) 1442 if intersection: 1443 raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection))) 1444 1445 def _merge_run_options(self, options, incoming_options): 1446 """Merge two instances of RunOptions into the first one. 1447 1448 During the merger, the numerical fields including trace_level, 1449 timeout_in_ms, inter_op_thread_pool are set to the larger one of the two. 1450 The boolean value is set to the logical OR of the two. 1451 debug_tensor_watch_opts of the original options is extended with that from 1452 the incoming one. 1453 1454 Args: 1455 options: The options to merge into. 1456 incoming_options: The options to be merged into the first argument. 1457 """ 1458 options.trace_level = max(options.trace_level, incoming_options.trace_level) 1459 options.timeout_in_ms = max(options.timeout_in_ms, 1460 incoming_options.timeout_in_ms) 1461 options.inter_op_thread_pool = max(options.inter_op_thread_pool, 1462 incoming_options.inter_op_thread_pool) 1463 options.output_partition_graphs = max( 1464 options.output_partition_graphs, 1465 incoming_options.output_partition_graphs) 1466 options.debug_options.debug_tensor_watch_opts.extend( 1467 incoming_options.debug_options.debug_tensor_watch_opts) 1468 options.debug_options.reset_disk_byte_usage = ( 1469 options.debug_options.reset_disk_byte_usage or 1470 incoming_options.debug_options.reset_disk_byte_usage) 1471 options.report_tensor_allocations_upon_oom = ( 1472 options.report_tensor_allocations_upon_oom or 1473 incoming_options.report_tensor_allocations_upon_oom) 1474