• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=g-bad-file-header
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""A wrapper of Session API which runs hooks."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import abc
23import os
24import sys
25
26import six
27
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.python.distribute import distribute_coordinator_context
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import lookup_ops
35from tensorflow.python.ops import resources
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.summary import summary
39from tensorflow.python.training import basic_session_run_hooks
40from tensorflow.python.training import coordinator
41from tensorflow.python.training import queue_runner
42from tensorflow.python.training import saver as training_saver
43from tensorflow.python.training import session_manager as sm
44from tensorflow.python.training import session_run_hook
45from tensorflow.python.training.tracking import graph_view
46from tensorflow.python.training.tracking import util as trackable_util
47from tensorflow.python.util import function_utils
48from tensorflow.python.util.tf_export import tf_export
49
50# The list of exceptions that we should recover from. Exceptions not in this
51# list may terminate the job.
52_PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError)
53
54# Value that indicates no value was provided.
55USE_DEFAULT = object()
56
57
58@tf_export(v1=['train.Scaffold'])
59class Scaffold(object):
60  """Structure to create or gather pieces commonly needed to train a model.
61
62  When you build a model for training you usually need ops to initialize
63  variables, a `Saver` to checkpoint them, an op to collect summaries for
64  the visualizer, and so on.
65
66  Various libraries built on top of the core TensorFlow library take care of
67  creating some or all of these pieces and storing them in well known
68  collections in the graph.  The `Scaffold` class helps pick these pieces from
69  the graph collections, creating and adding them to the collections if needed.
70
71  If you call the scaffold constructor without any arguments, it will pick
72  pieces from the collections, creating default ones if needed when
73  `scaffold.finalize()` is called.  You can pass arguments to the constructor to
74  provide your own pieces.  Pieces that you pass to the constructor are not
75  added to the graph collections.
76
77  The following pieces are directly accessible as attributes of the `Scaffold`
78  object:
79
80  * `saver`: A `tf.train.Saver` object taking care of saving the variables.
81    Picked from and stored into the `SAVERS` collection in the graph by default.
82  * `init_op`: An op to run to initialize the variables.  Picked from and
83    stored into the `INIT_OP` collection in the graph by default.
84  * `ready_op`: An op to verify that the variables are initialized.  Picked
85    from and stored into the `READY_OP` collection in the graph by default.
86  * `ready_for_local_init_op`: An op to verify that global state has been
87    initialized and it is alright to run `local_init_op`.  Picked from and
88    stored into the `READY_FOR_LOCAL_INIT_OP` collection in the graph by
89    default. This is needed when the initialization of local variables depends
90    on the values of global variables.
91  * `local_init_op`: An op to initialize the local variables.  Picked
92    from and stored into the `LOCAL_INIT_OP` collection in the graph by default.
93  * `summary_op`: An op to run and merge the summaries in the graph.  Picked
94    from and stored into the `SUMMARY_OP` collection in the graph by default.
95
96  You can also pass the following additional pieces to the constructor:
97
98  * `init_feed_dict`: A session feed dictionary that should be used when
99     running the init op.
100  * `init_fn`: A callable to run after the init op to perform additional
101    initializations.  The callable will be called as
102    `init_fn(scaffold, session)`.
103
104  """
105
106  def __init__(self,
107               init_op=None,
108               init_feed_dict=None,
109               init_fn=None,
110               ready_op=None,
111               ready_for_local_init_op=None,
112               local_init_op=None,
113               summary_op=None,
114               saver=None,
115               copy_from_scaffold=None):
116    """Create a scaffold.
117
118    Args:
119      init_op: Optional op for initializing variables.
120      init_feed_dict: Optional session feed dictionary to use when running the
121        init_op.
122      init_fn: Optional function to use to initialize the model after running
123        the init_op.  Will be called as `init_fn(scaffold, session)`.
124      ready_op: Optional op to verify that the variables are initialized.  Must
125        return an empty 1D string tensor when the variables are initialized, or
126        a non-empty 1D string tensor listing the names of the non-initialized
127        variables.
128      ready_for_local_init_op: Optional op to verify that the global variables
129        are initialized and `local_init_op` can be run. Must return an empty 1D
130        string tensor when the global variables are initialized, or a non-empty
131        1D string tensor listing the names of the non-initialized global
132        variables.
133      local_init_op: Optional op to initialize local variables.
134      summary_op: Optional op to gather all summaries.  Must return a scalar
135        string tensor containing a serialized `Summary` proto.
136      saver: Optional `tf.train.Saver` object to use to save and restore
137        variables.  May also be a `tf.train.Checkpoint` object, in which case
138        object-based checkpoints are saved. This will also load some
139        object-based checkpoints saved from elsewhere, but that loading may be
140        fragile since it uses fixed keys rather than performing a full
141        graph-based match. For example if a variable has two paths from the
142        `Checkpoint` object because two `Model` objects share the `Layer` object
143        that owns it, removing one `Model` may change the keys and break
144        checkpoint loading through this API, whereas a graph-based match would
145        match the variable through the other `Model`.
146      copy_from_scaffold: Optional scaffold object to copy fields from. Its
147        fields will be overwritten by the provided fields in this function.
148    """
149    if copy_from_scaffold is not None:
150      if not isinstance(copy_from_scaffold, Scaffold):
151        raise TypeError('copy_from_scaffold is not a Scaffold instance.')
152      # We need _coalesce since Tensor is not converted to bool automatically,
153      # so the common idiom of (a or b) does not work.
154      coalesce = lambda a, b: a if a is not None else b
155      init_op = coalesce(init_op, copy_from_scaffold.init_op)
156      init_feed_dict = coalesce(init_feed_dict,
157                                copy_from_scaffold.init_feed_dict)
158      # Use the original init_fn provided by the user to init the new Scaffold.
159      init_fn = coalesce(init_fn, copy_from_scaffold._user_init_fn)  # pylint: disable=protected-access
160      ready_op = coalesce(ready_op, copy_from_scaffold.ready_op)
161      ready_for_local_init_op = coalesce(
162          ready_for_local_init_op, copy_from_scaffold.ready_for_local_init_op)
163      local_init_op = coalesce(local_init_op, copy_from_scaffold.local_init_op)
164      summary_op = coalesce(summary_op, copy_from_scaffold.summary_op)
165      saver = coalesce(saver, copy_from_scaffold.saver)
166
167    # NOTE(touts): modifying the init function to be passed the scaffold is a
168    # hack to make it easy to find the saver.  Is there a better way?
169    self._user_init_fn = init_fn
170    if init_fn:
171      self._init_fn = lambda sess: init_fn(self, sess)
172    else:
173      self._init_fn = None
174
175    self._init_op = init_op
176    self._init_feed_dict = init_feed_dict
177    self._ready_op = ready_op
178    self._ready_for_local_init_op = ready_for_local_init_op
179    self._local_init_op = local_init_op
180    self._summary_op = summary_op
181    self._saver = saver
182
183  def finalize(self):
184    """Creates operations if needed and finalizes the graph."""
185    if self._init_op is None:
186
187      def default_init_op():
188        return control_flow_ops.group(
189            variables.global_variables_initializer(),
190            resources.initialize_resources(resources.shared_resources()))
191
192      self._init_op = Scaffold.get_or_default('init_op', ops.GraphKeys.INIT_OP,
193                                              default_init_op)
194    if self._ready_op is None:
195
196      def default_ready_op():
197        return array_ops.concat([
198            variables.report_uninitialized_variables(),
199            resources.report_uninitialized_resources()
200        ], 0)
201
202      self._ready_op = Scaffold.get_or_default(
203          'ready_op', ops.GraphKeys.READY_OP, default_ready_op)
204    if self._ready_for_local_init_op is None:
205
206      def default_ready_for_local_init_op():
207        return array_ops.concat([
208            variables.report_uninitialized_variables(
209                variables.global_variables()),
210            resources.report_uninitialized_resources(
211                resources.shared_resources())
212        ], 0)
213
214      self._ready_for_local_init_op = Scaffold.get_or_default(
215          'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
216          default_ready_for_local_init_op)
217    if self._local_init_op is None:
218      self._local_init_op = Scaffold.get_or_default(
219          'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
220          Scaffold.default_local_init_op)
221    if self._summary_op is None:
222      self._summary_op = Scaffold.get_or_default(
223          'summary_op', ops.GraphKeys.SUMMARY_OP, summary.merge_all)
224    # pylint: disable=g-long-lambda
225    if self._saver is None:
226      self._saver = training_saver._get_saver_or_default()  # pylint: disable=protected-access
227    # pylint: enable=g-long-lambda
228    if isinstance(self._saver, trackable_util.Checkpoint):
229      self._saver = training_saver.Saver(
230          var_list=graph_view.ObjectGraphView(
231              self._saver).frozen_saveable_objects(),
232          sharded=True)
233    else:
234      self._saver.build()
235
236    ops.get_default_graph().finalize()
237    logging.info('Graph was finalized.')
238    return self
239
240  @property
241  def init_fn(self):
242    return self._init_fn
243
244  @property
245  def init_op(self):
246    return self._init_op
247
248  @property
249  def ready_op(self):
250    return self._ready_op
251
252  @property
253  def ready_for_local_init_op(self):
254    return self._ready_for_local_init_op
255
256  @property
257  def local_init_op(self):
258    return self._local_init_op
259
260  @property
261  def summary_op(self):
262    return self._summary_op
263
264  @property
265  def saver(self):
266    return self._saver
267
268  @property
269  def init_feed_dict(self):
270    return self._init_feed_dict
271
272  @staticmethod
273  def get_or_default(arg_name, collection_key, default_constructor):
274    """Get from cache or create a default operation."""
275    elements = ops.get_collection(collection_key)
276    if elements:
277      if len(elements) > 1:
278        raise RuntimeError(
279            'More than one item in the collection "%s". '
280            'Please indicate which one to use by passing it to '
281            'the tf.Scaffold constructor as:  '
282            'tf.Scaffold(%s=item to use)', collection_key, arg_name)
283      return elements[0]
284    op = default_constructor()
285    if op is not None:
286      ops.add_to_collection(collection_key, op)
287    return op
288
289  @staticmethod
290  def default_local_init_op():
291    """Returns an op that groups the default local init ops.
292
293    This op is used during session initialization when a Scaffold is
294    initialized without specifying the local_init_op arg. It includes
295    `tf.local_variables_initializer`, `tf.tables_initializer`, and also
296    initializes local session resources.
297
298    Returns:
299      The default Scaffold local init op.
300    """
301    return control_flow_ops.group(
302        variables.local_variables_initializer(),
303        lookup_ops.tables_initializer(),
304        resources.initialize_resources(resources.local_resources()))
305
306
307def _create_monitored_session_with_worker_context(
308    worker_context,  # pylint: disable=missing-docstring
309    scaffold,
310    checkpoint_dir=None,
311    hooks=None,
312    chief_only_hooks=None,
313    save_checkpoint_secs=None,
314    save_summaries_steps=None,
315    save_summaries_secs=None,
316    config=None,
317    stop_grace_period_secs=120,
318    log_step_count_steps=100,
319    max_wait_secs=7200,
320    save_checkpoint_steps=None,
321    summary_dir=None):
322  all_hooks = []
323  if hooks:
324    all_hooks.extend(hooks)
325  if chief_only_hooks and worker_context.is_chief:
326    all_hooks.extend(chief_only_hooks)
327
328  # We need to call save or summary ops on all workers since these ops may
329  # contain collective ops, only running save ops on some workers would make
330  # collective ops hang. Therefore on those workers that don't need to actually
331  # write checkpoints or summaries, we let them write to a temp directory.
332  # pylint: disable=protected-access
333  if type(worker_context._strategy).__name__ in ('CollectiveAllReduceStrategy',
334                                                 'MultiWorkerMirroredStrategy'):
335    if worker_context.task_type:
336      tmpdir = 'tmp_%s_%d' % (worker_context.task_type, worker_context.task_id)
337    else:
338      tmpdir = 'tmp'
339
340    if save_checkpoint_secs:
341      logging.warning('Collective ops may deadlock with '
342                      '`save_checkpoints_secs` please use '
343                      '`save_checkpoint_steps` instead. Clearing '
344                      '`save_checkpoint_secs` and setting '
345                      '`save_checkpoint_steps` to 1000 now.')
346      save_checkpoint_secs = None
347      save_checkpoint_steps = 1000
348    if save_summaries_secs:
349      logging.warning('Collective ops may run out of sync with'
350                      '`save_summaries_secs`, please use '
351                      '`save_summaries_steps` instead.')
352  else:
353    tmpdir = None
354
355  summary_dir = summary_dir or checkpoint_dir
356  if summary_dir and log_step_count_steps and log_step_count_steps > 0:
357    if worker_context.should_save_summary:
358      all_hooks.append(
359          basic_session_run_hooks.StepCounterHook(
360              output_dir=summary_dir, every_n_steps=log_step_count_steps))
361    elif tmpdir:
362      all_hooks.append(
363          basic_session_run_hooks.StepCounterHook(
364              output_dir=os.path.join(summary_dir, tmpdir),
365              every_n_steps=log_step_count_steps))
366
367  if (((save_summaries_steps and save_summaries_steps > 0) or
368       (save_summaries_secs and save_summaries_secs > 0)) and summary_dir):
369    if worker_context.should_save_summary:
370      all_hooks.append(
371          basic_session_run_hooks.SummarySaverHook(
372              scaffold=scaffold,
373              save_steps=save_summaries_steps,
374              save_secs=save_summaries_secs,
375              output_dir=summary_dir))
376    elif tmpdir:
377      all_hooks.append(
378          basic_session_run_hooks.SummarySaverHook(
379              scaffold=scaffold,
380              save_steps=save_summaries_steps,
381              save_secs=save_summaries_secs,
382              output_dir=os.path.join(summary_dir, tmpdir)))
383
384    if (((save_checkpoint_secs and save_checkpoint_secs > 0) or
385         (save_checkpoint_steps and save_checkpoint_steps > 0)) and
386        checkpoint_dir):
387      if worker_context.should_checkpoint:
388        all_hooks.append(
389            basic_session_run_hooks.CheckpointSaverHook(
390                checkpoint_dir,
391                save_steps=save_checkpoint_steps,
392                save_secs=save_checkpoint_secs,
393                scaffold=scaffold))
394      elif tmpdir:
395        all_hooks.append(
396            basic_session_run_hooks.CheckpointSaverHook(
397                os.path.join(checkpoint_dir, tmpdir),
398                save_steps=save_checkpoint_steps,
399                save_secs=save_checkpoint_secs,
400                scaffold=scaffold))
401
402  logging.info('all_hooks %r', all_hooks)
403  session_creator = worker_context.session_creator(
404      scaffold,
405      config=config,
406      checkpoint_dir=checkpoint_dir,
407      max_wait_secs=max_wait_secs)
408  return MonitoredSession(
409      session_creator=session_creator,
410      hooks=all_hooks,
411      stop_grace_period_secs=stop_grace_period_secs)
412
413
414@tf_export(v1=['train.MonitoredTrainingSession'])
415def MonitoredTrainingSession(
416    master='',  # pylint: disable=invalid-name
417    is_chief=True,
418    checkpoint_dir=None,
419    scaffold=None,
420    hooks=None,
421    chief_only_hooks=None,
422    save_checkpoint_secs=USE_DEFAULT,
423    save_summaries_steps=USE_DEFAULT,
424    save_summaries_secs=USE_DEFAULT,
425    config=None,
426    stop_grace_period_secs=120,
427    log_step_count_steps=100,
428    max_wait_secs=7200,
429    save_checkpoint_steps=USE_DEFAULT,
430    summary_dir=None):
431  """Creates a `MonitoredSession` for training.
432
433  For a chief, this utility sets proper session initializer/restorer. It also
434  creates hooks related to checkpoint and summary saving. For workers, this
435  utility sets proper session creator which waits for the chief to
436  initialize/restore. Please check `tf.train.MonitoredSession` for more
437  information.
438
439
440  Args:
441    master: `String` the TensorFlow master to use.
442    is_chief: If `True`, it will take care of initialization and recovery the
443      underlying TensorFlow session. If `False`, it will wait on a chief to
444      initialize or recover the TensorFlow session.
445    checkpoint_dir: A string.  Optional path to a directory where to restore
446      variables.
447    scaffold: A `Scaffold` used for gathering or building supportive ops. If not
448      specified, a default one is created. It's used to finalize the graph.
449    hooks: Optional list of `SessionRunHook` objects.
450    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
451      `is_chief==True`, ignore otherwise.
452    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
453      using a default checkpoint saver. If both `save_checkpoint_steps` and
454      `save_checkpoint_secs` are set to `None`, then the default checkpoint
455      saver isn't used. If both are provided, then only `save_checkpoint_secs`
456      is used. Default 600.
457    save_summaries_steps: The frequency, in number of global steps, that the
458      summaries are written to disk using a default summary saver. If both
459      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
460      the default summary saver isn't used. Default 100.
461    save_summaries_secs: The frequency, in secs, that the summaries are written
462      to disk using a default summary saver.  If both `save_summaries_steps` and
463      `save_summaries_secs` are set to `None`, then the default summary saver
464      isn't used. Default not enabled.
465    config: an instance of `tf.ConfigProto` proto used to configure the session.
466      It's the `config` argument of constructor of `tf.Session`.
467    stop_grace_period_secs: Number of seconds given to threads to stop after
468      `close()` has been called.
469    log_step_count_steps: The frequency, in number of global steps, that the
470      global step/sec is logged.
471    max_wait_secs: Maximum time workers should wait for the session to become
472      available. This should be kept relatively short to help detect incorrect
473      code, but sometimes may need to be increased if the chief takes a while to
474      start up.
475    save_checkpoint_steps: The frequency, in number of global steps, that a
476      checkpoint is saved using a default checkpoint saver. If both
477      `save_checkpoint_steps` and `save_checkpoint_secs` are set to `None`, then
478      the default checkpoint saver isn't used. If both are provided, then only
479      `save_checkpoint_secs` is used. Default not enabled.
480    summary_dir: A string.  Optional path to a directory where to save
481      summaries. If None, checkpoint_dir is used instead.
482
483  Returns:
484    A `MonitoredSession` object.
485  """
486  if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT:
487    save_summaries_steps = 100
488    save_summaries_secs = None
489  elif save_summaries_secs == USE_DEFAULT:
490    save_summaries_secs = None
491  elif save_summaries_steps == USE_DEFAULT:
492    save_summaries_steps = None
493
494  if (save_checkpoint_steps == USE_DEFAULT and
495      save_checkpoint_secs == USE_DEFAULT):
496    save_checkpoint_steps = None
497    save_checkpoint_secs = 600
498  elif save_checkpoint_secs == USE_DEFAULT:
499    save_checkpoint_secs = None
500  elif save_checkpoint_steps == USE_DEFAULT:
501    save_checkpoint_steps = None
502
503  scaffold = scaffold or Scaffold()
504  worker_context = distribute_coordinator_context.get_current_worker_context()
505
506  if worker_context:
507    return _create_monitored_session_with_worker_context(
508        worker_context,
509        scaffold,
510        checkpoint_dir=checkpoint_dir,
511        hooks=hooks,
512        chief_only_hooks=chief_only_hooks,
513        save_checkpoint_secs=save_checkpoint_secs,
514        save_summaries_steps=save_summaries_steps,
515        save_summaries_secs=save_summaries_secs,
516        config=config,
517        stop_grace_period_secs=stop_grace_period_secs,
518        log_step_count_steps=log_step_count_steps,
519        max_wait_secs=max_wait_secs,
520        save_checkpoint_steps=save_checkpoint_steps,
521        summary_dir=summary_dir)
522
523  if not is_chief:
524    session_creator = WorkerSessionCreator(
525        scaffold=scaffold,
526        master=master,
527        config=config,
528        max_wait_secs=max_wait_secs)
529    return MonitoredSession(
530        session_creator=session_creator,
531        hooks=hooks or [],
532        stop_grace_period_secs=stop_grace_period_secs)
533
534  all_hooks = []
535  if chief_only_hooks:
536    all_hooks.extend(chief_only_hooks)
537  session_creator = ChiefSessionCreator(
538      scaffold=scaffold,
539      checkpoint_dir=checkpoint_dir,
540      master=master,
541      config=config)
542
543  summary_dir = summary_dir or checkpoint_dir
544  if summary_dir:
545    if log_step_count_steps and log_step_count_steps > 0:
546      all_hooks.append(
547          basic_session_run_hooks.StepCounterHook(
548              output_dir=summary_dir, every_n_steps=log_step_count_steps))
549
550    if (save_summaries_steps and
551        save_summaries_steps > 0) or (save_summaries_secs and
552                                      save_summaries_secs > 0):
553      all_hooks.append(
554          basic_session_run_hooks.SummarySaverHook(
555              scaffold=scaffold,
556              save_steps=save_summaries_steps,
557              save_secs=save_summaries_secs,
558              output_dir=summary_dir))
559
560  if checkpoint_dir:
561    if (save_checkpoint_secs and
562        save_checkpoint_secs > 0) or (save_checkpoint_steps and
563                                      save_checkpoint_steps > 0):
564      all_hooks.append(
565          basic_session_run_hooks.CheckpointSaverHook(
566              checkpoint_dir,
567              save_steps=save_checkpoint_steps,
568              save_secs=save_checkpoint_secs,
569              scaffold=scaffold))
570
571  if hooks:
572    all_hooks.extend(hooks)
573  return MonitoredSession(
574      session_creator=session_creator,
575      hooks=all_hooks,
576      stop_grace_period_secs=stop_grace_period_secs)
577
578
579@tf_export(v1=['train.SessionCreator'])
580@six.add_metaclass(abc.ABCMeta)
581class SessionCreator(object):
582  """A factory for tf.Session."""
583
584  @abc.abstractmethod
585  def create_session(self):
586    raise NotImplementedError(
587        'create_session is not implemented for {}.'.format(self))
588
589
590@tf_export(v1=['train.ChiefSessionCreator'])
591class ChiefSessionCreator(SessionCreator):
592  """Creates a tf.Session for a chief."""
593
594  def __init__(self,
595               scaffold=None,
596               master='',
597               config=None,
598               checkpoint_dir=None,
599               checkpoint_filename_with_path=None):
600    """Initializes a chief session creator.
601
602    Args:
603      scaffold: A `Scaffold` used for gathering or building supportive ops. If
604        not specified a default one is created. It's used to finalize the graph.
605      master: `String` representation of the TensorFlow master to use.
606      config: `ConfigProto` proto used to configure the session.
607      checkpoint_dir: A string.  Optional path to a directory where to restore
608        variables.
609      checkpoint_filename_with_path: Full file name path to the checkpoint file.
610    """
611    self._checkpoint_dir = checkpoint_dir
612    self._checkpoint_filename_with_path = checkpoint_filename_with_path
613    self._scaffold = scaffold or Scaffold()
614    self._session_manager = None
615    self._master = master
616    self._config = config
617
618  def _get_session_manager(self):
619    if self._session_manager:
620      return self._session_manager
621
622    self._session_manager = sm.SessionManager(
623        local_init_op=self._scaffold.local_init_op,
624        ready_op=self._scaffold.ready_op,
625        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
626        graph=ops.get_default_graph())
627    return self._session_manager
628
629  def create_session(self):
630    self._scaffold.finalize()
631    return self._get_session_manager().prepare_session(
632        self._master,
633        saver=self._scaffold.saver,
634        checkpoint_dir=self._checkpoint_dir,
635        checkpoint_filename_with_path=self._checkpoint_filename_with_path,
636        config=self._config,
637        init_op=self._scaffold.init_op,
638        init_feed_dict=self._scaffold.init_feed_dict,
639        init_fn=self._scaffold.init_fn)
640
641
642@tf_export(v1=['train.WorkerSessionCreator'])
643class WorkerSessionCreator(SessionCreator):
644  """Creates a tf.Session for a worker."""
645
646  def __init__(self,
647               scaffold=None,
648               master='',
649               config=None,
650               max_wait_secs=30 * 60):
651    """Initializes a worker session creator.
652
653    Args:
654      scaffold: A `Scaffold` used for gathering or building supportive ops. If
655        not specified a default one is created. It's used to finalize the graph.
656      master: `String` representation of the TensorFlow master to use.
657      config: `ConfigProto` proto used to configure the session.
658      max_wait_secs: Maximum time to wait for the session to become available.
659    """
660    self._scaffold = scaffold or Scaffold()
661    self._session_manager = None
662    self._master = master
663    self._config = config
664    self._max_wait_secs = max_wait_secs
665
666  def _get_session_manager(self):
667    if self._session_manager:
668      return self._session_manager
669
670    self._session_manager = sm.SessionManager(
671        local_init_op=self._scaffold.local_init_op,
672        ready_op=self._scaffold.ready_op,
673        ready_for_local_init_op=self._scaffold.ready_for_local_init_op,
674        graph=ops.get_default_graph())
675    return self._session_manager
676
677  def create_session(self):
678    self._scaffold.finalize()
679    return self._get_session_manager().wait_for_session(
680        self._master, config=self._config, max_wait_secs=self._max_wait_secs)
681
682
683class _MonitoredSession(object):
684  """See `MonitoredSession` or `SingularMonitoredSession`."""
685
686  def __init__(self,
687               session_creator,
688               hooks,
689               should_recover,
690               stop_grace_period_secs=120):
691    """Sets up a Monitored or Hooked Session.
692
693    Args:
694      session_creator: A factory object to create session. Typically a
695        `ChiefSessionCreator` or a `WorkerSessionCreator`.
696      hooks: An iterable of `SessionRunHook' objects.
697      should_recover: A bool. Indicates whether to recover from `AbortedError`
698        and `UnavailableError` or not.
699      stop_grace_period_secs: Number of seconds given to threads to stop after
700        `close()` has been called.
701    """
702    self._graph_was_finalized = ops.get_default_graph().finalized
703    self._hooks = hooks or []
704    for h in self._hooks:
705      h.begin()
706
707    worker_context = distribute_coordinator_context.get_current_worker_context()
708    if not session_creator and worker_context:
709      session_creator = worker_context.session_creator()
710
711    # Create the session.
712    self._coordinated_creator = self._CoordinatedSessionCreator(
713        session_creator=session_creator or ChiefSessionCreator(),
714        hooks=self._hooks,
715        stop_grace_period_secs=stop_grace_period_secs)
716    if should_recover:
717      self._sess = _RecoverableSession(self._coordinated_creator)
718    else:
719      self._sess = self._coordinated_creator.create_session()
720
721  @property
722  def graph(self):
723    """The graph that was launched in this session."""
724    if self._tf_sess() is None:
725      return None
726    return self._tf_sess().graph
727
728  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
729    """Run ops in the monitored session.
730
731    This method is completely compatible with the `tf.Session.run()` method.
732
733    Args:
734      fetches: Same as `tf.Session.run()`.
735      feed_dict: Same as `tf.Session.run()`.
736      options: Same as `tf.Session.run()`.
737      run_metadata: Same as `tf.Session.run()`.
738
739    Returns:
740      Same as `tf.Session.run()`.
741    """
742    return self._sess.run(
743        fetches,
744        feed_dict=feed_dict,
745        options=options,
746        run_metadata=run_metadata)
747
748  def run_step_fn(self, step_fn):
749    """Run ops using a step function.
750
751    Args:
752      step_fn: A function or a method with a single argument of type
753        `StepContext`.  The function may use methods of the argument to perform
754        computations with access to a raw session.  The returned value of the
755        `step_fn` will be returned from `run_step_fn`, unless a stop is
756        requested.  In that case, the next `should_stop` call will return True.
757        Example usage:  ```python
758           with tf.Graph().as_default(): c = tf.placeholder(dtypes.float32) v =
759             tf.add(c, 4.0) w = tf.add(c, 0.5)
760             def step_fn(step_context):
761               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
762               if a <= 4.5: step_context.request_stop()
763               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})
764             with tf.MonitoredSession() as session:
765               while not session.should_stop(): a = session.run_step_fn(step_fn)
766                 ```  Hooks interact with the `run_with_hooks()` call inside the
767                 `step_fn` as they do with a `MonitoredSession.run` call.
768
769    Returns:
770      Returns the returned value of `step_fn`.
771
772    Raises:
773      StopIteration: if `step_fn` has called `request_stop()`.  It may be
774        caught by `with tf.MonitoredSession()` to close the session.
775      ValueError: if `step_fn` doesn't have a single argument called
776        `step_context`. It may also optionally have `self` for cases when it
777        belongs to an object.
778    """
779    step_fn_arguments = function_utils.fn_args(step_fn)
780    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
781        'self',
782        'step_context',
783    ):
784      raise ValueError(
785          '`step_fn` may either have one `step_context` argument, or'
786          ' `self` and `step_context` arguments if it\'s an instance'
787          ' method. Got {} instead.'.format(step_fn_arguments))
788
789    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
790    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
791    # `_CoordinatedSession.run` downstream in either case. This allows
792    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
793    # `_RecoverableSession.run_step_fn`.
794    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
795
796  class StepContext(object):
797    """Control flow instrument for the `step_fn` from `run_step_fn()`.
798
799       Users of `step_fn` may perform `run()` calls without running hooks
800       by accessing the `session`.  A `run()` call with hooks may be performed
801       using `run_with_hooks()`.  Computation flow can be interrupted using
802       `request_stop()`.
803    """
804
805    def __init__(self, session, run_with_hooks_fn):
806      """Initializes the `step_context` argument for a `step_fn` invocation.
807
808      Args:
809        session: An instance of `tf.Session`.
810        run_with_hooks_fn: A function for running fetches and hooks.
811      """
812      self._session = session
813      self._run_with_hooks_fn = run_with_hooks_fn
814
815    @property
816    def session(self):
817      return self._session
818
819    def run_with_hooks(self, *args, **kwargs):
820      """Same as `MonitoredSession.run`. Accepts the same arguments."""
821      return self._run_with_hooks_fn(*args, **kwargs)
822
823    def request_stop(self):
824      """Exit the training loop by causing `should_stop()` to return `True`.
825
826         Causes `step_fn` to exit by raising an exception.
827
828      Raises:
829        StopIteration
830      """
831      raise StopIteration('step_fn has requested the iterations to stop.')
832
833  def should_stop(self):
834    return self._sess is None or self._sess.should_stop()
835
836  def close(self):
837    self._close_internal()
838
839  def __enter__(self):
840    return self
841
842  def __exit__(self, exception_type, exception_value, traceback):
843    if exception_type in [errors.OutOfRangeError, StopIteration]:
844      exception_type = None
845    self._close_internal(exception_type)
846    # __exit__ should return True to suppress an exception.
847    return exception_type is None
848
849  class _CoordinatedSessionCreator(SessionCreator):
850    """Factory for the _RecoverableSession."""
851
852    def __init__(self, session_creator, hooks, stop_grace_period_secs):
853      self._session_creator = session_creator
854      self._hooks = hooks
855      self.coord = None
856      self.tf_sess = None
857      self._stop_grace_period_secs = stop_grace_period_secs
858
859    def create_session(self):
860      """Creates a coordinated session."""
861      # Keep the tf_sess for unit testing.
862      self.tf_sess = self._session_creator.create_session()
863      # We don't want coordinator to suppress any exception.
864      self.coord = coordinator.Coordinator(clean_stop_exception_types=[])
865      if ops.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
866        queue_runner.start_queue_runners(sess=self.tf_sess, coord=self.coord)
867      # Inform the hooks that a new session has been created.
868      for hook in self._hooks:
869        hook.after_create_session(self.tf_sess, self.coord)
870      return _CoordinatedSession(
871          _HookedSession(self.tf_sess, self._hooks), self.coord,
872          self._stop_grace_period_secs)
873
874  def _close_internal(self, exception_type=None):
875    try:
876      if not exception_type:
877        for h in self._hooks:
878          h.end(self._coordinated_creator.tf_sess)
879    finally:
880      try:
881        if self._sess is None:
882          raise RuntimeError('Session is already closed.')
883        self._sess.close()
884      finally:
885        self._sess = None
886        self._coordinated_creator.tf_sess = None
887        self._coordinated_creator.coord = None
888        if not self._graph_was_finalized:
889          ops.get_default_graph()._unsafe_unfinalize()  # pylint: disable=protected-access
890
891  def _is_closed(self):
892    """Return True if the monitored session is closed.
893
894    For tests only.
895
896    Returns:
897      A boolean.
898    """
899    return self._coordinated_creator.tf_sess is None
900
901  def _tf_sess(self):
902    """Return underlying tf.Session object.
903
904    Warning: accessing the returned object in user code is likely to cause races
905    or "flaky tests".
906
907    Returns:
908      A tf.Session object.
909    """
910    return self._coordinated_creator.tf_sess
911
912
913@tf_export(v1=['train.MonitoredSession'])
914class MonitoredSession(_MonitoredSession):
915  """Session-like object that handles initialization, recovery and hooks.
916
917  Example usage:
918
919  ```python
920  saver_hook = CheckpointSaverHook(...)
921  summary_hook = SummarySaverHook(...)
922  with MonitoredSession(session_creator=ChiefSessionCreator(...),
923                        hooks=[saver_hook, summary_hook]) as sess:
924    while not sess.should_stop():
925      sess.run(train_op)
926  ```
927
928  Initialization: At creation time the monitored session does following things
929  in given order:
930
931  * calls `hook.begin()` for each given hook
932  * finalizes the graph via `scaffold.finalize()`
933  * create session
934  * initializes the model via initialization ops provided by `Scaffold`
935  * restores variables if a checkpoint exists
936  * launches queue runners
937  * calls `hook.after_create_session()`
938
939  Run: When `run()` is called, the monitored session does following things:
940
941  * calls `hook.before_run()`
942  * calls TensorFlow `session.run()` with merged fetches and feed_dict
943  * calls `hook.after_run()`
944  * returns result of `session.run()` asked by user
945  * if `AbortedError` or `UnavailableError` occurs, it recovers or
946    reinitializes the session before executing the run() call again
947
948
949  Exit: At the `close()`, the monitored session does following things in order:
950
951  * calls `hook.end()`
952  * closes the queue runners and the session
953  * suppresses `OutOfRange` error which indicates that all inputs have been
954    processed if the monitored_session is used as a context
955
956  How to set `tf.Session` arguments:
957
958  * In most cases you can set session arguments as follows:
959
960  ```python
961  MonitoredSession(
962    session_creator=ChiefSessionCreator(master=..., config=...))
963  ```
964
965  * In distributed setting for a non-chief worker, you can use following:
966
967  ```python
968  MonitoredSession(
969    session_creator=WorkerSessionCreator(master=..., config=...))
970  ```
971
972  See `MonitoredTrainingSession` for an example usage based on chief or worker.
973
974  Note: This is not a `tf.Session`. For example, it cannot do following:
975
976  * it cannot be set as default session.
977  * it cannot be sent to saver.save.
978  * it cannot be sent to tf.train.start_queue_runners.
979
980  Args:
981    session_creator: A factory object to create session. Typically a
982      `ChiefSessionCreator` which is the default one.
983    hooks: An iterable of `SessionRunHook' objects.
984
985  Returns:
986    A MonitoredSession object.
987  """
988
989  def __init__(self,
990               session_creator=None,
991               hooks=None,
992               stop_grace_period_secs=120):
993    super(MonitoredSession, self).__init__(
994        session_creator,
995        hooks,
996        should_recover=True,
997        stop_grace_period_secs=stop_grace_period_secs)
998
999
1000@tf_export(v1=['train.SingularMonitoredSession'])
1001class SingularMonitoredSession(_MonitoredSession):
1002  """Session-like object that handles initialization, restoring, and hooks.
1003
1004  Please note that this utility is not recommended for distributed settings.
1005  For distributed settings, please use `tf.train.MonitoredSession`. The
1006  differences between `MonitoredSession` and `SingularMonitoredSession` are:
1007
1008  * `MonitoredSession` handles `AbortedError` and `UnavailableError` for
1009    distributed settings, but `SingularMonitoredSession` does not.
1010  * `MonitoredSession` can be created in `chief` or `worker` modes.
1011    `SingularMonitoredSession` is always created as `chief`.
1012  * You can access the raw `tf.Session` object used by
1013    `SingularMonitoredSession`, whereas in MonitoredSession the raw session is
1014    private. This can be used:
1015      - To `run` without hooks.
1016      - To save and restore.
1017  * All other functionality is identical.
1018
1019  Example usage:
1020  ```python
1021  saver_hook = CheckpointSaverHook(...)
1022  summary_hook = SummarySaverHook(...)
1023  with SingularMonitoredSession(hooks=[saver_hook, summary_hook]) as sess:
1024    while not sess.should_stop():
1025      sess.run(train_op)
1026  ```
1027
1028  Initialization: At creation time the hooked session does following things
1029  in given order:
1030
1031  * calls `hook.begin()` for each given hook
1032  * finalizes the graph via `scaffold.finalize()`
1033  * create session
1034  * initializes the model via initialization ops provided by `Scaffold`
1035  * restores variables if a checkpoint exists
1036  * launches queue runners
1037
1038  Run: When `run()` is called, the hooked session does following things:
1039
1040  * calls `hook.before_run()`
1041  * calls TensorFlow `session.run()` with merged fetches and feed_dict
1042  * calls `hook.after_run()`
1043  * returns result of `session.run()` asked by user
1044
1045  Exit: At the `close()`, the hooked session does following things in order:
1046
1047  * calls `hook.end()`
1048  * closes the queue runners and the session
1049  * suppresses `OutOfRange` error which indicates that all inputs have been
1050    processed if the `SingularMonitoredSession` is used as a context.
1051  """
1052
1053  def __init__(self,
1054               hooks=None,
1055               scaffold=None,
1056               master='',
1057               config=None,
1058               checkpoint_dir=None,
1059               stop_grace_period_secs=120,
1060               checkpoint_filename_with_path=None):
1061    """Creates a SingularMonitoredSession.
1062
1063    Args:
1064      hooks: An iterable of `SessionRunHook' objects.
1065      scaffold: A `Scaffold` used for gathering or building supportive ops. If
1066        not specified a default one is created. It's used to finalize the graph.
1067      master: `String` representation of the TensorFlow master to use.
1068      config: `ConfigProto` proto used to configure the session.
1069      checkpoint_dir: A string.  Optional path to a directory where to restore
1070        variables.
1071      stop_grace_period_secs: Number of seconds given to threads to stop after
1072        `close()` has been called.
1073      checkpoint_filename_with_path: A string. Optional path to a checkpoint
1074        file from which to restore variables.
1075    """
1076    session_creator = ChiefSessionCreator(
1077        scaffold=scaffold,
1078        master=master,
1079        config=config,
1080        checkpoint_dir=checkpoint_dir,
1081        checkpoint_filename_with_path=checkpoint_filename_with_path)
1082    super(SingularMonitoredSession, self).__init__(
1083        session_creator,
1084        hooks,
1085        should_recover=False,
1086        stop_grace_period_secs=stop_grace_period_secs)
1087
1088  def raw_session(self):
1089    """Returns underlying `TensorFlow.Session` object."""
1090    return self._tf_sess()
1091
1092
1093class _WrappedSession(object):
1094  """Wrapper around a `tf.Session`.
1095
1096  This wrapper is used as a base class for various session wrappers
1097  that provide additional functionality such as monitoring, coordination,
1098  and recovery.
1099
1100  In addition to the methods exported by `SessionInterface` the wrapper
1101  provides a method to check for stop and never raises exceptions from
1102  calls to `close()`.
1103  """
1104
1105  def __init__(self, sess):
1106    """Creates a `_WrappedSession`.
1107
1108    Args:
1109      sess: A `tf.Session` or `_WrappedSession` object.  The wrapped session.
1110    """
1111    self._sess = sess
1112    self._wrapped_is_stoppable = isinstance(self._sess, _WrappedSession)
1113
1114  @property
1115  def graph(self):
1116    return self._sess.graph
1117
1118  @property
1119  def sess_str(self):
1120    return self._sess.sess_str
1121
1122  def should_stop(self):
1123    """Return true if this session should not be used anymore.
1124
1125    Always return True if the session was closed.
1126
1127    Returns:
1128      True if the session should stop, False otherwise.
1129    """
1130    if self._check_stop():
1131      return True
1132    if self._sess:
1133      return self._wrapped_is_stoppable and self._sess.should_stop()
1134    return True
1135
1136  def _check_stop(self):
1137    """Hook for subclasses to provide their own stop condition.
1138
1139    Returns:
1140      True if the session should stop, False otherwise.
1141    """
1142    return False
1143
1144  def close(self):
1145    if self._sess:
1146      try:
1147        self._sess.close()
1148      except _PREEMPTION_ERRORS as e:
1149        logging.warning(
1150            'An error occurred when attempting to close the '
1151            'session. This may be due to a preemption in a '
1152            'connected worker or parameter server. Error: %s', e)
1153      finally:
1154        self._sess = None
1155
1156  def run(self, *args, **kwargs):
1157    return self._sess.run(*args, **kwargs)
1158
1159  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1160    # `_RecoverableSession` sets `run_with_hooks` to `_CoordinatedSession.run`.
1161    # It is `None` when called from `_CoordinatedSession`. In that case
1162    # `self.run` is `_CoordinatedSession.run`.
1163    run_with_hooks = run_with_hooks or self.run
1164    return step_fn(_MonitoredSession.StepContext(raw_session, run_with_hooks))
1165
1166
1167class _RecoverableSession(_WrappedSession):
1168  """A wrapped session that recreates a session upon certain kinds of errors.
1169
1170  The constructor is passed a SessionCreator object, not a session.
1171
1172  Calls to `run()` are delegated to the wrapped session.  If a call raises the
1173  exception `tf.errors.AbortedError` or `tf.errors.UnavailableError`, the
1174  wrapped session is closed, and a new one is created by calling the factory
1175  again.
1176  """
1177
1178  def __init__(self, sess_creator):
1179    """Create a new `_RecoverableSession`.
1180
1181    The value returned by calling `sess_creator.create_session()` will be the
1182    session wrapped by this recoverable session.
1183
1184    Args:
1185      sess_creator: A 'SessionCreator' to be wrapped by recoverable.
1186    """
1187    self._sess_creator = sess_creator
1188    _WrappedSession.__init__(self, self._create_session())
1189
1190  def _create_session(self):
1191    while True:
1192      try:
1193        return self._sess_creator.create_session()
1194      except _PREEMPTION_ERRORS as e:
1195        logging.info(
1196            'An error was raised while a session was being created. '
1197            'This may be due to a preemption of a connected worker '
1198            'or parameter server. A new session will be created. '
1199            'This error may also occur due to a gRPC failure caused '
1200            'by high memory or network bandwidth usage in the '
1201            'parameter servers. If this error occurs repeatedly, try '
1202            'increasing the number of parameter servers assigned to '
1203            'the job. Error: %s', e)
1204
1205  def _check_stop(self):
1206    try:
1207      if self._sess:
1208        return self._sess._check_stop()  # pylint: disable=protected-access
1209      else:
1210        return True
1211    except _PREEMPTION_ERRORS as e:
1212      logging.info(
1213          'An error was raised while considering whether the '
1214          'session is complete. This may be due to a preemption in '
1215          'a connected worker or parameter server. The current '
1216          'session will be closed and a new session will be '
1217          'created. This error may also occur due to a gRPC failure '
1218          'caused by high memory or network bandwidth usage in the '
1219          'parameter servers. If this error occurs repeatedly, try '
1220          'increasing the number of parameter servers assigned to '
1221          'the job. Error: %s', e)
1222      self.close()
1223      self._sess = self._create_session()
1224      # Since we have just recreated the session, the overall computation should
1225      # not stop:
1226      return False
1227    except Exception:  # pylint: disable=broad-except
1228      # `should_stop` should return True instead of raising an exception.
1229      return True
1230
1231  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1232    while True:
1233      try:
1234        if not self._sess:
1235          self._sess = self._create_session()
1236        return self._sess.run(
1237            fetches,
1238            feed_dict=feed_dict,
1239            options=options,
1240            run_metadata=run_metadata)
1241      except _PREEMPTION_ERRORS as e:
1242        logging.info(
1243            'An error was raised. This may be due to a preemption in '
1244            'a connected worker or parameter server. The current '
1245            'session will be closed and a new session will be '
1246            'created. This error may also occur due to a gRPC failure '
1247            'caused by high memory or network bandwidth usage in the '
1248            'parameter servers. If this error occurs repeatedly, try '
1249            'increasing the number of parameter servers assigned to '
1250            'the job. Error: %s', e)
1251        self.close()
1252        self._sess = None
1253
1254  def run_step_fn(self, step_fn, raw_session, run_with_hooks):
1255    while True:
1256      try:
1257        if not self._sess:
1258          self._sess = self._create_session()
1259
1260        run_with_hooks = self._sess.run
1261        return self._sess.run_step_fn(step_fn, raw_session, run_with_hooks)
1262      except _PREEMPTION_ERRORS as e:
1263        logging.info(
1264            'An error was raised. This may be due to a preemption in '
1265            'a connected worker or parameter server. The current '
1266            'session will be closed and a new session will be '
1267            'created. This error may also occur due to a gRPC failure '
1268            'caused by high memory or network bandwidth usage in the '
1269            'parameter servers. If this error occurs repeatedly, try '
1270            'increasing the number of parameter servers assigned to '
1271            'the job. Error: %s', e)
1272        self.close()
1273        self._sess = None
1274
1275
1276class _CoordinatedSession(_WrappedSession):
1277  """A wrapped session that works with a `tf.Coordinator`.
1278
1279  Calls to `run()` are delegated to the wrapped session.  If a call
1280  raises an exception, the exception is reported to the coordinator.
1281
1282  In addition, after each call to `run()` this session ask the coordinator if
1283  the session should stop.  In that case it will will join all the threads
1284  registered with the coordinator before returning.
1285
1286  If the coordinator was requested to stop with an exception, that exception
1287  will be re-raised from the call to `run()`.
1288  """
1289
1290  def __init__(self, sess, coord, stop_grace_period_secs=120):
1291    """Create a new `_CoordinatedSession`.
1292
1293    Args:
1294      sess: A `tf.Session` object.  The wrapped session.
1295      coord: A `tf.train.Coordinator` object.
1296      stop_grace_period_secs: Number of seconds given to threads to stop after
1297        `close()` has been called.
1298    """
1299    _WrappedSession.__init__(self, sess)
1300    self._coord = coord
1301    self._stop_grace_period_secs = stop_grace_period_secs
1302
1303  def _check_stop(self):
1304    # If the coordinator was asked to stop due to an exception, then it needs
1305    # to be propagated to this stack.
1306    self._coord.raise_requested_exception()
1307    # At this point, no exceptions are recorded in the coordinator.
1308    return self._coord.should_stop()
1309
1310  def close(self):
1311    self._coord.request_stop()
1312    try:
1313      self._coord.join(
1314          stop_grace_period_secs=self._stop_grace_period_secs,
1315          ignore_live_threads=True)
1316    finally:
1317      try:
1318        _WrappedSession.close(self)
1319      except Exception:  # pylint: disable=broad-except
1320        # We intentionally suppress exceptions from the close() here since
1321        # useful exceptions are already reported by join().
1322        pass
1323
1324  def run(self, *args, **kwargs):
1325    try:
1326      return self._sess.run(*args, **kwargs)
1327    except _PREEMPTION_ERRORS:
1328      raise
1329    except Exception:  # pylint: disable=broad-except
1330      # A non-preemption error could have been caused by a preemption error
1331      # in the coordinator. If this is the case, raise that exception instead,
1332      # since it's the root cause. Otherwise, stick to the `original_exc_info`.
1333      original_exc_info = sys.exc_info()
1334      try:
1335        self._coord.raise_requested_exception()
1336      except _PREEMPTION_ERRORS:
1337        raise
1338      except Exception:  # pylint: disable=broad-except
1339        raise six.reraise(*original_exc_info)
1340      else:
1341        raise six.reraise(*original_exc_info)
1342
1343
1344class _HookedSession(_WrappedSession):
1345  """A _WrappedSession that calls hooks during calls to run().
1346
1347  The list of hooks to call is passed in the constructor.  Before each call
1348  to `run()` the session calls the `before_run()` method of the hooks, which
1349  can return additional ops or tensors to run.  These are added to the arguments
1350  of the call to `run()`.
1351
1352  When the `run()` call finishes, the session calls the `after_run()` methods of
1353  the hooks, passing the values returned by the `run()` call corresponding to
1354  the ops and tensors that each hook requested.
1355
1356  If any call to the hooks, requests stop via run_context the session will be
1357  marked as needing to stop and its `should_stop()` method will now return
1358  `True`.
1359  """
1360
1361  def __init__(self, sess, hooks):
1362    """Initializes a _HookedSession object.
1363
1364    Args:
1365      sess: A `tf.Session` or a `_WrappedSession` object.
1366      hooks: An iterable of `SessionRunHook' objects.
1367    """
1368
1369    _WrappedSession.__init__(self, sess)
1370    self._hooks = hooks
1371    self._should_stop = False
1372
1373  def _check_stop(self):
1374    """See base class."""
1375    return self._should_stop
1376
1377  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
1378    """See base class."""
1379    if self.should_stop():
1380      raise RuntimeError('Run called even after should_stop requested.')
1381
1382    actual_fetches = {'caller': fetches}
1383
1384    run_context = session_run_hook.SessionRunContext(
1385        original_args=session_run_hook.SessionRunArgs(fetches, feed_dict),
1386        session=self._sess)
1387
1388    options = options or config_pb2.RunOptions()
1389    feed_dict = self._call_hook_before_run(run_context, actual_fetches,
1390                                           feed_dict, options)
1391
1392    # Do session run.
1393    run_metadata = run_metadata or config_pb2.RunMetadata()
1394    outputs = _WrappedSession.run(
1395        self,
1396        fetches=actual_fetches,
1397        feed_dict=feed_dict,
1398        options=options,
1399        run_metadata=run_metadata)
1400
1401    for hook in self._hooks:
1402      hook.after_run(
1403          run_context,
1404          session_run_hook.SessionRunValues(
1405              results=outputs[hook] if hook in outputs else None,
1406              options=options,
1407              run_metadata=run_metadata))
1408    self._should_stop = self._should_stop or run_context.stop_requested
1409
1410    return outputs['caller']
1411
1412  def _call_hook_before_run(self, run_context, fetch_dict, user_feed_dict,
1413                            options):
1414    """Calls hooks.before_run and handles requests from hooks."""
1415    hook_feeds = {}
1416    for hook in self._hooks:
1417      request = hook.before_run(run_context)
1418      if request is not None:
1419        if request.fetches is not None:
1420          fetch_dict[hook] = request.fetches
1421        if request.feed_dict:
1422          self._raise_if_feeds_intersects(hook_feeds, request.feed_dict,
1423                                          'Same tensor is fed by two hooks.')
1424          hook_feeds.update(request.feed_dict)
1425        if request.options:
1426          self._merge_run_options(options, request.options)
1427
1428    if not hook_feeds:
1429      return user_feed_dict
1430
1431    if not user_feed_dict:
1432      return hook_feeds
1433
1434    self._raise_if_feeds_intersects(
1435        user_feed_dict, hook_feeds,
1436        'Same tensor is fed by a SessionRunHook and user.')
1437    hook_feeds.update(user_feed_dict)
1438    return hook_feeds
1439
1440  def _raise_if_feeds_intersects(self, feeds1, feeds2, message):
1441    intersection = set(feeds1.keys()) & set(feeds2.keys())
1442    if intersection:
1443      raise RuntimeError(message + ' Conflict(s): ' + str(list(intersection)))
1444
1445  def _merge_run_options(self, options, incoming_options):
1446    """Merge two instances of RunOptions into the first one.
1447
1448    During the merger, the numerical fields including trace_level,
1449    timeout_in_ms, inter_op_thread_pool are set to the larger one of the two.
1450    The boolean value is set to the logical OR of the two.
1451    debug_tensor_watch_opts of the original options is extended with that from
1452    the incoming one.
1453
1454    Args:
1455      options: The options to merge into.
1456      incoming_options: The options to be merged into the first argument.
1457    """
1458    options.trace_level = max(options.trace_level, incoming_options.trace_level)
1459    options.timeout_in_ms = max(options.timeout_in_ms,
1460                                incoming_options.timeout_in_ms)
1461    options.inter_op_thread_pool = max(options.inter_op_thread_pool,
1462                                       incoming_options.inter_op_thread_pool)
1463    options.output_partition_graphs = max(
1464        options.output_partition_graphs,
1465        incoming_options.output_partition_graphs)
1466    options.debug_options.debug_tensor_watch_opts.extend(
1467        incoming_options.debug_options.debug_tensor_watch_opts)
1468    options.debug_options.reset_disk_byte_usage = (
1469        options.debug_options.reset_disk_byte_usage or
1470        incoming_options.debug_options.reset_disk_byte_usage)
1471    options.report_tensor_allocations_upon_oom = (
1472        options.report_tensor_allocations_upon_oom or
1473        incoming_options.report_tensor_allocations_upon_oom)
1474