1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <vector>
24
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/common_runtime/copy_tensor.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
31 #include "tensorflow/core/common_runtime/function.h"
32 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
33 #include "tensorflow/core/framework/rendezvous.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/lib/core/stringpiece.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/gtl/map_util.h"
39 #include "tensorflow/core/lib/gtl/stl_util.h"
40 #include "tensorflow/core/platform/fingerprint.h"
41 #include "tensorflow/core/platform/mutex.h"
42 #include "tensorflow/core/platform/thread_annotations.h"
43 #include "tensorflow/core/public/session_options.h"
44 #include "tensorflow/core/public/version.h"
45
46 namespace tensorflow {
47
TensorHandle(const class Tensor & t,Device * d,Device * op_device,EagerContext * ctx)48 TensorHandle::TensorHandle(const class Tensor& t, Device* d, Device* op_device,
49 EagerContext* ctx)
50 : dtype(t.dtype()),
51 node_id_(0),
52 tensor_(t),
53 device_(d),
54 op_device_(op_device),
55 resource_device_(GetResourceDevice(t, ctx)),
56 remote_op_id_(-1),
57 remote_output_num_(-1),
58 remote_shape_node_id_(-1),
59 ctx_(ctx),
60 is_ready_(true) {}
61
TensorHandle(uint64 node_id,Device * d,Device * op_device,Device * resource_device,DataType dtype,EagerContext * ctx)62 TensorHandle::TensorHandle(uint64 node_id, Device* d, Device* op_device,
63 Device* resource_device, DataType dtype,
64 EagerContext* ctx)
65 : dtype(dtype),
66 node_id_(node_id),
67 tensor_(dtype),
68 device_(d),
69 op_device_(op_device),
70 resource_device_(resource_device),
71 remote_op_id_(-1),
72 remote_output_num_(-1),
73 remote_shape_node_id_(-1),
74 ctx_(ctx),
75 is_ready_(ctx == nullptr) {
76 DCHECK_GT(node_id_, 0);
77 DCHECK(dtype == DT_RESOURCE ? resource_device_ != nullptr
78 : resource_device_ == nullptr);
79 }
80
TensorHandle(int64 op_id,int32 output_num,uint64 remote_shape_node_id,DataType dtype,std::function<void ()> call_on_destroy,Device * d,Device * op_device,Device * resource_device,EagerContext * ctx)81 TensorHandle::TensorHandle(int64 op_id, int32 output_num,
82 uint64 remote_shape_node_id, DataType dtype,
83 std::function<void()> call_on_destroy, Device* d,
84 Device* op_device, Device* resource_device,
85 EagerContext* ctx)
86 : dtype(dtype),
87 node_id_(0),
88 device_(d),
89 op_device_(op_device),
90 resource_device_(resource_device),
91 remote_op_id_(op_id),
92 remote_output_num_(output_num),
93 remote_shape_node_id_(remote_shape_node_id),
94 call_on_destroy_(std::move(call_on_destroy)),
95 ctx_(ctx),
96 is_ready_(true) {
97 DCHECK(IsRemote()) << "Op ID and output num should be >= 0. Op ID: " << op_id
98 << ", Output num: " << output_num;
99 DCHECK(dtype == DT_RESOURCE ? resource_device_ != nullptr
100 : resource_device_ == nullptr);
101 }
102
TensorHandle(OutputGraphNode symbolic_tensor,DataType dtype)103 TensorHandle::TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype)
104 : dtype(dtype),
105 node_id_(0),
106 device_(nullptr),
107 op_device_(nullptr),
108 resource_device_(nullptr),
109 remote_op_id_(-1),
110 remote_output_num_(-1),
111 remote_shape_node_id_(-1),
112 ctx_(nullptr),
113 is_ready_(true),
114 symbolic_tensor(new OutputGraphNode(symbolic_tensor)) {}
115
IsReady()116 bool TensorHandle::IsReady() {
117 if (node_id_ == 0) return true;
118 mutex_lock l(ctx_mutex_);
119 return is_ready_;
120 }
121
IsRemote()122 bool TensorHandle::IsRemote() {
123 return remote_op_id_ >= 0 && remote_output_num_ >= 0;
124 }
125
WaitForNode(uint64 node_id,bool return_if_is_ready)126 Status TensorHandle::WaitForNode(uint64 node_id, bool return_if_is_ready) {
127 if (node_id == 0) return Status::OK();
128 EagerExecutor* executor = nullptr;
129 {
130 mutex_lock l(ctx_mutex_);
131 if (return_if_is_ready && is_ready_) return Status::OK();
132 executor = ctx_->Executor();
133 }
134 return executor->WaitFor(node_id);
135 }
136
WaitReady()137 Status TensorHandle::WaitReady() { return WaitForNode(node_id_, true); }
138
Tensor(const tensorflow::Tensor ** t)139 Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
140 if (IsRemote()) {
141 return errors::Unavailable(
142 "Unable to get a tensor for a remote device. Please copy the tensor "
143 "handle to a local device using TFE_TensorHandleCopyToDevice");
144 }
145 TF_RETURN_IF_ERROR(WaitReady());
146 DCHECK(IsReady());
147 *t = &tensor_;
148 return Status::OK();
149 }
150
TensorValue(tensorflow::TensorValue * t)151 Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
152 TF_RETURN_IF_ERROR(WaitReady());
153 DCHECK(IsReady());
154 *t = tensorflow::TensorValue(&tensor_);
155 return Status::OK();
156 }
157
TensorAndDevice(const tensorflow::Tensor ** tensor,tensorflow::Device ** device,tensorflow::Device ** op_device)158 Status TensorHandle::TensorAndDevice(const tensorflow::Tensor** tensor,
159 tensorflow::Device** device,
160 tensorflow::Device** op_device) {
161 if (IsRemote()) {
162 return errors::Unavailable(
163 "Unable to get a tensor for a remote device. Please copy the tensor "
164 "handle to a local device using TFE_TensorHandleCopyToDevice");
165 }
166 TF_RETURN_IF_ERROR(WaitReady());
167 DCHECK(IsReady());
168 *tensor = &tensor_;
169 *device = device_;
170 *op_device = op_device_;
171 return Status::OK();
172 }
173
Shape(tensorflow::TensorShape * shape)174 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
175 if (IsRemote()) {
176 TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
177 CHECK(remote_shape_ != nullptr);
178 *shape = *(remote_shape_.get());
179 } else {
180 TF_RETURN_IF_ERROR(WaitReady());
181 DCHECK(IsReady());
182 *shape = tensor_.shape();
183 }
184 return Status::OK();
185 }
186
NumDims(int * num_dims)187 Status TensorHandle::NumDims(int* num_dims) {
188 if (IsRemote()) {
189 TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
190 *num_dims = remote_shape_->dims();
191 } else {
192 TF_RETURN_IF_ERROR(WaitReady());
193 DCHECK(IsReady());
194 DCHECK(num_dims != nullptr);
195
196 *num_dims = tensor_.dims();
197 }
198
199 return Status::OK();
200 }
201
Dim(int dim_index,int64 * dim)202 Status TensorHandle::Dim(int dim_index, int64* dim) {
203 if (IsRemote()) {
204 TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
205 *dim = remote_shape_->dim_size(dim_index);
206 } else {
207 TF_RETURN_IF_ERROR(WaitReady());
208 DCHECK(IsReady());
209 DCHECK(dim != nullptr);
210
211 *dim = tensor_.dim_size(dim_index);
212 }
213
214 return Status::OK();
215 }
216
NumElements(int64 * num_elements)217 Status TensorHandle::NumElements(int64* num_elements) {
218 if (IsRemote()) {
219 TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
220 *num_elements = remote_shape_->num_elements();
221 } else {
222 TF_RETURN_IF_ERROR(WaitReady());
223 DCHECK(IsReady());
224 DCHECK(num_elements != nullptr);
225
226 *num_elements = tensor_.NumElements();
227 }
228
229 return Status::OK();
230 }
231
RemoteAddress(int64 * op_id,int32 * output_num)232 Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
233 if (!IsRemote()) {
234 return errors::FailedPrecondition(
235 "This TensorHandle refers to a local tensor handle");
236 }
237 *op_id = remote_op_id_;
238 *output_num = remote_output_num_;
239
240 return Status::OK();
241 }
242
SetTensor(const tensorflow::Tensor & tensor)243 void TensorHandle::SetTensor(const tensorflow::Tensor& tensor) {
244 mutex_lock l(ctx_mutex_);
245 DCHECK(node_id_ > 0 && !is_ready_) << "SetTensor should be only called "
246 << "on non-ready handles.";
247 is_ready_ = true;
248 tensor_ = tensor;
249 }
250
CopyToDevice(EagerContext * ctx,tensorflow::Device * dstd,TensorHandle ** output)251 Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
252 TensorHandle** output) {
253 const tensorflow::Tensor* src = nullptr;
254 tensorflow::Device* srcd = nullptr;
255 // TODO(agarwal): src_opd is unused. Perhaps allow TensorAndDevice to accept
256 // nullptr.
257 tensorflow::Device* src_opd = nullptr;
258 TF_RETURN_IF_ERROR(TensorAndDevice(&src, &srcd, &src_opd));
259 if (srcd == nullptr) srcd = ctx->HostCPU();
260 bool is_same_device = (srcd == dstd) || (srcd->name() == dstd->name());
261 const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
262 const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
263 if (is_same_device) {
264 *output = new tensorflow::TensorHandle(*src, dstd, dstd, ctx);
265 return tensorflow::Status::OK();
266 }
267 if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
268 !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
269 return tensorflow::errors::InvalidArgument(
270 "Can't copy Tensor with type ",
271 tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
272 ".");
273 }
274 tensorflow::AllocatorAttributes attr;
275 if (src->dtype() == tensorflow::DT_VARIANT) {
276 attr.set_on_host(true);
277 }
278 tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
279 if (src->shape().num_elements() == 0) {
280 dstd = dst_cpu ? nullptr : dstd;
281 *output = new tensorflow::TensorHandle(dst, dstd, dstd, ctx);
282 return tensorflow::Status::OK();
283 }
284 tensorflow::DeviceContext* src_device_context = nullptr;
285 if (!src_cpu) {
286 src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
287 }
288 tensorflow::DeviceContext* dst_device_context = nullptr;
289 if (!dst_cpu) {
290 dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
291 }
292 // TODO(ashankar): The Sync() call below may be more aggressive than
293 // necessary. It is based on knowledge of implementation details - that
294 // GPU devices are implemented using 3 streams - one for host->device copies,
295 // one for device->host copies and one for sending operations to the GPU.
296 // With that setup, Sync()ing across all 3 streams should be sufficient
297 // but more than necessary (since it waits for operations that might have
298 // nothing to do with this tensor to complete).
299 TF_RETURN_IF_ERROR(srcd->Sync());
300 tensorflow::Notification n;
301 tensorflow::Status status;
302 tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
303 srcd, dstd, tensorflow::AllocatorAttributes(),
304 tensorflow::AllocatorAttributes(), src, &dst,
305 0 /*dev_to_dev_stream_index*/,
306 [&status, &n](const tensorflow::Status& s) {
307 status = s;
308 n.Notify();
309 });
310 n.WaitForNotification();
311 if (status.ok()) {
312 dstd = dst_cpu ? nullptr : dstd;
313 *output = new tensorflow::TensorHandle(dst, dstd, dstd, ctx);
314 }
315 return status;
316 }
317
GetResourceDevice(const Tensor & t,EagerContext * ctx)318 Device* GetResourceDevice(const Tensor& t, EagerContext* ctx) {
319 if (t.dtype() != DT_RESOURCE) {
320 return nullptr;
321 }
322 const ResourceHandle& resource_handle = t.flat<ResourceHandle>()(0);
323 const auto& map = *ctx->device_map();
324 auto it = map.find(resource_handle.device());
325 DCHECK(it != map.end());
326 return it->second;
327 }
328
DebugString() const329 string TensorHandle::DebugString() const {
330 VLOG(1) << "Calling TensorHandle::DebugString() on " << this;
331
332 if (symbolic_tensor) {
333 return absl::Substitute("TF_Output($0, $1)", symbolic_tensor->oper,
334 symbolic_tensor->index);
335 }
336
337 string out;
338 strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
339 // Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
340 strings::StrAppend(&out, ", Tensor: ", device_ ? "?" : tensor_.DebugString(),
341 "\n");
342 return out;
343 }
344
345 } // namespace tensorflow
346