• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=g-bad-file-header
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""A wrapper of Session API which runs hooks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import abc
23import os
24import sys
25
26import six
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.distribute import distribute_coordinator_context
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import lookup_ops
35from tensorflow.python.ops import resources
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.summary import summary
39from tensorflow.python.training import basic_session_run_hooks
40from tensorflow.python.training import coordinator
41from tensorflow.python.training import queue_runner
42from tensorflow.python.training import saver as training_saver
43from tensorflow.python.training import session_manager as sm
44from tensorflow.python.training import session_run_hook
45from tensorflow.python.training.tracking import graph_view
46from tensorflow.python.training.tracking import util as trackable_util
47from tensorflow.python.util import function_utils
48from tensorflow.python.util.tf_export import tf_export
49
50# The list of exceptions that we should recover from. Exceptions not in this
51# list may terminate the job.
52_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError)
53
54# Value that indicates no value was provided.
55USE_DEFAULT = object()
56
57
58@tf_export(v1=['train.Scaffold'])
59class Scaffold(object):
60  """Structure to create or gather pieces commonly needed to train a model.
61
62  When you build a model for training you usually need ops to initialize
63  variables, a `Saver` to checkpoint them, an op to collect summaries for
64  the visualizer, and so on.
65
66  Various libraries built on top of the core TensorFlow library take care of
67  creating some or all of these pieces and storing them in well known
68  collections in the graph.  The `Scaffold` class helps pick these pieces from
69  the graph collections, creating and adding them to the collections if needed.
70
71  If you call the scaffold constructor without any arguments, it will pick
72  pieces from the collections, creating default ones if needed when
73  `scaffold.finalize()` is called.  You can pass arguments to the constructor to
74  provide your own pieces.  Pieces that you pass to the constructor are not
75  added to the graph collections.
76
77  The following pieces are directly accessible as attributes of the `Scaffold`
78  object:
79
80  * `saver`: A `tf.compat.v1.train.Saver` object taking care of saving the
81  variables.
82    Picked from and stored into the `SAVERS` collection in the graph by default.
83  * `init_op`: An op to run to initialize the variables.  Picked from and
84    stored into the `INIT_OP` collection in the graph by default.
85  * `ready_op`: An op to verify that the variables are initialized.  Picked
86    from and stored into the `READY_OP` collection in the graph by default.
87  * `ready_for_local_init_op`: An op to verify that global state has been
88    initialized and it is alright to run `local_init_op`.  Picked from and
89    stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by
90    default. This is needed when the initialization of local variables depends
91    on the values of global variables.
92  * `local_init_op`: An op to initialize the local variables.  Picked
93    from and stored into the `LOCAL_INIT_OP` collection in the graph by default.
94  * `summary_op`: An op to run and merge the summaries in the graph.  Picked
95    from and stored into the `SUMMARY_OP` collection in the graph by default.
96
97  You can also pass the following additional pieces to the constructor:
98
99  * `init_feed_dict`: A session feed dictionary that should be used when
100     running the init op.
101  * `init_fn`: A callable to run after the init op to perform additional
102    initializations.  The callable will be called as
103    `init_fn(scaffold, session)`.
104
105  """
106
107  def __init__(self,
108               init_op=None,
109               init_feed_dict=None,
110               init_fn=None,
111               ready_op=None,
112               ready_for_local_init_op=None,
113               local_init_op=None,
114               summary_op=None,
115               saver=None,
116               copy_from_scaffold=None,
117               local_init_feed_dict=None):
118    """Create a scaffold.
119
120    Args:
121      init_op: Optional op for initializing variables.
122      init_feed_dict: Optional session feed dictionary to use when running the
123        init_op.
124      init_fn: Optional function to use to initialize the model after running
125        the init_op.  Will be called as `init_fn(scaffold, session)`.
126      ready_op: Optional op to verify that the variables are initialized.  Must
127        return an empty 1D string tensor when the variables are initialized, or
128        a non-empty 1D string tensor listing the names of the non-initialized
129        variables.
130      ready_for_local_init_op: Optional op to verify that the global variables
131        are initialized and `local_init_op` can be run. Must return an empty 1D
132        string tensor when the global variables are initialized, or a non-empty
133        1D string tensor listing the names of the non-initialized global
134        variables.
135      local_init_op: Optional op to initialize local variables.
136      summary_op: Optional op to gather all summaries.  Must return a scalar
137        string tensor containing a serialized `Summary` proto.
138      saver: Optional `tf.compat.v1.train.Saver` object to use to save and
139        restore variables.  May also be a `tf.train.Checkpoint` object, in which
140        case object-based checkpoints are saved. This will also load some
141        object-based checkpoints saved from elsewhere, but that loading may be
142        fragile since it uses fixed keys rather than performing a full
143        graph-based match. For example if a variable has two paths from the
144        `Checkpoint` object because two `Model` objects share the `Layer` object
145        that owns it, removing one `Model` may change the keys and break
146        checkpoint loading through this API, whereas a graph-based match would
147        match the variable through the other `Model`.
148      copy_from_scaffold: Optional scaffold object to copy fields from. Its
149        fields will be overwritten by the provided fields in this function.
150      local_init_feed_dict: Optional session feed dictionary to use when running
151        the local_init_op.
152    """
153    if copy_from_scaffold is not None:
154      if not isinstance(copy_from_scaffold, Scaffold):
155        raise TypeError('copy_from_scaffold is not a Scaffold instance.')
156      # We need _coalesce since Tensor is not converted to bool automatically,
157      # so the common idiom of (a or b) does not work.
158      coalesce = lambda a, b: a if a is not None else b
159      init_op = coalesce(init_op, copy_from_scaffold.init_op)
160      init_feed_dict = coalesce(init_feed_dict,
161                                copy_from_scaffold.init_feed_dict)
162      # Use the original init_fn provided by the user to init the new Scaffold.
163      init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn)  # pylint: disable=protected-access
164      ready_op = coalesce(ready_op, copy_from_scaffold.ready_op)
165      ready_for_local_init_op = coalesce(
166          ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
167      local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
168      local_init_feed_dict = coalesce(local_init_feed_dict,
169                                      copy_from_scaffold.local_init_feed_dict)
170      summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
171      saver = coalesce(saver, copy_from_scaffold.saver)
172
173    # NOTE(touts): modifying the init function to be passed the scaffold is a
174    # hack to make it easy to find the saver.  Is there a better way?
175    self._user_init_fn = init_fn
176    if init_fn:
177      self._init_fn = lambda sess: init_fn(self, sess)
178    else:
179      self._init_fn = None
180
181    self._init_op = init_op
182    self._init_feed_dict = init_feed_dict
183    self._ready_op = ready_op
184    self._ready_for_local_init_op = ready_for_local_init_op
185    self._local_init_op = local_init_op
186    self._local_init_feed_dict = local_init_feed_dict
187    self._summary_op = summary_op
188    self._saver = saver
189
190  def finalize(self):
191    """Creates operations if needed and finalizes the graph."""
192    if self._init_op is None:
193
194      def default_init_op():
195        return control_flow_ops.group(
196            variables.global_variables_initializer(),
197            resources.initialize_resources(resources.shared_resources()),
198            ops.get_collection('saved_model_initializers'))
199
200      self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP,
201                                              default_init_op)
202    if self._ready_op is None:
203
204      def default_ready_op():
205        return array_ops.concat([
206            variables.report_uninitialized_variables(),
207            resources.report_uninitialized_resources()
208        ], 0)
209
210      self._ready_op = Scaffold.get_or_default('ready_op',
211                                               ops.GraphKeys.READY_OP,
212                                               default_ready_op)
213    if self._ready_for_local_init_op is None:
214
215      def default_ready_for_local_init_op():
216        return array_ops.concat([
217            variables.report_uninitialized_variables(
218                variables.global_variables()),
219            resources.report_uninitialized_resources(
220                resources.shared_resources())
221        ], 0)
222
223      self._ready_for_local_init_op = Scaffold.get_or_default(
224          'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
225          default_ready_for_local_init_op)
226    if self._local_init_op is None:
227      self._local_init_op = Scaffold.get_or_default(
228          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
229          Scaffold.default_local_init_op)
230    if self._summary_op is None:
231      self._summary_op = Scaffold.get_or_default('summary_op',
232                                                 ops.GraphKeys.SUMMARY_OP,
233                                                 summary.merge_all)
234    # pylint: disable=g-long-lambda
235    if self._saver is None:
236      self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
237    # pylint: enable=g-long-lambda
238    if isinstance(self._saver, trackable_util.Checkpoint):
239      self._saver = training_saver.Saver(
240          var_list=graph_view.ObjectGraphView(
241              self._saver).frozen_saveable_objects(),
242          sharded=True)
243    else:
244      self._saver.build()
245
246    ops.get_default_graph().finalize()
247    logging.info('Graph was finalized.')
248    return self
249
250  @property
251  def init_fn(self):
252    return self._init_fn
253
254  @property
255  def init_op(self):
256    return self._init_op
257
258  @property
259  def ready_op(self):
260    return self._ready_op
261
262  @property
263  def ready_for_local_init_op(self):
264    return self._ready_for_local_init_op
265
266  @property
267  def local_init_op(self):
268    return self._local_init_op
269
270  @property
271  def local_init_feed_dict(self):
272    return self._local_init_feed_dict
273
274  @property
275  def summary_op(self):
276    return self._summary_op
277
278  @property
279  def saver(self):
280    return self._saver
281
282  @property
283  def init_feed_dict(self):
284    return self._init_feed_dict
285
286  @staticmethod
287  def get_or_default(arg_name, collection_key, default_constructor):
288    """Get from cache or create a default operation."""
289    elements = ops.get_collection(collection_key)
290    if elements:
291      if len(elements) > 1:
292        raise RuntimeError(
293            'More than one item in the collection "%s". '
294            'Please indicate which one to use by passing it to '
295            'the tf.Scaffold constructor as:  '
296            'tf.Scaffold(%s=item to use)', collection_key, arg_name)
297      return elements[0]
298    op = default_constructor()
299    if op is not None:
300      ops.add_to_collection(collection_key, op)
301    return op
302
303  @staticmethod
304  def default_local_init_op():
305    """Returns an op that groups the default local init ops.
306
307    This op is used during session initialization when a Scaffold is
308    initialized without specifying the local_init_op arg. It includes
309    `tf.compat.v1.local_variables_initializer`,
310    `tf.compat.v1.tables_initializer`, and also
311    initializes local session resources.
312
313    Returns:
314      The default Scaffold local init op.
315    """
316    return control_flow_ops.group(
317        variables.local_variables_initializer(),
318        lookup_ops.tables_initializer(),
319        resources.initialize_resources(resources.local_resources()))
320
321
322def _create_monitored_session_with_worker_context(
323    worker_context,  # pylint: disable=missing-docstring
324    scaffold,
325    checkpoint_dir=None,
326    hooks=None,
327    chief_only_hooks=None,
328    save_checkpoint_secs=None,
329    save_summaries_steps=None,
330    save_summaries_secs=None,
331    config=None,
332    stop_grace_period_secs=120,
333    log_step_count_steps=100,
334    max_wait_secs=7200,
335    save_checkpoint_steps=None,
336    summary_dir=None,
337    save_graph_def=True):
338  all_hooks = []
339  if hooks:
340    all_hooks.extend(hooks)
341  if chief_only_hooks and worker_context.is_chief:
342    all_hooks.extend(chief_only_hooks)
343
344  # We need to call save or summary ops on all workers since these ops may
345  # contain collective ops, only running save ops on some workers would make
346  # collective ops hang. Therefore on those workers that don't need to actually
347  # write checkpoints or summaries, we let them write to a temp directory.
348  # pylint: disable=protected-access
349  if type(
350      worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy',
351                                             'CollectiveAllReduceStrategyV1',
352                                             'MultiWorkerMirroredStrategy'):
353    if worker_context.task_type:
354      tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id)
355    else:
356      tmpdir = 'tmp'
357
358    if save_checkpoint_secs:
359      logging.warning('Collective ops may deadlock with '
360                      '`save_checkpoints_secs` please use '
361                      '`save_checkpoint_steps` instead. Clearing '
362                      '`save_checkpoint_secs` and setting '
363                      '`save_checkpoint_steps` to 1000 now.')
364      save_checkpoint_secs = None
365      save_checkpoint_steps = 1000
366    if save_summaries_secs:
367      logging.warning('Collective ops may run out of sync with'
368                      '`save_summaries_secs`, please use '
369                      '`save_summaries_steps` instead.')
370  else:
371    tmpdir = None
372
373  summary_dir = summary_dir or checkpoint_dir
374  if summary_dir and log_step_count_steps and log_step_count_steps > 0:
375    if worker_context.should_save_summary:
376      all_hooks.append(
377          basic_session_run_hooks.StepCounterHook(
378              output_dir=summary_dir, every_n_steps=log_step_count_steps))
379    elif tmpdir:
380      all_hooks.append(
381          basic_session_run_hooks.StepCounterHook(
382              output_dir=os.path.join(summary_dir, tmpdir),
383              every_n_steps=log_step_count_steps))
384
385  if (((save_summaries_steps and save_summaries_steps > 0) or
386       (save_summaries_secs and save_summaries_secs > 0)) and summary_dir):
387    if worker_context.should_save_summary:
388      all_hooks.append(
389          basic_session_run_hooks.SummarySaverHook(
390              scaffold=scaffold,
391              save_steps=save_summaries_steps,
392              save_secs=save_summaries_secs,
393              output_dir=summary_dir))
394    elif tmpdir:
395      all_hooks.append(
396          basic_session_run_hooks.SummarySaverHook(
397              scaffold=scaffold,
398              save_steps=save_summaries_steps,
399              save_secs=save_summaries_secs,
400              output_dir=os.path.join(summary_dir, tmpdir)))
401
402    if (((save_checkpoint_secs and save_checkpoint_secs > 0) or
403         (save_checkpoint_steps and save_checkpoint_steps > 0)) and
404        checkpoint_dir):
405      if worker_context.should_checkpoint:
406        all_hooks.append(
407            basic_session_run_hooks.CheckpointSaverHook(
408                checkpoint_dir,
409                save_steps=save_checkpoint_steps,
410                save_secs=save_checkpoint_secs,
411                scaffold=scaffold,
412                save_graph_def=save_graph_def))
413      elif tmpdir:
414        all_hooks.append(
415            basic_session_run_hooks.CheckpointSaverHook(
416                os.path.join(checkpoint_dir, tmpdir),
417                save_steps=save_checkpoint_steps,
418                save_secs=save_checkpoint_secs,
419                scaffold=scaffold,
420                save_graph_def=save_graph_def))
421
422  logging.info('all_hooks %r', all_hooks)
423  session_creator = worker_context.session_creator(
424      scaffold,
425      config=config,
426      checkpoint_dir=checkpoint_dir,
427      max_wait_secs=max_wait_secs)
428  return MonitoredSession(
429      session_creator=session_creator,
430      hooks=all_hooks,
431      stop_grace_period_secs=stop_grace_period_secs)
432
433
434@tf_export(v1=['train.MonitoredTrainingSession'])
435def MonitoredTrainingSession(
436    master='',  # pylint: disable=invalid-name
437    is_chief=True,
438    checkpoint_dir=None,
439    scaffold=None,
440    hooks=None,
441    chief_only_hooks=None,
442    save_checkpoint_secs=USE_DEFAULT,
443    save_summaries_steps=USE_DEFAULT,
444    save_summaries_secs=USE_DEFAULT,
445    config=None,
446    stop_grace_period_secs=120,
447    log_step_count_steps=100,
448    max_wait_secs=7200,
449    save_checkpoint_steps=USE_DEFAULT,
450    summary_dir=None,
451    save_graph_def=True):
452  """Creates a `MonitoredSession` for training.
453
454  For a chief, this utility sets proper session initializer/restorer. It also
455  creates hooks related to checkpoint and summary saving. For workers, this
456  utility sets proper session creator which waits for the chief to
457  initialize/restore. Please check `tf.compat.v1.train.MonitoredSession` for
458  more
459  information.
460
461
462  Args:
463    master: `String` the TensorFlow master to use.
464    is_chief: If `True`, it will take care of initialization and recovery the
465      underlying TensorFlow session. If `False`, it will wait on a chief to
466      initialize or recover the TensorFlow session.
467    checkpoint_dir: A string.  Optional path to a directory where to restore
468      variables.
469    scaffold: A `Scaffold` used for gathering or building supportive ops. If not
470      specified, a default one is created. It's used to finalize the graph.
471    hooks: Optional list of `SessionRunHook` objects.
472    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
473      `is_chief==True`, ignore otherwise.
474    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
475      using a default checkpoint saver. If both `save_checkpoint_steps` and
476      `save_checkpoint_secs` are set to `None`, then the default checkpoint
477      saver isn't used. If both are provided, then only `save_checkpoint_secs`
478      is used. Default 600.
479    save_summaries_steps: The frequency, in number of global steps, that the
480      summaries are written to disk using a default summary saver. If both
481      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
482      the default summary saver isn't used. Default 100.
483    save_summaries_secs: The frequency, in secs, that the summaries are written
484      to disk using a default summary saver.  If both `save_summaries_steps` and
485      `save_summaries_secs` are set to `None`, then the default summary saver
486      isn't used. Default not enabled.
487    config: an instance of `tf.compat.v1.ConfigProto` proto used to configure
488      the session. It's the `config` argument of constructor of
489      `tf.compat.v1.Session`.
490    stop_grace_period_secs: Number of seconds given to threads to stop after
491      `close()` has been called.
492    log_step_count_steps: The frequency, in number of global steps, that the
493      global step/sec is logged.
494    max_wait_secs: Maximum time workers should wait for the session to become
495      available. This should be kept relatively short to help detect incorrect
496      code, but sometimes may need to be increased if the chief takes a while to
497      start up.
498    save_checkpoint_steps: The frequency, in number of global steps, that a
499      checkpoint is saved using a default checkpoint saver. If both
500      `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
501      the default checkpoint saver isn't used. If both are provided, then only
502      `save_checkpoint_secs` is used. Default not enabled.
503    summary_dir: A string.  Optional path to a directory where to save
504      summaries. If None, checkpoint_dir is used instead.
505    save_graph_def: Whether to save the GraphDef and MetaGraphDef to
506      `checkpoint_dir`. The GraphDef is saved after the session is created as
507      `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as
508      `model.ckpt-*.meta`.
509
510  Returns:
511    A `MonitoredSession` object.
512  """
513  if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT:
514    save_summaries_steps = 100
515    save_summaries_secs = None
516  elif save_summaries_secs == USE_DEFAULT:
517    save_summaries_secs = None
518  elif save_summaries_steps == USE_DEFAULT:
519    save_summaries_steps = None
520
521  if (save_checkpoint_steps == USE_DEFAULT and
522      save_checkpoint_secs == USE_DEFAULT):
523    save_checkpoint_steps = None
524    save_checkpoint_secs = 600
525  elif save_checkpoint_secs == USE_DEFAULT:
526    save_checkpoint_secs = None
527  elif save_checkpoint_steps == USE_DEFAULT:
528    save_checkpoint_steps = None
529
530  scaffold = scaffold or Scaffold()
531  worker_context = distribute_coordinator_context.get_current_worker_context()
532
533  if worker_context:
534    return _create_monitored_session_with_worker_context(
535        worker_context,
536        scaffold,
537        checkpoint_dir=checkpoint_dir,
538        hooks=hooks,
539        chief_only_hooks=chief_only_hooks,
540        save_checkpoint_secs=save_checkpoint_secs,
541        save_summaries_steps=save_summaries_steps,
542        save_summaries_secs=save_summaries_secs,
543        config=config,
544        stop_grace_period_secs=stop_grace_period_secs,
545        log_step_count_steps=log_step_count_steps,
546        max_wait_secs=max_wait_secs,
547        save_checkpoint_steps=save_checkpoint_steps,
548        summary_dir=summary_dir,
549        save_graph_def=save_graph_def)
550
551  if not is_chief:
552    session_creator = WorkerSessionCreator(
553        scaffold=scaffold,
554        master=master,
555        config=config,
556        max_wait_secs=max_wait_secs)
557    return MonitoredSession(
558        session_creator=session_creator,
559        hooks=hooks or [],
560        stop_grace_period_secs=stop_grace_period_secs)
561
562  all_hooks = []
563  if chief_only_hooks:
564    all_hooks.extend(chief_only_hooks)
565  session_creator = ChiefSessionCreator(
566      scaffold=scaffold,
567      checkpoint_dir=checkpoint_dir,
568      master=master,
569      config=config)
570
571  summary_dir = summary_dir or checkpoint_dir
572  if summary_dir:
573    if log_step_count_steps and log_step_count_steps > 0:
574      all_hooks.append(
575          basic_session_run_hooks.StepCounterHook(
576              output_dir=summary_dir, every_n_steps=log_step_count_steps))
577
578    if (save_summaries_steps and
579        save_summaries_steps > 0) or (save_summaries_secs and
580                                      save_summaries_secs > 0):
581      all_hooks.append(
582          basic_session_run_hooks.SummarySaverHook(
583              scaffold=scaffold,
584              save_steps=save_summaries_steps,
585              save_secs=save_summaries_secs,
586              output_dir=summary_dir))
587
588  if checkpoint_dir:
589    if (save_checkpoint_secs and
590        save_checkpoint_secs > 0) or (save_checkpoint_steps and
591                                      save_checkpoint_steps > 0):
592      all_hooks.append(
593          basic_session_run_hooks.CheckpointSaverHook(
594              checkpoint_dir,
595              save_steps=save_checkpoint_steps,
596              save_secs=save_checkpoint_secs,
597              scaffold=scaffold,
598              save_graph_def=save_graph_def))
599
600  if hooks:
601    all_hooks.extend(hooks)
602  return MonitoredSession(
603      session_creator=session_creator,
604      hooks=all_hooks,
605      stop_grace_period_secs=stop_grace_period_secs)
606
607
608@tf_export(v1=['train.SessionCreator'])
609@six.add_metaclass(abc.ABCMeta)
610class SessionCreator(object):
611  """A factory for tf.Session."""
612
613  @abc.abstractmethod
614  def create_session(self):
615    raise NotImplementedError(
616        'create_session is not implemented for {}.'.format(self))
617
618
619@tf_export(v1=['train.ChiefSessionCreator'])
620class ChiefSessionCreator(SessionCreator):
621  """Creates a tf.compat.v1.Session for a chief."""
622
623  def __init__(self,
624               scaffold=None,
625               master='',
626               config=None,
627               checkpoint_dir=None,
628               checkpoint_filename_with_path=None):
629    """Initializes a chief session creator.
630
631    Args:
632      scaffold: A `Scaffold` used for gathering or building supportive ops. If
633        not specified a default one is created. It's used to finalize the graph.
634      master: `String` representation of the TensorFlow master to use.
635      config: `ConfigProto` proto used to configure the session.
636      checkpoint_dir: A string.  Optional path to a directory where to restore
637        variables.
638      checkpoint_filename_with_path: Full file name path to the checkpoint file.
639    """
640    self._checkpoint_dir = checkpoint_dir
641    self._checkpoint_filename_with_path = checkpoint_filename_with_path
642    self._scaffold = scaffold or Scaffold()
643    self._session_manager = None
644    self._master = master
645    self._config = config
646
647  def _get_session_manager(self):
648    """Gets or creates a SessionManager."""
649    if self._session_manager:
650      return self._session_manager
651
652    self._session_manager = sm.SessionManager(
653        local_init_op=self._scaffold.local_init_op,
654        local_init_feed_dict=self._scaffold.local_init_feed_dict,
655        ready_op=self._scaffold.ready_op,
656        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
657        graph=ops.get_default_graph())
658    return self._session_manager
659
660  def create_session(self):
661    self._scaffold.finalize()
662    return self._get_session_manager().prepare_session(
663        self._master,
664        saver=self._scaffold.saver,
665        checkpoint_dir=self._checkpoint_dir,
666        checkpoint_filename_with_path=self._checkpoint_filename_with_path,
667        config=self._config,
668        init_op=self._scaffold.init_op,
669        init_feed_dict=self._scaffold.init_feed_dict,
670        init_fn=self._scaffold.init_fn)
671
672
673@tf_export(v1=['train.WorkerSessionCreator'])
674class WorkerSessionCreator(SessionCreator):
675  """Creates a tf.compat.v1.Session for a worker."""
676
677  def __init__(self,
678               scaffold=None,
679               master='',
680               config=None,
681               max_wait_secs=30 * 60):
682    """Initializes a worker session creator.
683
684    Args:
685      scaffold: A `Scaffold` used for gathering or building supportive ops. If
686        not specified a default one is created. It's used to finalize the graph.
687      master: `String` representation of the TensorFlow master to use.
688      config: `ConfigProto` proto used to configure the session.
689      max_wait_secs: Maximum time to wait for the session to become available.
690    """
691    self._scaffold = scaffold or Scaffold()
692    self._session_manager = None
693    self._master = master
694    self._config = config
695    self._max_wait_secs = max_wait_secs
696
697  def _get_session_manager(self):
698    """Gets or creates a SessionManager."""
699    if self._session_manager:
700      return self._session_manager
701
702    self._session_manager = sm.SessionManager(
703        local_init_op=self._scaffold.local_init_op,
704        local_init_feed_dict=self._scaffold.local_init_feed_dict,
705        ready_op=self._scaffold.ready_op,
706        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
707        graph=ops.get_default_graph())
708    return self._session_manager
709
710  def create_session(self):
711    self._scaffold.finalize()
712    return self._get_session_manager().wait_for_session(
713        self._master, config=self._config, max_wait_secs=self._max_wait_secs)
714
715
716class _MonitoredSession(object):
717  """See `MonitoredSession` or `SingularMonitoredSession`."""
718
719  def __init__(self,
720               session_creator,
721               hooks,
722               should_recover,
723               stop_grace_period_secs=120):
724    """Sets up a Monitored or Hooked Session.
725
726    Args:
727      session_creator: A factory object to create session. Typically a
728        `ChiefSessionCreator` or a `WorkerSessionCreator`.
729      hooks: An iterable of `SessionRunHook' objects.
730      should_recover: A bool. Indicates whether to recover from `AbortedError`
731        and `UnavailableError` or not.
732      stop_grace_period_secs: Number of seconds given to threads to stop after
733        `close()` has been called.
734    """
735    self._graph_was_finalized = ops.get_default_graph().finalized
736    self._hooks = hooks or []
737    for h in self._hooks:
738      h.begin()
739
740    worker_context = distribute_coordinator_context.get_current_worker_context()
741    if not session_creator and worker_context:
742      session_creator = worker_context.session_creator()
743
744    # Create the session.
745    self._coordinated_creator = self._CoordinatedSessionCreator(
746        session_creator=session_creator or ChiefSessionCreator(),
747        hooks=self._hooks,
748        stop_grace_period_secs=stop_grace_period_secs)
749    if should_recover:
750      self._sess = _RecoverableSession(self._coordinated_creator)
751    else:
752      self._sess = self._coordinated_creator.create_session()
753
754  @property
755  def graph(self):
756    """The graph that was launched in this session."""
757    if self._tf_sess() is None:
758      return None
759    return self._tf_sess().graph
760
761  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
762    """Run ops in the monitored session.
763
764    This method is completely compatible with the `tf.Session.run()` method.
765
766    Args:
767      fetches: Same as `tf.Session.run()`.
768      feed_dict: Same as `tf.Session.run()`.
769      options: Same as `tf.Session.run()`.
770      run_metadata: Same as `tf.Session.run()`.
771
772    Returns:
773      Same as `tf.Session.run()`.
774    """
775    return self._sess.run(
776        fetches,
777        feed_dict=feed_dict,
778        options=options,
779        run_metadata=run_metadata)
780
781  def run_step_fn(self, step_fn):
782    """Run ops using a step function.
783
784    Args:
785      step_fn: A function or a method with a single argument of type
786        `StepContext`.  The function may use methods of the argument to perform
787        computations with access to a raw session.  The returned value of the
788        `step_fn` will be returned from `run_step_fn`, unless a stop is
789        requested.  In that case, the next `should_stop` call will return True.
790        Example usage:
791            ```python
792            with tf.Graph().as_default():
793              c = tf.compat.v1.placeholder(dtypes.float32)
794              v = tf.add(c, 4.0)
795              w = tf.add(c, 0.5)
796              def step_fn(step_context):
797                a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
798                if a <= 4.5:
799                  step_context.request_stop()
800                  return step_context.run_with_hooks(fetches=w,
801                                                     feed_dict={c: 0.1})
802
803              with tf.MonitoredSession() as session:
804                while not session.should_stop():
805                  a = session.run_step_fn(step_fn)
806            ```
807            Hooks interact with the `run_with_hooks()` call inside the
808                 `step_fn` as they do with a `MonitoredSession.run` call.
809
810    Returns:
811      Returns the returned value of `step_fn`.
812
813    Raises:
814      StopIteration: if `step_fn` has called `request_stop()`.  It may be
815        caught by `with tf.MonitoredSession()` to close the session.
816      ValueError: if `step_fn` doesn't have a single argument called
817        `step_context`. It may also optionally have `self` for cases when it
818        belongs to an object.
819    """
820    step_fn_arguments = function_utils.fn_args(step_fn)
821    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
822        'self',
823        'step_context',
824    ):
825      raise ValueError(
826          '`step_fn` may either have one `step_context` argument, or'
827          ' `self` and `step_context` arguments if it\'s an instance'
828          ' method. Got {} instead.'.format(step_fn_arguments))
829
830    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
831    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
832    # `_CoordinatedSession.run` downstream in either case. This allows
833    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
834    # `_RecoverableSession.run_step_fn`.
835    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
836
837  class StepContext(object):
838    """Control flow instrument for the `step_fn` from `run_step_fn()`.
839
840       Users of `step_fn` may perform `run()` calls without running hooks
841       by accessing the `session`.  A `run()` call with hooks may be performed
842       using `run_with_hooks()`.  Computation flow can be interrupted using
843       `request_stop()`.
844    """
845
846    def __init__(self, session, run_with_hooks_fn):
847      """Initializes the `step_context` argument for a `step_fn` invocation.
848
849      Args:
850        session: An instance of `tf.compat.v1.Session`.
851        run_with_hooks_fn: A function for running fetches and hooks.
852      """
853      self._session = session
854      self._run_with_hooks_fn = run_with_hooks_fn
855
856    @property
857    def session(self):
858      return self._session
859
860    def run_with_hooks(self, *args, **kwargs):
861      """Same as `MonitoredSession.run`. Accepts the same arguments."""
862      return self._run_with_hooks_fn(*args, **kwargs)
863
864    def request_stop(self):
865      """Exit the training loop by causing `should_stop()` to return `True`.
866
867         Causes `step_fn` to exit by raising an exception.
868
869      Raises:
870        StopIteration
871      """
872      raise StopIteration('step_fn has requested the iterations to stop.')
873
874  def should_stop(self):
875    return self._sess is None or self._sess.should_stop()
876
877  def close(self):
878    self._close_internal()
879
880  def __enter__(self):
881    return self
882
883  def __exit__(self, exception_type, exception_value, traceback):
884    if exception_type in [errors.OutOfRangeError, StopIteration]:
885      exception_type = None
886    self._close_internal(exception_type)
887    # __exit__ should return True to suppress an exception.
888    return exception_type is None
889
890  class _CoordinatedSessionCreator(SessionCreator):
891    """Factory for _CoordinatedSession."""
892
893    def __init__(self, session_creator, hooks, stop_grace_period_secs):
894      self._session_creator = session_creator
895      self._hooks = hooks
896      self.coord = None
897      self.tf_sess = None
898      self._stop_grace_period_secs = stop_grace_period_secs
899
900    def create_session(self):
901      """Creates a coordinated session."""
902      # Keep the tf_sess for unit testing.
903      self.tf_sess = self._session_creator.create_session()
904      # We don't want coordinator to suppress any exception.
905      self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
906      if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
907        queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
908      # Inform the hooks that a new session has been created.
909      for hook in self._hooks:
910        hook.after_create_session(self.tf_sess, self.coord)
911      return _CoordinatedSession(
912          _HookedSession(self.tf_sess, self._hooks), self.coord,
913          self._stop_grace_period_secs)
914
915  def _close_internal(self, exception_type=None):
916    try:
917      if not exception_type:
918        for h in self._hooks:
919          h.end(self._coordinated_creator.tf_sess)
920    finally:
921      try:
922        if self._sess is None:
923          raise RuntimeError('Session is already closed.')
924        self._sess.close()
925      finally:
926        self._sess = None
927        self._coordinated_creator.tf_sess = None
928        self._coordinated_creator.coord = None
929        if not self._graph_was_finalized:
930          ops.get_default_graph()._unsafe_unfinalize()  # pylint: disable=protected-access
931
932  def _is_closed(self):
933    """Return True if the monitored session is closed.
934
935    For tests only.
936
937    Returns:
938      A boolean.
939    """
940    return self._coordinated_creator.tf_sess is None
941
942  def _tf_sess(self):
943    """Return underlying tf.compat.v1.Session object.
944
945    Warning: accessing the returned object in user code is likely to cause races
946    or "flaky tests".
947
948    Returns:
949      A tf.compat.v1.Session object.
950    """
951    return self._coordinated_creator.tf_sess
952
953
954@tf_export(v1=['train.MonitoredSession'])
955class MonitoredSession(_MonitoredSession):
956  """Session-like object that handles initialization, recovery and hooks.
957
958  Example usage:
959
960  ```python
961  saver_hook = CheckpointSaverHook(...)
962  summary_hook = SummarySaverHook(...)
963  with MonitoredSession(session_creator=ChiefSessionCreator(...),
964                        hooks=[saver_hook, summary_hook]) as sess:
965    while not sess.should_stop():
966      sess.run(train_op)
967  ```
968
969  Initialization: At creation time the monitored session does following things
970  in given order:
971
972  * calls `hook.begin()` for each given hook
973  * finalizes the graph via `scaffold.finalize()`
974  * create session
975  * initializes the model via initialization ops provided by `Scaffold`
976  * restores variables if a checkpoint exists
977  * launches queue runners
978  * calls `hook.after_create_session()`
979
980  Run: When `run()` is called, the monitored session does following things:
981
982  * calls `hook.before_run()`
983  * calls TensorFlow `session.run()` with merged fetches and feed_dict
984  * calls `hook.after_run()`
985  * returns result of `session.run()` asked by user
986  * if `AbortedError` or `UnavailableError` occurs, it recovers or
987    reinitializes the session before executing the run() call again
988
989
990  Exit: At the `close()`, the monitored session does following things in order:
991
992  * calls `hook.end()`
993  * closes the queue runners and the session
994  * suppresses `OutOfRange` error which indicates that all inputs have been
995    processed if the monitored_session is used as a context
996
997  How to set `tf.compat.v1.Session` arguments:
998
999  * In most cases you can set session arguments as follows:
1000
1001  ```python
1002  MonitoredSession(
1003    session_creator=ChiefSessionCreator(master=..., config=...))
1004  ```
1005
1006  * In distributed setting for a non-chief worker, you can use following:
1007
1008  ```python
1009  MonitoredSession(
1010    session_creator=WorkerSessionCreator(master=..., config=...))
1011  ```
1012
1013  See `MonitoredTrainingSession` for an example usage based on chief or worker.
1014
1015  Note: This is not a `tf.compat.v1.Session`. For example, it cannot do
1016  following:
1017
1018  * it cannot be set as default session.
1019  * it cannot be sent to saver.save.
1020  * it cannot be sent to tf.train.start_queue_runners.
1021
1022  Args:
1023    session_creator: A factory object to create session. Typically a
1024      `ChiefSessionCreator` which is the default one.
1025    hooks: An iterable of `SessionRunHook' objects.
1026
1027  Returns:
1028    A MonitoredSession object.
1029  """
1030
1031  def __init__(self,
1032               session_creator=None,
1033               hooks=None,
1034               stop_grace_period_secs=120):
1035    super(MonitoredSession, self).__init__(
1036        session_creator,
1037        hooks,
1038        should_recover=True,
1039        stop_grace_period_secs=stop_grace_period_secs)
1040
1041
1042@tf_export(v1=['train.SingularMonitoredSession'])
1043class SingularMonitoredSession(_MonitoredSession):
1044  """Session-like object that handles initialization, restoring, and hooks.
1045
1046  Please note that this utility is not recommended for distributed settings.
1047  For distributed settings, please use `tf.compat.v1.train.MonitoredSession`.
1048  The
1049  differences between `MonitoredSession` and `SingularMonitoredSession` are:
1050
1051  * `MonitoredSession` handles `AbortedError` and `UnavailableError` for
1052    distributed settings, but `SingularMonitoredSession` does not.
1053  * `MonitoredSession` can be created in `chief` or `worker` modes.
1054    `SingularMonitoredSession` is always created as `chief`.
1055  * You can access the raw `tf.compat.v1.Session` object used by
1056    `SingularMonitoredSession`, whereas in MonitoredSession the raw session is
1057    private. This can be used:
1058      - To `run` without hooks.
1059      - To save and restore.
1060  * All other functionality is identical.
1061
1062  Example usage:
1063  ```python
1064  saver_hook = CheckpointSaverHook(...)
1065  summary_hook = SummarySaverHook(...)
1066  with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess:
1067    while not sess.should_stop():
1068      sess.run(train_op)
1069  ```
1070
1071  Initialization: At creation time the hooked session does following things
1072  in given order:
1073
1074  * calls `hook.begin()` for each given hook
1075  * finalizes the graph via `scaffold.finalize()`
1076  * create session
1077  * initializes the model via initialization ops provided by `Scaffold`
1078  * restores variables if a checkpoint exists
1079  * launches queue runners
1080
1081  Run: When `run()` is called, the hooked session does following things:
1082
1083  * calls `hook.before_run()`
1084  * calls TensorFlow `session.run()` with merged fetches and feed_dict
1085  * calls `hook.after_run()`
1086  * returns result of `session.run()` asked by user
1087
1088  Exit: At the `close()`, the hooked session does following things in order:
1089
1090  * calls `hook.end()`
1091  * closes the queue runners and the session
1092  * suppresses `OutOfRange` error which indicates that all inputs have been
1093    processed if the `SingularMonitoredSession` is used as a context.
1094  """
1095
1096  def __init__(self,
1097               hooks=None,
1098               scaffold=None,
1099               master='',
1100               config=None,
1101               checkpoint_dir=None,
1102               stop_grace_period_secs=120,
1103               checkpoint_filename_with_path=None):
1104    """Creates a SingularMonitoredSession.
1105
1106    Args:
1107      hooks: An iterable of `SessionRunHook' objects.
1108      scaffold: A `Scaffold` used for gathering or building supportive ops. If
1109        not specified a default one is created. It's used to finalize the graph.
1110      master: `String` representation of the TensorFlow master to use.
1111      config: `ConfigProto` proto used to configure the session.
1112      checkpoint_dir: A string.  Optional path to a directory where to restore
1113        variables.
1114      stop_grace_period_secs: Number of seconds given to threads to stop after
1115        `close()` has been called.
1116      checkpoint_filename_with_path: A string. Optional path to a checkpoint
1117        file from which to restore variables.
1118    """
1119    session_creator = ChiefSessionCreator(
1120        scaffold=scaffold,
1121        master=master,
1122        config=config,
1123        checkpoint_dir=checkpoint_dir,
1124        checkpoint_filename_with_path=checkpoint_filename_with_path)
1125    super(SingularMonitoredSession, self).__init__(
1126        session_creator,
1127        hooks,
1128        should_recover=False,
1129        stop_grace_period_secs=stop_grace_period_secs)
1130
1131  def raw_session(self):
1132    """Returns underlying `TensorFlow.Session` object."""
1133    return self._tf_sess()
1134
1135
1136class _WrappedSession(object):
1137  """Wrapper around a `tf.compat.v1.Session`.
1138
1139  This wrapper is used as a base class for various session wrappers
1140  that provide additional functionality such as monitoring, coordination,
1141  and recovery.
1142
1143  In addition to the methods exported by `SessionInterface` the wrapper
1144  provides a method to check for stop and never raises exceptions from
1145  calls to `close()`.
1146  """
1147
1148  def __init__(self, sess):
1149    """Creates a `_WrappedSession`.
1150
1151    Args:
1152      sess: A `tf.compat.v1.Session` or `_WrappedSession` object.  The wrapped
1153        session.
1154    """
1155    self._sess = sess
1156    self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
1157
1158  @property
1159  def graph(self):
1160    return self._sess.graph
1161
1162  @property
1163  def sess_str(self):
1164    return self._sess.sess_str
1165
1166  def should_stop(self):
1167    """Return true if this session should not be used anymore.
1168
1169    Always return True if the session was closed.
1170
1171    Returns:
1172      True if the session should stop, False otherwise.
1173    """
1174    if self._check_stop():
1175      return True
1176    if self._sess:
1177      return self._wrapped_is_stoppable and self._sess.should_stop()
1178    return True
1179
1180  def _check_stop(self):
1181    """Hook for subclasses to provide their own stop condition.
1182
1183    Returns:
1184      True if the session should stop, False otherwise.
1185    """
1186    return False
1187
1188  def close(self):
1189    if self._sess:
1190      try:
1191        self._sess.close()
1192      except _PREEMPTION_ERRORS as e:
1193        logging.error(
1194            'An error occurred when attempting to close the '
1195            'session. This may be due to a preemption in a '
1196            'connected worker or parameter server. Error: %s', e)
1197      finally:
1198        self._sess = None
1199
1200  def run(self, *args, **kwargs):
1201    return self._sess.run(*args, **kwargs)
1202
1203  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1204    # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
1205    # It is `None` when called from `_CoordinatedSession`. In that case
1206    # `self.run` is `_CoordinatedSession.run`.
1207    run_with_hooks = run_with_hooks or self.run
1208    return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
1209
1210
1211class _RecoverableSession(_WrappedSession):
1212  """A wrapped session that recreates a session upon certain kinds of errors.
1213
1214  The constructor is passed a SessionCreator object, not a session.
1215
1216  Calls to `run()` are delegated to the wrapped session.  If a call raises the
1217  exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the
1218  wrapped session is closed, and a new one is created by calling the factory
1219  again.
1220  """
1221
1222  def __init__(self, sess_creator):
1223    """Create a new `_RecoverableSession`.
1224
1225    The value returned by calling `sess_creator.create_session()` will be the
1226    session wrapped by this recoverable session.
1227
1228    Args:
1229      sess_creator: A 'SessionCreator' to be wrapped by recoverable.
1230    """
1231    self._sess_creator = sess_creator
1232    _WrappedSession.__init__(self, self._create_session())
1233
1234  def _create_session(self):
1235    while True:
1236      try:
1237        return self._sess_creator.create_session()
1238      except _PREEMPTION_ERRORS as e:
1239        logging.info(
1240            'An error was raised while a session was being created. '
1241            'This may be due to a preemption of a connected worker '
1242            'or parameter server. A new session will be created. '
1243            'This error may also occur due to a gRPC failure caused '
1244            'by high memory or network bandwidth usage in the '
1245            'parameter servers. If this error occurs repeatedly, try '
1246            'increasing the number of parameter servers assigned to '
1247            'the job. Error: %s', e)
1248
1249  def _check_stop(self):
1250    try:
1251      if self._sess:
1252        return self._sess._check_stop()  # pylint: disable=protected-access
1253      else:
1254        return True
1255    except _PREEMPTION_ERRORS as e:
1256      logging.info(
1257          'An error was raised while considering whether the '
1258          'session is complete. This may be due to a preemption in '
1259          'a connected worker or parameter server. The current '
1260          'session will be closed and a new session will be '
1261          'created. This error may also occur due to a gRPC failure '
1262          'caused by high memory or network bandwidth usage in the '
1263          'parameter servers. If this error occurs repeatedly, try '
1264          'increasing the number of parameter servers assigned to '
1265          'the job. Error: %s', e)
1266      self.close()
1267      self._sess = self._create_session()
1268      # Since we have just recreated the session, the overall computation should
1269      # not stop:
1270      return False
1271    except Exception:  # pylint: disable=broad-except
1272      # `should_stop` should return True instead of raising an exception.
1273      return True
1274
1275  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1276    while True:
1277      try:
1278        if not self._sess:
1279          self._sess = self._create_session()
1280        return self._sess.run(
1281            fetches,
1282            feed_dict=feed_dict,
1283            options=options,
1284            run_metadata=run_metadata)
1285      except _PREEMPTION_ERRORS as e:
1286        logging.info(
1287            'An error was raised. This may be due to a preemption in '
1288            'a connected worker or parameter server. The current '
1289            'session will be closed and a new session will be '
1290            'created. This error may also occur due to a gRPC failure '
1291            'caused by high memory or network bandwidth usage in the '
1292            'parameter servers. If this error occurs repeatedly, try '
1293            'increasing the number of parameter servers assigned to '
1294            'the job. Error: %s', e)
1295        self.close()
1296        self._sess = None
1297
1298  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1299    while True:
1300      try:
1301        if not self._sess:
1302          self._sess = self._create_session()
1303
1304        run_with_hooks = self._sess.run
1305        return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
1306      except _PREEMPTION_ERRORS as e:
1307        logging.info(
1308            'An error was raised. This may be due to a preemption in '
1309            'a connected worker or parameter server. The current '
1310            'session will be closed and a new session will be '
1311            'created. This error may also occur due to a gRPC failure '
1312            'caused by high memory or network bandwidth usage in the '
1313            'parameter servers. If this error occurs repeatedly, try '
1314            'increasing the number of parameter servers assigned to '
1315            'the job. Error: %s', e)
1316        self.close()
1317        self._sess = None
1318
1319
1320class _CoordinatedSession(_WrappedSession):
1321  """A wrapped session that works with a `tf.Coordinator`.
1322
1323  Calls to `run()` are delegated to the wrapped session.  If a call
1324  raises an exception, the exception is reported to the coordinator.
1325
1326  In addition, after each call to `run()` this session ask the coordinator if
1327  the session should stop.  In that case it will join all the threads
1328  registered with the coordinator before returning.
1329
1330  If the coordinator was requested to stop with an exception, that exception
1331  will be re-raised from the call to `run()`.
1332  """
1333
1334  def __init__(self, sess, coord, stop_grace_period_secs=120):
1335    """Create a new `_CoordinatedSession`.
1336
1337    Args:
1338      sess: A `tf.compat.v1.Session` object.  The wrapped session.
1339      coord: A `tf.train.Coordinator` object.
1340      stop_grace_period_secs: Number of seconds given to threads to stop after
1341        `close()` has been called.
1342    """
1343    _WrappedSession.__init__(self, sess)
1344    self._coord = coord
1345    self._stop_grace_period_secs = stop_grace_period_secs
1346
1347  def _check_stop(self):
1348    # If the coordinator was asked to stop due to an exception, then it needs
1349    # to be propagated to this stack.
1350    self._coord.raise_requested_exception()
1351    # At this point, no exceptions are recorded in the coordinator.
1352    return self._coord.should_stop()
1353
1354  def close(self):
1355    self._coord.request_stop()
1356    try:
1357      self._coord.join(
1358          stop_grace_period_secs=self._stop_grace_period_secs,
1359          ignore_live_threads=True)
1360    finally:
1361      try:
1362        _WrappedSession.close(self)
1363      except Exception:  # pylint: disable=broad-except
1364        # We intentionally suppress exceptions from the close() here since
1365        # useful exceptions are already reported by join().
1366        pass
1367
1368  def run(self, *args, **kwargs):
1369    try:
1370      return self._sess.run(*args, **kwargs)
1371    except _PREEMPTION_ERRORS:
1372      raise
1373    except Exception:  # pylint: disable=broad-except
1374      # A non-preemption error could have been caused by a preemption error
1375      # in the coordinator. If this is the case, raise that exception instead,
1376      # since it's the root cause. Otherwise, stick to the `original_exc_info`.
1377      original_exc_info = sys.exc_info()
1378      try:
1379        self._coord.raise_requested_exception()
1380      except _PREEMPTION_ERRORS:
1381        raise
1382      except Exception:  # pylint: disable=broad-except
1383        raise six.reraise(*original_exc_info)
1384      else:
1385        raise six.reraise(*original_exc_info)
1386
1387
1388class _HookedSession(_WrappedSession):
1389  """A _WrappedSession that calls hooks during calls to run().
1390
1391  The list of hooks to call is passed in the constructor.  Before each call
1392  to `run()` the session calls the `before_run()` method of the hooks, which
1393  can return additional ops or tensors to run.  These are added to the arguments
1394  of the call to `run()`.
1395
1396  When the `run()` call finishes, the session calls the `after_run()` methods of
1397  the hooks, passing the values returned by the `run()` call corresponding to
1398  the ops and tensors that each hook requested.
1399
1400  If any call to the hooks, requests stop via run_context the session will be
1401  marked as needing to stop and its `should_stop()` method will now return
1402  `True`.
1403  """
1404
1405  def __init__(self, sess, hooks):
1406    """Initializes a _HookedSession object.
1407
1408    Args:
1409      sess: A `tf.compat.v1.Session` or a `_WrappedSession` object.
1410      hooks: An iterable of `SessionRunHook' objects.
1411    """
1412
1413    _WrappedSession.__init__(self, sess)
1414    self._hooks = hooks
1415    self._should_stop = False
1416
1417  def _check_stop(self):
1418    """See base class."""
1419    return self._should_stop
1420
1421  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1422    """See base class."""
1423    if self.should_stop():
1424      raise RuntimeError('Run called even after should_stop requested.')
1425
1426    actual_fetches = {'caller': fetches}
1427
1428    run_context = session_run_hook.SessionRunContext(
1429        original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
1430        session=self._sess)
1431
1432    options = options or config_pb2.RunOptions()
1433    feed_dict = self._call_hook_before_run(run_context, actual_fetches,
1434                                           feed_dict, options)
1435
1436    # Do session run.
1437    run_metadata = run_metadata or config_pb2.RunMetadata()
1438    outputs = _WrappedSession.run(
1439        self,
1440        fetches=actual_fetches,
1441        feed_dict=feed_dict,
1442        options=options,
1443        run_metadata=run_metadata)
1444
1445    for hook in self._hooks:
1446      hook.after_run(
1447          run_context,
1448          session_run_hook.SessionRunValues(
1449              results=outputs[hook] if hook in outputs else None,
1450              options=options,
1451              run_metadata=run_metadata))
1452    self._should_stop = self._should_stop or run_context.stop_requested
1453
1454    return outputs['caller']
1455
1456  def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict,
1457                            options):
1458    """Calls hooks.before_run and handles requests from hooks."""
1459    hook_feeds = {}
1460    for hook in self._hooks:
1461      request = hook.before_run(run_context)
1462      if request is not None:
1463        if request.fetches is not None:
1464          fetch_dict[hook] = request.fetches
1465        if request.feed_dict:
1466          self._raise_if_feeds_intersects(hook_feeds, request.feed_dict,
1467                                          'Same tensor is fed by two hooks.')
1468          hook_feeds.update(request.feed_dict)
1469        if request.options:
1470          self._merge_run_options(options, request.options)
1471
1472    if not hook_feeds:
1473      return user_feed_dict
1474
1475    if not user_feed_dict:
1476      return hook_feeds
1477
1478    self._raise_if_feeds_intersects(
1479        user_feed_dict, hook_feeds,
1480        'Same tensor is fed by a SessionRunHook and user.')
1481    hook_feeds.update(user_feed_dict)
1482    return hook_feeds
1483
1484  def _raise_if_feeds_intersects(self, feeds1, feeds2, message):
1485    intersection = set(feeds1.keys()) & set(feeds2.keys())
1486    if intersection:
1487      raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection)))
1488
1489  def _merge_run_options(self, options, incoming_options):
1490    """Merge two instances of RunOptions into the first one.
1491
1492    During the merger, the numerical fields including trace_level,
1493    timeout_in_ms, inter_op_thread_pool are set to the larger one of the two.
1494    The boolean value is set to the logical OR of the two.
1495    debug_tensor_watch_opts of the original options is extended with that from
1496    the incoming one.
1497
1498    Args:
1499      options: The options to merge into.
1500      incoming_options: The options to be merged into the first argument.
1501    """
1502    options.trace_level = max(options.trace_level, incoming_options.trace_level)
1503    options.timeout_in_ms = max(options.timeout_in_ms,
1504                                incoming_options.timeout_in_ms)
1505    options.inter_op_thread_pool = max(options.inter_op_thread_pool,
1506                                       incoming_options.inter_op_thread_pool)
1507    options.output_partition_graphs = max(
1508        options.output_partition_graphs,
1509        incoming_options.output_partition_graphs)
1510    options.debug_options.debug_tensor_watch_opts.extend(
1511        incoming_options.debug_options.debug_tensor_watch_opts)
1512    options.debug_options.reset_disk_byte_usage = (
1513        options.debug_options.reset_disk_byte_usage or
1514        incoming_options.debug_options.reset_disk_byte_usage)
1515    options.report_tensor_allocations_upon_oom = (
1516        options.report_tensor_allocations_upon_oom or
1517        incoming_options.report_tensor_allocations_upon_oom)
1518