• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""High level operations on graphs (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import itertools
28import sys
29import threading
30import time
31
32import numpy as np
33
34from six import reraise
35
36from tensorflow.contrib.framework import load_variable
37from tensorflow.contrib.framework.python.ops import ops as contrib_ops
38from tensorflow.contrib.framework.python.ops import variables as contrib_variables
39from tensorflow.contrib.learn.python.learn import monitors as monitors_lib
40from tensorflow.core.framework import summary_pb2
41from tensorflow.python.client import session as tf_session
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import logging_ops
46from tensorflow.python.ops import lookup_ops
47from tensorflow.python.ops import resources
48from tensorflow.python.ops import variables
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.training import coordinator
51from tensorflow.python.training import queue_runner
52from tensorflow.python.training import saver as tf_saver
53from tensorflow.python.training import session_manager as session_manager_lib
54from tensorflow.python.training import summary_io
55from tensorflow.python.training import supervisor as tf_supervisor
56from tensorflow.python.util.deprecation import deprecated
57
58# Singleton for SummaryWriter per logdir folder.
59_SUMMARY_WRITERS = {}
60
61# Lock protecting _SUMMARY_WRITERS
62_summary_writer_lock = threading.Lock()
63
64_graph_action_deprecation = deprecated(
65    '2017-02-15',
66    'graph_actions.py will be deleted. Use tf.train.* utilities instead. '
67    'You can use learn/estimators/estimator.py as an example.')
68
69
70@_graph_action_deprecation
71def clear_summary_writers():
72  """Clear cached summary writers. Currently only used for unit tests."""
73  return summary_io.SummaryWriterCache.clear()
74
75
76@deprecated(None, 'Use `SummaryWriterCache.get` directly.')
77def get_summary_writer(logdir):
78  """Returns single SummaryWriter per logdir in current run.
79
80  Args:
81    logdir: str, folder to write summaries.
82
83  Returns:
84    Existing `SummaryWriter` object or new one if never wrote to given
85    directory.
86  """
87  return summary_io.SummaryWriterCache.get(logdir)
88
89
90def _make_saver(graph, keep_checkpoint_max=5):
91  vars_to_save = (graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
92                  graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))
93  if vars_to_save:
94    return tf_saver.Saver(vars_to_save,
95                          sharded=True,
96                          max_to_keep=keep_checkpoint_max)
97  else:
98    return None
99
100
101def _restore_from_checkpoint(session, graph, checkpoint_path, saver=None):
102  logging.info('Loading model from checkpoint: %s.', checkpoint_path)
103  saver = saver or _make_saver(graph)
104  if saver:
105    saver.restore(session, checkpoint_path)
106  else:
107    logging.info('No variables found in graph, not creating Saver() object.')
108
109
110def _run_with_monitors(session, step, tensors, feed_dict, monitors):
111  """Runs session for given tensors with monitor callbacks."""
112  for monitor in monitors:
113    tensors += monitor.step_begin(step)
114  tensors = list(set(tensors))
115
116  outputs = session.run(tensors, feed_dict=feed_dict)
117  outputs = dict(zip(
118      [t.name if isinstance(t, ops.Tensor) else t for t in tensors],
119      outputs))
120
121  should_stop = False
122  for monitor in monitors:
123    induce_stop = monitor.step_end(step, outputs)
124    should_stop = should_stop or induce_stop
125  return outputs, should_stop
126
127
128@_graph_action_deprecation
129def train(graph,
130          output_dir,
131          train_op,
132          loss_op,
133          global_step_tensor=None,
134          init_op=None,
135          init_feed_dict=None,
136          init_fn=None,
137          log_every_steps=10,
138          supervisor_is_chief=True,
139          supervisor_master='',
140          supervisor_save_model_secs=600,
141          keep_checkpoint_max=5,
142          supervisor_save_summaries_steps=100,
143          feed_fn=None,
144          steps=None,
145          fail_on_nan_loss=True,
146          monitors=None,
147          max_steps=None):
148  """Train a model.
149
150  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
151  run a training loop. The given `train_op` performs one step of training on the
152  model. The `loss_op` represents the objective function of the training. It is
153  expected to increment the `global_step_tensor`, a scalar integer tensor
154  counting training steps. This function uses `Supervisor` to initialize the
155  graph (from a checkpoint if one is available in `output_dir`), write summaries
156  defined in the graph, and write regular checkpoints as defined by
157  `supervisor_save_model_secs`.
158
159  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
160  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
161  program is terminated with exit code 1.
162
163  Args:
164    graph: A graph to train. It is expected that this graph is not in use
165      elsewhere.
166    output_dir: A directory to write outputs to.
167    train_op: An op that performs one training step when run.
168    loss_op: A scalar loss tensor.
169    global_step_tensor: A tensor representing the global step. If none is given,
170      one is extracted from the graph using the same logic as in `Supervisor`.
171    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
172      default.
173    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
174      This feed dictionary will be used when `init_op` is evaluated.
175    init_fn: Optional callable passed to Supervisor to initialize the model.
176    log_every_steps: Output logs regularly. The logs contain timing data and the
177      current loss.
178    supervisor_is_chief: Whether the current process is the chief supervisor in
179      charge of restoring the model and running standard services.
180    supervisor_master: The master string to use when preparing the session.
181    supervisor_save_model_secs: Save a checkpoint every
182      `supervisor_save_model_secs` seconds when training.
183    keep_checkpoint_max: The maximum number of recent checkpoint files to
184      keep. As new files are created, older files are deleted. If None or 0,
185      all checkpoint files are kept. This is simply passed as the max_to_keep
186      arg to tf.train.Saver constructor.
187    supervisor_save_summaries_steps: Save summaries every
188      `supervisor_save_summaries_steps` seconds when training.
189    feed_fn: A function that is called every iteration to produce a `feed_dict`
190      passed to `session.run` calls. Optional.
191    steps: Trains for this many steps (e.g. current global step + `steps`).
192    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
193      evaluates to `NaN`. If false, continue training as if nothing happened.
194    monitors: List of `BaseMonitor` subclass instances. Used for callbacks
195      inside the training loop.
196    max_steps: Number of total steps for which to train model. If `None`,
197      train forever. Two calls fit(steps=100) means 200 training iterations.
198      On the other hand two calls of fit(max_steps=100) means, second call
199      will not do any iteration since first call did all 100 steps.
200
201  Returns:
202    The final loss value.
203
204  Raises:
205    ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
206      is not provided. See `tf.contrib.framework.get_global_step` for how we
207      look up the latter if not provided explicitly.
208    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
209      evaluates to `NaN`.
210    ValueError: If both `steps` and `max_steps` are not `None`.
211  """
212  while True:
213    try:
214      return _train_internal(graph,
215                             output_dir,
216                             train_op,
217                             loss_op,
218                             global_step_tensor,
219                             init_op,
220                             init_feed_dict,
221                             init_fn,
222                             log_every_steps,
223                             supervisor_is_chief,
224                             supervisor_master,
225                             supervisor_save_model_secs,
226                             keep_checkpoint_max,
227                             supervisor_save_summaries_steps,
228                             feed_fn,
229                             steps,
230                             fail_on_nan_loss,
231                             monitors,
232                             max_steps)
233    except errors.AbortedError:
234      # Happens when PS restarts, keep training.
235      logging.warning('Training got Aborted error. Keep training.')
236
237
238def _train_internal(graph,
239                    output_dir,
240                    train_op,
241                    loss_op,
242                    global_step_tensor,
243                    init_op,
244                    init_feed_dict,
245                    init_fn,
246                    log_every_steps,
247                    supervisor_is_chief,
248                    supervisor_master,
249                    supervisor_save_model_secs,
250                    keep_checkpoint_max,
251                    supervisor_save_summaries_steps,
252                    feed_fn,
253                    steps,
254                    fail_on_nan_loss,
255                    monitors,
256                    max_steps):
257  """See train."""
258  if (steps is not None) and (max_steps is not None):
259    raise ValueError('Can not provide both steps and max_steps.')
260  if not output_dir:
261    raise ValueError('Output directory should be non-empty %s.' % output_dir)
262  if train_op is None:
263    raise ValueError('Missing train_op.')
264  if loss_op is None:
265    raise ValueError('Missing loss_op.')
266
267  with graph.as_default():
268    global_step_tensor = contrib_variables.assert_or_get_global_step(
269        graph, global_step_tensor)
270    if global_step_tensor is None:
271      raise ValueError('No "global_step" was provided or found in the graph.')
272
273    # Get current step.
274    try:
275      start_step = load_variable(output_dir, global_step_tensor.name)
276    except (errors.NotFoundError, ValueError):
277      start_step = 0
278
279    summary_writer = (get_summary_writer(output_dir)
280                      if supervisor_is_chief else None)
281
282    # Add default chief monitors if none were provided.
283    if not monitors:
284      monitors = monitors_lib.get_default_monitors(
285          loss_op=loss_op,
286          summary_op=logging_ops.get_summary_op(),
287          save_summary_steps=supervisor_save_summaries_steps,
288          summary_writer=summary_writer) if supervisor_is_chief else []
289
290    # TODO(ipolosukhin): Replace all functionality of Supervisor
291    # with Chief-Exclusive Monitors.
292    if not supervisor_is_chief:
293      # Prune list of monitor to the ones runnable on all workers.
294      monitors = [monitor for monitor in monitors if monitor.run_on_all_workers]
295
296    if max_steps is None:
297      max_steps = (start_step + steps) if steps else None
298    # Start monitors, can create graph parts.
299    for monitor in monitors:
300      monitor.begin(max_steps=max_steps)
301
302  supervisor = tf_supervisor.Supervisor(
303      graph,
304      init_op=init_op or tf_supervisor.Supervisor.USE_DEFAULT,
305      init_feed_dict=init_feed_dict,
306      is_chief=supervisor_is_chief,
307      logdir=output_dir,
308      saver=_make_saver(graph, keep_checkpoint_max),
309      global_step=global_step_tensor,
310      summary_op=None,
311      summary_writer=summary_writer,
312      save_model_secs=supervisor_save_model_secs,
313      init_fn=init_fn)
314  session = supervisor.PrepareSession(master=supervisor_master,
315                                      start_standard_services=True)
316  supervisor.StartQueueRunners(session)
317
318  with session:
319    get_current_step = lambda: session.run(global_step_tensor)
320
321    start_step = get_current_step()
322    last_step = start_step
323    last_log_step = start_step
324    loss_value = None
325    logging.info('Training steps [%d,%s)', last_step, 'inf'
326                 if max_steps is None else str(max_steps))
327
328    excinfo = None
329    try:
330      while not supervisor.ShouldStop() and (
331          (max_steps is None) or (last_step < max_steps)):
332        start_time = time.time()
333        feed_dict = feed_fn() if feed_fn is not None else None
334
335        outputs, should_stop = _run_with_monitors(
336            session, last_step + 1, [train_op, loss_op], feed_dict, monitors)
337
338        loss_value = outputs[loss_op.name]
339        if np.isnan(loss_value):
340          failure_message = 'Model diverged with loss = NaN.'
341          if fail_on_nan_loss:
342            logging.error(failure_message)
343            raise monitors_lib.NanLossDuringTrainingError()
344          else:
345            logging.warning(failure_message)
346
347        if should_stop:
348          break
349
350        this_step = get_current_step()
351
352        if this_step <= last_step:
353          logging.error(
354              'Global step was not incremented by train op at step %s'
355              ': new step %d', last_step, this_step)
356
357        last_step = this_step
358        is_last_step = (max_steps is not None) and (last_step >= max_steps)
359        if is_last_step or (last_step - last_log_step >= log_every_steps):
360          logging.info(
361              'training step %d, loss = %.5f (%.3f sec/batch).',
362              last_step, loss_value, float(time.time() - start_time))
363          last_log_step = last_step
364    except errors.OutOfRangeError as e:
365      logging.warn('Got exception during tf.learn training loop possibly '
366                   'due to exhausted input queue %s.', e)
367    except StopIteration:
368      logging.info('Exhausted input iterarator.')
369    except BaseException as e:  # pylint: disable=broad-except
370      # Hold on to any other exceptions while we try recording a final
371      # checkpoint and summary.
372      excinfo = sys.exc_info()
373    finally:
374      try:
375        # Call supervisor.Stop() from within a try block because it re-raises
376        # exceptions thrown by the supervised threads.
377        supervisor.Stop(close_summary_writer=False)
378
379        # Save one last checkpoint and summaries
380        # TODO(wicke): This should be handled by Supervisor
381
382        # In case we encountered an exception in the try block before we updated
383        # last_step, update it here (again).
384        last_step = get_current_step()
385        if supervisor_is_chief:
386          ckpt_path = supervisor.save_path
387          logging.info('Saving checkpoint for step %d to checkpoint: %s.',
388                       last_step, ckpt_path)
389          supervisor.saver.save(session, ckpt_path, global_step=last_step)
390
391          # Finish monitors.
392          for monitor in monitors:
393            monitor.end()
394
395      # catch OutOfRangeError which is thrown when queue is out of data (and for
396      # other reasons as well).
397      except errors.OutOfRangeError as e:
398        logging.warn('OutOfRangeError in tf.learn final checkpoint possibly '
399                     'due to exhausted input queue. Note: summary_op is not '
400                     'expected to trigger dequeues. %s.', e)
401      except BaseException as e:  # pylint: disable=broad-except
402        # If we don't already have an exception to re-raise, raise this one.
403        if not excinfo:
404          raise
405        # Otherwise, log this one and raise the other in the finally block.
406        logging.error('Got exception during tf.learn final checkpoint %s.', e)
407      finally:
408        if excinfo:
409          reraise(*excinfo)
410    return loss_value
411
412
413def _get_first_op_from_collection(collection_name):
414  elements = ops.get_collection(collection_name)
415  if elements:
416    return elements[0]
417  return None
418
419
420def _get_saver():
421  """Lazy init and return saver."""
422  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
423  if saver is None and variables.global_variables():
424    saver = tf_saver.Saver()
425    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
426  return saver
427
428
429def _get_ready_op():
430  ready_op = _get_first_op_from_collection(ops.GraphKeys.READY_OP)
431  if ready_op is None:
432    ready_op = variables.report_uninitialized_variables()
433    ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
434  return ready_op
435
436
437def _get_local_init_op():
438  """Returns the local init ops to initialize tables and local variables."""
439  local_init_op = _get_first_op_from_collection(
440      ops.GraphKeys.LOCAL_INIT_OP)
441  if local_init_op is None:
442    op_list = [
443        variables.local_variables_initializer(),
444        lookup_ops.tables_initializer()
445    ]
446    if op_list:
447      local_init_op = control_flow_ops.group(*op_list)
448      ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
449  return local_init_op
450
451
452def _eval_results_to_str(eval_results):
453  return ', '.join('%s = %s' % (k, v) for k, v in sorted(eval_results.items()))
454
455
456def _write_summary_results(output_dir, eval_results, current_global_step):
457  """Writes eval results into summary file in given dir."""
458  logging.info('Saving evaluation summary for step %d: %s', current_global_step,
459               _eval_results_to_str(eval_results))
460  summary_writer = get_summary_writer(output_dir)
461  summary = summary_pb2.Summary()
462  for key in eval_results:
463    if eval_results[key] is None:
464      continue
465    value = summary.value.add()
466    value.tag = key
467    if (isinstance(eval_results[key], np.float32) or
468        isinstance(eval_results[key], float)):
469      value.simple_value = float(eval_results[key])
470    else:
471      logging.warn('Skipping summary for %s, must be a float or np.float32.',
472                   key)
473  summary_writer.add_summary(summary, current_global_step)
474  summary_writer.flush()
475
476
477@_graph_action_deprecation
478def evaluate(graph,
479             output_dir,
480             checkpoint_path,
481             eval_dict,
482             update_op=None,
483             global_step_tensor=None,
484             supervisor_master='',
485             log_every_steps=10,
486             feed_fn=None,
487             max_steps=None):
488  """Evaluate a model loaded from a checkpoint.
489
490  Given `graph`, a directory to write summaries to (`output_dir`), a checkpoint
491  to restore variables from, and a `dict` of `Tensor`s to evaluate, run an eval
492  loop for `max_steps` steps, or until an exception (generally, an
493  end-of-input signal from a reader operation) is raised from running
494  `eval_dict`.
495
496  In each step of evaluation, all tensors in the `eval_dict` are evaluated, and
497  every `log_every_steps` steps, they are logged. At the very end of evaluation,
498  a summary is evaluated (finding the summary ops using `Supervisor`'s logic)
499  and written to `output_dir`.
500
501  Args:
502    graph: A `Graph` to train. It is expected that this graph is not in use
503      elsewhere.
504    output_dir: A string containing the directory to write a summary to.
505    checkpoint_path: A string containing the path to a checkpoint to restore.
506      Can be `None` if the graph doesn't require loading any variables.
507    eval_dict: A `dict` mapping string names to tensors to evaluate. It is
508      evaluated in every logging step. The result of the final evaluation is
509      returned. If `update_op` is None, then it's evaluated in every step. If
510      `max_steps` is `None`, this should depend on a reader that will raise an
511      end-of-input exception when the inputs are exhausted.
512    update_op: A `Tensor` which is run in every step.
513    global_step_tensor: A `Variable` containing the global step. If `None`,
514      one is extracted from the graph using the same logic as in `Supervisor`.
515      Used to place eval summaries on training curves.
516    supervisor_master: The master string to use when preparing the session.
517    log_every_steps: Integer. Output logs every `log_every_steps` evaluation
518      steps. The logs contain the `eval_dict` and timing information.
519    feed_fn: A function that is called every iteration to produce a `feed_dict`
520      passed to `session.run` calls. Optional.
521    max_steps: Integer. Evaluate `eval_dict` this many times.
522
523  Returns:
524    A tuple `(eval_results, global_step)`:
525    eval_results: A `dict` mapping `string` to numeric values (`int`, `float`)
526      that are the result of running eval_dict in the last step. `None` if no
527      eval steps were run.
528    global_step: The global step this evaluation corresponds to.
529
530  Raises:
531    ValueError: if `output_dir` is empty.
532  """
533  if not output_dir:
534    raise ValueError('Output directory should be non-empty %s.' % output_dir)
535  with graph.as_default():
536    global_step_tensor = contrib_variables.assert_or_get_global_step(
537        graph, global_step_tensor)
538
539    # Create or get summary op, global_step and saver.
540    saver = _get_saver()
541    local_init_op = _get_local_init_op()
542    ready_for_local_init_op = _get_first_op_from_collection(
543        ops.GraphKeys.READY_FOR_LOCAL_INIT_OP)
544    ready_op = _get_ready_op()
545
546    session_manager = session_manager_lib.SessionManager(
547        local_init_op=local_init_op,
548        ready_op=ready_op,
549        ready_for_local_init_op=ready_for_local_init_op)
550    session, initialized = session_manager.recover_session(
551        master=supervisor_master,
552        saver=saver,
553        checkpoint_dir=checkpoint_path)
554
555    # Start queue runners.
556    coord = coordinator.Coordinator()
557    threads = queue_runner.start_queue_runners(session, coord)
558
559  with session:
560    if not initialized:
561      logging.warning('Failed to initialize from %s.', checkpoint_path)
562      # TODO(ipolosukhin): This should be failing, but old code relies on that.
563      session.run(variables.global_variables_initializer())
564      if checkpoint_path:
565        _restore_from_checkpoint(session, graph, checkpoint_path, saver)
566
567    current_global_step = session.run(global_step_tensor)
568    eval_results = None
569    # TODO(amodei): Fix this to run through the eval set exactly once.
570    step = 0
571    eval_step = None
572    feed_dict = None
573    logging.info('Eval steps [%d,%s) for training step %d.', step,
574                 'inf' if max_steps is None
575                 else str(max_steps), current_global_step)
576    try:
577      try:
578        while (max_steps is None) or (step < max_steps):
579          step += 1
580          start_time = time.time()
581          feed_dict = feed_fn() if feed_fn is not None else None
582          if update_op is not None:
583            session.run(update_op, feed_dict=feed_dict)
584          else:
585            eval_results = session.run(eval_dict, feed_dict=feed_dict)
586            eval_step = step
587
588          # TODO(wicke): We should assert that the global step hasn't changed.
589          if step % log_every_steps == 0:
590            if eval_step is None or step != eval_step:
591              eval_results = session.run(eval_dict, feed_dict=feed_dict)
592              eval_step = step
593            duration = time.time() - start_time
594            logging.info('Results after %d steps (%.3f sec/batch): %s.',
595                         step, float(duration),
596                         _eval_results_to_str(eval_results))
597      finally:
598        if eval_results is None or step != eval_step:
599          eval_results = session.run(eval_dict, feed_dict=feed_dict)
600          eval_step = step
601        # Stop session first, before queue runners.
602        session.close()
603
604        # Stop queue runners.
605        try:
606          coord.request_stop()
607          coord.join(threads, stop_grace_period_secs=120)
608        except (RuntimeError, errors.CancelledError) as e:
609          logging.warning('Coordinator didn\'t stop cleanly: %s', e)
610
611    # catch OutOfRangeError which is thrown when queue is out of data (and for
612    # other reasons as well).
613    except errors.OutOfRangeError as e:
614      if max_steps is None:
615        logging.info('Input queue is exhausted.')
616      else:
617        logging.warn('Input queue is exhausted: %s.', e)
618    # catch StopIteration which is thrown is DataReader is out of data.
619    except StopIteration as e:
620      if max_steps is None:
621        logging.info('Input iterator is exhausted.')
622      else:
623        logging.warn('Input iterator is exhausted: %s.', e)
624
625  # Save summaries for this evaluation.
626  _write_summary_results(output_dir, eval_results, current_global_step)
627
628  return eval_results, current_global_step
629
630
631@_graph_action_deprecation
632def run_n(output_dict, feed_dict=None, restore_checkpoint_path=None, n=1):
633  """Run `output_dict` tensors `n` times, with the same `feed_dict` each run.
634
635  Args:
636    output_dict: A `dict` mapping string names to tensors to run. Must all be
637      from the same graph.
638    feed_dict: `dict` of input values to feed each run.
639    restore_checkpoint_path: A string containing the path to a checkpoint to
640      restore.
641    n: Number of times to repeat.
642
643  Returns:
644    A list of `n` `dict` objects, each containing values read from `output_dict`
645    tensors.
646  """
647  return run_feeds(
648      output_dict=output_dict,
649      feed_dicts=itertools.repeat(feed_dict, n),
650      restore_checkpoint_path=restore_checkpoint_path)
651
652
653@_graph_action_deprecation
654def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
655  """Run `output_dict` tensors with each input in `feed_dicts`.
656
657  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
658  init all variables.
659
660  Args:
661    output_dict: A `dict` mapping string names to `Tensor` objects to run.
662      Tensors must all be from the same graph.
663    feed_dicts: Iterable of `dict` objects of input values to feed.
664    restore_checkpoint_path: A string containing the path to a checkpoint to
665      restore.
666
667  Yields:
668    A sequence of dicts of values read from `output_dict` tensors, one item
669    yielded for each item in `feed_dicts`. Keys are the same as `output_dict`,
670    values are the results read from the corresponding `Tensor` in
671    `output_dict`.
672
673  Raises:
674    ValueError: if `output_dict` or `feed_dicts` is None or empty.
675  """
676  if not output_dict:
677    raise ValueError('output_dict is invalid: %s.' % output_dict)
678  if not feed_dicts:
679    raise ValueError('feed_dicts is invalid: %s.' % feed_dicts)
680
681  graph = contrib_ops.get_graph_from_inputs(output_dict.values())
682  with graph.as_default() as g:
683    with tf_session.Session('') as session:
684      session.run(
685          resources.initialize_resources(resources.shared_resources() +
686                                         resources.local_resources()))
687      if restore_checkpoint_path:
688        _restore_from_checkpoint(session, g, restore_checkpoint_path)
689      else:
690        session.run(variables.global_variables_initializer())
691      session.run(variables.local_variables_initializer())
692      session.run(lookup_ops.tables_initializer())
693      coord = coordinator.Coordinator()
694      threads = None
695      try:
696        threads = queue_runner.start_queue_runners(session, coord=coord)
697        for f in feed_dicts:
698          yield session.run(output_dict, f)
699      finally:
700        coord.request_stop()
701        if threads:
702          coord.join(threads, stop_grace_period_secs=120)
703
704
705@_graph_action_deprecation
706def run_feeds(*args, **kwargs):
707  """See run_feeds_iter(). Returns a `list` instead of an iterator."""
708  return list(run_feeds_iter(*args, **kwargs))
709
710
711@_graph_action_deprecation
712def infer(restore_checkpoint_path, output_dict, feed_dict=None):
713  """Restore graph from `restore_checkpoint_path` and run `output_dict` tensors.
714
715  If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise,
716  init all variables.
717
718  Args:
719    restore_checkpoint_path: A string containing the path to a checkpoint to
720      restore.
721    output_dict: A `dict` mapping string names to `Tensor` objects to run.
722      Tensors must all be from the same graph.
723    feed_dict: `dict` object mapping `Tensor` objects to input values to feed.
724
725  Returns:
726    Dict of values read from `output_dict` tensors. Keys are the same as
727    `output_dict`, values are the results read from the corresponding `Tensor`
728    in `output_dict`.
729
730  Raises:
731    ValueError: if `output_dict` or `feed_dicts` is None or empty.
732  """
733  return run_feeds(output_dict=output_dict,
734                   feed_dicts=[feed_dict] if feed_dict is not None else [None],
735                   restore_checkpoint_path=restore_checkpoint_path)[0]
736