1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <string> 24 #include <vector> 25 26 #include "tensorflow/core/common_runtime/device.h" 27 #include "tensorflow/core/common_runtime/device_factory.h" 28 #include "tensorflow/core/common_runtime/eager/context.h" 29 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 30 #include "tensorflow/core/common_runtime/function.h" 31 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 32 #include "tensorflow/core/framework/rendezvous.h" 33 #include "tensorflow/core/framework/tensor.h" 34 #include "tensorflow/core/lib/core/stringpiece.h" 35 #include "tensorflow/core/lib/gtl/inlined_vector.h" 36 #include "tensorflow/core/lib/gtl/map_util.h" 37 #include "tensorflow/core/lib/gtl/stl_util.h" 38 #include "tensorflow/core/platform/fingerprint.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/thread_annotations.h" 41 #include "tensorflow/core/public/session_options.h" 42 #include "tensorflow/core/public/version.h" 43 44 struct TF_Operation; 45 46 namespace tensorflow { 47 48 // This struct is isomorphic to TF_Output, but we cannot use the latter here due 49 // to layering concerns (TF_Output is defined at the C API layer). 50 struct OutputGraphNode { 51 TF_Operation* oper; 52 int index; // The index of the output within oper. 53 }; 54 55 // Associates a Tensor and a Device, used in the eager runtime. Internal version 56 // of the TFE_TensorHandle struct and the python EagerTensor class 57 // (unrelated to python TensorHandle). 58 class TensorHandle : public core::RefCounted { 59 public: 60 TensorHandle(const Tensor& t, Device* d, Device* op_device, 61 EagerContext* ctx); 62 TensorHandle(uint64 node_id, Device* d, Device* op_device, 63 Device* resource_device, DataType dtype, EagerContext* ctx); 64 65 // Remote tensor handle constructor. 66 TensorHandle(int64 op_id, int32 output_num, uint64 remote_shape_node_id, 67 DataType dtype, std::function<void()> call_on_destroy, Device* d, 68 Device* op_device, Device* resource_device, EagerContext* ctx); 69 70 // Symbolic tensor constructor. 71 TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype); 72 ~TensorHandle()73 ~TensorHandle() override { 74 VLOG(1) << "Deleting internal TensorHandle " << this; 75 if (call_on_destroy_) { 76 call_on_destroy_(); 77 } 78 } 79 80 Status Tensor(const tensorflow::Tensor** t); 81 82 Status TensorValue(tensorflow::TensorValue* t); 83 device()84 tensorflow::Device* device() const { return device_; } op_device()85 tensorflow::Device* op_device() const { return op_device_; } resource_device()86 tensorflow::Device* resource_device() const { return resource_device_; } 87 88 Status TensorAndDevice(const tensorflow::Tensor** tensor, 89 tensorflow::Device** device, 90 tensorflow::Device** op_device); 91 92 Status Shape(tensorflow::TensorShape* shape); 93 94 Status NumDims(int* num_dims); 95 Status Dim(int dim_index, int64* dim); 96 Status NumElements(int64* num_elements); 97 98 // Return the op_id and output num if the handle refers to a remote tensor. 99 Status RemoteAddress(int64* op_id, int32* output_num); 100 101 // Note that this can be called at most once, and only on non-ready handles, 102 // and makes them ready. 103 void SetTensor(const tensorflow::Tensor& tensor); 104 105 Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd, 106 TensorHandle** output); 107 108 // Warning: can return nullptr for CPU tensors. Context()109 EagerContext* Context() { 110 mutex_lock ml(ctx_mutex_); 111 return ctx_; 112 } 113 114 // dtype for the handle. It must be the same as t.dtype() once the handle is 115 // ready. 116 const DataType dtype; 117 SetRemoteShape(std::unique_ptr<TensorShape> remote_shape)118 void SetRemoteShape(std::unique_ptr<TensorShape> remote_shape) { 119 remote_shape_ = std::move(remote_shape); 120 } 121 OnHostCPU()122 bool OnHostCPU() { 123 mutex_lock ml(ctx_mutex_); 124 return device_ == nullptr || 125 (ctx_ == nullptr || ctx_->HostCPU() == device_); 126 } 127 128 bool IsRemote(); 129 getSymbolicTensor()130 OutputGraphNode* getSymbolicTensor() const { return symbolic_tensor.get(); } 131 132 string DebugString() const; 133 134 private: 135 // If the contents of the Tensor pointed to by this handle is yet to be 136 // computed by a EagerNode, this function will block till that computation is 137 // done and the handle is "ready". 138 Status WaitReady(); 139 Status WaitForNode(uint64 node_id, bool return_if_is_ready); 140 141 bool IsReady(); 142 143 // Id for the EagerNode that will compute the value pointed to by this handle. 144 // If the value is 0, the handle is already ready, but not vice-versa. 145 const uint64 node_id_; 146 147 tensorflow::Tensor tensor_; 148 149 // TODO(ashankar): device_ == nullptr iff local CPU 150 // This was expedient, but perhaps worth revisiting ('device_' should always 151 // be a valid pointer?) 152 // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are 153 // provided with the appropriate TFE_Context. 154 // 155 // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a 156 // TFE_TensorHandle does not outlive the TFE_Context from which it came? 157 tensorflow::Device* const device_; 158 159 // Device in which the op producing this tensor was executed. Equals to 160 // device_ for constant tensors. 161 // Can be nullptr if the op producing this tensor was a function executed 162 // with function library runtime or if this tensor represents a symbolic 163 // tensor. 164 tensorflow::Device* const op_device_; 165 166 // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device 167 // backing the resource. Else resource_device_ is nullptr. 168 tensorflow::Device* const resource_device_; 169 170 // IDs required when this class is representing a remote tensor handle. 171 const int64 remote_op_id_; 172 const int32 remote_output_num_; 173 std::unique_ptr<TensorShape> remote_shape_; 174 const uint64 remote_shape_node_id_; 175 176 // A callback that is executed when the class is destroyed. 177 // 178 // This is currently used for remote tensor handles. 179 const std::function<void()> call_on_destroy_; 180 181 mutex ctx_mutex_; 182 183 // `ctx` is only guaranteed to be set if the handle is not "ready". This is 184 // typically true when the handle was produced during async execution. 185 // `ctx` object is not owned and should outlive this handle. 186 EagerContext* ctx_ GUARDED_BY(ctx_mutex_); 187 bool is_ready_ GUARDED_BY(ctx_mutex_); 188 189 // When non-NULL, this tensor handle instance represents a symbolic tensor 190 // (corresponding to a graph node), whose concrete value is to be produced by 191 // executing that graph node. 192 std::unique_ptr<OutputGraphNode> symbolic_tensor; 193 }; 194 195 // If tensor's dtype is DT_RESOURCE, returns the device backing the resource. 196 // Else, returns nullptr. 197 Device* GetResourceDevice(const Tensor& t, EagerContext* ctx); 198 199 } // namespace tensorflow 200 201 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ 202