• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
16 
17 #include <algorithm>
18 #include <cstddef>
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/strings/substitute.h"
27 #include "absl/types/variant.h"
28 #include "tensorflow/c/tf_tensor_internal.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/copy_tensor.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
33 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/framework/shape_inference.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #include "tensorflow/core/platform/errors.h"
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
40 #endif  // IS_MOBILE_PLATFORM
41 #include "tensorflow/core/framework/resource_var.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/stringpiece.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/platform/mutex.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49 
50 namespace tensorflow {
51 
52 namespace {
GetRemoteDeviceIncarnation(Device * device)53 int64 GetRemoteDeviceIncarnation(Device* device) {
54   if (device == nullptr || device->IsLocal()) return 0;
55   return device->attributes().incarnation();
56 }
57 
SafeDeviceDebugString(Device * device)58 string SafeDeviceDebugString(Device* device) {
59   if (device == nullptr) {
60     return "[]";
61   } else {
62     return device->DebugString();
63   }
64 }
65 }  // namespace
66 
PackedTensorHandleData(std::vector<TensorHandle * > && handles,const TensorShape & shape)67 TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
68     std::vector<TensorHandle*>&& handles, const TensorShape& shape)
69     : handles_(std::move(handles)), shape_(shape) {
70   for (auto* handle : handles_) {
71     handle->Ref();
72   }
73 }
74 
~PackedTensorHandleData()75 TensorHandle::PackedTensorHandleData::~PackedTensorHandleData() {
76   for (auto* handle : handles_) {
77     handle->Unref();
78   }
79 }
80 
Shape(TensorShape * shape) const81 Status TensorHandle::PackedTensorHandleData::Shape(TensorShape* shape) const {
82   *shape = shape_;
83   return Status::OK();
84 }
85 
NumDims(int * num_dims) const86 Status TensorHandle::PackedTensorHandleData::NumDims(int* num_dims) const {
87   *num_dims = shape_.dims();
88   return Status::OK();
89 }
90 
Dim(int dim_index,int64 * dim) const91 Status TensorHandle::PackedTensorHandleData::Dim(int dim_index,
92                                                  int64* dim) const {
93   *dim = shape_.dim_size(dim_index);
94   return Status::OK();
95 }
96 
NumElements(int64 * num_elements) const97 Status TensorHandle::PackedTensorHandleData::NumElements(
98     int64* num_elements) const {
99   *num_elements = shape_.num_elements();
100   return Status::OK();
101 }
102 
Unprotect()103 Status TensorHandle::PackedTensorHandleData::Unprotect() {
104   for (auto* handle : handles_) {
105     TF_RETURN_IF_ERROR(absl::visit([](auto& data) { return data.Unprotect(); },
106                                    handle->data_));
107   }
108   return Status::OK();
109 }
110 
IsReady() const111 bool TensorHandle::PackedTensorHandleData::IsReady() const {
112   {
113     tf_shared_lock l(mu_);
114     if (!is_poisoned_.ok()) {
115       return true;
116     }
117   }
118   for (auto* handle : handles_) {
119     if (!handle->IsReady()) {
120       return false;
121     }
122   }
123   return true;
124 }
125 
WaitReady(const char * caller) const126 Status TensorHandle::PackedTensorHandleData::WaitReady(
127     const char* caller) const {
128   {
129     tf_shared_lock l(mu_);
130     if (!is_poisoned_.ok()) {
131       return is_poisoned_;
132     }
133   }
134   for (auto* handle : handles_) {
135     TF_RETURN_IF_ERROR(handle->WaitReady(caller));
136   }
137   return Status::OK();
138 }
139 
Poison(Status status)140 void TensorHandle::PackedTensorHandleData::Poison(Status status) {
141   mutex_lock l(mu_);
142   is_poisoned_ = status;
143 }
144 
DebugString() const145 string TensorHandle::PackedTensorHandleData::DebugString() const {
146   string debug_str = "PackedTensorHandleData: ";
147   for (const auto* handle : handles_) {
148     debug_str.append(
149         absl::StrCat(absl::visit([](auto& data) { return data.DebugString(); },
150                                  handle->data_),
151                      "; "));
152   }
153   return debug_str;
154 }
155 
NumPackedHandles() const156 int TensorHandle::PackedTensorHandleData::NumPackedHandles() const {
157   return handles_.size();
158 }
159 
ExtractPackedHandle(const int index,TensorHandle ** handle) const160 Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
161     const int index, TensorHandle** handle) const {
162   if (index < 0 || index >= handles_.size()) {
163     return errors::InvalidArgument("Expect an index within [0, ",
164                                    handles_.size(), "), but got ", index);
165   }
166   *handle = handles_.at(index);
167   return Status::OK();
168 }
169 
SetResourceHandleDtypeAndShape(std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes)170 void TensorHandle::SetResourceHandleDtypeAndShape(
171     std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
172   handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
173 }
174 
GetResourceHandleDtypesAndShapes(std::vector<DtypeAndPartialTensorShape> * result)175 Status TensorHandle::GetResourceHandleDtypesAndShapes(
176     std::vector<DtypeAndPartialTensorShape>* result) {
177   if (dtype != DT_RESOURCE) {
178     return errors::InvalidArgument(
179         "TensorHandle::GetResourceDtypeAndShape should be called on tensor "
180         "handles with data type DT_RESOURCE. Actual tensor: ",
181         dtype);
182   }
183 
184   if (Type() != LOCAL) {
185     *result = handle_dtypes_and_shapes_;
186     return Status::OK();
187   }
188 
189   // Wait for this TensorHandle to be ready.
190   profiler::TraceMe activity("TensorHandle::GetResourceHandleInfo WaitReady",
191                              profiler::TraceMeLevel::kInfo);
192   auto& data = absl::get<LocalTensorHandleData>(data_);
193   TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo"));
194 
195   *result = handle_dtypes_and_shapes_;
196   return Status::OK();
197 }
198 
NumPackedHandles() const199 int TensorHandle::NumPackedHandles() const {
200   if (Type() != PACKED) {
201     return 0;
202   }
203   return absl::get<PackedTensorHandleData>(data_).NumPackedHandles();
204 }
205 
ExtractPackedHandle(const int index,TensorHandle ** handle) const206 Status TensorHandle::ExtractPackedHandle(const int index,
207                                          TensorHandle** handle) const {
208   if (Type() != PACKED) {
209     return errors::Internal("Invalid ExtractPackedHandleOnDevice call on a",
210                             TypeString(), " handle: ", this);
211   }
212   return absl::get<PackedTensorHandleData>(data_).ExtractPackedHandle(index,
213                                                                       handle);
214 }
215 
CreateLocalHandle(const tensorflow::Tensor & t)216 TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) {
217   // TODO(b/136608821): Move away from nullptr
218   tensorflow::Tensor tensor = t;
219   return CreateLocalHandle(std::move(tensor),
220                            /*d=*/nullptr,
221                            /*op_device=*/nullptr,
222                            /*ctx=*/nullptr);
223 }
224 
CreateLocalHandle(tensorflow::Tensor && t,Device * d,Device * op_device,EagerContext * ctx)225 TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
226                                               Device* op_device,
227                                               EagerContext* ctx) {
228   return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx);
229 }
230 
CreateLocalHandle(tensorflow::Tensor && t,Device * d,Device * op_device,Device * resource_device,EagerContext * ctx)231 TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
232                                               Device* op_device,
233                                               Device* resource_device,
234                                               EagerContext* ctx) {
235   if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
236     return new TensorHandle(std::move(t), d, op_device, ctx);
237   } else {
238     return new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
239   }
240 }
241 
TensorHandle(tensorflow::Tensor && t,Device * d,Device * op_device,Device * resource_device,EagerContext * ctx)242 TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
243                            Device* resource_device, EagerContext* ctx)
244     : ImmediateExecutionTensorHandle(kEager),
245       dtype(t.dtype()),
246       device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
247       op_device_(op_device),
248       resource_device_(resource_device),
249       resource_remote_device_incarnation_(
250           GetRemoteDeviceIncarnation(resource_device_)),
251       ctx_(ctx),
252       data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
253   DVLOG(3) << "Creating Local TensorHandle: " << this
254            << " device: " << SafeDeviceDebugString(device_)
255            << " tensor: " << t.DeviceSafeDebugString();
256 }
257 
TensorHandle(tensorflow::Tensor && t,Device * d,Device * op_device,EagerContext * ctx)258 TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
259                            EagerContext* ctx)
260     : ImmediateExecutionTensorHandle(kEager),
261       dtype(DT_RESOURCE),
262       device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
263       op_device_(op_device),
264       resource_device_(
265           GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
266       resource_remote_device_incarnation_(
267           GetRemoteDeviceIncarnation(resource_device_)),
268       ctx_(ctx),
269       handle_dtypes_and_shapes_(
270           t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
271       data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
272   DVLOG(3) << "Creating Local TensorHandle: " << this
273            << " device: " << SafeDeviceDebugString(device_)
274            << " tensor: " << t.DeviceSafeDebugString();
275 }
276 
277 
CreateEmptyLocalHandle(Device * d,Device * op_device,Device * resource_device,tensorflow::DataType dtype,EagerContext * ctx)278 TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
279                                                    Device* resource_device,
280                                                    tensorflow::DataType dtype,
281                                                    EagerContext* ctx) {
282   return new TensorHandle(d, op_device, resource_device, dtype, ctx);
283 }
284 
TensorHandle(Device * d,Device * op_device,Device * resource_device,tensorflow::DataType dtype,EagerContext * ctx)285 TensorHandle::TensorHandle(Device* d, Device* op_device,
286                            Device* resource_device, tensorflow::DataType dtype,
287                            EagerContext* ctx)
288     : ImmediateExecutionTensorHandle(kEager),
289       dtype(dtype),
290       device_((d == ctx->HostCPU()) ? nullptr : d),
291       op_device_(op_device),
292       resource_device_(resource_device),
293       resource_remote_device_incarnation_(
294           GetRemoteDeviceIncarnation(resource_device_)),
295       ctx_(ctx),
296       data_(absl::in_place_type<LocalTensorHandleData>) {
297   DVLOG(3) << "Creating empty Local TensorHandle: " << this
298            << " device: " << SafeDeviceDebugString(device_);
299 }
300 
CreatePackedHandle(std::vector<TensorHandle * > && handles,const tensorflow::DataType dtype,const tensorflow::TensorShape & shape,const string & device_name,EagerContext * ctx,TensorHandle ** packed_handle)301 Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
302                                         const tensorflow::DataType dtype,
303                                         const tensorflow::TensorShape& shape,
304                                         const string& device_name,
305                                         EagerContext* ctx,
306                                         TensorHandle** packed_handle) {
307   if (handles.empty()) {
308     return errors::InvalidArgument("Handles should not be empty.");
309   }
310 
311   std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
312   if (dtype == DT_RESOURCE) {
313     TF_RETURN_IF_ERROR(
314         handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
315   }
316   std::vector<string> devices;
317   devices.reserve(handles.size());
318   for (auto* handle : handles) {
319     devices.push_back(handle->op_device() ? handle->op_device()->name()
320                                           : ctx->HostCPU()->name());
321   }
322 
323   CompositeDevice* composite_device = nullptr;
324   TF_RETURN_IF_ERROR(ctx->FindOrCreateCompositeDevice(devices, device_name,
325                                                       &composite_device));
326   *packed_handle =
327       new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
328   (*packed_handle)
329       ->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
330   return Status::OK();
331 }
332 
CreatePackedHandle(std::vector<TensorHandle * > && handles,EagerContext * ctx,TensorHandle ** packed_handle)333 Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
334                                         EagerContext* ctx,
335                                         TensorHandle** packed_handle) {
336   if (handles.empty()) {
337     return errors::InvalidArgument("Handles should not be empty.");
338   }
339 
340   // Get the dtype and shape from the fisrt handle since all handles have the
341   // same dtype and shape.
342   tensorflow::DataType dtype = handles.at(0)->dtype;
343   tensorflow::TensorShape shape;
344   TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
345   return CreatePackedHandle(std::move(handles), dtype, shape,
346                             /*device_name*/ "", ctx, packed_handle);
347 }
348 
TensorHandle(std::vector<TensorHandle * > && handles,Device * device,const tensorflow::DataType dtype,const tensorflow::TensorShape & shape,EagerContext * ctx)349 TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
350                            const tensorflow::DataType dtype,
351                            const tensorflow::TensorShape& shape,
352                            EagerContext* ctx)
353     : ImmediateExecutionTensorHandle(kEager),
354       dtype(dtype),
355       device_(device),
356       op_device_(device),
357       resource_device_(dtype == DT_RESOURCE ? device : nullptr),
358       resource_remote_device_incarnation_(
359           GetRemoteDeviceIncarnation(resource_device_)),
360       ctx_(ctx),
361       data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
362             shape) {
363   DVLOG(3) << "Creating a packed TensorHandle: " << this
364            << " device: " << SafeDeviceDebugString(device_);
365 }
366 
367 #if !defined(IS_MOBILE_PLATFORM)
CreateUnshapedRemoteHandle(int64 op_id,int32 output_num,const string & remote_task,tensorflow::DataType dtype,Device * d,EagerContext * ctx,const bool unknown_device)368 TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
369     int64 op_id, int32 output_num, const string& remote_task,
370     tensorflow::DataType dtype, Device* d, EagerContext* ctx,
371     const bool unknown_device) {
372   return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx,
373                           unknown_device);
374 }
375 
TensorHandle(int64 op_id,int32 output_num,const string & remote_task,tensorflow::DataType dtype,Device * d,EagerContext * ctx,const bool unknown_device)376 TensorHandle::TensorHandle(int64 op_id, int32 output_num,
377                            const string& remote_task,
378                            tensorflow::DataType dtype, Device* d,
379                            EagerContext* ctx, const bool unknown_device)
380     : ImmediateExecutionTensorHandle(kEager),
381       dtype(dtype),
382       device_(d),
383       op_device_(d),
384       resource_device_(dtype == DT_RESOURCE ? d : nullptr),
385       resource_remote_device_incarnation_(
386           GetRemoteDeviceIncarnation(resource_device_)),
387       unknown_device_(unknown_device),
388       ctx_(ctx),
389       data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
390             remote_task, ctx) {
391   DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
392            << " device: " << SafeDeviceDebugString(device_);
393 }
394 
CreateLazyRemoteHandle(int64 op_id,int32 output_num,tensorflow::DataType dtype,Device * d,const bool is_ready,EagerContext * ctx)395 TensorHandle* TensorHandle::CreateLazyRemoteHandle(
396     int64 op_id, int32 output_num, tensorflow::DataType dtype, Device* d,
397     const bool is_ready, EagerContext* ctx) {
398   return new TensorHandle(op_id, output_num, dtype, d, is_ready, ctx);
399 }
400 
TensorHandle(int64 op_id,int32 output_num,tensorflow::DataType dtype,Device * d,const bool is_ready,EagerContext * ctx)401 TensorHandle::TensorHandle(int64 op_id, int32 output_num,
402                            tensorflow::DataType dtype, Device* d,
403                            const bool is_ready, EagerContext* ctx)
404     : ImmediateExecutionTensorHandle(kEager),
405       dtype(dtype),
406       device_(d),
407       op_device_(d),
408       resource_device_(dtype == DT_RESOURCE ? d : nullptr),
409       resource_remote_device_incarnation_(
410           GetRemoteDeviceIncarnation(resource_device_)),
411       ctx_(ctx),
412       data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
413             ctx->GetContextViewId(), is_ready) {
414   DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
415            << " device: " << SafeDeviceDebugString(device_);
416 }
417 #endif
418 
~TensorHandle()419 TensorHandle::~TensorHandle() { DVLOG(3) << "Deleting tensor handle " << this; }
420 
Release()421 void TensorHandle::Release() {
422   DVLOG(3) << "Releasing tensor handle " << this;
423   Unref();
424 }
425 
DataType() const426 tensorflow::DataType TensorHandle::DataType() const { return dtype; }
427 
IsReady() const428 bool TensorHandle::IsReady() const {
429   return absl::visit([](auto& data) { return data.IsReady(); }, data_);
430 }
431 
WaitReady(const char * caller) const432 Status TensorHandle::WaitReady(const char* caller) const {
433   return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
434                      data_);
435 }
436 
Type() const437 TensorHandle::HandleType TensorHandle::Type() const {
438   if (data_.index() == 0) {
439     return LOCAL;
440   } else if (data_.index() == 1) {
441     return PACKED;
442   } else {
443     return REMOTE;
444   }
445 }
446 
TypeString() const447 string TensorHandle::TypeString() const {
448   if (data_.index() == 0) {
449     return "LOCAL";
450   } else if (data_.index() == 1) {
451     return "PACKED";
452   } else {
453     return "REMOTE";
454   }
455 }
456 
Tensor(const tensorflow::Tensor ** t) const457 Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
458   DVLOG(3) << "Tensor on TensorHandle: " << this;
459 
460   if (Type() != LOCAL) {
461     return errors::Internal("Invalid Tensor call on a ", TypeString(),
462                             " handle: ", this);
463   }
464 
465   auto& data = absl::get<LocalTensorHandleData>(data_);
466   return data.Tensor(t);
467 }
468 
TensorFromDevice(const Device * d,const tensorflow::Tensor ** t) const469 Status TensorHandle::TensorFromDevice(const Device* d,
470                                       const tensorflow::Tensor** t) const {
471   DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
472 
473   if (d == device_) {
474     if (Type() != LOCAL) {
475       return errors::Internal("Invalid Tensor call on a ", TypeString(),
476                               " handle: ", this);
477     }
478 
479     auto& data = absl::get<LocalTensorHandleData>(data_);
480     return data.Tensor(t);
481   }
482 
483   tf_shared_lock l(mu_);
484   auto elem = local_mirrors_.find(d);
485   if (elem == local_mirrors_.end()) {
486     return errors::Internal("Invalid device: ", d,
487                             " in Tensor call to handle: ", this);
488   }
489 
490   auto& mirror = elem->second;
491   return mirror.Tensor(t);
492 }
493 
TensorValue(const Device * d,tensorflow::TensorValue * t)494 Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
495   DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
496 
497   if (d == device_) {
498     if (Type() != LOCAL) {
499       return errors::Internal("Invalid TensorValue call on a ", TypeString(),
500                               " handle: ", this);
501     }
502 
503     auto& data = absl::get<LocalTensorHandleData>(data_);
504     return data.TensorValue(t);
505   }
506 
507   tf_shared_lock l(mu_);
508   auto elem = local_mirrors_.find(d);
509   if (elem == local_mirrors_.end()) {
510     return errors::Internal("Invalid device: ", d,
511                             " in TensorValue call to handle: ", this);
512   }
513 
514   auto& mirror = elem->second;
515   return mirror.TensorValue(t);
516 }
517 
WaitUnknownDevice() const518 Status TensorHandle::WaitUnknownDevice() const {
519   if (unknown_device_) {
520     TF_RETURN_IF_ERROR(absl::visit(
521         [](auto& data) {
522           return data.WaitReady("TensorHandle::UnknownDevice");
523         },
524         data_));
525   }
526   return Status::OK();
527 }
528 
DeviceOrHostCPU(const EagerContext & ctx) const529 Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
530   return (device_ == nullptr) ? ctx.HostCPU() : device_;
531 }
532 
Shape(tensorflow::TensorShape * shape)533 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
534   if (!IsReady() && inference_shape_.IsFullyDefined()) {
535     bool fill = inference_shape_.AsTensorShape(shape);
536     DCHECK(fill);
537     return Status::OK();
538   } else {
539     return absl::visit([shape](auto& data) { return data.Shape(shape); },
540                        data_);
541   }
542 }
543 
InferenceShape(shape_inference::InferenceContext * const inference_context,shape_inference::ShapeHandle * shape_handle)544 Status TensorHandle::InferenceShape(
545     shape_inference::InferenceContext* const inference_context,
546     shape_inference::ShapeHandle* shape_handle) {
547   if (IsReady()) {
548     TF_RETURN_IF_ERROR(is_poisoned_);
549     std::vector<shape_inference::DimensionHandle> dims_handle;
550     int num_dims;
551     TF_RETURN_IF_ERROR(NumDims(&num_dims));
552     for (int i = 0; i < num_dims; i++) {
553       int64 dims;
554       TF_RETURN_IF_ERROR(Dim(i, &dims));
555       dims_handle.push_back(inference_context->MakeDim(dims));
556     }
557     *shape_handle = inference_context->MakeShape(dims_handle);
558     return Status::OK();
559   } else {
560     if (inference_shape_.unknown_rank()) {
561       *shape_handle = inference_context->UnknownShape();
562       return Status::OK();
563     }
564     std::vector<shape_inference::DimensionHandle> dims_handle(
565         inference_shape_.dims());
566     for (int i = 0; i < dims_handle.size(); i++) {
567       dims_handle[i] = inference_context->MakeDim(inference_shape_.dim_size(i));
568     }
569     *shape_handle = inference_context->MakeShape(dims_handle);
570     return Status::OK();
571   }
572 }
573 
SetInferenceShape(shape_inference::InferenceContext * const inference_context,const shape_inference::ShapeHandle & shape_handle)574 void TensorHandle::SetInferenceShape(
575     shape_inference::InferenceContext* const inference_context,
576     const shape_inference::ShapeHandle& shape_handle) {
577   auto num_dims = inference_context->Rank(shape_handle);
578   std::vector<int64> dims;
579   if (num_dims == shape_inference::InferenceContext::kUnknownRank) {
580     inference_shape_ = PartialTensorShape();
581     return;
582   }
583   DCHECK_GE(num_dims, 0);
584   dims.resize(num_dims);
585   for (size_t i = 0; i < num_dims; ++i) {
586     dims[i] = inference_context->Value(inference_context->Dim(shape_handle, i));
587   }
588   auto s = PartialTensorShape::MakePartialShape(dims.data(), num_dims,
589                                                 &inference_shape_);
590   DCHECK(s.ok());
591 }
592 
CopyInferenceShape(TensorHandle * other)593 Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
594   if (IsReady()) {
595     TF_RETURN_IF_ERROR(is_poisoned_);
596     return Status::OK();
597   }
598   if (other->IsReady()) {
599     TensorShape other_shape;
600     TF_RETURN_IF_ERROR(other->Shape(&other_shape));
601     inference_shape_ = other_shape;
602   } else {
603     inference_shape_ = other->inference_shape_;
604   }
605   return Status::OK();
606 }
607 
Shape(tensorflow::PartialTensorShape * shape) const608 Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const {
609   DCHECK(shape != nullptr);
610   if (!IsReady() && !inference_shape_.unknown_rank()) {
611     *shape = inference_shape_;
612     return Status::OK();
613   } else {
614     auto result = absl::visit(
615         [](auto& data) {
616           TensorShape shape;
617           Status s = data.Shape(&shape);
618           return std::make_pair(shape, s);
619         },
620         data_);
621     TF_RETURN_IF_ERROR(result.second);
622     *shape = result.first;
623   }
624   return Status::OK();
625 }
626 
NumDims(int * num_dims) const627 Status TensorHandle::NumDims(int* num_dims) const {
628   DCHECK(num_dims != nullptr);
629   if (!IsReady() && !inference_shape_.unknown_rank()) {
630     *num_dims = inference_shape_.dims();
631     return Status::OK();
632   } else {
633     return absl::visit(
634         [num_dims](auto& data) { return data.NumDims(num_dims); }, data_);
635   }
636 }
637 
Dim(int dim_index,int64 * dim) const638 Status TensorHandle::Dim(int dim_index, int64* dim) const {
639   DCHECK(dim != nullptr);
640   if (!IsReady() && !inference_shape_.unknown_rank() &&
641       inference_shape_.dim_size(dim_index) != -1) {
642     *dim = inference_shape_.dim_size(dim_index);
643     return Status::OK();
644   } else {
645     return absl::visit(
646         [dim_index, dim](auto& data) { return data.Dim(dim_index, dim); },
647         data_);
648   }
649 }
650 
NumElements(int64 * num_elements) const651 Status TensorHandle::NumElements(int64* num_elements) const {
652   DCHECK(num_elements != nullptr);
653   if (!IsReady() && inference_shape_.IsFullyDefined()) {
654     *num_elements = inference_shape_.num_elements();
655     return Status::OK();
656   } else {
657     return absl::visit(
658         [num_elements](auto& data) { return data.NumElements(num_elements); },
659         data_);
660   }
661 }
662 
Unprotect(const Device * d)663 Status TensorHandle::Unprotect(const Device* d) {
664   DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
665 
666   if (d == device_) {
667     return absl::visit([](auto& data) { return data.Unprotect(); }, data_);
668   }
669 
670   tf_shared_lock l(mu_);
671   auto elem = local_mirrors_.find(d);
672   if (elem == local_mirrors_.end()) {
673     return errors::Internal("Invalid device: ", d,
674                             " in Unprotect call to handle: ", this);
675   }
676 
677   // Check if the handle is non-empty
678   auto& mirror = elem->second;
679   return mirror.Unprotect();
680 }
681 
HasLocalMirror(const Device * d) const682 bool TensorHandle::HasLocalMirror(const Device* d) const {
683   DVLOG(3) << "HasLocalMirror on TensorHandle: " << this << " device: " << d;
684 
685   tf_shared_lock l(mu_);
686   return local_mirrors_.find(d) != local_mirrors_.end();
687 }
688 
AddEmptyLocalMirror(const Device * d)689 Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
690   DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this
691            << " device: " << d;
692 
693   if (d == device_) {
694     return errors::Internal("Cannot add mirror for primary device.");
695   }
696 
697   mutex_lock l(mu_);
698   if (local_mirrors_.find(d) != local_mirrors_.end()) {
699     return errors::AlreadyExists("Attempted to duplicate a local mirror.");
700   }
701 
702   local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
703                          std::forward_as_tuple());
704 
705   return Status::OK();
706 }
707 
708 #if !defined(IS_MOBILE_PLATFORM)
RemoteAddress(const Device * d,const bool wait_until_ready,int64 * op_id,int32 * output_num) const709 Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready,
710                                    int64* op_id, int32* output_num) const {
711   DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
712            << " " << d->name();
713 
714   if (d != device_) {
715     tf_shared_lock l(mu_);
716     auto mirror = remote_mirrors_.find(d->name());
717     if (mirror != remote_mirrors_.end()) {
718       return mirror->second.OpIdAndOutputNum(wait_until_ready, op_id,
719                                              output_num);
720     }
721 
722     return errors::FailedPrecondition(
723         "Could not find remote mirror for specified device");
724   }
725 
726   if (Type() != REMOTE) {
727     return errors::InvalidArgument("Primary device is not remote");
728   }
729 
730   auto& data = absl::get<RemoteTensorHandleData>(data_);
731   return data.OpIdAndOutputNum(wait_until_ready, op_id, output_num);
732 }
733 
HasRemoteMirror(const Device * d,uint64 context_view_id) const734 bool TensorHandle::HasRemoteMirror(const Device* d,
735                                    uint64 context_view_id) const {
736   DVLOG(3) << "HasRemoteMirror on TensorHandle: " << this << " device: " << d
737            << " " << d->name();
738 
739   tf_shared_lock l(mu_);
740   auto mirror = remote_mirrors_.find(d->name());
741   if (mirror != remote_mirrors_.end()) {
742     // Check if mirror is stale
743     if (mirror->second.context_view_id() != context_view_id) {
744       return false;
745     }
746     return true;
747   }
748 
749   return false;
750 }
751 
HasResourceShapeMirror(const Device * d,uint64 context_view_id) const752 bool TensorHandle::HasResourceShapeMirror(const Device* d,
753                                           uint64 context_view_id) const {
754   DVLOG(3) << "HasResourceShapeMirror on TensorHandle: " << this
755            << " device: " << d << " " << d->name();
756 
757   tf_shared_lock l(mu_);
758   auto mirror = resource_shape_mirrors_.find(d->name());
759   if (mirror != resource_shape_mirrors_.end()) {
760     // Check if mirror is stale
761     if (mirror->second.context_view_id() != context_view_id) {
762       return false;
763     }
764     return true;
765   }
766   return false;
767 }
768 
AddUnshapedRemoteMirror(const Device * d,int64 op_id,int output_num,const string & remote_task,EagerContext * ctx)769 Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id,
770                                              int output_num,
771                                              const string& remote_task,
772                                              EagerContext* ctx) {
773   DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
774            << " device: " << d << " " << d->name() << " op_id: " << op_id
775            << " output_num: " << output_num;
776 
777   mutex_lock l(mu_);
778   auto remote_mirror = remote_mirrors_.find(d->name());
779   if (remote_mirror != remote_mirrors_.end()) {
780     if (remote_mirror->second.context_view_id() >= ctx->GetContextId()) {
781       return errors::Internal("Attempted to duplicate a remote mirror.");
782     }
783     // Remove stale mirror
784     remote_mirrors_.erase(remote_mirror);
785   }
786 
787   remote_mirrors_.emplace(
788       std::piecewise_construct, std::forward_as_tuple(d->name()),
789       std::forward_as_tuple(op_id, output_num, remote_task, ctx));
790 
791   return Status::OK();
792 }
793 
AddResourceShapeMirror(const Device * d,int64 op_id,int output_num,EagerContext * ctx)794 Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
795                                             int output_num, EagerContext* ctx) {
796   DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
797 
798   mutex_lock l(mu_);
799   auto mirror = resource_shape_mirrors_.find(d->name());
800   if (mirror != resource_shape_mirrors_.end()) {
801     if (mirror->second.context_view_id() == ctx->GetContextViewId()) {
802       return errors::Internal(
803           "Attempted to duplicate a resource shape mirror.");
804     }
805     // Remove stale mirror
806     resource_shape_mirrors_.erase(mirror);
807   }
808 
809   resource_shape_mirrors_.emplace(
810       std::piecewise_construct, std::forward_as_tuple(d->name()),
811       std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId(),
812                             /*is_ready=*/true));
813 
814   return Status::OK();
815 }
816 
SetRemoteShape(const TensorShape & shape,const Device * d,uint64 context_view_id)817 Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
818                                     uint64 context_view_id) {
819   return SetRemoteShapeAndDevice(shape, d, context_view_id, /*op_device=*/"");
820 }
821 
SetRemoteShapeAndDevice(const TensorShape & shape,const Device * d,uint64 context_view_id,string op_device)822 Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape,
823                                              const Device* d,
824                                              uint64 context_view_id,
825                                              string op_device) {
826   DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
827            << " " << d->name();
828 
829   if (d != device_) {
830     tf_shared_lock l(mu_);
831     auto remote_mirror = remote_mirrors_.find(d->name());
832     if (remote_mirror == remote_mirrors_.end()) {
833       return Status::OK();
834     }
835     auto& mirror = remote_mirror->second;
836     if (mirror.context_view_id() == context_view_id) {
837       return mirror.SetShape(shape);
838     } else if (mirror.context_view_id() < context_view_id) {
839       return errors::Internal(
840           absl::Substitute("Unexpected context_view_id ($0) which should not "
841                            "be newer than the "
842                            "one ($1) associated to the remote mirror.",
843                            context_view_id, mirror.context_view_id()));
844     } else {
845       LOG(WARNING) << "SetRemoteShape is ignored for a remote mirror that is "
846                       "accociated with a newer context_view_id.";
847     }
848     return Status::OK();
849   }
850 
851   DCHECK(Type() == REMOTE)
852       << "SetRemoteShape is only called on remote handles.";
853 
854   auto& data = absl::get<RemoteTensorHandleData>(data_);
855   // context_view_id is currently used to validate mirrors. The shape of
856   // RemoteTensorHandleData should be set without checking context_view_id.
857   // The reason behind it is that for the primary copy of data, if the remote
858   // worker / device is removed, the consumer should report a connection error
859   // indicating the remote tensor is no longer available.
860   // For mirrors, this is not the case because they colocate with the data
861   // consuming op/function device, and we (for now) have to aggressively
862   // invalidate those copies to avoid any false positives during cluster update.
863   if (op_device.empty()) {
864     return data.SetShape(shape);
865   } else {
866     if (!unknown_device_) {
867       return errors::Internal("Cannot reset known devices.");
868     }
869     Device* device;
870     TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(op_device.c_str(), &device));
871     device_ = device;
872     op_device_ = device;
873     resource_device_ = dtype == DT_RESOURCE ? device : nullptr;
874     resource_remote_device_incarnation_ =
875         GetRemoteDeviceIncarnation(resource_device_);
876     string remote_task;
877     if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
878       return errors::InvalidArgument(
879           "Unable to find remote task corresponding to device ",
880           device->name());
881     }
882     return data.SetShapeAndRemoteTask(shape, remote_task);
883   }
884 }
885 
PoisonRemote(Status status,const Device * d,uint64 context_view_id)886 void TensorHandle::PoisonRemote(Status status, const Device* d,
887                                 uint64 context_view_id) {
888   DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d
889            << " " << d->name();
890 
891   if (d == device_) {
892     DCHECK(Type() == REMOTE)
893         << "Poison can only be on remote handles: " << this;
894 
895     auto& data = absl::get<RemoteTensorHandleData>(data_);
896     data.Poison(status);
897   } else {
898     tf_shared_lock l(mu_);
899     auto mirror = remote_mirrors_.find(d->name());
900     if (mirror != remote_mirrors_.end()) {
901       if (mirror->second.context_view_id() == context_view_id) {
902         mirror->second.Poison(status);
903       }
904     }
905   }
906 }
907 #endif
908 
AddLocalMirror(tensorflow::Tensor && tensor,const Device * d)909 Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
910                                     const Device* d) {
911   if (d == device_) {
912     return errors::Internal(
913         "Local mirror assign conflicts with primary device.");
914   }
915 
916   mutex_lock l(mu_);
917   auto elem =
918       local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
919                              std::forward_as_tuple(std::move(tensor)));
920   if (!elem.second) {
921     return errors::AlreadyExists("Attempted to add existing mirror.");
922   }
923 
924   return Status::OK();
925 }
926 
SetTensor(tensorflow::Tensor && t,const Device * d)927 Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
928   DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
929 
930   if (d == device_) {
931     DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles.";
932 
933     if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
934       auto& resource_handle = t.flat<class ResourceHandle>()(0);
935       handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
936     }
937     auto& data = absl::get<LocalTensorHandleData>(data_);
938     return data.SetTensor(std::move(t));
939   } else {
940     tf_shared_lock l(mu_);
941     auto elem = local_mirrors_.find(d);
942     if (elem == local_mirrors_.end()) {
943       return errors::Internal(
944           "Attempted to set tensor for non-existent local mirror.");
945     }
946 
947     auto& mirror = elem->second;
948     return mirror.SetTensor(std::move(t));
949   }
950 
951   return Status::OK();
952 }
953 
Poison(Status status,const Device * d)954 void TensorHandle::Poison(Status status, const Device* d) {
955   DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
956 
957   if (d == device_) {
958     DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this;
959     absl::visit([status](auto& data) { data.Poison(status); }, data_);
960   } else {
961     tf_shared_lock l(mu_);
962     auto elem = local_mirrors_.find(d);
963     DCHECK(elem != local_mirrors_.end())
964         << "Attempted to poison non-existent local mirror, handle: " << this
965         << " device: " << d;
966 
967     auto& mirror = elem->second;
968     mirror.Poison(status);
969   }
970 }
971 
CopyToDevice(const EagerContext & ctx,tensorflow::Device * d,tensorflow::Tensor * output) const972 Status TensorHandle::CopyToDevice(const EagerContext& ctx,
973                                   tensorflow::Device* d,
974                                   tensorflow::Tensor* output) const {
975   tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d;
976   tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
977   const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
978   const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
979   bool is_same_device =
980       (srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
981 
982   const tensorflow::Tensor* src = nullptr;
983   TF_RETURN_IF_ERROR(Tensor(&src));
984   if (is_same_device) {
985     *output = *src;
986     return Status::OK();
987   }
988   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
989                    !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
990     return tensorflow::errors::InvalidArgument(
991         "Can't copy Tensor with type ",
992         tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
993         ".");
994   }
995   tensorflow::AllocatorAttributes attr;
996   if (src->dtype() == tensorflow::DT_VARIANT) {
997     attr.set_on_host(true);
998   }
999   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
1000   if (src->shape().num_elements() == 0) {
1001     *output = dst;
1002     return Status::OK();
1003   }
1004   tensorflow::DeviceContext* src_device_context = nullptr;
1005   if (!src_cpu) {
1006     src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
1007   }
1008   tensorflow::DeviceContext* dst_device_context = nullptr;
1009   if (!dst_cpu) {
1010     dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
1011   }
1012   // TODO(ashankar): The Sync() call below may be more aggressive than
1013   // necessary. It is based on knowledge of implementation details - that
1014   // GPU devices are implemented using 3 streams - one for host->device copies,
1015   // one for device->host copies and one for sending operations to the GPU.
1016   // With that setup, Sync()ing across all 3 streams should be sufficient
1017   // but more than necessary (since it waits for operations that might have
1018   // nothing to do with this tensor to complete).
1019   TF_RETURN_IF_ERROR(srcd->Sync());
1020   tensorflow::Notification n;
1021   tensorflow::Status status;
1022   tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
1023                                  srcd, dstd, tensorflow::AllocatorAttributes(),
1024                                  tensorflow::AllocatorAttributes(), src, &dst,
1025                                  0 /*dev_to_dev_stream_index*/,
1026                                  [&status, &n](const tensorflow::Status& s) {
1027                                    status = s;
1028                                    n.Notify();
1029                                  });
1030   n.WaitForNotification();
1031   if (status.ok()) {
1032     *output = dst;
1033     return Status::OK();
1034   }
1035   return status;
1036 }
1037 
GetResourceDevice(const ResourceHandle & handle,EagerContext * ctx)1038 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
1039   if (ctx == nullptr) {
1040     return nullptr;
1041   }
1042   Device* device = nullptr;
1043   if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
1044     LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
1045     return nullptr;
1046   }
1047   return device;
1048 }
1049 
DebugString() const1050 string TensorHandle::DebugString() const {
1051   DVLOG(4) << "Calling TensorHandle::DebugString() on " << this;
1052 
1053   string out;
1054   string device_debug = SafeDeviceDebugString(device_);
1055   strings::StrAppend(&out, "Device: ", device_debug);
1056   bool is_cpu = device_ != nullptr;
1057   // Consider supporting non-CPU tensors and CPU tensors with a device_ set to
1058   // non-NULL if needed.
1059   strings::StrAppend(
1060       &out, ", Tensor: ",
1061       is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_)
1062              : "?",
1063       "\n");
1064   return out;
1065 }
1066 
DeviceName(Status * status) const1067 const char* TensorHandle::DeviceName(Status* status) const {
1068   status->Update(WaitUnknownDevice());
1069   tensorflow::Device* d = op_device();
1070   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1071                         : d->name().c_str();
1072 }
1073 
BackingDeviceName(Status * status) const1074 const char* TensorHandle::BackingDeviceName(Status* status) const {
1075   status->Update(WaitUnknownDevice());
1076   tensorflow::Device* d = device();
1077   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1078                         : d->name().c_str();
1079 }
1080 
DeviceType(Status * status) const1081 const char* TensorHandle::DeviceType(Status* status) const {
1082   status->Update(WaitUnknownDevice());
1083   tensorflow::Device* d = op_device();
1084   return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
1085 }
1086 
DeviceId(Status * status) const1087 int TensorHandle::DeviceId(Status* status) const {
1088   status->Update(WaitUnknownDevice());
1089   tensorflow::Device* d = op_device();
1090   return (d == nullptr) ? 0 : d->parsed_name().id;
1091 }
1092 
Copy()1093 tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
1094   Ref();
1095   return this;
1096 }
1097 
1098 }  // namespace tensorflow
1099