1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A client interface for TensorFlow.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import re 24import threading 25import warnings 26 27import numpy as np 28 29from tensorflow.core.protobuf import config_pb2 30from tensorflow.python import pywrap_tensorflow as tf_session 31from tensorflow.python.framework import device 32from tensorflow.python.framework import error_interpolation 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import sparse_tensor 36from tensorflow.python.ops import session_ops 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.util import compat 39from tensorflow.python.util import nest 40from tensorflow.python.util.tf_export import tf_export 41 42 43class SessionInterface(object): 44 """Base class for implementations of TensorFlow client sessions.""" 45 46 @property 47 def graph(self): 48 """The underlying TensorFlow graph, to be used in building Operations.""" 49 raise NotImplementedError('graph') 50 51 @property 52 def sess_str(self): 53 """The TensorFlow process to which this session will connect.""" 54 raise NotImplementedError('sess_str') 55 56 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 57 """Runs operations in the session. See `BaseSession.run()` for details.""" 58 raise NotImplementedError('run') 59 60 def partial_run_setup(self, fetches, feeds=None): 61 """Sets up the feeds and fetches for partial runs in the session.""" 62 raise NotImplementedError('partial_run_setup') 63 64 def partial_run(self, handle, fetches, feed_dict=None): 65 """Continues the execution with additional feeds and fetches.""" 66 raise NotImplementedError('partial_run') 67 68 69def _get_indexed_slices_value_from_fetches(fetched_vals): 70 return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1], 71 fetched_vals[2] 72 if len(fetched_vals) == 3 else None) 73 74 75def _get_feeds_for_indexed_slices(feed, feed_val): 76 return list( 77 zip([feed.values, feed.indices] if feed.dense_shape is None else 78 [feed.values, feed.indices, feed.dense_shape], feed_val)) 79 80 81# List of extensions supported to convert run arguments into actual fetches and 82# feeds. 83# 84# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2), 85# where the function signatures are: 86# fetch_fn : Type -> (list of Tensors, 87# lambda: list of fetched np.ndarray -> TypeVal) 88# feed_fn1 : Type, TypeVal -> list of (Tensor, value) 89# feed_fn2 : Type -> list of Tensors 90# 91# `fetch_fn` describes how to expand fetch into its 92# component Tensors and how to contract the fetched results back into 93# a single return value. 94# 95# Each feed function describes how to unpack a single fed value and map it to 96# feeds of one or more tensors and their corresponding values: `feed_fn1` is 97# used to feed a run, `feed_fn2` to set up a partial run. 98# 99# TODO(touts): We could reimplement these as specialized _FeedMapper 100# implementations after we refactor the feed handling code to use them. 101# 102# Eventually, this registration could be opened up to support custom Tensor 103# expansions. 104# pylint: disable=g-long-lambda 105_REGISTERED_EXPANSIONS = [ 106 # SparseTensors are fetched as SparseTensorValues. They can be fed 107 # SparseTensorValues or normal tuples. 108 (sparse_tensor.SparseTensor, 109 lambda fetch: ( 110 [fetch.indices, fetch.values, fetch.dense_shape], 111 lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)), 112 lambda feed, feed_val: list(zip( 113 [feed.indices, feed.values, feed.dense_shape], feed_val)), 114 lambda feed: [feed.indices, feed.values, feed.dense_shape]), 115 # IndexedSlices are fetched as IndexedSlicesValues. They can be fed 116 # IndexedSlicesValues or normal tuples. 117 (ops.IndexedSlices, 118 lambda fetch: ( 119 [fetch.values, fetch.indices] if fetch.dense_shape is None 120 else [fetch.values, fetch.indices, fetch.dense_shape], 121 _get_indexed_slices_value_from_fetches), 122 _get_feeds_for_indexed_slices, 123 lambda feed: [feed.values, feed.indices] if feed.dense_shape is None 124 else [feed.values, feed.indices, feed.dense_shape]), 125 # The default catches all other types and performs no expansions. 126 (object, 127 lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), 128 lambda feed, feed_val: [(feed, feed_val)], 129 lambda feed: [feed])] 130 131# pylint: enable=g-long-lambda 132 133 134def _convert_to_numpy_obj(numpy_dtype, obj): 135 """Explicitly convert obj based on numpy type except for string type.""" 136 return numpy_dtype(obj) if numpy_dtype is not object else str(obj) 137 138 139def register_session_run_conversion_functions( 140 tensor_type, 141 fetch_function, 142 feed_function=None, 143 feed_function_for_partial_run=None): 144 """Register fetch and feed conversion functions for `tf.Session.run()`. 145 146 This function registers a triple of conversion functions for fetching and/or 147 feeding values of user-defined types in a call to tf.Session.run(). 148 149 An example 150 151 ```python 152 class SquaredTensor(object): 153 def __init__(self, tensor): 154 self.sq = tf.square(tensor) 155 #you can define conversion functions as follows: 156 fetch_function = lambda squared_tensor:([squared_tensor.sq], 157 lambda val: val[0]) 158 feed_function = lambda feed, feed_val: [(feed.sq, feed_val)] 159 feed_function_for_partial_run = lambda feed: [feed.sq] 160 #then after invoking this register function, you can use as follows: 161 session.run(squared_tensor1, 162 feed_dict = {squared_tensor2 : some_numpy_array}) 163 ``` 164 165 Args: 166 tensor_type: The type for which you want to register a conversion function. 167 fetch_function: A callable that takes an object of type `tensor_type` and 168 returns a tuple, where the first element is a list of `tf.Tensor` objects, 169 and the second element is a callable that takes a list of ndarrays and 170 returns an object of some value type that corresponds to `tensor_type`. 171 fetch_function describes how to expand fetch into its component Tensors 172 and how to contract the fetched results back into a single return value. 173 feed_function: A callable that takes feed_key and feed_value as input, and 174 returns a list of tuples (feed_tensor, feed_val), feed_key must have type 175 `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed 176 function describes how to unpack a single fed value and map it to feeds 177 of one or more tensors and their corresponding values. 178 feed_function_for_partial_run: A callable for specifying tensor values to 179 feed when setting up a partial run, which takes a `tensor_type` type 180 object as input, and returns a list of Tensors. 181 182 Raises: 183 ValueError: If `tensor_type` has already been registered. 184 """ 185 for conversion_function in _REGISTERED_EXPANSIONS: 186 if issubclass(conversion_function[0], tensor_type): 187 raise ValueError('%s has already been registered so ignore it.' % 188 tensor_type) 189 190 _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, 191 feed_function_for_partial_run)) 192 193 194def _is_attrs_instance(obj): 195 """Returns True if the given obj is an instance of attrs-decorated class.""" 196 return getattr(obj.__class__, '__attrs_attrs__', None) is not None 197 198 199def _get_attrs_values(obj): 200 """Returns the list of values from an attrs instance.""" 201 attrs = getattr(obj.__class__, '__attrs_attrs__') 202 return [getattr(obj, a.name) for a in attrs] 203 204 205class _FetchMapper(object): 206 """Definition of the interface provided by fetch mappers. 207 208 Fetch mappers are utility classes used by the _FetchHandler to handle 209 arbitrary structures for the `fetch` argument to `Session.run()`. 210 211 The `fetch` argument can be of various shapes: single tensor or op, list of 212 fetches, tuple of fetches, namedtuple of fetches, or dict of fetches. The 213 structures can be arbitrarily nested. 214 215 The low level run() API only wants a list of tensor or op names. The various 216 `_FetchMapper` subclasses below take care of handling the different shapes: 217 uniquifying the fetches, and constructing results with the original shape. 218 """ 219 220 def unique_fetches(self): 221 """Return the list of unique tensors or ops needed by this fetch mapper. 222 223 Returns: 224 A list of tensors or ops. 225 """ 226 raise NotImplementedError('Must be implemented by subclasses') 227 228 def build_results(self, values): 229 """Build results that match the original shape of the fetch. 230 231 Args: 232 values: List of values returned by run(). The values correspond 233 exactly to the list tensors or ops returned by unique_fetches(). 234 235 Returns: 236 A struct of the same shape as the original fetch object handled by 237 this fetch mapper. In the returned struct, the original fetches are 238 replaced by their fetched values. 239 """ 240 raise NotImplementedError('Must be implemented by subclasses') 241 242 @staticmethod 243 def for_fetch(fetch): 244 """Creates fetch mapper that handles the structure of `fetch`. 245 246 The default graph must be the one from which we want to fetch values when 247 this function is called. 248 249 Args: 250 fetch: An arbitrary fetch structure: singleton, list, tuple, 251 namedtuple, or dict. 252 253 Returns: 254 An instance of a subclass of `_FetchMapper` that handles the shape. 255 """ 256 if fetch is None: 257 raise TypeError('Fetch argument %r has invalid type %r' % (fetch, 258 type(fetch))) 259 elif isinstance(fetch, (list, tuple)): 260 # NOTE(touts): This is also the code path for namedtuples. 261 return _ListFetchMapper(fetch) 262 elif isinstance(fetch, collections.Mapping): 263 return _DictFetchMapper(fetch) 264 elif _is_attrs_instance(fetch): 265 return _AttrsFetchMapper(fetch) 266 else: 267 # Look for a handler in the registered expansions. 268 for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS: 269 if isinstance(fetch, tensor_type): 270 fetches, contraction_fn = fetch_fn(fetch) 271 return _ElementFetchMapper(fetches, contraction_fn) 272 # Did not find anything. 273 raise TypeError('Fetch argument %r has invalid type %r' % (fetch, 274 type(fetch))) 275 276 277class _ElementFetchMapper(_FetchMapper): 278 """Fetch mapper for singleton tensors and ops.""" 279 280 def __init__(self, fetches, contraction_fn): 281 """Creates an _ElementFetchMapper. 282 283 This is the fetch mapper used for leaves in the fetch struct. Because of 284 the expansions mechanism, a leaf can actually fetch more than one tensor. 285 286 Also note that the fetches here can be just strings (tensor or op names) or 287 any other object that the graph knows how to convert to a tensor, such as a 288 Variable. So we have to run each fetch through `as_graph_element()` to get 289 the corresponding tensor or op. 290 291 Args: 292 fetches: List of objects, as returned by a fetch_fn defined 293 in _REGISTERED_EXPANSIONS. 294 contraction_fn: Callable as returned by a fetch_fn. 295 """ 296 self._unique_fetches = [] 297 for fetch in fetches: 298 try: 299 self._unique_fetches.append(ops.get_default_graph().as_graph_element( 300 fetch, allow_tensor=True, allow_operation=True)) 301 except TypeError as e: 302 raise TypeError('Fetch argument %r has invalid type %r, ' 303 'must be a string or Tensor. (%s)' % 304 (fetch, type(fetch), str(e))) 305 except ValueError as e: 306 raise ValueError('Fetch argument %r cannot be interpreted as a ' 307 'Tensor. (%s)' % (fetch, str(e))) 308 except KeyError as e: 309 raise ValueError('Fetch argument %r cannot be interpreted as a ' 310 'Tensor. (%s)' % (fetch, str(e))) 311 self._contraction_fn = contraction_fn 312 313 def unique_fetches(self): 314 return self._unique_fetches 315 316 def build_results(self, values): 317 if not values: 318 # 'Operation' case 319 return None 320 else: 321 return self._contraction_fn(values) 322 323 324def _uniquify_fetches(fetch_mappers): 325 """Uniquifies fetches from a list of fetch_mappers. 326 327 This is a utility function used by _ListFetchMapper and _DictFetchMapper. It 328 gathers all the unique fetches from a list of mappers and builds a list 329 containing all of them but without duplicates (unique_fetches). 330 331 It also returns a 2-D list of integers (values_indices) indicating at which 332 index in unique_fetches the fetches of the mappers are located. 333 334 This list is as follows: 335 values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index 336 337 Args: 338 fetch_mappers: list of fetch mappers. 339 340 Returns: 341 A list of fetches. 342 A 2-D list of integers. 343 """ 344 unique_fetches = [] 345 value_indices = [] 346 seen_fetches = {} 347 for m in fetch_mappers: 348 m_value_indices = [] 349 for f in m.unique_fetches(): 350 j = seen_fetches.get(f) 351 if j is None: 352 j = len(seen_fetches) 353 seen_fetches[f] = j 354 unique_fetches.append(f) 355 m_value_indices.append(j) 356 value_indices.append(m_value_indices) 357 return unique_fetches, value_indices 358 359 360class _ListFetchMapper(_FetchMapper): 361 """Fetch mapper for lists, tuples, and namedtuples.""" 362 363 def __init__(self, fetches): 364 """Creates a _ListFetchMapper. 365 366 Args: 367 fetches: List, tuple, or namedtuple of fetches. 368 """ 369 self._fetch_type = type(fetches) 370 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches] 371 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 372 373 def unique_fetches(self): 374 return self._unique_fetches 375 376 def build_results(self, values): 377 # Create the list of results for each mapper. 378 results = [] 379 for m, vi in zip(self._mappers, self._value_indices): 380 results.append(m.build_results([values[j] for j in vi])) 381 # Return a value of the original type of the fetches. 382 if issubclass(self._fetch_type, list): 383 return results 384 elif self._fetch_type == tuple: 385 return tuple(results) 386 else: 387 # This is the code path for namedtuple. 388 return self._fetch_type(*results) 389 390 391class _DictFetchMapper(_FetchMapper): 392 """Fetch mapper for dicts.""" 393 394 def __init__(self, fetches): 395 """Creates a _DictFetchMapper. 396 397 Args: 398 fetches: Dict of fetches. 399 """ 400 self._fetch_type = type(fetches) 401 self._keys = fetches.keys() 402 self._mappers = [ 403 _FetchMapper.for_fetch(fetch) for fetch in fetches.values() 404 ] 405 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 406 407 def unique_fetches(self): 408 return self._unique_fetches 409 410 def build_results(self, values): 411 results = self._fetch_type() 412 for k, m, vi in zip(self._keys, self._mappers, self._value_indices): 413 results[k] = m.build_results([values[j] for j in vi]) 414 return results 415 416 417class _AttrsFetchMapper(_FetchMapper): 418 """Fetch mapper for attrs decorated classes.""" 419 420 def __init__(self, fetches): 421 """Creates a _AttrsFetchMapper. 422 423 Args: 424 fetches: An instance of an attrs decorated class. 425 """ 426 values = _get_attrs_values(fetches) 427 self._fetch_type = type(fetches) 428 self._mappers = [ 429 _FetchMapper.for_fetch(fetch) for fetch in values 430 ] 431 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) 432 433 def unique_fetches(self): 434 return self._unique_fetches 435 436 def build_results(self, values): 437 results = [] 438 for m, vi in zip(self._mappers, self._value_indices): 439 results.append(m.build_results([values[j] for j in vi])) 440 return self._fetch_type(*results) 441 442 443class _FetchHandler(object): 444 """Handler for structured fetches. 445 446 Given a graph, a user-provided structure for fetches, and a feed dict, this 447 class takes care of generating a list of tensor names to fetch and op names 448 to run for a low level `run()` call. 449 450 Given the results of the low level run call, this class can also rebuild a 451 result structure matching the user-provided structure for fetches, but 452 containing the corresponding results. 453 """ 454 455 # TODO(touts): Make this class also take care of destructuring the feed 456 # dict instead of doing it in the callers. 457 458 def __init__(self, graph, fetches, feeds, feed_handles=None): 459 """Creates a fetch handler. 460 461 Args: 462 graph: Graph of the fetches. Used to check for fetchability 463 and to convert all fetches to tensors or ops as needed. 464 fetches: An arbitrary fetch structure: singleton, list, tuple, 465 namedtuple, or dict. 466 feeds: A feed dict where keys are Tensors. 467 feed_handles: A dict from feed Tensors to TensorHandle objects used as 468 direct feeds. 469 """ 470 with graph.as_default(): 471 self._fetch_mapper = _FetchMapper.for_fetch(fetches) 472 self._fetches = [] 473 self._targets = [] 474 self._feeds = feeds 475 self._feed_handles = feed_handles or {} 476 self._ops = [] 477 self._fetch_handles = {} 478 for fetch in self._fetch_mapper.unique_fetches(): 479 if isinstance(fetch, ops.Operation): 480 self._assert_fetchable(graph, fetch) 481 self._targets.append(fetch) 482 self._ops.append(True) 483 else: 484 self._assert_fetchable(graph, fetch.op) 485 self._fetches.append(fetch) 486 self._ops.append(False) 487 # Remember the fetch if it is for a tensor handle. 488 if (isinstance(fetch, ops.Tensor) and 489 (fetch.op.type == 'GetSessionHandle' or 490 fetch.op.type == 'GetSessionHandleV2')): 491 self._fetch_handles[fetch] = fetch.op.inputs[0].dtype 492 self._final_fetches = [x for x in self._fetches if x not in feeds] 493 494 def _assert_fetchable(self, graph, op): 495 if not graph.is_fetchable(op): 496 raise ValueError( 497 'Operation %r has been marked as not fetchable.' % op.name) 498 499 def fetches(self): 500 """Return the unique names of tensors to fetch. 501 502 Returns: 503 A list of strings. 504 """ 505 return self._final_fetches 506 507 def targets(self): 508 """Return the unique names of ops to run. 509 510 Returns: 511 A list of strings. 512 """ 513 return self._targets 514 515 def build_results(self, session, tensor_values): 516 """Build results matching the original fetch shape. 517 518 `tensor_values` must be a list of the same length as 519 the one returned by `fetches()`, and holding the requested 520 fetch values. 521 522 This method builds a struct with the same shape as the original `fetches` 523 passed to the constructor, in which the fetches are replaced by their 524 fetched value. 525 526 Args: 527 session: The enclosing session. Used for tensor handles. 528 tensor_values: List of values matching the list returned 529 by fetches(). 530 531 Returns: 532 A structure of the same shape as the original `fetches` argument but 533 containing tensors or None (for fetched ops). 534 """ 535 full_values = [] 536 assert len(self._final_fetches) == len(tensor_values) 537 i = 0 538 j = 0 539 for is_op in self._ops: 540 if is_op: 541 full_values.append(None) 542 else: 543 # If the fetch was in the feeds, use the fed value, otherwise 544 # use the returned value. 545 if self._fetches[i] in self._feed_handles: 546 # A fetch had a corresponding direct TensorHandle feed. Call eval() 547 # to obtain the Tensor value from the TensorHandle. 548 value = self._feed_handles[self._fetches[i]].eval() 549 else: 550 value = self._feeds.get(self._fetches[i]) 551 if value is None: 552 value = tensor_values[j] 553 j += 1 554 dtype = self._fetch_handles.get(self._fetches[i]) 555 if dtype: 556 full_values.append(session_ops.TensorHandle(value, dtype, session)) 557 else: 558 full_values.append(value) 559 i += 1 560 assert j == len(tensor_values) 561 return self._fetch_mapper.build_results(full_values) 562 563 564def _name_list(tensor_list): 565 """Utility function for transitioning to the new session API. 566 567 Args: 568 tensor_list: a list of `Tensor`s. 569 570 Returns: 571 A list of each `Tensor`s name (as byte arrays). 572 """ 573 return [compat.as_bytes(t.name) for t in tensor_list] 574 575 576class _DeviceAttributes(object): 577 """Struct-like object describing a device's attributes. 578 579 Each device has 3 key properties: 580 - name: the fully-qualified TensorFlow path to the device. For 581 example: /job:worker/replica:0/task:3/device:CPU:0 582 - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.) 583 - memory_limit_bytes: the maximum amount of memory available on the device 584 (in bytes). 585 """ 586 587 def __init__(self, name, device_type, memory_limit_bytes, incarnation): 588 self._name = device.canonical_name(name) 589 self._device_type = device_type 590 self._memory_limit_bytes = memory_limit_bytes 591 self._incarnation = incarnation 592 593 @property 594 def name(self): 595 return self._name 596 597 @property 598 def device_type(self): 599 return self._device_type 600 601 @property 602 def memory_limit_bytes(self): 603 return self._memory_limit_bytes 604 605 @property 606 def incarnation(self): 607 return self._incarnation 608 609 def __repr__(self): 610 return '_DeviceAttributes(%s, %s, %d, %d)' % ( 611 self.name, 612 self.device_type, 613 self.memory_limit_bytes, 614 self.incarnation, 615 ) 616 617 618class BaseSession(SessionInterface): 619 """A class for interacting with a TensorFlow computation. 620 621 The BaseSession enables incremental graph building with inline 622 execution of Operations and evaluation of Tensors. 623 """ 624 625 def __init__(self, target='', graph=None, config=None): 626 """Constructs a new TensorFlow session. 627 628 Args: 629 target: (Optional) The TensorFlow execution engine to connect to. 630 graph: (Optional) The graph to be used. If this argument is None, 631 the default graph will be used. 632 config: (Optional) ConfigProto proto used to configure the session. 633 634 Raises: 635 tf.errors.OpError: Or one of its subclasses if an error occurs while 636 creating the TensorFlow session. 637 TypeError: If one of the arguments has the wrong type. 638 """ 639 if graph is None: 640 self._graph = ops.get_default_graph() 641 else: 642 if not isinstance(graph, ops.Graph): 643 raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) 644 self._graph = graph 645 646 self._opened = False 647 self._closed = False 648 649 self._current_version = 0 650 self._extend_lock = threading.Lock() 651 if target is not None: 652 try: 653 self._target = compat.as_bytes(target) 654 except TypeError: 655 raise TypeError('target must be a string, but got %s' % type(target)) 656 else: 657 self._target = None 658 659 self._delete_lock = threading.Lock() 660 self._dead_handles = [] 661 662 if config is not None: 663 if not isinstance(config, config_pb2.ConfigProto): 664 raise TypeError( 665 'config must be a tf.ConfigProto, but got %s' % type(config)) 666 self._config = config 667 self._add_shapes = config.graph_options.infer_shapes 668 else: 669 self._config = None 670 self._add_shapes = False 671 672 self._session = None 673 opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) 674 try: 675 # pylint: disable=protected-access 676 self._session = tf_session.TF_NewSessionRef(self._graph._c_graph, opts) 677 # pylint: enable=protected-access 678 finally: 679 tf_session.TF_DeleteSessionOptions(opts) 680 681 def list_devices(self): 682 """Lists available devices in this session. 683 684 ```python 685 devices = sess.list_devices() 686 for d in devices: 687 print(d.name) 688 ``` 689 690 Each element in the list has the following properties: 691 - `name`: A string with the full name of the device. ex: 692 `/job:worker/replica:0/task:3/device:CPU:0` 693 - `device_type`: The type of the device (e.g. `CPU`, `GPU`, `TPU`.) 694 - `memory_limit`: The maximum amount of memory available on the device. 695 Note: depending on the device, it is possible the usable memory could 696 be substantially less. 697 Raises: 698 tf.errors.OpError: If it encounters an error (e.g. session is in an 699 invalid state, or network errors occur). 700 701 Returns: 702 A list of devices in the session. 703 """ 704 raw_device_list = tf_session.TF_SessionListDevices(self._session) 705 device_list = [] 706 size = tf_session.TF_DeviceListCount(raw_device_list) 707 for i in range(size): 708 name = tf_session.TF_DeviceListName(raw_device_list, i) 709 device_type = tf_session.TF_DeviceListType(raw_device_list, i) 710 memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i) 711 incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i) 712 device_list.append( 713 _DeviceAttributes(name, device_type, memory, incarnation)) 714 tf_session.TF_DeleteDeviceList(raw_device_list) 715 return device_list 716 717 def close(self): 718 """Closes this session. 719 720 Calling this method frees all resources associated with the session. 721 722 Raises: 723 tf.errors.OpError: Or one of its subclasses if an error occurs while 724 closing the TensorFlow session. 725 """ 726 if self._session and not self._closed: 727 self._closed = True 728 tf_session.TF_CloseSession(self._session) 729 730 def __del__(self): 731 # cleanly ignore all exceptions 732 try: 733 self.close() 734 except Exception: # pylint: disable=broad-except 735 pass 736 if self._session is not None: 737 try: 738 tf_session.TF_DeleteSession(self._session) 739 except (AttributeError, TypeError): 740 # At shutdown, `c_api_util`, `tf_session`, or 741 # `tf_session.TF_DeleteSession` may have been garbage collected, causing 742 # the above method calls to fail. In this case, silently leak since the 743 # program is about to terminate anyway. 744 pass 745 self._session = None 746 747 @property 748 def graph(self): 749 """The graph that was launched in this session.""" 750 return self._graph 751 752 @property 753 def graph_def(self): 754 """A serializable version of the underlying TensorFlow graph. 755 756 Returns: 757 A graph_pb2.GraphDef proto containing nodes for all of the Operations in 758 the underlying TensorFlow graph. 759 """ 760 return self._graph.as_graph_def(add_shapes=self._add_shapes) 761 762 @property 763 def sess_str(self): 764 return self._target 765 766 def as_default(self): 767 """Returns a context manager that makes this object the default session. 768 769 Use with the `with` keyword to specify that calls to 770 `tf.Operation.run` or `tf.Tensor.eval` should be executed in 771 this session. 772 773 ```python 774 c = tf.constant(..) 775 sess = tf.Session() 776 777 with sess.as_default(): 778 assert tf.get_default_session() is sess 779 print(c.eval()) 780 ``` 781 782 To get the current default session, use `tf.get_default_session`. 783 784 *N.B.* The `as_default` context manager *does not* close the 785 session when you exit the context, and you must close the session 786 explicitly. 787 788 ```python 789 c = tf.constant(...) 790 sess = tf.Session() 791 with sess.as_default(): 792 print(c.eval()) 793 # ... 794 with sess.as_default(): 795 print(c.eval()) 796 797 sess.close() 798 ``` 799 800 Alternatively, you can use `with tf.Session():` to create a 801 session that is automatically closed on exiting the context, 802 including when an uncaught exception is raised. 803 804 *N.B.* The default session is a property of the current thread. If you 805 create a new thread, and wish to use the default session in that 806 thread, you must explicitly add a `with sess.as_default():` in that 807 thread's function. 808 809 *N.B.* Entering a `with sess.as_default():` block does not affect 810 the current default graph. If you are using multiple graphs, and 811 `sess.graph` is different from the value of `tf.get_default_graph`, 812 you must explicitly enter a `with sess.graph.as_default():` block 813 to make `sess.graph` the default graph. 814 815 Returns: 816 A context manager using this session as the default session. 817 """ 818 return ops.default_session(self) 819 820 def run(self, fetches, feed_dict=None, options=None, run_metadata=None): 821 """Runs operations and evaluates tensors in `fetches`. 822 823 This method runs one "step" of TensorFlow computation, by 824 running the necessary graph fragment to execute every `Operation` 825 and evaluate every `Tensor` in `fetches`, substituting the values in 826 `feed_dict` for the corresponding input values. 827 828 The `fetches` argument may be a single graph element, or an arbitrarily 829 nested list, tuple, namedtuple, dict, or OrderedDict containing graph 830 elements at its leaves. A graph element can be one of the following types: 831 832 * A `tf.Operation`. 833 The corresponding fetched value will be `None`. 834 * A `tf.Tensor`. 835 The corresponding fetched value will be a numpy ndarray containing the 836 value of that tensor. 837 * A `tf.SparseTensor`. 838 The corresponding fetched value will be a 839 `tf.SparseTensorValue` 840 containing the value of that sparse tensor. 841 * A `get_tensor_handle` op. The corresponding fetched value will be a 842 numpy ndarray containing the handle of that tensor. 843 * A `string` which is the name of a tensor or operation in the graph. 844 845 The value returned by `run()` has the same shape as the `fetches` argument, 846 where the leaves are replaced by the corresponding values returned by 847 TensorFlow. 848 849 Example: 850 851 ```python 852 a = tf.constant([10, 20]) 853 b = tf.constant([1.0, 2.0]) 854 # 'fetches' can be a singleton 855 v = session.run(a) 856 # v is the numpy array [10, 20] 857 # 'fetches' can be a list. 858 v = session.run([a, b]) 859 # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the 860 # 1-D array [1.0, 2.0] 861 # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts: 862 MyData = collections.namedtuple('MyData', ['a', 'b']) 863 v = session.run({'k1': MyData(a, b), 'k2': [b, a]}) 864 # v is a dict with 865 # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and 866 # 'b' (the numpy array [1.0, 2.0]) 867 # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array 868 # [10, 20]. 869 ``` 870 871 The optional `feed_dict` argument allows the caller to override 872 the value of tensors in the graph. Each key in `feed_dict` can be 873 one of the following types: 874 875 * If the key is a `tf.Tensor`, the 876 value may be a Python scalar, string, list, or numpy ndarray 877 that can be converted to the same `dtype` as that 878 tensor. Additionally, if the key is a 879 `tf.placeholder`, the shape of 880 the value will be checked for compatibility with the placeholder. 881 * If the key is a 882 `tf.SparseTensor`, 883 the value should be a 884 `tf.SparseTensorValue`. 885 * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value 886 should be a nested tuple with the same structure that maps to their 887 corresponding values as above. 888 889 Each value in `feed_dict` must be convertible to a numpy array of the dtype 890 of the corresponding key. 891 892 The optional `options` argument expects a [`RunOptions`] proto. The options 893 allow controlling the behavior of this particular step (e.g. turning tracing 894 on). 895 896 The optional `run_metadata` argument expects a [`RunMetadata`] proto. When 897 appropriate, the non-Tensor output of this step will be collected there. For 898 example, when users turn on tracing in `options`, the profiled info will be 899 collected into this argument and passed back. 900 901 Args: 902 fetches: A single graph element, a list of graph elements, 903 or a dictionary whose values are graph elements or lists of graph 904 elements (described above). 905 feed_dict: A dictionary that maps graph elements to values 906 (described above). 907 options: A [`RunOptions`] protocol buffer 908 run_metadata: A [`RunMetadata`] protocol buffer 909 910 Returns: 911 Either a single value if `fetches` is a single graph element, or 912 a list of values if `fetches` is a list, or a dictionary with the 913 same keys as `fetches` if that is a dictionary (described above). 914 Order in which `fetches` operations are evaluated inside the call 915 is undefined. 916 917 Raises: 918 RuntimeError: If this `Session` is in an invalid state (e.g. has been 919 closed). 920 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 921 ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a 922 `Tensor` that doesn't exist. 923 """ 924 options_ptr = tf_session.TF_NewBufferFromString( 925 compat.as_bytes(options.SerializeToString())) if options else None 926 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 927 928 try: 929 result = self._run(None, fetches, feed_dict, options_ptr, 930 run_metadata_ptr) 931 if run_metadata: 932 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 933 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 934 finally: 935 if run_metadata_ptr: 936 tf_session.TF_DeleteBuffer(run_metadata_ptr) 937 if options: 938 tf_session.TF_DeleteBuffer(options_ptr) 939 return result 940 941 def partial_run(self, handle, fetches, feed_dict=None): 942 """Continues the execution with more feeds and fetches. 943 944 This is EXPERIMENTAL and subject to change. 945 946 To use partial execution, a user first calls `partial_run_setup()` and 947 then a sequence of `partial_run()`. `partial_run_setup` specifies the 948 list of feeds and fetches that will be used in the subsequent 949 `partial_run` calls. 950 951 The optional `feed_dict` argument allows the caller to override 952 the value of tensors in the graph. See run() for more information. 953 954 Below is a simple example: 955 956 ```python 957 a = array_ops.placeholder(dtypes.float32, shape=[]) 958 b = array_ops.placeholder(dtypes.float32, shape=[]) 959 c = array_ops.placeholder(dtypes.float32, shape=[]) 960 r1 = math_ops.add(a, b) 961 r2 = math_ops.multiply(r1, c) 962 963 h = sess.partial_run_setup([r1, r2], [a, b, c]) 964 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) 965 res = sess.partial_run(h, r2, feed_dict={c: res}) 966 ``` 967 968 Args: 969 handle: A handle for a sequence of partial runs. 970 fetches: A single graph element, a list of graph elements, 971 or a dictionary whose values are graph elements or lists of graph 972 elements (see documentation for `run`). 973 feed_dict: A dictionary that maps graph elements to values 974 (described above). 975 976 Returns: 977 Either a single value if `fetches` is a single graph element, or 978 a list of values if `fetches` is a list, or a dictionary with the 979 same keys as `fetches` if that is a dictionary 980 (see documentation for `run`). 981 982 Raises: 983 tf.errors.OpError: Or one of its subclasses on error. 984 """ 985 # TODO(touts): Support feeding and fetching the same tensor. 986 return self._run(handle, fetches, feed_dict, None, None) 987 988 def partial_run_setup(self, fetches, feeds=None): 989 """Sets up a graph with feeds and fetches for partial run. 990 991 This is EXPERIMENTAL and subject to change. 992 993 Note that contrary to `run`, `feeds` only specifies the graph elements. 994 The tensors will be supplied by the subsequent `partial_run` calls. 995 996 Args: 997 fetches: A single graph element, or a list of graph elements. 998 feeds: A single graph element, or a list of graph elements. 999 1000 Returns: 1001 A handle for partial run. 1002 1003 Raises: 1004 RuntimeError: If this `Session` is in an invalid state (e.g. has been 1005 closed). 1006 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. 1007 tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens. 1008 """ 1009 1010 def _feed_fn(feed): 1011 for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: 1012 if isinstance(feed, tensor_type): 1013 return feed_fn(feed) 1014 raise TypeError('Feed argument %r has invalid type %r' % (feed, 1015 type(feed))) 1016 1017 # Check session. 1018 if self._closed: 1019 raise RuntimeError('Attempted to use a closed Session.') 1020 if self.graph.version == 0: 1021 raise RuntimeError('The Session graph is empty. Add operations to the ' 1022 'graph before calling run().') 1023 1024 if feeds is None: 1025 feeds = [] 1026 # Create request. 1027 feed_list = [] 1028 1029 # Validate and process feed_list. 1030 is_list_feed = isinstance(feeds, (list, tuple)) 1031 if not is_list_feed: 1032 feeds = [feeds] 1033 for feed in feeds: 1034 for subfeed in _feed_fn(feed): 1035 try: 1036 subfeed_t = self.graph.as_graph_element( 1037 subfeed, allow_tensor=True, allow_operation=False) 1038 # pylint: disable=protected-access 1039 feed_list.append(subfeed_t._as_tf_output()) 1040 # pylint: enable=protected-access 1041 except Exception as e: 1042 e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) 1043 e.args = (e.message,) 1044 raise e 1045 1046 # Validate and process fetches. 1047 # TODO(touts): Support feeding and fetching the same tensor. 1048 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1049 1050 # Set up a graph with feeds and fetches for partial run. 1051 def _setup_fn(session, feed_list, fetch_list, target_list): 1052 self._extend_graph() 1053 return tf_session.TF_SessionPRunSetup_wrapper( 1054 session, feed_list, fetch_list, target_list) 1055 1056 # pylint: disable=protected-access 1057 final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] 1058 final_targets = [op._c_op for op in fetch_handler.targets()] 1059 # pylint: enable=protected-access 1060 1061 return self._do_call(_setup_fn, self._session, feed_list, final_fetches, 1062 final_targets) 1063 1064 def _run(self, handle, fetches, feed_dict, options, run_metadata): 1065 """Perform either run or partial_run, depending the presence of `handle`.""" 1066 1067 def _feed_fn(feed, feed_val): 1068 for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS: 1069 if isinstance(feed, tensor_type): 1070 return feed_fn(feed, feed_val) 1071 raise TypeError('Feed argument %r has invalid type %r' % (feed, 1072 type(feed))) 1073 1074 # Check session. 1075 if self._closed: 1076 raise RuntimeError('Attempted to use a closed Session.') 1077 if self.graph.version == 0: 1078 raise RuntimeError('The Session graph is empty. Add operations to the ' 1079 'graph before calling run().') 1080 1081 # Create request. 1082 feed_dict_tensor = {} 1083 feed_map = {} 1084 1085 # Validate and process feed_dict. 1086 feed_handles = {} 1087 if feed_dict: 1088 feed_dict = nest.flatten_dict_items(feed_dict) 1089 for feed, feed_val in feed_dict.items(): 1090 for subfeed, subfeed_val in _feed_fn(feed, feed_val): 1091 try: 1092 subfeed_t = self.graph.as_graph_element( 1093 subfeed, allow_tensor=True, allow_operation=False) 1094 except Exception as e: 1095 raise TypeError( 1096 'Cannot interpret feed_dict key as Tensor: ' + e.args[0]) 1097 1098 if isinstance(subfeed_val, ops.Tensor): 1099 raise TypeError('The value of a feed cannot be a tf.Tensor object. ' 1100 'Acceptable feed values include Python scalars, ' 1101 'strings, lists, numpy ndarrays, or TensorHandles. ' 1102 'For reference, the tensor object was ' + 1103 str(feed_val) + ' which was passed to the ' 1104 'feed with key ' + str(feed) + '.') 1105 1106 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype 1107 if isinstance(subfeed_val, int) and _convert_to_numpy_obj( 1108 subfeed_dtype, subfeed_val) != subfeed_val: 1109 raise TypeError( 1110 'Type of feed value ' + str(subfeed_val) + ' with type ' + str( 1111 type(subfeed_val)) + 1112 ' is not compatible with Tensor type ' + str(subfeed_dtype) + 1113 '. Try explicitly setting the type of the feed tensor' 1114 ' to a larger type (e.g. int64).') 1115 1116 is_tensor_handle_feed = isinstance(subfeed_val, 1117 session_ops.TensorHandle) 1118 if is_tensor_handle_feed: 1119 np_val = subfeed_val.to_numpy_array() 1120 feed_handles[subfeed_t] = subfeed_val 1121 else: 1122 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) 1123 1124 if (not is_tensor_handle_feed and 1125 not subfeed_t.get_shape().is_compatible_with(np_val.shape)): 1126 raise ValueError('Cannot feed value of shape %r for Tensor %r, ' 1127 'which has shape %r' % 1128 (np_val.shape, subfeed_t.name, 1129 str(subfeed_t.get_shape()))) 1130 if not self.graph.is_feedable(subfeed_t): 1131 raise ValueError('Tensor %s may not be fed.' % subfeed_t) 1132 1133 feed_dict_tensor[subfeed_t] = np_val 1134 feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val) 1135 1136 # Create a fetch handler to take care of the structure of fetches. 1137 fetch_handler = _FetchHandler( 1138 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles) 1139 1140 # Run request and get response. 1141 # We need to keep the returned movers alive for the following _do_run(). 1142 # These movers are no longer needed when _do_run() completes, and 1143 # are deleted when `movers` goes out of scope when this _run() ends. 1144 # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding 1145 # of a handle from a different device as an error. 1146 _ = self._update_with_movers(feed_dict_tensor, feed_map) 1147 final_fetches = fetch_handler.fetches() 1148 final_targets = fetch_handler.targets() 1149 # We only want to really perform the run if fetches or targets are provided, 1150 # or if the call is a partial run that specifies feeds. 1151 if final_fetches or final_targets or (handle and feed_dict_tensor): 1152 results = self._do_run(handle, final_targets, final_fetches, 1153 feed_dict_tensor, options, run_metadata) 1154 else: 1155 results = [] 1156 return fetch_handler.build_results(self, results) 1157 1158 def make_callable(self, fetches, feed_list=None, accept_options=False): 1159 """Returns a Python callable that runs a particular step. 1160 1161 The returned callable will take `len(feed_list)` arguments whose types 1162 must be compatible feed values for the respective elements of `feed_list`. 1163 For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th 1164 argument to the returned callable must be a numpy ndarray (or something 1165 convertible to an ndarray) with matching element type and shape. See 1166 `tf.Session.run` for details of the allowable feed key and value types. 1167 1168 The returned callable will have the same return type as 1169 `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, 1170 the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`, 1171 it will return `None`. 1172 1173 Args: 1174 fetches: A value or list of values to fetch. See `tf.Session.run` 1175 for details of the allowable fetch types. 1176 feed_list: (Optional.) A list of `feed_dict` keys. See 1177 `tf.Session.run` for details of the allowable feed key types. 1178 accept_options: (Optional.) If `True`, the returned `Callable` will be 1179 able to accept `tf.RunOptions` and `tf.RunMetadata` as optional 1180 keyword arguments `options` and `run_metadata`, respectively, with 1181 the same syntax and semantics as `tf.Session.run`, which is useful 1182 for certain use cases (profiling and debugging) but will result in 1183 measurable slowdown of the `Callable`'s performance. Default: `False`. 1184 1185 Returns: 1186 A function that when called will execute the step defined by 1187 `feed_list` and `fetches` in this session. 1188 1189 Raises: 1190 TypeError: If `fetches` or `feed_list` cannot be interpreted 1191 as arguments to `tf.Session.run`. 1192 """ 1193 if feed_list is not None: 1194 if not isinstance(feed_list, (list, tuple)): 1195 raise TypeError('`feed_list` must be a list or tuple.') 1196 # Delegate any non-empty feed lists to the existing `run()` logic. 1197 # TODO(mrry): Refactor the feed handling logic from 1198 # `Session._run()` so that we can convert the feeds to a list of 1199 # strings here. 1200 def _generic_run(*feed_args, **kwargs): 1201 feed_dict = { 1202 feed: feed_val 1203 for feed, feed_val in zip(feed_list, feed_args) 1204 } 1205 return self.run(fetches, feed_dict=feed_dict, **kwargs) 1206 1207 return _generic_run 1208 1209 # Ensure any changes to the graph are reflected in the runtime. 1210 # Note that we don't need to do this on subsequent calls to the 1211 # returned object, because the arguments to `fetches` must already be 1212 # in the graph. 1213 self._extend_graph() 1214 1215 # Create a fetch handler to take care of the structure of fetches. 1216 fetch_handler = _FetchHandler(self._graph, fetches, {}) 1217 # pylint: disable=protected-access 1218 fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] 1219 target_list = [op._c_op for op in fetch_handler.targets()] 1220 # pylint: enable=protected-access 1221 1222 def _callable_template_with_options_and_metadata(fetch_list, 1223 target_list, 1224 fetch_handler, 1225 options=None, 1226 run_metadata=None): 1227 """Template callable that accepts RunOptions and RunMetadata.""" 1228 options_ptr = tf_session.TF_NewBufferFromString( 1229 compat.as_bytes(options.SerializeToString())) if options else None 1230 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1231 try: 1232 results = self._call_tf_sessionrun( 1233 options_ptr, {}, fetch_list, target_list, run_metadata_ptr) 1234 if fetch_handler: 1235 results = fetch_handler.build_results(self, results) 1236 else: 1237 results = results[0] if results else None 1238 if run_metadata: 1239 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1240 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1241 finally: 1242 if run_metadata_ptr: 1243 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1244 if options: 1245 tf_session.TF_DeleteBuffer(options_ptr) 1246 return results 1247 1248 if accept_options: 1249 return functools.partial(_callable_template_with_options_and_metadata, 1250 fetch_list, target_list, fetch_handler) 1251 elif isinstance(fetches, ops.Operation): 1252 # Special case for fetching a single operation, because the 1253 # function will have no return value. 1254 assert not fetch_list 1255 assert len(target_list) == 1 1256 1257 def _single_operation_run(): 1258 self._call_tf_sessionrun(None, {}, [], target_list, None) 1259 1260 return _single_operation_run 1261 elif isinstance(fetches, ops.Tensor): 1262 # Special case for fetching a single tensor, because the 1263 # function can return the result of `TF_Run()` directly. 1264 assert len(fetch_list) == 1 1265 assert not target_list 1266 1267 def _single_tensor_run(): 1268 results = self._call_tf_sessionrun(None, {}, fetch_list, [], None) 1269 return results[0] 1270 1271 return _single_tensor_run 1272 else: 1273 # In all other cases, we must use `fetch_handler` to build the 1274 # results for us. 1275 def _fetch_handler_run(): 1276 results = self._call_tf_sessionrun( 1277 None, {}, fetch_list, target_list, None) 1278 return fetch_handler.build_results(self, results) 1279 1280 return _fetch_handler_run 1281 1282 # Captures the name of a node in an error status. The regex below matches 1283 # both the old and the new formats: 1284 # Old format: [[Node: <node_name> = ...]] 1285 # New format: [[{{node <node_name>}} = ...]] 1286 _NODEDEF_NAME_RE = re.compile( 1287 r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*') 1288 1289 def _do_run(self, handle, target_list, fetch_list, feed_dict, options, 1290 run_metadata): 1291 """Runs a step based on the given fetches and feeds. 1292 1293 Args: 1294 handle: a handle for partial_run. None if this is just a call to run(). 1295 target_list: A list of operations to be run, but not fetched. 1296 fetch_list: A list of tensors to be fetched. 1297 feed_dict: A dictionary that maps tensors to numpy ndarrays. 1298 options: A (pointer to a) [`RunOptions`] protocol buffer, or None 1299 run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None 1300 1301 Returns: 1302 A list of numpy ndarrays, corresponding to the elements of 1303 `fetch_list`. If the ith element of `fetch_list` contains the 1304 name of an operation, the first Tensor output of that operation 1305 will be returned for that element. 1306 1307 Raises: 1308 tf.errors.OpError: Or one of its subclasses on error. 1309 """ 1310 # pylint: disable=protected-access 1311 feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items()) 1312 fetches = [t._as_tf_output() for t in fetch_list] 1313 targets = [op._c_op for op in target_list] 1314 # pylint: enable=protected-access 1315 1316 def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata): 1317 # Ensure any changes to the graph are reflected in the runtime. 1318 self._extend_graph() 1319 return self._call_tf_sessionrun( 1320 options, feed_dict, fetch_list, target_list, run_metadata) 1321 1322 def _prun_fn(handle, feed_dict, fetch_list): 1323 if target_list: 1324 raise RuntimeError('partial_run() requires empty target_list.') 1325 return self._call_tf_sessionprun(handle, feed_dict, fetch_list) 1326 1327 if handle is None: 1328 return self._do_call(_run_fn, feeds, fetches, targets, options, 1329 run_metadata) 1330 else: 1331 return self._do_call(_prun_fn, handle, feeds, fetches) 1332 1333 def _do_call(self, fn, *args): 1334 try: 1335 return fn(*args) 1336 except errors.OpError as e: 1337 message = compat.as_text(e.message) 1338 m = BaseSession._NODEDEF_NAME_RE.search(message) 1339 node_def = None 1340 op = None 1341 if m is not None: 1342 node_name = m.group(3) 1343 try: 1344 op = self._graph.get_operation_by_name(node_name) 1345 node_def = op.node_def 1346 except KeyError: 1347 pass 1348 message = error_interpolation.interpolate(message, self._graph) 1349 raise type(e)(node_def, op, message) 1350 1351 def _extend_graph(self): 1352 with self._graph._session_run_lock(): # pylint: disable=protected-access 1353 tf_session.ExtendSession(self._session) 1354 1355 # The threshold to run garbage collection to delete dead tensors. 1356 _DEAD_HANDLES_THRESHOLD = 10 1357 1358 def _register_dead_handle(self, handle): 1359 # Register a dead handle in the session. Delete the dead tensors when 1360 # the number of dead tensors exceeds certain threshold. 1361 tensors_to_delete = None 1362 with self._delete_lock: 1363 self._dead_handles.append(handle) 1364 if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD: 1365 tensors_to_delete = self._dead_handles 1366 self._dead_handles = [] 1367 # Delete the dead tensors. 1368 if tensors_to_delete: 1369 feeds = {} 1370 fetches = [] 1371 for deleter_key, tensor_handle in enumerate(tensors_to_delete): 1372 holder, deleter = session_ops._get_handle_deleter( 1373 self.graph, deleter_key, tensor_handle) 1374 feeds[holder] = tensor_handle 1375 fetches.append(deleter) 1376 self.run(fetches, feed_dict=feeds) 1377 1378 def _update_with_movers(self, feed_dict, feed_map): 1379 # If a tensor handle that is fed to a device incompatible placeholder, 1380 # we move the tensor to the right device, generate a new tensor handle, 1381 # and update `feed_dict` to use the new handle. 1382 handle_movers = [] 1383 for feed_name, val in feed_map.items(): 1384 mover = session_ops._get_handle_mover(self.graph, *val) 1385 if mover: 1386 handle_movers.append((feed_name, val[1], mover)) 1387 # Transfer a tensor to the right device if needed. 1388 if not handle_movers: 1389 return [] 1390 else: 1391 feeds = {} 1392 fetches = [] 1393 for _, handle, mover in handle_movers: 1394 feeds[mover[0]] = handle 1395 fetches.append(mover[1]) 1396 handles = self.run(fetches, feed_dict=feeds) 1397 for handle_mover, handle in zip(handle_movers, handles): 1398 np_val = np.array(handle.handle, dtype=np.object) 1399 feed_name = handle_mover[0] 1400 feed_tensor = feed_map[feed_name][0] 1401 feed_dict[feed_tensor] = np_val 1402 return handles 1403 1404 def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, 1405 run_metadata): 1406 return tf_session.TF_SessionRun_wrapper( 1407 self._session, options, feed_dict, fetch_list, target_list, 1408 run_metadata) 1409 1410 def _call_tf_sessionprun(self, handle, feed_dict, fetch_list): 1411 return tf_session.TF_SessionPRun_wrapper( 1412 self._session, handle, feed_dict, fetch_list) 1413 1414 # pylint: disable=protected-access 1415 class _Callable(object): 1416 """Experimental wrapper for the C++ `Session::MakeCallable()` API.""" 1417 1418 def __init__(self, session, callable_options): 1419 self._session = session 1420 self._handle = None 1421 options_ptr = tf_session.TF_NewBufferFromString( 1422 compat.as_bytes(callable_options.SerializeToString())) 1423 try: 1424 with errors.raise_exception_on_not_ok_status() as status: 1425 self._handle = tf_session.TF_SessionMakeCallable( 1426 session._session, options_ptr, status) 1427 finally: 1428 tf_session.TF_DeleteBuffer(options_ptr) 1429 1430 def __call__(self, *args, **kwargs): 1431 # TODO(b/74355905): Support argument and return value nested structures, 1432 # and tensor-like objects such as SparseTensors. 1433 run_metadata = kwargs.get('run_metadata', None) 1434 try: 1435 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None 1436 # TODO(mrry): Switch to raising an exception from the SWIG wrapper. 1437 with errors.raise_exception_on_not_ok_status() as status: 1438 ret = tf_session.TF_SessionRunCallable( 1439 self._session._session, self._handle, args, status, 1440 run_metadata_ptr) 1441 if run_metadata: 1442 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) 1443 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 1444 finally: 1445 if run_metadata_ptr: 1446 tf_session.TF_DeleteBuffer(run_metadata_ptr) 1447 return ret 1448 1449 def __del__(self): 1450 # NOTE(mrry): It is possible that `self._session.__del__()` could be 1451 # called before this destructor, in which case `self._session._session` 1452 # will be `None`. 1453 if self._handle is not None and self._session._session is not None: 1454 with errors.raise_exception_on_not_ok_status() as status: 1455 tf_session.TF_SessionReleaseCallable( 1456 self._session._session, self._handle, status) 1457 # pylint: enable=protected-access 1458 1459 # TODO(b/74355905): Reimplement `Session.make_callable()` using this method 1460 # where possible. 1461 def _make_callable_from_options(self, callable_options): 1462 """Returns a handle to a "callable" with the given options. 1463 1464 Args: 1465 callable_options: A `CallableOptions` protocol buffer message describing 1466 the computation that will be performed by the callable. 1467 1468 Returns: 1469 A handle to the new callable. 1470 """ 1471 self._extend_graph() 1472 return BaseSession._Callable(self, callable_options) 1473 1474 1475@tf_export(v1=['Session']) 1476class Session(BaseSession): 1477 """A class for running TensorFlow operations. 1478 1479 A `Session` object encapsulates the environment in which `Operation` 1480 objects are executed, and `Tensor` objects are evaluated. For 1481 example: 1482 1483 ```python 1484 # Build a graph. 1485 a = tf.constant(5.0) 1486 b = tf.constant(6.0) 1487 c = a * b 1488 1489 # Launch the graph in a session. 1490 sess = tf.Session() 1491 1492 # Evaluate the tensor `c`. 1493 print(sess.run(c)) 1494 ``` 1495 1496 A session may own resources, such as 1497 `tf.Variable`, `tf.QueueBase`, 1498 and `tf.ReaderBase`. It is important to release 1499 these resources when they are no longer required. To do this, either 1500 invoke the `tf.Session.close` method on the session, or use 1501 the session as a context manager. The following two examples are 1502 equivalent: 1503 1504 ```python 1505 # Using the `close()` method. 1506 sess = tf.Session() 1507 sess.run(...) 1508 sess.close() 1509 1510 # Using the context manager. 1511 with tf.Session() as sess: 1512 sess.run(...) 1513 ``` 1514 1515 The 1516 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1517 protocol buffer exposes various configuration options for a 1518 session. For example, to create a session that uses soft constraints 1519 for device placement, and log the resulting placement decisions, 1520 create a session as follows: 1521 1522 ```python 1523 # Launch the graph in a session that allows soft device placement and 1524 # logs the placement decisions. 1525 sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, 1526 log_device_placement=True)) 1527 ``` 1528 """ 1529 1530 def __init__(self, target='', graph=None, config=None): 1531 """Creates a new TensorFlow session. 1532 1533 If no `graph` argument is specified when constructing the session, 1534 the default graph will be launched in the session. If you are 1535 using more than one graph (created with `tf.Graph()`) in the same 1536 process, you will have to use different sessions for each graph, 1537 but each graph can be used in multiple sessions. In this case, it 1538 is often clearer to pass the graph to be launched explicitly to 1539 the session constructor. 1540 1541 Args: 1542 target: (Optional.) The execution engine to connect to. 1543 Defaults to using an in-process engine. See 1544 [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) 1545 for more examples. 1546 graph: (Optional.) The `Graph` to be launched (described above). 1547 config: (Optional.) A 1548 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto) 1549 protocol buffer with configuration options for the session. 1550 1551 """ 1552 super(Session, self).__init__(target, graph, config=config) 1553 # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle. 1554 self._default_graph_context_manager = None 1555 self._default_session_context_manager = None 1556 1557 def __enter__(self): 1558 if self._default_graph_context_manager is None: 1559 self._default_graph_context_manager = self.graph.as_default() 1560 else: 1561 raise RuntimeError('Session context managers are not re-entrant. ' 1562 'Use `Session.as_default()` if you want to enter ' 1563 'a session multiple times.') 1564 if self._default_session_context_manager is None: 1565 self._default_session_context_manager = self.as_default() 1566 self._default_graph_context_manager.__enter__() 1567 return self._default_session_context_manager.__enter__() 1568 1569 def __exit__(self, exec_type, exec_value, exec_tb): 1570 if exec_type is errors.OpError: 1571 logging.error('Session closing due to OpError: %s', (exec_value,)) 1572 try: 1573 self._default_session_context_manager.__exit__(exec_type, exec_value, 1574 exec_tb) 1575 except RuntimeError as error: 1576 if error == exec_value: 1577 # NOTE(skyewm): for some reason, in Python3, 1578 # _default_session_context_manager.__exit__ will re-raise the "not 1579 # re-entrant" exception raised in __enter__ above (note that if we're 1580 # here, we're in the outer session context manager, since __exit__ is 1581 # not called when __enter__ raises an exception). We still want to 1582 # continue cleaning up this context manager before the exception is 1583 # further propagated, so we ignore it here (note that it'll continue 1584 # being propagated after this method completes). 1585 pass 1586 else: 1587 raise 1588 self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb) 1589 1590 self._default_session_context_manager = None 1591 self._default_graph_context_manager = None 1592 1593 # If we are closing due to an exception, set a time limit on our Close() to 1594 # avoid blocking forever. 1595 # TODO(b/120204635) remove this when deadlock is fixed. 1596 if exec_type: 1597 close_thread = threading.Thread( 1598 name='SessionCloseThread', target=self.close) 1599 close_thread.daemon = True 1600 close_thread.start() 1601 close_thread.join(30.0) 1602 if close_thread.is_alive(): 1603 logging.error( 1604 'Session failed to close after 30 seconds. Continuing after this ' 1605 'point may leave your program in an undefined state.') 1606 else: 1607 self.close() 1608 1609 @staticmethod 1610 def reset(target, containers=None, config=None): 1611 """Resets resource containers on `target`, and close all connected sessions. 1612 1613 A resource container is distributed across all workers in the 1614 same cluster as `target`. When a resource container on `target` 1615 is reset, resources associated with that container will be cleared. 1616 In particular, all Variables in the container will become undefined: 1617 they lose their values and shapes. 1618 1619 NOTE: 1620 (i) reset() is currently only implemented for distributed sessions. 1621 (ii) Any sessions on the master named by `target` will be closed. 1622 1623 If no resource containers are provided, all containers are reset. 1624 1625 Args: 1626 target: The execution engine to connect to. 1627 containers: A list of resource container name strings, or `None` if all of 1628 all the containers are to be reset. 1629 config: (Optional.) Protocol buffer with configuration options. 1630 1631 Raises: 1632 tf.errors.OpError: Or one of its subclasses if an error occurs while 1633 resetting containers. 1634 """ 1635 if target is not None: 1636 target = compat.as_bytes(target) 1637 if containers is not None: 1638 containers = [compat.as_bytes(c) for c in containers] 1639 else: 1640 containers = [] 1641 tf_session.TF_Reset(target, containers, config) 1642 1643 1644@tf_export(v1=['InteractiveSession']) 1645class InteractiveSession(BaseSession): 1646 """A TensorFlow `Session` for use in interactive contexts, such as a shell. 1647 1648 The only difference with a regular `Session` is that an `InteractiveSession` 1649 installs itself as the default session on construction. 1650 The methods `tf.Tensor.eval` 1651 and `tf.Operation.run` 1652 will use that session to run ops. 1653 1654 This is convenient in interactive shells and [IPython 1655 notebooks](http://ipython.org), as it avoids having to pass an explicit 1656 `Session` object to run ops. 1657 1658 For example: 1659 1660 ```python 1661 sess = tf.InteractiveSession() 1662 a = tf.constant(5.0) 1663 b = tf.constant(6.0) 1664 c = a * b 1665 # We can just use 'c.eval()' without passing 'sess' 1666 print(c.eval()) 1667 sess.close() 1668 ``` 1669 1670 Note that a regular session installs itself as the default session when it 1671 is created in a `with` statement. The common usage in non-interactive 1672 programs is to follow that pattern: 1673 1674 ```python 1675 a = tf.constant(5.0) 1676 b = tf.constant(6.0) 1677 c = a * b 1678 with tf.Session(): 1679 # We can also use 'c.eval()' here. 1680 print(c.eval()) 1681 ``` 1682 """ 1683 1684 _count_lock = threading.Lock() 1685 _active_session_count = 0 # GUARDED_BY(_count_lock) 1686 1687 def __init__(self, target='', graph=None, config=None): 1688 """Creates a new interactive TensorFlow session. 1689 1690 If no `graph` argument is specified when constructing the session, 1691 the default graph will be launched in the session. If you are 1692 using more than one graph (created with `tf.Graph()`) in the same 1693 process, you will have to use different sessions for each graph, 1694 but each graph can be used in multiple sessions. In this case, it 1695 is often clearer to pass the graph to be launched explicitly to 1696 the session constructor. 1697 1698 Args: 1699 target: (Optional.) The execution engine to connect to. 1700 Defaults to using an in-process engine. 1701 graph: (Optional.) The `Graph` to be launched (described above). 1702 config: (Optional) `ConfigProto` proto used to configure the session. 1703 """ 1704 if not config: 1705 # If config is not provided, choose some reasonable defaults for 1706 # interactive use: 1707 # 1708 # - Grow GPU memory as needed at the cost of fragmentation. 1709 gpu_options = config_pb2.GPUOptions(allow_growth=True) 1710 config = config_pb2.ConfigProto(gpu_options=gpu_options) 1711 # Interactive sessions always place pruned graphs. 1712 config.graph_options.place_pruned_graph = True 1713 1714 super(InteractiveSession, self).__init__(target, graph, config) 1715 with InteractiveSession._count_lock: 1716 if InteractiveSession._active_session_count > 0: 1717 warnings.warn('An interactive session is already active. This can ' 1718 'cause out-of-memory errors in some cases. You must ' 1719 'explicitly call `InteractiveSession.close()` to release ' 1720 'resources held by the other session(s).') 1721 InteractiveSession._active_session_count += 1 1722 # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful 1723 # semantics (in particular, it is not set to true if `Session.close()` is 1724 # called on a session that has not been "opened" by running a step) and we 1725 # cannot change those semantics without breaking existing code. 1726 self._explicitly_closed = False 1727 1728 self._default_session = self.as_default() 1729 self._default_session.enforce_nesting = False 1730 self._default_session.__enter__() 1731 self._explicit_graph = graph 1732 if self._explicit_graph is not None: 1733 self._default_graph = graph.as_default() 1734 self._default_graph.enforce_nesting = False 1735 self._default_graph.__enter__() 1736 1737 def close(self): 1738 """Closes an `InteractiveSession`.""" 1739 super(InteractiveSession, self).close() 1740 with InteractiveSession._count_lock: 1741 if not self._explicitly_closed: 1742 InteractiveSession._active_session_count -= 1 1743 self._explicitly_closed = True 1744 else: 1745 return 1746 if self._explicit_graph is not None: 1747 self._default_graph.__exit__(None, None, None) 1748 self._default_graph = None 1749 self._default_session.__exit__(None, None, None) 1750 self._default_session = None 1751