• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tensor Handle Operations. See the @{$python/session_ops} guide.
17
18@@get_session_handle
19@@get_session_handle_v2
20@@get_session_tensor
21@@delete_session_tensor
22"""
23
24# pylint: disable=g-bad-name
25from __future__ import absolute_import
26from __future__ import division
27from __future__ import print_function
28
29import numpy as np
30
31from tensorflow.core.framework import resource_handle_pb2
32from tensorflow.python import pywrap_tensorflow_internal
33from tensorflow.python.framework import device as pydev
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import gen_data_flow_ops
38from tensorflow.python.util import compat
39from tensorflow.python.util.tf_export import tf_export
40
41
42def encode_resource_handle(resource_handle):
43  """Encode a ResourceHandle proto as custom numpy struct type."""
44  return np.asarray(bytearray(resource_handle.SerializeToString()),
45                    dtype=dtypes.np_resource)
46
47
48class TensorHandle(object):
49  """Represents a handle for a live tensor in a session."""
50
51  def __init__(self, handle, dtype, session):
52    """Constructs a new tensor handle.
53
54    A tensor handle for a persistent tensor is a python string
55    that has the form of "tensor_name;unique_id;device_name".
56
57    Args:
58      handle: A tensor handle.
59      dtype: The data type of the tensor represented by `handle`.
60      session: The session in which the tensor is produced.
61    """
62    self._handle = compat.as_str_any(handle)
63    self._resource_handle = None
64    self._dtype = dtype
65    self._session = session
66    self._auto_gc_enabled = True
67
68  def __del__(self):
69    if self._auto_gc_enabled:
70      self._session._register_dead_handle(self.handle)
71
72  def __str__(self):
73    return self._handle
74
75  def _get_resource_handle(self):
76    """The ResourceHandle representation of this handle."""
77    if not self._resource_handle:
78      self._resource_handle = resource_handle_pb2.ResourceHandleProto()
79      self._resource_handle.device = self._handle.split(";")[-1]
80      self._resource_handle.container = (
81          pywrap_tensorflow_internal.TENSOR_HANDLE_KEY)
82      self._resource_handle.name = self._handle
83    return self._resource_handle
84
85  def to_numpy_array(self):
86    """Convert a TensorHandle object to a feedable numpy value.
87
88    Returns:
89      A numpy array of a custom struct type that can be used as a feed value
90      to run().
91    """
92    return encode_resource_handle(self._get_resource_handle())
93
94  @property
95  def handle(self):
96    """The string representation of this handle."""
97    return self._handle
98
99  def eval(self):
100    """Return the value of the tensor represented by this handle."""
101    if not self._auto_gc_enabled:
102      raise TypeError("Persistent tensor %s may have already been deleted."
103                      % self.handle)
104    holder, reader = _get_handle_reader(self._session.graph, self._handle,
105                                        self._dtype)
106    return self._session.run(reader, feed_dict={holder: self._handle})
107
108  def delete(self):
109    """Force the deletion of this persistent tensor."""
110    if not self._auto_gc_enabled:
111      raise TypeError("Persistent tensor %s may have already been deleted."
112                      % self.handle)
113    self._auto_gc_enabled = False
114    holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
115    self._session.run(deleter, feed_dict={holder: self.handle})
116
117  def get_raw_handle(self):
118    """Return the raw handle of the tensor.
119
120    Note that the method disables the automatic garbage collection of this
121    persistent tensor. The caller is now responsible for managing the life
122    time of the tensor.
123    """
124    self._auto_gc_enabled = False
125    return self._handle
126
127  @staticmethod
128  def _get_device_name(handle):
129    """The device name encoded in the handle."""
130    handle_str = compat.as_str_any(handle)
131    return pydev.canonical_name(handle_str.split(";")[-1])
132
133  @staticmethod
134  def _get_reader_key(handle):
135    """The graph key for reader."""
136    handle_parts = str(handle).split(";")
137    return handle_parts[0] + ";" + handle_parts[-1]
138
139  @staticmethod
140  def _get_mover_key(feeder, handle):
141    """The graph key for mover."""
142    return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
143
144
145@tf_export("get_session_handle")
146def get_session_handle(data, name=None):
147  """Return the handle of `data`.
148
149  This is EXPERIMENTAL and subject to change.
150
151  Keep `data` "in-place" in the runtime and create a handle that can be
152  used to retrieve `data` in a subsequent run().
153
154  Combined with `get_session_tensor`, we can keep a tensor produced in
155  one run call in place, and use it as the input in a future run call.
156
157  Args:
158    data: A tensor to be stored in the session.
159    name: Optional name prefix for the return tensor.
160
161  Returns:
162    A scalar string tensor representing a unique handle for `data`.
163
164  Raises:
165    TypeError: if `data` is not a Tensor.
166
167  Example:
168
169  ```python
170  c = tf.multiply(a, b)
171  h = tf.get_session_handle(c)
172  h = sess.run(h)
173
174  p, a = tf.get_session_tensor(h.handle, tf.float32)
175  b = tf.multiply(a, 10)
176  c = sess.run(b, feed_dict={p: h.handle})
177  ```
178
179  """
180  if not isinstance(data, ops.Tensor):
181    raise TypeError("`data` must be of type Tensor.")
182
183  # Colocate this operation with data.
184  with ops.colocate_with(data):
185    return gen_data_flow_ops._get_session_handle(data, name=name)  # pylint: disable=protected-access
186
187
188@tf_export("get_session_tensor")
189def get_session_tensor(handle, dtype, name=None):
190  """Get the tensor of type `dtype` by feeding a tensor handle.
191
192  This is EXPERIMENTAL and subject to change.
193
194  Get the value of the tensor from a tensor handle. The tensor
195  is produced in a previous run() and stored in the state of the
196  session.
197
198  Args:
199    handle: The string representation of a persistent tensor handle.
200    dtype: The type of the output tensor.
201    name: Optional name prefix for the return tensor.
202
203  Returns:
204    A pair of tensors. The first is a placeholder for feeding a
205    tensor handle and the second is the tensor in the session state
206    keyed by the tensor handle.
207
208  Example:
209
210  ```python
211  c = tf.multiply(a, b)
212  h = tf.get_session_handle(c)
213  h = sess.run(h)
214
215  p, a = tf.get_session_tensor(h.handle, tf.float32)
216  b = tf.multiply(a, 10)
217  c = sess.run(b, feed_dict={p: h.handle})
218  ```
219
220  """
221  handle_device = TensorHandle._get_device_name(handle)
222  with ops.device(handle_device):
223    holder = array_ops.placeholder(dtypes.string)
224    _register_handle_feeder(holder.graph, holder, dtype)
225    tensor = gen_data_flow_ops._get_session_tensor(holder, dtype, name=name)
226  return (holder, tensor)
227
228
229@tf_export("delete_session_tensor")
230def delete_session_tensor(handle, name=None):
231  """Delete the tensor for the given tensor handle.
232
233  This is EXPERIMENTAL and subject to change.
234
235  Delete the tensor of a given tensor handle. The tensor is produced
236  in a previous run() and stored in the state of the session.
237
238  Args:
239    handle: The string representation of a persistent tensor handle.
240    name: Optional name prefix for the return tensor.
241
242  Returns:
243    A pair of graph elements. The first is a placeholder for feeding a
244    tensor handle and the second is a deletion operation.
245  """
246  handle_device = TensorHandle._get_device_name(handle)
247  with ops.device(handle_device):
248    holder = array_ops.placeholder(dtypes.string)
249    deleter = gen_data_flow_ops._delete_session_tensor(holder, name=name)
250  return (holder, deleter)
251
252
253def _register_handle_feeder(graph, feeder, dtype):
254  graph._handle_feeders[feeder.op.name] = dtype
255
256
257def _get_handle_feeder(graph, feeder):
258  return graph._handle_feeders.get(feeder.op.name)
259
260
261def _get_handle_reader(graph, handle, dtype):
262  """Return a read subgraph for this handle."""
263  graph_key = TensorHandle._get_reader_key(handle)
264  result = graph._handle_readers.get(graph_key)
265  if result is None:
266    # Create reader if we haven't done it.
267    handle_device = TensorHandle._get_device_name(handle)
268    with graph.as_default(), graph.device(handle_device):
269      holder = array_ops.placeholder(dtypes.string)
270      _register_handle_feeder(holder.graph, holder, dtype)
271      reader = gen_data_flow_ops._get_session_tensor(holder, dtype)
272    result = (holder, reader)
273    graph._handle_readers[graph_key] = result
274  return result
275
276
277def _get_handle_mover(graph, feeder, handle):
278  """Return a move subgraph for this pair of feeder and handle."""
279  dtype = _get_handle_feeder(graph, feeder)
280  if dtype is None:
281    return None
282  handle_device = TensorHandle._get_device_name(handle)
283  if feeder.op.device == handle_device:
284    return None
285  # Now we know we have to move the tensor.
286  graph_key = TensorHandle._get_mover_key(feeder, handle)
287  result = graph._handle_movers.get(graph_key)
288  if result is None:
289    # Create mover if we haven't done it.
290    holder, reader = _get_handle_reader(graph, handle, dtype)
291    with graph.as_default(), graph.device(feeder.op.device):
292      mover = gen_data_flow_ops._get_session_handle(reader)  # pylint: disable=protected-access
293    result = (holder, mover)
294    graph._handle_movers[graph_key] = result
295  return result
296
297
298def _get_handle_deleter(graph, deleter_key, handle):
299  """Return a deletion subgraph for this handle."""
300  result = graph._handle_deleters.get(deleter_key)
301  if result is None:
302    # Create deleter if we haven't done it.
303    handle_device = TensorHandle._get_device_name(handle)
304    with graph.as_default(), graph.device(handle_device):
305      holder = array_ops.placeholder(dtypes.string)
306      deleter = gen_data_flow_ops._delete_session_tensor(holder)
307    result = (holder, deleter)
308    graph._handle_deleters[deleter_key] = result
309  return result
310