1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tensor Handle Operations.""" 17 18# pylint: disable=g-bad-name 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.core.framework import resource_handle_pb2 26from tensorflow.python import pywrap_tensorflow_internal 27from tensorflow.python.framework import device as pydev 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_data_flow_ops 32from tensorflow.python.util import compat 33from tensorflow.python.util.tf_export import tf_export 34 35 36def encode_resource_handle(resource_handle): 37 """Encode a ResourceHandle proto as custom numpy struct type.""" 38 return np.asarray(bytearray(resource_handle.SerializeToString()), 39 dtype=dtypes.np_resource) 40 41 42class TensorHandle(object): 43 """Represents a handle for a live tensor in a session.""" 44 45 def __init__(self, handle, dtype, session): 46 """Constructs a new tensor handle. 47 48 A tensor handle for a persistent tensor is a python string 49 that has the form of "tensor_name;unique_id;device_name". 50 51 Args: 52 handle: A tensor handle. 53 dtype: The data type of the tensor represented by `handle`. 54 session: The session in which the tensor is produced. 55 """ 56 self._handle = compat.as_str_any(handle) 57 self._resource_handle = None 58 self._dtype = dtype 59 self._session = session 60 self._auto_gc_enabled = True 61 62 def __del__(self): 63 if self._auto_gc_enabled: 64 self._session._register_dead_handle(self.handle) 65 66 def __str__(self): 67 return self._handle 68 69 def _get_resource_handle(self): 70 """The ResourceHandle representation of this handle.""" 71 if not self._resource_handle: 72 self._resource_handle = resource_handle_pb2.ResourceHandleProto() 73 self._resource_handle.device = self._handle.split(";")[-1] 74 self._resource_handle.container = ( 75 pywrap_tensorflow_internal.TENSOR_HANDLE_KEY) 76 self._resource_handle.name = self._handle 77 return self._resource_handle 78 79 def to_numpy_array(self): 80 """Convert a TensorHandle object to a feedable numpy value. 81 82 Returns: 83 A numpy array of a custom struct type that can be used as a feed value 84 to run(). 85 """ 86 return encode_resource_handle(self._get_resource_handle()) 87 88 @property 89 def handle(self): 90 """The string representation of this handle.""" 91 return self._handle 92 93 def eval(self): 94 """Return the value of the tensor represented by this handle.""" 95 if not self._auto_gc_enabled: 96 raise TypeError("Persistent tensor %s may have already been deleted." 97 % self.handle) 98 holder, reader = _get_handle_reader(self._session.graph, self._handle, 99 self._dtype) 100 return self._session.run(reader, feed_dict={holder: self._handle}) 101 102 def delete(self): 103 """Force the deletion of this persistent tensor.""" 104 if not self._auto_gc_enabled: 105 raise TypeError("Persistent tensor %s may have already been deleted." 106 % self.handle) 107 self._auto_gc_enabled = False 108 holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) 109 self._session.run(deleter, feed_dict={holder: self.handle}) 110 111 def get_raw_handle(self): 112 """Return the raw handle of the tensor. 113 114 Note that the method disables the automatic garbage collection of this 115 persistent tensor. The caller is now responsible for managing the life 116 time of the tensor. 117 """ 118 self._auto_gc_enabled = False 119 return self._handle 120 121 @staticmethod 122 def _get_device_name(handle): 123 """The device name encoded in the handle.""" 124 handle_str = compat.as_str_any(handle) 125 return pydev.canonical_name(handle_str.split(";")[-1]) 126 127 @staticmethod 128 def _get_reader_key(handle): 129 """The graph key for reader.""" 130 handle_parts = str(handle).split(";") 131 return handle_parts[0] + ";" + handle_parts[-1] 132 133 @staticmethod 134 def _get_mover_key(feeder, handle): 135 """The graph key for mover.""" 136 return feeder.op.name + ";" + TensorHandle._get_reader_key(handle) 137 138 139@tf_export(v1=["get_session_handle"]) 140def get_session_handle(data, name=None): 141 """Return the handle of `data`. 142 143 This is EXPERIMENTAL and subject to change. 144 145 Keep `data` "in-place" in the runtime and create a handle that can be 146 used to retrieve `data` in a subsequent run(). 147 148 Combined with `get_session_tensor`, we can keep a tensor produced in 149 one run call in place, and use it as the input in a future run call. 150 151 Args: 152 data: A tensor to be stored in the session. 153 name: Optional name prefix for the return tensor. 154 155 Returns: 156 A scalar string tensor representing a unique handle for `data`. 157 158 Raises: 159 TypeError: if `data` is not a Tensor. 160 161 Example: 162 163 ```python 164 c = tf.multiply(a, b) 165 h = tf.get_session_handle(c) 166 h = sess.run(h) 167 168 p, a = tf.get_session_tensor(h.handle, tf.float32) 169 b = tf.multiply(a, 10) 170 c = sess.run(b, feed_dict={p: h.handle}) 171 ``` 172 173 """ 174 if not isinstance(data, ops.Tensor): 175 raise TypeError("`data` must be of type Tensor.") 176 177 # Colocate this operation with data. 178 with ops.colocate_with(data): 179 return gen_data_flow_ops.get_session_handle(data, name=name) 180 181 182@tf_export(v1=["get_session_tensor"]) 183def get_session_tensor(handle, dtype, name=None): 184 """Get the tensor of type `dtype` by feeding a tensor handle. 185 186 This is EXPERIMENTAL and subject to change. 187 188 Get the value of the tensor from a tensor handle. The tensor 189 is produced in a previous run() and stored in the state of the 190 session. 191 192 Args: 193 handle: The string representation of a persistent tensor handle. 194 dtype: The type of the output tensor. 195 name: Optional name prefix for the return tensor. 196 197 Returns: 198 A pair of tensors. The first is a placeholder for feeding a 199 tensor handle and the second is the tensor in the session state 200 keyed by the tensor handle. 201 202 Example: 203 204 ```python 205 c = tf.multiply(a, b) 206 h = tf.get_session_handle(c) 207 h = sess.run(h) 208 209 p, a = tf.get_session_tensor(h.handle, tf.float32) 210 b = tf.multiply(a, 10) 211 c = sess.run(b, feed_dict={p: h.handle}) 212 ``` 213 214 """ 215 handle_device = TensorHandle._get_device_name(handle) 216 with ops.device(handle_device): 217 holder = array_ops.placeholder(dtypes.string) 218 _register_handle_feeder(holder.graph, holder, dtype) 219 tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name) 220 return (holder, tensor) 221 222 223@tf_export(v1=["delete_session_tensor"]) 224def delete_session_tensor(handle, name=None): 225 """Delete the tensor for the given tensor handle. 226 227 This is EXPERIMENTAL and subject to change. 228 229 Delete the tensor of a given tensor handle. The tensor is produced 230 in a previous run() and stored in the state of the session. 231 232 Args: 233 handle: The string representation of a persistent tensor handle. 234 name: Optional name prefix for the return tensor. 235 236 Returns: 237 A pair of graph elements. The first is a placeholder for feeding a 238 tensor handle and the second is a deletion operation. 239 """ 240 handle_device = TensorHandle._get_device_name(handle) 241 with ops.device(handle_device): 242 holder = array_ops.placeholder(dtypes.string) 243 deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name) 244 return (holder, deleter) 245 246 247def _register_handle_feeder(graph, feeder, dtype): 248 graph._handle_feeders[feeder.op.name] = dtype 249 250 251def _get_handle_feeder(graph, feeder): 252 return graph._handle_feeders.get(feeder.op.name) 253 254 255def _get_handle_reader(graph, handle, dtype): 256 """Return a read subgraph for this handle.""" 257 graph_key = TensorHandle._get_reader_key(handle) 258 result = graph._handle_readers.get(graph_key) 259 if result is None: 260 # Create reader if we haven't done it. 261 handle_device = TensorHandle._get_device_name(handle) 262 with graph.as_default(), graph.device(handle_device): 263 holder = array_ops.placeholder(dtypes.string) 264 _register_handle_feeder(holder.graph, holder, dtype) 265 reader = gen_data_flow_ops.get_session_tensor(holder, dtype) 266 result = (holder, reader) 267 graph._handle_readers[graph_key] = result 268 return result 269 270 271def _get_handle_mover(graph, feeder, handle): 272 """Return a move subgraph for this pair of feeder and handle.""" 273 dtype = _get_handle_feeder(graph, feeder) 274 if dtype is None: 275 return None 276 handle_device = TensorHandle._get_device_name(handle) 277 if feeder.op.device == handle_device: 278 return None 279 # Now we know we have to move the tensor. 280 graph_key = TensorHandle._get_mover_key(feeder, handle) 281 result = graph._handle_movers.get(graph_key) 282 if result is None: 283 # Create mover if we haven't done it. 284 holder, reader = _get_handle_reader(graph, handle, dtype) 285 with graph.as_default(), graph.device(feeder.op.device): 286 mover = gen_data_flow_ops.get_session_handle(reader) 287 result = (holder, mover) 288 graph._handle_movers[graph_key] = result 289 return result 290 291 292def _get_handle_deleter(graph, deleter_key, handle): 293 """Return a deletion subgraph for this handle.""" 294 result = graph._handle_deleters.get(deleter_key) 295 if result is None: 296 # Create deleter if we haven't done it. 297 handle_device = TensorHandle._get_device_name(handle) 298 with graph.as_default(), graph.device(handle_device): 299 holder = array_ops.placeholder(dtypes.string) 300 deleter = gen_data_flow_ops.delete_session_tensor(holder) 301 result = (holder, deleter) 302 graph._handle_deleters[deleter_key] = result 303 return result 304