1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tensor Handle Operations.""" 17 18# pylint: disable=g-bad-name 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.core.framework import resource_handle_pb2 26from tensorflow.python.client import pywrap_tf_session 27from tensorflow.python.framework import device as pydev 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_data_flow_ops 32from tensorflow.python.util import compat 33from tensorflow.python.util.tf_export import tf_export 34 35 36def encode_resource_handle(resource_handle): 37 """Encode a ResourceHandle proto as custom numpy struct type.""" 38 return np.asarray(bytearray(resource_handle.SerializeToString()), 39 dtype=dtypes.np_resource) 40 41 42class TensorHandle(object): 43 """Represents a handle for a live tensor in a session.""" 44 45 def __init__(self, handle, dtype, session): 46 """Constructs a new tensor handle. 47 48 A tensor handle for a persistent tensor is a python string 49 that has the form of "tensor_name;unique_id;device_name". 50 51 Args: 52 handle: A tensor handle. 53 dtype: The data type of the tensor represented by `handle`. 54 session: The session in which the tensor is produced. 55 """ 56 self._handle = compat.as_str_any(handle) 57 self._resource_handle = None 58 self._dtype = dtype 59 self._session = session 60 self._auto_gc_enabled = True 61 62 def __del__(self): 63 if self._auto_gc_enabled: 64 self._session._register_dead_handle(self.handle) 65 66 def __str__(self): 67 return self._handle 68 69 def _get_resource_handle(self): 70 """The ResourceHandle representation of this handle.""" 71 if not self._resource_handle: 72 self._resource_handle = resource_handle_pb2.ResourceHandleProto() 73 self._resource_handle.device = self._handle.split(";")[-1] 74 self._resource_handle.container = (pywrap_tf_session.TENSOR_HANDLE_KEY) 75 self._resource_handle.name = self._handle 76 return self._resource_handle 77 78 def to_numpy_array(self): 79 """Convert a TensorHandle object to a feedable numpy value. 80 81 Returns: 82 A numpy array of a custom struct type that can be used as a feed value 83 to run(). 84 """ 85 return encode_resource_handle(self._get_resource_handle()) 86 87 @property 88 def handle(self): 89 """The string representation of this handle.""" 90 return self._handle 91 92 def eval(self): 93 """Return the value of the tensor represented by this handle.""" 94 if not self._auto_gc_enabled: 95 raise TypeError("Persistent tensor %s may have already been deleted." 96 % self.handle) 97 holder, reader = _get_handle_reader(self._session.graph, self._handle, 98 self._dtype) 99 return self._session.run(reader, feed_dict={holder: self._handle}) 100 101 def delete(self): 102 """Force the deletion of this persistent tensor.""" 103 if not self._auto_gc_enabled: 104 raise TypeError("Persistent tensor %s may have already been deleted." 105 % self.handle) 106 self._auto_gc_enabled = False 107 holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) 108 self._session.run(deleter, feed_dict={holder: self.handle}) 109 110 def get_raw_handle(self): 111 """Return the raw handle of the tensor. 112 113 Note that the method disables the automatic garbage collection of this 114 persistent tensor. The caller is now responsible for managing the life 115 time of the tensor. 116 """ 117 self._auto_gc_enabled = False 118 return self._handle 119 120 @staticmethod 121 def _get_device_name(handle): 122 """The device name encoded in the handle.""" 123 handle_str = compat.as_str_any(handle) 124 return pydev.canonical_name(handle_str.split(";")[-1]) 125 126 @staticmethod 127 def _get_reader_key(handle): 128 """The graph key for reader.""" 129 handle_parts = str(handle).split(";") 130 return handle_parts[0] + ";" + handle_parts[-1] 131 132 @staticmethod 133 def _get_mover_key(feeder, handle): 134 """The graph key for mover.""" 135 return feeder.op.name + ";" + TensorHandle._get_reader_key(handle) 136 137 138@tf_export(v1=["get_session_handle"]) 139def get_session_handle(data, name=None): 140 """Return the handle of `data`. 141 142 This is EXPERIMENTAL and subject to change. 143 144 Keep `data` "in-place" in the runtime and create a handle that can be 145 used to retrieve `data` in a subsequent run(). 146 147 Combined with `get_session_tensor`, we can keep a tensor produced in 148 one run call in place, and use it as the input in a future run call. 149 150 Args: 151 data: A tensor to be stored in the session. 152 name: Optional name prefix for the return tensor. 153 154 Returns: 155 A scalar string tensor representing a unique handle for `data`. 156 157 Raises: 158 TypeError: if `data` is not a Tensor. 159 160 Example: 161 162 ```python 163 c = tf.multiply(a, b) 164 h = tf.compat.v1.get_session_handle(c) 165 h = sess.run(h) 166 167 p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32) 168 b = tf.multiply(a, 10) 169 c = sess.run(b, feed_dict={p: h.handle}) 170 ``` 171 172 """ 173 if not isinstance(data, ops.Tensor): 174 raise TypeError("`data` must be of type Tensor.") 175 176 # Colocate this operation with data. 177 with ops.colocate_with(data): 178 return gen_data_flow_ops.get_session_handle(data, name=name) 179 180 181@tf_export(v1=["get_session_tensor"]) 182def get_session_tensor(handle, dtype, name=None): 183 """Get the tensor of type `dtype` by feeding a tensor handle. 184 185 This is EXPERIMENTAL and subject to change. 186 187 Get the value of the tensor from a tensor handle. The tensor 188 is produced in a previous run() and stored in the state of the 189 session. 190 191 Args: 192 handle: The string representation of a persistent tensor handle. 193 dtype: The type of the output tensor. 194 name: Optional name prefix for the return tensor. 195 196 Returns: 197 A pair of tensors. The first is a placeholder for feeding a 198 tensor handle and the second is the tensor in the session state 199 keyed by the tensor handle. 200 201 Example: 202 203 ```python 204 c = tf.multiply(a, b) 205 h = tf.compat.v1.get_session_handle(c) 206 h = sess.run(h) 207 208 p, a = tf.compat.v1.get_session_tensor(h.handle, tf.float32) 209 b = tf.multiply(a, 10) 210 c = sess.run(b, feed_dict={p: h.handle}) 211 ``` 212 213 """ 214 handle_device = TensorHandle._get_device_name(handle) 215 with ops.device(handle_device): 216 holder = array_ops.placeholder(dtypes.string) 217 _register_handle_feeder(holder.graph, holder, dtype) 218 tensor = gen_data_flow_ops.get_session_tensor(holder, dtype, name=name) 219 return (holder, tensor) 220 221 222@tf_export(v1=["delete_session_tensor"]) 223def delete_session_tensor(handle, name=None): 224 """Delete the tensor for the given tensor handle. 225 226 This is EXPERIMENTAL and subject to change. 227 228 Delete the tensor of a given tensor handle. The tensor is produced 229 in a previous run() and stored in the state of the session. 230 231 Args: 232 handle: The string representation of a persistent tensor handle. 233 name: Optional name prefix for the return tensor. 234 235 Returns: 236 A pair of graph elements. The first is a placeholder for feeding a 237 tensor handle and the second is a deletion operation. 238 """ 239 handle_device = TensorHandle._get_device_name(handle) 240 with ops.device(handle_device): 241 holder = array_ops.placeholder(dtypes.string) 242 deleter = gen_data_flow_ops.delete_session_tensor(holder, name=name) 243 return (holder, deleter) 244 245 246def _register_handle_feeder(graph, feeder, dtype): 247 graph._handle_feeders[feeder.op.name] = dtype 248 249 250def _get_handle_feeder(graph, feeder): 251 return graph._handle_feeders.get(feeder.op.name) 252 253 254def _get_handle_reader(graph, handle, dtype): 255 """Return a read subgraph for this handle.""" 256 graph_key = TensorHandle._get_reader_key(handle) 257 result = graph._handle_readers.get(graph_key) 258 if result is None: 259 # Create reader if we haven't done it. 260 handle_device = TensorHandle._get_device_name(handle) 261 with graph.as_default(), graph.device(handle_device): 262 holder = array_ops.placeholder(dtypes.string) 263 _register_handle_feeder(holder.graph, holder, dtype) 264 reader = gen_data_flow_ops.get_session_tensor(holder, dtype) 265 result = (holder, reader) 266 graph._handle_readers[graph_key] = result 267 return result 268 269 270def _get_handle_mover(graph, feeder, handle): 271 """Return a move subgraph for this pair of feeder and handle.""" 272 dtype = _get_handle_feeder(graph, feeder) 273 if dtype is None: 274 return None 275 handle_device = TensorHandle._get_device_name(handle) 276 if feeder.op.device == handle_device: 277 return None 278 # Now we know we have to move the tensor. 279 graph_key = TensorHandle._get_mover_key(feeder, handle) 280 result = graph._handle_movers.get(graph_key) 281 if result is None: 282 # Create mover if we haven't done it. 283 holder, reader = _get_handle_reader(graph, handle, dtype) 284 with graph.as_default(), graph.device(feeder.op.device): 285 mover = gen_data_flow_ops.get_session_handle(reader) 286 result = (holder, mover) 287 graph._handle_movers[graph_key] = result 288 return result 289 290 291def _get_handle_deleter(graph, deleter_key, handle): 292 """Return a deletion subgraph for this handle.""" 293 result = graph._handle_deleters.get(deleter_key) 294 if result is None: 295 # Create deleter if we haven't done it. 296 handle_device = TensorHandle._get_device_name(handle) 297 with graph.as_default(), graph.device(handle_device): 298 holder = array_ops.placeholder(dtypes.string) 299 deleter = gen_data_flow_ops.delete_session_tensor(holder) 300 result = (holder, deleter) 301 graph._handle_deleters[deleter_key] = result 302 return result 303