• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A client interface for TensorFlow."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import re
24import threading
25import warnings
26
27import numpy as np
28import wrapt
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python.client import pywrap_tf_session as tf_session
33from tensorflow.python.eager import context
34from tensorflow.python.eager import monitoring
35from tensorflow.python.framework import device
36from tensorflow.python.framework import error_interpolation
37from tensorflow.python.framework import errors
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.ops import session_ops
41from tensorflow.python.platform import tf_logging as logging
42from tensorflow.python.training.experimental import mixed_precision_global_state
43from tensorflow.python.util import compat
44from tensorflow.python.util import nest
45from tensorflow.python.util.compat import collections_abc
46from tensorflow.python.util.tf_export import tf_export
47
48_python_session_create_counter = monitoring.Counter(
49    '/tensorflow/api/python/session_create_counter',
50    'Counter for number of sessions created in Python.')
51
52
53class SessionInterface(object):
54  """Base class for implementations of TensorFlow client sessions."""
55
56  @property
57  def graph(self):
58    """The underlying TensorFlow graph, to be used in building Operations."""
59    raise NotImplementedError('graph')
60
61  @property
62  def sess_str(self):
63    """The TensorFlow process to which this session will connect."""
64    raise NotImplementedError('sess_str')
65
66  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
67    """Runs operations in the session. See `BaseSession.run()` for details."""
68    raise NotImplementedError('run')
69
70  def partial_run_setup(self, fetches, feeds=None):
71    """Sets up the feeds and fetches for partial runs in the session."""
72    raise NotImplementedError('partial_run_setup')
73
74  def partial_run(self, handle, fetches, feed_dict=None):
75    """Continues the execution with additional feeds and fetches."""
76    raise NotImplementedError('partial_run')
77
78
79def _get_indexed_slices_value_from_fetches(fetched_vals):
80  return ops.IndexedSlicesValue(
81      fetched_vals[0], fetched_vals[1],
82      fetched_vals[2] if len(fetched_vals) == 3 else None)
83
84
85def _get_feeds_for_indexed_slices(feed, feed_val):
86  return list(
87      zip([feed.values, feed.indices] if feed.dense_shape is None else
88          [feed.values, feed.indices, feed.dense_shape], feed_val))
89
90
91# List of extensions supported to convert run arguments into actual fetches and
92# feeds.
93#
94# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2),
95# where the function signatures are:
96#   fetch_fn : Type -> (list of Tensors,
97#                       lambda: list of fetched np.ndarray -> TypeVal)
98#   feed_fn1 : Type, TypeVal -> list of (Tensor, value)
99#   feed_fn2 : Type -> list of Tensors
100#
101# `fetch_fn` describes how to expand fetch into its
102# component Tensors and how to contract the fetched results back into
103# a single return value.
104#
105# Each feed function describes how to unpack a single fed value and map it to
106# feeds of one or more tensors and their corresponding values: `feed_fn1` is
107# used to feed a run, `feed_fn2` to set up a partial run.
108#
109# TODO(touts): We could reimplement these as specialized _FeedMapper
110# implementations after we refactor the feed handling code to use them.
111#
112# Eventually, this registration could be opened up to support custom Tensor
113# expansions.
114# pylint: disable=g-long-lambda
115_REGISTERED_EXPANSIONS = [
116    # SparseTensors are fetched as SparseTensorValues. They can be fed
117    # SparseTensorValues or normal tuples.
118    (sparse_tensor.SparseTensor, lambda fetch: ([
119        fetch.indices, fetch.values, fetch.dense_shape
120    ], lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)),
121     lambda feed, feed_val: list(
122         zip([feed.indices, feed.values, feed.dense_shape], feed_val)),
123     lambda feed: [feed.indices, feed.values, feed.dense_shape]),
124    # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
125    # IndexedSlicesValues or normal tuples.
126    (ops.IndexedSlices,
127     lambda fetch: ([fetch.values, fetch.indices] if fetch.dense_shape is None
128                    else [fetch.values, fetch.indices, fetch.dense_shape
129                         ], _get_indexed_slices_value_from_fetches),
130     _get_feeds_for_indexed_slices,
131     lambda feed: [feed.values, feed.indices] if feed.dense_shape is None else
132     [feed.values, feed.indices, feed.dense_shape]),
133    # The default catches all other types and performs no expansions.
134    (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
135     lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed])
136]
137
138# pylint: enable=g-long-lambda
139
140
141def _convert_to_numpy_obj(numpy_dtype, obj):
142  """Explicitly convert obj based on numpy type except for string type."""
143  return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
144
145
146def register_session_run_conversion_functions(
147    tensor_type,
148    fetch_function,
149    feed_function=None,
150    feed_function_for_partial_run=None):
151  """Register fetch and feed conversion functions for `tf.Session.run()`.
152
153  This function registers a triple of conversion functions for fetching and/or
154  feeding values of user-defined types in a call to tf.Session.run().
155
156  An example
157
158  ```python
159     class SquaredTensor(object):
160       def __init__(self, tensor):
161         self.sq = tf.square(tensor)
162     #you can define conversion functions as follows:
163     fetch_function = lambda squared_tensor:([squared_tensor.sq],
164                                             lambda val: val[0])
165     feed_function = lambda feed, feed_val: [(feed.sq, feed_val)]
166     feed_function_for_partial_run = lambda feed: [feed.sq]
167     #then after invoking this register function, you can use as follows:
168     session.run(squared_tensor1,
169                 feed_dict = {squared_tensor2 : some_numpy_array})
170  ```
171
172  Args:
173    tensor_type: The type for which you want to register a conversion function.
174    fetch_function: A callable that takes an object of type `tensor_type` and
175      returns a tuple, where the first element is a list of `tf.Tensor` objects,
176      and the second element is a callable that takes a list of ndarrays and
177      returns an object of some value type that corresponds to `tensor_type`.
178      fetch_function describes how to expand fetch into its component Tensors
179      and how to contract the fetched results back into a single return value.
180    feed_function: A callable that takes feed_key and feed_value as input, and
181      returns a list of tuples (feed_tensor, feed_val), feed_key must have type
182      `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed
183      function describes how to unpack a single fed value and map it to feeds of
184      one or more tensors and their corresponding values.
185    feed_function_for_partial_run: A callable for specifying tensor values to
186      feed when setting up a partial run, which takes a `tensor_type` type
187      object as input, and returns a list of Tensors.
188
189  Raises:
190    ValueError: If `tensor_type` has already been registered.
191  """
192  for conversion_function in _REGISTERED_EXPANSIONS:
193    if issubclass(conversion_function[0], tensor_type):
194      raise ValueError('%s has already been registered so ignore it.' %
195                       tensor_type)
196
197  _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
198                                    feed_function_for_partial_run))
199
200
201def _is_attrs_instance(obj):
202  """Returns True if the given obj is an instance of attrs-decorated class."""
203  return getattr(obj.__class__, '__attrs_attrs__', None) is not None
204
205
206def _get_attrs_values(obj):
207  """Returns the list of values from an attrs instance."""
208  attrs = getattr(obj.__class__, '__attrs_attrs__')
209  return [getattr(obj, a.name) for a in attrs]
210
211
212class _FetchMapper(object):
213  """Definition of the interface provided by fetch mappers.
214
215  Fetch mappers are utility classes used by the _FetchHandler to handle
216  arbitrary structures for the `fetch` argument to `Session.run()`.
217
218  The `fetch` argument can be of various shapes: single tensor or op, list of
219  fetches, tuple of fetches, namedtuple of fetches, or dict of fetches.  The
220  structures can be arbitrarily nested.
221
222  The low level run() API only wants a list of tensor or op names.  The various
223  `_FetchMapper` subclasses below take care of handling the different shapes:
224  uniquifying the fetches, and constructing results with the original shape.
225  """
226
227  def unique_fetches(self):
228    """Return the list of unique tensors or ops needed by this fetch mapper.
229
230    Returns:
231      A list of tensors or ops.
232    """
233    raise NotImplementedError('Must be implemented by subclasses')
234
235  def build_results(self, values):
236    """Build results that match the original shape of the fetch.
237
238    Args:
239      values: List of values returned by run(). The values correspond exactly to
240        the list tensors or ops returned by unique_fetches().
241
242    Returns:
243      A struct of the same shape as the original fetch object handled by
244      this fetch mapper.  In the returned struct, the original fetches are
245      replaced by their fetched values.
246    """
247    raise NotImplementedError('Must be implemented by subclasses')
248
249  @staticmethod
250  def for_fetch(fetch):
251    """Creates fetch mapper that handles the structure of `fetch`.
252
253    The default graph must be the one from which we want to fetch values when
254    this function is called.
255
256    Args:
257      fetch: An arbitrary fetch structure: singleton, list, tuple, namedtuple,
258        or dict.
259
260    Returns:
261      An instance of a subclass of `_FetchMapper` that handles the shape.
262    """
263    if fetch is None:
264      raise TypeError('Fetch argument %r has invalid type %r' %
265                      (fetch, type(fetch)))
266    elif isinstance(fetch, (list, tuple)):
267      # NOTE(touts): This is also the code path for namedtuples.
268      return _ListFetchMapper(fetch)
269    elif isinstance(fetch, collections_abc.Mapping):
270      return _DictFetchMapper(fetch)
271    elif _is_attrs_instance(fetch):
272      return _AttrsFetchMapper(fetch)
273    else:
274      # Look for a handler in the registered expansions.
275      for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
276        if isinstance(fetch, tensor_type):
277          fetches, contraction_fn = fetch_fn(fetch)
278          return _ElementFetchMapper(fetches, contraction_fn)
279    # Did not find anything.
280    raise TypeError('Fetch argument %r has invalid type %r' %
281                    (fetch, type(fetch)))
282
283
284class _ElementFetchMapper(_FetchMapper):
285  """Fetch mapper for singleton tensors and ops."""
286
287  def __init__(self, fetches, contraction_fn):
288    """Creates an _ElementFetchMapper.
289
290    This is the fetch mapper used for leaves in the fetch struct.  Because of
291    the expansions mechanism, a leaf can actually fetch more than one tensor.
292
293    Also note that the fetches here can be just strings (tensor or op names) or
294    any other object that the graph knows how to convert to a tensor, such as a
295    Variable.  So we have to run each fetch through `as_graph_element()` to get
296    the corresponding tensor or op.
297
298    Args:
299      fetches: List of objects, as returned by a fetch_fn defined in
300        _REGISTERED_EXPANSIONS.
301      contraction_fn: Callable as returned by a fetch_fn.
302    """
303    self._unique_fetches = []
304    for fetch in fetches:
305      try:
306        self._unique_fetches.append(ops.get_default_graph().as_graph_element(
307            fetch, allow_tensor=True, allow_operation=True))
308      except TypeError as e:
309        raise TypeError('Fetch argument %r has invalid type %r, '
310                        'must be a string or Tensor. (%s)' %
311                        (fetch, type(fetch), str(e)))
312      except ValueError as e:
313        raise ValueError('Fetch argument %r cannot be interpreted as a '
314                         'Tensor. (%s)' % (fetch, str(e)))
315      except KeyError as e:
316        raise ValueError('Fetch argument %r cannot be interpreted as a '
317                         'Tensor. (%s)' % (fetch, str(e)))
318    self._contraction_fn = contraction_fn
319
320  def unique_fetches(self):
321    return self._unique_fetches
322
323  def build_results(self, values):
324    if not values:
325      # 'Operation' case
326      return None
327    else:
328      return self._contraction_fn(values)
329
330
331def _uniquify_fetches(fetch_mappers):
332  """Uniquifies fetches from a list of fetch_mappers.
333
334  This is a utility function used by _ListFetchMapper and _DictFetchMapper.  It
335  gathers all the unique fetches from a list of mappers and builds a list
336  containing all of them but without duplicates (unique_fetches).
337
338  It also returns a 2-D list of integers (values_indices) indicating at which
339  index in unique_fetches the fetches of the mappers are located.
340
341  This list is as follows:
342    values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index
343
344  Args:
345    fetch_mappers: list of fetch mappers.
346
347  Returns:
348    A list of fetches.
349    A 2-D list of integers.
350  """
351  unique_fetches = []
352  value_indices = []
353  seen_fetches = {}
354  for m in fetch_mappers:
355    m_value_indices = []
356    for f in m.unique_fetches():
357      j = seen_fetches.get(id(f))
358      if j is None:
359        j = len(seen_fetches)
360        seen_fetches[id(f)] = j
361        unique_fetches.append(f)
362      m_value_indices.append(j)
363    value_indices.append(m_value_indices)
364  return unique_fetches, value_indices
365
366
367class _ListFetchMapper(_FetchMapper):
368  """Fetch mapper for lists, tuples, and namedtuples."""
369
370  def __init__(self, fetches):
371    """Creates a _ListFetchMapper.
372
373    Args:
374      fetches: List, tuple, or namedtuple of fetches.
375    """
376    if isinstance(fetches, wrapt.ObjectProxy):
377      self._fetch_type = type(fetches.__wrapped__)
378    else:
379      self._fetch_type = type(fetches)
380    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
381    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
382
383  def unique_fetches(self):
384    return self._unique_fetches
385
386  def build_results(self, values):
387    # Create the list of results for each mapper.
388    results = []
389    for m, vi in zip(self._mappers, self._value_indices):
390      results.append(m.build_results([values[j] for j in vi]))
391    # Return a value of the original type of the fetches.
392    if issubclass(self._fetch_type, list):
393      return results
394    elif self._fetch_type == tuple:
395      return tuple(results)
396    else:
397      # This is the code path for namedtuple.
398      return self._fetch_type(*results)
399
400
401class _DictFetchMapper(_FetchMapper):
402  """Fetch mapper for dicts."""
403
404  def __init__(self, fetches):
405    """Creates a _DictFetchMapper.
406
407    Args:
408      fetches: Dict of fetches.
409    """
410    self._fetch_type = type(fetches)
411    if isinstance(fetches, collections.defaultdict):
412      self._type_ctor = functools.partial(collections.defaultdict,
413                                          fetches.default_factory)
414    else:
415      self._type_ctor = self._fetch_type
416
417    self._keys = fetches.keys()
418    self._mappers = [
419        _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
420    ]
421    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
422
423  def unique_fetches(self):
424    return self._unique_fetches
425
426  def build_results(self, values):
427
428    def _generator():
429      for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
430        yield k, m.build_results([values[j] for j in vi])
431
432    return self._type_ctor(_generator())
433
434
435class _AttrsFetchMapper(_FetchMapper):
436  """Fetch mapper for attrs decorated classes."""
437
438  def __init__(self, fetches):
439    """Creates a _AttrsFetchMapper.
440
441    Args:
442      fetches: An instance of an attrs decorated class.
443    """
444    values = _get_attrs_values(fetches)
445    self._fetch_type = type(fetches)
446    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in values]
447    self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
448
449  def unique_fetches(self):
450    return self._unique_fetches
451
452  def build_results(self, values):
453    results = []
454    for m, vi in zip(self._mappers, self._value_indices):
455      results.append(m.build_results([values[j] for j in vi]))
456    return self._fetch_type(*results)
457
458
459class _FetchHandler(object):
460  """Handler for structured fetches.
461
462  Given a graph, a user-provided structure for fetches, and a feed dict, this
463  class takes care of generating a list of tensor names to fetch and op names
464  to run for a low level `run()` call.
465
466  Given the results of the low level run call, this class can also rebuild a
467  result structure matching the user-provided structure for fetches, but
468  containing the corresponding results.
469  """
470
471  # TODO(touts): Make this class also take care of destructuring the feed
472  # dict instead of doing it in the callers.
473
474  def __init__(self, graph, fetches, feeds, feed_handles=None):
475    """Creates a fetch handler.
476
477    Args:
478      graph: Graph of the fetches.   Used to check for fetchability and to
479        convert all fetches to tensors or ops as needed.
480      fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple,
481        or dict.
482      feeds: A feed dict where keys are Tensors.
483      feed_handles: A dict from feed Tensors to TensorHandle objects used as
484        direct feeds.
485    """
486    with graph.as_default():
487      self._fetch_mapper = _FetchMapper.for_fetch(fetches)
488    self._fetches = []
489    self._targets = []
490    self._feeds = feeds
491    self._feed_handles = feed_handles or {}
492    self._ops = []
493    self._fetch_handles = {}
494    for fetch in self._fetch_mapper.unique_fetches():
495      if isinstance(fetch, ops.Operation):
496        self._assert_fetchable(graph, fetch)
497        self._targets.append(fetch)
498        self._ops.append(True)
499      else:
500        self._assert_fetchable(graph, fetch.op)
501        self._fetches.append(fetch)
502        self._ops.append(False)
503      # Remember the fetch if it is for a tensor handle.
504      if (isinstance(fetch, ops.Tensor) and
505          (fetch.op.type == 'GetSessionHandle' or
506           fetch.op.type == 'GetSessionHandleV2')):
507        self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype
508    self._final_fetches = [x for x in self._fetches if x.ref() not in feeds]
509
510  def _assert_fetchable(self, graph, op):
511    if not graph.is_fetchable(op):
512      raise errors.InaccessibleTensorError(
513          'Operation %r has been marked as not fetchable. Typically this'
514          ' happens when it is defined in another function or code block.'
515          ' Use return values,explicit Python locals or TensorFlow collections'
516          ' to access it.'
517          % op.name)
518
519  def fetches(self):
520    """Return the unique names of tensors to fetch.
521
522    Returns:
523      A list of strings.
524    """
525    return self._final_fetches
526
527  def targets(self):
528    """Return the unique names of ops to run.
529
530    Returns:
531      A list of strings.
532    """
533    return self._targets
534
535  def build_results(self, session, tensor_values):
536    """Build results matching the original fetch shape.
537
538    `tensor_values` must be a list of the same length as
539    the one returned by `fetches()`, and holding the requested
540    fetch values.
541
542    This method builds a struct with the same shape as the original `fetches`
543    passed to the constructor, in which the fetches are replaced by their
544    fetched value.
545
546    Args:
547      session: The enclosing session.  Used for tensor handles.
548      tensor_values: List of values matching the list returned by fetches().
549
550    Returns:
551      A structure of the same shape as the original `fetches` argument but
552        containing tensors or None (for fetched ops).
553    """
554    full_values = []
555    assert len(self._final_fetches) == len(tensor_values)
556    i = 0
557    j = 0
558    for is_op in self._ops:
559      if is_op:
560        full_values.append(None)
561      else:
562        # If the fetch was in the feeds, use the fed value, otherwise
563        # use the returned value.
564        if self._fetches[i].ref() in self._feed_handles:
565          # A fetch had a corresponding direct TensorHandle feed. Call eval()
566          # to obtain the Tensor value from the TensorHandle.
567          value = self._feed_handles[self._fetches[i].ref()].eval()
568        else:
569          value = self._feeds.get(self._fetches[i].ref())
570        if value is None:
571          value = tensor_values[j]
572          j += 1
573        dtype = self._fetch_handles.get(self._fetches[i].ref())
574        if dtype:
575          full_values.append(session_ops.TensorHandle(value, dtype, session))
576        else:
577          full_values.append(value)
578        i += 1
579    assert j == len(tensor_values)
580    return self._fetch_mapper.build_results(full_values)
581
582
583def _name_list(tensor_list):
584  """Utility function for transitioning to the new session API.
585
586  Args:
587    tensor_list: a list of `Tensor`s.
588
589  Returns:
590    A list of each `Tensor`s name (as byte arrays).
591  """
592  return [compat.as_bytes(t.name) for t in tensor_list]
593
594
595class _DeviceAttributes(object):
596  """Struct-like object describing a device's attributes.
597
598  Each device has 3 key properties:
599   - name: the fully-qualified TensorFlow path to the device. For
600        example: /job:worker/replica:0/task:3/device:CPU:0
601   - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.)
602   - memory_limit_bytes: the maximum amount of memory available on the device
603        (in bytes).
604  """
605
606  def __init__(self, name, device_type, memory_limit_bytes, incarnation):
607    self._name = device.canonical_name(name)
608    self._device_type = device_type
609    self._memory_limit_bytes = memory_limit_bytes
610    self._incarnation = incarnation
611
612  @property
613  def name(self):
614    return self._name
615
616  @property
617  def device_type(self):
618    return self._device_type
619
620  @property
621  def memory_limit_bytes(self):
622    return self._memory_limit_bytes
623
624  @property
625  def incarnation(self):
626    return self._incarnation
627
628  def __repr__(self):
629    return '_DeviceAttributes(%s, %s, %d, %d)' % (
630        self.name,
631        self.device_type,
632        self.memory_limit_bytes,
633        self.incarnation,
634    )
635
636
637class BaseSession(SessionInterface):
638  """A class for interacting with a TensorFlow computation.
639
640  The BaseSession enables incremental graph building with inline
641  execution of Operations and evaluation of Tensors.
642  """
643
644  def __init__(self, target='', graph=None, config=None):
645    """Constructs a new TensorFlow session.
646
647    Args:
648      target: (Optional) The TensorFlow execution engine to connect to.
649      graph: (Optional) The graph to be used. If this argument is None, the
650        default graph will be used.
651      config: (Optional) ConfigProto proto used to configure the session. If no
652        config is specified, the global default will be used. The global default
653        can be configured via the tf.config APIs.
654
655    Raises:
656      tf.errors.OpError: Or one of its subclasses if an error occurs while
657        creating the TensorFlow session.
658      TypeError: If one of the arguments has the wrong type.
659    """
660    _python_session_create_counter.get_cell().increase_by(1)
661    if graph is None:
662      self._graph = ops.get_default_graph()
663    else:
664      if not isinstance(graph, ops.Graph):
665        raise TypeError('graph must be a tf.Graph, but got %s' % type(graph))
666      self._graph = graph
667
668    self._closed = False
669
670    if target is not None:
671      try:
672        self._target = compat.as_bytes(target)
673      except TypeError:
674        if isinstance(target, config_pb2.ConfigProto):
675          raise TypeError('target must be a string, but got %s.'
676                          ' Did you do "Session(config)" instead of'
677                          ' "Session(config=config)"?' % type(target))
678        raise TypeError('target must be a string, but got %s' % type(target))
679    else:
680      self._target = None
681
682    self._delete_lock = threading.Lock()
683    self._dead_handles = []
684
685    if config is None:
686      config = context.context().config
687
688    if not isinstance(config, config_pb2.ConfigProto):
689      raise TypeError('config must be a tf.ConfigProto, but got %s' %
690                      type(config))
691
692    if (mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled
693        and config.graph_options.rewrite_options.auto_mixed_precision !=
694        rewriter_config_pb2.RewriterConfig.OFF):
695      new_config = config_pb2.ConfigProto()
696      new_config.CopyFrom(config)
697      new_config.graph_options.rewrite_options.auto_mixed_precision = (
698          rewriter_config_pb2.RewriterConfig.ON)
699      config = new_config
700    elif (config.graph_options.rewrite_options.auto_mixed_precision !=
701          rewriter_config_pb2.RewriterConfig.ON):
702      mixed_precision_global_state.non_mixed_precision_session_created = True
703
704    self._config = config
705    self._add_shapes = config.graph_options.infer_shapes
706
707    self._session = None
708    opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
709    try:
710      # pylint: disable=protected-access
711      self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts)
712      # pylint: enable=protected-access
713    finally:
714      tf_session.TF_DeleteSessionOptions(opts)
715
716  def list_devices(self):
717    """Lists available devices in this session.
718
719    ```python
720    devices = sess.list_devices()
721    for d in devices:
722      print(d.name)
723    ```
724
725    Where:
726      Each element in the list has the following properties
727      name: A string with the full name of the device. ex:
728          `/job:worker/replica:0/task:3/device:CPU:0`
729      device_type: The type of the device (e.g. `CPU`, `GPU`, `TPU`.)
730      memory_limit: The maximum amount of memory available on the device.
731          Note: depending on the device, it is possible the usable memory could
732          be substantially less.
733
734    Raises:
735      tf.errors.OpError: If it encounters an error (e.g. session is in an
736      invalid state, or network errors occur).
737
738    Returns:
739      A list of devices in the session.
740    """
741    raw_device_list = tf_session.TF_SessionListDevices(self._session)
742    device_list = []
743    size = tf_session.TF_DeviceListCount(raw_device_list)
744    for i in range(size):
745      name = tf_session.TF_DeviceListName(raw_device_list, i)
746      device_type = tf_session.TF_DeviceListType(raw_device_list, i)
747      memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
748      incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
749      device_list.append(
750          _DeviceAttributes(name, device_type, memory, incarnation))
751    tf_session.TF_DeleteDeviceList(raw_device_list)
752    return device_list
753
754  def close(self):
755    """Closes this session.
756
757    Calling this method frees all resources associated with the session.
758
759    Raises:
760      tf.errors.OpError: Or one of its subclasses if an error occurs while
761        closing the TensorFlow session.
762    """
763    if self._session and not self._closed:
764      self._closed = True
765      tf_session.TF_CloseSession(self._session)
766
767  def __del__(self):
768    # cleanly ignore all exceptions
769    try:
770      self.close()
771    except Exception:  # pylint: disable=broad-except
772      pass
773    if self._session is not None:
774      try:
775        tf_session.TF_DeleteSession(self._session)
776      except (AttributeError, TypeError):
777        # At shutdown, `c_api_util`, `tf_session`, or
778        # `tf_session.TF_DeleteSession` may have been garbage collected, causing
779        # the above method calls to fail. In this case, silently leak since the
780        # program is about to terminate anyway.
781        pass
782      self._session = None
783
784  @property
785  def graph(self):
786    """The graph that was launched in this session."""
787    return self._graph
788
789  @property
790  def graph_def(self):
791    """A serializable version of the underlying TensorFlow graph.
792
793    Returns:
794      A graph_pb2.GraphDef proto containing nodes for all of the Operations in
795      the underlying TensorFlow graph.
796    """
797    return self._graph.as_graph_def(add_shapes=self._add_shapes)
798
799  @property
800  def sess_str(self):
801    return self._target
802
803  def as_default(self):
804    """Returns a context manager that makes this object the default session.
805
806    Use with the `with` keyword to specify that calls to
807    `tf.Operation.run` or `tf.Tensor.eval` should be executed in
808    this session.
809
810    ```python
811    c = tf.constant(..)
812    sess = tf.compat.v1.Session()
813
814    with sess.as_default():
815      assert tf.compat.v1.get_default_session() is sess
816      print(c.eval())
817    ```
818
819    To get the current default session, use `tf.compat.v1.get_default_session`.
820
821    *N.B.* The `as_default` context manager *does not* close the
822    session when you exit the context, and you must close the session
823    explicitly.
824
825    ```python
826    c = tf.constant(...)
827    sess = tf.compat.v1.Session()
828    with sess.as_default():
829      print(c.eval())
830    # ...
831    with sess.as_default():
832      print(c.eval())
833
834    sess.close()
835    ```
836
837    Alternatively, you can use `with tf.compat.v1.Session():` to create a
838    session that is automatically closed on exiting the context,
839    including when an uncaught exception is raised.
840
841    *N.B.* The default session is a property of the current thread. If you
842    create a new thread, and wish to use the default session in that
843    thread, you must explicitly add a `with sess.as_default():` in that
844    thread's function.
845
846    *N.B.* Entering a `with sess.as_default():` block does not affect
847    the current default graph. If you are using multiple graphs, and
848    `sess.graph` is different from the value of
849    `tf.compat.v1.get_default_graph`, you must explicitly enter a
850    `with sess.graph.as_default():` block to make `sess.graph` the default
851    graph.
852
853    Returns:
854      A context manager using this session as the default session.
855    """
856    return ops.default_session(self)
857
858  def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
859    """Runs operations and evaluates tensors in `fetches`.
860
861    This method runs one "step" of TensorFlow computation, by
862    running the necessary graph fragment to execute every `Operation`
863    and evaluate every `Tensor` in `fetches`, substituting the values in
864    `feed_dict` for the corresponding input values.
865
866    The `fetches` argument may be a single graph element, or an arbitrarily
867    nested list, tuple, namedtuple, dict, or OrderedDict containing graph
868    elements at its leaves.  A graph element can be one of the following types:
869
870    * A `tf.Operation`.
871      The corresponding fetched value will be `None`.
872    * A `tf.Tensor`.
873      The corresponding fetched value will be a numpy ndarray containing the
874      value of that tensor.
875    * A `tf.sparse.SparseTensor`.
876      The corresponding fetched value will be a
877      `tf.compat.v1.SparseTensorValue`
878      containing the value of that sparse tensor.
879    * A `get_tensor_handle` op.  The corresponding fetched value will be a
880      numpy ndarray containing the handle of that tensor.
881    * A `string` which is the name of a tensor or operation in the graph.
882
883    The value returned by `run()` has the same shape as the `fetches` argument,
884    where the leaves are replaced by the corresponding values returned by
885    TensorFlow.
886
887    Example:
888
889    ```python
890       a = tf.constant([10, 20])
891       b = tf.constant([1.0, 2.0])
892       # 'fetches' can be a singleton
893       v = session.run(a)
894       # v is the numpy array [10, 20]
895       # 'fetches' can be a list.
896       v = session.run([a, b])
897       # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
898       # 1-D array [1.0, 2.0]
899       # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
900       MyData = collections.namedtuple('MyData', ['a', 'b'])
901       v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
902       # v is a dict with
903       # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
904       # 'b' (the numpy array [1.0, 2.0])
905       # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
906       # [10, 20].
907    ```
908
909    The optional `feed_dict` argument allows the caller to override
910    the value of tensors in the graph. Each key in `feed_dict` can be
911    one of the following types:
912
913    * If the key is a `tf.Tensor`, the
914      value may be a Python scalar, string, list, or numpy ndarray
915      that can be converted to the same `dtype` as that
916      tensor. Additionally, if the key is a
917      `tf.compat.v1.placeholder`, the shape of
918      the value will be checked for compatibility with the placeholder.
919    * If the key is a
920      `tf.sparse.SparseTensor`,
921      the value should be a
922      `tf.compat.v1.SparseTensorValue`.
923    * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
924      should be a nested tuple with the same structure that maps to their
925      corresponding values as above.
926
927    Each value in `feed_dict` must be convertible to a numpy array of the dtype
928    of the corresponding key.
929
930    The optional `options` argument expects a [`RunOptions`] proto. The options
931    allow controlling the behavior of this particular step (e.g. turning tracing
932    on).
933
934    The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
935    appropriate, the non-Tensor output of this step will be collected there. For
936    example, when users turn on tracing in `options`, the profiled info will be
937    collected into this argument and passed back.
938
939    Args:
940      fetches: A single graph element, a list of graph elements, or a dictionary
941        whose values are graph elements or lists of graph elements (described
942        above).
943      feed_dict: A dictionary that maps graph elements to values (described
944        above).
945      options: A [`RunOptions`] protocol buffer
946      run_metadata: A [`RunMetadata`] protocol buffer
947
948    Returns:
949      Either a single value if `fetches` is a single graph element, or
950      a list of values if `fetches` is a list, or a dictionary with the
951      same keys as `fetches` if that is a dictionary (described above).
952      Order in which `fetches` operations are evaluated inside the call
953      is undefined.
954
955    Raises:
956      RuntimeError: If this `Session` is in an invalid state (e.g. has been
957        closed).
958      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
959      ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
960        `Tensor` that doesn't exist.
961    """
962    options_ptr = tf_session.TF_NewBufferFromString(
963        compat.as_bytes(options.SerializeToString())) if options else None
964    run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
965
966    try:
967      result = self._run(None, fetches, feed_dict, options_ptr,
968                         run_metadata_ptr)
969      if run_metadata:
970        proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
971        run_metadata.ParseFromString(compat.as_bytes(proto_data))
972    finally:
973      if run_metadata_ptr:
974        tf_session.TF_DeleteBuffer(run_metadata_ptr)
975      if options:
976        tf_session.TF_DeleteBuffer(options_ptr)
977    return result
978
979  def partial_run(self, handle, fetches, feed_dict=None):
980    """Continues the execution with more feeds and fetches.
981
982    This is EXPERIMENTAL and subject to change.
983
984    To use partial execution, a user first calls `partial_run_setup()` and
985    then a sequence of `partial_run()`. `partial_run_setup` specifies the
986    list of feeds and fetches that will be used in the subsequent
987    `partial_run` calls.
988
989    The optional `feed_dict` argument allows the caller to override
990    the value of tensors in the graph. See run() for more information.
991
992    Below is a simple example:
993
994    ```python
995    a = array_ops.placeholder(dtypes.float32, shape=[])
996    b = array_ops.placeholder(dtypes.float32, shape=[])
997    c = array_ops.placeholder(dtypes.float32, shape=[])
998    r1 = math_ops.add(a, b)
999    r2 = math_ops.multiply(r1, c)
1000
1001    h = sess.partial_run_setup([r1, r2], [a, b, c])
1002    res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
1003    res = sess.partial_run(h, r2, feed_dict={c: res})
1004    ```
1005
1006    Args:
1007      handle: A handle for a sequence of partial runs.
1008      fetches: A single graph element, a list of graph elements, or a dictionary
1009        whose values are graph elements or lists of graph elements (see
1010        documentation for `run`).
1011      feed_dict: A dictionary that maps graph elements to values (described
1012        above).
1013
1014    Returns:
1015      Either a single value if `fetches` is a single graph element, or
1016      a list of values if `fetches` is a list, or a dictionary with the
1017      same keys as `fetches` if that is a dictionary
1018      (see documentation for `run`).
1019
1020    Raises:
1021      tf.errors.OpError: Or one of its subclasses on error.
1022    """
1023    # TODO(touts): Support feeding and fetching the same tensor.
1024    return self._run(handle, fetches, feed_dict, None, None)
1025
1026  def partial_run_setup(self, fetches, feeds=None):
1027    """Sets up a graph with feeds and fetches for partial run.
1028
1029    This is EXPERIMENTAL and subject to change.
1030
1031    Note that contrary to `run`, `feeds` only specifies the graph elements.
1032    The tensors will be supplied by the subsequent `partial_run` calls.
1033
1034    Args:
1035      fetches: A single graph element, or a list of graph elements.
1036      feeds: A single graph element, or a list of graph elements.
1037
1038    Returns:
1039      A handle for partial run.
1040
1041    Raises:
1042      RuntimeError: If this `Session` is in an invalid state (e.g. has been
1043        closed).
1044      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
1045      tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.
1046    """
1047
1048    def _feed_fn(feed):
1049      for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
1050        if isinstance(feed, tensor_type):
1051          return feed_fn(feed)
1052      raise TypeError('Feed argument %r has invalid type %r' %
1053                      (feed, type(feed)))
1054
1055    # Check session.
1056    if self._closed:
1057      raise RuntimeError('Attempted to use a closed Session.')
1058    if self.graph.version == 0:
1059      raise RuntimeError('The Session graph is empty.  Add operations to the '
1060                         'graph before calling run().')
1061
1062    if feeds is None:
1063      feeds = []
1064    # Create request.
1065    feed_list = []
1066
1067    # Validate and process feed_list.
1068    is_list_feed = isinstance(feeds, (list, tuple))
1069    if not is_list_feed:
1070      feeds = [feeds]
1071    for feed in feeds:
1072      for subfeed in _feed_fn(feed):
1073        try:
1074          subfeed_t = self.graph.as_graph_element(
1075              subfeed, allow_tensor=True, allow_operation=False)
1076          # pylint: disable=protected-access
1077          feed_list.append(subfeed_t._as_tf_output())
1078          # pylint: enable=protected-access
1079        except Exception as e:
1080          e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message)
1081          e.args = (e.message,)
1082          raise e
1083
1084    # Validate and process fetches.
1085    # TODO(touts): Support feeding and fetching the same tensor.
1086    fetch_handler = _FetchHandler(self._graph, fetches, {})
1087
1088    # Set up a graph with feeds and fetches for partial run.
1089    def _setup_fn(session, feed_list, fetch_list, target_list):
1090      self._extend_graph()
1091      return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list,
1092                                                    fetch_list, target_list)
1093
1094    # pylint: disable=protected-access
1095    final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
1096    final_targets = [op._c_op for op in fetch_handler.targets()]
1097    # pylint: enable=protected-access
1098
1099    return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
1100                         final_targets)
1101
1102  def _run(self, handle, fetches, feed_dict, options, run_metadata):
1103    """Perform either run or partial_run, depending the presence of `handle`."""
1104
1105    def _feed_fn(feed, feed_val):
1106      for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
1107        if isinstance(feed, tensor_type):
1108          return feed_fn(feed, feed_val)
1109      raise TypeError('Feed argument %r has invalid type %r' %
1110                      (feed, type(feed)))
1111
1112    # Check session.
1113    if self._closed:
1114      raise RuntimeError('Attempted to use a closed Session.')
1115    if self.graph.version == 0:
1116      raise RuntimeError('The Session graph is empty.  Add operations to the '
1117                         'graph before calling run().')
1118
1119    # Create request.
1120    feed_dict_tensor = {}
1121    feed_map = {}
1122
1123    # Validate and process feed_dict.
1124    feed_handles = {}
1125    if feed_dict:
1126      feed_dict = nest.flatten_dict_items(feed_dict)
1127      for feed, feed_val in feed_dict.items():
1128        for subfeed, subfeed_val in _feed_fn(feed, feed_val):
1129          try:
1130            subfeed_t = self.graph.as_graph_element(
1131                subfeed, allow_tensor=True, allow_operation=False)
1132          except Exception as e:
1133            raise TypeError('Cannot interpret feed_dict key as Tensor: ' +
1134                            e.args[0])
1135
1136          if isinstance(subfeed_val, ops.Tensor):
1137            raise TypeError('The value of a feed cannot be a tf.Tensor object. '
1138                            'Acceptable feed values include Python scalars, '
1139                            'strings, lists, numpy ndarrays, or TensorHandles. '
1140                            'For reference, the tensor object was ' +
1141                            str(feed_val) + ' which was passed to the '
1142                            'feed with key ' + str(feed) + '.')
1143
1144          subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
1145          if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
1146              subfeed_dtype, subfeed_val) != subfeed_val:
1147            raise TypeError(
1148                'Type of feed value ' + str(subfeed_val) + ' with type ' +
1149                str(type(subfeed_val)) +
1150                ' is not compatible with Tensor type ' + str(subfeed_dtype) +
1151                '. Try explicitly setting the type of the feed tensor'
1152                ' to a larger type (e.g. int64).')
1153
1154          is_tensor_handle_feed = isinstance(subfeed_val,
1155                                             session_ops.TensorHandle)
1156          if is_tensor_handle_feed:
1157            np_val = subfeed_val.to_numpy_array()
1158            feed_handles[subfeed_t.ref()] = subfeed_val
1159          else:
1160            np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
1161
1162          if (not is_tensor_handle_feed and
1163              not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
1164            raise ValueError(
1165                'Cannot feed value of shape %r for Tensor %r, '
1166                'which has shape %r' %
1167                (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
1168          if not self.graph.is_feedable(subfeed_t):
1169            raise ValueError('Tensor %s may not be fed.' % subfeed_t)
1170
1171          feed_dict_tensor[subfeed_t.ref()] = np_val
1172          feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
1173
1174    # Create a fetch handler to take care of the structure of fetches.
1175    fetch_handler = _FetchHandler(
1176        self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1177
1178    # Run request and get response.
1179    # We need to keep the returned movers alive for the following _do_run().
1180    # These movers are no longer needed when _do_run() completes, and
1181    # are deleted when `movers` goes out of scope when this _run() ends.
1182    # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
1183    # of a handle from a different device as an error.
1184    _ = self._update_with_movers(feed_dict_tensor, feed_map)
1185    final_fetches = fetch_handler.fetches()
1186    final_targets = fetch_handler.targets()
1187    # We only want to really perform the run if fetches or targets are provided,
1188    # or if the call is a partial run that specifies feeds.
1189    if final_fetches or final_targets or (handle and feed_dict_tensor):
1190      results = self._do_run(handle, final_targets, final_fetches,
1191                             feed_dict_tensor, options, run_metadata)
1192    else:
1193      results = []
1194    return fetch_handler.build_results(self, results)
1195
1196  def make_callable(self, fetches, feed_list=None, accept_options=False):
1197    """Returns a Python callable that runs a particular step.
1198
1199    The returned callable will take `len(feed_list)` arguments whose types
1200    must be compatible feed values for the respective elements of `feed_list`.
1201    For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
1202    argument to the returned callable must be a numpy ndarray (or something
1203    convertible to an ndarray) with matching element type and shape. See
1204    `tf.Session.run` for details of the allowable feed key and value types.
1205
1206    The returned callable will have the same return type as
1207    `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
1208    the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
1209    it will return `None`.
1210
1211    Args:
1212      fetches: A value or list of values to fetch. See `tf.Session.run` for
1213        details of the allowable fetch types.
1214      feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run`
1215        for details of the allowable feed key types.
1216      accept_options: (Optional.) If `True`, the returned `Callable` will be
1217        able to accept `tf.compat.v1.RunOptions` and `tf.compat.v1.RunMetadata`
1218        as optional keyword arguments `options` and `run_metadata`,
1219        respectively, with the same syntax and semantics as `tf.Session.run`,
1220        which is useful for certain use cases (profiling and debugging) but will
1221        result in measurable slowdown of the `Callable`'s
1222        performance. Default: `False`.
1223
1224    Returns:
1225      A function that when called will execute the step defined by
1226      `feed_list` and `fetches` in this session.
1227
1228    Raises:
1229      TypeError: If `fetches` or `feed_list` cannot be interpreted
1230        as arguments to `tf.Session.run`.
1231    """
1232    if feed_list is not None:
1233      if not isinstance(feed_list, (list, tuple)):
1234        raise TypeError('`feed_list` must be a list or tuple.')
1235      # Delegate any non-empty feed lists to the existing `run()` logic.
1236      # TODO(mrry): Refactor the feed handling logic from
1237      # `Session._run()` so that we can convert the feeds to a list of
1238      # strings here.
1239      def _generic_run(*feed_args, **kwargs):
1240        feed_dict = {
1241            feed: feed_val for feed, feed_val in zip(feed_list, feed_args)
1242        }
1243        return self.run(fetches, feed_dict=feed_dict, **kwargs)
1244
1245      return _generic_run
1246
1247    # Ensure any changes to the graph are reflected in the runtime.
1248    # Note that we don't need to do this on subsequent calls to the
1249    # returned object, because the arguments to `fetches` must already be
1250    # in the graph.
1251    self._extend_graph()
1252
1253    # Create a fetch handler to take care of the structure of fetches.
1254    fetch_handler = _FetchHandler(self._graph, fetches, {})
1255    # pylint: disable=protected-access
1256    fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
1257    target_list = [op._c_op for op in fetch_handler.targets()]
1258
1259    # pylint: enable=protected-access
1260
1261    def _callable_template_with_options_and_metadata(fetch_list,
1262                                                     target_list,
1263                                                     fetch_handler,
1264                                                     options=None,
1265                                                     run_metadata=None):
1266      """Template callable that accepts RunOptions and RunMetadata."""
1267      options_ptr = tf_session.TF_NewBufferFromString(
1268          compat.as_bytes(options.SerializeToString())) if options else None
1269      run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1270      try:
1271        results = self._call_tf_sessionrun(options_ptr, {}, fetch_list,
1272                                           target_list, run_metadata_ptr)
1273        if fetch_handler:
1274          results = fetch_handler.build_results(self, results)
1275        else:
1276          results = results[0] if results else None
1277        if run_metadata:
1278          proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1279          run_metadata.ParseFromString(compat.as_bytes(proto_data))
1280      finally:
1281        if run_metadata_ptr:
1282          tf_session.TF_DeleteBuffer(run_metadata_ptr)
1283        if options:
1284          tf_session.TF_DeleteBuffer(options_ptr)
1285      return results
1286
1287    if accept_options:
1288      return functools.partial(_callable_template_with_options_and_metadata,
1289                               fetch_list, target_list, fetch_handler)
1290    elif isinstance(fetches, ops.Operation):
1291      # Special case for fetching a single operation, because the
1292      # function will have no return value.
1293      assert not fetch_list
1294      assert len(target_list) == 1
1295
1296      def _single_operation_run():
1297        self._call_tf_sessionrun(None, {}, [], target_list, None)
1298
1299      return _single_operation_run
1300    elif isinstance(fetches, ops.Tensor):
1301      # Special case for fetching a single tensor, because the
1302      # function can return the result of `TF_Run()` directly.
1303      assert len(fetch_list) == 1
1304      assert not target_list
1305
1306      def _single_tensor_run():
1307        results = self._call_tf_sessionrun(None, {}, fetch_list, [], None)
1308        return results[0]
1309
1310      return _single_tensor_run
1311    else:
1312      # In all other cases, we must use `fetch_handler` to build the
1313      # results for us.
1314      def _fetch_handler_run():
1315        results = self._call_tf_sessionrun(None, {}, fetch_list, target_list,
1316                                           None)
1317        return fetch_handler.build_results(self, results)
1318
1319      return _fetch_handler_run
1320
1321  # Captures the name of a node in an error status. The regex below matches
1322  # both the old and the new formats:
1323  # Old format: [[Node: <node_name> = ...]]
1324  # New format: [[{{node <node_name>}} = ...]]
1325  _NODEDEF_NAME_RE = re.compile(
1326      r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*')
1327
1328  def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
1329              run_metadata):
1330    """Runs a step based on the given fetches and feeds.
1331
1332    Args:
1333      handle: a handle for partial_run. None if this is just a call to run().
1334      target_list: A list of operations to be run, but not fetched.
1335      fetch_list: A list of tensors to be fetched.
1336      feed_dict: A dictionary that maps tensors to numpy ndarrays.
1337      options: A (pointer to a) [`RunOptions`] protocol buffer, or None
1338      run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None
1339
1340    Returns:
1341      A list of numpy ndarrays, corresponding to the elements of
1342      `fetch_list`.  If the ith element of `fetch_list` contains the
1343      name of an operation, the first Tensor output of that operation
1344      will be returned for that element.
1345
1346    Raises:
1347      tf.errors.OpError: Or one of its subclasses on error.
1348    """
1349    # pylint: disable=protected-access
1350    feeds = dict((t.deref()._as_tf_output(), v) for t, v in feed_dict.items())
1351    fetches = [t._as_tf_output() for t in fetch_list]
1352    targets = [op._c_op for op in target_list]
1353
1354    # pylint: enable=protected-access
1355
1356    def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
1357      # Ensure any changes to the graph are reflected in the runtime.
1358      self._extend_graph()
1359      return self._call_tf_sessionrun(options, feed_dict, fetch_list,
1360                                      target_list, run_metadata)
1361
1362    def _prun_fn(handle, feed_dict, fetch_list):
1363      if target_list:
1364        raise RuntimeError('partial_run() requires empty target_list.')
1365      return self._call_tf_sessionprun(handle, feed_dict, fetch_list)
1366
1367    if handle is None:
1368      return self._do_call(_run_fn, feeds, fetches, targets, options,
1369                           run_metadata)
1370    else:
1371      return self._do_call(_prun_fn, handle, feeds, fetches)
1372
1373  def _do_call(self, fn, *args):
1374    try:
1375      return fn(*args)
1376    except errors.OpError as e:
1377      message = compat.as_text(e.message)
1378      m = BaseSession._NODEDEF_NAME_RE.search(message)
1379      node_def = None
1380      op = None
1381      if m is not None:
1382        node_name = m.group(3)
1383        try:
1384          op = self._graph.get_operation_by_name(node_name)
1385          node_def = op.node_def
1386        except KeyError:
1387          pass
1388      message = error_interpolation.interpolate(message, self._graph)
1389      if 'only supports NHWC tensor format' in message:
1390        message += ('\nA possible workaround: Try disabling Grappler optimizer'
1391                    '\nby modifying the config for creating the session eg.'
1392                    '\nsession_config.graph_options.rewrite_options.'
1393                    'disable_meta_optimizer = True')
1394      raise type(e)(node_def, op, message)
1395
1396  def _extend_graph(self):
1397    with self._graph._session_run_lock():  # pylint: disable=protected-access
1398      tf_session.ExtendSession(self._session)
1399
1400  # The threshold to run garbage collection to delete dead tensors.
1401  _DEAD_HANDLES_THRESHOLD = 10
1402
1403  def _register_dead_handle(self, handle):
1404    # Register a dead handle in the session. Delete the dead tensors when
1405    # the number of dead tensors exceeds certain threshold.
1406    tensors_to_delete = None
1407    with self._delete_lock:
1408      self._dead_handles.append(handle)
1409      if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD:
1410        tensors_to_delete = self._dead_handles
1411        self._dead_handles = []
1412    # Delete the dead tensors.
1413    if tensors_to_delete:
1414      feeds = {}
1415      fetches = []
1416      for deleter_key, tensor_handle in enumerate(tensors_to_delete):
1417        holder, deleter = session_ops._get_handle_deleter(
1418            self.graph, deleter_key, tensor_handle)
1419        feeds[holder] = tensor_handle
1420        fetches.append(deleter)
1421      self.run(fetches, feed_dict=feeds)
1422
1423  def _update_with_movers(self, feed_dict, feed_map):
1424    # If a tensor handle that is fed to a device incompatible placeholder,
1425    # we move the tensor to the right device, generate a new tensor handle,
1426    # and update `feed_dict` to use the new handle.
1427    handle_movers = []
1428    for feed_name, val in feed_map.items():
1429      mover = session_ops._get_handle_mover(self.graph, *val)
1430      if mover:
1431        handle_movers.append((feed_name, val[1], mover))
1432    # Transfer a tensor to the right device if needed.
1433    if not handle_movers:
1434      return []
1435    else:
1436      feeds = {}
1437      fetches = []
1438      for _, handle, mover in handle_movers:
1439        feeds[mover[0]] = handle
1440        fetches.append(mover[1])
1441      handles = self.run(fetches, feed_dict=feeds)
1442      for handle_mover, handle in zip(handle_movers, handles):
1443        np_val = np.array(handle.handle, dtype=np.object)
1444        feed_name = handle_mover[0]
1445        feed_tensor = feed_map[feed_name][0]
1446        feed_dict[feed_tensor.ref()] = np_val
1447      return handles
1448
1449  def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
1450                          run_metadata):
1451    return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
1452                                            fetch_list, target_list,
1453                                            run_metadata)
1454
1455  def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
1456    return tf_session.TF_SessionPRun_wrapper(self._session, handle, feed_dict,
1457                                             fetch_list)
1458
1459  # pylint: disable=protected-access
1460  class _Callable(object):
1461    """Experimental wrapper for the C++ `Session::MakeCallable()` API."""
1462
1463    def __init__(self, session, callable_options):
1464      self._session = session
1465      self._handle = None
1466      options_ptr = tf_session.TF_NewBufferFromString(
1467          compat.as_bytes(callable_options.SerializeToString()))
1468      try:
1469        self._handle = tf_session.TF_SessionMakeCallable(
1470            session._session, options_ptr)
1471      finally:
1472        tf_session.TF_DeleteBuffer(options_ptr)
1473
1474    def __call__(self, *args, **kwargs):
1475      # TODO(b/74355905): Support argument and return value nested structures,
1476      # and tensor-like objects such as SparseTensors.
1477      run_metadata = kwargs.get('run_metadata', None)
1478      try:
1479        run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1480        ret = tf_session.TF_SessionRunCallable(self._session._session,
1481                                               self._handle, args,
1482                                               run_metadata_ptr)
1483        if run_metadata:
1484          proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1485          run_metadata.ParseFromString(compat.as_bytes(proto_data))
1486      finally:
1487        if run_metadata_ptr:
1488          tf_session.TF_DeleteBuffer(run_metadata_ptr)
1489      return ret
1490
1491    def __del__(self):
1492      # NOTE(mrry): It is possible that `self._session.__del__()` could be
1493      # called before this destructor, in which case `self._session._session`
1494      # will be `None`.
1495      if (self._handle is not None and self._session._session is not None and
1496          not self._session._closed):
1497        tf_session.TF_SessionReleaseCallable(self._session._session,
1498                                             self._handle)
1499
1500  # pylint: enable=protected-access
1501
1502  # TODO(b/74355905): Reimplement `Session.make_callable()` using this method
1503  # where possible.
1504  def _make_callable_from_options(self, callable_options):
1505    """Returns a handle to a "callable" with the given options.
1506
1507    Args:
1508      callable_options: A `CallableOptions` protocol buffer message describing
1509        the computation that will be performed by the callable.
1510
1511    Returns:
1512      A handle to the new callable.
1513    """
1514    self._extend_graph()
1515    return BaseSession._Callable(self, callable_options)
1516
1517
1518@tf_export(v1=['Session'])
1519class Session(BaseSession):
1520  """A class for running TensorFlow operations.
1521
1522  A `Session` object encapsulates the environment in which `Operation`
1523  objects are executed, and `Tensor` objects are evaluated. For
1524  example:
1525
1526  ```python
1527  tf.compat.v1.disable_eager_execution() # need to disable eager in TF2.x
1528  # Build a graph.
1529  a = tf.constant(5.0)
1530  b = tf.constant(6.0)
1531  c = a * b
1532
1533  # Launch the graph in a session.
1534  sess = tf.compat.v1.Session()
1535
1536  # Evaluate the tensor `c`.
1537  print(sess.run(c)) # prints 30.0
1538  ```
1539
1540  A session may own resources, such as
1541  `tf.Variable`, `tf.queue.QueueBase`,
1542  and `tf.compat.v1.ReaderBase`. It is important to release
1543  these resources when they are no longer required. To do this, either
1544  invoke the `tf.Session.close` method on the session, or use
1545  the session as a context manager. The following two examples are
1546  equivalent:
1547
1548  ```python
1549  # Using the `close()` method.
1550  sess = tf.compat.v1.Session()
1551  sess.run(...)
1552  sess.close()
1553
1554  # Using the context manager.
1555  with tf.compat.v1.Session() as sess:
1556    sess.run(...)
1557  ```
1558
1559  The
1560  [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1561  protocol buffer exposes various configuration options for a
1562  session. For example, to create a session that uses soft constraints
1563  for device placement, and log the resulting placement decisions,
1564  create a session as follows:
1565
1566  ```python
1567  # Launch the graph in a session that allows soft device placement and
1568  # logs the placement decisions.
1569  sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
1570      allow_soft_placement=True,
1571      log_device_placement=True))
1572  ```
1573  """
1574
1575  def __init__(self, target='', graph=None, config=None):
1576    """Creates a new TensorFlow session.
1577
1578    If no `graph` argument is specified when constructing the session,
1579    the default graph will be launched in the session. If you are
1580    using more than one graph (created with `tf.Graph()`) in the same
1581    process, you will have to use different sessions for each graph,
1582    but each graph can be used in multiple sessions. In this case, it
1583    is often clearer to pass the graph to be launched explicitly to
1584    the session constructor.
1585
1586    Args:
1587      target: (Optional.) The execution engine to connect to. Defaults to using
1588        an in-process engine. See
1589        [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) for
1590          more examples.
1591      graph: (Optional.) The `Graph` to be launched (described above).
1592      config: (Optional.) A
1593        [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1594          protocol buffer with configuration options for the session.
1595    """
1596    super(Session, self).__init__(target, graph, config=config)
1597    # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.
1598    self._default_graph_context_manager = None
1599    self._default_session_context_manager = None
1600
1601  def __enter__(self):
1602    if self._default_graph_context_manager is None:
1603      self._default_graph_context_manager = self.graph.as_default()
1604    else:
1605      raise RuntimeError('Session context managers are not re-entrant. '
1606                         'Use `Session.as_default()` if you want to enter '
1607                         'a session multiple times.')
1608    if self._default_session_context_manager is None:
1609      self._default_session_context_manager = self.as_default()
1610    self._default_graph_context_manager.__enter__()
1611    return self._default_session_context_manager.__enter__()
1612
1613  def __exit__(self, exec_type, exec_value, exec_tb):
1614    if exec_type is errors.OpError:
1615      logging.error('Session closing due to OpError: %s', (exec_value,))
1616    try:
1617      self._default_session_context_manager.__exit__(exec_type, exec_value,
1618                                                     exec_tb)
1619    except RuntimeError as error:
1620      if error == exec_value:
1621        # NOTE(skyewm): for some reason, in Python3,
1622        # _default_session_context_manager.__exit__ will re-raise the "not
1623        # re-entrant" exception raised in __enter__ above (note that if we're
1624        # here, we're in the outer session context manager, since __exit__ is
1625        # not called when __enter__ raises an exception). We still want to
1626        # continue cleaning up this context manager before the exception is
1627        # further propagated, so we ignore it here (note that it'll continue
1628        # being propagated after this method completes).
1629        pass
1630      else:
1631        raise
1632    self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
1633
1634    self._default_session_context_manager = None
1635    self._default_graph_context_manager = None
1636
1637    # If we are closing due to an exception, set a time limit on our Close() to
1638    # avoid blocking forever.
1639    # TODO(b/120204635) remove this when deadlock is fixed.
1640    if exec_type:
1641      close_thread = threading.Thread(
1642          name='SessionCloseThread', target=self.close)
1643      close_thread.daemon = True
1644      close_thread.start()
1645      close_thread.join(30.0)
1646      if close_thread.is_alive():
1647        logging.error(
1648            'Session failed to close after 30 seconds. Continuing after this '
1649            'point may leave your program in an undefined state.')
1650    else:
1651      self.close()
1652
1653  @staticmethod
1654  def reset(target, containers=None, config=None):
1655    """Resets resource containers on `target`, and close all connected sessions.
1656
1657    A resource container is distributed across all workers in the
1658    same cluster as `target`.  When a resource container on `target`
1659    is reset, resources associated with that container will be cleared.
1660    In particular, all Variables in the container will become undefined:
1661    they lose their values and shapes.
1662
1663    NOTE:
1664    (i) reset() is currently only implemented for distributed sessions.
1665    (ii) Any sessions on the master named by `target` will be closed.
1666
1667    If no resource containers are provided, all containers are reset.
1668
1669    Args:
1670      target: The execution engine to connect to.
1671      containers: A list of resource container name strings, or `None` if all of
1672        all the containers are to be reset.
1673      config: (Optional.) Protocol buffer with configuration options.
1674
1675    Raises:
1676      tf.errors.OpError: Or one of its subclasses if an error occurs while
1677        resetting containers.
1678    """
1679    if target is not None:
1680      target = compat.as_bytes(target)
1681    if containers is not None:
1682      containers = [compat.as_bytes(c) for c in containers]
1683    else:
1684      containers = []
1685    tf_session.TF_Reset(target, containers, config)
1686
1687
1688@tf_export(v1=['InteractiveSession'])
1689class InteractiveSession(BaseSession):
1690  """A TensorFlow `Session` for use in interactive contexts, such as a shell.
1691
1692  The only difference with a regular `Session` is that an `InteractiveSession`
1693  installs itself as the default session on construction.
1694  The methods `tf.Tensor.eval`
1695  and `tf.Operation.run`
1696  will use that session to run ops.
1697
1698  This is convenient in interactive shells and [IPython
1699  notebooks](http://ipython.org), as it avoids having to pass an explicit
1700  `Session` object to run ops.
1701
1702  For example:
1703
1704  ```python
1705  sess = tf.compat.v1.InteractiveSession()
1706  a = tf.constant(5.0)
1707  b = tf.constant(6.0)
1708  c = a * b
1709  # We can just use 'c.eval()' without passing 'sess'
1710  print(c.eval())
1711  sess.close()
1712  ```
1713
1714  Note that a regular session installs itself as the default session when it
1715  is created in a `with` statement.  The common usage in non-interactive
1716  programs is to follow that pattern:
1717
1718  ```python
1719  a = tf.constant(5.0)
1720  b = tf.constant(6.0)
1721  c = a * b
1722  with tf.compat.v1.Session():
1723    # We can also use 'c.eval()' here.
1724    print(c.eval())
1725  ```
1726  """
1727
1728  _count_lock = threading.Lock()
1729  _active_session_count = 0  # GUARDED_BY(_count_lock)
1730
1731  def __init__(self, target='', graph=None, config=None):
1732    """Creates a new interactive TensorFlow session.
1733
1734    If no `graph` argument is specified when constructing the session,
1735    the default graph will be launched in the session. If you are
1736    using more than one graph (created with `tf.Graph()`) in the same
1737    process, you will have to use different sessions for each graph,
1738    but each graph can be used in multiple sessions. In this case, it
1739    is often clearer to pass the graph to be launched explicitly to
1740    the session constructor.
1741
1742    Args:
1743      target: (Optional.) The execution engine to connect to. Defaults to using
1744        an in-process engine.
1745      graph: (Optional.) The `Graph` to be launched (described above).
1746      config: (Optional) `ConfigProto` proto used to configure the session.
1747    """
1748    if not config:
1749      # If config is not provided, choose some reasonable defaults for
1750      # interactive use:
1751      #
1752      #   - Grow GPU memory as needed at the cost of fragmentation.
1753      gpu_options = config_pb2.GPUOptions(allow_growth=True)
1754      config = config_pb2.ConfigProto(gpu_options=gpu_options)
1755    # Interactive sessions always place pruned graphs.
1756    config.graph_options.place_pruned_graph = True
1757
1758    super(InteractiveSession, self).__init__(target, graph, config)
1759    with InteractiveSession._count_lock:
1760      if InteractiveSession._active_session_count > 0:
1761        warnings.warn('An interactive session is already active. This can '
1762                      'cause out-of-memory errors in some cases. You must '
1763                      'explicitly call `InteractiveSession.close()` to release '
1764                      'resources held by the other session(s).')
1765      InteractiveSession._active_session_count += 1
1766    # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful
1767    # semantics (in particular, it is not set to true if `Session.close()` is
1768    # called on a session that has not been "opened" by running a step) and we
1769    # cannot change those semantics without breaking existing code.
1770    self._explicitly_closed = False
1771
1772    self._default_session = self.as_default()
1773    self._default_session.enforce_nesting = False
1774    self._default_session.__enter__()
1775    self._explicit_graph = graph
1776    if self._explicit_graph is not None:
1777      self._default_graph = graph.as_default()
1778      self._default_graph.enforce_nesting = False
1779      self._default_graph.__enter__()
1780
1781  def close(self):
1782    """Closes an `InteractiveSession`."""
1783    super(InteractiveSession, self).close()
1784    with InteractiveSession._count_lock:
1785      if not self._explicitly_closed:
1786        InteractiveSession._active_session_count -= 1
1787        self._explicitly_closed = True
1788      else:
1789        return
1790    if self._explicit_graph is not None:
1791      self._default_graph.__exit__(None, None, None)
1792      self._default_graph = None
1793    self._default_session.__exit__(None, None, None)
1794    self._default_session = None
1795