• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A client interface for TensorFlow."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import re
24import threading
25import warnings
26
27import numpy as np
28
29from tensorflow.core.protobuf import config_pb2
30from tensorflow.python import pywrap_tensorflow as tf_session
31from tensorflow.python.framework import device
32from tensorflow.python.framework import error_interpolation
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import sparse_tensor
36from tensorflow.python.ops import session_ops
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import compat
39from tensorflow.python.util import nest
40from tensorflow.python.util.tf_export import tf_export
41
42
43class SessionInterface(object):
44  """Base class for implementations of TensorFlow client sessions."""
45
46  @property
47  def graph(self):
48    """The underlying TensorFlow graph, to be used in building Operations."""
49    raise NotImplementedError('graph')
50
51  @property
52  def sess_str(self):
53    """The TensorFlow process to which this session will connect."""
54    raise NotImplementedError('sess_str')
55
56  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
57    """Runs operations in the session. See `BaseSession.run()` for details."""
58    raise NotImplementedError('run')
59
60  def partial_run_setup(self, fetches, feeds=None):
61    """Sets up the feeds and fetches for partial runs in the session."""
62    raise NotImplementedError('partial_run_setup')
63
64  def partial_run(self, handle, fetches, feed_dict=None):
65    """Continues the execution with additional feeds and fetches."""
66    raise NotImplementedError('partial_run')
67
68
69def _get_indexed_slices_value_from_fetches(fetched_vals):
70  return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1],
71                                fetched_vals[2]
72                                if len(fetched_vals) == 3 else None)
73
74
75def _get_feeds_for_indexed_slices(feed, feed_val):
76  return list(
77      zip([feed.values, feed.indices] if feed.dense_shape is None else
78          [feed.values, feed.indices, feed.dense_shape], feed_val))
79
80
81# List of extensions supported to convert run arguments into actual fetches and
82# feeds.
83#
84# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2),
85# where the function signatures are:
86#   fetch_fn : Type -> (list of Tensors,
87#                       lambda: list of fetched np.ndarray -> TypeVal)
88#   feed_fn1 : Type, TypeVal -> list of (Tensor, value)
89#   feed_fn2 : Type -> list of Tensors
90#
91# `fetch_fn` describes how to expand fetch into its
92# component Tensors and how to contract the fetched results back into
93# a single return value.
94#
95# Each feed function describes how to unpack a single fed value and map it to
96# feeds of one or more tensors and their corresponding values: `feed_fn1` is
97# used to feed a run, `feed_fn2` to set up a partial run.
98#
99# TODO(touts): We could reimplement these as specialized _FeedMapper
100# implementations after we refactor the feed handling code to use them.
101#
102# Eventually, this registration could be opened up to support custom Tensor
103# expansions.
104# pylint: disable=g-long-lambda
105_REGISTERED_EXPANSIONS = [
106    # SparseTensors are fetched as SparseTensorValues. They can be fed
107    # SparseTensorValues or normal tuples.
108    (sparse_tensor.SparseTensor,
109     lambda fetch: (
110         [fetch.indices, fetch.values, fetch.dense_shape],
111         lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)),
112     lambda feed, feed_val: list(zip(
113         [feed.indices, feed.values, feed.dense_shape], feed_val)),
114     lambda feed: [feed.indices, feed.values, feed.dense_shape]),
115    # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
116    # IndexedSlicesValues or normal tuples.
117    (ops.IndexedSlices,
118     lambda fetch: (
119         [fetch.values, fetch.indices] if fetch.dense_shape is None
120         else [fetch.values, fetch.indices, fetch.dense_shape],
121         _get_indexed_slices_value_from_fetches),
122     _get_feeds_for_indexed_slices,
123     lambda feed: [feed.values, feed.indices] if feed.dense_shape is None
124     else [feed.values, feed.indices, feed.dense_shape]),
125    # The default catches all other types and performs no expansions.
126    (object,
127     lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
128     lambda feed, feed_val: [(feed, feed_val)],
129     lambda feed: [feed])]
130
131# pylint: enable=g-long-lambda
132
133
134def _convert_to_numpy_obj(numpy_dtype, obj):
135  """Explicitly convert obj based on numpy type except for string type."""
136  return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
137
138
139def register_session_run_conversion_functions(
140    tensor_type,
141    fetch_function,
142    feed_function=None,
143    feed_function_for_partial_run=None):
144  """Register fetch and feed conversion functions for `tf.Session.run()`.
145
146  This function registers a triple of conversion functions for fetching and/or
147  feeding values of user-defined types in a call to tf.Session.run().
148
149  An example
150
151  ```python
152     class SquaredTensor(object):
153       def __init__(self, tensor):
154         self.sq = tf.square(tensor)
155     #you can define conversion functions as follows:
156     fetch_function = lambda squared_tensor:([squared_tensor.sq],
157                                             lambda val: val[0])
158     feed_function = lambda feed, feed_val: [(feed.sq, feed_val)]
159     feed_function_for_partial_run = lambda feed: [feed.sq]
160     #then after invoking this register function, you can use as follows:
161     session.run(squared_tensor1,
162                 feed_dict = {squared_tensor2 : some_numpy_array})
163  ```
164
165  Args:
166    tensor_type: The type for which you want to register a conversion function.
167    fetch_function: A callable that takes an object of type `tensor_type` and
168      returns a tuple, where the first element is a list of `tf.Tensor` objects,
169      and the second element is a callable that takes a list of ndarrays and
170      returns an object of some value type that corresponds to `tensor_type`.
171      fetch_function describes how to expand fetch into its component Tensors
172      and how to contract the fetched results back into a single return value.
173    feed_function: A callable that takes feed_key and feed_value as input, and
174      returns a list of tuples (feed_tensor, feed_val), feed_key must have type
175      `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed
176      function describes how to unpack a single fed value and map it to feeds
177      of one or more tensors and their corresponding values.
178    feed_function_for_partial_run: A callable for specifying tensor values to
179      feed when setting up a partial run, which takes a `tensor_type` type
180      object as input, and returns a list of Tensors.
181
182  Raises:
183    ValueError: If `tensor_type` has already been registered.
184  """
185  for conversion_function in _REGISTERED_EXPANSIONS:
186    if issubclass(conversion_function[0], tensor_type):
187      raise ValueError('%s has already been registered so ignore it.' %
188                       tensor_type)
189
190  _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
191                                    feed_function_for_partial_run))
192
193
194def _is_attrs_instance(obj):
195  """Returns True if the given obj is an instance of attrs-decorated class."""
196  return getattr(obj.__class__, '__attrs_attrs__', None) is not None
197
198
199def _get_attrs_values(obj):
200  """Returns the list of values from an attrs instance."""
201  attrs = getattr(obj.__class__, '__attrs_attrs__')
202  return [getattr(obj, a.name) for a in attrs]
203
204
205class _FetchMapper(object):
206  """Definition of the interface provided by fetch mappers.
207
208  Fetch mappers are utility classes used by the _FetchHandler to handle
209  arbitrary structures for the `fetch` argument to `Session.run()`.
210
211  The `fetch` argument can be of various shapes: single tensor or op, list of
212  fetches, tuple of fetches, namedtuple of fetches, or dict of fetches.  The
213  structures can be arbitrarily nested.
214
215  The low level run() API only wants a list of tensor or op names.  The various
216  `_FetchMapper` subclasses below take care of handling the different shapes:
217  uniquifying the fetches, and constructing results with the original shape.
218  """
219
220  def unique_fetches(self):
221    """Return the list of unique tensors or ops needed by this fetch mapper.
222
223    Returns:
224      A list of tensors or ops.
225    """
226    raise NotImplementedError('Must be implemented by subclasses')
227
228  def build_results(self, values):
229    """Build results that match the original shape of the fetch.
230
231    Args:
232      values: List of values returned by run(). The values correspond
233        exactly to the list tensors or ops returned by unique_fetches().
234
235    Returns:
236      A struct of the same shape as the original fetch object handled by
237      this fetch mapper.  In the returned struct, the original fetches are
238      replaced by their fetched values.
239    """
240    raise NotImplementedError('Must be implemented by subclasses')
241
242  @staticmethod
243  def for_fetch(fetch):
244    """Creates fetch mapper that handles the structure of `fetch`.
245
246    The default graph must be the one from which we want to fetch values when
247    this function is called.
248
249    Args:
250      fetch: An arbitrary fetch structure: singleton, list, tuple,
251        namedtuple, or dict.
252
253    Returns:
254      An instance of a subclass of `_FetchMapper` that handles the shape.
255    """
256    if fetch is None:
257      raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
258                                                                 type(fetch)))
259    elif isinstance(fetch, (list, tuple)):
260      # NOTE(touts): This is also the code path for namedtuples.
261      return _ListFetchMapper(fetch)
262    elif isinstance(fetch, collections.Mapping):
263      return _DictFetchMapper(fetch)
264    elif _is_attrs_instance(fetch):
265      return _AttrsFetchMapper(fetch)
266    else:
267      # Look for a handler in the registered expansions.
268      for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
269        if isinstance(fetch, tensor_type):
270          fetches, contraction_fn = fetch_fn(fetch)
271          return _ElementFetchMapper(fetches, contraction_fn)
272    # Did not find anything.
273    raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
274                                                               type(fetch)))
275
276
277class _ElementFetchMapper(_FetchMapper):
278  """Fetch mapper for singleton tensors and ops."""
279
280  def __init__(self, fetches, contraction_fn):
281    """Creates an _ElementFetchMapper.
282
283    This is the fetch mapper used for leaves in the fetch struct.  Because of
284    the expansions mechanism, a leaf can actually fetch more than one tensor.
285
286    Also note that the fetches here can be just strings (tensor or op names) or
287    any other object that the graph knows how to convert to a tensor, such as a
288    Variable.  So we have to run each fetch through `as_graph_element()` to get
289    the corresponding tensor or op.
290
291    Args:
292      fetches: List of objects, as returned by a fetch_fn defined
293        in _REGISTERED_EXPANSIONS.
294      contraction_fn: Callable as returned by a fetch_fn.
295    """
296    self._unique_fetches = []
297    for fetch in fetches:
298      try:
299        self._unique_fetches.append(ops.get_default_graph().as_graph_element(
300            fetch, allow_tensor=True, allow_operation=True))
301      except TypeError as e:
302        raise TypeError('Fetch argument %r has invalid type %r, '
303                        'must be a string or Tensor. (%s)' %
304                        (fetch, type(fetch), str(e)))
305      except ValueError as e:
306        raise ValueError('Fetch argument %r cannot be interpreted as a '
307                         'Tensor. (%s)' % (fetch, str(e)))
308      except KeyError as e:
309        raise ValueError('Fetch argument %r cannot be interpreted as a '
310                         'Tensor. (%s)' % (fetch, str(e)))
311    self._contraction_fn = contraction_fn
312
313  def unique_fetches(self):
314    return self._unique_fetches
315
316  def build_results(self, values):
317    if not values:
318      # 'Operation' case
319      return None
320    else:
321      return self._contraction_fn(values)
322
323
324def _uniquify_fetches(fetch_mappers):
325  """Uniquifies fetches from a list of fetch_mappers.
326
327  This is a utility function used by _ListFetchMapper and _DictFetchMapper.  It
328  gathers all the unique fetches from a list of mappers and builds a list
329  containing all of them but without duplicates (unique_fetches).
330
331  It also returns a 2-D list of integers (values_indices) indicating at which
332  index in unique_fetches the fetches of the mappers are located.
333
334  This list is as follows:
335    values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index
336
337  Args:
338    fetch_mappers: list of fetch mappers.
339
340  Returns:
341    A list of fetches.
342    A 2-D list of integers.
343  """
344  unique_fetches = []
345  value_indices = []
346  seen_fetches = {}
347  for m in fetch_mappers:
348    m_value_indices = []
349    for f in m.unique_fetches():
350      j = seen_fetches.get(f)
351      if j is None:
352        j = len(seen_fetches)
353        seen_fetches[f] = j
354        unique_fetches.append(f)
355      m_value_indices.append(j)
356    value_indices.append(m_value_indices)
357  return unique_fetches, value_indices
358
359
360class _ListFetchMapper(_FetchMapper):
361  """Fetch mapper for lists, tuples, and namedtuples."""
362
363  def __init__(self, fetches):
364    """Creates a _ListFetchMapper.
365
366    Args:
367      fetches: List, tuple, or namedtuple of fetches.
368    """
369    self._fetch_type = type(fetches)
370    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
371    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
372
373  def unique_fetches(self):
374    return self._unique_fetches
375
376  def build_results(self, values):
377    # Create the list of results for each mapper.
378    results = []
379    for m, vi in zip(self._mappers, self._value_indices):
380      results.append(m.build_results([values[j] for j in vi]))
381    # Return a value of the original type of the fetches.
382    if issubclass(self._fetch_type, list):
383      return results
384    elif self._fetch_type == tuple:
385      return tuple(results)
386    else:
387      # This is the code path for namedtuple.
388      return self._fetch_type(*results)
389
390
391class _DictFetchMapper(_FetchMapper):
392  """Fetch mapper for dicts."""
393
394  def __init__(self, fetches):
395    """Creates a _DictFetchMapper.
396
397    Args:
398      fetches: Dict of fetches.
399    """
400    self._fetch_type = type(fetches)
401    self._keys = fetches.keys()
402    self._mappers = [
403        _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
404    ]
405    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
406
407  def unique_fetches(self):
408    return self._unique_fetches
409
410  def build_results(self, values):
411    results = self._fetch_type()
412    for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
413      results[k] = m.build_results([values[j] for j in vi])
414    return results
415
416
417class _AttrsFetchMapper(_FetchMapper):
418  """Fetch mapper for attrs decorated classes."""
419
420  def __init__(self, fetches):
421    """Creates a _AttrsFetchMapper.
422
423    Args:
424      fetches: An instance of an attrs decorated class.
425    """
426    values = _get_attrs_values(fetches)
427    self._fetch_type = type(fetches)
428    self._mappers = [
429        _FetchMapper.for_fetch(fetch) for fetch in values
430    ]
431    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
432
433  def unique_fetches(self):
434    return self._unique_fetches
435
436  def build_results(self, values):
437    results = []
438    for m, vi in zip(self._mappers, self._value_indices):
439      results.append(m.build_results([values[j] for j in vi]))
440    return self._fetch_type(*results)
441
442
443class _FetchHandler(object):
444  """Handler for structured fetches.
445
446  Given a graph, a user-provided structure for fetches, and a feed dict, this
447  class takes care of generating a list of tensor names to fetch and op names
448  to run for a low level `run()` call.
449
450  Given the results of the low level run call, this class can also rebuild a
451  result structure matching the user-provided structure for fetches, but
452  containing the corresponding results.
453  """
454
455  # TODO(touts): Make this class also take care of destructuring the feed
456  # dict instead of doing it in the callers.
457
458  def __init__(self, graph, fetches, feeds, feed_handles=None):
459    """Creates a fetch handler.
460
461    Args:
462      graph: Graph of the fetches.   Used to check for fetchability
463        and to convert all fetches to tensors or ops as needed.
464      fetches: An arbitrary fetch structure: singleton, list, tuple,
465        namedtuple, or dict.
466      feeds: A feed dict where keys are Tensors.
467      feed_handles: A dict from feed Tensors to TensorHandle objects used as
468        direct feeds.
469    """
470    with graph.as_default():
471      self._fetch_mapper = _FetchMapper.for_fetch(fetches)
472    self._fetches = []
473    self._targets = []
474    self._feeds = feeds
475    self._feed_handles = feed_handles or {}
476    self._ops = []
477    self._fetch_handles = {}
478    for fetch in self._fetch_mapper.unique_fetches():
479      if isinstance(fetch, ops.Operation):
480        self._assert_fetchable(graph, fetch)
481        self._targets.append(fetch)
482        self._ops.append(True)
483      else:
484        self._assert_fetchable(graph, fetch.op)
485        self._fetches.append(fetch)
486        self._ops.append(False)
487      # Remember the fetch if it is for a tensor handle.
488      if (isinstance(fetch, ops.Tensor) and
489          (fetch.op.type == 'GetSessionHandle' or
490           fetch.op.type == 'GetSessionHandleV2')):
491        self._fetch_handles[fetch] = fetch.op.inputs[0].dtype
492    self._final_fetches = [x for x in self._fetches if x not in feeds]
493
494  def _assert_fetchable(self, graph, op):
495    if not graph.is_fetchable(op):
496      raise ValueError(
497          'Operation %r has been marked as not fetchable.' % op.name)
498
499  def fetches(self):
500    """Return the unique names of tensors to fetch.
501
502    Returns:
503      A list of strings.
504    """
505    return self._final_fetches
506
507  def targets(self):
508    """Return the unique names of ops to run.
509
510    Returns:
511      A list of strings.
512    """
513    return self._targets
514
515  def build_results(self, session, tensor_values):
516    """Build results matching the original fetch shape.
517
518    `tensor_values` must be a list of the same length as
519    the one returned by `fetches()`, and holding the requested
520    fetch values.
521
522    This method builds a struct with the same shape as the original `fetches`
523    passed to the constructor, in which the fetches are replaced by their
524    fetched value.
525
526    Args:
527      session: The enclosing session.  Used for tensor handles.
528      tensor_values: List of values matching the list returned
529        by fetches().
530
531    Returns:
532      A structure of the same shape as the original `fetches` argument but
533        containing tensors or None (for fetched ops).
534    """
535    full_values = []
536    assert len(self._final_fetches) == len(tensor_values)
537    i = 0
538    j = 0
539    for is_op in self._ops:
540      if is_op:
541        full_values.append(None)
542      else:
543        # If the fetch was in the feeds, use the fed value, otherwise
544        # use the returned value.
545        if self._fetches[i] in self._feed_handles:
546          # A fetch had a corresponding direct TensorHandle feed. Call eval()
547          # to obtain the Tensor value from the TensorHandle.
548          value = self._feed_handles[self._fetches[i]].eval()
549        else:
550          value = self._feeds.get(self._fetches[i])
551        if value is None:
552          value = tensor_values[j]
553          j += 1
554        dtype = self._fetch_handles.get(self._fetches[i])
555        if dtype:
556          full_values.append(session_ops.TensorHandle(value, dtype, session))
557        else:
558          full_values.append(value)
559        i += 1
560    assert j == len(tensor_values)
561    return self._fetch_mapper.build_results(full_values)
562
563
564def _name_list(tensor_list):
565  """Utility function for transitioning to the new session API.
566
567  Args:
568    tensor_list: a list of `Tensor`s.
569
570  Returns:
571    A list of each `Tensor`s name (as byte arrays).
572  """
573  return [compat.as_bytes(t.name) for t in tensor_list]
574
575
576class _DeviceAttributes(object):
577  """Struct-like object describing a device's attributes.
578
579  Each device has 3 key properties:
580   - name: the fully-qualified TensorFlow path to the device. For
581        example: /job:worker/replica:0/task:3/device:CPU:0
582   - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.)
583   - memory_limit_bytes: the maximum amount of memory available on the device
584        (in bytes).
585  """
586
587  def __init__(self, name, device_type, memory_limit_bytes, incarnation):
588    self._name = device.canonical_name(name)
589    self._device_type = device_type
590    self._memory_limit_bytes = memory_limit_bytes
591    self._incarnation = incarnation
592
593  @property
594  def name(self):
595    return self._name
596
597  @property
598  def device_type(self):
599    return self._device_type
600
601  @property
602  def memory_limit_bytes(self):
603    return self._memory_limit_bytes
604
605  @property
606  def incarnation(self):
607    return self._incarnation
608
609  def __repr__(self):
610    return '_DeviceAttributes(%s, %s, %d, %d)' % (
611        self.name,
612        self.device_type,
613        self.memory_limit_bytes,
614        self.incarnation,
615    )
616
617
618class BaseSession(SessionInterface):
619  """A class for interacting with a TensorFlow computation.
620
621  The BaseSession enables incremental graph building with inline
622  execution of Operations and evaluation of Tensors.
623  """
624
625  def __init__(self, target='', graph=None, config=None):
626    """Constructs a new TensorFlow session.
627
628    Args:
629      target: (Optional) The TensorFlow execution engine to connect to.
630      graph: (Optional) The graph to be used. If this argument is None,
631        the default graph will be used.
632      config: (Optional) ConfigProto proto used to configure the session.
633
634    Raises:
635      tf.errors.OpError: Or one of its subclasses if an error occurs while
636        creating the TensorFlow session.
637      TypeError: If one of the arguments has the wrong type.
638    """
639    if graph is None:
640      self._graph = ops.get_default_graph()
641    else:
642      if not isinstance(graph, ops.Graph):
643        raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
644      self._graph = graph
645
646    self._opened = False
647    self._closed = False
648
649    self._current_version = 0
650    self._extend_lock = threading.Lock()
651    if target is not None:
652      try:
653        self._target = compat.as_bytes(target)
654      except TypeError:
655        raise TypeError('target must be a string, but got %s' % type(target))
656    else:
657      self._target = None
658
659    self._delete_lock = threading.Lock()
660    self._dead_handles = []
661
662    if config is not None:
663      if not isinstance(config, config_pb2.ConfigProto):
664        raise TypeError(
665            'config must be a tf.ConfigProto, but got %s' % type(config))
666      self._config = config
667      self._add_shapes = config.graph_options.infer_shapes
668    else:
669      self._config = None
670      self._add_shapes = False
671
672    self._session = None
673    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
674    try:
675      # pylint: disable=protected-access
676      self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts)
677      # pylint: enable=protected-access
678    finally:
679      tf_session.TF_DeleteSessionOptions(opts)
680
681  def list_devices(self):
682    """Lists available devices in this session.
683
684    ```python
685    devices = sess.list_devices()
686    for d in devices:
687      print(d.name)
688    ```
689
690    Each element in the list has the following properties:
691     - `name`: A string with the full name of the device. ex:
692          `/job:worker/replica:0/task:3/device:CPU:0`
693     - `device_type`: The type of the device (e.g. `CPU`, `GPU`, `TPU`.)
694     - `memory_limit`: The maximum amount of memory available on the device.
695          Note: depending on the device, it is possible the usable memory could
696          be substantially less.
697    Raises:
698      tf.errors.OpError: If it encounters an error (e.g. session is in an
699      invalid state, or network errors occur).
700
701    Returns:
702      A list of devices in the session.
703    """
704    raw_device_list = tf_session.TF_SessionListDevices(self._session)
705    device_list = []
706    size = tf_session.TF_DeviceListCount(raw_device_list)
707    for i in range(size):
708      name = tf_session.TF_DeviceListName(raw_device_list, i)
709      device_type = tf_session.TF_DeviceListType(raw_device_list, i)
710      memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
711      incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
712      device_list.append(
713          _DeviceAttributes(name, device_type, memory, incarnation))
714    tf_session.TF_DeleteDeviceList(raw_device_list)
715    return device_list
716
717  def close(self):
718    """Closes this session.
719
720    Calling this method frees all resources associated with the session.
721
722    Raises:
723      tf.errors.OpError: Or one of its subclasses if an error occurs while
724        closing the TensorFlow session.
725    """
726    if self._session and not self._closed:
727      self._closed = True
728      tf_session.TF_CloseSession(self._session)
729
730  def __del__(self):
731    # cleanly ignore all exceptions
732    try:
733      self.close()
734    except Exception:  # pylint: disable=broad-except
735      pass
736    if self._session is not None:
737      try:
738        tf_session.TF_DeleteSession(self._session)
739      except (AttributeError, TypeError):
740        # At shutdown, `c_api_util`, `tf_session`, or
741        # `tf_session.TF_DeleteSession` may have been garbage collected, causing
742        # the above method calls to fail. In this case, silently leak since the
743        # program is about to terminate anyway.
744        pass
745      self._session = None
746
747  @property
748  def graph(self):
749    """The graph that was launched in this session."""
750    return self._graph
751
752  @property
753  def graph_def(self):
754    """A serializable version of the underlying TensorFlow graph.
755
756    Returns:
757      A graph_pb2.GraphDef proto containing nodes for all of the Operations in
758      the underlying TensorFlow graph.
759    """
760    return self._graph.as_graph_def(add_shapes=self._add_shapes)
761
762  @property
763  def sess_str(self):
764    return self._target
765
766  def as_default(self):
767    """Returns a context manager that makes this object the default session.
768
769    Use with the `with` keyword to specify that calls to
770    `tf.Operation.run` or `tf.Tensor.eval` should be executed in
771    this session.
772
773    ```python
774    c = tf.constant(..)
775    sess = tf.Session()
776
777    with sess.as_default():
778      assert tf.get_default_session() is sess
779      print(c.eval())
780    ```
781
782    To get the current default session, use `tf.get_default_session`.
783
784    *N.B.* The `as_default` context manager *does not* close the
785    session when you exit the context, and you must close the session
786    explicitly.
787
788    ```python
789    c = tf.constant(...)
790    sess = tf.Session()
791    with sess.as_default():
792      print(c.eval())
793    # ...
794    with sess.as_default():
795      print(c.eval())
796
797    sess.close()
798    ```
799
800    Alternatively, you can use `with tf.Session():` to create a
801    session that is automatically closed on exiting the context,
802    including when an uncaught exception is raised.
803
804    *N.B.* The default session is a property of the current thread. If you
805    create a new thread, and wish to use the default session in that
806    thread, you must explicitly add a `with sess.as_default():` in that
807    thread's function.
808
809    *N.B.* Entering a `with sess.as_default():` block does not affect
810    the current default graph. If you are using multiple graphs, and
811    `sess.graph` is different from the value of `tf.get_default_graph`,
812    you must explicitly enter a `with sess.graph.as_default():` block
813    to make `sess.graph` the default graph.
814
815    Returns:
816      A context manager using this session as the default session.
817    """
818    return ops.default_session(self)
819
820  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
821    """Runs operations and evaluates tensors in `fetches`.
822
823    This method runs one "step" of TensorFlow computation, by
824    running the necessary graph fragment to execute every `Operation`
825    and evaluate every `Tensor` in `fetches`, substituting the values in
826    `feed_dict` for the corresponding input values.
827
828    The `fetches` argument may be a single graph element, or an arbitrarily
829    nested list, tuple, namedtuple, dict, or OrderedDict containing graph
830    elements at its leaves.  A graph element can be one of the following types:
831
832    * A `tf.Operation`.
833      The corresponding fetched value will be `None`.
834    * A `tf.Tensor`.
835      The corresponding fetched value will be a numpy ndarray containing the
836      value of that tensor.
837    * A `tf.SparseTensor`.
838      The corresponding fetched value will be a
839      `tf.SparseTensorValue`
840      containing the value of that sparse tensor.
841    * A `get_tensor_handle` op.  The corresponding fetched value will be a
842      numpy ndarray containing the handle of that tensor.
843    * A `string` which is the name of a tensor or operation in the graph.
844
845    The value returned by `run()` has the same shape as the `fetches` argument,
846    where the leaves are replaced by the corresponding values returned by
847    TensorFlow.
848
849    Example:
850
851    ```python
852       a = tf.constant([10, 20])
853       b = tf.constant([1.0, 2.0])
854       # 'fetches' can be a singleton
855       v = session.run(a)
856       # v is the numpy array [10, 20]
857       # 'fetches' can be a list.
858       v = session.run([a, b])
859       # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
860       # 1-D array [1.0, 2.0]
861       # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
862       MyData = collections.namedtuple('MyData', ['a', 'b'])
863       v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
864       # v is a dict with
865       # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
866       # 'b' (the numpy array [1.0, 2.0])
867       # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
868       # [10, 20].
869    ```
870
871    The optional `feed_dict` argument allows the caller to override
872    the value of tensors in the graph. Each key in `feed_dict` can be
873    one of the following types:
874
875    * If the key is a `tf.Tensor`, the
876      value may be a Python scalar, string, list, or numpy ndarray
877      that can be converted to the same `dtype` as that
878      tensor. Additionally, if the key is a
879      `tf.placeholder`, the shape of
880      the value will be checked for compatibility with the placeholder.
881    * If the key is a
882      `tf.SparseTensor`,
883      the value should be a
884      `tf.SparseTensorValue`.
885    * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
886      should be a nested tuple with the same structure that maps to their
887      corresponding values as above.
888
889    Each value in `feed_dict` must be convertible to a numpy array of the dtype
890    of the corresponding key.
891
892    The optional `options` argument expects a [`RunOptions`] proto. The options
893    allow controlling the behavior of this particular step (e.g. turning tracing
894    on).
895
896    The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
897    appropriate, the non-Tensor output of this step will be collected there. For
898    example, when users turn on tracing in `options`, the profiled info will be
899    collected into this argument and passed back.
900
901    Args:
902      fetches: A single graph element, a list of graph elements,
903        or a dictionary whose values are graph elements or lists of graph
904        elements (described above).
905      feed_dict: A dictionary that maps graph elements to values
906        (described above).
907      options: A [`RunOptions`] protocol buffer
908      run_metadata: A [`RunMetadata`] protocol buffer
909
910    Returns:
911      Either a single value if `fetches` is a single graph element, or
912      a list of values if `fetches` is a list, or a dictionary with the
913      same keys as `fetches` if that is a dictionary (described above).
914      Order in which `fetches` operations are evaluated inside the call
915      is undefined.
916
917    Raises:
918      RuntimeError: If this `Session` is in an invalid state (e.g. has been
919        closed).
920      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
921      ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
922        `Tensor` that doesn't exist.
923    """
924    options_ptr = tf_session.TF_NewBufferFromString(
925        compat.as_bytes(options.SerializeToString())) if options else None
926    run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
927
928    try:
929      result = self._run(None, fetches, feed_dict, options_ptr,
930                         run_metadata_ptr)
931      if run_metadata:
932        proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
933        run_metadata.ParseFromString(compat.as_bytes(proto_data))
934    finally:
935      if run_metadata_ptr:
936        tf_session.TF_DeleteBuffer(run_metadata_ptr)
937      if options:
938        tf_session.TF_DeleteBuffer(options_ptr)
939    return result
940
941  def partial_run(self, handle, fetches, feed_dict=None):
942    """Continues the execution with more feeds and fetches.
943
944    This is EXPERIMENTAL and subject to change.
945
946    To use partial execution, a user first calls `partial_run_setup()` and
947    then a sequence of `partial_run()`. `partial_run_setup` specifies the
948    list of feeds and fetches that will be used in the subsequent
949    `partial_run` calls.
950
951    The optional `feed_dict` argument allows the caller to override
952    the value of tensors in the graph. See run() for more information.
953
954    Below is a simple example:
955
956    ```python
957    a = array_ops.placeholder(dtypes.float32, shape=[])
958    b = array_ops.placeholder(dtypes.float32, shape=[])
959    c = array_ops.placeholder(dtypes.float32, shape=[])
960    r1 = math_ops.add(a, b)
961    r2 = math_ops.multiply(r1, c)
962
963    h = sess.partial_run_setup([r1, r2], [a, b, c])
964    res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
965    res = sess.partial_run(h, r2, feed_dict={c: res})
966    ```
967
968    Args:
969      handle: A handle for a sequence of partial runs.
970      fetches: A single graph element, a list of graph elements,
971        or a dictionary whose values are graph elements or lists of graph
972        elements (see documentation for `run`).
973      feed_dict: A dictionary that maps graph elements to values
974        (described above).
975
976    Returns:
977      Either a single value if `fetches` is a single graph element, or
978      a list of values if `fetches` is a list, or a dictionary with the
979      same keys as `fetches` if that is a dictionary
980      (see documentation for `run`).
981
982    Raises:
983      tf.errors.OpError: Or one of its subclasses on error.
984    """
985    # TODO(touts): Support feeding and fetching the same tensor.
986    return self._run(handle, fetches, feed_dict, None, None)
987
988  def partial_run_setup(self, fetches, feeds=None):
989    """Sets up a graph with feeds and fetches for partial run.
990
991    This is EXPERIMENTAL and subject to change.
992
993    Note that contrary to `run`, `feeds` only specifies the graph elements.
994    The tensors will be supplied by the subsequent `partial_run` calls.
995
996    Args:
997      fetches: A single graph element, or a list of graph elements.
998      feeds: A single graph element, or a list of graph elements.
999
1000    Returns:
1001      A handle for partial run.
1002
1003    Raises:
1004      RuntimeError: If this `Session` is in an invalid state (e.g. has been
1005        closed).
1006      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
1007      tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.
1008    """
1009
1010    def _feed_fn(feed):
1011      for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
1012        if isinstance(feed, tensor_type):
1013          return feed_fn(feed)
1014      raise TypeError('Feed argument %r has invalid type %r' % (feed,
1015                                                                type(feed)))
1016
1017    # Check session.
1018    if self._closed:
1019      raise RuntimeError('Attempted to use a closed Session.')
1020    if self.graph.version == 0:
1021      raise RuntimeError('The Session graph is empty.  Add operations to the '
1022                         'graph before calling run().')
1023
1024    if feeds is None:
1025      feeds = []
1026    # Create request.
1027    feed_list = []
1028
1029    # Validate and process feed_list.
1030    is_list_feed = isinstance(feeds, (list, tuple))
1031    if not is_list_feed:
1032      feeds = [feeds]
1033    for feed in feeds:
1034      for subfeed in _feed_fn(feed):
1035        try:
1036          subfeed_t = self.graph.as_graph_element(
1037              subfeed, allow_tensor=True, allow_operation=False)
1038          # pylint: disable=protected-access
1039          feed_list.append(subfeed_t._as_tf_output())
1040          # pylint: enable=protected-access
1041        except Exception as e:
1042          e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
1043          e.args = (e.message,)
1044          raise e
1045
1046    # Validate and process fetches.
1047    # TODO(touts): Support feeding and fetching the same tensor.
1048    fetch_handler = _FetchHandler(self._graph, fetches, {})
1049
1050    # Set up a graph with feeds and fetches for partial run.
1051    def _setup_fn(session, feed_list, fetch_list, target_list):
1052      self._extend_graph()
1053      return tf_session.TF_SessionPRunSetup_wrapper(
1054          session, feed_list, fetch_list, target_list)
1055
1056    # pylint: disable=protected-access
1057    final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
1058    final_targets = [op._c_op for op in fetch_handler.targets()]
1059    # pylint: enable=protected-access
1060
1061    return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
1062                         final_targets)
1063
1064  def _run(self, handle, fetches, feed_dict, options, run_metadata):
1065    """Perform either run or partial_run, depending the presence of `handle`."""
1066
1067    def _feed_fn(feed, feed_val):
1068      for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
1069        if isinstance(feed, tensor_type):
1070          return feed_fn(feed, feed_val)
1071      raise TypeError('Feed argument %r has invalid type %r' % (feed,
1072                                                                type(feed)))
1073
1074    # Check session.
1075    if self._closed:
1076      raise RuntimeError('Attempted to use a closed Session.')
1077    if self.graph.version == 0:
1078      raise RuntimeError('The Session graph is empty.  Add operations to the '
1079                         'graph before calling run().')
1080
1081    # Create request.
1082    feed_dict_tensor = {}
1083    feed_map = {}
1084
1085    # Validate and process feed_dict.
1086    feed_handles = {}
1087    if feed_dict:
1088      feed_dict = nest.flatten_dict_items(feed_dict)
1089      for feed, feed_val in feed_dict.items():
1090        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
1091          try:
1092            subfeed_t = self.graph.as_graph_element(
1093                subfeed, allow_tensor=True, allow_operation=False)
1094          except Exception as e:
1095            raise TypeError(
1096                'Cannot interpret feed_dict key as Tensor: ' + e.args[0])
1097
1098          if isinstance(subfeed_val, ops.Tensor):
1099            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
1100                            'Acceptable feed values include Python scalars, '
1101                            'strings, lists, numpy ndarrays, or TensorHandles. '
1102                            'For reference, the tensor object was ' +
1103                            str(feed_val) + ' which was passed to the '
1104                            'feed with key ' + str(feed) + '.')
1105
1106          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
1107          if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
1108              subfeed_dtype, subfeed_val) != subfeed_val:
1109            raise TypeError(
1110                'Type of feed value ' + str(subfeed_val) + ' with type ' + str(
1111                    type(subfeed_val)) +
1112                ' is not compatible with Tensor type ' + str(subfeed_dtype) +
1113                '. Try explicitly setting the type of the feed tensor'
1114                ' to a larger type (e.g. int64).')
1115
1116          is_tensor_handle_feed = isinstance(subfeed_val,
1117                                             session_ops.TensorHandle)
1118          if is_tensor_handle_feed:
1119            np_val = subfeed_val.to_numpy_array()
1120            feed_handles[subfeed_t] = subfeed_val
1121          else:
1122            np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
1123
1124          if (not is_tensor_handle_feed and
1125              not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
1126            raise ValueError('Cannot feed value of shape %r for Tensor %r, '
1127                             'which has shape %r' %
1128                             (np_val.shape, subfeed_t.name,
1129                              str(subfeed_t.get_shape())))
1130          if not self.graph.is_feedable(subfeed_t):
1131            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
1132
1133          feed_dict_tensor[subfeed_t] = np_val
1134          feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
1135
1136    # Create a fetch handler to take care of the structure of fetches.
1137    fetch_handler = _FetchHandler(
1138        self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1139
1140    # Run request and get response.
1141    # We need to keep the returned movers alive for the following _do_run().
1142    # These movers are no longer needed when _do_run() completes, and
1143    # are deleted when `movers` goes out of scope when this _run() ends.
1144    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
1145    # of a handle from a different device as an error.
1146    _ = self._update_with_movers(feed_dict_tensor, feed_map)
1147    final_fetches = fetch_handler.fetches()
1148    final_targets = fetch_handler.targets()
1149    # We only want to really perform the run if fetches or targets are provided,
1150    # or if the call is a partial run that specifies feeds.
1151    if final_fetches or final_targets or (handle and feed_dict_tensor):
1152      results = self._do_run(handle, final_targets, final_fetches,
1153                             feed_dict_tensor, options, run_metadata)
1154    else:
1155      results = []
1156    return fetch_handler.build_results(self, results)
1157
1158  def make_callable(self, fetches, feed_list=None, accept_options=False):
1159    """Returns a Python callable that runs a particular step.
1160
1161    The returned callable will take `len(feed_list)` arguments whose types
1162    must be compatible feed values for the respective elements of `feed_list`.
1163    For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
1164    argument to the returned callable must be a numpy ndarray (or something
1165    convertible to an ndarray) with matching element type and shape. See
1166    `tf.Session.run` for details of the allowable feed key and value types.
1167
1168    The returned callable will have the same return type as
1169    `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
1170    the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
1171    it will return `None`.
1172
1173    Args:
1174      fetches: A value or list of values to fetch. See `tf.Session.run`
1175        for details of the allowable fetch types.
1176      feed_list: (Optional.) A list of `feed_dict` keys. See
1177        `tf.Session.run` for details of the allowable feed key types.
1178      accept_options: (Optional.) If `True`, the returned `Callable` will be
1179        able to accept `tf.RunOptions` and `tf.RunMetadata` as optional
1180        keyword arguments `options` and `run_metadata`, respectively, with
1181        the same syntax and semantics as `tf.Session.run`, which is useful
1182        for certain use cases (profiling and debugging) but will result in
1183        measurable slowdown of the `Callable`'s performance. Default: `False`.
1184
1185    Returns:
1186      A function that when called will execute the step defined by
1187      `feed_list` and `fetches` in this session.
1188
1189    Raises:
1190      TypeError: If `fetches` or `feed_list` cannot be interpreted
1191        as arguments to `tf.Session.run`.
1192    """
1193    if feed_list is not None:
1194      if not isinstance(feed_list, (list, tuple)):
1195        raise TypeError('`feed_list` must be a list or tuple.')
1196      # Delegate any non-empty feed lists to the existing `run()` logic.
1197      # TODO(mrry): Refactor the feed handling logic from
1198      # `Session._run()` so that we can convert the feeds to a list of
1199      # strings here.
1200      def _generic_run(*feed_args, **kwargs):
1201        feed_dict = {
1202            feed: feed_val
1203            for feed, feed_val in zip(feed_list, feed_args)
1204        }
1205        return self.run(fetches, feed_dict=feed_dict, **kwargs)
1206
1207      return _generic_run
1208
1209    # Ensure any changes to the graph are reflected in the runtime.
1210    # Note that we don't need to do this on subsequent calls to the
1211    # returned object, because the arguments to `fetches` must already be
1212    # in the graph.
1213    self._extend_graph()
1214
1215    # Create a fetch handler to take care of the structure of fetches.
1216    fetch_handler = _FetchHandler(self._graph, fetches, {})
1217    # pylint: disable=protected-access
1218    fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
1219    target_list = [op._c_op for op in fetch_handler.targets()]
1220    # pylint: enable=protected-access
1221
1222    def _callable_template_with_options_and_metadata(fetch_list,
1223                                                     target_list,
1224                                                     fetch_handler,
1225                                                     options=None,
1226                                                     run_metadata=None):
1227      """Template callable that accepts RunOptions and RunMetadata."""
1228      options_ptr = tf_session.TF_NewBufferFromString(
1229          compat.as_bytes(options.SerializeToString())) if options else None
1230      run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1231      try:
1232        results = self._call_tf_sessionrun(
1233            options_ptr, {}, fetch_list, target_list, run_metadata_ptr)
1234        if fetch_handler:
1235          results = fetch_handler.build_results(self, results)
1236        else:
1237          results = results[0] if results else None
1238        if run_metadata:
1239          proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1240          run_metadata.ParseFromString(compat.as_bytes(proto_data))
1241      finally:
1242        if run_metadata_ptr:
1243          tf_session.TF_DeleteBuffer(run_metadata_ptr)
1244        if options:
1245          tf_session.TF_DeleteBuffer(options_ptr)
1246      return results
1247
1248    if accept_options:
1249      return functools.partial(_callable_template_with_options_and_metadata,
1250                               fetch_list, target_list, fetch_handler)
1251    elif isinstance(fetches, ops.Operation):
1252      # Special case for fetching a single operation, because the
1253      # function will have no return value.
1254      assert not fetch_list
1255      assert len(target_list) == 1
1256
1257      def _single_operation_run():
1258        self._call_tf_sessionrun(None, {}, [], target_list, None)
1259
1260      return _single_operation_run
1261    elif isinstance(fetches, ops.Tensor):
1262      # Special case for fetching a single tensor, because the
1263      # function can return the result of `TF_Run()` directly.
1264      assert len(fetch_list) == 1
1265      assert not target_list
1266
1267      def _single_tensor_run():
1268        results = self._call_tf_sessionrun(None, {}, fetch_list, [], None)
1269        return results[0]
1270
1271      return _single_tensor_run
1272    else:
1273      # In all other cases, we must use `fetch_handler` to build the
1274      # results for us.
1275      def _fetch_handler_run():
1276        results = self._call_tf_sessionrun(
1277            None, {}, fetch_list, target_list, None)
1278        return fetch_handler.build_results(self, results)
1279
1280      return _fetch_handler_run
1281
1282  # Captures the name of a node in an error status. The regex below matches
1283  # both the old and the new formats:
1284  # Old format: [[Node: <node_name> = ...]]
1285  # New format: [[{{node <node_name>}} = ...]]
1286  _NODEDEF_NAME_RE = re.compile(
1287      r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*')
1288
1289  def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
1290              run_metadata):
1291    """Runs a step based on the given fetches and feeds.
1292
1293    Args:
1294      handle: a handle for partial_run. None if this is just a call to run().
1295      target_list: A list of operations to be run, but not fetched.
1296      fetch_list: A list of tensors to be fetched.
1297      feed_dict: A dictionary that maps tensors to numpy ndarrays.
1298      options: A (pointer to a) [`RunOptions`] protocol buffer, or None
1299      run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None
1300
1301    Returns:
1302      A list of numpy ndarrays, corresponding to the elements of
1303      `fetch_list`.  If the ith element of `fetch_list` contains the
1304      name of an operation, the first Tensor output of that operation
1305      will be returned for that element.
1306
1307    Raises:
1308      tf.errors.OpError: Or one of its subclasses on error.
1309    """
1310    # pylint: disable=protected-access
1311    feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
1312    fetches = [t._as_tf_output() for t in fetch_list]
1313    targets = [op._c_op for op in target_list]
1314    # pylint: enable=protected-access
1315
1316    def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
1317      # Ensure any changes to the graph are reflected in the runtime.
1318      self._extend_graph()
1319      return self._call_tf_sessionrun(
1320          options, feed_dict, fetch_list, target_list, run_metadata)
1321
1322    def _prun_fn(handle, feed_dict, fetch_list):
1323      if target_list:
1324        raise RuntimeError('partial_run() requires empty target_list.')
1325      return self._call_tf_sessionprun(handle, feed_dict, fetch_list)
1326
1327    if handle is None:
1328      return self._do_call(_run_fn, feeds, fetches, targets, options,
1329                           run_metadata)
1330    else:
1331      return self._do_call(_prun_fn, handle, feeds, fetches)
1332
1333  def _do_call(self, fn, *args):
1334    try:
1335      return fn(*args)
1336    except errors.OpError as e:
1337      message = compat.as_text(e.message)
1338      m = BaseSession._NODEDEF_NAME_RE.search(message)
1339      node_def = None
1340      op = None
1341      if m is not None:
1342        node_name = m.group(3)
1343        try:
1344          op = self._graph.get_operation_by_name(node_name)
1345          node_def = op.node_def
1346        except KeyError:
1347          pass
1348      message = error_interpolation.interpolate(message, self._graph)
1349      raise type(e)(node_def, op, message)
1350
1351  def _extend_graph(self):
1352    with self._graph._session_run_lock():  # pylint: disable=protected-access
1353      tf_session.ExtendSession(self._session)
1354
1355  # The threshold to run garbage collection to delete dead tensors.
1356  _DEAD_HANDLES_THRESHOLD = 10
1357
1358  def _register_dead_handle(self, handle):
1359    # Register a dead handle in the session. Delete the dead tensors when
1360    # the number of dead tensors exceeds certain threshold.
1361    tensors_to_delete = None
1362    with self._delete_lock:
1363      self._dead_handles.append(handle)
1364      if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD:
1365        tensors_to_delete = self._dead_handles
1366        self._dead_handles = []
1367    # Delete the dead tensors.
1368    if tensors_to_delete:
1369      feeds = {}
1370      fetches = []
1371      for deleter_key, tensor_handle in enumerate(tensors_to_delete):
1372        holder, deleter = session_ops._get_handle_deleter(
1373            self.graph, deleter_key, tensor_handle)
1374        feeds[holder] = tensor_handle
1375        fetches.append(deleter)
1376      self.run(fetches, feed_dict=feeds)
1377
1378  def _update_with_movers(self, feed_dict, feed_map):
1379    # If a tensor handle that is fed to a device incompatible placeholder,
1380    # we move the tensor to the right device, generate a new tensor handle,
1381    # and update `feed_dict` to use the new handle.
1382    handle_movers = []
1383    for feed_name, val in feed_map.items():
1384      mover = session_ops._get_handle_mover(self.graph, *val)
1385      if mover:
1386        handle_movers.append((feed_name, val[1], mover))
1387    # Transfer a tensor to the right device if needed.
1388    if not handle_movers:
1389      return []
1390    else:
1391      feeds = {}
1392      fetches = []
1393      for _, handle, mover in handle_movers:
1394        feeds[mover[0]] = handle
1395        fetches.append(mover[1])
1396      handles = self.run(fetches, feed_dict=feeds)
1397      for handle_mover, handle in zip(handle_movers, handles):
1398        np_val = np.array(handle.handle, dtype=np.object)
1399        feed_name = handle_mover[0]
1400        feed_tensor = feed_map[feed_name][0]
1401        feed_dict[feed_tensor] = np_val
1402      return handles
1403
1404  def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
1405                          run_metadata):
1406    return tf_session.TF_SessionRun_wrapper(
1407        self._session, options, feed_dict, fetch_list, target_list,
1408        run_metadata)
1409
1410  def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
1411    return tf_session.TF_SessionPRun_wrapper(
1412        self._session, handle, feed_dict, fetch_list)
1413
1414  # pylint: disable=protected-access
1415  class _Callable(object):
1416    """Experimental wrapper for the C++ `Session::MakeCallable()` API."""
1417
1418    def __init__(self, session, callable_options):
1419      self._session = session
1420      self._handle = None
1421      options_ptr = tf_session.TF_NewBufferFromString(
1422          compat.as_bytes(callable_options.SerializeToString()))
1423      try:
1424        with errors.raise_exception_on_not_ok_status() as status:
1425          self._handle = tf_session.TF_SessionMakeCallable(
1426              session._session, options_ptr, status)
1427      finally:
1428        tf_session.TF_DeleteBuffer(options_ptr)
1429
1430    def __call__(self, *args, **kwargs):
1431      # TODO(b/74355905): Support argument and return value nested structures,
1432      # and tensor-like objects such as SparseTensors.
1433      run_metadata = kwargs.get('run_metadata', None)
1434      try:
1435        run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1436        # TODO(mrry): Switch to raising an exception from the SWIG wrapper.
1437        with errors.raise_exception_on_not_ok_status() as status:
1438          ret = tf_session.TF_SessionRunCallable(
1439              self._session._session, self._handle, args, status,
1440              run_metadata_ptr)
1441        if run_metadata:
1442          proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1443          run_metadata.ParseFromString(compat.as_bytes(proto_data))
1444      finally:
1445        if run_metadata_ptr:
1446          tf_session.TF_DeleteBuffer(run_metadata_ptr)
1447      return ret
1448
1449    def __del__(self):
1450      # NOTE(mrry): It is possible that `self._session.__del__()` could be
1451      # called before this destructor, in which case `self._session._session`
1452      # will be `None`.
1453      if self._handle is not None and self._session._session is not None:
1454        with errors.raise_exception_on_not_ok_status() as status:
1455          tf_session.TF_SessionReleaseCallable(
1456              self._session._session, self._handle, status)
1457  # pylint: enable=protected-access
1458
1459  # TODO(b/74355905): Reimplement `Session.make_callable()` using this method
1460  # where possible.
1461  def _make_callable_from_options(self, callable_options):
1462    """Returns a handle to a "callable" with the given options.
1463
1464    Args:
1465      callable_options: A `CallableOptions` protocol buffer message describing
1466        the computation that will be performed by the callable.
1467
1468    Returns:
1469      A handle to the new callable.
1470    """
1471    self._extend_graph()
1472    return BaseSession._Callable(self, callable_options)
1473
1474
1475@tf_export(v1=['Session'])
1476class Session(BaseSession):
1477  """A class for running TensorFlow operations.
1478
1479  A `Session` object encapsulates the environment in which `Operation`
1480  objects are executed, and `Tensor` objects are evaluated. For
1481  example:
1482
1483  ```python
1484  # Build a graph.
1485  a = tf.constant(5.0)
1486  b = tf.constant(6.0)
1487  c = a * b
1488
1489  # Launch the graph in a session.
1490  sess = tf.Session()
1491
1492  # Evaluate the tensor `c`.
1493  print(sess.run(c))
1494  ```
1495
1496  A session may own resources, such as
1497  `tf.Variable`, `tf.QueueBase`,
1498  and `tf.ReaderBase`. It is important to release
1499  these resources when they are no longer required. To do this, either
1500  invoke the `tf.Session.close` method on the session, or use
1501  the session as a context manager. The following two examples are
1502  equivalent:
1503
1504  ```python
1505  # Using the `close()` method.
1506  sess = tf.Session()
1507  sess.run(...)
1508  sess.close()
1509
1510  # Using the context manager.
1511  with tf.Session() as sess:
1512    sess.run(...)
1513  ```
1514
1515  The
1516  [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1517  protocol buffer exposes various configuration options for a
1518  session. For example, to create a session that uses soft constraints
1519  for device placement, and log the resulting placement decisions,
1520  create a session as follows:
1521
1522  ```python
1523  # Launch the graph in a session that allows soft device placement and
1524  # logs the placement decisions.
1525  sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
1526                                          log_device_placement=True))
1527  ```
1528  """
1529
1530  def __init__(self, target='', graph=None, config=None):
1531    """Creates a new TensorFlow session.
1532
1533    If no `graph` argument is specified when constructing the session,
1534    the default graph will be launched in the session. If you are
1535    using more than one graph (created with `tf.Graph()`) in the same
1536    process, you will have to use different sessions for each graph,
1537    but each graph can be used in multiple sessions. In this case, it
1538    is often clearer to pass the graph to be launched explicitly to
1539    the session constructor.
1540
1541    Args:
1542      target: (Optional.) The execution engine to connect to.
1543        Defaults to using an in-process engine. See
1544        [Distributed TensorFlow](https://tensorflow.org/deploy/distributed)
1545        for more examples.
1546      graph: (Optional.) The `Graph` to be launched (described above).
1547      config: (Optional.) A
1548        [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1549        protocol buffer with configuration options for the session.
1550
1551    """
1552    super(Session, self).__init__(target, graph, config=config)
1553    # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.
1554    self._default_graph_context_manager = None
1555    self._default_session_context_manager = None
1556
1557  def __enter__(self):
1558    if self._default_graph_context_manager is None:
1559      self._default_graph_context_manager = self.graph.as_default()
1560    else:
1561      raise RuntimeError('Session context managers are not re-entrant. '
1562                         'Use `Session.as_default()` if you want to enter '
1563                         'a session multiple times.')
1564    if self._default_session_context_manager is None:
1565      self._default_session_context_manager = self.as_default()
1566    self._default_graph_context_manager.__enter__()
1567    return self._default_session_context_manager.__enter__()
1568
1569  def __exit__(self, exec_type, exec_value, exec_tb):
1570    if exec_type is errors.OpError:
1571      logging.error('Session closing due to OpError: %s', (exec_value,))
1572    try:
1573      self._default_session_context_manager.__exit__(exec_type, exec_value,
1574                                                     exec_tb)
1575    except RuntimeError as error:
1576      if error == exec_value:
1577        # NOTE(skyewm): for some reason, in Python3,
1578        # _default_session_context_manager.__exit__ will re-raise the "not
1579        # re-entrant" exception raised in __enter__ above (note that if we're
1580        # here, we're in the outer session context manager, since __exit__ is
1581        # not called when __enter__ raises an exception). We still want to
1582        # continue cleaning up this context manager before the exception is
1583        # further propagated, so we ignore it here (note that it'll continue
1584        # being propagated after this method completes).
1585        pass
1586      else:
1587        raise
1588    self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
1589
1590    self._default_session_context_manager = None
1591    self._default_graph_context_manager = None
1592
1593    # If we are closing due to an exception, set a time limit on our Close() to
1594    # avoid blocking forever.
1595    # TODO(b/120204635) remove this when deadlock is fixed.
1596    if exec_type:
1597      close_thread = threading.Thread(
1598          name='SessionCloseThread', target=self.close)
1599      close_thread.daemon = True
1600      close_thread.start()
1601      close_thread.join(30.0)
1602      if close_thread.is_alive():
1603        logging.error(
1604            'Session failed to close after 30 seconds. Continuing after this '
1605            'point may leave your program in an undefined state.')
1606    else:
1607      self.close()
1608
1609  @staticmethod
1610  def reset(target, containers=None, config=None):
1611    """Resets resource containers on `target`, and close all connected sessions.
1612
1613    A resource container is distributed across all workers in the
1614    same cluster as `target`.  When a resource container on `target`
1615    is reset, resources associated with that container will be cleared.
1616    In particular, all Variables in the container will become undefined:
1617    they lose their values and shapes.
1618
1619    NOTE:
1620    (i) reset() is currently only implemented for distributed sessions.
1621    (ii) Any sessions on the master named by `target` will be closed.
1622
1623    If no resource containers are provided, all containers are reset.
1624
1625    Args:
1626      target: The execution engine to connect to.
1627      containers: A list of resource container name strings, or `None` if all of
1628        all the containers are to be reset.
1629      config: (Optional.) Protocol buffer with configuration options.
1630
1631    Raises:
1632      tf.errors.OpError: Or one of its subclasses if an error occurs while
1633        resetting containers.
1634    """
1635    if target is not None:
1636      target = compat.as_bytes(target)
1637    if containers is not None:
1638      containers = [compat.as_bytes(c) for c in containers]
1639    else:
1640      containers = []
1641    tf_session.TF_Reset(target, containers, config)
1642
1643
1644@tf_export(v1=['InteractiveSession'])
1645class InteractiveSession(BaseSession):
1646  """A TensorFlow `Session` for use in interactive contexts, such as a shell.
1647
1648  The only difference with a regular `Session` is that an `InteractiveSession`
1649  installs itself as the default session on construction.
1650  The methods `tf.Tensor.eval`
1651  and `tf.Operation.run`
1652  will use that session to run ops.
1653
1654  This is convenient in interactive shells and [IPython
1655  notebooks](http://ipython.org), as it avoids having to pass an explicit
1656  `Session` object to run ops.
1657
1658  For example:
1659
1660  ```python
1661  sess = tf.InteractiveSession()
1662  a = tf.constant(5.0)
1663  b = tf.constant(6.0)
1664  c = a * b
1665  # We can just use 'c.eval()' without passing 'sess'
1666  print(c.eval())
1667  sess.close()
1668  ```
1669
1670  Note that a regular session installs itself as the default session when it
1671  is created in a `with` statement.  The common usage in non-interactive
1672  programs is to follow that pattern:
1673
1674  ```python
1675  a = tf.constant(5.0)
1676  b = tf.constant(6.0)
1677  c = a * b
1678  with tf.Session():
1679    # We can also use 'c.eval()' here.
1680    print(c.eval())
1681  ```
1682  """
1683
1684  _count_lock = threading.Lock()
1685  _active_session_count = 0  # GUARDED_BY(_count_lock)
1686
1687  def __init__(self, target='', graph=None, config=None):
1688    """Creates a new interactive TensorFlow session.
1689
1690    If no `graph` argument is specified when constructing the session,
1691    the default graph will be launched in the session. If you are
1692    using more than one graph (created with `tf.Graph()`) in the same
1693    process, you will have to use different sessions for each graph,
1694    but each graph can be used in multiple sessions. In this case, it
1695    is often clearer to pass the graph to be launched explicitly to
1696    the session constructor.
1697
1698    Args:
1699      target: (Optional.) The execution engine to connect to.
1700        Defaults to using an in-process engine.
1701      graph: (Optional.) The `Graph` to be launched (described above).
1702      config: (Optional) `ConfigProto` proto used to configure the session.
1703    """
1704    if not config:
1705      # If config is not provided, choose some reasonable defaults for
1706      # interactive use:
1707      #
1708      #   - Grow GPU memory as needed at the cost of fragmentation.
1709      gpu_options = config_pb2.GPUOptions(allow_growth=True)
1710      config = config_pb2.ConfigProto(gpu_options=gpu_options)
1711    # Interactive sessions always place pruned graphs.
1712    config.graph_options.place_pruned_graph = True
1713
1714    super(InteractiveSession, self).__init__(target, graph, config)
1715    with InteractiveSession._count_lock:
1716      if InteractiveSession._active_session_count > 0:
1717        warnings.warn('An interactive session is already active. This can '
1718                      'cause out-of-memory errors in some cases. You must '
1719                      'explicitly call `InteractiveSession.close()` to release '
1720                      'resources held by the other session(s).')
1721      InteractiveSession._active_session_count += 1
1722    # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful
1723    # semantics (in particular, it is not set to true if `Session.close()` is
1724    # called on a session that has not been "opened" by running a step) and we
1725    # cannot change those semantics without breaking existing code.
1726    self._explicitly_closed = False
1727
1728    self._default_session = self.as_default()
1729    self._default_session.enforce_nesting = False
1730    self._default_session.__enter__()
1731    self._explicit_graph = graph
1732    if self._explicit_graph is not None:
1733      self._default_graph = graph.as_default()
1734      self._default_graph.enforce_nesting = False
1735      self._default_graph.__enter__()
1736
1737  def close(self):
1738    """Closes an `InteractiveSession`."""
1739    super(InteractiveSession, self).close()
1740    with InteractiveSession._count_lock:
1741      if not self._explicitly_closed:
1742        InteractiveSession._active_session_count -= 1
1743        self._explicitly_closed = True
1744      else:
1745        return
1746    if self._explicit_graph is not None:
1747      self._default_graph.__exit__(None, None, None)
1748      self._default_graph = None
1749    self._default_session.__exit__(None, None, None)
1750    self._default_session = None
1751