1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""High level operations on graphs (deprecated). 17 18This module and all its submodules are deprecated. See 19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 20for migration instructions. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import itertools 28import sys 29import threading 30import time 31 32import numpy as np 33 34from six import reraise 35 36from tensorflow.contrib.framework import load_variable 37from tensorflow.contrib.framework.python.ops import ops as contrib_ops 38from tensorflow.contrib.framework.python.ops import variables as contrib_variables 39from tensorflow.contrib.learn.python.learn import monitors as monitors_lib 40from tensorflow.core.framework import summary_pb2 41from tensorflow.python.client import session as tf_session 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import logging_ops 46from tensorflow.python.ops import lookup_ops 47from tensorflow.python.ops import resources 48from tensorflow.python.ops import variables 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.training import coordinator 51from tensorflow.python.training import queue_runner 52from tensorflow.python.training import saver as tf_saver 53from tensorflow.python.training import session_manager as session_manager_lib 54from tensorflow.python.training import summary_io 55from tensorflow.python.training import supervisor as tf_supervisor 56from tensorflow.python.util.deprecation import deprecated 57 58# Singleton for SummaryWriter per logdir folder. 59_SUMMARY_WRITERS = {} 60 61# Lock protecting _SUMMARY_WRITERS 62_summary_writer_lock = threading.Lock() 63 64_graph_action_deprecation = deprecated( 65 '2017-02-15', 66 'graph_actions.py will be deleted. Use tf.train.* utilities instead. ' 67 'You can use learn/estimators/estimator.py as an example.') 68 69 70@_graph_action_deprecation 71def clear_summary_writers(): 72 """Clear cached summary writers. Currently only used for unit tests.""" 73 return summary_io.SummaryWriterCache.clear() 74 75 76@deprecated(None, 'Use `SummaryWriterCache.get` directly.') 77def get_summary_writer(logdir): 78 """Returns single SummaryWriter per logdir in current run. 79 80 Args: 81 logdir: str, folder to write summaries. 82 83 Returns: 84 Existing `SummaryWriter` object or new one if never wrote to given 85 directory. 86 """ 87 return summary_io.SummaryWriterCache.get(logdir) 88 89 90def _make_saver(graph, keep_checkpoint_max=5): 91 vars_to_save = (graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + 92 graph.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS)) 93 if vars_to_save: 94 return tf_saver.Saver(vars_to_save, 95 sharded=True, 96 max_to_keep=keep_checkpoint_max) 97 else: 98 return None 99 100 101def _restore_from_checkpoint(session, graph, checkpoint_path, saver=None): 102 logging.info('Loading model from checkpoint: %s.', checkpoint_path) 103 saver = saver or _make_saver(graph) 104 if saver: 105 saver.restore(session, checkpoint_path) 106 else: 107 logging.info('No variables found in graph, not creating Saver() object.') 108 109 110def _run_with_monitors(session, step, tensors, feed_dict, monitors): 111 """Runs session for given tensors with monitor callbacks.""" 112 for monitor in monitors: 113 tensors += monitor.step_begin(step) 114 tensors = list(set(tensors)) 115 116 outputs = session.run(tensors, feed_dict=feed_dict) 117 outputs = dict(zip( 118 [t.name if isinstance(t, ops.Tensor) else t for t in tensors], 119 outputs)) 120 121 should_stop = False 122 for monitor in monitors: 123 induce_stop = monitor.step_end(step, outputs) 124 should_stop = should_stop or induce_stop 125 return outputs, should_stop 126 127 128@_graph_action_deprecation 129def train(graph, 130 output_dir, 131 train_op, 132 loss_op, 133 global_step_tensor=None, 134 init_op=None, 135 init_feed_dict=None, 136 init_fn=None, 137 log_every_steps=10, 138 supervisor_is_chief=True, 139 supervisor_master='', 140 supervisor_save_model_secs=600, 141 keep_checkpoint_max=5, 142 supervisor_save_summaries_steps=100, 143 feed_fn=None, 144 steps=None, 145 fail_on_nan_loss=True, 146 monitors=None, 147 max_steps=None): 148 """Train a model. 149 150 Given `graph`, a directory to write outputs to (`output_dir`), and some ops, 151 run a training loop. The given `train_op` performs one step of training on the 152 model. The `loss_op` represents the objective function of the training. It is 153 expected to increment the `global_step_tensor`, a scalar integer tensor 154 counting training steps. This function uses `Supervisor` to initialize the 155 graph (from a checkpoint if one is available in `output_dir`), write summaries 156 defined in the graph, and write regular checkpoints as defined by 157 `supervisor_save_model_secs`. 158 159 Training continues until `global_step_tensor` evaluates to `max_steps`, or, if 160 `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the 161 program is terminated with exit code 1. 162 163 Args: 164 graph: A graph to train. It is expected that this graph is not in use 165 elsewhere. 166 output_dir: A directory to write outputs to. 167 train_op: An op that performs one training step when run. 168 loss_op: A scalar loss tensor. 169 global_step_tensor: A tensor representing the global step. If none is given, 170 one is extracted from the graph using the same logic as in `Supervisor`. 171 init_op: An op that initializes the graph. If `None`, use `Supervisor`'s 172 default. 173 init_feed_dict: A dictionary that maps `Tensor` objects to feed values. 174 This feed dictionary will be used when `init_op` is evaluated. 175 init_fn: Optional callable passed to Supervisor to initialize the model. 176 log_every_steps: Output logs regularly. The logs contain timing data and the 177 current loss. 178 supervisor_is_chief: Whether the current process is the chief supervisor in 179 charge of restoring the model and running standard services. 180 supervisor_master: The master string to use when preparing the session. 181 supervisor_save_model_secs: Save a checkpoint every 182 `supervisor_save_model_secs` seconds when training. 183 keep_checkpoint_max: The maximum number of recent checkpoint files to 184 keep. As new files are created, older files are deleted. If None or 0, 185 all checkpoint files are kept. This is simply passed as the max_to_keep 186 arg to tf.train.Saver constructor. 187 supervisor_save_summaries_steps: Save summaries every 188 `supervisor_save_summaries_steps` seconds when training. 189 feed_fn: A function that is called every iteration to produce a `feed_dict` 190 passed to `session.run` calls. Optional. 191 steps: Trains for this many steps (e.g. current global step + `steps`). 192 fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op` 193 evaluates to `NaN`. If false, continue training as if nothing happened. 194 monitors: List of `BaseMonitor` subclass instances. Used for callbacks 195 inside the training loop. 196 max_steps: Number of total steps for which to train model. If `None`, 197 train forever. Two calls fit(steps=100) means 200 training iterations. 198 On the other hand two calls of fit(max_steps=100) means, second call 199 will not do any iteration since first call did all 100 steps. 200 201 Returns: 202 The final loss value. 203 204 Raises: 205 ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor` 206 is not provided. See `tf.contrib.framework.get_global_step` for how we 207 look up the latter if not provided explicitly. 208 NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever 209 evaluates to `NaN`. 210 ValueError: If both `steps` and `max_steps` are not `None`. 211 """ 212 while True: 213 try: 214 return _train_internal(graph, 215 output_dir, 216 train_op, 217 loss_op, 218 global_step_tensor, 219 init_op, 220 init_feed_dict, 221 init_fn, 222 log_every_steps, 223 supervisor_is_chief, 224 supervisor_master, 225 supervisor_save_model_secs, 226 keep_checkpoint_max, 227 supervisor_save_summaries_steps, 228 feed_fn, 229 steps, 230 fail_on_nan_loss, 231 monitors, 232 max_steps) 233 except errors.AbortedError: 234 # Happens when PS restarts, keep training. 235 logging.warning('Training got Aborted error. Keep training.') 236 237 238def _train_internal(graph, 239 output_dir, 240 train_op, 241 loss_op, 242 global_step_tensor, 243 init_op, 244 init_feed_dict, 245 init_fn, 246 log_every_steps, 247 supervisor_is_chief, 248 supervisor_master, 249 supervisor_save_model_secs, 250 keep_checkpoint_max, 251 supervisor_save_summaries_steps, 252 feed_fn, 253 steps, 254 fail_on_nan_loss, 255 monitors, 256 max_steps): 257 """See train.""" 258 if (steps is not None) and (max_steps is not None): 259 raise ValueError('Can not provide both steps and max_steps.') 260 if not output_dir: 261 raise ValueError('Output directory should be non-empty %s.' % output_dir) 262 if train_op is None: 263 raise ValueError('Missing train_op.') 264 if loss_op is None: 265 raise ValueError('Missing loss_op.') 266 267 with graph.as_default(): 268 global_step_tensor = contrib_variables.assert_or_get_global_step( 269 graph, global_step_tensor) 270 if global_step_tensor is None: 271 raise ValueError('No "global_step" was provided or found in the graph.') 272 273 # Get current step. 274 try: 275 start_step = load_variable(output_dir, global_step_tensor.name) 276 except (errors.NotFoundError, ValueError): 277 start_step = 0 278 279 summary_writer = (get_summary_writer(output_dir) 280 if supervisor_is_chief else None) 281 282 # Add default chief monitors if none were provided. 283 if not monitors: 284 monitors = monitors_lib.get_default_monitors( 285 loss_op=loss_op, 286 summary_op=logging_ops.get_summary_op(), 287 save_summary_steps=supervisor_save_summaries_steps, 288 summary_writer=summary_writer) if supervisor_is_chief else [] 289 290 # TODO(ipolosukhin): Replace all functionality of Supervisor 291 # with Chief-Exclusive Monitors. 292 if not supervisor_is_chief: 293 # Prune list of monitor to the ones runnable on all workers. 294 monitors = [monitor for monitor in monitors if monitor.run_on_all_workers] 295 296 if max_steps is None: 297 max_steps = (start_step + steps) if steps else None 298 # Start monitors, can create graph parts. 299 for monitor in monitors: 300 monitor.begin(max_steps=max_steps) 301 302 supervisor = tf_supervisor.Supervisor( 303 graph, 304 init_op=init_op or tf_supervisor.Supervisor.USE_DEFAULT, 305 init_feed_dict=init_feed_dict, 306 is_chief=supervisor_is_chief, 307 logdir=output_dir, 308 saver=_make_saver(graph, keep_checkpoint_max), 309 global_step=global_step_tensor, 310 summary_op=None, 311 summary_writer=summary_writer, 312 save_model_secs=supervisor_save_model_secs, 313 init_fn=init_fn) 314 session = supervisor.PrepareSession(master=supervisor_master, 315 start_standard_services=True) 316 supervisor.StartQueueRunners(session) 317 318 with session: 319 get_current_step = lambda: session.run(global_step_tensor) 320 321 start_step = get_current_step() 322 last_step = start_step 323 last_log_step = start_step 324 loss_value = None 325 logging.info('Training steps [%d,%s)', last_step, 'inf' 326 if max_steps is None else str(max_steps)) 327 328 excinfo = None 329 try: 330 while not supervisor.ShouldStop() and ( 331 (max_steps is None) or (last_step < max_steps)): 332 start_time = time.time() 333 feed_dict = feed_fn() if feed_fn is not None else None 334 335 outputs, should_stop = _run_with_monitors( 336 session, last_step + 1, [train_op, loss_op], feed_dict, monitors) 337 338 loss_value = outputs[loss_op.name] 339 if np.isnan(loss_value): 340 failure_message = 'Model diverged with loss = NaN.' 341 if fail_on_nan_loss: 342 logging.error(failure_message) 343 raise monitors_lib.NanLossDuringTrainingError() 344 else: 345 logging.warning(failure_message) 346 347 if should_stop: 348 break 349 350 this_step = get_current_step() 351 352 if this_step <= last_step: 353 logging.error( 354 'Global step was not incremented by train op at step %s' 355 ': new step %d', last_step, this_step) 356 357 last_step = this_step 358 is_last_step = (max_steps is not None) and (last_step >= max_steps) 359 if is_last_step or (last_step - last_log_step >= log_every_steps): 360 logging.info( 361 'training step %d, loss = %.5f (%.3f sec/batch).', 362 last_step, loss_value, float(time.time() - start_time)) 363 last_log_step = last_step 364 except errors.OutOfRangeError as e: 365 logging.warn('Got exception during tf.learn training loop possibly ' 366 'due to exhausted input queue %s.', e) 367 except StopIteration: 368 logging.info('Exhausted input iterarator.') 369 except BaseException as e: # pylint: disable=broad-except 370 # Hold on to any other exceptions while we try recording a final 371 # checkpoint and summary. 372 excinfo = sys.exc_info() 373 finally: 374 try: 375 # Call supervisor.Stop() from within a try block because it re-raises 376 # exceptions thrown by the supervised threads. 377 supervisor.Stop(close_summary_writer=False) 378 379 # Save one last checkpoint and summaries 380 # TODO(wicke): This should be handled by Supervisor 381 382 # In case we encountered an exception in the try block before we updated 383 # last_step, update it here (again). 384 last_step = get_current_step() 385 if supervisor_is_chief: 386 ckpt_path = supervisor.save_path 387 logging.info('Saving checkpoint for step %d to checkpoint: %s.', 388 last_step, ckpt_path) 389 supervisor.saver.save(session, ckpt_path, global_step=last_step) 390 391 # Finish monitors. 392 for monitor in monitors: 393 monitor.end() 394 395 # catch OutOfRangeError which is thrown when queue is out of data (and for 396 # other reasons as well). 397 except errors.OutOfRangeError as e: 398 logging.warn('OutOfRangeError in tf.learn final checkpoint possibly ' 399 'due to exhausted input queue. Note: summary_op is not ' 400 'expected to trigger dequeues. %s.', e) 401 except BaseException as e: # pylint: disable=broad-except 402 # If we don't already have an exception to re-raise, raise this one. 403 if not excinfo: 404 raise 405 # Otherwise, log this one and raise the other in the finally block. 406 logging.error('Got exception during tf.learn final checkpoint %s.', e) 407 finally: 408 if excinfo: 409 reraise(*excinfo) 410 return loss_value 411 412 413def _get_first_op_from_collection(collection_name): 414 elements = ops.get_collection(collection_name) 415 if elements: 416 return elements[0] 417 return None 418 419 420def _get_saver(): 421 """Lazy init and return saver.""" 422 saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS) 423 if saver is None and variables.global_variables(): 424 saver = tf_saver.Saver() 425 ops.add_to_collection(ops.GraphKeys.SAVERS, saver) 426 return saver 427 428 429def _get_ready_op(): 430 ready_op = _get_first_op_from_collection(ops.GraphKeys.READY_OP) 431 if ready_op is None: 432 ready_op = variables.report_uninitialized_variables() 433 ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op) 434 return ready_op 435 436 437def _get_local_init_op(): 438 """Returns the local init ops to initialize tables and local variables.""" 439 local_init_op = _get_first_op_from_collection( 440 ops.GraphKeys.LOCAL_INIT_OP) 441 if local_init_op is None: 442 op_list = [ 443 variables.local_variables_initializer(), 444 lookup_ops.tables_initializer() 445 ] 446 if op_list: 447 local_init_op = control_flow_ops.group(*op_list) 448 ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op) 449 return local_init_op 450 451 452def _eval_results_to_str(eval_results): 453 return ', '.join('%s = %s' % (k, v) for k, v in sorted(eval_results.items())) 454 455 456def _write_summary_results(output_dir, eval_results, current_global_step): 457 """Writes eval results into summary file in given dir.""" 458 logging.info('Saving evaluation summary for step %d: %s', current_global_step, 459 _eval_results_to_str(eval_results)) 460 summary_writer = get_summary_writer(output_dir) 461 summary = summary_pb2.Summary() 462 for key in eval_results: 463 if eval_results[key] is None: 464 continue 465 value = summary.value.add() 466 value.tag = key 467 if (isinstance(eval_results[key], np.float32) or 468 isinstance(eval_results[key], float)): 469 value.simple_value = float(eval_results[key]) 470 else: 471 logging.warn('Skipping summary for %s, must be a float or np.float32.', 472 key) 473 summary_writer.add_summary(summary, current_global_step) 474 summary_writer.flush() 475 476 477@_graph_action_deprecation 478def evaluate(graph, 479 output_dir, 480 checkpoint_path, 481 eval_dict, 482 update_op=None, 483 global_step_tensor=None, 484 supervisor_master='', 485 log_every_steps=10, 486 feed_fn=None, 487 max_steps=None): 488 """Evaluate a model loaded from a checkpoint. 489 490 Given `graph`, a directory to write summaries to (`output_dir`), a checkpoint 491 to restore variables from, and a `dict` of `Tensor`s to evaluate, run an eval 492 loop for `max_steps` steps, or until an exception (generally, an 493 end-of-input signal from a reader operation) is raised from running 494 `eval_dict`. 495 496 In each step of evaluation, all tensors in the `eval_dict` are evaluated, and 497 every `log_every_steps` steps, they are logged. At the very end of evaluation, 498 a summary is evaluated (finding the summary ops using `Supervisor`'s logic) 499 and written to `output_dir`. 500 501 Args: 502 graph: A `Graph` to train. It is expected that this graph is not in use 503 elsewhere. 504 output_dir: A string containing the directory to write a summary to. 505 checkpoint_path: A string containing the path to a checkpoint to restore. 506 Can be `None` if the graph doesn't require loading any variables. 507 eval_dict: A `dict` mapping string names to tensors to evaluate. It is 508 evaluated in every logging step. The result of the final evaluation is 509 returned. If `update_op` is None, then it's evaluated in every step. If 510 `max_steps` is `None`, this should depend on a reader that will raise an 511 end-of-input exception when the inputs are exhausted. 512 update_op: A `Tensor` which is run in every step. 513 global_step_tensor: A `Variable` containing the global step. If `None`, 514 one is extracted from the graph using the same logic as in `Supervisor`. 515 Used to place eval summaries on training curves. 516 supervisor_master: The master string to use when preparing the session. 517 log_every_steps: Integer. Output logs every `log_every_steps` evaluation 518 steps. The logs contain the `eval_dict` and timing information. 519 feed_fn: A function that is called every iteration to produce a `feed_dict` 520 passed to `session.run` calls. Optional. 521 max_steps: Integer. Evaluate `eval_dict` this many times. 522 523 Returns: 524 A tuple `(eval_results, global_step)`: 525 eval_results: A `dict` mapping `string` to numeric values (`int`, `float`) 526 that are the result of running eval_dict in the last step. `None` if no 527 eval steps were run. 528 global_step: The global step this evaluation corresponds to. 529 530 Raises: 531 ValueError: if `output_dir` is empty. 532 """ 533 if not output_dir: 534 raise ValueError('Output directory should be non-empty %s.' % output_dir) 535 with graph.as_default(): 536 global_step_tensor = contrib_variables.assert_or_get_global_step( 537 graph, global_step_tensor) 538 539 # Create or get summary op, global_step and saver. 540 saver = _get_saver() 541 local_init_op = _get_local_init_op() 542 ready_for_local_init_op = _get_first_op_from_collection( 543 ops.GraphKeys.READY_FOR_LOCAL_INIT_OP) 544 ready_op = _get_ready_op() 545 546 session_manager = session_manager_lib.SessionManager( 547 local_init_op=local_init_op, 548 ready_op=ready_op, 549 ready_for_local_init_op=ready_for_local_init_op) 550 session, initialized = session_manager.recover_session( 551 master=supervisor_master, 552 saver=saver, 553 checkpoint_dir=checkpoint_path) 554 555 # Start queue runners. 556 coord = coordinator.Coordinator() 557 threads = queue_runner.start_queue_runners(session, coord) 558 559 with session: 560 if not initialized: 561 logging.warning('Failed to initialize from %s.', checkpoint_path) 562 # TODO(ipolosukhin): This should be failing, but old code relies on that. 563 session.run(variables.global_variables_initializer()) 564 if checkpoint_path: 565 _restore_from_checkpoint(session, graph, checkpoint_path, saver) 566 567 current_global_step = session.run(global_step_tensor) 568 eval_results = None 569 # TODO(amodei): Fix this to run through the eval set exactly once. 570 step = 0 571 eval_step = None 572 feed_dict = None 573 logging.info('Eval steps [%d,%s) for training step %d.', step, 574 'inf' if max_steps is None 575 else str(max_steps), current_global_step) 576 try: 577 try: 578 while (max_steps is None) or (step < max_steps): 579 step += 1 580 start_time = time.time() 581 feed_dict = feed_fn() if feed_fn is not None else None 582 if update_op is not None: 583 session.run(update_op, feed_dict=feed_dict) 584 else: 585 eval_results = session.run(eval_dict, feed_dict=feed_dict) 586 eval_step = step 587 588 # TODO(wicke): We should assert that the global step hasn't changed. 589 if step % log_every_steps == 0: 590 if eval_step is None or step != eval_step: 591 eval_results = session.run(eval_dict, feed_dict=feed_dict) 592 eval_step = step 593 duration = time.time() - start_time 594 logging.info('Results after %d steps (%.3f sec/batch): %s.', 595 step, float(duration), 596 _eval_results_to_str(eval_results)) 597 finally: 598 if eval_results is None or step != eval_step: 599 eval_results = session.run(eval_dict, feed_dict=feed_dict) 600 eval_step = step 601 # Stop session first, before queue runners. 602 session.close() 603 604 # Stop queue runners. 605 try: 606 coord.request_stop() 607 coord.join(threads, stop_grace_period_secs=120) 608 except (RuntimeError, errors.CancelledError) as e: 609 logging.warning('Coordinator didn\'t stop cleanly: %s', e) 610 611 # catch OutOfRangeError which is thrown when queue is out of data (and for 612 # other reasons as well). 613 except errors.OutOfRangeError as e: 614 if max_steps is None: 615 logging.info('Input queue is exhausted.') 616 else: 617 logging.warn('Input queue is exhausted: %s.', e) 618 # catch StopIteration which is thrown is DataReader is out of data. 619 except StopIteration as e: 620 if max_steps is None: 621 logging.info('Input iterator is exhausted.') 622 else: 623 logging.warn('Input iterator is exhausted: %s.', e) 624 625 # Save summaries for this evaluation. 626 _write_summary_results(output_dir, eval_results, current_global_step) 627 628 return eval_results, current_global_step 629 630 631@_graph_action_deprecation 632def run_n(output_dict, feed_dict=None, restore_checkpoint_path=None, n=1): 633 """Run `output_dict` tensors `n` times, with the same `feed_dict` each run. 634 635 Args: 636 output_dict: A `dict` mapping string names to tensors to run. Must all be 637 from the same graph. 638 feed_dict: `dict` of input values to feed each run. 639 restore_checkpoint_path: A string containing the path to a checkpoint to 640 restore. 641 n: Number of times to repeat. 642 643 Returns: 644 A list of `n` `dict` objects, each containing values read from `output_dict` 645 tensors. 646 """ 647 return run_feeds( 648 output_dict=output_dict, 649 feed_dicts=itertools.repeat(feed_dict, n), 650 restore_checkpoint_path=restore_checkpoint_path) 651 652 653@_graph_action_deprecation 654def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None): 655 """Run `output_dict` tensors with each input in `feed_dicts`. 656 657 If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise, 658 init all variables. 659 660 Args: 661 output_dict: A `dict` mapping string names to `Tensor` objects to run. 662 Tensors must all be from the same graph. 663 feed_dicts: Iterable of `dict` objects of input values to feed. 664 restore_checkpoint_path: A string containing the path to a checkpoint to 665 restore. 666 667 Yields: 668 A sequence of dicts of values read from `output_dict` tensors, one item 669 yielded for each item in `feed_dicts`. Keys are the same as `output_dict`, 670 values are the results read from the corresponding `Tensor` in 671 `output_dict`. 672 673 Raises: 674 ValueError: if `output_dict` or `feed_dicts` is None or empty. 675 """ 676 if not output_dict: 677 raise ValueError('output_dict is invalid: %s.' % output_dict) 678 if not feed_dicts: 679 raise ValueError('feed_dicts is invalid: %s.' % feed_dicts) 680 681 graph = contrib_ops.get_graph_from_inputs(output_dict.values()) 682 with graph.as_default() as g: 683 with tf_session.Session('') as session: 684 session.run( 685 resources.initialize_resources(resources.shared_resources() + 686 resources.local_resources())) 687 if restore_checkpoint_path: 688 _restore_from_checkpoint(session, g, restore_checkpoint_path) 689 else: 690 session.run(variables.global_variables_initializer()) 691 session.run(variables.local_variables_initializer()) 692 session.run(lookup_ops.tables_initializer()) 693 coord = coordinator.Coordinator() 694 threads = None 695 try: 696 threads = queue_runner.start_queue_runners(session, coord=coord) 697 for f in feed_dicts: 698 yield session.run(output_dict, f) 699 finally: 700 coord.request_stop() 701 if threads: 702 coord.join(threads, stop_grace_period_secs=120) 703 704 705@_graph_action_deprecation 706def run_feeds(*args, **kwargs): 707 """See run_feeds_iter(). Returns a `list` instead of an iterator.""" 708 return list(run_feeds_iter(*args, **kwargs)) 709 710 711@_graph_action_deprecation 712def infer(restore_checkpoint_path, output_dict, feed_dict=None): 713 """Restore graph from `restore_checkpoint_path` and run `output_dict` tensors. 714 715 If `restore_checkpoint_path` is supplied, restore from checkpoint. Otherwise, 716 init all variables. 717 718 Args: 719 restore_checkpoint_path: A string containing the path to a checkpoint to 720 restore. 721 output_dict: A `dict` mapping string names to `Tensor` objects to run. 722 Tensors must all be from the same graph. 723 feed_dict: `dict` object mapping `Tensor` objects to input values to feed. 724 725 Returns: 726 Dict of values read from `output_dict` tensors. Keys are the same as 727 `output_dict`, values are the results read from the corresponding `Tensor` 728 in `output_dict`. 729 730 Raises: 731 ValueError: if `output_dict` or `feed_dicts` is None or empty. 732 """ 733 return run_feeds(output_dict=output_dict, 734 feed_dicts=[feed_dict] if feed_dict is not None else [None], 735 restore_checkpoint_path=restore_checkpoint_path)[0] 736