1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25
26 // clang-format off
27 // Required for IS_MOBILE_PLATFORM
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/platform/platform.h"
31 // clang-format on
32
33 #include "absl/types/variant.h"
34 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
38 #include "tensorflow/core/common_runtime/function.h"
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
41 #endif // IS_MOBILE_PLATFORM
42 #include "tensorflow/core/framework/tensor.h"
43
44 #include "tensorflow/core/lib/core/stringpiece.h"
45
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/thread_annotations.h"
48
49 namespace tensorflow {
50
51 class EagerContext;
52
53 // Associates a Tensor and a Device, used in the eager runtime. Internal version
54 // of the TFE_TensorHandle struct and the python EagerTensor class
55 // (unrelated to python TensorHandle).
56 class TensorHandle : public ImmediateExecutionTensorHandle {
57 // TensorHandle for dtype != DT_RESOURCE
58 TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
59 Device* resource_device, EagerContext* ctx);
60 // TensorHandle for dtype == DT_RESOURCE
61 TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
62 EagerContext* ctx);
63 TensorHandle(Device* d, Device* op_device, Device* resource_device,
64 tensorflow::DataType dtype, EagerContext* ctx);
65
66 #if !defined(IS_MOBILE_PLATFORM)
67 TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
68 tensorflow::DataType dtype, Device* device, EagerContext* ctx,
69 const bool unknown_device);
70 TensorHandle(int64 op_id, int32 output_num, tensorflow::DataType dtype,
71 Device* device, const bool is_ready, EagerContext* ctx);
72 #endif // IS_MOBILE_PLATFORM
73
74 public:
75 // TensorHandle with no assigned device
76 static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t);
77 static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
78 Device* op_device, EagerContext* ctx);
79 static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
80 Device* op_device,
81 Device* resource_device,
82 EagerContext* ctx);
83 static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
84 Device* resource_device,
85 tensorflow::DataType dtype,
86 EagerContext* ctx);
87
88 // Create a handle which packs the given handles of the same dtype and shape.
89 // If handles are on different devices, assign the packed handle to a
90 // CompositeDevice.
91 //
92 // The new tensor handle shares ownership of the given handle: their reference
93 // count will be increased by one after a call to `CreatePackedHandle`.
94 // TODO(b/170414377): Use `TensorHandlePtr` instead.
95 static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
96 const tensorflow::DataType dtype,
97 const tensorflow::TensorShape& shape,
98 const string& device_name, EagerContext* ctx,
99 TensorHandle** packed_handle);
100 static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
101 EagerContext* ctx,
102 TensorHandle** packed_handle);
103
104 #if !defined(IS_MOBILE_PLATFORM)
105 // An unshaped remote handle refers to a tensor on a remote worker. It's not
106 // ready until the shape is set. It controls the lifetime of the remote
107 // tensor.
108 static TensorHandle* CreateUnshapedRemoteHandle(
109 int64 op_id, int32 output_num, const string& remote_task,
110 tensorflow::DataType dtype, Device* d, EagerContext* ctx,
111 const bool unknown_device = false);
112 // A lazy remote handle refers to a tensor on a remote worker. The lifetime of
113 // the remote tensor is controlled by the remote worker, but not by the lazy
114 // remote handle. Lazy handles are normally created on a default function
115 // device.
116 static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
117 tensorflow::DataType dtype,
118 Device* d, const bool is_ready,
119 EagerContext* ctx);
120 #endif // IS_MOBILE_PLATFORM
121
122 void Release() override;
123
124 tensorflow::DataType DataType() const override;
125 Status Shape(tensorflow::PartialTensorShape* shape) const override;
126 Status NumDims(int* num_dims) const override;
127 Status NumElements(int64* num_elements) const override;
128 Status Dim(int dim_index, int64* dim) const override;
129
130 const char* DeviceName(Status* status) const override;
131 const char* BackingDeviceName(Status* status) const override;
132 const char* DeviceType(Status* status) const override;
133 int DeviceId(Status* status) const override;
134 AbstractTensorInterface* Resolve(Status* status) override;
135
136 ImmediateExecutionTensorHandle* Copy() override;
137
138 // Return the Tensor from the default device.
139 Status Tensor(const tensorflow::Tensor** t) const;
140 // Return the Tensor from the specified device which could be either the
141 // default device or a local mirror. The device pointer should be nullptr if
142 // requesting the HostCPU.
143 Status TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const;
144
145 // Return the TensorValue from the specified device which could be either the
146 // default device or a local mirror. The device pointer should be nullptr if
147 // requesting the HostCPU.
148 Status TensorValue(const Device* d, tensorflow::TensorValue* t);
149
device()150 Device* device() const { return device_; }
op_device()151 Device* op_device() const { return op_device_; }
resource_device()152 Device* resource_device() const { return resource_device_; }
resource_remote_device_incarnation()153 int64 resource_remote_device_incarnation() const {
154 return resource_remote_device_incarnation_;
155 }
156
157 // If the devices are unknown at creation time, block until the actual devices
158 // are set (data is ready).
159 Status WaitUnknownDevice() const;
160
161 Device* DeviceOrHostCPU(const EagerContext& ctx) const;
162
163 Status Shape(tensorflow::TensorShape* shape);
164
165 Status Unprotect(const Device* d);
166
167 // Checks if a mirror tensor exists for the specified device. Mirrors are only
168 // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty,
169 // as it is still to be set by an async operation.
170 bool HasLocalMirror(const Device* d) const;
171 // Add an empty mirror placeholder for the specified device. The expectation
172 // is this will be populated by a call to SetTensor.
173 Status AddEmptyLocalMirror(const Device* d);
174 // Add a local mirror. This will fail if an empty local mirror was previously
175 // added. For that case, SetTensor should be used instead.
176 Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d);
177
178 #if !defined(IS_MOBILE_PLATFORM)
179 bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
180 bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
181
182 Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num,
183 const string& remote_task, EagerContext* ctx);
184 Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
185 EagerContext* ctx);
186
187 // Return the op_id and output num if the handle refers to a remote tensor.
188 // If wait_until_ready is true, block until the remote tensor is ready on the
189 // given remote worker.
190 Status RemoteAddress(const Device* d, const bool wait_until_ready,
191 int64* op_id, int32* output_num) const;
192
193 // Called on an async remote tensor once it's shape has been determined. This
194 // transitions the tensor handle from a non-ready to a ready state by
195 // replacing the backing data abstraction to allow for the shape to be
196 // queried.
197 // creating a TensorHandle (e.g. a remote output of a remote function).
198 // This method or Poison must be called exactly once for remote tensors that
199 // were created without a known shape.
200 Status SetRemoteShape(const TensorShape& shape, const Device* d,
201 uint64 context_view_id);
202 // If op_device is not empty, reset the devices of a remote tensor which is
203 // created without known devices (e.g. function outputs).
204 Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d,
205 uint64 context_view_id, string op_device);
206
207 // Poisons either this handle or a remote mirror with error `status`.
208 // Poisoning means that the handle will become ready and methods trying
209 // to access the remote shape will return this error `status`.
210 // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a
211 // unshaped handle on a remote device.
212 void PoisonRemote(Status status, const Device* d, uint64 context_view_id);
213 #endif
214
215 // Sets the `tensor` for this async non-ready handle making it ready.
216 // This method or Poison must be called exactly once for non-ready async
217 // handles to make them ready.
218 Status SetTensor(tensorflow::Tensor&& tensor, const Device* d);
219
220 // Poisons either this handle or a local mirror with error `status`.
221 // Poisoning means that the handle will become ready and methods trying
222 // to access the actual tensor or shape will return this error `status`.
223 // Exactly one of SetTensor or Poison methods must be called on a non-ready
224 // tensor for a specific device.
225 void Poison(Status status, const Device* d);
226
227 // TODO(b/154282629): Consider moving it to EagerContext.
228 // Copies to the tensor on the given device `d`, or to host iff `d` is null.
229 Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d,
230 tensorflow::Tensor* output) const;
231
232 Status InferenceShape(
233 shape_inference::InferenceContext* const inference_context,
234 shape_inference::ShapeHandle* shape_handle);
235 void SetInferenceShape(
236 shape_inference::InferenceContext* const inference_context,
237 const shape_inference::ShapeHandle& shape_handle);
238 Status CopyInferenceShape(TensorHandle* other);
239
240 // dtype for the handle. It must be the same as t.dtype() once the handle is
241 // ready.
242 const tensorflow::DataType dtype;
243
244 enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 };
245
246 HandleType Type() const;
247 string TypeString() const;
248
249 string DebugString() const;
250
251 void SetResourceHandleDtypeAndShape(
252 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
253
254 // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
255 // return data types and shapes of the underlying resource.
256 Status GetResourceHandleDtypesAndShapes(
257 std::vector<DtypeAndPartialTensorShape>* result);
258
259 // Returns the number of packed handles. 0 if the handle type is not PACKED.
260 int NumPackedHandles() const;
261 // It's called on a packed TensorHandle. Extract a handle with the given
262 // index.
263 Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
264
265 // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)266 static bool classof(const AbstractTensorHandle* ptr) {
267 return ptr->getKind() == kEager;
268 }
269
270 private:
271 friend class PackedTensorHandleTest;
272
273 TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
274 const tensorflow::DataType dtype,
275 const tensorflow::TensorShape& shape, EagerContext* ctx);
276
277 ~TensorHandle() override;
278
279 // The TensorHandleData can either represent a local or remote tensor handle.
280 // Further, it can be in a non-ready state. It would become ready with a call
281 // to either SetTensor or SetRemoteShape which replaces the underlying data
282 // with a ready version of the tensor handle data.
283 bool IsReady() const;
284 Status WaitReady(const char* caller) const;
285
286 tensorflow::Device* device_;
287
288 // Device in which the op producing this tensor was executed. Equals to
289 // device_ for constant tensors.
290 // Can be nullptr if the op producing this tensor was a function executed
291 // with function library runtime.
292 tensorflow::Device* op_device_;
293
294 // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
295 // backing the resource. Else resource_device_ is nullptr.
296 tensorflow::Device* resource_device_;
297 // Incarnation ID of the resource device if it locates on a remote device, or
298 // 0 if it locates on a local device.
299 int64 resource_remote_device_incarnation_;
300
301 // If true, the handle refers to a remote tensor which is created without
302 // known devices. The actual devices are set by SetRemoteShape. The devices
303 // should be accessed once the handle is ready.
304 const bool unknown_device_ = false;
305
306 mutable mutex mu_;
307
308 // Map of local mirrors. This can include both ready and non-ready mirrors.
309 std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
310 local_mirrors_ TF_GUARDED_BY(mu_);
311 #if !defined(IS_MOBILE_PLATFORM)
312 // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
313 // variable is ready, since we could get the shape locally without remote copy
314 // then.
315 std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
316 TF_GUARDED_BY(mu_);
317 // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
318 // a fixed size map.
319 std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
320 TF_GUARDED_BY(mu_);
321 #endif
322
323 // `ctx` is only guaranteed to be set if the handle is not "ready". This is
324 // typically true when the handle was produced during async execution.
325 // `ctx` object is not owned and should outlive this handle.
326 //
327 // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
328 // a TensorHandle does not outlive the EagerContext from which it came?
329 EagerContext* const ctx_;
330
331 // Does not need synchronization because it can be accessed only after
332 // WaitReady() has returned. At that point, is_poisoned_ is immutable.
333 Status is_poisoned_;
334
335 // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
336 // refers to a remote resource handle, we store data types and shapes for
337 // the underlying resource.
338 std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
339
340 // A handle data which refers to multiple TensorHandles of the same dtype and
341 // shape.
342 class PackedTensorHandleData {
343 public:
344 // Initialize handle data from list of tensor handles.
345 // Ownership of the tensor handles is shared between the
346 // `PackedTensorHandleData` and the caller (the reference count for the
347 // given handles is incremented).
348 // TODO(b/170414377): Use `TensorHandlePtr` instead.
349 PackedTensorHandleData(std::vector<TensorHandle*>&& handles,
350 const TensorShape& shape);
351
352 ~PackedTensorHandleData();
353
354 Status Shape(TensorShape* shape) const;
355 Status NumDims(int* num_dims) const;
356 Status Dim(int dim_index, int64* dim) const;
357 Status NumElements(int64* num_elements) const;
358 Status Unprotect();
359 bool IsReady() const;
360 Status WaitReady(const char* caller) const;
361 void Poison(Status status);
362 string DebugString() const;
363
364 // Number of packed handles.
365 int NumPackedHandles() const;
366 // Extract a handle on the given index.
367 Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
368
369 private:
370 // TODO(b/170414377): Use `TensorHandlePtr` instead.
371 const std::vector<TensorHandle*> handles_;
372 const TensorShape shape_;
373
374 mutable mutex mu_;
375 Status is_poisoned_ TF_GUARDED_BY(mu_);
376 };
377
378 // Does not need synchronization because it can be accessed only after
379 // WaitReady() has returned. At that point, data_ is immutable.
380 #if !defined(IS_MOBILE_PLATFORM)
381 absl::variant<LocalTensorHandleData, PackedTensorHandleData,
382 RemoteTensorHandleData>
383 data_;
384 #else
385 absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_;
386 #endif
387
388 PartialTensorShape inference_shape_;
389 };
390
391 // Returns the device backing the resource. Else, returns nullptr.
392 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
393
394 class TensorHandleInterface : public ImmediateExecutionTensorHandle {
395 public:
396 };
397
398 template <typename T>
TensorHandleFromInterface(T * handle)399 inline TensorHandle* TensorHandleFromInterface(T* handle) {
400 return down_cast<TensorHandle*>(handle);
401 }
402
403 } // namespace tensorflow
404
405 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
406