• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=g-bad-file-header
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""A wrapper of Session API which runs hooks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import abc
23import os
24import sys
25
26import six
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.distribute import distribute_coordinator_context
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import lookup_ops
35from tensorflow.python.ops import resources
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.summary import summary
39from tensorflow.python.training import basic_session_run_hooks
40from tensorflow.python.training import coordinator
41from tensorflow.python.training import queue_runner
42from tensorflow.python.training import saver as training_saver
43from tensorflow.python.training import session_manager as sm
44from tensorflow.python.training import session_run_hook
45from tensorflow.python.training.tracking import graph_view
46from tensorflow.python.training.tracking import util as trackable_util
47from tensorflow.python.util import function_utils
48from tensorflow.python.util.tf_export import tf_export
49
50# The list of exceptions that we should recover from. Exceptions not in this
51# list may terminate the job.
52_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError)
53
54# Value that indicates no value was provided.
55USE_DEFAULT = object()
56
57
58@tf_export(v1=['train.Scaffold'])
59class Scaffold(object):
60  """Structure to create or gather pieces commonly needed to train a model.
61
62  When you build a model for training you usually need ops to initialize
63  variables, a `Saver` to checkpoint them, an op to collect summaries for
64  the visualizer, and so on.
65
66  Various libraries built on top of the core TensorFlow library take care of
67  creating some or all of these pieces and storing them in well known
68  collections in the graph.  The `Scaffold` class helps pick these pieces from
69  the graph collections, creating and adding them to the collections if needed.
70
71  If you call the scaffold constructor without any arguments, it will pick
72  pieces from the collections, creating default ones if needed when
73  `scaffold.finalize()` is called.  You can pass arguments to the constructor to
74  provide your own pieces.  Pieces that you pass to the constructor are not
75  added to the graph collections.
76
77  The following pieces are directly accessible as attributes of the `Scaffold`
78  object:
79
80  * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
81  variables.
82    Picked from and stored into the `SAVERS` collection in the graph by default.
83  * `init_op`: An op to run to initialize the variables.  Picked from and
84    stored into the `INIT_OP` collection in the graph by default.
85  * `ready_op`: An op to verify that the variables are initialized.  Picked
86    from and stored into the `READY_OP` collection in the graph by default.
87  * `ready_for_local_init_op`: An op to verify that global state has been
88    initialized and it is alright to run `local_init_op`.  Picked from and
89    stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by
90    default. This is needed when the initialization of local variables depends
91    on the values of global variables.
92  * `local_init_op`: An op to initialize the local variables.  Picked
93    from and stored into the `LOCAL_INIT_OP` collection in the graph by default.
94  * `summary_op`: An op to run and merge the summaries in the graph.  Picked
95    from and stored into the `SUMMARY_OP` collection in the graph by default.
96
97  You can also pass the following additional pieces to the constructor:
98
99  * `init_feed_dict`: A session feed dictionary that should be used when
100     running the init op.
101  * `init_fn`: A callable to run after the init op to perform additional
102    initializations.  The callable will be called as
103    `init_fn(scaffold, session)`.
104
105  """
106
107  def __init__(self,
108               init_op=None,
109               init_feed_dict=None,
110               init_fn=None,
111               ready_op=None,
112               ready_for_local_init_op=None,
113               local_init_op=None,
114               summary_op=None,
115               saver=None,
116               copy_from_scaffold=None,
117               local_init_feed_dict=None):
118    """Create a scaffold.
119
120    Args:
121      init_op: Optional op for initializing variables.
122      init_feed_dict: Optional session feed dictionary to use when running the
123        init_op.
124      init_fn: Optional function to use to initialize the model after running
125        the init_op.  Will be called as `init_fn(scaffold, session)`.
126      ready_op: Optional op to verify that the variables are initialized.  Must
127        return an empty 1D string tensor when the variables are initialized, or
128        a non-empty 1D string tensor listing the names of the non-initialized
129        variables.
130      ready_for_local_init_op: Optional op to verify that the global variables
131        are initialized and `local_init_op` can be run. Must return an empty 1D
132        string tensor when the global variables are initialized, or a non-empty
133        1D string tensor listing the names of the non-initialized global
134        variables.
135      local_init_op: Optional op to initialize local variables.
136      summary_op: Optional op to gather all summaries.  Must return a scalar
137        string tensor containing a serialized `Summary` proto.
138      saver: Optional `tf.compat.v1.train.Saver` object to use to save and
139        restore variables.  May also be a `tf.train.Checkpoint` object, in which
140        case object-based checkpoints are saved. This will also load some
141        object-based checkpoints saved from elsewhere, but that loading may be
142        fragile since it uses fixed keys rather than performing a full
143        graph-based match. For example if a variable has two paths from the
144        `Checkpoint` object because two `Model` objects share the `Layer` object
145        that owns it, removing one `Model` may change the keys and break
146        checkpoint loading through this API, whereas a graph-based match would
147        match the variable through the other `Model`.
148      copy_from_scaffold: Optional scaffold object to copy fields from. Its
149        fields will be overwritten by the provided fields in this function.
150      local_init_feed_dict: Optional session feed dictionary to use when running
151        the local_init_op.
152    """
153    if copy_from_scaffold is not None:
154      if not isinstance(copy_from_scaffold, Scaffold):
155        raise TypeError('copy_from_scaffold is not a Scaffold instance.')
156      # We need _coalesce since Tensor is not converted to bool automatically,
157      # so the common idiom of (a or b) does not work.
158      coalesce = lambda a, b: a if a is not None else b
159      init_op = coalesce(init_op, copy_from_scaffold.init_op)
160      init_feed_dict = coalesce(init_feed_dict,
161                                copy_from_scaffold.init_feed_dict)
162      # Use the original init_fn provided by the user to init the new Scaffold.
163      init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn)  # pylint: disable=protected-access
164      ready_op = coalesce(ready_op, copy_from_scaffold.ready_op)
165      ready_for_local_init_op = coalesce(
166          ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
167      local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
168      local_init_feed_dict = coalesce(local_init_feed_dict,
169                                      copy_from_scaffold.local_init_feed_dict)
170      summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
171      saver = coalesce(saver, copy_from_scaffold.saver)
172
173    # NOTE(touts): modifying the init function to be passed the scaffold is a
174    # hack to make it easy to find the saver.  Is there a better way?
175    self._user_init_fn = init_fn
176    if init_fn:
177      self._init_fn = lambda sess: init_fn(self, sess)
178    else:
179      self._init_fn = None
180
181    self._init_op = init_op
182    self._init_feed_dict = init_feed_dict
183    self._ready_op = ready_op
184    self._ready_for_local_init_op = ready_for_local_init_op
185    self._local_init_op = local_init_op
186    self._local_init_feed_dict = local_init_feed_dict
187    self._summary_op = summary_op
188    self._saver = saver
189
190  def finalize(self):
191    """Creates operations if needed and finalizes the graph."""
192    if self._init_op is None:
193
194      def default_init_op():
195        return control_flow_ops.group(
196            variables.global_variables_initializer(),
197            resources.initialize_resources(resources.shared_resources()))
198
199      self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP,
200                                              default_init_op)
201    if self._ready_op is None:
202
203      def default_ready_op():
204        return array_ops.concat([
205            variables.report_uninitialized_variables(),
206            resources.report_uninitialized_resources()
207        ], 0)
208
209      self._ready_op = Scaffold.get_or_default('ready_op',
210                                               ops.GraphKeys.READY_OP,
211                                               default_ready_op)
212    if self._ready_for_local_init_op is None:
213
214      def default_ready_for_local_init_op():
215        return array_ops.concat([
216            variables.report_uninitialized_variables(
217                variables.global_variables()),
218            resources.report_uninitialized_resources(
219                resources.shared_resources())
220        ], 0)
221
222      self._ready_for_local_init_op = Scaffold.get_or_default(
223          'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
224          default_ready_for_local_init_op)
225    if self._local_init_op is None:
226      self._local_init_op = Scaffold.get_or_default(
227          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
228          Scaffold.default_local_init_op)
229    if self._summary_op is None:
230      self._summary_op = Scaffold.get_or_default('summary_op',
231                                                 ops.GraphKeys.SUMMARY_OP,
232                                                 summary.merge_all)
233    # pylint: disable=g-long-lambda
234    if self._saver is None:
235      self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
236    # pylint: enable=g-long-lambda
237    if isinstance(self._saver, trackable_util.Checkpoint):
238      self._saver = training_saver.Saver(
239          var_list=graph_view.ObjectGraphView(
240              self._saver).frozen_saveable_objects(),
241          sharded=True)
242    else:
243      self._saver.build()
244
245    ops.get_default_graph().finalize()
246    logging.info('Graph was finalized.')
247    return self
248
249  @property
250  def init_fn(self):
251    return self._init_fn
252
253  @property
254  def init_op(self):
255    return self._init_op
256
257  @property
258  def ready_op(self):
259    return self._ready_op
260
261  @property
262  def ready_for_local_init_op(self):
263    return self._ready_for_local_init_op
264
265  @property
266  def local_init_op(self):
267    return self._local_init_op
268
269  @property
270  def local_init_feed_dict(self):
271    return self._local_init_feed_dict
272
273  @property
274  def summary_op(self):
275    return self._summary_op
276
277  @property
278  def saver(self):
279    return self._saver
280
281  @property
282  def init_feed_dict(self):
283    return self._init_feed_dict
284
285  @staticmethod
286  def get_or_default(arg_name, collection_key, default_constructor):
287    """Get from cache or create a default operation."""
288    elements = ops.get_collection(collection_key)
289    if elements:
290      if len(elements) > 1:
291        raise RuntimeError(
292            'More than one item in the collection "%s". '
293            'Please indicate which one to use by passing it to '
294            'the tf.Scaffold constructor as:  '
295            'tf.Scaffold(%s=item to use)', collection_key, arg_name)
296      return elements[0]
297    op = default_constructor()
298    if op is not None:
299      ops.add_to_collection(collection_key, op)
300    return op
301
302  @staticmethod
303  def default_local_init_op():
304    """Returns an op that groups the default local init ops.
305
306    This op is used during session initialization when a Scaffold is
307    initialized without specifying the local_init_op arg. It includes
308    `tf.compat.v1.local_variables_initializer`,
309    `tf.compat.v1.tables_initializer`, and also
310    initializes local session resources.
311
312    Returns:
313      The default Scaffold local init op.
314    """
315    return control_flow_ops.group(
316        variables.local_variables_initializer(),
317        lookup_ops.tables_initializer(),
318        resources.initialize_resources(resources.local_resources()))
319
320
321def _create_monitored_session_with_worker_context(
322    worker_context,  # pylint: disable=missing-docstring
323    scaffold,
324    checkpoint_dir=None,
325    hooks=None,
326    chief_only_hooks=None,
327    save_checkpoint_secs=None,
328    save_summaries_steps=None,
329    save_summaries_secs=None,
330    config=None,
331    stop_grace_period_secs=120,
332    log_step_count_steps=100,
333    max_wait_secs=7200,
334    save_checkpoint_steps=None,
335    summary_dir=None,
336    save_graph_def=True):
337  all_hooks = []
338  if hooks:
339    all_hooks.extend(hooks)
340  if chief_only_hooks and worker_context.is_chief:
341    all_hooks.extend(chief_only_hooks)
342
343  # We need to call save or summary ops on all workers since these ops may
344  # contain collective ops, only running save ops on some workers would make
345  # collective ops hang. Therefore on those workers that don't need to actually
346  # write checkpoints or summaries, we let them write to a temp directory.
347  # pylint: disable=protected-access
348  if type(
349      worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy',
350                                             'CollectiveAllReduceStrategyV1',
351                                             'MultiWorkerMirroredStrategy'):
352    if worker_context.task_type:
353      tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id)
354    else:
355      tmpdir = 'tmp'
356
357    if save_checkpoint_secs:
358      logging.warning('Collective ops may deadlock with '
359                      '`save_checkpoints_secs` please use '
360                      '`save_checkpoint_steps` instead. Clearing '
361                      '`save_checkpoint_secs` and setting '
362                      '`save_checkpoint_steps` to 1000 now.')
363      save_checkpoint_secs = None
364      save_checkpoint_steps = 1000
365    if save_summaries_secs:
366      logging.warning('Collective ops may run out of sync with'
367                      '`save_summaries_secs`, please use '
368                      '`save_summaries_steps` instead.')
369  else:
370    tmpdir = None
371
372  summary_dir = summary_dir or checkpoint_dir
373  if summary_dir and log_step_count_steps and log_step_count_steps > 0:
374    if worker_context.should_save_summary:
375      all_hooks.append(
376          basic_session_run_hooks.StepCounterHook(
377              output_dir=summary_dir, every_n_steps=log_step_count_steps))
378    elif tmpdir:
379      all_hooks.append(
380          basic_session_run_hooks.StepCounterHook(
381              output_dir=os.path.join(summary_dir, tmpdir),
382              every_n_steps=log_step_count_steps))
383
384  if (((save_summaries_steps and save_summaries_steps > 0) or
385       (save_summaries_secs and save_summaries_secs > 0)) and summary_dir):
386    if worker_context.should_save_summary:
387      all_hooks.append(
388          basic_session_run_hooks.SummarySaverHook(
389              scaffold=scaffold,
390              save_steps=save_summaries_steps,
391              save_secs=save_summaries_secs,
392              output_dir=summary_dir))
393    elif tmpdir:
394      all_hooks.append(
395          basic_session_run_hooks.SummarySaverHook(
396              scaffold=scaffold,
397              save_steps=save_summaries_steps,
398              save_secs=save_summaries_secs,
399              output_dir=os.path.join(summary_dir, tmpdir)))
400
401    if (((save_checkpoint_secs and save_checkpoint_secs > 0) or
402         (save_checkpoint_steps and save_checkpoint_steps > 0)) and
403        checkpoint_dir):
404      if worker_context.should_checkpoint:
405        all_hooks.append(
406            basic_session_run_hooks.CheckpointSaverHook(
407                checkpoint_dir,
408                save_steps=save_checkpoint_steps,
409                save_secs=save_checkpoint_secs,
410                scaffold=scaffold,
411                save_graph_def=save_graph_def))
412      elif tmpdir:
413        all_hooks.append(
414            basic_session_run_hooks.CheckpointSaverHook(
415                os.path.join(checkpoint_dir, tmpdir),
416                save_steps=save_checkpoint_steps,
417                save_secs=save_checkpoint_secs,
418                scaffold=scaffold,
419                save_graph_def=save_graph_def))
420
421  logging.info('all_hooks %r', all_hooks)
422  session_creator = worker_context.session_creator(
423      scaffold,
424      config=config,
425      checkpoint_dir=checkpoint_dir,
426      max_wait_secs=max_wait_secs)
427  return MonitoredSession(
428      session_creator=session_creator,
429      hooks=all_hooks,
430      stop_grace_period_secs=stop_grace_period_secs)
431
432
433@tf_export(v1=['train.MonitoredTrainingSession'])
434def MonitoredTrainingSession(
435    master='',  # pylint: disable=invalid-name
436    is_chief=True,
437    checkpoint_dir=None,
438    scaffold=None,
439    hooks=None,
440    chief_only_hooks=None,
441    save_checkpoint_secs=USE_DEFAULT,
442    save_summaries_steps=USE_DEFAULT,
443    save_summaries_secs=USE_DEFAULT,
444    config=None,
445    stop_grace_period_secs=120,
446    log_step_count_steps=100,
447    max_wait_secs=7200,
448    save_checkpoint_steps=USE_DEFAULT,
449    summary_dir=None,
450    save_graph_def=True):
451  """Creates a `MonitoredSession` for training.
452
453  For a chief, this utility sets proper session initializer/restorer. It also
454  creates hooks related to checkpoint and summary saving. For workers, this
455  utility sets proper session creator which waits for the chief to
456  initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
457  more
458  information.
459
460
461  Args:
462    master: `String` the TensorFlow master to use.
463    is_chief: If `True`, it will take care of initialization and recovery the
464      underlying TensorFlow session. If `False`, it will wait on a chief to
465      initialize or recover the TensorFlow session.
466    checkpoint_dir: A string.  Optional path to a directory where to restore
467      variables.
468    scaffold: A `Scaffold` used for gathering or building supportive ops. If not
469      specified, a default one is created. It's used to finalize the graph.
470    hooks: Optional list of `SessionRunHook` objects.
471    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
472      `is_chief==True`, ignore otherwise.
473    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
474      using a default checkpoint saver. If both `save_checkpoint_steps` and
475      `save_checkpoint_secs` are set to `None`, then the default checkpoint
476      saver isn't used. If both are provided, then only `save_checkpoint_secs`
477      is used. Default 600.
478    save_summaries_steps: The frequency, in number of global steps, that the
479      summaries are written to disk using a default summary saver. If both
480      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
481      the default summary saver isn't used. Default 100.
482    save_summaries_secs: The frequency, in secs, that the summaries are written
483      to disk using a default summary saver.  If both `save_summaries_steps` and
484      `save_summaries_secs` are set to `None`, then the default summary saver
485      isn't used. Default not enabled.
486    config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
487      the session. It's the `config` argument of constructor of
488      `tf.compat.v1.Session`.
489    stop_grace_period_secs: Number of seconds given to threads to stop after
490      `close()` has been called.
491    log_step_count_steps: The frequency, in number of global steps, that the
492      global step/sec is logged.
493    max_wait_secs: Maximum time workers should wait for the session to become
494      available. This should be kept relatively short to help detect incorrect
495      code, but sometimes may need to be increased if the chief takes a while to
496      start up.
497    save_checkpoint_steps: The frequency, in number of global steps, that a
498      checkpoint is saved using a default checkpoint saver. If both
499      `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
500      the default checkpoint saver isn't used. If both are provided, then only
501      `save_checkpoint_secs` is used. Default not enabled.
502    summary_dir: A string.  Optional path to a directory where to save
503      summaries. If None, checkpoint_dir is used instead.
504    save_graph_def: Whether to save the GraphDef and MetaGraphDef to
505      `checkpoint_dir`. The GraphDef is saved after the session is created as
506      `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
507      `model.ckpt-*.meta`.
508
509  Returns:
510    A `MonitoredSession` object.
511  """
512  if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT:
513    save_summaries_steps = 100
514    save_summaries_secs = None
515  elif save_summaries_secs == USE_DEFAULT:
516    save_summaries_secs = None
517  elif save_summaries_steps == USE_DEFAULT:
518    save_summaries_steps = None
519
520  if (save_checkpoint_steps == USE_DEFAULT and
521      save_checkpoint_secs == USE_DEFAULT):
522    save_checkpoint_steps = None
523    save_checkpoint_secs = 600
524  elif save_checkpoint_secs == USE_DEFAULT:
525    save_checkpoint_secs = None
526  elif save_checkpoint_steps == USE_DEFAULT:
527    save_checkpoint_steps = None
528
529  scaffold = scaffold or Scaffold()
530  worker_context = distribute_coordinator_context.get_current_worker_context()
531
532  if worker_context:
533    return _create_monitored_session_with_worker_context(
534        worker_context,
535        scaffold,
536        checkpoint_dir=checkpoint_dir,
537        hooks=hooks,
538        chief_only_hooks=chief_only_hooks,
539        save_checkpoint_secs=save_checkpoint_secs,
540        save_summaries_steps=save_summaries_steps,
541        save_summaries_secs=save_summaries_secs,
542        config=config,
543        stop_grace_period_secs=stop_grace_period_secs,
544        log_step_count_steps=log_step_count_steps,
545        max_wait_secs=max_wait_secs,
546        save_checkpoint_steps=save_checkpoint_steps,
547        summary_dir=summary_dir,
548        save_graph_def=save_graph_def)
549
550  if not is_chief:
551    session_creator = WorkerSessionCreator(
552        scaffold=scaffold,
553        master=master,
554        config=config,
555        max_wait_secs=max_wait_secs)
556    return MonitoredSession(
557        session_creator=session_creator,
558        hooks=hooks or [],
559        stop_grace_period_secs=stop_grace_period_secs)
560
561  all_hooks = []
562  if chief_only_hooks:
563    all_hooks.extend(chief_only_hooks)
564  session_creator = ChiefSessionCreator(
565      scaffold=scaffold,
566      checkpoint_dir=checkpoint_dir,
567      master=master,
568      config=config)
569
570  summary_dir = summary_dir or checkpoint_dir
571  if summary_dir:
572    if log_step_count_steps and log_step_count_steps > 0:
573      all_hooks.append(
574          basic_session_run_hooks.StepCounterHook(
575              output_dir=summary_dir, every_n_steps=log_step_count_steps))
576
577    if (save_summaries_steps and
578        save_summaries_steps > 0) or (save_summaries_secs and
579                                      save_summaries_secs > 0):
580      all_hooks.append(
581          basic_session_run_hooks.SummarySaverHook(
582              scaffold=scaffold,
583              save_steps=save_summaries_steps,
584              save_secs=save_summaries_secs,
585              output_dir=summary_dir))
586
587  if checkpoint_dir:
588    if (save_checkpoint_secs and
589        save_checkpoint_secs > 0) or (save_checkpoint_steps and
590                                      save_checkpoint_steps > 0):
591      all_hooks.append(
592          basic_session_run_hooks.CheckpointSaverHook(
593              checkpoint_dir,
594              save_steps=save_checkpoint_steps,
595              save_secs=save_checkpoint_secs,
596              scaffold=scaffold,
597              save_graph_def=save_graph_def))
598
599  if hooks:
600    all_hooks.extend(hooks)
601  return MonitoredSession(
602      session_creator=session_creator,
603      hooks=all_hooks,
604      stop_grace_period_secs=stop_grace_period_secs)
605
606
607@tf_export(v1=['train.SessionCreator'])
608@six.add_metaclass(abc.ABCMeta)
609class SessionCreator(object):
610  """A factory for tf.Session."""
611
612  @abc.abstractmethod
613  def create_session(self):
614    raise NotImplementedError(
615        'create_session is not implemented for {}.'.format(self))
616
617
618@tf_export(v1=['train.ChiefSessionCreator'])
619class ChiefSessionCreator(SessionCreator):
620  """Creates a tf.compat.v1.Session for a chief."""
621
622  def __init__(self,
623               scaffold=None,
624               master='',
625               config=None,
626               checkpoint_dir=None,
627               checkpoint_filename_with_path=None):
628    """Initializes a chief session creator.
629
630    Args:
631      scaffold: A `Scaffold` used for gathering or building supportive ops. If
632        not specified a default one is created. It's used to finalize the graph.
633      master: `String` representation of the TensorFlow master to use.
634      config: `ConfigProto` proto used to configure the session.
635      checkpoint_dir: A string.  Optional path to a directory where to restore
636        variables.
637      checkpoint_filename_with_path: Full file name path to the checkpoint file.
638    """
639    self._checkpoint_dir = checkpoint_dir
640    self._checkpoint_filename_with_path = checkpoint_filename_with_path
641    self._scaffold = scaffold or Scaffold()
642    self._session_manager = None
643    self._master = master
644    self._config = config
645
646  def _get_session_manager(self):
647    """Gets or creates a SessionManager."""
648    if self._session_manager:
649      return self._session_manager
650
651    self._session_manager = sm.SessionManager(
652        local_init_op=self._scaffold.local_init_op,
653        local_init_feed_dict=self._scaffold.local_init_feed_dict,
654        ready_op=self._scaffold.ready_op,
655        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
656        graph=ops.get_default_graph())
657    return self._session_manager
658
659  def create_session(self):
660    self._scaffold.finalize()
661    return self._get_session_manager().prepare_session(
662        self._master,
663        saver=self._scaffold.saver,
664        checkpoint_dir=self._checkpoint_dir,
665        checkpoint_filename_with_path=self._checkpoint_filename_with_path,
666        config=self._config,
667        init_op=self._scaffold.init_op,
668        init_feed_dict=self._scaffold.init_feed_dict,
669        init_fn=self._scaffold.init_fn)
670
671
672@tf_export(v1=['train.WorkerSessionCreator'])
673class WorkerSessionCreator(SessionCreator):
674  """Creates a tf.compat.v1.Session for a worker."""
675
676  def __init__(self,
677               scaffold=None,
678               master='',
679               config=None,
680               max_wait_secs=30 * 60):
681    """Initializes a worker session creator.
682
683    Args:
684      scaffold: A `Scaffold` used for gathering or building supportive ops. If
685        not specified a default one is created. It's used to finalize the graph.
686      master: `String` representation of the TensorFlow master to use.
687      config: `ConfigProto` proto used to configure the session.
688      max_wait_secs: Maximum time to wait for the session to become available.
689    """
690    self._scaffold = scaffold or Scaffold()
691    self._session_manager = None
692    self._master = master
693    self._config = config
694    self._max_wait_secs = max_wait_secs
695
696  def _get_session_manager(self):
697    """Gets or creates a SessionManager."""
698    if self._session_manager:
699      return self._session_manager
700
701    self._session_manager = sm.SessionManager(
702        local_init_op=self._scaffold.local_init_op,
703        local_init_feed_dict=self._scaffold.local_init_feed_dict,
704        ready_op=self._scaffold.ready_op,
705        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
706        graph=ops.get_default_graph())
707    return self._session_manager
708
709  def create_session(self):
710    self._scaffold.finalize()
711    return self._get_session_manager().wait_for_session(
712        self._master, config=self._config, max_wait_secs=self._max_wait_secs)
713
714
715class _MonitoredSession(object):
716  """See `MonitoredSession` or `SingularMonitoredSession`."""
717
718  def __init__(self,
719               session_creator,
720               hooks,
721               should_recover,
722               stop_grace_period_secs=120):
723    """Sets up a Monitored or Hooked Session.
724
725    Args:
726      session_creator: A factory object to create session. Typically a
727        `ChiefSessionCreator` or a `WorkerSessionCreator`.
728      hooks: An iterable of `SessionRunHook' objects.
729      should_recover: A bool. Indicates whether to recover from `AbortedError`
730        and `UnavailableError` or not.
731      stop_grace_period_secs: Number of seconds given to threads to stop after
732        `close()` has been called.
733    """
734    self._graph_was_finalized = ops.get_default_graph().finalized
735    self._hooks = hooks or []
736    for h in self._hooks:
737      h.begin()
738
739    worker_context = distribute_coordinator_context.get_current_worker_context()
740    if not session_creator and worker_context:
741      session_creator = worker_context.session_creator()
742
743    # Create the session.
744    self._coordinated_creator = self._CoordinatedSessionCreator(
745        session_creator=session_creator or ChiefSessionCreator(),
746        hooks=self._hooks,
747        stop_grace_period_secs=stop_grace_period_secs)
748    if should_recover:
749      self._sess = _RecoverableSession(self._coordinated_creator)
750    else:
751      self._sess = self._coordinated_creator.create_session()
752
753  @property
754  def graph(self):
755    """The graph that was launched in this session."""
756    if self._tf_sess() is None:
757      return None
758    return self._tf_sess().graph
759
760  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
761    """Run ops in the monitored session.
762
763    This method is completely compatible with the `tf.Session.run()` method.
764
765    Args:
766      fetches: Same as `tf.Session.run()`.
767      feed_dict: Same as `tf.Session.run()`.
768      options: Same as `tf.Session.run()`.
769      run_metadata: Same as `tf.Session.run()`.
770
771    Returns:
772      Same as `tf.Session.run()`.
773    """
774    return self._sess.run(
775        fetches,
776        feed_dict=feed_dict,
777        options=options,
778        run_metadata=run_metadata)
779
780  def run_step_fn(self, step_fn):
781    """Run ops using a step function.
782
783    Args:
784      step_fn: A function or a method with a single argument of type
785        `StepContext`.  The function may use methods of the argument to perform
786        computations with access to a raw session.  The returned value of the
787        `step_fn` will be returned from `run_step_fn`, unless a stop is
788        requested.  In that case, the next `should_stop` call will return True.
789        Example usage:
790            ```python
791            with tf.Graph().as_default():
792              c = tf.compat.v1.placeholder(dtypes.float32)
793              v = tf.add(c, 4.0)
794              w = tf.add(c, 0.5)
795              def step_fn(step_context):
796                a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
797                if a <= 4.5:
798                  step_context.request_stop()
799                  return step_context.run_with_hooks(fetches=w,
800                                                     feed_dict={c: 0.1})
801
802              with tf.MonitoredSession() as session:
803                while not session.should_stop():
804                  a = session.run_step_fn(step_fn)
805            ```
806            Hooks interact with the `run_with_hooks()` call inside the
807                 `step_fn` as they do with a `MonitoredSession.run` call.
808
809    Returns:
810      Returns the returned value of `step_fn`.
811
812    Raises:
813      StopIteration: if `step_fn` has called `request_stop()`.  It may be
814        caught by `with tf.MonitoredSession()` to close the session.
815      ValueError: if `step_fn` doesn't have a single argument called
816        `step_context`. It may also optionally have `self` for cases when it
817        belongs to an object.
818    """
819    step_fn_arguments = function_utils.fn_args(step_fn)
820    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
821        'self',
822        'step_context',
823    ):
824      raise ValueError(
825          '`step_fn` may either have one `step_context` argument, or'
826          ' `self` and `step_context` arguments if it\'s an instance'
827          ' method. Got {} instead.'.format(step_fn_arguments))
828
829    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
830    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
831    # `_CoordinatedSession.run` downstream in either case. This allows
832    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
833    # `_RecoverableSession.run_step_fn`.
834    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
835
836  class StepContext(object):
837    """Control flow instrument for the `step_fn` from `run_step_fn()`.
838
839       Users of `step_fn` may perform `run()` calls without running hooks
840       by accessing the `session`.  A `run()` call with hooks may be performed
841       using `run_with_hooks()`.  Computation flow can be interrupted using
842       `request_stop()`.
843    """
844
845    def __init__(self, session, run_with_hooks_fn):
846      """Initializes the `step_context` argument for a `step_fn` invocation.
847
848      Args:
849        session: An instance of `tf.compat.v1.Session`.
850        run_with_hooks_fn: A function for running fetches and hooks.
851      """
852      self._session = session
853      self._run_with_hooks_fn = run_with_hooks_fn
854
855    @property
856    def session(self):
857      return self._session
858
859    def run_with_hooks(self, *args, **kwargs):
860      """Same as `MonitoredSession.run`. Accepts the same arguments."""
861      return self._run_with_hooks_fn(*args, **kwargs)
862
863    def request_stop(self):
864      """Exit the training loop by causing `should_stop()` to return `True`.
865
866         Causes `step_fn` to exit by raising an exception.
867
868      Raises:
869        StopIteration
870      """
871      raise StopIteration('step_fn has requested the iterations to stop.')
872
873  def should_stop(self):
874    return self._sess is None or self._sess.should_stop()
875
876  def close(self):
877    self._close_internal()
878
879  def __enter__(self):
880    return self
881
882  def __exit__(self, exception_type, exception_value, traceback):
883    if exception_type in [errors.OutOfRangeError, StopIteration]:
884      exception_type = None
885    self._close_internal(exception_type)
886    # __exit__ should return True to suppress an exception.
887    return exception_type is None
888
889  class _CoordinatedSessionCreator(SessionCreator):
890    """Factory for _CoordinatedSession."""
891
892    def __init__(self, session_creator, hooks, stop_grace_period_secs):
893      self._session_creator = session_creator
894      self._hooks = hooks
895      self.coord = None
896      self.tf_sess = None
897      self._stop_grace_period_secs = stop_grace_period_secs
898
899    def create_session(self):
900      """Creates a coordinated session."""
901      # Keep the tf_sess for unit testing.
902      self.tf_sess = self._session_creator.create_session()
903      # We don't want coordinator to suppress any exception.
904      self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
905      if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
906        queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
907      # Inform the hooks that a new session has been created.
908      for hook in self._hooks:
909        hook.after_create_session(self.tf_sess, self.coord)
910      return _CoordinatedSession(
911          _HookedSession(self.tf_sess, self._hooks), self.coord,
912          self._stop_grace_period_secs)
913
914  def _close_internal(self, exception_type=None):
915    try:
916      if not exception_type:
917        for h in self._hooks:
918          h.end(self._coordinated_creator.tf_sess)
919    finally:
920      try:
921        if self._sess is None:
922          raise RuntimeError('Session is already closed.')
923        self._sess.close()
924      finally:
925        self._sess = None
926        self._coordinated_creator.tf_sess = None
927        self._coordinated_creator.coord = None
928        if not self._graph_was_finalized:
929          ops.get_default_graph()._unsafe_unfinalize()  # pylint: disable=protected-access
930
931  def _is_closed(self):
932    """Return True if the monitored session is closed.
933
934    For tests only.
935
936    Returns:
937      A boolean.
938    """
939    return self._coordinated_creator.tf_sess is None
940
941  def _tf_sess(self):
942    """Return underlying tf.compat.v1.Session object.
943
944    Warning: accessing the returned object in user code is likely to cause races
945    or "flaky tests".
946
947    Returns:
948      A tf.compat.v1.Session object.
949    """
950    return self._coordinated_creator.tf_sess
951
952
953@tf_export(v1=['train.MonitoredSession'])
954class MonitoredSession(_MonitoredSession):
955  """Session-like object that handles initialization, recovery and hooks.
956
957  Example usage:
958
959  ```python
960  saver_hook = CheckpointSaverHook(...)
961  summary_hook = SummarySaverHook(...)
962  with MonitoredSession(session_creator=ChiefSessionCreator(...),
963                        hooks=[saver_hook, summary_hook]) as sess:
964    while not sess.should_stop():
965      sess.run(train_op)
966  ```
967
968  Initialization: At creation time the monitored session does following things
969  in given order:
970
971  * calls `hook.begin()` for each given hook
972  * finalizes the graph via `scaffold.finalize()`
973  * create session
974  * initializes the model via initialization ops provided by `Scaffold`
975  * restores variables if a checkpoint exists
976  * launches queue runners
977  * calls `hook.after_create_session()`
978
979  Run: When `run()` is called, the monitored session does following things:
980
981  * calls `hook.before_run()`
982  * calls TensorFlow `session.run()` with merged fetches and feed_dict
983  * calls `hook.after_run()`
984  * returns result of `session.run()` asked by user
985  * if `AbortedError` or `UnavailableError` occurs, it recovers or
986    reinitializes the session before executing the run() call again
987
988
989  Exit: At the `close()`, the monitored session does following things in order:
990
991  * calls `hook.end()`
992  * closes the queue runners and the session
993  * suppresses `OutOfRange` error which indicates that all inputs have been
994    processed if the monitored_session is used as a context
995
996  How to set `tf.compat.v1.Session` arguments:
997
998  * In most cases you can set session arguments as follows:
999
1000  ```python
1001  MonitoredSession(
1002    session_creator=ChiefSessionCreator(master=..., config=...))
1003  ```
1004
1005  * In distributed setting for a non-chief worker, you can use following:
1006
1007  ```python
1008  MonitoredSession(
1009    session_creator=WorkerSessionCreator(master=..., config=...))
1010  ```
1011
1012  See `MonitoredTrainingSession` for an example usage based on chief or worker.
1013
1014  Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
1015  following:
1016
1017  * it cannot be set as default session.
1018  * it cannot be sent to saver.save.
1019  * it cannot be sent to tf.train.start_queue_runners.
1020
1021  Args:
1022    session_creator: A factory object to create session. Typically a
1023      `ChiefSessionCreator` which is the default one.
1024    hooks: An iterable of `SessionRunHook' objects.
1025
1026  Returns:
1027    A MonitoredSession object.
1028  """
1029
1030  def __init__(self,
1031               session_creator=None,
1032               hooks=None,
1033               stop_grace_period_secs=120):
1034    super(MonitoredSession, self).__init__(
1035        session_creator,
1036        hooks,
1037        should_recover=True,
1038        stop_grace_period_secs=stop_grace_period_secs)
1039
1040
1041@tf_export(v1=['train.SingularMonitoredSession'])
1042class SingularMonitoredSession(_MonitoredSession):
1043  """Session-like object that handles initialization, restoring, and hooks.
1044
1045  Please note that this utility is not recommended for distributed settings.
1046  For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
1047  The
1048  differences between `MonitoredSession` and `SingularMonitoredSession` are:
1049
1050  * `MonitoredSession` handles `AbortedError` and `UnavailableError` for
1051    distributed settings, but `SingularMonitoredSession` does not.
1052  * `MonitoredSession` can be created in `chief` or `worker` modes.
1053    `SingularMonitoredSession` is always created as `chief`.
1054  * You can access the raw `tf.compat.v1.Session` object used by
1055    `SingularMonitoredSession`, whereas in MonitoredSession the raw session is
1056    private. This can be used:
1057      - To `run` without hooks.
1058      - To save and restore.
1059  * All other functionality is identical.
1060
1061  Example usage:
1062  ```python
1063  saver_hook = CheckpointSaverHook(...)
1064  summary_hook = SummarySaverHook(...)
1065  with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess:
1066    while not sess.should_stop():
1067      sess.run(train_op)
1068  ```
1069
1070  Initialization: At creation time the hooked session does following things
1071  in given order:
1072
1073  * calls `hook.begin()` for each given hook
1074  * finalizes the graph via `scaffold.finalize()`
1075  * create session
1076  * initializes the model via initialization ops provided by `Scaffold`
1077  * restores variables if a checkpoint exists
1078  * launches queue runners
1079
1080  Run: When `run()` is called, the hooked session does following things:
1081
1082  * calls `hook.before_run()`
1083  * calls TensorFlow `session.run()` with merged fetches and feed_dict
1084  * calls `hook.after_run()`
1085  * returns result of `session.run()` asked by user
1086
1087  Exit: At the `close()`, the hooked session does following things in order:
1088
1089  * calls `hook.end()`
1090  * closes the queue runners and the session
1091  * suppresses `OutOfRange` error which indicates that all inputs have been
1092    processed if the `SingularMonitoredSession` is used as a context.
1093  """
1094
1095  def __init__(self,
1096               hooks=None,
1097               scaffold=None,
1098               master='',
1099               config=None,
1100               checkpoint_dir=None,
1101               stop_grace_period_secs=120,
1102               checkpoint_filename_with_path=None):
1103    """Creates a SingularMonitoredSession.
1104
1105    Args:
1106      hooks: An iterable of `SessionRunHook' objects.
1107      scaffold: A `Scaffold` used for gathering or building supportive ops. If
1108        not specified a default one is created. It's used to finalize the graph.
1109      master: `String` representation of the TensorFlow master to use.
1110      config: `ConfigProto` proto used to configure the session.
1111      checkpoint_dir: A string.  Optional path to a directory where to restore
1112        variables.
1113      stop_grace_period_secs: Number of seconds given to threads to stop after
1114        `close()` has been called.
1115      checkpoint_filename_with_path: A string. Optional path to a checkpoint
1116        file from which to restore variables.
1117    """
1118    session_creator = ChiefSessionCreator(
1119        scaffold=scaffold,
1120        master=master,
1121        config=config,
1122        checkpoint_dir=checkpoint_dir,
1123        checkpoint_filename_with_path=checkpoint_filename_with_path)
1124    super(SingularMonitoredSession, self).__init__(
1125        session_creator,
1126        hooks,
1127        should_recover=False,
1128        stop_grace_period_secs=stop_grace_period_secs)
1129
1130  def raw_session(self):
1131    """Returns underlying `TensorFlow.Session` object."""
1132    return self._tf_sess()
1133
1134
1135class _WrappedSession(object):
1136  """Wrapper around a `tf.compat.v1.Session`.
1137
1138  This wrapper is used as a base class for various session wrappers
1139  that provide additional functionality such as monitoring, coordination,
1140  and recovery.
1141
1142  In addition to the methods exported by `SessionInterface` the wrapper
1143  provides a method to check for stop and never raises exceptions from
1144  calls to `close()`.
1145  """
1146
1147  def __init__(self, sess):
1148    """Creates a `_WrappedSession`.
1149
1150    Args:
1151      sess: A `tf.compat.v1.Session` or `_WrappedSession` object.  The wrapped
1152        session.
1153    """
1154    self._sess = sess
1155    self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
1156
1157  @property
1158  def graph(self):
1159    return self._sess.graph
1160
1161  @property
1162  def sess_str(self):
1163    return self._sess.sess_str
1164
1165  def should_stop(self):
1166    """Return true if this session should not be used anymore.
1167
1168    Always return True if the session was closed.
1169
1170    Returns:
1171      True if the session should stop, False otherwise.
1172    """
1173    if self._check_stop():
1174      return True
1175    if self._sess:
1176      return self._wrapped_is_stoppable and self._sess.should_stop()
1177    return True
1178
1179  def _check_stop(self):
1180    """Hook for subclasses to provide their own stop condition.
1181
1182    Returns:
1183      True if the session should stop, False otherwise.
1184    """
1185    return False
1186
1187  def close(self):
1188    if self._sess:
1189      try:
1190        self._sess.close()
1191      except _PREEMPTION_ERRORS as e:
1192        logging.error(
1193            'An error occurred when attempting to close the '
1194            'session. This may be due to a preemption in a '
1195            'connected worker or parameter server. Error: %s', e)
1196      finally:
1197        self._sess = None
1198
1199  def run(self, *args, **kwargs):
1200    return self._sess.run(*args, **kwargs)
1201
1202  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1203    # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
1204    # It is `None` when called from `_CoordinatedSession`. In that case
1205    # `self.run` is `_CoordinatedSession.run`.
1206    run_with_hooks = run_with_hooks or self.run
1207    return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
1208
1209
1210class _RecoverableSession(_WrappedSession):
1211  """A wrapped session that recreates a session upon certain kinds of errors.
1212
1213  The constructor is passed a SessionCreator object, not a session.
1214
1215  Calls to `run()` are delegated to the wrapped session.  If a call raises the
1216  exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the
1217  wrapped session is closed, and a new one is created by calling the factory
1218  again.
1219  """
1220
1221  def __init__(self, sess_creator):
1222    """Create a new `_RecoverableSession`.
1223
1224    The value returned by calling `sess_creator.create_session()` will be the
1225    session wrapped by this recoverable session.
1226
1227    Args:
1228      sess_creator: A 'SessionCreator' to be wrapped by recoverable.
1229    """
1230    self._sess_creator = sess_creator
1231    _WrappedSession.__init__(self, self._create_session())
1232
1233  def _create_session(self):
1234    while True:
1235      try:
1236        return self._sess_creator.create_session()
1237      except _PREEMPTION_ERRORS as e:
1238        logging.info(
1239            'An error was raised while a session was being created. '
1240            'This may be due to a preemption of a connected worker '
1241            'or parameter server. A new session will be created. '
1242            'This error may also occur due to a gRPC failure caused '
1243            'by high memory or network bandwidth usage in the '
1244            'parameter servers. If this error occurs repeatedly, try '
1245            'increasing the number of parameter servers assigned to '
1246            'the job. Error: %s', e)
1247
1248  def _check_stop(self):
1249    try:
1250      if self._sess:
1251        return self._sess._check_stop()  # pylint: disable=protected-access
1252      else:
1253        return True
1254    except _PREEMPTION_ERRORS as e:
1255      logging.info(
1256          'An error was raised while considering whether the '
1257          'session is complete. This may be due to a preemption in '
1258          'a connected worker or parameter server. The current '
1259          'session will be closed and a new session will be '
1260          'created. This error may also occur due to a gRPC failure '
1261          'caused by high memory or network bandwidth usage in the '
1262          'parameter servers. If this error occurs repeatedly, try '
1263          'increasing the number of parameter servers assigned to '
1264          'the job. Error: %s', e)
1265      self.close()
1266      self._sess = self._create_session()
1267      # Since we have just recreated the session, the overall computation should
1268      # not stop:
1269      return False
1270    except Exception:  # pylint: disable=broad-except
1271      # `should_stop` should return True instead of raising an exception.
1272      return True
1273
1274  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1275    while True:
1276      try:
1277        if not self._sess:
1278          self._sess = self._create_session()
1279        return self._sess.run(
1280            fetches,
1281            feed_dict=feed_dict,
1282            options=options,
1283            run_metadata=run_metadata)
1284      except _PREEMPTION_ERRORS as e:
1285        logging.info(
1286            'An error was raised. This may be due to a preemption in '
1287            'a connected worker or parameter server. The current '
1288            'session will be closed and a new session will be '
1289            'created. This error may also occur due to a gRPC failure '
1290            'caused by high memory or network bandwidth usage in the '
1291            'parameter servers. If this error occurs repeatedly, try '
1292            'increasing the number of parameter servers assigned to '
1293            'the job. Error: %s', e)
1294        self.close()
1295        self._sess = None
1296
1297  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1298    while True:
1299      try:
1300        if not self._sess:
1301          self._sess = self._create_session()
1302
1303        run_with_hooks = self._sess.run
1304        return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
1305      except _PREEMPTION_ERRORS as e:
1306        logging.info(
1307            'An error was raised. This may be due to a preemption in '
1308            'a connected worker or parameter server. The current '
1309            'session will be closed and a new session will be '
1310            'created. This error may also occur due to a gRPC failure '
1311            'caused by high memory or network bandwidth usage in the '
1312            'parameter servers. If this error occurs repeatedly, try '
1313            'increasing the number of parameter servers assigned to '
1314            'the job. Error: %s', e)
1315        self.close()
1316        self._sess = None
1317
1318
1319class _CoordinatedSession(_WrappedSession):
1320  """A wrapped session that works with a `tf.Coordinator`.
1321
1322  Calls to `run()` are delegated to the wrapped session.  If a call
1323  raises an exception, the exception is reported to the coordinator.
1324
1325  In addition, after each call to `run()` this session ask the coordinator if
1326  the session should stop.  In that case it will join all the threads
1327  registered with the coordinator before returning.
1328
1329  If the coordinator was requested to stop with an exception, that exception
1330  will be re-raised from the call to `run()`.
1331  """
1332
1333  def __init__(self, sess, coord, stop_grace_period_secs=120):
1334    """Create a new `_CoordinatedSession`.
1335
1336    Args:
1337      sess: A `tf.compat.v1.Session` object.  The wrapped session.
1338      coord: A `tf.train.Coordinator` object.
1339      stop_grace_period_secs: Number of seconds given to threads to stop after
1340        `close()` has been called.
1341    """
1342    _WrappedSession.__init__(self, sess)
1343    self._coord = coord
1344    self._stop_grace_period_secs = stop_grace_period_secs
1345
1346  def _check_stop(self):
1347    # If the coordinator was asked to stop due to an exception, then it needs
1348    # to be propagated to this stack.
1349    self._coord.raise_requested_exception()
1350    # At this point, no exceptions are recorded in the coordinator.
1351    return self._coord.should_stop()
1352
1353  def close(self):
1354    self._coord.request_stop()
1355    try:
1356      self._coord.join(
1357          stop_grace_period_secs=self._stop_grace_period_secs,
1358          ignore_live_threads=True)
1359    finally:
1360      try:
1361        _WrappedSession.close(self)
1362      except Exception:  # pylint: disable=broad-except
1363        # We intentionally suppress exceptions from the close() here since
1364        # useful exceptions are already reported by join().
1365        pass
1366
1367  def run(self, *args, **kwargs):
1368    try:
1369      return self._sess.run(*args, **kwargs)
1370    except _PREEMPTION_ERRORS:
1371      raise
1372    except Exception:  # pylint: disable=broad-except
1373      # A non-preemption error could have been caused by a preemption error
1374      # in the coordinator. If this is the case, raise that exception instead,
1375      # since it's the root cause. Otherwise, stick to the `original_exc_info`.
1376      original_exc_info = sys.exc_info()
1377      try:
1378        self._coord.raise_requested_exception()
1379      except _PREEMPTION_ERRORS:
1380        raise
1381      except Exception:  # pylint: disable=broad-except
1382        raise six.reraise(*original_exc_info)
1383      else:
1384        raise six.reraise(*original_exc_info)
1385
1386
1387class _HookedSession(_WrappedSession):
1388  """A _WrappedSession that calls hooks during calls to run().
1389
1390  The list of hooks to call is passed in the constructor.  Before each call
1391  to `run()` the session calls the `before_run()` method of the hooks, which
1392  can return additional ops or tensors to run.  These are added to the arguments
1393  of the call to `run()`.
1394
1395  When the `run()` call finishes, the session calls the `after_run()` methods of
1396  the hooks, passing the values returned by the `run()` call corresponding to
1397  the ops and tensors that each hook requested.
1398
1399  If any call to the hooks, requests stop via run_context the session will be
1400  marked as needing to stop and its `should_stop()` method will now return
1401  `True`.
1402  """
1403
1404  def __init__(self, sess, hooks):
1405    """Initializes a _HookedSession object.
1406
1407    Args:
1408      sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
1409      hooks: An iterable of `SessionRunHook' objects.
1410    """
1411
1412    _WrappedSession.__init__(self, sess)
1413    self._hooks = hooks
1414    self._should_stop = False
1415
1416  def _check_stop(self):
1417    """See base class."""
1418    return self._should_stop
1419
1420  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1421    """See base class."""
1422    if self.should_stop():
1423      raise RuntimeError('Run called even after should_stop requested.')
1424
1425    actual_fetches = {'caller': fetches}
1426
1427    run_context = session_run_hook.SessionRunContext(
1428        original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
1429        session=self._sess)
1430
1431    options = options or config_pb2.RunOptions()
1432    feed_dict = self._call_hook_before_run(run_context, actual_fetches,
1433                                           feed_dict, options)
1434
1435    # Do session run.
1436    run_metadata = run_metadata or config_pb2.RunMetadata()
1437    outputs = _WrappedSession.run(
1438        self,
1439        fetches=actual_fetches,
1440        feed_dict=feed_dict,
1441        options=options,
1442        run_metadata=run_metadata)
1443
1444    for hook in self._hooks:
1445      hook.after_run(
1446          run_context,
1447          session_run_hook.SessionRunValues(
1448              results=outputs[hook] if hook in outputs else None,
1449              options=options,
1450              run_metadata=run_metadata))
1451    self._should_stop = self._should_stop or run_context.stop_requested
1452
1453    return outputs['caller']
1454
1455  def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict,
1456                            options):
1457    """Calls hooks.before_run and handles requests from hooks."""
1458    hook_feeds = {}
1459    for hook in self._hooks:
1460      request = hook.before_run(run_context)
1461      if request is not None:
1462        if request.fetches is not None:
1463          fetch_dict[hook] = request.fetches
1464        if request.feed_dict:
1465          self._raise_if_feeds_intersects(hook_feeds, request.feed_dict,
1466                                          'Same tensor is fed by two hooks.')
1467          hook_feeds.update(request.feed_dict)
1468        if request.options:
1469          self._merge_run_options(options, request.options)
1470
1471    if not hook_feeds:
1472      return user_feed_dict
1473
1474    if not user_feed_dict:
1475      return hook_feeds
1476
1477    self._raise_if_feeds_intersects(
1478        user_feed_dict, hook_feeds,
1479        'Same tensor is fed by a SessionRunHook and user.')
1480    hook_feeds.update(user_feed_dict)
1481    return hook_feeds
1482
1483  def _raise_if_feeds_intersects(self, feeds1, feeds2, message):
1484    intersection = set(feeds1.keys()) & set(feeds2.keys())
1485    if intersection:
1486      raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection)))
1487
1488  def _merge_run_options(self, options, incoming_options):
1489    """Merge two instances of RunOptions into the first one.
1490
1491    During the merger, the numerical fields including trace_level,
1492    timeout_in_ms, inter_op_thread_pool are set to the larger one of the two.
1493    The boolean value is set to the logical OR of the two.
1494    debug_tensor_watch_opts of the original options is extended with that from
1495    the incoming one.
1496
1497    Args:
1498      options: The options to merge into.
1499      incoming_options: The options to be merged into the first argument.
1500    """
1501    options.trace_level = max(options.trace_level, incoming_options.trace_level)
1502    options.timeout_in_ms = max(options.timeout_in_ms,
1503                                incoming_options.timeout_in_ms)
1504    options.inter_op_thread_pool = max(options.inter_op_thread_pool,
1505                                       incoming_options.inter_op_thread_pool)
1506    options.output_partition_graphs = max(
1507        options.output_partition_graphs,
1508        incoming_options.output_partition_graphs)
1509    options.debug_options.debug_tensor_watch_opts.extend(
1510        incoming_options.debug_options.debug_tensor_watch_opts)
1511    options.debug_options.reset_disk_byte_usage = (
1512        options.debug_options.reset_disk_byte_usage or
1513        incoming_options.debug_options.reset_disk_byte_usage)
1514    options.report_tensor_allocations_upon_oom = (
1515        options.report_tensor_allocations_upon_oom or
1516        incoming_options.report_tensor_allocations_upon_oom)
1517