1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A client interface for TensorFlow.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import re 24import threading 25import warnings 26 27import numpy as np 28import wrapt 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.core.protobuf import rewriter_config_pb2 32from tensorflow.python.client import pywrap_tf_session as tf_session 33from tensorflow.python.eager import context 34from tensorflow.python.eager import monitoring 35from tensorflow.python.framework import device 36from tensorflow.python.framework import error_interpolation 37from tensorflow.python.framework import errors 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.ops import session_ops 41from tensorflow.python.platform import tf_logging as logging 42from tensorflow.python.training.experimental import mixed_precision_global_state 43from tensorflow.python.util import compat 44from tensorflow.python.util import nest 45from tensorflow.python.util.compat import collections_abc 46from tensorflow.python.util.tf_export import tf_export 47 48_python_session_create_counter = monitoring.Counter( 49 '/tensorflow/api/python/session_create_counter', 50 'Counter for number of sessions created in Python.') 51 52 53class SessionInterface(object): 54 """Base class for implementations of TensorFlow client sessions.""" 55 56 @property 57 def graph(self): 58 """The underlying TensorFlow graph, to be used in building Operations.""" 59 raise NotImplementedError('graph') 60 61 @property 62 def sess_str(self): 63 """The TensorFlow process to which this session will connect.""" 64 raise NotImplementedError('sess_str') 65 66 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 67 """Runs operations in the session. See `BaseSession.run()` for details.""" 68 raise NotImplementedError('run') 69 70 def partial_run_setup(self, fetches, feeds=None): 71 """Sets up the feeds and fetches for partial runs in the session.""" 72 raise NotImplementedError('partial_run_setup') 73 74 def partial_run(self, handle, fetches, feed_dict=None): 75 """Continues the execution with additional feeds and fetches.""" 76 raise NotImplementedError('partial_run') 77 78 79def _get_indexed_slices_value_from_fetches(fetched_vals): 80 return ops.IndexedSlicesValue( 81 fetched_vals[0], fetched_vals[1], 82 fetched_vals[2] if len(fetched_vals) == 3 else None) 83 84 85def _get_feeds_for_indexed_slices(feed, feed_val): 86 return list( 87 zip([feed.values, feed.indices] if feed.dense_shape is None else 88 [feed.values, feed.indices, feed.dense_shape], feed_val)) 89 90 91# List of extensions supported to convert run arguments into actual fetches and 92# feeds. 93# 94# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2), 95# where the function signatures are: 96# fetch_fn : Type -> (list of Tensors, 97# lambda: list of fetched np.ndarray -> TypeVal) 98# feed_fn1 : Type, TypeVal -> list of (Tensor, value) 99# feed_fn2 : Type -> list of Tensors 100# 101# `fetch_fn` describes how to expand fetch into its 102# component Tensors and how to contract the fetched results back into 103# a single return value. 104# 105# Each feed function describes how to unpack a single fed value and map it to 106# feeds of one or more tensors and their corresponding values: `feed_fn1` is 107# used to feed a run, `feed_fn2` to set up a partial run. 108# 109# TODO(touts): We could reimplement these as specialized _FeedMapper 110# implementations after we refactor the feed handling code to use them. 111# 112# Eventually, this registration could be opened up to support custom Tensor 113# expansions. 114# pylint: disable=g-long-lambda 115_REGISTERED_EXPANSIONS = [ 116 # SparseTensors are fetched as SparseTensorValues. They can be fed 117 # SparseTensorValues or normal tuples. 118 (sparse_tensor.SparseTensor, lambda fetch: ([ 119 fetch.indices, fetch.values, fetch.dense_shape 120 ], lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)), 121 lambda feed, feed_val: list( 122 zip([feed.indices, feed.values, feed.dense_shape], feed_val)), 123 lambda feed: [feed.indices, feed.values, feed.dense_shape]), 124 # IndexedSlices are fetched as IndexedSlicesValues. They can be fed 125 # IndexedSlicesValues or normal tuples. 126 (ops.IndexedSlices, 127 lambda fetch: ([fetch.values, fetch.indices] if fetch.dense_shape is None 128 else [fetch.values, fetch.indices, fetch.dense_shape 129 ], _get_indexed_slices_value_from_fetches), 130 _get_feeds_for_indexed_slices, 131 lambda feed: [feed.values, feed.indices] if feed.dense_shape is None else 132 [feed.values, feed.indices, feed.dense_shape]), 133 # The default catches all other types and performs no expansions. 134 (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), 135 lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed]) 136] 137 138# pylint: enable=g-long-lambda 139 140 141def _convert_to_numpy_obj(numpy_dtype, obj): 142 """Explicitly convert obj based on numpy type except for string type.""" 143 return numpy_dtype(obj) if numpy_dtype is not object else str(obj) 144 145 146def register_session_run_conversion_functions( 147 tensor_type, 148 fetch_function, 149 feed_function=None, 150 feed_function_for_partial_run=None): 151 """Register fetch and feed conversion functions for `tf.Session.run()`. 152 153 This function registers a triple of conversion functions for fetching and/or 154 feeding values of user-defined types in a call to tf.Session.run(). 155 156 An example 157 158 ```python 159 class SquaredTensor(object): 160 def __init__(self, tensor): 161 self.sq = tf.square(tensor) 162 #you can define conversion functions as follows: 163 fetch_function = lambda squared_tensor:([squared_tensor.sq], 164 lambda val: val[0]) 165 feed_function = lambda feed, feed_val: [(feed.sq, feed_val)] 166 feed_function_for_partial_run = lambda feed: [feed.sq] 167 #then after invoking this register function, you can use as follows: 168 session.run(squared_tensor1, 169 feed_dict = {squared_tensor2 : some_numpy_array}) 170 ``` 171 172 Args: 173 tensor_type: The type for which you want to register a conversion function. 174 fetch_function: A callable that takes an object of type `tensor_type` and 175 returns a tuple, where the first element is a list of `tf.Tensor` objects, 176 and the second element is a callable that takes a list of ndarrays and 177 returns an object of some value type that corresponds to `tensor_type`. 178 fetch_function describes how to expand fetch into its component Tensors 179 and how to contract the fetched results back into a single return value. 180 feed_function: A callable that takes feed_key and feed_value as input, and 181 returns a list of tuples (feed_tensor, feed_val), feed_key must have type 182 `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed 183 function describes how to unpack a single fed value and map it to feeds of 184 one or more tensors and their corresponding values. 185 feed_function_for_partial_run: A callable for specifying tensor values to 186 feed when setting up a partial run, which takes a `tensor_type` type 187 object as input, and returns a list of Tensors. 188 189 Raises: 190 ValueError: If `tensor_type` has already been registered. 191 """ 192 for conversion_function in _REGISTERED_EXPANSIONS: 193 if issubclass(conversion_function[0], tensor_type): 194 raise ValueError('%s has already been registered so ignore it.' % 195 tensor_type) 196 197 _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, 198 feed_function_for_partial_run)) 199 200 201def _is_attrs_instance(obj): 202 """Returns True if the given obj is an instance of attrs-decorated class.""" 203 return getattr(obj.__class__, '__attrs_attrs__', None) is not None 204 205 206def _get_attrs_values(obj): 207 """Returns the list of values from an attrs instance.""" 208 attrs = getattr(obj.__class__, '__attrs_attrs__') 209 return [getattr(obj, a.name) for a in attrs] 210 211 212class _FetchMapper(object): 213 """Definition of the interface provided by fetch mappers. 214 215 Fetch mappers are utility classes used by the _FetchHandler to handle 216 arbitrary structures for the `fetch` argument to `Session.run()`. 217 218 The `fetch` argument can be of various shapes: single tensor or op, list of 219 fetches, tuple of fetches, namedtuple of fetches, or dict of fetches. The 220 structures can be arbitrarily nested. 221 222 The low level run() API only wants a list of tensor or op names. The various 223 `_FetchMapper` subclasses below take care of handling the different shapes: 224 uniquifying the fetches, and constructing results with the original shape. 225 """ 226 227 def unique_fetches(self): 228 """Return the list of unique tensors or ops needed by this fetch mapper. 229 230 Returns: 231 A list of tensors or ops. 232 """ 233 raise NotImplementedError('Must be implemented by subclasses') 234 235 def build_results(self, values): 236 """Build results that match the original shape of the fetch. 237 238 Args: 239 values: List of values returned by run(). The values correspond exactly to 240 the list tensors or ops returned by unique_fetches(). 241 242 Returns: 243 A struct of the same shape as the original fetch object handled by 244 this fetch mapper. In the returned struct, the original fetches are 245 replaced by their fetched values. 246 """ 247 raise NotImplementedError('Must be implemented by subclasses') 248 249 @staticmethod 250 def for_fetch(fetch): 251 """Creates fetch mapper that handles the structure of `fetch`. 252 253 The default graph must be the one from which we want to fetch values when 254 this function is called. 255 256 Args: 257 fetch: An arbitrary fetch structure: singleton, list, tuple, namedtuple, 258 or dict. 259 260 Returns: 261 An instance of a subclass of `_FetchMapper` that handles the shape. 262 """ 263 if fetch is None: 264 raise TypeError('Fetch argument %r has invalid type %r' % 265 (fetch, type(fetch))) 266 elif isinstance(fetch, (list, tuple)): 267 # NOTE(touts): This is also the code path for namedtuples. 268 return _ListFetchMapper(fetch) 269 elif isinstance(fetch, collections_abc.Mapping): 270 return _DictFetchMapper(fetch) 271 elif _is_attrs_instance(fetch): 272 return _AttrsFetchMapper(fetch) 273 else: 274 # Look for a handler in the registered expansions. 275 for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS: 276 if isinstance(fetch, tensor_type): 277 fetches, contraction_fn = fetch_fn(fetch) 278 return _ElementFetchMapper(fetches, contraction_fn) 279 # Did not find anything. 280 raise TypeError('Fetch argument %r has invalid type %r' % 281 (fetch, type(fetch))) 282 283 284class _ElementFetchMapper(_FetchMapper): 285 """Fetch mapper for singleton tensors and ops.""" 286 287 def __init__(self, fetches, contraction_fn): 288 """Creates an _ElementFetchMapper. 289 290 This is the fetch mapper used for leaves in the fetch struct. Because of 291 the expansions mechanism, a leaf can actually fetch more than one tensor. 292 293 Also note that the fetches here can be just strings (tensor or op names) or 294 any other object that the graph knows how to convert to a tensor, such as a 295 Variable. So we have to run each fetch through `as_graph_element()` to get 296 the corresponding tensor or op. 297 298 Args: 299 fetches: List of objects, as returned by a fetch_fn defined in 300 _REGISTERED_EXPANSIONS. 301 contraction_fn: Callable as returned by a fetch_fn. 302 """ 303 self._unique_fetches = [] 304 for fetch in fetches: 305 try: 306 self._unique_fetches.append(ops.get_default_graph().as_graph_element( 307 fetch, allow_tensor=True, allow_operation=True)) 308 except TypeError as e: 309 raise TypeError('Fetch argument %r has invalid type %r, ' 310 'must be a string or Tensor. (%s)' % 311 (fetch, type(fetch), str(e))) 312 except ValueError as e: 313 raise ValueError('Fetch argument %r cannot be interpreted as a ' 314 'Tensor. (%s)' % (fetch, str(e))) 315 except KeyError as e: 316 raise ValueError('Fetch argument %r cannot be interpreted as a ' 317 'Tensor. (%s)' % (fetch, str(e))) 318 self._contraction_fn = contraction_fn 319 320 def unique_fetches(self): 321 return self._unique_fetches 322 323 def build_results(self, values): 324 if not values: 325 # 'Operation' case 326 return None 327 else: 328 return self._contraction_fn(values) 329 330 331def _uniquify_fetches(fetch_mappers): 332 """Uniquifies fetches from a list of fetch_mappers. 333 334 This is a utility function used by _ListFetchMapper and _DictFetchMapper. It 335 gathers all the unique fetches from a list of mappers and builds a list 336 containing all of them but without duplicates (unique_fetches). 337 338 It also returns a 2-D list of integers (values_indices) indicating at which 339 index in unique_fetches the fetches of the mappers are located. 340 341 This list is as follows: 342 values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index 343 344 Args: 345 fetch_mappers: list of fetch mappers. 346 347 Returns: 348 A list of fetches. 349 A 2-D list of integers. 350 """ 351 unique_fetches = [] 352 value_indices = [] 353 seen_fetches = {} 354 for m in fetch_mappers: 355 m_value_indices = [] 356 for f in m.unique_fetches(): 357 j = seen_fetches.get(id(f)) 358 if j is None: 359 j = len(seen_fetches) 360 seen_fetches[id(f)] = j 361 unique_fetches.append(f) 362 m_value_indices.append(j) 363 value_indices.append(m_value_indices) 364 return unique_fetches, value_indices 365 366 367class _ListFetchMapper(_FetchMapper): 368 """Fetch mapper for lists, tuples, and namedtuples.""" 369 370 def __init__(self, fetches): 371 """Creates a _ListFetchMapper. 372 373 Args: 374 fetches: List, tuple, or namedtuple of fetches. 375 """ 376 if isinstance(fetches, wrapt.ObjectProxy): 377 self._fetch_type = type(fetches.__wrapped__) 378 else: 379 self._fetch_type = type(fetches) 380 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 381 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 382 383 def unique_fetches(self): 384 return self._unique_fetches 385 386 def build_results(self, values): 387 # Create the list of results for each mapper. 388 results = [] 389 for m, vi in zip(self._mappers, self._value_indices): 390 results.append(m.build_results([values[j] for j in vi])) 391 # Return a value of the original type of the fetches. 392 if issubclass(self._fetch_type, list): 393 return results 394 elif self._fetch_type == tuple: 395 return tuple(results) 396 else: 397 # This is the code path for namedtuple. 398 return self._fetch_type(*results) 399 400 401class _DictFetchMapper(_FetchMapper): 402 """Fetch mapper for dicts.""" 403 404 def __init__(self, fetches): 405 """Creates a _DictFetchMapper. 406 407 Args: 408 fetches: Dict of fetches. 409 """ 410 self._fetch_type = type(fetches) 411 if isinstance(fetches, collections.defaultdict): 412 self._type_ctor = functools.partial(collections.defaultdict, 413 fetches.default_factory) 414 else: 415 self._type_ctor = self._fetch_type 416 417 self._keys = fetches.keys() 418 self._mappers = [ 419 _FetchMapper.for_fetch(fetch) for fetch in fetches.values() 420 ] 421 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 422 423 def unique_fetches(self): 424 return self._unique_fetches 425 426 def build_results(self, values): 427 428 def _generator(): 429 for k, m, vi in zip(self._keys, self._mappers, self._value_indices): 430 yield k, m.build_results([values[j] for j in vi]) 431 432 return self._type_ctor(_generator()) 433 434 435class _AttrsFetchMapper(_FetchMapper): 436 """Fetch mapper for attrs decorated classes.""" 437 438 def __init__(self, fetches): 439 """Creates a _AttrsFetchMapper. 440 441 Args: 442 fetches: An instance of an attrs decorated class. 443 """ 444 values = _get_attrs_values(fetches) 445 self._fetch_type = type(fetches) 446 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in values] 447 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 448 449 def unique_fetches(self): 450 return self._unique_fetches 451 452 def build_results(self, values): 453 results = [] 454 for m, vi in zip(self._mappers, self._value_indices): 455 results.append(m.build_results([values[j] for j in vi])) 456 return self._fetch_type(*results) 457 458 459class _FetchHandler(object): 460 """Handler for structured fetches. 461 462 Given a graph, a user-provided structure for fetches, and a feed dict, this 463 class takes care of generating a list of tensor names to fetch and op names 464 to run for a low level `run()` call. 465 466 Given the results of the low level run call, this class can also rebuild a 467 result structure matching the user-provided structure for fetches, but 468 containing the corresponding results. 469 """ 470 471 # TODO(touts): Make this class also take care of destructuring the feed 472 # dict instead of doing it in the callers. 473 474 def __init__(self, graph, fetches, feeds, feed_handles=None): 475 """Creates a fetch handler. 476 477 Args: 478 graph: Graph of the fetches. Used to check for fetchability and to 479 convert all fetches to tensors or ops as needed. 480 fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple, 481 or dict. 482 feeds: A feed dict where keys are Tensors. 483 feed_handles: A dict from feed Tensors to TensorHandle objects used as 484 direct feeds. 485 """ 486 with graph.as_default(): 487 self._fetch_mapper = _FetchMapper.for_fetch(fetches) 488 self._fetches = [] 489 self._targets = [] 490 self._feeds = feeds 491 self._feed_handles = feed_handles or {} 492 self._ops = [] 493 self._fetch_handles = {} 494 for fetch in self._fetch_mapper.unique_fetches(): 495 if isinstance(fetch, ops.Operation): 496 self._assert_fetchable(graph, fetch) 497 self._targets.append(fetch) 498 self._ops.append(True) 499 else: 500 self._assert_fetchable(graph, fetch.op) 501 self._fetches.append(fetch) 502 self._ops.append(False) 503 # Remember the fetch if it is for a tensor handle. 504 if (isinstance(fetch, ops.Tensor) and 505 (fetch.op.type == 'GetSessionHandle' or 506 fetch.op.type == 'GetSessionHandleV2')): 507 self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype 508 self._final_fetches = [x for x in self._fetches if x.ref() not in feeds] 509 510 def _assert_fetchable(self, graph, op): 511 if not graph.is_fetchable(op): 512 raise errors.InaccessibleTensorError( 513 'Operation %r has been marked as not fetchable. Typically this' 514 ' happens when it is defined in another function or code block.' 515 ' Use return values,explicit Python locals or TensorFlow collections' 516 ' to access it.' 517 % op.name) 518 519 def fetches(self): 520 """Return the unique names of tensors to fetch. 521 522 Returns: 523 A list of strings. 524 """ 525 return self._final_fetches 526 527 def targets(self): 528 """Return the unique names of ops to run. 529 530 Returns: 531 A list of strings. 532 """ 533 return self._targets 534 535 def build_results(self, session, tensor_values): 536 """Build results matching the original fetch shape. 537 538 `tensor_values` must be a list of the same length as 539 the one returned by `fetches()`, and holding the requested 540 fetch values. 541 542 This method builds a struct with the same shape as the original `fetches` 543 passed to the constructor, in which the fetches are replaced by their 544 fetched value. 545 546 Args: 547 session: The enclosing session. Used for tensor handles. 548 tensor_values: List of values matching the list returned by fetches(). 549 550 Returns: 551 A structure of the same shape as the original `fetches` argument but 552 containing tensors or None (for fetched ops). 553 """ 554 full_values = [] 555 assert len(self._final_fetches) == len(tensor_values) 556 i = 0 557 j = 0 558 for is_op in self._ops: 559 if is_op: 560 full_values.append(None) 561 else: 562 # If the fetch was in the feeds, use the fed value, otherwise 563 # use the returned value. 564 if self._fetches[i].ref() in self._feed_handles: 565 # A fetch had a corresponding direct TensorHandle feed. Call eval() 566 # to obtain the Tensor value from the TensorHandle. 567 value = self._feed_handles[self._fetches[i].ref()].eval() 568 else: 569 value = self._feeds.get(self._fetches[i].ref()) 570 if value is None: 571 value = tensor_values[j] 572 j += 1 573 dtype = self._fetch_handles.get(self._fetches[i].ref()) 574 if dtype: 575 full_values.append(session_ops.TensorHandle(value, dtype, session)) 576 else: 577 full_values.append(value) 578 i += 1 579 assert j == len(tensor_values) 580 return self._fetch_mapper.build_results(full_values) 581 582 583def _name_list(tensor_list): 584 """Utility function for transitioning to the new session API. 585 586 Args: 587 tensor_list: a list of `Tensor`s. 588 589 Returns: 590 A list of each `Tensor`s name (as byte arrays). 591 """ 592 return [compat.as_bytes(t.name) for t in tensor_list] 593 594 595class _DeviceAttributes(object): 596 """Struct-like object describing a device's attributes. 597 598 Each device has 3 key properties: 599 - name: the fully-qualified TensorFlow path to the device. For 600 example: /job:worker/replica:0/task:3/device:CPU:0 601 - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.) 602 - memory_limit_bytes: the maximum amount of memory available on the device 603 (in bytes). 604 """ 605 606 def __init__(self, name, device_type, memory_limit_bytes, incarnation): 607 self._name = device.canonical_name(name) 608 self._device_type = device_type 609 self._memory_limit_bytes = memory_limit_bytes 610 self._incarnation = incarnation 611 612 @property 613 def name(self): 614 return self._name 615 616 @property 617 def device_type(self): 618 return self._device_type 619 620 @property 621 def memory_limit_bytes(self): 622 return self._memory_limit_bytes 623 624 @property 625 def incarnation(self): 626 return self._incarnation 627 628 def __repr__(self): 629 return '_DeviceAttributes(%s, %s, %d, %d)' % ( 630 self.name, 631 self.device_type, 632 self.memory_limit_bytes, 633 self.incarnation, 634 ) 635 636 637class BaseSession(SessionInterface): 638 """A class for interacting with a TensorFlow computation. 639 640 The BaseSession enables incremental graph building with inline 641 execution of Operations and evaluation of Tensors. 642 """ 643 644 def __init__(self, target='', graph=None, config=None): 645 """Constructs a new TensorFlow session. 646 647 Args: 648 target: (Optional) The TensorFlow execution engine to connect to. 649 graph: (Optional) The graph to be used. If this argument is None, the 650 default graph will be used. 651 config: (Optional) ConfigProto proto used to configure the session. If no 652 config is specified, the global default will be used. The global default 653 can be configured via the tf.config APIs. 654 655 Raises: 656 tf.errors.OpError: Or one of its subclasses if an error occurs while 657 creating the TensorFlow session. 658 TypeError: If one of the arguments has the wrong type. 659 """ 660 _python_session_create_counter.get_cell().increase_by(1) 661 if graph is None: 662 self._graph = ops.get_default_graph() 663 else: 664 if not isinstance(graph, ops.Graph): 665 raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) 666 self._graph = graph 667 668 self._closed = False 669 670 if target is not None: 671 try: 672 self._target = compat.as_bytes(target) 673 except TypeError: 674 if isinstance(target, config_pb2.ConfigProto): 675 raise TypeError('target must be a string, but got %s.' 676 ' Did you do "Session(config)" instead of' 677 ' "Session(config=config)"?' % type(target)) 678 raise TypeError('target must be a string, but got %s' % type(target)) 679 else: 680 self._target = None 681 682 self._delete_lock = threading.Lock() 683 self._dead_handles = [] 684 685 if config is None: 686 config = context.context().config 687 688 if not isinstance(config, config_pb2.ConfigProto): 689 raise TypeError('config must be a tf.ConfigProto, but got %s' % 690 type(config)) 691 692 if (mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled 693 and config.graph_options.rewrite_options.auto_mixed_precision != 694 rewriter_config_pb2.RewriterConfig.OFF): 695 new_config = config_pb2.ConfigProto() 696 new_config.CopyFrom(config) 697 new_config.graph_options.rewrite_options.auto_mixed_precision = ( 698 rewriter_config_pb2.RewriterConfig.ON) 699 config = new_config 700 elif (config.graph_options.rewrite_options.auto_mixed_precision != 701 rewriter_config_pb2.RewriterConfig.ON): 702 mixed_precision_global_state.non_mixed_precision_session_created = True 703 704 self._config = config 705 self._add_shapes = config.graph_options.infer_shapes 706 707 self._session = None 708 opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) 709 try: 710 # pylint: disable=protected-access 711 self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts) 712 # pylint: enable=protected-access 713 finally: 714 tf_session.TF_DeleteSessionOptions(opts) 715 716 def list_devices(self): 717 """Lists available devices in this session. 718 719 ```python 720 devices = sess.list_devices() 721 for d in devices: 722 print(d.name) 723 ``` 724 725 Where: 726 Each element in the list has the following properties 727 name: A string with the full name of the device. ex: 728 `/job:worker/replica:0/task:3/device:CPU:0` 729 device_type: The type of the device (e.g. `CPU`, `GPU`, `TPU`.) 730 memory_limit: The maximum amount of memory available on the device. 731 Note: depending on the device, it is possible the usable memory could 732 be substantially less. 733 734 Raises: 735 tf.errors.OpError: If it encounters an error (e.g. session is in an 736 invalid state, or network errors occur). 737 738 Returns: 739 A list of devices in the session. 740 """ 741 raw_device_list = tf_session.TF_SessionListDevices(self._session) 742 device_list = [] 743 size = tf_session.TF_DeviceListCount(raw_device_list) 744 for i in range(size): 745 name = tf_session.TF_DeviceListName(raw_device_list, i) 746 device_type = tf_session.TF_DeviceListType(raw_device_list, i) 747 memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) 748 incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i) 749 device_list.append( 750 _DeviceAttributes(name, device_type, memory, incarnation)) 751 tf_session.TF_DeleteDeviceList(raw_device_list) 752 return device_list 753 754 def close(self): 755 """Closes this session. 756 757 Calling this method frees all resources associated with the session. 758 759 Raises: 760 tf.errors.OpError: Or one of its subclasses if an error occurs while 761 closing the TensorFlow session. 762 """ 763 if self._session and not self._closed: 764 self._closed = True 765 tf_session.TF_CloseSession(self._session) 766 767 def __del__(self): 768 # cleanly ignore all exceptions 769 try: 770 self.close() 771 except Exception: # pylint: disable=broad-except 772 pass 773 if self._session is not None: 774 try: 775 tf_session.TF_DeleteSession(self._session) 776 except (AttributeError, TypeError): 777 # At shutdown, `c_api_util`, `tf_session`, or 778 # `tf_session.TF_DeleteSession` may have been garbage collected, causing 779 # the above method calls to fail. In this case, silently leak since the 780 # program is about to terminate anyway. 781 pass 782 self._session = None 783 784 @property 785 def graph(self): 786 """The graph that was launched in this session.""" 787 return self._graph 788 789 @property 790 def graph_def(self): 791 """A serializable version of the underlying TensorFlow graph. 792 793 Returns: 794 A graph_pb2.GraphDef proto containing nodes for all of the Operations in 795 the underlying TensorFlow graph. 796 """ 797 return self._graph.as_graph_def(add_shapes=self._add_shapes) 798 799 @property 800 def sess_str(self): 801 return self._target 802 803 def as_default(self): 804 """Returns a context manager that makes this object the default session. 805 806 Use with the `with` keyword to specify that calls to 807 `tf.Operation.run` or `tf.Tensor.eval` should be executed in 808 this session. 809 810 ```python 811 c = tf.constant(..) 812 sess = tf.compat.v1.Session() 813 814 with sess.as_default(): 815 assert tf.compat.v1.get_default_session() is sess 816 print(c.eval()) 817 ``` 818 819 To get the current default session, use `tf.compat.v1.get_default_session`. 820 821 *N.B.* The `as_default` context manager *does not* close the 822 session when you exit the context, and you must close the session 823 explicitly. 824 825 ```python 826 c = tf.constant(...) 827 sess = tf.compat.v1.Session() 828 with sess.as_default(): 829 print(c.eval()) 830 # ... 831 with sess.as_default(): 832 print(c.eval()) 833 834 sess.close() 835 ``` 836 837 Alternatively, you can use `with tf.compat.v1.Session():` to create a 838 session that is automatically closed on exiting the context, 839 including when an uncaught exception is raised. 840 841 *N.B.* The default session is a property of the current thread. If you 842 create a new thread, and wish to use the default session in that 843 thread, you must explicitly add a `with sess.as_default():` in that 844 thread's function. 845 846 *N.B.* Entering a `with sess.as_default():` block does not affect 847 the current default graph. If you are using multiple graphs, and 848 `sess.graph` is different from the value of 849 `tf.compat.v1.get_default_graph`, you must explicitly enter a 850 `with sess.graph.as_default():` block to make `sess.graph` the default 851 graph. 852 853 Returns: 854 A context manager using this session as the default session. 855 """ 856 return ops.default_session(self) 857 858 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 859 """Runs operations and evaluates tensors in `fetches`. 860 861 This method runs one "step" of TensorFlow computation, by 862 running the necessary graph fragment to execute every `Operation` 863 and evaluate every `Tensor` in `fetches`, substituting the values in 864 `feed_dict` for the corresponding input values. 865 866 The `fetches` argument may be a single graph element, or an arbitrarily 867 nested list, tuple, namedtuple, dict, or OrderedDict containing graph 868 elements at its leaves. A graph element can be one of the following types: 869 870 * A `tf.Operation`. 871 The corresponding fetched value will be `None`. 872 * A `tf.Tensor`. 873 The corresponding fetched value will be a numpy ndarray containing the 874 value of that tensor. 875 * A `tf.sparse.SparseTensor`. 876 The corresponding fetched value will be a 877 `tf.compat.v1.SparseTensorValue` 878 containing the value of that sparse tensor. 879 * A `get_tensor_handle` op. The corresponding fetched value will be a 880 numpy ndarray containing the handle of that tensor. 881 * A `string` which is the name of a tensor or operation in the graph. 882 883 The value returned by `run()` has the same shape as the `fetches` argument, 884 where the leaves are replaced by the corresponding values returned by 885 TensorFlow. 886 887 Example: 888 889 ```python 890 a = tf.constant([10, 20]) 891 b = tf.constant([1.0, 2.0]) 892 # 'fetches' can be a singleton 893 v = session.run(a) 894 # v is the numpy array [10, 20] 895 # 'fetches' can be a list. 896 v = session.run([a, b]) 897 # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the 898 # 1-D array [1.0, 2.0] 899 # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts: 900 MyData = collections.namedtuple('MyData', ['a', 'b']) 901 v = session.run({'k1': MyData(a, b), 'k2': [b, a]}) 902 # v is a dict with 903 # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and 904 # 'b' (the numpy array [1.0, 2.0]) 905 # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array 906 # [10, 20]. 907 ``` 908 909 The optional `feed_dict` argument allows the caller to override 910 the value of tensors in the graph. Each key in `feed_dict` can be 911 one of the following types: 912 913 * If the key is a `tf.Tensor`, the 914 value may be a Python scalar, string, list, or numpy ndarray 915 that can be converted to the same `dtype` as that 916 tensor. Additionally, if the key is a 917 `tf.compat.v1.placeholder`, the shape of 918 the value will be checked for compatibility with the placeholder. 919 * If the key is a 920 `tf.sparse.SparseTensor`, 921 the value should be a 922 `tf.compat.v1.SparseTensorValue`. 923 * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value 924 should be a nested tuple with the same structure that maps to their 925 corresponding values as above. 926 927 Each value in `feed_dict` must be convertible to a numpy array of the dtype 928 of the corresponding key. 929 930 The optional `options` argument expects a [`RunOptions`] proto. The options 931 allow controlling the behavior of this particular step (e.g. turning tracing 932 on). 933 934 The optional `run_metadata` argument expects a [`RunMetadata`] proto. When 935 appropriate, the non-Tensor output of this step will be collected there. For 936 example, when users turn on tracing in `options`, the profiled info will be 937 collected into this argument and passed back. 938 939 Args: 940 fetches: A single graph element, a list of graph elements, or a dictionary 941 whose values are graph elements or lists of graph elements (described 942 above). 943 feed_dict: A dictionary that maps graph elements to values (described 944 above). 945 options: A [`RunOptions`] protocol buffer 946 run_metadata: A [`RunMetadata`] protocol buffer 947 948 Returns: 949 Either a single value if `fetches` is a single graph element, or 950 a list of values if `fetches` is a list, or a dictionary with the 951 same keys as `fetches` if that is a dictionary (described above). 952 Order in which `fetches` operations are evaluated inside the call 953 is undefined. 954 955 Raises: 956 RuntimeError: If this `Session` is in an invalid state (e.g. has been 957 closed). 958 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 959 ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a 960 `Tensor` that doesn't exist. 961 """ 962 options_ptr = tf_session.TF_NewBufferFromString( 963 compat.as_bytes(options.SerializeToString())) if options else None 964 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 965 966 try: 967 result = self._run(None, fetches, feed_dict, options_ptr, 968 run_metadata_ptr) 969 if run_metadata: 970 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 971 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 972 finally: 973 if run_metadata_ptr: 974 tf_session.TF_DeleteBuffer(run_metadata_ptr) 975 if options: 976 tf_session.TF_DeleteBuffer(options_ptr) 977 return result 978 979 def partial_run(self, handle, fetches, feed_dict=None): 980 """Continues the execution with more feeds and fetches. 981 982 This is EXPERIMENTAL and subject to change. 983 984 To use partial execution, a user first calls `partial_run_setup()` and 985 then a sequence of `partial_run()`. `partial_run_setup` specifies the 986 list of feeds and fetches that will be used in the subsequent 987 `partial_run` calls. 988 989 The optional `feed_dict` argument allows the caller to override 990 the value of tensors in the graph. See run() for more information. 991 992 Below is a simple example: 993 994 ```python 995 a = array_ops.placeholder(dtypes.float32, shape=[]) 996 b = array_ops.placeholder(dtypes.float32, shape=[]) 997 c = array_ops.placeholder(dtypes.float32, shape=[]) 998 r1 = math_ops.add(a, b) 999 r2 = math_ops.multiply(r1, c) 1000 1001 h = sess.partial_run_setup([r1, r2], [a, b, c]) 1002 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) 1003 res = sess.partial_run(h, r2, feed_dict={c: res}) 1004 ``` 1005 1006 Args: 1007 handle: A handle for a sequence of partial runs. 1008 fetches: A single graph element, a list of graph elements, or a dictionary 1009 whose values are graph elements or lists of graph elements (see 1010 documentation for `run`). 1011 feed_dict: A dictionary that maps graph elements to values (described 1012 above). 1013 1014 Returns: 1015 Either a single value if `fetches` is a single graph element, or 1016 a list of values if `fetches` is a list, or a dictionary with the 1017 same keys as `fetches` if that is a dictionary 1018 (see documentation for `run`). 1019 1020 Raises: 1021 tf.errors.OpError: Or one of its subclasses on error. 1022 """ 1023 # TODO(touts): Support feeding and fetching the same tensor. 1024 return self._run(handle, fetches, feed_dict, None, None) 1025 1026 def partial_run_setup(self, fetches, feeds=None): 1027 """Sets up a graph with feeds and fetches for partial run. 1028 1029 This is EXPERIMENTAL and subject to change. 1030 1031 Note that contrary to `run`, `feeds` only specifies the graph elements. 1032 The tensors will be supplied by the subsequent `partial_run` calls. 1033 1034 Args: 1035 fetches: A single graph element, or a list of graph elements. 1036 feeds: A single graph element, or a list of graph elements. 1037 1038 Returns: 1039 A handle for partial run. 1040 1041 Raises: 1042 RuntimeError: If this `Session` is in an invalid state (e.g. has been 1043 closed). 1044 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 1045 tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens. 1046 """ 1047 1048 def _feed_fn(feed): 1049 for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: 1050 if isinstance(feed, tensor_type): 1051 return feed_fn(feed) 1052 raise TypeError('Feed argument %r has invalid type %r' % 1053 (feed, type(feed))) 1054 1055 # Check session. 1056 if self._closed: 1057 raise RuntimeError('Attempted to use a closed Session.') 1058 if self.graph.version == 0: 1059 raise RuntimeError('The Session graph is empty. Add operations to the ' 1060 'graph before calling run().') 1061 1062 if feeds is None: 1063 feeds = [] 1064 # Create request. 1065 feed_list = [] 1066 1067 # Validate and process feed_list. 1068 is_list_feed = isinstance(feeds, (list, tuple)) 1069 if not is_list_feed: 1070 feeds = [feeds] 1071 for feed in feeds: 1072 for subfeed in _feed_fn(feed): 1073 try: 1074 subfeed_t = self.graph.as_graph_element( 1075 subfeed, allow_tensor=True, allow_operation=False) 1076 # pylint: disable=protected-access 1077 feed_list.append(subfeed_t._as_tf_output()) 1078 # pylint: enable=protected-access 1079 except Exception as e: 1080 e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) 1081 e.args = (e.message,) 1082 raise e 1083 1084 # Validate and process fetches. 1085 # TODO(touts): Support feeding and fetching the same tensor. 1086 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1087 1088 # Set up a graph with feeds and fetches for partial run. 1089 def _setup_fn(session, feed_list, fetch_list, target_list): 1090 self._extend_graph() 1091 return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list, 1092 fetch_list, target_list) 1093 1094 # pylint: disable=protected-access 1095 final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] 1096 final_targets = [op._c_op for op in fetch_handler.targets()] 1097 # pylint: enable=protected-access 1098 1099 return self._do_call(_setup_fn, self._session, feed_list, final_fetches, 1100 final_targets) 1101 1102 def _run(self, handle, fetches, feed_dict, options, run_metadata): 1103 """Perform either run or partial_run, depending the presence of `handle`.""" 1104 1105 def _feed_fn(feed, feed_val): 1106 for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS: 1107 if isinstance(feed, tensor_type): 1108 return feed_fn(feed, feed_val) 1109 raise TypeError('Feed argument %r has invalid type %r' % 1110 (feed, type(feed))) 1111 1112 # Check session. 1113 if self._closed: 1114 raise RuntimeError('Attempted to use a closed Session.') 1115 if self.graph.version == 0: 1116 raise RuntimeError('The Session graph is empty. Add operations to the ' 1117 'graph before calling run().') 1118 1119 # Create request. 1120 feed_dict_tensor = {} 1121 feed_map = {} 1122 1123 # Validate and process feed_dict. 1124 feed_handles = {} 1125 if feed_dict: 1126 feed_dict = nest.flatten_dict_items(feed_dict) 1127 for feed, feed_val in feed_dict.items(): 1128 for subfeed, subfeed_val in _feed_fn(feed, feed_val): 1129 try: 1130 subfeed_t = self.graph.as_graph_element( 1131 subfeed, allow_tensor=True, allow_operation=False) 1132 except Exception as e: 1133 raise TypeError('Cannot interpret feed_dict key as Tensor: ' + 1134 e.args[0]) 1135 1136 if isinstance(subfeed_val, ops.Tensor): 1137 raise TypeError('The value of a feed cannot be a tf.Tensor object. ' 1138 'Acceptable feed values include Python scalars, ' 1139 'strings, lists, numpy ndarrays, or TensorHandles. ' 1140 'For reference, the tensor object was ' + 1141 str(feed_val) + ' which was passed to the ' 1142 'feed with key ' + str(feed) + '.') 1143 1144 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype 1145 if isinstance(subfeed_val, int) and _convert_to_numpy_obj( 1146 subfeed_dtype, subfeed_val) != subfeed_val: 1147 raise TypeError( 1148 'Type of feed value ' + str(subfeed_val) + ' with type ' + 1149 str(type(subfeed_val)) + 1150 ' is not compatible with Tensor type ' + str(subfeed_dtype) + 1151 '. Try explicitly setting the type of the feed tensor' 1152 ' to a larger type (e.g. int64).') 1153 1154 is_tensor_handle_feed = isinstance(subfeed_val, 1155 session_ops.TensorHandle) 1156 if is_tensor_handle_feed: 1157 np_val = subfeed_val.to_numpy_array() 1158 feed_handles[subfeed_t.ref()] = subfeed_val 1159 else: 1160 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 1161 1162 if (not is_tensor_handle_feed and 1163 not subfeed_t.get_shape().is_compatible_with(np_val.shape)): 1164 raise ValueError( 1165 'Cannot feed value of shape %r for Tensor %r, ' 1166 'which has shape %r' % 1167 (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape()))) 1168 if not self.graph.is_feedable(subfeed_t): 1169 raise ValueError('Tensor %s may not be fed.' % subfeed_t) 1170 1171 feed_dict_tensor[subfeed_t.ref()] = np_val 1172 feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val) 1173 1174 # Create a fetch handler to take care of the structure of fetches. 1175 fetch_handler = _FetchHandler( 1176 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 1177 1178 # Run request and get response. 1179 # We need to keep the returned movers alive for the following _do_run(). 1180 # These movers are no longer needed when _do_run() completes, and 1181 # are deleted when `movers` goes out of scope when this _run() ends. 1182 # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding 1183 # of a handle from a different device as an error. 1184 _ = self._update_with_movers(feed_dict_tensor, feed_map) 1185 final_fetches = fetch_handler.fetches() 1186 final_targets = fetch_handler.targets() 1187 # We only want to really perform the run if fetches or targets are provided, 1188 # or if the call is a partial run that specifies feeds. 1189 if final_fetches or final_targets or (handle and feed_dict_tensor): 1190 results = self._do_run(handle, final_targets, final_fetches, 1191 feed_dict_tensor, options, run_metadata) 1192 else: 1193 results = [] 1194 return fetch_handler.build_results(self, results) 1195 1196 def make_callable(self, fetches, feed_list=None, accept_options=False): 1197 """Returns a Python callable that runs a particular step. 1198 1199 The returned callable will take `len(feed_list)` arguments whose types 1200 must be compatible feed values for the respective elements of `feed_list`. 1201 For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th 1202 argument to the returned callable must be a numpy ndarray (or something 1203 convertible to an ndarray) with matching element type and shape. See 1204 `tf.Session.run` for details of the allowable feed key and value types. 1205 1206 The returned callable will have the same return type as 1207 `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, 1208 the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`, 1209 it will return `None`. 1210 1211 Args: 1212 fetches: A value or list of values to fetch. See `tf.Session.run` for 1213 details of the allowable fetch types. 1214 feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run` 1215 for details of the allowable feed key types. 1216 accept_options: (Optional.) If `True`, the returned `Callable` will be 1217 able to accept `tf.compat.v1.RunOptions` and `tf.compat.v1.RunMetadata` 1218 as optional keyword arguments `options` and `run_metadata`, 1219 respectively, with the same syntax and semantics as `tf.Session.run`, 1220 which is useful for certain use cases (profiling and debugging) but will 1221 result in measurable slowdown of the `Callable`'s 1222 performance. Default: `False`. 1223 1224 Returns: 1225 A function that when called will execute the step defined by 1226 `feed_list` and `fetches` in this session. 1227 1228 Raises: 1229 TypeError: If `fetches` or `feed_list` cannot be interpreted 1230 as arguments to `tf.Session.run`. 1231 """ 1232 if feed_list is not None: 1233 if not isinstance(feed_list, (list, tuple)): 1234 raise TypeError('`feed_list` must be a list or tuple.') 1235 # Delegate any non-empty feed lists to the existing `run()` logic. 1236 # TODO(mrry): Refactor the feed handling logic from 1237 # `Session._run()` so that we can convert the feeds to a list of 1238 # strings here. 1239 def _generic_run(*feed_args, **kwargs): 1240 feed_dict = { 1241 feed: feed_val for feed, feed_val in zip(feed_list, feed_args) 1242 } 1243 return self.run(fetches, feed_dict=feed_dict, **kwargs) 1244 1245 return _generic_run 1246 1247 # Ensure any changes to the graph are reflected in the runtime. 1248 # Note that we don't need to do this on subsequent calls to the 1249 # returned object, because the arguments to `fetches` must already be 1250 # in the graph. 1251 self._extend_graph() 1252 1253 # Create a fetch handler to take care of the structure of fetches. 1254 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1255 # pylint: disable=protected-access 1256 fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] 1257 target_list = [op._c_op for op in fetch_handler.targets()] 1258 1259 # pylint: enable=protected-access 1260 1261 def _callable_template_with_options_and_metadata(fetch_list, 1262 target_list, 1263 fetch_handler, 1264 options=None, 1265 run_metadata=None): 1266 """Template callable that accepts RunOptions and RunMetadata.""" 1267 options_ptr = tf_session.TF_NewBufferFromString( 1268 compat.as_bytes(options.SerializeToString())) if options else None 1269 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1270 try: 1271 results = self._call_tf_sessionrun(options_ptr, {}, fetch_list, 1272 target_list, run_metadata_ptr) 1273 if fetch_handler: 1274 results = fetch_handler.build_results(self, results) 1275 else: 1276 results = results[0] if results else None 1277 if run_metadata: 1278 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1279 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1280 finally: 1281 if run_metadata_ptr: 1282 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1283 if options: 1284 tf_session.TF_DeleteBuffer(options_ptr) 1285 return results 1286 1287 if accept_options: 1288 return functools.partial(_callable_template_with_options_and_metadata, 1289 fetch_list, target_list, fetch_handler) 1290 elif isinstance(fetches, ops.Operation): 1291 # Special case for fetching a single operation, because the 1292 # function will have no return value. 1293 assert not fetch_list 1294 assert len(target_list) == 1 1295 1296 def _single_operation_run(): 1297 self._call_tf_sessionrun(None, {}, [], target_list, None) 1298 1299 return _single_operation_run 1300 elif isinstance(fetches, ops.Tensor): 1301 # Special case for fetching a single tensor, because the 1302 # function can return the result of `TF_Run()` directly. 1303 assert len(fetch_list) == 1 1304 assert not target_list 1305 1306 def _single_tensor_run(): 1307 results = self._call_tf_sessionrun(None, {}, fetch_list, [], None) 1308 return results[0] 1309 1310 return _single_tensor_run 1311 else: 1312 # In all other cases, we must use `fetch_handler` to build the 1313 # results for us. 1314 def _fetch_handler_run(): 1315 results = self._call_tf_sessionrun(None, {}, fetch_list, target_list, 1316 None) 1317 return fetch_handler.build_results(self, results) 1318 1319 return _fetch_handler_run 1320 1321 # Captures the name of a node in an error status. The regex below matches 1322 # both the old and the new formats: 1323 # Old format: [[Node: <node_name> = ...]] 1324 # New format: [[{{node <node_name>}} = ...]] 1325 _NODEDEF_NAME_RE = re.compile( 1326 r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*') 1327 1328 def _do_run(self, handle, target_list, fetch_list, feed_dict, options, 1329 run_metadata): 1330 """Runs a step based on the given fetches and feeds. 1331 1332 Args: 1333 handle: a handle for partial_run. None if this is just a call to run(). 1334 target_list: A list of operations to be run, but not fetched. 1335 fetch_list: A list of tensors to be fetched. 1336 feed_dict: A dictionary that maps tensors to numpy ndarrays. 1337 options: A (pointer to a) [`RunOptions`] protocol buffer, or None 1338 run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None 1339 1340 Returns: 1341 A list of numpy ndarrays, corresponding to the elements of 1342 `fetch_list`. If the ith element of `fetch_list` contains the 1343 name of an operation, the first Tensor output of that operation 1344 will be returned for that element. 1345 1346 Raises: 1347 tf.errors.OpError: Or one of its subclasses on error. 1348 """ 1349 # pylint: disable=protected-access 1350 feeds = dict((t.deref()._as_tf_output(), v) for t, v in feed_dict.items()) 1351 fetches = [t._as_tf_output() for t in fetch_list] 1352 targets = [op._c_op for op in target_list] 1353 1354 # pylint: enable=protected-access 1355 1356 def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): 1357 # Ensure any changes to the graph are reflected in the runtime. 1358 self._extend_graph() 1359 return self._call_tf_sessionrun(options, feed_dict, fetch_list, 1360 target_list, run_metadata) 1361 1362 def _prun_fn(handle, feed_dict, fetch_list): 1363 if target_list: 1364 raise RuntimeError('partial_run() requires empty target_list.') 1365 return self._call_tf_sessionprun(handle, feed_dict, fetch_list) 1366 1367 if handle is None: 1368 return self._do_call(_run_fn, feeds, fetches, targets, options, 1369 run_metadata) 1370 else: 1371 return self._do_call(_prun_fn, handle, feeds, fetches) 1372 1373 def _do_call(self, fn, *args): 1374 try: 1375 return fn(*args) 1376 except errors.OpError as e: 1377 message = compat.as_text(e.message) 1378 m = BaseSession._NODEDEF_NAME_RE.search(message) 1379 node_def = None 1380 op = None 1381 if m is not None: 1382 node_name = m.group(3) 1383 try: 1384 op = self._graph.get_operation_by_name(node_name) 1385 node_def = op.node_def 1386 except KeyError: 1387 pass 1388 message = error_interpolation.interpolate(message, self._graph) 1389 if 'only supports NHWC tensor format' in message: 1390 message += ('\nA possible workaround: Try disabling Grappler optimizer' 1391 '\nby modifying the config for creating the session eg.' 1392 '\nsession_config.graph_options.rewrite_options.' 1393 'disable_meta_optimizer = True') 1394 raise type(e)(node_def, op, message) 1395 1396 def _extend_graph(self): 1397 with self._graph._session_run_lock(): # pylint: disable=protected-access 1398 tf_session.ExtendSession(self._session) 1399 1400 # The threshold to run garbage collection to delete dead tensors. 1401 _DEAD_HANDLES_THRESHOLD = 10 1402 1403 def _register_dead_handle(self, handle): 1404 # Register a dead handle in the session. Delete the dead tensors when 1405 # the number of dead tensors exceeds certain threshold. 1406 tensors_to_delete = None 1407 with self._delete_lock: 1408 self._dead_handles.append(handle) 1409 if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD: 1410 tensors_to_delete = self._dead_handles 1411 self._dead_handles = [] 1412 # Delete the dead tensors. 1413 if tensors_to_delete: 1414 feeds = {} 1415 fetches = [] 1416 for deleter_key, tensor_handle in enumerate(tensors_to_delete): 1417 holder, deleter = session_ops._get_handle_deleter( 1418 self.graph, deleter_key, tensor_handle) 1419 feeds[holder] = tensor_handle 1420 fetches.append(deleter) 1421 self.run(fetches, feed_dict=feeds) 1422 1423 def _update_with_movers(self, feed_dict, feed_map): 1424 # If a tensor handle that is fed to a device incompatible placeholder, 1425 # we move the tensor to the right device, generate a new tensor handle, 1426 # and update `feed_dict` to use the new handle. 1427 handle_movers = [] 1428 for feed_name, val in feed_map.items(): 1429 mover = session_ops._get_handle_mover(self.graph, *val) 1430 if mover: 1431 handle_movers.append((feed_name, val[1], mover)) 1432 # Transfer a tensor to the right device if needed. 1433 if not handle_movers: 1434 return [] 1435 else: 1436 feeds = {} 1437 fetches = [] 1438 for _, handle, mover in handle_movers: 1439 feeds[mover[0]] = handle 1440 fetches.append(mover[1]) 1441 handles = self.run(fetches, feed_dict=feeds) 1442 for handle_mover, handle in zip(handle_movers, handles): 1443 np_val = np.array(handle.handle, dtype=np.object) 1444 feed_name = handle_mover[0] 1445 feed_tensor = feed_map[feed_name][0] 1446 feed_dict[feed_tensor.ref()] = np_val 1447 return handles 1448 1449 def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, 1450 run_metadata): 1451 return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, 1452 fetch_list, target_list, 1453 run_metadata) 1454 1455 def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): 1456 return tf_session.TF_SessionPRun_wrapper(self._session, handle, feed_dict, 1457 fetch_list) 1458 1459 # pylint: disable=protected-access 1460 class _Callable(object): 1461 """Experimental wrapper for the C++ `Session::MakeCallable()` API.""" 1462 1463 def __init__(self, session, callable_options): 1464 self._session = session 1465 self._handle = None 1466 options_ptr = tf_session.TF_NewBufferFromString( 1467 compat.as_bytes(callable_options.SerializeToString())) 1468 try: 1469 self._handle = tf_session.TF_SessionMakeCallable( 1470 session._session, options_ptr) 1471 finally: 1472 tf_session.TF_DeleteBuffer(options_ptr) 1473 1474 def __call__(self, *args, **kwargs): 1475 # TODO(b/74355905): Support argument and return value nested structures, 1476 # and tensor-like objects such as SparseTensors. 1477 run_metadata = kwargs.get('run_metadata', None) 1478 try: 1479 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1480 ret = tf_session.TF_SessionRunCallable(self._session._session, 1481 self._handle, args, 1482 run_metadata_ptr) 1483 if run_metadata: 1484 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1485 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1486 finally: 1487 if run_metadata_ptr: 1488 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1489 return ret 1490 1491 def __del__(self): 1492 # NOTE(mrry): It is possible that `self._session.__del__()` could be 1493 # called before this destructor, in which case `self._session._session` 1494 # will be `None`. 1495 if (self._handle is not None and self._session._session is not None and 1496 not self._session._closed): 1497 tf_session.TF_SessionReleaseCallable(self._session._session, 1498 self._handle) 1499 1500 # pylint: enable=protected-access 1501 1502 # TODO(b/74355905): Reimplement `Session.make_callable()` using this method 1503 # where possible. 1504 def _make_callable_from_options(self, callable_options): 1505 """Returns a handle to a "callable" with the given options. 1506 1507 Args: 1508 callable_options: A `CallableOptions` protocol buffer message describing 1509 the computation that will be performed by the callable. 1510 1511 Returns: 1512 A handle to the new callable. 1513 """ 1514 self._extend_graph() 1515 return BaseSession._Callable(self, callable_options) 1516 1517 1518@tf_export(v1=['Session']) 1519class Session(BaseSession): 1520 """A class for running TensorFlow operations. 1521 1522 A `Session` object encapsulates the environment in which `Operation` 1523 objects are executed, and `Tensor` objects are evaluated. For 1524 example: 1525 1526 ```python 1527 tf.compat.v1.disable_eager_execution() # need to disable eager in TF2.x 1528 # Build a graph. 1529 a = tf.constant(5.0) 1530 b = tf.constant(6.0) 1531 c = a * b 1532 1533 # Launch the graph in a session. 1534 sess = tf.compat.v1.Session() 1535 1536 # Evaluate the tensor `c`. 1537 print(sess.run(c)) # prints 30.0 1538 ``` 1539 1540 A session may own resources, such as 1541 `tf.Variable`, `tf.queue.QueueBase`, 1542 and `tf.compat.v1.ReaderBase`. It is important to release 1543 these resources when they are no longer required. To do this, either 1544 invoke the `tf.Session.close` method on the session, or use 1545 the session as a context manager. The following two examples are 1546 equivalent: 1547 1548 ```python 1549 # Using the `close()` method. 1550 sess = tf.compat.v1.Session() 1551 sess.run(...) 1552 sess.close() 1553 1554 # Using the context manager. 1555 with tf.compat.v1.Session() as sess: 1556 sess.run(...) 1557 ``` 1558 1559 The 1560 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1561 protocol buffer exposes various configuration options for a 1562 session. For example, to create a session that uses soft constraints 1563 for device placement, and log the resulting placement decisions, 1564 create a session as follows: 1565 1566 ```python 1567 # Launch the graph in a session that allows soft device placement and 1568 # logs the placement decisions. 1569 sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto( 1570 allow_soft_placement=True, 1571 log_device_placement=True)) 1572 ``` 1573 """ 1574 1575 def __init__(self, target='', graph=None, config=None): 1576 """Creates a new TensorFlow session. 1577 1578 If no `graph` argument is specified when constructing the session, 1579 the default graph will be launched in the session. If you are 1580 using more than one graph (created with `tf.Graph()`) in the same 1581 process, you will have to use different sessions for each graph, 1582 but each graph can be used in multiple sessions. In this case, it 1583 is often clearer to pass the graph to be launched explicitly to 1584 the session constructor. 1585 1586 Args: 1587 target: (Optional.) The execution engine to connect to. Defaults to using 1588 an in-process engine. See 1589 [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) for 1590 more examples. 1591 graph: (Optional.) The `Graph` to be launched (described above). 1592 config: (Optional.) A 1593 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1594 protocol buffer with configuration options for the session. 1595 """ 1596 super(Session, self).__init__(target, graph, config=config) 1597 # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle. 1598 self._default_graph_context_manager = None 1599 self._default_session_context_manager = None 1600 1601 def __enter__(self): 1602 if self._default_graph_context_manager is None: 1603 self._default_graph_context_manager = self.graph.as_default() 1604 else: 1605 raise RuntimeError('Session context managers are not re-entrant. ' 1606 'Use `Session.as_default()` if you want to enter ' 1607 'a session multiple times.') 1608 if self._default_session_context_manager is None: 1609 self._default_session_context_manager = self.as_default() 1610 self._default_graph_context_manager.__enter__() 1611 return self._default_session_context_manager.__enter__() 1612 1613 def __exit__(self, exec_type, exec_value, exec_tb): 1614 if exec_type is errors.OpError: 1615 logging.error('Session closing due to OpError: %s', (exec_value,)) 1616 try: 1617 self._default_session_context_manager.__exit__(exec_type, exec_value, 1618 exec_tb) 1619 except RuntimeError as error: 1620 if error == exec_value: 1621 # NOTE(skyewm): for some reason, in Python3, 1622 # _default_session_context_manager.__exit__ will re-raise the "not 1623 # re-entrant" exception raised in __enter__ above (note that if we're 1624 # here, we're in the outer session context manager, since __exit__ is 1625 # not called when __enter__ raises an exception). We still want to 1626 # continue cleaning up this context manager before the exception is 1627 # further propagated, so we ignore it here (note that it'll continue 1628 # being propagated after this method completes). 1629 pass 1630 else: 1631 raise 1632 self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) 1633 1634 self._default_session_context_manager = None 1635 self._default_graph_context_manager = None 1636 1637 # If we are closing due to an exception, set a time limit on our Close() to 1638 # avoid blocking forever. 1639 # TODO(b/120204635) remove this when deadlock is fixed. 1640 if exec_type: 1641 close_thread = threading.Thread( 1642 name='SessionCloseThread', target=self.close) 1643 close_thread.daemon = True 1644 close_thread.start() 1645 close_thread.join(30.0) 1646 if close_thread.is_alive(): 1647 logging.error( 1648 'Session failed to close after 30 seconds. Continuing after this ' 1649 'point may leave your program in an undefined state.') 1650 else: 1651 self.close() 1652 1653 @staticmethod 1654 def reset(target, containers=None, config=None): 1655 """Resets resource containers on `target`, and close all connected sessions. 1656 1657 A resource container is distributed across all workers in the 1658 same cluster as `target`. When a resource container on `target` 1659 is reset, resources associated with that container will be cleared. 1660 In particular, all Variables in the container will become undefined: 1661 they lose their values and shapes. 1662 1663 NOTE: 1664 (i) reset() is currently only implemented for distributed sessions. 1665 (ii) Any sessions on the master named by `target` will be closed. 1666 1667 If no resource containers are provided, all containers are reset. 1668 1669 Args: 1670 target: The execution engine to connect to. 1671 containers: A list of resource container name strings, or `None` if all of 1672 all the containers are to be reset. 1673 config: (Optional.) Protocol buffer with configuration options. 1674 1675 Raises: 1676 tf.errors.OpError: Or one of its subclasses if an error occurs while 1677 resetting containers. 1678 """ 1679 if target is not None: 1680 target = compat.as_bytes(target) 1681 if containers is not None: 1682 containers = [compat.as_bytes(c) for c in containers] 1683 else: 1684 containers = [] 1685 tf_session.TF_Reset(target, containers, config) 1686 1687 1688@tf_export(v1=['InteractiveSession']) 1689class InteractiveSession(BaseSession): 1690 """A TensorFlow `Session` for use in interactive contexts, such as a shell. 1691 1692 The only difference with a regular `Session` is that an `InteractiveSession` 1693 installs itself as the default session on construction. 1694 The methods `tf.Tensor.eval` 1695 and `tf.Operation.run` 1696 will use that session to run ops. 1697 1698 This is convenient in interactive shells and [IPython 1699 notebooks](http://ipython.org), as it avoids having to pass an explicit 1700 `Session` object to run ops. 1701 1702 For example: 1703 1704 ```python 1705 sess = tf.compat.v1.InteractiveSession() 1706 a = tf.constant(5.0) 1707 b = tf.constant(6.0) 1708 c = a * b 1709 # We can just use 'c.eval()' without passing 'sess' 1710 print(c.eval()) 1711 sess.close() 1712 ``` 1713 1714 Note that a regular session installs itself as the default session when it 1715 is created in a `with` statement. The common usage in non-interactive 1716 programs is to follow that pattern: 1717 1718 ```python 1719 a = tf.constant(5.0) 1720 b = tf.constant(6.0) 1721 c = a * b 1722 with tf.compat.v1.Session(): 1723 # We can also use 'c.eval()' here. 1724 print(c.eval()) 1725 ``` 1726 """ 1727 1728 _count_lock = threading.Lock() 1729 _active_session_count = 0 # GUARDED_BY(_count_lock) 1730 1731 def __init__(self, target='', graph=None, config=None): 1732 """Creates a new interactive TensorFlow session. 1733 1734 If no `graph` argument is specified when constructing the session, 1735 the default graph will be launched in the session. If you are 1736 using more than one graph (created with `tf.Graph()`) in the same 1737 process, you will have to use different sessions for each graph, 1738 but each graph can be used in multiple sessions. In this case, it 1739 is often clearer to pass the graph to be launched explicitly to 1740 the session constructor. 1741 1742 Args: 1743 target: (Optional.) The execution engine to connect to. Defaults to using 1744 an in-process engine. 1745 graph: (Optional.) The `Graph` to be launched (described above). 1746 config: (Optional) `ConfigProto` proto used to configure the session. 1747 """ 1748 if not config: 1749 # If config is not provided, choose some reasonable defaults for 1750 # interactive use: 1751 # 1752 # - Grow GPU memory as needed at the cost of fragmentation. 1753 gpu_options = config_pb2.GPUOptions(allow_growth=True) 1754 config = config_pb2.ConfigProto(gpu_options=gpu_options) 1755 # Interactive sessions always place pruned graphs. 1756 config.graph_options.place_pruned_graph = True 1757 1758 super(InteractiveSession, self).__init__(target, graph, config) 1759 with InteractiveSession._count_lock: 1760 if InteractiveSession._active_session_count > 0: 1761 warnings.warn('An interactive session is already active. This can ' 1762 'cause out-of-memory errors in some cases. You must ' 1763 'explicitly call `InteractiveSession.close()` to release ' 1764 'resources held by the other session(s).') 1765 InteractiveSession._active_session_count += 1 1766 # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful 1767 # semantics (in particular, it is not set to true if `Session.close()` is 1768 # called on a session that has not been "opened" by running a step) and we 1769 # cannot change those semantics without breaking existing code. 1770 self._explicitly_closed = False 1771 1772 self._default_session = self.as_default() 1773 self._default_session.enforce_nesting = False 1774 self._default_session.__enter__() 1775 self._explicit_graph = graph 1776 if self._explicit_graph is not None: 1777 self._default_graph = graph.as_default() 1778 self._default_graph.enforce_nesting = False 1779 self._default_graph.__enter__() 1780 1781 def close(self): 1782 """Closes an `InteractiveSession`.""" 1783 super(InteractiveSession, self).close() 1784 with InteractiveSession._count_lock: 1785 if not self._explicitly_closed: 1786 InteractiveSession._active_session_count -= 1 1787 self._explicitly_closed = True 1788 else: 1789 return 1790 if self._explicit_graph is not None: 1791 self._default_graph.__exit__(None, None, None) 1792 self._default_graph = None 1793 self._default_session.__exit__(None, None, None) 1794 self._default_session = None 1795