• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/common_runtime/copy_tensor.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/stringpiece.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/lib/gtl/stl_util.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/public/version.h"
45 
46 namespace tensorflow {
47 
TensorHandle(const class Tensor & t,Device * d,Device * op_device,EagerContext * ctx)48 TensorHandle::TensorHandle(const class Tensor& t, Device* d, Device* op_device,
49                            EagerContext* ctx)
50     : dtype(t.dtype()),
51       node_id_(0),
52       tensor_(t),
53       device_(d),
54       op_device_(op_device),
55       resource_device_(GetResourceDevice(t, ctx)),
56       remote_op_id_(-1),
57       remote_output_num_(-1),
58       remote_shape_node_id_(-1),
59       ctx_(ctx),
60       is_ready_(true) {}
61 
TensorHandle(uint64 node_id,Device * d,Device * op_device,Device * resource_device,DataType dtype,EagerContext * ctx)62 TensorHandle::TensorHandle(uint64 node_id, Device* d, Device* op_device,
63                            Device* resource_device, DataType dtype,
64                            EagerContext* ctx)
65     : dtype(dtype),
66       node_id_(node_id),
67       tensor_(dtype),
68       device_(d),
69       op_device_(op_device),
70       resource_device_(resource_device),
71       remote_op_id_(-1),
72       remote_output_num_(-1),
73       remote_shape_node_id_(-1),
74       ctx_(ctx),
75       is_ready_(ctx == nullptr) {
76   DCHECK_GT(node_id_, 0);
77   DCHECK(dtype == DT_RESOURCE ? resource_device_ != nullptr
78                               : resource_device_ == nullptr);
79 }
80 
TensorHandle(int64 op_id,int32 output_num,uint64 remote_shape_node_id,DataType dtype,std::function<void ()> call_on_destroy,Device * d,Device * op_device,Device * resource_device,EagerContext * ctx)81 TensorHandle::TensorHandle(int64 op_id, int32 output_num,
82                            uint64 remote_shape_node_id, DataType dtype,
83                            std::function<void()> call_on_destroy, Device* d,
84                            Device* op_device, Device* resource_device,
85                            EagerContext* ctx)
86     : dtype(dtype),
87       node_id_(0),
88       device_(d),
89       op_device_(op_device),
90       resource_device_(resource_device),
91       remote_op_id_(op_id),
92       remote_output_num_(output_num),
93       remote_shape_node_id_(remote_shape_node_id),
94       call_on_destroy_(std::move(call_on_destroy)),
95       ctx_(ctx),
96       is_ready_(true) {
97   DCHECK(IsRemote()) << "Op ID and output num should be >= 0. Op ID: " << op_id
98                      << ", Output num: " << output_num;
99   DCHECK(dtype == DT_RESOURCE ? resource_device_ != nullptr
100                               : resource_device_ == nullptr);
101 }
102 
TensorHandle(OutputGraphNode symbolic_tensor,DataType dtype)103 TensorHandle::TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype)
104     : dtype(dtype),
105       node_id_(0),
106       device_(nullptr),
107       op_device_(nullptr),
108       resource_device_(nullptr),
109       remote_op_id_(-1),
110       remote_output_num_(-1),
111       remote_shape_node_id_(-1),
112       ctx_(nullptr),
113       is_ready_(true),
114       symbolic_tensor(new OutputGraphNode(symbolic_tensor)) {}
115 
IsReady()116 bool TensorHandle::IsReady() {
117   if (node_id_ == 0) return true;
118   mutex_lock l(ctx_mutex_);
119   return is_ready_;
120 }
121 
IsRemote()122 bool TensorHandle::IsRemote() {
123   return remote_op_id_ >= 0 && remote_output_num_ >= 0;
124 }
125 
WaitForNode(uint64 node_id,bool return_if_is_ready)126 Status TensorHandle::WaitForNode(uint64 node_id, bool return_if_is_ready) {
127   if (node_id == 0) return Status::OK();
128   EagerExecutor* executor = nullptr;
129   {
130     mutex_lock l(ctx_mutex_);
131     if (return_if_is_ready && is_ready_) return Status::OK();
132     executor = ctx_->Executor();
133   }
134   return executor->WaitFor(node_id);
135 }
136 
WaitReady()137 Status TensorHandle::WaitReady() { return WaitForNode(node_id_, true); }
138 
Tensor(const tensorflow::Tensor ** t)139 Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
140   if (IsRemote()) {
141     return errors::Unavailable(
142         "Unable to get a tensor for a remote device. Please copy the tensor "
143         "handle to a local device using TFE_TensorHandleCopyToDevice");
144   }
145   TF_RETURN_IF_ERROR(WaitReady());
146   DCHECK(IsReady());
147   *t = &tensor_;
148   return Status::OK();
149 }
150 
TensorValue(tensorflow::TensorValue * t)151 Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
152   TF_RETURN_IF_ERROR(WaitReady());
153   DCHECK(IsReady());
154   *t = tensorflow::TensorValue(&tensor_);
155   return Status::OK();
156 }
157 
TensorAndDevice(const tensorflow::Tensor ** tensor,tensorflow::Device ** device,tensorflow::Device ** op_device)158 Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
159                                      tensorflow::Device** device,
160                                      tensorflow::Device** op_device) {
161   if (IsRemote()) {
162     return errors::Unavailable(
163         "Unable to get a tensor for a remote device. Please copy the tensor "
164         "handle to a local device using TFE_TensorHandleCopyToDevice");
165   }
166   TF_RETURN_IF_ERROR(WaitReady());
167   DCHECK(IsReady());
168   *tensor = &tensor_;
169   *device = device_;
170   *op_device = op_device_;
171   return Status::OK();
172 }
173 
Shape(tensorflow::TensorShape * shape)174 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
175   if (IsRemote()) {
176     TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
177     CHECK(remote_shape_ != nullptr);
178     *shape = *(remote_shape_.get());
179   } else {
180     TF_RETURN_IF_ERROR(WaitReady());
181     DCHECK(IsReady());
182     *shape = tensor_.shape();
183   }
184   return Status::OK();
185 }
186 
NumDims(int * num_dims)187 Status TensorHandle::NumDims(int* num_dims) {
188   if (IsRemote()) {
189     TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
190     *num_dims = remote_shape_->dims();
191   } else {
192     TF_RETURN_IF_ERROR(WaitReady());
193     DCHECK(IsReady());
194     DCHECK(num_dims != nullptr);
195 
196     *num_dims = tensor_.dims();
197   }
198 
199   return Status::OK();
200 }
201 
Dim(int dim_index,int64 * dim)202 Status TensorHandle::Dim(int dim_index, int64* dim) {
203   if (IsRemote()) {
204     TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
205     *dim = remote_shape_->dim_size(dim_index);
206   } else {
207     TF_RETURN_IF_ERROR(WaitReady());
208     DCHECK(IsReady());
209     DCHECK(dim != nullptr);
210 
211     *dim = tensor_.dim_size(dim_index);
212   }
213 
214   return Status::OK();
215 }
216 
NumElements(int64 * num_elements)217 Status TensorHandle::NumElements(int64* num_elements) {
218   if (IsRemote()) {
219     TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
220     *num_elements = remote_shape_->num_elements();
221   } else {
222     TF_RETURN_IF_ERROR(WaitReady());
223     DCHECK(IsReady());
224     DCHECK(num_elements != nullptr);
225 
226     *num_elements = tensor_.NumElements();
227   }
228 
229   return Status::OK();
230 }
231 
RemoteAddress(int64 * op_id,int32 * output_num)232 Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
233   if (!IsRemote()) {
234     return errors::FailedPrecondition(
235         "This TensorHandle refers to a local tensor handle");
236   }
237   *op_id = remote_op_id_;
238   *output_num = remote_output_num_;
239 
240   return Status::OK();
241 }
242 
SetTensor(const tensorflow::Tensor & tensor)243 void TensorHandle::SetTensor(const tensorflow::Tensor& tensor) {
244   mutex_lock l(ctx_mutex_);
245   DCHECK(node_id_ > 0 && !is_ready_) << "SetTensor should be only called  "
246                                      << "on non-ready handles.";
247   is_ready_ = true;
248   tensor_ = tensor;
249 }
250 
CopyToDevice(EagerContext * ctx,tensorflow::Device * dstd,TensorHandle ** output)251 Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
252                                   TensorHandle** output) {
253   const tensorflow::Tensor* src = nullptr;
254   tensorflow::Device* srcd = nullptr;
255   // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
256   // nullptr.
257   tensorflow::Device* src_opd = nullptr;
258   TF_RETURN_IF_ERROR(TensorAndDevice(&src, &srcd, &src_opd));
259   if (srcd == nullptr) srcd = ctx->HostCPU();
260   bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
261   const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
262   const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
263   if (is_same_device) {
264     *output = new tensorflow::TensorHandle(*src, dstd, dstd, ctx);
265     return tensorflow::Status::OK();
266   }
267   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
268                    !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
269     return tensorflow::errors::InvalidArgument(
270         "Can't copy Tensor with type ",
271         tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
272         ".");
273   }
274   tensorflow::AllocatorAttributes attr;
275   if (src->dtype() == tensorflow::DT_VARIANT) {
276     attr.set_on_host(true);
277   }
278   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
279   if (src->shape().num_elements() == 0) {
280     dstd = dst_cpu ? nullptr : dstd;
281     *output = new tensorflow::TensorHandle(dst, dstd, dstd, ctx);
282     return tensorflow::Status::OK();
283   }
284   tensorflow::DeviceContext* src_device_context = nullptr;
285   if (!src_cpu) {
286     src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
287   }
288   tensorflow::DeviceContext* dst_device_context = nullptr;
289   if (!dst_cpu) {
290     dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
291   }
292   // TODO(ashankar): The Sync() call below may be more aggressive than
293   // necessary. It is based on knowledge of implementation details - that
294   // GPU devices are implemented using 3 streams - one for host->device copies,
295   // one for device->host copies and one for sending operations to the GPU.
296   // With that setup, Sync()ing across all 3 streams should be sufficient
297   // but more than necessary (since it waits for operations that might have
298   // nothing to do with this tensor to complete).
299   TF_RETURN_IF_ERROR(srcd->Sync());
300   tensorflow::Notification n;
301   tensorflow::Status status;
302   tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
303                                  srcd, dstd, tensorflow::AllocatorAttributes(),
304                                  tensorflow::AllocatorAttributes(), src, &dst,
305                                  0 /*dev_to_dev_stream_index*/,
306                                  [&status, &n](const tensorflow::Status& s) {
307                                    status = s;
308                                    n.Notify();
309                                  });
310   n.WaitForNotification();
311   if (status.ok()) {
312     dstd = dst_cpu ? nullptr : dstd;
313     *output = new tensorflow::TensorHandle(dst, dstd, dstd, ctx);
314   }
315   return status;
316 }
317 
GetResourceDevice(const Tensor & t,EagerContext * ctx)318 Device* GetResourceDevice(const Tensor& t, EagerContext* ctx) {
319   if (t.dtype() != DT_RESOURCE) {
320     return nullptr;
321   }
322   const ResourceHandle& resource_handle = t.flat<ResourceHandle>()(0);
323   const auto& map = *ctx->device_map();
324   auto it = map.find(resource_handle.device());
325   DCHECK(it != map.end());
326   return it->second;
327 }
328 
DebugString() const329 string TensorHandle::DebugString() const {
330   VLOG(1) << "Calling TensorHandle::DebugString() on " << this;
331 
332   if (symbolic_tensor) {
333     return absl::Substitute("TF_Output($0, $1)", symbolic_tensor->oper,
334                             symbolic_tensor->index);
335   }
336 
337   string out;
338   strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
339   // Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
340   strings::StrAppend(&out, ", Tensor: ", device_ ? "?" : tensor_.DebugString(),
341                      "\n");
342   return out;
343 }
344 
345 }  // namespace tensorflow
346