1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <vector>
24
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/common_runtime/copy_tensor.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
31 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/framework/shape_inference.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #if !defined(IS_MOBILE_PLATFORM)
38 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
39 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
40 #endif // IS_MOBILE_PLATFORM
41 #include "tensorflow/core/framework/rendezvous.h"
42 #include "tensorflow/core/framework/resource_var.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/fingerprint.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/thread_annotations.h"
52 #include "tensorflow/core/profiler/lib/traceme.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55
56 namespace tensorflow {
57
58 namespace {
59 #if !defined(IS_MOBILE_PLATFORM)
60 const int64 kInvalidOpId = -1;
61 const int32 kInvalidOutputNum = -1;
62 #endif
63 } // namespace
64
SetResourceHandleDtypeAndShape(std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes)65 void TensorHandle::SetResourceHandleDtypeAndShape(
66 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
67 handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
68 }
69
GetResourceHandleDtypesAndShapes(std::vector<DtypeAndPartialTensorShape> * result)70 Status TensorHandle::GetResourceHandleDtypesAndShapes(
71 std::vector<DtypeAndPartialTensorShape>* result) {
72 if (dtype != DT_RESOURCE) {
73 return errors::InvalidArgument(
74 "TensorHandle::GetResourceDtypeAndShape should be called on tensor "
75 "handles with data type DT_RESOURCE. Actual tensor: ",
76 dtype);
77 }
78
79 if (IsRemote()) {
80 *result = handle_dtypes_and_shapes_;
81 return Status::OK();
82 }
83
84 // Wait for this TensorHandle to be ready.
85 profiler::TraceMe activity(
86 "TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
87 profiler::TraceMeLevel::kInfo);
88 TF_RETURN_IF_ERROR(
89 WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
90
91 *result = handle_dtypes_and_shapes_;
92 return Status::OK();
93 }
94
CreateLocalHandle(const class Tensor & t,TensorHandle ** h)95 Status TensorHandle::CreateLocalHandle(const class Tensor& t,
96 TensorHandle** h) {
97 // TODO(b/136608821): Move away from nullptr
98 return CreateLocalHandle(t, /*d=*/nullptr, /*op_device=*/nullptr,
99 /*ctx=*/nullptr, h);
100 }
101
CreateLocalHandle(const class Tensor & t,Device * d,EagerContext * ctx,TensorHandle ** h)102 Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
103 EagerContext* ctx, TensorHandle** h) {
104 return CreateLocalHandle(t, d, d, ctx, h);
105 }
106
CreateLocalHandle(const class Tensor & t,Device * d,Device * op_device,EagerContext * ctx,TensorHandle ** h)107 Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
108 Device* op_device, EagerContext* ctx,
109 TensorHandle** h) {
110 if (t.dtype() != DT_RESOURCE) {
111 *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
112 t.dtype(), d, op_device, ctx);
113 } else {
114 const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0);
115 *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
116 resource_handle, d, op_device, ctx);
117 }
118
119 return Status::OK();
120 }
121
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,DataType dtype,Device * d,Device * op_device,EagerContext * ctx)122 TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
123 DataType dtype, Device* d, Device* op_device,
124 EagerContext* ctx)
125 : dtype(dtype),
126 device_(d),
127 op_device_(op_device),
128 resource_device_(nullptr),
129 #if !defined(IS_MOBILE_PLATFORM)
130 remote_op_id_(kInvalidOpId),
131 remote_output_num_(kInvalidOutputNum),
132 #endif
133 ctx_(ctx),
134 is_remote_(false),
135 is_async_(false),
136 is_ready_(true),
137 tensor_handle_data_(std::move(t)) {
138 DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
139 }
140
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,const ResourceHandle & resource_handle,Device * d,Device * op_device,EagerContext * ctx)141 TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
142 const ResourceHandle& resource_handle, Device* d,
143 Device* op_device, EagerContext* ctx)
144 : dtype(DT_RESOURCE),
145 device_(d),
146 op_device_(op_device),
147 resource_device_(GetResourceDevice(resource_handle, ctx)),
148 #if !defined(IS_MOBILE_PLATFORM)
149 remote_op_id_(kInvalidOpId),
150 remote_output_num_(kInvalidOutputNum),
151 #endif
152 ctx_(ctx),
153 is_remote_(false),
154 is_async_(false),
155 is_ready_(true),
156 handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
157 tensor_handle_data_(std::move(t)) {
158 DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
159 }
160
CreateEmptyLocalHandle(bool async,Device * d,Device * op_device,Device * resource_device,DataType dtype,EagerContext * ctx,TensorHandle ** h)161 Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
162 Device* op_device,
163 Device* resource_device,
164 DataType dtype, EagerContext* ctx,
165 TensorHandle** h) {
166 *h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async,
167 d, op_device, resource_device, dtype, ctx);
168
169 return Status::OK();
170 }
171
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,bool async,Device * d,Device * op_device,Device * resource_device,DataType dtype,EagerContext * ctx)172 TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
173 bool async, Device* d, Device* op_device,
174 Device* resource_device, DataType dtype,
175 EagerContext* ctx)
176 : dtype(dtype),
177 device_(d),
178 op_device_(op_device),
179 resource_device_(resource_device),
180 #if !defined(IS_MOBILE_PLATFORM)
181 remote_op_id_(kInvalidOpId),
182 remote_output_num_(kInvalidOutputNum),
183 #endif
184 ctx_(ctx),
185 is_remote_(false),
186 is_async_(async),
187 is_ready_(!async),
188 tensor_handle_data_(std::move(t)) {
189 DVLOG(3) << "Creating Async Local TensorHandle: " << this
190 << " device: " << device_;
191 }
192
193 #if !defined(IS_MOBILE_PLATFORM)
CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,DataType dtype,Device * d,Device * resource_device,EagerContext * ctx,TensorHandle ** h)194 Status TensorHandle::CreateRemoteHandle(
195 std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d,
196 Device* resource_device, EagerContext* ctx, TensorHandle** h) {
197 *h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx);
198
199 return Status::OK();
200 }
201
CreateRemoteHandle(int64 op_id,int output_num,const TensorShape & shape,const string & remote_task,uint64 context_id,DataType dtype,Device * d,Device * resource_device,EagerContext * ctx,TensorHandle ** h)202 Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
203 const TensorShape& shape,
204 const string& remote_task,
205 uint64 context_id, DataType dtype,
206 Device* d, Device* resource_device,
207 EagerContext* ctx, TensorHandle** h) {
208 *h = new TensorHandle(
209 absl::make_unique<RemoteTensorHandleData>(op_id, output_num, shape,
210 remote_task, context_id, ctx),
211 dtype, d, resource_device, ctx);
212 return Status::OK();
213 }
214
TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,DataType dtype,Device * d,Device * resource_device,EagerContext * ctx)215 TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
216 DataType dtype, Device* d, Device* resource_device,
217 EagerContext* ctx)
218 : dtype(dtype),
219 device_(d),
220 op_device_(d),
221 resource_device_(resource_device),
222 remote_op_id_(t->op_id()),
223 remote_output_num_(t->output_num()),
224 ctx_(ctx),
225 is_remote_(true),
226 is_async_(false),
227 is_ready_(true),
228 tensor_handle_data_(std::move(t)) {
229 DVLOG(3) << "Creating Remote TensorHandle: " << this
230 << " device: " << device_;
231 }
232
CreateUnshapedRemoteHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,DataType dtype,Device * d,EagerContext * ctx,TensorHandle ** h)233 Status TensorHandle::CreateUnshapedRemoteHandle(
234 std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
235 Device* d, EagerContext* ctx, TensorHandle** h) {
236 *h = new TensorHandle(std::move(t), dtype, d, ctx);
237
238 return Status::OK();
239 }
240
CreateUnshapedRemoteHandle(int64 op_id,int32 output_num,const string & remote_task,uint64 context_id,DataType dtype,Device * device,EagerContext * ctx,TensorHandle ** h)241 Status TensorHandle::CreateUnshapedRemoteHandle(
242 int64 op_id, int32 output_num, const string& remote_task, uint64 context_id,
243 DataType dtype, Device* device, EagerContext* ctx, TensorHandle** h) {
244 *h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>(
245 op_id, output_num, remote_task, context_id, ctx),
246 dtype, device, ctx);
247 return Status::OK();
248 }
249
TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,DataType dtype,Device * device,EagerContext * ctx)250 TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
251 DataType dtype, Device* device, EagerContext* ctx)
252 : dtype(dtype),
253 device_(device),
254 op_device_(device),
255 resource_device_(dtype == DT_RESOURCE ? device : nullptr),
256 remote_op_id_(t->op_id()),
257 remote_output_num_(t->output_num()),
258 remote_task_(t->remote_task()),
259 remote_context_id_(t->context_id()),
260 ctx_(ctx),
261 is_remote_(true),
262 is_async_(true),
263 is_ready_(false),
264 tensor_handle_data_(std::move(t)) {
265 DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
266 << " device: " << device_;
267 }
268 #endif
269
IsReady() const270 bool TensorHandle::IsReady() const {
271 // Avoid mutex acquisition for local sync handles
272 if (!is_async_ && !is_remote_) {
273 return true;
274 }
275
276 tf_shared_lock l(mu_);
277 return is_ready_;
278 }
279
WaitReady(const char * caller) const280 Status TensorHandle::WaitReady(const char* caller) const {
281 if (!IsReady()) {
282 profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
283 profiler::TraceMeLevel::kInfo);
284 tf_shared_lock l(mu_);
285 mu_.Await(Condition(&is_ready_));
286 }
287 return is_poisoned_;
288 }
289
Tensor(const tensorflow::Tensor ** t)290 Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
291 TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor"));
292 return tensor_handle_data_->Tensor(t);
293 }
294
TensorValue(tensorflow::TensorValue * t)295 Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
296 TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue"));
297 return tensor_handle_data_->TensorValue(t);
298 }
299
DeviceOrHostCPU(const EagerContext & ctx) const300 Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
301 return (device_ == nullptr) ? ctx.HostCPU() : device_;
302 }
303
Shape(tensorflow::TensorShape * shape)304 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
305 if (!IsReady() && inference_shape_.IsFullyDefined()) {
306 bool fill = inference_shape_.AsTensorShape(shape);
307 DCHECK(fill);
308 return Status::OK();
309 } else {
310 TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape"));
311 return tensor_handle_data_->Shape(shape);
312 }
313 }
314
InferenceShape(shape_inference::InferenceContext * const inference_context,shape_inference::ShapeHandle * shape_handle)315 Status TensorHandle::InferenceShape(
316 shape_inference::InferenceContext* const inference_context,
317 shape_inference::ShapeHandle* shape_handle) {
318 if (IsReady()) {
319 TF_RETURN_IF_ERROR(is_poisoned_);
320 std::vector<shape_inference::DimensionHandle> dims_handle;
321 int num_dims;
322 TF_RETURN_IF_ERROR(NumDims(&num_dims));
323 for (int i = 0; i < num_dims; i++) {
324 int64 dims;
325 TF_RETURN_IF_ERROR(Dim(i, &dims));
326 dims_handle.push_back(inference_context->MakeDim(dims));
327 }
328 *shape_handle = inference_context->MakeShape(dims_handle);
329 return Status::OK();
330 } else {
331 if (inference_shape_.unknown_rank()) {
332 *shape_handle = inference_context->UnknownShape();
333 return Status::OK();
334 }
335 std::vector<shape_inference::DimensionHandle> dims_handle(
336 inference_shape_.dims());
337 for (int i = 0; i < dims_handle.size(); i++) {
338 dims_handle[i] = inference_context->MakeDim(inference_shape_.dim_size(i));
339 }
340 *shape_handle = inference_context->MakeShape(dims_handle);
341 return Status::OK();
342 }
343 }
344
SetInferenceShape(shape_inference::InferenceContext * const inference_context,const shape_inference::ShapeHandle & shape_handle)345 void TensorHandle::SetInferenceShape(
346 shape_inference::InferenceContext* const inference_context,
347 const shape_inference::ShapeHandle& shape_handle) {
348 auto num_dims = inference_context->Rank(shape_handle);
349 std::vector<int64> dims;
350 if (num_dims == shape_inference::InferenceContext::kUnknownRank) {
351 inference_shape_ = PartialTensorShape();
352 return;
353 }
354 DCHECK_GE(num_dims, 0);
355 dims.resize(num_dims);
356 for (size_t i = 0; i < num_dims; ++i) {
357 dims[i] = inference_context->Value(inference_context->Dim(shape_handle, i));
358 }
359 auto s = PartialTensorShape::MakePartialShape(dims.data(), num_dims,
360 &inference_shape_);
361 DCHECK(s.ok());
362 }
363
CopyInferenceShape(TensorHandle * other)364 Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
365 if (IsReady()) {
366 TF_RETURN_IF_ERROR(is_poisoned_);
367 return Status::OK();
368 }
369 if (other->IsReady()) {
370 TensorShape other_shape;
371 TF_RETURN_IF_ERROR(other->Shape(&other_shape));
372 inference_shape_ = other_shape;
373 } else {
374 inference_shape_ = other->inference_shape_;
375 }
376 return Status::OK();
377 }
378
NumDims(int * num_dims) const379 Status TensorHandle::NumDims(int* num_dims) const {
380 DCHECK(num_dims != nullptr);
381 if (!IsReady() && !inference_shape_.unknown_rank()) {
382 *num_dims = inference_shape_.dims();
383 return Status::OK();
384 } else {
385 TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims"));
386 return tensor_handle_data_->NumDims(num_dims);
387 }
388 }
389
Dim(int dim_index,int64 * dim) const390 Status TensorHandle::Dim(int dim_index, int64* dim) const {
391 DCHECK(dim != nullptr);
392 if (!IsReady() && !inference_shape_.unknown_rank() &&
393 inference_shape_.dim_size(dim_index) != -1) {
394 *dim = inference_shape_.dim_size(dim_index);
395 return Status::OK();
396 } else {
397 TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim"));
398 return tensor_handle_data_->Dim(dim_index, dim);
399 }
400 }
401
NumElements(int64 * num_elements) const402 Status TensorHandle::NumElements(int64* num_elements) const {
403 DCHECK(num_elements != nullptr);
404 if (!IsReady() && inference_shape_.IsFullyDefined()) {
405 *num_elements = inference_shape_.num_elements();
406 return Status::OK();
407 } else {
408 TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements"));
409 return tensor_handle_data_->NumElements(num_elements);
410 }
411 }
412
413 #if !defined(IS_MOBILE_PLATFORM)
RemoteAddress(Device * d,int64 * op_id,int32 * output_num) const414 Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
415 int32* output_num) const {
416 if (d != device_) {
417 tf_shared_lock l(mu_);
418 auto mirror = remote_mirrors_.find(d);
419 if (mirror != remote_mirrors_.end()) {
420 *op_id = mirror->second->op_id();
421 *output_num = mirror->second->output_num();
422 return Status::OK();
423 }
424
425 auto unshaped_mirror = unshaped_remote_mirrors_.find(d);
426 if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
427 *op_id = unshaped_mirror->second->op_id();
428 *output_num = unshaped_mirror->second->output_num();
429 return Status::OK();
430 }
431
432 return errors::FailedPrecondition(
433 "Could not find remote mirror for specified device");
434 }
435
436 if (remote_op_id_ == kInvalidOpId ||
437 remote_output_num_ == kInvalidOutputNum) {
438 return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
439 ", output_num:", remote_output_num_,
440 ") is not set.");
441 }
442 *op_id = remote_op_id_;
443 *output_num = remote_output_num_;
444 return Status::OK();
445 }
446
SetRemoteOpIdAndOutputNumToLocalTensorHandle(const int64 op_id,const int32 output_num)447 void TensorHandle::SetRemoteOpIdAndOutputNumToLocalTensorHandle(
448 const int64 op_id, const int32 output_num) {
449 DCHECK(!is_remote_);
450 remote_op_id_ = op_id;
451 remote_output_num_ = output_num;
452 }
453
HasRemoteMirror(Device * d)454 bool TensorHandle::HasRemoteMirror(Device* d) {
455 tf_shared_lock l(mu_);
456 auto mirror = remote_mirrors_.find(d);
457 if (mirror != remote_mirrors_.end()) {
458 return true;
459 }
460
461 auto unshaped_mirror = unshaped_remote_mirrors_.find(d);
462 if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
463 return true;
464 }
465
466 return false;
467 }
468
HasResourceShapeMirror(Device * d)469 bool TensorHandle::HasResourceShapeMirror(Device* d) {
470 tf_shared_lock l(mu_);
471 auto mirror = resource_shape_mirrors_.find(d);
472 if (mirror != resource_shape_mirrors_.end()) {
473 return true;
474 }
475 return false;
476 }
477
AddUnshapedRemoteMirror(std::unique_ptr<UnshapedRemoteTensorHandleData> t,Device * d)478 Status TensorHandle::AddUnshapedRemoteMirror(
479 std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
480 mutex_lock l(mu_);
481 if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
482 return errors::Internal("Attempted to duplicate a remote mirror.");
483 }
484
485 auto ret = unshaped_remote_mirrors_.insert(std::make_pair(d, std::move(t)));
486 if (!ret.second) {
487 return errors::Internal(
488 "Attempted to duplicate an unshaped remote mirror.");
489 }
490
491 return Status::OK();
492 }
493
AddResourceShapeMirror(std::unique_ptr<UnshapedRemoteTensorHandleData> t,Device * d)494 Status TensorHandle::AddResourceShapeMirror(
495 std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
496 mutex_lock l(mu_);
497 auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t)));
498 if (!ret.second) {
499 return errors::Internal("Attempted to duplicate a resource shape mirror.");
500 }
501
502 return Status::OK();
503 }
504
AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,Device * d)505 Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
506 Device* d) {
507 mutex_lock l(mu_);
508 auto ret = remote_mirrors_.insert(std::make_pair(d, std::move(t)));
509 if (!ret.second) {
510 return errors::Internal("Attempted to duplicate a remote mirror.");
511 }
512
513 return Status::OK();
514 }
515
SetRemoteShape(const TensorShape & shape,tensorflow::Device * d)516 Status TensorHandle::SetRemoteShape(const TensorShape& shape,
517 tensorflow::Device* d) {
518 DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
519
520 if (d != device_) {
521 mutex_lock l(mu_);
522 if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
523 return errors::Internal(
524 "Attempted to set remote shape for existing mirror.");
525 }
526
527 auto elem = unshaped_remote_mirrors_.find(d);
528 if (elem == unshaped_remote_mirrors_.end()) {
529 return errors::Internal(
530 "Attempted to set remote shape for non-waiting mirror.");
531 }
532
533 auto& data = elem->second;
534 data->ReleaseRemoteTensorHandle();
535 remote_mirrors_[d] = absl::make_unique<RemoteTensorHandleData>(
536 data->op_id(), data->output_num(), shape, data->remote_task(),
537 data->context_id(), data->ctx());
538 unshaped_remote_mirrors_.erase(elem);
539
540 return Status::OK();
541 }
542
543 DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles.";
544 DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
545
546 UnshapedRemoteTensorHandleData* p =
547 reinterpret_cast<UnshapedRemoteTensorHandleData*>(
548 tensor_handle_data_.get());
549 p->ReleaseRemoteTensorHandle();
550 tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
551 remote_op_id_, remote_output_num_, shape, remote_task_,
552 remote_context_id_, ctx_);
553 is_poisoned_ = Status::OK();
554 mutex_lock l(mu_);
555 is_ready_ = true;
556
557 return Status::OK();
558 }
559 #endif
560
SetTensor(tensorflow::Tensor && tensor)561 Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
562 DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
563 DCHECK(!is_async_ || !IsReady())
564 << "SetTensor is only called on non-ready handles.";
565
566 DVLOG(3) << "SetTensor on TensorHandle: " << this;
567
568 if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) {
569 auto& resource_handle = tensor.flat<class ResourceHandle>()(0);
570 handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
571 }
572 tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
573 if (is_async_) {
574 is_poisoned_ = Status::OK();
575 mutex_lock l(mu_);
576 is_ready_ = true;
577 }
578
579 return Status::OK();
580 }
581
Poison(Status status)582 void TensorHandle::Poison(Status status) {
583 DCHECK(!is_async_ || !IsReady())
584 << "Poison(status) can only be called on non-ready handle: " << this;
585
586 DVLOG(3) << "Poison on TensorHandle: " << this;
587
588 is_poisoned_ = status;
589 mutex_lock l(mu_);
590 is_ready_ = true;
591 }
592
CopyToDevice(const EagerContext & ctx,tensorflow::Device * dstd,tensorflow::Tensor * output)593 Status TensorHandle::CopyToDevice(const EagerContext& ctx,
594 tensorflow::Device* dstd,
595 tensorflow::Tensor* output) {
596 tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
597 const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
598 const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
599 bool is_same_device =
600 (srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
601
602 const tensorflow::Tensor* src = nullptr;
603 TF_RETURN_IF_ERROR(Tensor(&src));
604 if (is_same_device) {
605 *output = *src;
606 return Status::OK();
607 }
608 if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
609 !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
610 return tensorflow::errors::InvalidArgument(
611 "Can't copy Tensor with type ",
612 tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
613 ".");
614 }
615 tensorflow::AllocatorAttributes attr;
616 if (src->dtype() == tensorflow::DT_VARIANT) {
617 attr.set_on_host(true);
618 }
619 tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
620 if (src->shape().num_elements() == 0) {
621 *output = dst;
622 return Status::OK();
623 }
624 tensorflow::DeviceContext* src_device_context = nullptr;
625 if (!src_cpu) {
626 src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
627 }
628 tensorflow::DeviceContext* dst_device_context = nullptr;
629 if (!dst_cpu) {
630 dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
631 }
632 // TODO(ashankar): The Sync() call below may be more aggressive than
633 // necessary. It is based on knowledge of implementation details - that
634 // GPU devices are implemented using 3 streams - one for host->device copies,
635 // one for device->host copies and one for sending operations to the GPU.
636 // With that setup, Sync()ing across all 3 streams should be sufficient
637 // but more than necessary (since it waits for operations that might have
638 // nothing to do with this tensor to complete).
639 TF_RETURN_IF_ERROR(srcd->Sync());
640 tensorflow::Notification n;
641 tensorflow::Status status;
642 tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
643 srcd, dstd, tensorflow::AllocatorAttributes(),
644 tensorflow::AllocatorAttributes(), src, &dst,
645 0 /*dev_to_dev_stream_index*/,
646 [&status, &n](const tensorflow::Status& s) {
647 status = s;
648 n.Notify();
649 });
650 n.WaitForNotification();
651 if (status.ok()) {
652 *output = dst;
653 return Status::OK();
654 }
655 return status;
656 }
657
GetResourceDevice(const ResourceHandle & handle,EagerContext * ctx)658 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
659 if (ctx == nullptr) {
660 return nullptr;
661 }
662 Device* device = nullptr;
663 if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
664 LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
665 return nullptr;
666 }
667 return device;
668 }
669
DebugString() const670 string TensorHandle::DebugString() const {
671 DVLOG(1) << "Calling TensorHandle::DebugString() on " << this;
672
673 string out;
674 strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
675 // Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
676 strings::StrAppend(&out, ", Tensor: ",
677 device_ ? "?" : tensor_handle_data_->DebugString(), "\n");
678 return out;
679 }
680
681 } // namespace tensorflow
682