• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <map>
21 #include <memory>
22 #include <queue>
23 #include <string>
24 #include <vector>
25 
26 // clang-format off
27 // Required for IS_MOBILE_PLATFORM
28 #include "tensorflow/core/framework/shape_inference.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/platform/platform.h"
31 // clang-format on
32 
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/device_factory.h"
35 #include "tensorflow/core/common_runtime/eager/context.h"
36 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
37 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
38 #include "tensorflow/core/common_runtime/function.h"
39 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
40 #if !defined(IS_MOBILE_PLATFORM)
41 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
42 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
43 #endif  // IS_MOBILE_PLATFORM
44 #include "tensorflow/core/framework/rendezvous.h"
45 #include "tensorflow/core/framework/tensor.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 
50 #include "tensorflow/core/platform/fingerprint.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/notification.h"
53 #include "tensorflow/core/platform/thread_annotations.h"
54 #include "tensorflow/core/public/session_options.h"
55 #include "tensorflow/core/public/version.h"
56 
57 namespace tensorflow {
58 
59 // Associates a Tensor and a Device, used in the eager runtime. Internal version
60 // of the TFE_TensorHandle struct and the python EagerTensor class
61 // (unrelated to python TensorHandle).
62 class TensorHandle : public core::RefCounted {
63   // TensorHandle for dtype != DT_RESOURCE
64   TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
65                Device* d, Device* op_device, EagerContext* ctx);
66   // TensorHandle for dtype == DT_RESOURCE
67   TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
68                const ResourceHandle& resource_handle, Device* d,
69                Device* op_device, EagerContext* ctx);
70   TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
71                Device* d, Device* op_device, Device* resource_device,
72                DataType dtype, EagerContext* ctx);
73 
74 #if !defined(IS_MOBILE_PLATFORM)
75   TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype,
76                Device* d, Device* resource_device, EagerContext* ctx);
77   TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
78                DataType dtype, Device* device, EagerContext* ctx);
79 #endif  // IS_MOBILE_PLATFORM
80 
81  public:
82   // TensorHandle with no assigned device
83   static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h);
84   // TensorHandle with device == op_device
85   static Status CreateLocalHandle(const class Tensor& t, Device* d,
86                                   EagerContext* ctx, TensorHandle** h);
87   static Status CreateLocalHandle(const class Tensor& t, Device* d,
88                                   Device* op_device, EagerContext* ctx,
89                                   TensorHandle** h);
90   static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
91                                        Device* resource_device, DataType dtype,
92                                        EagerContext* ctx, TensorHandle** h);
93 #if !defined(IS_MOBILE_PLATFORM)
94   static Status CreateRemoteHandle(int64 op_id, int output_num,
95                                    const TensorShape& shape,
96                                    const string& remote_task, uint64 context_id,
97                                    DataType dtype, Device* d,
98                                    Device* resource_device, EagerContext* ctx,
99                                    TensorHandle** h);
100   static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,
101                                    DataType dtype, Device* d,
102                                    Device* resource_device, EagerContext* ctx,
103                                    TensorHandle** h);
104   static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
105                                            const string& remote_task,
106                                            uint64 context_id, DataType dtype,
107                                            Device* device, EagerContext* ctx,
108                                            TensorHandle** h);
109   static Status CreateUnshapedRemoteHandle(
110       std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
111       Device* device, EagerContext* ctx, TensorHandle** h);
112 #endif  // IS_MOBILE_PLATFORM
113 
~TensorHandle()114   ~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
115 
116   Status Tensor(const tensorflow::Tensor** t);
117 
118   Status TensorValue(tensorflow::TensorValue* t);
119 
device()120   Device* device() const { return device_; }
op_device()121   Device* op_device() const { return op_device_; }
resource_device()122   Device* resource_device() const { return resource_device_; }
123 
124   Device* DeviceOrHostCPU(const EagerContext& ctx) const;
125 
126   Status Shape(tensorflow::TensorShape* shape);
127   Status NumDims(int* num_dims) const;
128   Status Dim(int dim_index, int64* dim) const;
129   Status NumElements(int64* num_elements) const;
130 
131 #if !defined(IS_MOBILE_PLATFORM)
132   bool HasRemoteMirror(Device* d);
133   bool HasResourceShapeMirror(Device* d);
134 
135   Status AddUnshapedRemoteMirror(
136       std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d);
137   Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Device* d);
138   Status AddResourceShapeMirror(
139       std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d);
140 
141   // Return the op_id and output num if the handle refers to a remote tensor.
142   Status RemoteAddress(Device* d, int64* op_id, int32* output_num) const;
143 
144   // Set remote_op_id_ and remote_output_num_ if the handle refers to a local
145   // tensor that needs to be copied to remote workers.
146   void SetRemoteOpIdAndOutputNumToLocalTensorHandle(const int64 op_id,
147                                                     const int32 output_num);
148 
149   // Called on an async remote tensor once it's shape has been determined. This
150   // transitions the tensor handle from a non-ready to a ready state by
151   // replacing the backing data abstraction to allow for the shape to be
152   // queried.
153   // This method or Poison must be called exactly once for remote tensors that
154   // were created without a known shape.
155   Status SetRemoteShape(const TensorShape& shape, tensorflow::Device* d);
156 #endif
157 
158   // Sets the `tensor` for this async non-ready handle making it ready.
159   // This method or Poison must be called exactly once for non-ready async
160   // handles to make them ready.
161   Status SetTensor(tensorflow::Tensor&& tensor);
162 
163   // Poisons this non-ready handle with an error `status`.
164   // Poisoning means that the handle will become ready and methods trying
165   // to access the actual tensor or shape will return this error `status`.
166   // Exactly one of SetTensor, SetRemoteShape, or Poison methods must be called
167   // on a non-ready tensor.
168   void Poison(Status status);
169 
170   Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* dstd,
171                       tensorflow::Tensor* output);
172 
173   Status InferenceShape(
174       shape_inference::InferenceContext* const inference_context,
175       shape_inference::ShapeHandle* shape_handle);
176   void SetInferenceShape(
177       shape_inference::InferenceContext* const inference_context,
178       const shape_inference::ShapeHandle& shape_handle);
179   Status CopyInferenceShape(TensorHandle* other);
180 
181   // Warning: can return nullptr for CPU tensors.
182   // TODO(b/136608821): Move away from nullptr
Context()183   EagerContext* Context() { return ctx_; }
184 
185   // dtype for the handle. It must be the same as t.dtype() once the handle is
186   // ready.
187   const DataType dtype;
188 
189   // TODO(b/136608821): Move away from nullptr
OnHostCPU()190   bool OnHostCPU() const {
191     return device_ == nullptr ||
192            (ctx_ != nullptr && ctx_->HostCPU() == device_);
193   }
194 
IsRemote()195   bool IsRemote() const { return is_remote_; }
196 
197   string DebugString() const;
198 
199   void SetResourceHandleDtypeAndShape(
200       std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
201 
202   // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
203   // return data types and shapes of the underlying resource.
204   Status GetResourceHandleDtypesAndShapes(
205       std::vector<DtypeAndPartialTensorShape>* result);
206 
207  private:
208   // The TensorHandleData can either represent a local or remote tensor handle.
209   // Further, it can be in a non-ready state. It would become ready with a call
210   // to either SetTensor or SetRemoteShape which replaces the underlying data
211   // with a ready version of the tensor handle data.
212   bool IsReady() const;
213 
214   // If the contents of the Tensor pointed to by this handle is yet to be
215   // computed by a EagerNode, this function will block till that computation is
216   // done and the handle is "ready".
217   Status WaitReady(const char* caller) const;
218 
219   // TODO(b/136608821): device_ == nullptr iff Host CPU:0
220   // This was expedient, but perhaps worth revisiting ('device_' should always
221   // be a valid pointer?)
222   // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
223   // provided with the appropriate TFE_Context.
224   //
225   // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
226   // TFE_TensorHandle does not outlive the TFE_Context from which it came?
227   tensorflow::Device* const device_;
228 
229   // Device in which the op producing this tensor was executed. Equals to
230   // device_ for constant tensors.
231   // Can be nullptr if the op producing this tensor was a function executed
232   // with function library runtime.
233   tensorflow::Device* const op_device_;
234 
235   // If the tensor dtype is DT_RESOURCE, resource_device_ holds the device
236   // backing the resource. Else resource_device_ is nullptr.
237   tensorflow::Device* const resource_device_;
238 
239   mutable mutex mu_;
240 
241 #if !defined(IS_MOBILE_PLATFORM)
242   // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
243   // variable is ready, since we could get the shape locally without remote copy
244   // then.
245   std::map<tensorflow::Device*, std::unique_ptr<UnshapedRemoteTensorHandleData>>
246       resource_shape_mirrors_ GUARDED_BY(mu_);
247 
248   // TODO(gjn): Unshaped remote mirrors are long expected to be long-lived.
249   // Consider replacing the unshaped_remote_mirrors_ map with something more
250   // efficient.
251   std::map<tensorflow::Device*, std::unique_ptr<UnshapedRemoteTensorHandleData>>
252       unshaped_remote_mirrors_ GUARDED_BY(mu_);
253   // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
254   // a fixed size map.
255   std::map<tensorflow::Device*, std::unique_ptr<RemoteTensorHandleData>>
256       remote_mirrors_ GUARDED_BY(mu_);
257 
258   // IDs required when this class is representing a remote tensor handle.
259   int64 remote_op_id_;
260   int32 remote_output_num_;
261   string remote_task_;
262   uint64 remote_context_id_;
263 #endif
264 
265   // `ctx` is only guaranteed to be set if the handle is not "ready". This is
266   // typically true when the handle was produced during async execution.
267   // `ctx` object is not owned and should outlive this handle.
268   EagerContext* const ctx_;
269 
270   // Does not need synchronization because it can be accessed only after
271   // WaitReady() has returned. At that point, is_poisoned_ is immutable.
272   Status is_poisoned_;
273   const bool is_remote_;
274   const bool is_async_;
275   bool is_ready_ GUARDED_BY(mu_);
276 
277   // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
278   // refers to a remote resource handle, we store data types and shapes for
279   // the underlying resource.
280   std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
281 
282   // Does not need synchronization because it can be accessed only after
283   // WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
284   std::unique_ptr<TensorHandleData> tensor_handle_data_;
285 
286   PartialTensorShape inference_shape_;
287 };
288 
289 // Returns the device backing the resource. Else, returns nullptr.
290 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
291 
292 }  // namespace tensorflow
293 
294 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_H_
295