• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 // clang-format off
27 // Required for IS_MOBILE_PLATFORM
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/platform/platform.h"
31 // clang-format on
32 
33 #include "absl/types/variant.h"
34 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
38 #include "tensorflow/core/common_runtime/function.h"
39 #if !defined(IS_MOBILE_PLATFORM)
40 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
41 #endif  // IS_MOBILE_PLATFORM
42 #include "tensorflow/core/framework/tensor.h"
43 
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 
46 #include "tensorflow/core/platform/mutex.h"
47 #include "tensorflow/core/platform/thread_annotations.h"
48 
49 namespace tensorflow {
50 
51 class EagerContext;
52 
53 // Associates a Tensor and a Device, used in the eager runtime. Internal version
54 // of the TFE_TensorHandle struct and the python EagerTensor class
55 // (unrelated to python TensorHandle).
56 class TensorHandle : public ImmediateExecutionTensorHandle {
57   // TensorHandle for dtype != DT_RESOURCE
58   TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
59                Device* resource_device, EagerContext* ctx);
60   // TensorHandle for dtype == DT_RESOURCE
61   TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
62                EagerContext* ctx);
63   TensorHandle(Device* d, Device* op_device, Device* resource_device,
64                tensorflow::DataType dtype, EagerContext* ctx);
65 
66 #if !defined(IS_MOBILE_PLATFORM)
67   TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
68                tensorflow::DataType dtype, Device* device, EagerContext* ctx,
69                const bool unknown_device);
70   TensorHandle(int64 op_id, int32 output_num, tensorflow::DataType dtype,
71                Device* device, const bool is_ready, EagerContext* ctx);
72 #endif  // IS_MOBILE_PLATFORM
73 
74  public:
75   // TensorHandle with no assigned device
76   static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t);
77   static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
78                                          Device* op_device, EagerContext* ctx);
79   static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
80                                          Device* op_device,
81                                          Device* resource_device,
82                                          EagerContext* ctx);
83   static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
84                                               Device* resource_device,
85                                               tensorflow::DataType dtype,
86                                               EagerContext* ctx);
87 
88   // Create a handle which packs the given handles of the same dtype and shape.
89   // If handles are on different devices, assign the packed handle to a
90   // CompositeDevice.
91   //
92   // The new tensor handle shares ownership of the given handle: their reference
93   // count will be increased by one after a call to `CreatePackedHandle`.
94   // TODO(b/170414377): Use `TensorHandlePtr` instead.
95   static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
96                                    const tensorflow::DataType dtype,
97                                    const tensorflow::TensorShape& shape,
98                                    const string& device_name, EagerContext* ctx,
99                                    TensorHandle** packed_handle);
100   static Status CreatePackedHandle(std::vector<TensorHandle*>&& handles,
101                                    EagerContext* ctx,
102                                    TensorHandle** packed_handle);
103 
104 #if !defined(IS_MOBILE_PLATFORM)
105   // An unshaped remote handle refers to a tensor on a remote worker. It's not
106   // ready until the shape is set. It controls the lifetime of the remote
107   // tensor.
108   static TensorHandle* CreateUnshapedRemoteHandle(
109       int64 op_id, int32 output_num, const string& remote_task,
110       tensorflow::DataType dtype, Device* d, EagerContext* ctx,
111       const bool unknown_device = false);
112   // A lazy remote handle refers to a tensor on a remote worker. The lifetime of
113   // the remote tensor is controlled by the remote worker, but not by the lazy
114   // remote handle. Lazy handles are normally created on a default function
115   // device.
116   static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
117                                               tensorflow::DataType dtype,
118                                               Device* d, const bool is_ready,
119                                               EagerContext* ctx);
120 #endif  // IS_MOBILE_PLATFORM
121 
122   void Release() override;
123 
124   tensorflow::DataType DataType() const override;
125   Status Shape(tensorflow::PartialTensorShape* shape) const override;
126   Status NumDims(int* num_dims) const override;
127   Status NumElements(int64* num_elements) const override;
128   Status Dim(int dim_index, int64* dim) const override;
129 
130   const char* DeviceName(Status* status) const override;
131   const char* BackingDeviceName(Status* status) const override;
132   const char* DeviceType(Status* status) const override;
133   int DeviceId(Status* status) const override;
134   AbstractTensorInterface* Resolve(Status* status) override;
135 
136   ImmediateExecutionTensorHandle* Copy() override;
137 
138   // Return the Tensor from the default device.
139   Status Tensor(const tensorflow::Tensor** t) const;
140   // Return the Tensor from the specified device which could be either the
141   // default device or a local mirror. The device pointer should be nullptr if
142   // requesting the HostCPU.
143   Status TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const;
144 
145   // Return the TensorValue from the specified device which could be either the
146   // default device or a local mirror. The device pointer should be nullptr if
147   // requesting the HostCPU.
148   Status TensorValue(const Device* d, tensorflow::TensorValue* t);
149 
device()150   Device* device() const { return device_; }
op_device()151   Device* op_device() const { return op_device_; }
resource_device()152   Device* resource_device() const { return resource_device_; }
resource_remote_device_incarnation()153   int64 resource_remote_device_incarnation() const {
154     return resource_remote_device_incarnation_;
155   }
156 
157   // If the devices are unknown at creation time, block until the actual devices
158   // are set (data is ready).
159   Status WaitUnknownDevice() const;
160 
161   Device* DeviceOrHostCPU(const EagerContext& ctx) const;
162 
163   Status Shape(tensorflow::TensorShape* shape);
164 
165   Status Unprotect(const Device* d);
166 
167   // Checks if a mirror tensor exists for the specified device. Mirrors are only
168   // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty,
169   // as it is still to be set by an async operation.
170   bool HasLocalMirror(const Device* d) const;
171   // Add an empty mirror placeholder for the specified device. The expectation
172   // is this will be populated by a call to SetTensor.
173   Status AddEmptyLocalMirror(const Device* d);
174   // Add a local mirror. This will fail if an empty local mirror was previously
175   // added. For that case, SetTensor should be used instead.
176   Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d);
177 
178 #if !defined(IS_MOBILE_PLATFORM)
179   bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
180   bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
181 
182   Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num,
183                                  const string& remote_task, EagerContext* ctx);
184   Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
185                                 EagerContext* ctx);
186 
187   // Return the op_id and output num if the handle refers to a remote tensor.
188   // If wait_until_ready is true, block until the remote tensor is ready on the
189   // given remote worker.
190   Status RemoteAddress(const Device* d, const bool wait_until_ready,
191                        int64* op_id, int32* output_num) const;
192 
193   // Called on an async remote tensor once it's shape has been determined. This
194   // transitions the tensor handle from a non-ready to a ready state by
195   // replacing the backing data abstraction to allow for the shape to be
196   // queried.
197   // creating a TensorHandle (e.g. a remote output of a remote function).
198   // This method or Poison must be called exactly once for remote tensors that
199   // were created without a known shape.
200   Status SetRemoteShape(const TensorShape& shape, const Device* d,
201                         uint64 context_view_id);
202   // If op_device is not empty, reset the devices of a remote tensor which is
203   // created without known devices (e.g. function outputs).
204   Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d,
205                                  uint64 context_view_id, string op_device);
206 
207   // Poisons either this handle or a remote mirror with error `status`.
208   // Poisoning means that the handle will become ready and methods trying
209   // to access the remote shape will return this error `status`.
210   // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a
211   // unshaped handle on a remote device.
212   void PoisonRemote(Status status, const Device* d, uint64 context_view_id);
213 #endif
214 
215   // Sets the `tensor` for this async non-ready handle making it ready.
216   // This method or Poison must be called exactly once for non-ready async
217   // handles to make them ready.
218   Status SetTensor(tensorflow::Tensor&& tensor, const Device* d);
219 
220   // Poisons either this handle or a local mirror with error `status`.
221   // Poisoning means that the handle will become ready and methods trying
222   // to access the actual tensor or shape will return this error `status`.
223   // Exactly one of SetTensor or Poison methods must be called on a non-ready
224   // tensor for a specific device.
225   void Poison(Status status, const Device* d);
226 
227   // TODO(b/154282629): Consider moving it to EagerContext.
228   // Copies to the tensor on the given device `d`, or to host iff `d` is null.
229   Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d,
230                       tensorflow::Tensor* output) const;
231 
232   Status InferenceShape(
233       shape_inference::InferenceContext* const inference_context,
234       shape_inference::ShapeHandle* shape_handle);
235   void SetInferenceShape(
236       shape_inference::InferenceContext* const inference_context,
237       const shape_inference::ShapeHandle& shape_handle);
238   Status CopyInferenceShape(TensorHandle* other);
239 
240   // dtype for the handle. It must be the same as t.dtype() once the handle is
241   // ready.
242   const tensorflow::DataType dtype;
243 
244   enum HandleType { LOCAL = 0, PACKED = 1, REMOTE = 2 };
245 
246   HandleType Type() const;
247   string TypeString() const;
248 
249   string DebugString() const;
250 
251   void SetResourceHandleDtypeAndShape(
252       std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
253 
254   // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
255   // return data types and shapes of the underlying resource.
256   Status GetResourceHandleDtypesAndShapes(
257       std::vector<DtypeAndPartialTensorShape>* result);
258 
259   // Returns the number of packed handles. 0 if the handle type is not PACKED.
260   int NumPackedHandles() const;
261   // It's called on a packed TensorHandle. Extract a handle with the given
262   // index.
263   Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
264 
265   // For LLVM style RTTI.
classof(const AbstractTensorHandle * ptr)266   static bool classof(const AbstractTensorHandle* ptr) {
267     return ptr->getKind() == kEager;
268   }
269 
270  private:
271   friend class PackedTensorHandleTest;
272 
273   TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
274                const tensorflow::DataType dtype,
275                const tensorflow::TensorShape& shape, EagerContext* ctx);
276 
277   ~TensorHandle() override;
278 
279   // The TensorHandleData can either represent a local or remote tensor handle.
280   // Further, it can be in a non-ready state. It would become ready with a call
281   // to either SetTensor or SetRemoteShape which replaces the underlying data
282   // with a ready version of the tensor handle data.
283   bool IsReady() const;
284   Status WaitReady(const char* caller) const;
285 
286   tensorflow::Device* device_;
287 
288   // Device in which the op producing this tensor was executed. Equals to
289   // device_ for constant tensors.
290   // Can be nullptr if the op producing this tensor was a function executed
291   // with function library runtime.
292   tensorflow::Device* op_device_;
293 
294   // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
295   // backing the resource. Else resource_device_ is nullptr.
296   tensorflow::Device* resource_device_;
297   // Incarnation ID of the resource device if it locates on a remote device, or
298   // 0 if it locates on a local device.
299   int64 resource_remote_device_incarnation_;
300 
301   // If true, the handle refers to a remote tensor which is created without
302   // known devices. The actual devices are set by SetRemoteShape. The devices
303   // should be accessed once the handle is ready.
304   const bool unknown_device_ = false;
305 
306   mutable mutex mu_;
307 
308   // Map of local mirrors. This can include both ready and non-ready mirrors.
309   std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
310       local_mirrors_ TF_GUARDED_BY(mu_);
311 #if !defined(IS_MOBILE_PLATFORM)
312   // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
313   // variable is ready, since we could get the shape locally without remote copy
314   // then.
315   std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
316       TF_GUARDED_BY(mu_);
317   // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
318   // a fixed size map.
319   std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
320       TF_GUARDED_BY(mu_);
321 #endif
322 
323   // `ctx` is only guaranteed to be set if the handle is not "ready". This is
324   // typically true when the handle was produced during async execution.
325   // `ctx` object is not owned and should outlive this handle.
326   //
327   // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
328   // a TensorHandle does not outlive the EagerContext from which it came?
329   EagerContext* const ctx_;
330 
331   // Does not need synchronization because it can be accessed only after
332   // WaitReady() has returned. At that point, is_poisoned_ is immutable.
333   Status is_poisoned_;
334 
335   // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
336   // refers to a remote resource handle, we store data types and shapes for
337   // the underlying resource.
338   std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
339 
340   // A handle data which refers to multiple TensorHandles of the same dtype and
341   // shape.
342   class PackedTensorHandleData {
343    public:
344     // Initialize handle data from list of tensor handles.
345     // Ownership of the tensor handles is shared between the
346     // `PackedTensorHandleData` and the caller (the reference count for the
347     // given handles is incremented).
348     // TODO(b/170414377): Use `TensorHandlePtr` instead.
349     PackedTensorHandleData(std::vector<TensorHandle*>&& handles,
350                            const TensorShape& shape);
351 
352     ~PackedTensorHandleData();
353 
354     Status Shape(TensorShape* shape) const;
355     Status NumDims(int* num_dims) const;
356     Status Dim(int dim_index, int64* dim) const;
357     Status NumElements(int64* num_elements) const;
358     Status Unprotect();
359     bool IsReady() const;
360     Status WaitReady(const char* caller) const;
361     void Poison(Status status);
362     string DebugString() const;
363 
364     // Number of packed handles.
365     int NumPackedHandles() const;
366     // Extract a handle on the given index.
367     Status ExtractPackedHandle(const int index, TensorHandle** handle) const;
368 
369    private:
370     // TODO(b/170414377): Use `TensorHandlePtr` instead.
371     const std::vector<TensorHandle*> handles_;
372     const TensorShape shape_;
373 
374     mutable mutex mu_;
375     Status is_poisoned_ TF_GUARDED_BY(mu_);
376   };
377 
378   // Does not need synchronization because it can be accessed only after
379   // WaitReady() has returned. At that point, data_ is immutable.
380 #if !defined(IS_MOBILE_PLATFORM)
381   absl::variant<LocalTensorHandleData, PackedTensorHandleData,
382                 RemoteTensorHandleData>
383       data_;
384 #else
385   absl::variant<LocalTensorHandleData, PackedTensorHandleData> data_;
386 #endif
387 
388   PartialTensorShape inference_shape_;
389 };
390 
391 // Returns the device backing the resource. Else, returns nullptr.
392 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
393 
394 class TensorHandleInterface : public ImmediateExecutionTensorHandle {
395  public:
396 };
397 
398 template <typename T>
TensorHandleFromInterface(T * handle)399 inline TensorHandle* TensorHandleFromInterface(T* handle) {
400   return down_cast<TensorHandle*>(handle);
401 }
402 
403 }  // namespace tensorflow
404 
405 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
406