• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/core/common_runtime/copy_tensor.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_factory.h"
29 #include "tensorflow/core/common_runtime/eager/context.h"
30 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
31 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
32 #include "tensorflow/core/common_runtime/function.h"
33 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34 #include "tensorflow/core/framework/resource_mgr.h"
35 #include "tensorflow/core/framework/shape_inference.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #if !defined(IS_MOBILE_PLATFORM)
38 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
39 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
40 #endif  // IS_MOBILE_PLATFORM
41 #include "tensorflow/core/framework/rendezvous.h"
42 #include "tensorflow/core/framework/resource_var.h"
43 #include "tensorflow/core/framework/tensor.h"
44 #include "tensorflow/core/framework/types.pb.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/inlined_vector.h"
48 #include "tensorflow/core/lib/gtl/map_util.h"
49 #include "tensorflow/core/platform/fingerprint.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/thread_annotations.h"
52 #include "tensorflow/core/profiler/lib/traceme.h"
53 #include "tensorflow/core/public/session_options.h"
54 #include "tensorflow/core/public/version.h"
55 
56 namespace tensorflow {
57 
58 namespace {
59 #if !defined(IS_MOBILE_PLATFORM)
60 const int64 kInvalidOpId = -1;
61 const int32 kInvalidOutputNum = -1;
62 #endif
63 }  // namespace
64 
SetResourceHandleDtypeAndShape(std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes)65 void TensorHandle::SetResourceHandleDtypeAndShape(
66     std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
67   handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
68 }
69 
GetResourceHandleDtypesAndShapes(std::vector<DtypeAndPartialTensorShape> * result)70 Status TensorHandle::GetResourceHandleDtypesAndShapes(
71     std::vector<DtypeAndPartialTensorShape>* result) {
72   if (dtype != DT_RESOURCE) {
73     return errors::InvalidArgument(
74         "TensorHandle::GetResourceDtypeAndShape should be called on tensor "
75         "handles with data type DT_RESOURCE. Actual tensor: ",
76         dtype);
77   }
78 
79   if (IsRemote()) {
80     *result = handle_dtypes_and_shapes_;
81     return Status::OK();
82   }
83 
84   // Wait for this TensorHandle to be ready.
85   profiler::TraceMe activity(
86       "TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
87       profiler::TraceMeLevel::kInfo);
88   TF_RETURN_IF_ERROR(
89       WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
90 
91   *result = handle_dtypes_and_shapes_;
92   return Status::OK();
93 }
94 
CreateLocalHandle(const class Tensor & t,TensorHandle ** h)95 Status TensorHandle::CreateLocalHandle(const class Tensor& t,
96                                        TensorHandle** h) {
97   // TODO(b/136608821): Move away from nullptr
98   return CreateLocalHandle(t, /*d=*/nullptr, /*op_device=*/nullptr,
99                            /*ctx=*/nullptr, h);
100 }
101 
CreateLocalHandle(const class Tensor & t,Device * d,EagerContext * ctx,TensorHandle ** h)102 Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
103                                        EagerContext* ctx, TensorHandle** h) {
104   return CreateLocalHandle(t, d, d, ctx, h);
105 }
106 
CreateLocalHandle(const class Tensor & t,Device * d,Device * op_device,EagerContext * ctx,TensorHandle ** h)107 Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
108                                        Device* op_device, EagerContext* ctx,
109                                        TensorHandle** h) {
110   if (t.dtype() != DT_RESOURCE) {
111     *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
112                           t.dtype(), d, op_device, ctx);
113   } else {
114     const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0);
115     *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
116                           resource_handle, d, op_device, ctx);
117   }
118 
119   return Status::OK();
120 }
121 
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,DataType dtype,Device * d,Device * op_device,EagerContext * ctx)122 TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
123                            DataType dtype, Device* d, Device* op_device,
124                            EagerContext* ctx)
125     : dtype(dtype),
126       device_(d),
127       op_device_(op_device),
128       resource_device_(nullptr),
129 #if !defined(IS_MOBILE_PLATFORM)
130       remote_op_id_(kInvalidOpId),
131       remote_output_num_(kInvalidOutputNum),
132 #endif
133       ctx_(ctx),
134       is_remote_(false),
135       is_async_(false),
136       is_ready_(true),
137       tensor_handle_data_(std::move(t)) {
138   DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
139 }
140 
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,const ResourceHandle & resource_handle,Device * d,Device * op_device,EagerContext * ctx)141 TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
142                            const ResourceHandle& resource_handle, Device* d,
143                            Device* op_device, EagerContext* ctx)
144     : dtype(DT_RESOURCE),
145       device_(d),
146       op_device_(op_device),
147       resource_device_(GetResourceDevice(resource_handle, ctx)),
148 #if !defined(IS_MOBILE_PLATFORM)
149       remote_op_id_(kInvalidOpId),
150       remote_output_num_(kInvalidOutputNum),
151 #endif
152       ctx_(ctx),
153       is_remote_(false),
154       is_async_(false),
155       is_ready_(true),
156       handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
157       tensor_handle_data_(std::move(t)) {
158   DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
159 }
160 
CreateEmptyLocalHandle(bool async,Device * d,Device * op_device,Device * resource_device,DataType dtype,EagerContext * ctx,TensorHandle ** h)161 Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
162                                             Device* op_device,
163                                             Device* resource_device,
164                                             DataType dtype, EagerContext* ctx,
165                                             TensorHandle** h) {
166   *h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async,
167                         d, op_device, resource_device, dtype, ctx);
168 
169   return Status::OK();
170 }
171 
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,bool async,Device * d,Device * op_device,Device * resource_device,DataType dtype,EagerContext * ctx)172 TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
173                            bool async, Device* d, Device* op_device,
174                            Device* resource_device, DataType dtype,
175                            EagerContext* ctx)
176     : dtype(dtype),
177       device_(d),
178       op_device_(op_device),
179       resource_device_(resource_device),
180 #if !defined(IS_MOBILE_PLATFORM)
181       remote_op_id_(kInvalidOpId),
182       remote_output_num_(kInvalidOutputNum),
183 #endif
184       ctx_(ctx),
185       is_remote_(false),
186       is_async_(async),
187       is_ready_(!async),
188       tensor_handle_data_(std::move(t)) {
189   DVLOG(3) << "Creating Async Local TensorHandle: " << this
190            << " device: " << device_;
191 }
192 
193 #if !defined(IS_MOBILE_PLATFORM)
CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,DataType dtype,Device * d,Device * resource_device,EagerContext * ctx,TensorHandle ** h)194 Status TensorHandle::CreateRemoteHandle(
195     std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d,
196     Device* resource_device, EagerContext* ctx, TensorHandle** h) {
197   *h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx);
198 
199   return Status::OK();
200 }
201 
CreateRemoteHandle(int64 op_id,int output_num,const TensorShape & shape,const string & remote_task,uint64 context_id,DataType dtype,Device * d,Device * resource_device,EagerContext * ctx,TensorHandle ** h)202 Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
203                                         const TensorShape& shape,
204                                         const string& remote_task,
205                                         uint64 context_id, DataType dtype,
206                                         Device* d, Device* resource_device,
207                                         EagerContext* ctx, TensorHandle** h) {
208   *h = new TensorHandle(
209       absl::make_unique<RemoteTensorHandleData>(op_id, output_num, shape,
210                                                 remote_task, context_id, ctx),
211       dtype, d, resource_device, ctx);
212   return Status::OK();
213 }
214 
TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,DataType dtype,Device * d,Device * resource_device,EagerContext * ctx)215 TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
216                            DataType dtype, Device* d, Device* resource_device,
217                            EagerContext* ctx)
218     : dtype(dtype),
219       device_(d),
220       op_device_(d),
221       resource_device_(resource_device),
222       remote_op_id_(t->op_id()),
223       remote_output_num_(t->output_num()),
224       ctx_(ctx),
225       is_remote_(true),
226       is_async_(false),
227       is_ready_(true),
228       tensor_handle_data_(std::move(t)) {
229   DVLOG(3) << "Creating Remote TensorHandle: " << this
230            << " device: " << device_;
231 }
232 
CreateUnshapedRemoteHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,DataType dtype,Device * d,EagerContext * ctx,TensorHandle ** h)233 Status TensorHandle::CreateUnshapedRemoteHandle(
234     std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
235     Device* d, EagerContext* ctx, TensorHandle** h) {
236   *h = new TensorHandle(std::move(t), dtype, d, ctx);
237 
238   return Status::OK();
239 }
240 
CreateUnshapedRemoteHandle(int64 op_id,int32 output_num,const string & remote_task,uint64 context_id,DataType dtype,Device * device,EagerContext * ctx,TensorHandle ** h)241 Status TensorHandle::CreateUnshapedRemoteHandle(
242     int64 op_id, int32 output_num, const string& remote_task, uint64 context_id,
243     DataType dtype, Device* device, EagerContext* ctx, TensorHandle** h) {
244   *h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>(
245                             op_id, output_num, remote_task, context_id, ctx),
246                         dtype, device, ctx);
247   return Status::OK();
248 }
249 
TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,DataType dtype,Device * device,EagerContext * ctx)250 TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
251                            DataType dtype, Device* device, EagerContext* ctx)
252     : dtype(dtype),
253       device_(device),
254       op_device_(device),
255       resource_device_(dtype == DT_RESOURCE ? device : nullptr),
256       remote_op_id_(t->op_id()),
257       remote_output_num_(t->output_num()),
258       remote_task_(t->remote_task()),
259       remote_context_id_(t->context_id()),
260       ctx_(ctx),
261       is_remote_(true),
262       is_async_(true),
263       is_ready_(false),
264       tensor_handle_data_(std::move(t)) {
265   DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
266            << " device: " << device_;
267 }
268 #endif
269 
IsReady() const270 bool TensorHandle::IsReady() const {
271   // Avoid mutex acquisition for local sync handles
272   if (!is_async_ && !is_remote_) {
273     return true;
274   }
275 
276   tf_shared_lock l(mu_);
277   return is_ready_;
278 }
279 
WaitReady(const char * caller) const280 Status TensorHandle::WaitReady(const char* caller) const {
281   if (!IsReady()) {
282     profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
283                                profiler::TraceMeLevel::kInfo);
284     tf_shared_lock l(mu_);
285     mu_.Await(Condition(&is_ready_));
286   }
287   return is_poisoned_;
288 }
289 
Tensor(const tensorflow::Tensor ** t)290 Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
291   TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor"));
292   return tensor_handle_data_->Tensor(t);
293 }
294 
TensorValue(tensorflow::TensorValue * t)295 Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
296   TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue"));
297   return tensor_handle_data_->TensorValue(t);
298 }
299 
DeviceOrHostCPU(const EagerContext & ctx) const300 Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
301   return (device_ == nullptr) ? ctx.HostCPU() : device_;
302 }
303 
Shape(tensorflow::TensorShape * shape)304 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
305   if (!IsReady() && inference_shape_.IsFullyDefined()) {
306     bool fill = inference_shape_.AsTensorShape(shape);
307     DCHECK(fill);
308     return Status::OK();
309   } else {
310     TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape"));
311     return tensor_handle_data_->Shape(shape);
312   }
313 }
314 
InferenceShape(shape_inference::InferenceContext * const inference_context,shape_inference::ShapeHandle * shape_handle)315 Status TensorHandle::InferenceShape(
316     shape_inference::InferenceContext* const inference_context,
317     shape_inference::ShapeHandle* shape_handle) {
318   if (IsReady()) {
319     TF_RETURN_IF_ERROR(is_poisoned_);
320     std::vector<shape_inference::DimensionHandle> dims_handle;
321     int num_dims;
322     TF_RETURN_IF_ERROR(NumDims(&num_dims));
323     for (int i = 0; i < num_dims; i++) {
324       int64 dims;
325       TF_RETURN_IF_ERROR(Dim(i, &dims));
326       dims_handle.push_back(inference_context->MakeDim(dims));
327     }
328     *shape_handle = inference_context->MakeShape(dims_handle);
329     return Status::OK();
330   } else {
331     if (inference_shape_.unknown_rank()) {
332       *shape_handle = inference_context->UnknownShape();
333       return Status::OK();
334     }
335     std::vector<shape_inference::DimensionHandle> dims_handle(
336         inference_shape_.dims());
337     for (int i = 0; i < dims_handle.size(); i++) {
338       dims_handle[i] = inference_context->MakeDim(inference_shape_.dim_size(i));
339     }
340     *shape_handle = inference_context->MakeShape(dims_handle);
341     return Status::OK();
342   }
343 }
344 
SetInferenceShape(shape_inference::InferenceContext * const inference_context,const shape_inference::ShapeHandle & shape_handle)345 void TensorHandle::SetInferenceShape(
346     shape_inference::InferenceContext* const inference_context,
347     const shape_inference::ShapeHandle& shape_handle) {
348   auto num_dims = inference_context->Rank(shape_handle);
349   std::vector<int64> dims;
350   if (num_dims == shape_inference::InferenceContext::kUnknownRank) {
351     inference_shape_ = PartialTensorShape();
352     return;
353   }
354   DCHECK_GE(num_dims, 0);
355   dims.resize(num_dims);
356   for (size_t i = 0; i < num_dims; ++i) {
357     dims[i] = inference_context->Value(inference_context->Dim(shape_handle, i));
358   }
359   auto s = PartialTensorShape::MakePartialShape(dims.data(), num_dims,
360                                                 &inference_shape_);
361   DCHECK(s.ok());
362 }
363 
CopyInferenceShape(TensorHandle * other)364 Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
365   if (IsReady()) {
366     TF_RETURN_IF_ERROR(is_poisoned_);
367     return Status::OK();
368   }
369   if (other->IsReady()) {
370     TensorShape other_shape;
371     TF_RETURN_IF_ERROR(other->Shape(&other_shape));
372     inference_shape_ = other_shape;
373   } else {
374     inference_shape_ = other->inference_shape_;
375   }
376   return Status::OK();
377 }
378 
NumDims(int * num_dims) const379 Status TensorHandle::NumDims(int* num_dims) const {
380   DCHECK(num_dims != nullptr);
381   if (!IsReady() && !inference_shape_.unknown_rank()) {
382     *num_dims = inference_shape_.dims();
383     return Status::OK();
384   } else {
385     TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims"));
386     return tensor_handle_data_->NumDims(num_dims);
387   }
388 }
389 
Dim(int dim_index,int64 * dim) const390 Status TensorHandle::Dim(int dim_index, int64* dim) const {
391   DCHECK(dim != nullptr);
392   if (!IsReady() && !inference_shape_.unknown_rank() &&
393       inference_shape_.dim_size(dim_index) != -1) {
394     *dim = inference_shape_.dim_size(dim_index);
395     return Status::OK();
396   } else {
397     TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim"));
398     return tensor_handle_data_->Dim(dim_index, dim);
399   }
400 }
401 
NumElements(int64 * num_elements) const402 Status TensorHandle::NumElements(int64* num_elements) const {
403   DCHECK(num_elements != nullptr);
404   if (!IsReady() && inference_shape_.IsFullyDefined()) {
405     *num_elements = inference_shape_.num_elements();
406     return Status::OK();
407   } else {
408     TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements"));
409     return tensor_handle_data_->NumElements(num_elements);
410   }
411 }
412 
413 #if !defined(IS_MOBILE_PLATFORM)
RemoteAddress(Device * d,int64 * op_id,int32 * output_num) const414 Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
415                                    int32* output_num) const {
416   if (d != device_) {
417     tf_shared_lock l(mu_);
418     auto mirror = remote_mirrors_.find(d);
419     if (mirror != remote_mirrors_.end()) {
420       *op_id = mirror->second->op_id();
421       *output_num = mirror->second->output_num();
422       return Status::OK();
423     }
424 
425     auto unshaped_mirror = unshaped_remote_mirrors_.find(d);
426     if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
427       *op_id = unshaped_mirror->second->op_id();
428       *output_num = unshaped_mirror->second->output_num();
429       return Status::OK();
430     }
431 
432     return errors::FailedPrecondition(
433         "Could not find remote mirror for specified device");
434   }
435 
436   if (remote_op_id_ == kInvalidOpId ||
437       remote_output_num_ == kInvalidOutputNum) {
438     return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
439                                    ", output_num:", remote_output_num_,
440                                    ") is not set.");
441   }
442   *op_id = remote_op_id_;
443   *output_num = remote_output_num_;
444   return Status::OK();
445 }
446 
SetRemoteOpIdAndOutputNumToLocalTensorHandle(const int64 op_id,const int32 output_num)447 void TensorHandle::SetRemoteOpIdAndOutputNumToLocalTensorHandle(
448     const int64 op_id, const int32 output_num) {
449   DCHECK(!is_remote_);
450   remote_op_id_ = op_id;
451   remote_output_num_ = output_num;
452 }
453 
HasRemoteMirror(Device * d)454 bool TensorHandle::HasRemoteMirror(Device* d) {
455   tf_shared_lock l(mu_);
456   auto mirror = remote_mirrors_.find(d);
457   if (mirror != remote_mirrors_.end()) {
458     return true;
459   }
460 
461   auto unshaped_mirror = unshaped_remote_mirrors_.find(d);
462   if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
463     return true;
464   }
465 
466   return false;
467 }
468 
HasResourceShapeMirror(Device * d)469 bool TensorHandle::HasResourceShapeMirror(Device* d) {
470   tf_shared_lock l(mu_);
471   auto mirror = resource_shape_mirrors_.find(d);
472   if (mirror != resource_shape_mirrors_.end()) {
473     return true;
474   }
475   return false;
476 }
477 
AddUnshapedRemoteMirror(std::unique_ptr<UnshapedRemoteTensorHandleData> t,Device * d)478 Status TensorHandle::AddUnshapedRemoteMirror(
479     std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
480   mutex_lock l(mu_);
481   if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
482     return errors::Internal("Attempted to duplicate a remote mirror.");
483   }
484 
485   auto ret = unshaped_remote_mirrors_.insert(std::make_pair(d, std::move(t)));
486   if (!ret.second) {
487     return errors::Internal(
488         "Attempted to duplicate an unshaped remote mirror.");
489   }
490 
491   return Status::OK();
492 }
493 
AddResourceShapeMirror(std::unique_ptr<UnshapedRemoteTensorHandleData> t,Device * d)494 Status TensorHandle::AddResourceShapeMirror(
495     std::unique_ptr<UnshapedRemoteTensorHandleData> t, Device* d) {
496   mutex_lock l(mu_);
497   auto ret = resource_shape_mirrors_.insert(std::make_pair(d, std::move(t)));
498   if (!ret.second) {
499     return errors::Internal("Attempted to duplicate a resource shape mirror.");
500   }
501 
502   return Status::OK();
503 }
504 
AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,Device * d)505 Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
506                                      Device* d) {
507   mutex_lock l(mu_);
508   auto ret = remote_mirrors_.insert(std::make_pair(d, std::move(t)));
509   if (!ret.second) {
510     return errors::Internal("Attempted to duplicate a remote mirror.");
511   }
512 
513   return Status::OK();
514 }
515 
SetRemoteShape(const TensorShape & shape,tensorflow::Device * d)516 Status TensorHandle::SetRemoteShape(const TensorShape& shape,
517                                     tensorflow::Device* d) {
518   DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
519 
520   if (d != device_) {
521     mutex_lock l(mu_);
522     if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
523       return errors::Internal(
524           "Attempted to set remote shape for existing mirror.");
525     }
526 
527     auto elem = unshaped_remote_mirrors_.find(d);
528     if (elem == unshaped_remote_mirrors_.end()) {
529       return errors::Internal(
530           "Attempted to set remote shape for non-waiting mirror.");
531     }
532 
533     auto& data = elem->second;
534     data->ReleaseRemoteTensorHandle();
535     remote_mirrors_[d] = absl::make_unique<RemoteTensorHandleData>(
536         data->op_id(), data->output_num(), shape, data->remote_task(),
537         data->context_id(), data->ctx());
538     unshaped_remote_mirrors_.erase(elem);
539 
540     return Status::OK();
541   }
542 
543   DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles.";
544   DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
545 
546   UnshapedRemoteTensorHandleData* p =
547       reinterpret_cast<UnshapedRemoteTensorHandleData*>(
548           tensor_handle_data_.get());
549   p->ReleaseRemoteTensorHandle();
550   tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
551       remote_op_id_, remote_output_num_, shape, remote_task_,
552       remote_context_id_, ctx_);
553   is_poisoned_ = Status::OK();
554   mutex_lock l(mu_);
555   is_ready_ = true;
556 
557   return Status::OK();
558 }
559 #endif
560 
SetTensor(tensorflow::Tensor && tensor)561 Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
562   DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
563   DCHECK(!is_async_ || !IsReady())
564       << "SetTensor is only called on non-ready handles.";
565 
566   DVLOG(3) << "SetTensor on TensorHandle: " << this;
567 
568   if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) {
569     auto& resource_handle = tensor.flat<class ResourceHandle>()(0);
570     handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
571   }
572   tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
573   if (is_async_) {
574     is_poisoned_ = Status::OK();
575     mutex_lock l(mu_);
576     is_ready_ = true;
577   }
578 
579   return Status::OK();
580 }
581 
Poison(Status status)582 void TensorHandle::Poison(Status status) {
583   DCHECK(!is_async_ || !IsReady())
584       << "Poison(status) can only be called on non-ready handle: " << this;
585 
586   DVLOG(3) << "Poison on TensorHandle: " << this;
587 
588   is_poisoned_ = status;
589   mutex_lock l(mu_);
590   is_ready_ = true;
591 }
592 
CopyToDevice(const EagerContext & ctx,tensorflow::Device * dstd,tensorflow::Tensor * output)593 Status TensorHandle::CopyToDevice(const EagerContext& ctx,
594                                   tensorflow::Device* dstd,
595                                   tensorflow::Tensor* output) {
596   tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
597   const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
598   const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
599   bool is_same_device =
600       (srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
601 
602   const tensorflow::Tensor* src = nullptr;
603   TF_RETURN_IF_ERROR(Tensor(&src));
604   if (is_same_device) {
605     *output = *src;
606     return Status::OK();
607   }
608   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
609                    !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
610     return tensorflow::errors::InvalidArgument(
611         "Can't copy Tensor with type ",
612         tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
613         ".");
614   }
615   tensorflow::AllocatorAttributes attr;
616   if (src->dtype() == tensorflow::DT_VARIANT) {
617     attr.set_on_host(true);
618   }
619   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
620   if (src->shape().num_elements() == 0) {
621     *output = dst;
622     return Status::OK();
623   }
624   tensorflow::DeviceContext* src_device_context = nullptr;
625   if (!src_cpu) {
626     src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
627   }
628   tensorflow::DeviceContext* dst_device_context = nullptr;
629   if (!dst_cpu) {
630     dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
631   }
632   // TODO(ashankar): The Sync() call below may be more aggressive than
633   // necessary. It is based on knowledge of implementation details - that
634   // GPU devices are implemented using 3 streams - one for host->device copies,
635   // one for device->host copies and one for sending operations to the GPU.
636   // With that setup, Sync()ing across all 3 streams should be sufficient
637   // but more than necessary (since it waits for operations that might have
638   // nothing to do with this tensor to complete).
639   TF_RETURN_IF_ERROR(srcd->Sync());
640   tensorflow::Notification n;
641   tensorflow::Status status;
642   tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
643                                  srcd, dstd, tensorflow::AllocatorAttributes(),
644                                  tensorflow::AllocatorAttributes(), src, &dst,
645                                  0 /*dev_to_dev_stream_index*/,
646                                  [&status, &n](const tensorflow::Status& s) {
647                                    status = s;
648                                    n.Notify();
649                                  });
650   n.WaitForNotification();
651   if (status.ok()) {
652     *output = dst;
653     return Status::OK();
654   }
655   return status;
656 }
657 
GetResourceDevice(const ResourceHandle & handle,EagerContext * ctx)658 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
659   if (ctx == nullptr) {
660     return nullptr;
661   }
662   Device* device = nullptr;
663   if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
664     LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
665     return nullptr;
666   }
667   return device;
668 }
669 
DebugString() const670 string TensorHandle::DebugString() const {
671   DVLOG(1) << "Calling TensorHandle::DebugString() on " << this;
672 
673   string out;
674   strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
675   // Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
676   strings::StrAppend(&out, ", Tensor: ",
677                      device_ ? "?" : tensor_handle_data_->DebugString(), "\n");
678   return out;
679 }
680 
681 }  // namespace tensorflow
682