1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ 17 18 #include <algorithm> 19 #include <cstddef> 20 #include <map> 21 #include <memory> 22 #include <queue> 23 #include <string> 24 #include <vector> 25 26 // clang-format off 27 // Required for IS_MOBILE_PLATFORM 28 #include "tensorflow/core/framework/shape_inference.h" 29 #include "tensorflow/core/framework/tensor_shape.h" 30 #include "tensorflow/core/platform/platform.h" 31 // clang-format on 32 33 #include "tensorflow/core/common_runtime/device.h" 34 #include "tensorflow/core/common_runtime/device_factory.h" 35 #include "tensorflow/core/common_runtime/eager/context.h" 36 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 37 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" 38 #include "tensorflow/core/common_runtime/function.h" 39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h" 40 #if !defined(IS_MOBILE_PLATFORM) 41 #include "tensorflow/core/distributed_runtime/eager/eager_client.h" 42 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" 43 #endif // IS_MOBILE_PLATFORM 44 #include "tensorflow/core/framework/rendezvous.h" 45 #include "tensorflow/core/framework/tensor.h" 46 #include "tensorflow/core/lib/core/stringpiece.h" 47 #include "tensorflow/core/lib/gtl/inlined_vector.h" 48 #include "tensorflow/core/lib/gtl/map_util.h" 49 50 #include "tensorflow/core/platform/fingerprint.h" 51 #include "tensorflow/core/platform/mutex.h" 52 #include "tensorflow/core/platform/notification.h" 53 #include "tensorflow/core/platform/thread_annotations.h" 54 #include "tensorflow/core/public/session_options.h" 55 #include "tensorflow/core/public/version.h" 56 57 namespace tensorflow { 58 59 // Associates a Tensor and a Device, used in the eager runtime. Internal version 60 // of the TFE_TensorHandle struct and the python EagerTensor class 61 // (unrelated to python TensorHandle). 62 class TensorHandle : public core::RefCounted { 63 // TensorHandle for dtype != DT_RESOURCE 64 TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, 65 Device* d, Device* op_device, EagerContext* ctx); 66 // TensorHandle for dtype == DT_RESOURCE 67 TensorHandle(std::unique_ptr<LocalTensorHandleData> t, 68 const ResourceHandle& resource_handle, Device* d, 69 Device* op_device, EagerContext* ctx); 70 TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async, 71 Device* d, Device* op_device, Device* resource_device, 72 DataType dtype, EagerContext* ctx); 73 74 #if !defined(IS_MOBILE_PLATFORM) 75 TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, 76 Device* d, Device* resource_device, EagerContext* ctx); 77 TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t, 78 DataType dtype, Device* device, EagerContext* ctx); 79 #endif // IS_MOBILE_PLATFORM 80 81 public: 82 // TensorHandle with no assigned device 83 static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h); 84 // TensorHandle with device == op_device 85 static Status CreateLocalHandle(const class Tensor& t, Device* d, 86 EagerContext* ctx, TensorHandle** h); 87 static Status CreateLocalHandle(const class Tensor& t, Device* d, 88 Device* op_device, EagerContext* ctx, 89 TensorHandle** h); 90 static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device, 91 Device* resource_device, DataType dtype, 92 EagerContext* ctx, TensorHandle** h); 93 #if !defined(IS_MOBILE_PLATFORM) 94 static Status CreateRemoteHandle(int64 op_id, int output_num, 95 const TensorShape& shape, 96 const string& remote_task, uint64 context_id, 97 DataType dtype, Device* d, 98 Device* resource_device, EagerContext* ctx, 99 TensorHandle** h); 100 static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t, 101 DataType dtype, Device* d, 102 Device* resource_device, EagerContext* ctx, 103 TensorHandle** h); 104 static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, 105 const string& remote_task, 106 uint64 context_id, DataType dtype, 107 Device* device, EagerContext* ctx, 108 TensorHandle** h); 109 static Status CreateUnshapedRemoteHandle( 110 std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype, 111 Device* device, EagerContext* ctx, TensorHandle** h); 112 #endif // IS_MOBILE_PLATFORM 113 ~TensorHandle()114 ~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; } 115 116 Status Tensor(const tensorflow::Tensor** t); 117 118 Status TensorValue(tensorflow::TensorValue* t); 119 device()120 Device* device() const { return device_; } op_device()121 Device* op_device() const { return op_device_; } resource_device()122 Device* resource_device() const { return resource_device_; } 123 124 Device* DeviceOrHostCPU(const EagerContext& ctx) const; 125 126 Status Shape(tensorflow::TensorShape* shape); 127 Status NumDims(int* num_dims) const; 128 Status Dim(int dim_index, int64* dim) const; 129 Status NumElements(int64* num_elements) const; 130 131 #if !defined(IS_MOBILE_PLATFORM) 132 bool HasRemoteMirror(Device* d); 133 bool HasResourceShapeMirror(Device* d); 134 135 Status AddUnshapedRemoteMirror( 136 std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d); 137 Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Device* d); 138 Status AddResourceShapeMirror( 139 std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d); 140 141 // Return the op_id and output num if the handle refers to a remote tensor. 142 Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const; 143 144 // Set remote_op_id_ and remote_output_num_ if the handle refers to a local 145 // tensor that needs to be copied to remote workers. 146 void SetRemoteOpIdAndOutputNumToLocalTensorHandle(const int64 op_id, 147 const int32 output_num); 148 149 // Called on an async remote tensor once it's shape has been determined. This 150 // transitions the tensor handle from a non-ready to a ready state by 151 // replacing the backing data abstraction to allow for the shape to be 152 // queried. 153 // This method or Poison must be called exactly once for remote tensors that 154 // were created without a known shape. 155 Status SetRemoteShape(const TensorShape& shape, tensorflow::Device* d); 156 #endif 157 158 // Sets the `tensor` for this async non-ready handle making it ready. 159 // This method or Poison must be called exactly once for non-ready async 160 // handles to make them ready. 161 Status SetTensor(tensorflow::Tensor&& tensor); 162 163 // Poisons this non-ready handle with an error `status`. 164 // Poisoning means that the handle will become ready and methods trying 165 // to access the actual tensor or shape will return this error `status`. 166 // Exactly one of SetTensor, SetRemoteShape, or Poison methods must be called 167 // on a non-ready tensor. 168 void Poison(Status status); 169 170 Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* dstd, 171 tensorflow::Tensor* output); 172 173 Status InferenceShape( 174 shape_inference::InferenceContext* const inference_context, 175 shape_inference::ShapeHandle* shape_handle); 176 void SetInferenceShape( 177 shape_inference::InferenceContext* const inference_context, 178 const shape_inference::ShapeHandle& shape_handle); 179 Status CopyInferenceShape(TensorHandle* other); 180 181 // Warning: can return nullptr for CPU tensors. 182 // TODO(b/136608821): Move away from nullptr Context()183 EagerContext* Context() { return ctx_; } 184 185 // dtype for the handle. It must be the same as t.dtype() once the handle is 186 // ready. 187 const DataType dtype; 188 189 // TODO(b/136608821): Move away from nullptr OnHostCPU()190 bool OnHostCPU() const { 191 return device_ == nullptr || 192 (ctx_ != nullptr && ctx_->HostCPU() == device_); 193 } 194 IsRemote()195 bool IsRemote() const { return is_remote_; } 196 197 string DebugString() const; 198 199 void SetResourceHandleDtypeAndShape( 200 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes); 201 202 // If this TensorHandle is 1) a local tensor, and 2) a resource handle, 203 // return data types and shapes of the underlying resource. 204 Status GetResourceHandleDtypesAndShapes( 205 std::vector<DtypeAndPartialTensorShape>* result); 206 207 private: 208 // The TensorHandleData can either represent a local or remote tensor handle. 209 // Further, it can be in a non-ready state. It would become ready with a call 210 // to either SetTensor or SetRemoteShape which replaces the underlying data 211 // with a ready version of the tensor handle data. 212 bool IsReady() const; 213 214 // If the contents of the Tensor pointed to by this handle is yet to be 215 // computed by a EagerNode, this function will block till that computation is 216 // done and the handle is "ready". 217 Status WaitReady(const char* caller) const; 218 219 // TODO(b/136608821): device_ == nullptr iff Host CPU:0 220 // This was expedient, but perhaps worth revisiting ('device_' should always 221 // be a valid pointer?) 222 // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are 223 // provided with the appropriate TFE_Context. 224 // 225 // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a 226 // TFE_TensorHandle does not outlive the TFE_Context from which it came? 227 tensorflow::Device* const device_; 228 229 // Device in which the op producing this tensor was executed. Equals to 230 // device_ for constant tensors. 231 // Can be nullptr if the op producing this tensor was a function executed 232 // with function library runtime. 233 tensorflow::Device* const op_device_; 234 235 // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device 236 // backing the resource. Else resource_device_ is nullptr. 237 tensorflow::Device* const resource_device_; 238 239 mutable mutex mu_; 240 241 #if !defined(IS_MOBILE_PLATFORM) 242 // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica 243 // variable is ready, since we could get the shape locally without remote copy 244 // then. 245 std::map<tensorflow::Device*, std::unique_ptr<UnshapedRemoteTensorHandleData>> 246 resource_shape_mirrors_ GUARDED_BY(mu_); 247 248 // TODO(gjn): Unshaped remote mirrors are long expected to be long-lived. 249 // Consider replacing the unshaped_remote_mirrors_ map with something more 250 // efficient. 251 std::map<tensorflow::Device*, std::unique_ptr<UnshapedRemoteTensorHandleData>> 252 unshaped_remote_mirrors_ GUARDED_BY(mu_); 253 // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be 254 // a fixed size map. 255 std::map<tensorflow::Device*, std::unique_ptr<RemoteTensorHandleData>> 256 remote_mirrors_ GUARDED_BY(mu_); 257 258 // IDs required when this class is representing a remote tensor handle. 259 int64 remote_op_id_; 260 int32 remote_output_num_; 261 string remote_task_; 262 uint64 remote_context_id_; 263 #endif 264 265 // `ctx` is only guaranteed to be set if the handle is not "ready". This is 266 // typically true when the handle was produced during async execution. 267 // `ctx` object is not owned and should outlive this handle. 268 EagerContext* const ctx_; 269 270 // Does not need synchronization because it can be accessed only after 271 // WaitReady() has returned. At that point, is_poisoned_ is immutable. 272 Status is_poisoned_; 273 const bool is_remote_; 274 const bool is_async_; 275 bool is_ready_ GUARDED_BY(mu_); 276 277 // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or 278 // refers to a remote resource handle, we store data types and shapes for 279 // the underlying resource. 280 std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_; 281 282 // Does not need synchronization because it can be accessed only after 283 // WaitReady() has returned. At that point, tensor_handle_data_ is immutable. 284 std::unique_ptr<TensorHandleData> tensor_handle_data_; 285 286 PartialTensorShape inference_shape_; 287 }; 288 289 // Returns the device backing the resource. Else, returns nullptr. 290 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); 291 292 } // namespace tensorflow 293 294 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_ 295