1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
16
17 #include <algorithm>
18 #include <cstddef>
19 #include <map>
20 #include <memory>
21 #include <queue>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/strings/substitute.h"
27 #include "absl/types/variant.h"
28 #include "tensorflow/c/tf_tensor_internal.h"
29 #include "tensorflow/core/common_runtime/composite_device.h"
30 #include "tensorflow/core/common_runtime/copy_tensor.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
33 #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/framework/shape_inference.h"
36 #include "tensorflow/core/framework/tensor_shape.h"
37 #include "tensorflow/core/platform/errors.h"
38 #if !defined(IS_MOBILE_PLATFORM)
39 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
40 #endif // IS_MOBILE_PLATFORM
41 #include "tensorflow/core/framework/resource_var.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/framework/types.pb.h"
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/core/stringpiece.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/platform/mutex.h"
48 #include "tensorflow/core/profiler/lib/traceme.h"
49
50 namespace tensorflow {
51
52 namespace {
GetRemoteDeviceIncarnation(Device * device)53 int64 GetRemoteDeviceIncarnation(Device* device) {
54 if (device == nullptr || device->IsLocal()) return 0;
55 return device->attributes().incarnation();
56 }
57
SafeDeviceDebugString(Device * device)58 string SafeDeviceDebugString(Device* device) {
59 if (device == nullptr) {
60 return "[]";
61 } else {
62 return device->DebugString();
63 }
64 }
65 } // namespace
66
PackedTensorHandleData(std::vector<TensorHandle * > && handles,const TensorShape & shape)67 TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
68 std::vector<TensorHandle*>&& handles, const TensorShape& shape)
69 : handles_(std::move(handles)), shape_(shape) {
70 for (auto* handle : handles_) {
71 handle->Ref();
72 }
73 }
74
~PackedTensorHandleData()75 TensorHandle::PackedTensorHandleData::~PackedTensorHandleData() {
76 for (auto* handle : handles_) {
77 handle->Unref();
78 }
79 }
80
Shape(TensorShape * shape) const81 Status TensorHandle::PackedTensorHandleData::Shape(TensorShape* shape) const {
82 *shape = shape_;
83 return Status::OK();
84 }
85
NumDims(int * num_dims) const86 Status TensorHandle::PackedTensorHandleData::NumDims(int* num_dims) const {
87 *num_dims = shape_.dims();
88 return Status::OK();
89 }
90
Dim(int dim_index,int64 * dim) const91 Status TensorHandle::PackedTensorHandleData::Dim(int dim_index,
92 int64* dim) const {
93 *dim = shape_.dim_size(dim_index);
94 return Status::OK();
95 }
96
NumElements(int64 * num_elements) const97 Status TensorHandle::PackedTensorHandleData::NumElements(
98 int64* num_elements) const {
99 *num_elements = shape_.num_elements();
100 return Status::OK();
101 }
102
Unprotect()103 Status TensorHandle::PackedTensorHandleData::Unprotect() {
104 for (auto* handle : handles_) {
105 TF_RETURN_IF_ERROR(absl::visit([](auto& data) { return data.Unprotect(); },
106 handle->data_));
107 }
108 return Status::OK();
109 }
110
IsReady() const111 bool TensorHandle::PackedTensorHandleData::IsReady() const {
112 {
113 tf_shared_lock l(mu_);
114 if (!is_poisoned_.ok()) {
115 return true;
116 }
117 }
118 for (auto* handle : handles_) {
119 if (!handle->IsReady()) {
120 return false;
121 }
122 }
123 return true;
124 }
125
WaitReady(const char * caller) const126 Status TensorHandle::PackedTensorHandleData::WaitReady(
127 const char* caller) const {
128 {
129 tf_shared_lock l(mu_);
130 if (!is_poisoned_.ok()) {
131 return is_poisoned_;
132 }
133 }
134 for (auto* handle : handles_) {
135 TF_RETURN_IF_ERROR(handle->WaitReady(caller));
136 }
137 return Status::OK();
138 }
139
Poison(Status status)140 void TensorHandle::PackedTensorHandleData::Poison(Status status) {
141 mutex_lock l(mu_);
142 is_poisoned_ = status;
143 }
144
DebugString() const145 string TensorHandle::PackedTensorHandleData::DebugString() const {
146 string debug_str = "PackedTensorHandleData: ";
147 for (const auto* handle : handles_) {
148 debug_str.append(
149 absl::StrCat(absl::visit([](auto& data) { return data.DebugString(); },
150 handle->data_),
151 "; "));
152 }
153 return debug_str;
154 }
155
NumPackedHandles() const156 int TensorHandle::PackedTensorHandleData::NumPackedHandles() const {
157 return handles_.size();
158 }
159
ExtractPackedHandle(const int index,TensorHandle ** handle) const160 Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle(
161 const int index, TensorHandle** handle) const {
162 if (index < 0 || index >= handles_.size()) {
163 return errors::InvalidArgument("Expect an index within [0, ",
164 handles_.size(), "), but got ", index);
165 }
166 *handle = handles_.at(index);
167 return Status::OK();
168 }
169
SetResourceHandleDtypeAndShape(std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes)170 void TensorHandle::SetResourceHandleDtypeAndShape(
171 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
172 handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
173 }
174
GetResourceHandleDtypesAndShapes(std::vector<DtypeAndPartialTensorShape> * result)175 Status TensorHandle::GetResourceHandleDtypesAndShapes(
176 std::vector<DtypeAndPartialTensorShape>* result) {
177 if (dtype != DT_RESOURCE) {
178 return errors::InvalidArgument(
179 "TensorHandle::GetResourceDtypeAndShape should be called on tensor "
180 "handles with data type DT_RESOURCE. Actual tensor: ",
181 dtype);
182 }
183
184 if (Type() != LOCAL) {
185 *result = handle_dtypes_and_shapes_;
186 return Status::OK();
187 }
188
189 // Wait for this TensorHandle to be ready.
190 profiler::TraceMe activity("TensorHandle::GetResourceHandleInfo WaitReady",
191 profiler::TraceMeLevel::kInfo);
192 auto& data = absl::get<LocalTensorHandleData>(data_);
193 TF_RETURN_IF_ERROR(data.WaitReady("TensorHandle::GetResourceHandleInfo"));
194
195 *result = handle_dtypes_and_shapes_;
196 return Status::OK();
197 }
198
NumPackedHandles() const199 int TensorHandle::NumPackedHandles() const {
200 if (Type() != PACKED) {
201 return 0;
202 }
203 return absl::get<PackedTensorHandleData>(data_).NumPackedHandles();
204 }
205
ExtractPackedHandle(const int index,TensorHandle ** handle) const206 Status TensorHandle::ExtractPackedHandle(const int index,
207 TensorHandle** handle) const {
208 if (Type() != PACKED) {
209 return errors::Internal("Invalid ExtractPackedHandleOnDevice call on a",
210 TypeString(), " handle: ", this);
211 }
212 return absl::get<PackedTensorHandleData>(data_).ExtractPackedHandle(index,
213 handle);
214 }
215
CreateLocalHandle(const tensorflow::Tensor & t)216 TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) {
217 // TODO(b/136608821): Move away from nullptr
218 tensorflow::Tensor tensor = t;
219 return CreateLocalHandle(std::move(tensor),
220 /*d=*/nullptr,
221 /*op_device=*/nullptr,
222 /*ctx=*/nullptr);
223 }
224
CreateLocalHandle(tensorflow::Tensor && t,Device * d,Device * op_device,EagerContext * ctx)225 TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
226 Device* op_device,
227 EagerContext* ctx) {
228 return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx);
229 }
230
CreateLocalHandle(tensorflow::Tensor && t,Device * d,Device * op_device,Device * resource_device,EagerContext * ctx)231 TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
232 Device* op_device,
233 Device* resource_device,
234 EagerContext* ctx) {
235 if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
236 return new TensorHandle(std::move(t), d, op_device, ctx);
237 } else {
238 return new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
239 }
240 }
241
TensorHandle(tensorflow::Tensor && t,Device * d,Device * op_device,Device * resource_device,EagerContext * ctx)242 TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
243 Device* resource_device, EagerContext* ctx)
244 : ImmediateExecutionTensorHandle(kEager),
245 dtype(t.dtype()),
246 device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
247 op_device_(op_device),
248 resource_device_(resource_device),
249 resource_remote_device_incarnation_(
250 GetRemoteDeviceIncarnation(resource_device_)),
251 ctx_(ctx),
252 data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
253 DVLOG(3) << "Creating Local TensorHandle: " << this
254 << " device: " << SafeDeviceDebugString(device_)
255 << " tensor: " << t.DeviceSafeDebugString();
256 }
257
TensorHandle(tensorflow::Tensor && t,Device * d,Device * op_device,EagerContext * ctx)258 TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
259 EagerContext* ctx)
260 : ImmediateExecutionTensorHandle(kEager),
261 dtype(DT_RESOURCE),
262 device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
263 op_device_(op_device),
264 resource_device_(
265 GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
266 resource_remote_device_incarnation_(
267 GetRemoteDeviceIncarnation(resource_device_)),
268 ctx_(ctx),
269 handle_dtypes_and_shapes_(
270 t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
271 data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
272 DVLOG(3) << "Creating Local TensorHandle: " << this
273 << " device: " << SafeDeviceDebugString(device_)
274 << " tensor: " << t.DeviceSafeDebugString();
275 }
276
277
CreateEmptyLocalHandle(Device * d,Device * op_device,Device * resource_device,tensorflow::DataType dtype,EagerContext * ctx)278 TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
279 Device* resource_device,
280 tensorflow::DataType dtype,
281 EagerContext* ctx) {
282 return new TensorHandle(d, op_device, resource_device, dtype, ctx);
283 }
284
TensorHandle(Device * d,Device * op_device,Device * resource_device,tensorflow::DataType dtype,EagerContext * ctx)285 TensorHandle::TensorHandle(Device* d, Device* op_device,
286 Device* resource_device, tensorflow::DataType dtype,
287 EagerContext* ctx)
288 : ImmediateExecutionTensorHandle(kEager),
289 dtype(dtype),
290 device_((d == ctx->HostCPU()) ? nullptr : d),
291 op_device_(op_device),
292 resource_device_(resource_device),
293 resource_remote_device_incarnation_(
294 GetRemoteDeviceIncarnation(resource_device_)),
295 ctx_(ctx),
296 data_(absl::in_place_type<LocalTensorHandleData>) {
297 DVLOG(3) << "Creating empty Local TensorHandle: " << this
298 << " device: " << SafeDeviceDebugString(device_);
299 }
300
CreatePackedHandle(std::vector<TensorHandle * > && handles,const tensorflow::DataType dtype,const tensorflow::TensorShape & shape,const string & device_name,EagerContext * ctx,TensorHandle ** packed_handle)301 Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
302 const tensorflow::DataType dtype,
303 const tensorflow::TensorShape& shape,
304 const string& device_name,
305 EagerContext* ctx,
306 TensorHandle** packed_handle) {
307 if (handles.empty()) {
308 return errors::InvalidArgument("Handles should not be empty.");
309 }
310
311 std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
312 if (dtype == DT_RESOURCE) {
313 TF_RETURN_IF_ERROR(
314 handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
315 }
316 std::vector<string> devices;
317 devices.reserve(handles.size());
318 for (auto* handle : handles) {
319 devices.push_back(handle->op_device() ? handle->op_device()->name()
320 : ctx->HostCPU()->name());
321 }
322
323 CompositeDevice* composite_device = nullptr;
324 TF_RETURN_IF_ERROR(ctx->FindOrCreateCompositeDevice(devices, device_name,
325 &composite_device));
326 *packed_handle =
327 new TensorHandle(std::move(handles), composite_device, dtype, shape, ctx);
328 (*packed_handle)
329 ->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
330 return Status::OK();
331 }
332
CreatePackedHandle(std::vector<TensorHandle * > && handles,EagerContext * ctx,TensorHandle ** packed_handle)333 Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
334 EagerContext* ctx,
335 TensorHandle** packed_handle) {
336 if (handles.empty()) {
337 return errors::InvalidArgument("Handles should not be empty.");
338 }
339
340 // Get the dtype and shape from the fisrt handle since all handles have the
341 // same dtype and shape.
342 tensorflow::DataType dtype = handles.at(0)->dtype;
343 tensorflow::TensorShape shape;
344 TF_RETURN_IF_ERROR(handles.at(0)->Shape(&shape));
345 return CreatePackedHandle(std::move(handles), dtype, shape,
346 /*device_name*/ "", ctx, packed_handle);
347 }
348
TensorHandle(std::vector<TensorHandle * > && handles,Device * device,const tensorflow::DataType dtype,const tensorflow::TensorShape & shape,EagerContext * ctx)349 TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
350 const tensorflow::DataType dtype,
351 const tensorflow::TensorShape& shape,
352 EagerContext* ctx)
353 : ImmediateExecutionTensorHandle(kEager),
354 dtype(dtype),
355 device_(device),
356 op_device_(device),
357 resource_device_(dtype == DT_RESOURCE ? device : nullptr),
358 resource_remote_device_incarnation_(
359 GetRemoteDeviceIncarnation(resource_device_)),
360 ctx_(ctx),
361 data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
362 shape) {
363 DVLOG(3) << "Creating a packed TensorHandle: " << this
364 << " device: " << SafeDeviceDebugString(device_);
365 }
366
367 #if !defined(IS_MOBILE_PLATFORM)
CreateUnshapedRemoteHandle(int64 op_id,int32 output_num,const string & remote_task,tensorflow::DataType dtype,Device * d,EagerContext * ctx,const bool unknown_device)368 TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
369 int64 op_id, int32 output_num, const string& remote_task,
370 tensorflow::DataType dtype, Device* d, EagerContext* ctx,
371 const bool unknown_device) {
372 return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx,
373 unknown_device);
374 }
375
TensorHandle(int64 op_id,int32 output_num,const string & remote_task,tensorflow::DataType dtype,Device * d,EagerContext * ctx,const bool unknown_device)376 TensorHandle::TensorHandle(int64 op_id, int32 output_num,
377 const string& remote_task,
378 tensorflow::DataType dtype, Device* d,
379 EagerContext* ctx, const bool unknown_device)
380 : ImmediateExecutionTensorHandle(kEager),
381 dtype(dtype),
382 device_(d),
383 op_device_(d),
384 resource_device_(dtype == DT_RESOURCE ? d : nullptr),
385 resource_remote_device_incarnation_(
386 GetRemoteDeviceIncarnation(resource_device_)),
387 unknown_device_(unknown_device),
388 ctx_(ctx),
389 data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
390 remote_task, ctx) {
391 DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
392 << " device: " << SafeDeviceDebugString(device_);
393 }
394
CreateLazyRemoteHandle(int64 op_id,int32 output_num,tensorflow::DataType dtype,Device * d,const bool is_ready,EagerContext * ctx)395 TensorHandle* TensorHandle::CreateLazyRemoteHandle(
396 int64 op_id, int32 output_num, tensorflow::DataType dtype, Device* d,
397 const bool is_ready, EagerContext* ctx) {
398 return new TensorHandle(op_id, output_num, dtype, d, is_ready, ctx);
399 }
400
TensorHandle(int64 op_id,int32 output_num,tensorflow::DataType dtype,Device * d,const bool is_ready,EagerContext * ctx)401 TensorHandle::TensorHandle(int64 op_id, int32 output_num,
402 tensorflow::DataType dtype, Device* d,
403 const bool is_ready, EagerContext* ctx)
404 : ImmediateExecutionTensorHandle(kEager),
405 dtype(dtype),
406 device_(d),
407 op_device_(d),
408 resource_device_(dtype == DT_RESOURCE ? d : nullptr),
409 resource_remote_device_incarnation_(
410 GetRemoteDeviceIncarnation(resource_device_)),
411 ctx_(ctx),
412 data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
413 ctx->GetContextViewId(), is_ready) {
414 DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
415 << " device: " << SafeDeviceDebugString(device_);
416 }
417 #endif
418
~TensorHandle()419 TensorHandle::~TensorHandle() { DVLOG(3) << "Deleting tensor handle " << this; }
420
Release()421 void TensorHandle::Release() {
422 DVLOG(3) << "Releasing tensor handle " << this;
423 Unref();
424 }
425
DataType() const426 tensorflow::DataType TensorHandle::DataType() const { return dtype; }
427
IsReady() const428 bool TensorHandle::IsReady() const {
429 return absl::visit([](auto& data) { return data.IsReady(); }, data_);
430 }
431
WaitReady(const char * caller) const432 Status TensorHandle::WaitReady(const char* caller) const {
433 return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
434 data_);
435 }
436
Type() const437 TensorHandle::HandleType TensorHandle::Type() const {
438 if (data_.index() == 0) {
439 return LOCAL;
440 } else if (data_.index() == 1) {
441 return PACKED;
442 } else {
443 return REMOTE;
444 }
445 }
446
TypeString() const447 string TensorHandle::TypeString() const {
448 if (data_.index() == 0) {
449 return "LOCAL";
450 } else if (data_.index() == 1) {
451 return "PACKED";
452 } else {
453 return "REMOTE";
454 }
455 }
456
Tensor(const tensorflow::Tensor ** t) const457 Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
458 DVLOG(3) << "Tensor on TensorHandle: " << this;
459
460 if (Type() != LOCAL) {
461 return errors::Internal("Invalid Tensor call on a ", TypeString(),
462 " handle: ", this);
463 }
464
465 auto& data = absl::get<LocalTensorHandleData>(data_);
466 return data.Tensor(t);
467 }
468
TensorFromDevice(const Device * d,const tensorflow::Tensor ** t) const469 Status TensorHandle::TensorFromDevice(const Device* d,
470 const tensorflow::Tensor** t) const {
471 DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
472
473 if (d == device_) {
474 if (Type() != LOCAL) {
475 return errors::Internal("Invalid Tensor call on a ", TypeString(),
476 " handle: ", this);
477 }
478
479 auto& data = absl::get<LocalTensorHandleData>(data_);
480 return data.Tensor(t);
481 }
482
483 tf_shared_lock l(mu_);
484 auto elem = local_mirrors_.find(d);
485 if (elem == local_mirrors_.end()) {
486 return errors::Internal("Invalid device: ", d,
487 " in Tensor call to handle: ", this);
488 }
489
490 auto& mirror = elem->second;
491 return mirror.Tensor(t);
492 }
493
TensorValue(const Device * d,tensorflow::TensorValue * t)494 Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
495 DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
496
497 if (d == device_) {
498 if (Type() != LOCAL) {
499 return errors::Internal("Invalid TensorValue call on a ", TypeString(),
500 " handle: ", this);
501 }
502
503 auto& data = absl::get<LocalTensorHandleData>(data_);
504 return data.TensorValue(t);
505 }
506
507 tf_shared_lock l(mu_);
508 auto elem = local_mirrors_.find(d);
509 if (elem == local_mirrors_.end()) {
510 return errors::Internal("Invalid device: ", d,
511 " in TensorValue call to handle: ", this);
512 }
513
514 auto& mirror = elem->second;
515 return mirror.TensorValue(t);
516 }
517
WaitUnknownDevice() const518 Status TensorHandle::WaitUnknownDevice() const {
519 if (unknown_device_) {
520 TF_RETURN_IF_ERROR(absl::visit(
521 [](auto& data) {
522 return data.WaitReady("TensorHandle::UnknownDevice");
523 },
524 data_));
525 }
526 return Status::OK();
527 }
528
DeviceOrHostCPU(const EagerContext & ctx) const529 Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
530 return (device_ == nullptr) ? ctx.HostCPU() : device_;
531 }
532
Shape(tensorflow::TensorShape * shape)533 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
534 if (!IsReady() && inference_shape_.IsFullyDefined()) {
535 bool fill = inference_shape_.AsTensorShape(shape);
536 DCHECK(fill);
537 return Status::OK();
538 } else {
539 return absl::visit([shape](auto& data) { return data.Shape(shape); },
540 data_);
541 }
542 }
543
InferenceShape(shape_inference::InferenceContext * const inference_context,shape_inference::ShapeHandle * shape_handle)544 Status TensorHandle::InferenceShape(
545 shape_inference::InferenceContext* const inference_context,
546 shape_inference::ShapeHandle* shape_handle) {
547 if (IsReady()) {
548 TF_RETURN_IF_ERROR(is_poisoned_);
549 std::vector<shape_inference::DimensionHandle> dims_handle;
550 int num_dims;
551 TF_RETURN_IF_ERROR(NumDims(&num_dims));
552 for (int i = 0; i < num_dims; i++) {
553 int64 dims;
554 TF_RETURN_IF_ERROR(Dim(i, &dims));
555 dims_handle.push_back(inference_context->MakeDim(dims));
556 }
557 *shape_handle = inference_context->MakeShape(dims_handle);
558 return Status::OK();
559 } else {
560 if (inference_shape_.unknown_rank()) {
561 *shape_handle = inference_context->UnknownShape();
562 return Status::OK();
563 }
564 std::vector<shape_inference::DimensionHandle> dims_handle(
565 inference_shape_.dims());
566 for (int i = 0; i < dims_handle.size(); i++) {
567 dims_handle[i] = inference_context->MakeDim(inference_shape_.dim_size(i));
568 }
569 *shape_handle = inference_context->MakeShape(dims_handle);
570 return Status::OK();
571 }
572 }
573
SetInferenceShape(shape_inference::InferenceContext * const inference_context,const shape_inference::ShapeHandle & shape_handle)574 void TensorHandle::SetInferenceShape(
575 shape_inference::InferenceContext* const inference_context,
576 const shape_inference::ShapeHandle& shape_handle) {
577 auto num_dims = inference_context->Rank(shape_handle);
578 std::vector<int64> dims;
579 if (num_dims == shape_inference::InferenceContext::kUnknownRank) {
580 inference_shape_ = PartialTensorShape();
581 return;
582 }
583 DCHECK_GE(num_dims, 0);
584 dims.resize(num_dims);
585 for (size_t i = 0; i < num_dims; ++i) {
586 dims[i] = inference_context->Value(inference_context->Dim(shape_handle, i));
587 }
588 auto s = PartialTensorShape::MakePartialShape(dims.data(), num_dims,
589 &inference_shape_);
590 DCHECK(s.ok());
591 }
592
CopyInferenceShape(TensorHandle * other)593 Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
594 if (IsReady()) {
595 TF_RETURN_IF_ERROR(is_poisoned_);
596 return Status::OK();
597 }
598 if (other->IsReady()) {
599 TensorShape other_shape;
600 TF_RETURN_IF_ERROR(other->Shape(&other_shape));
601 inference_shape_ = other_shape;
602 } else {
603 inference_shape_ = other->inference_shape_;
604 }
605 return Status::OK();
606 }
607
Shape(tensorflow::PartialTensorShape * shape) const608 Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const {
609 DCHECK(shape != nullptr);
610 if (!IsReady() && !inference_shape_.unknown_rank()) {
611 *shape = inference_shape_;
612 return Status::OK();
613 } else {
614 auto result = absl::visit(
615 [](auto& data) {
616 TensorShape shape;
617 Status s = data.Shape(&shape);
618 return std::make_pair(shape, s);
619 },
620 data_);
621 TF_RETURN_IF_ERROR(result.second);
622 *shape = result.first;
623 }
624 return Status::OK();
625 }
626
NumDims(int * num_dims) const627 Status TensorHandle::NumDims(int* num_dims) const {
628 DCHECK(num_dims != nullptr);
629 if (!IsReady() && !inference_shape_.unknown_rank()) {
630 *num_dims = inference_shape_.dims();
631 return Status::OK();
632 } else {
633 return absl::visit(
634 [num_dims](auto& data) { return data.NumDims(num_dims); }, data_);
635 }
636 }
637
Dim(int dim_index,int64 * dim) const638 Status TensorHandle::Dim(int dim_index, int64* dim) const {
639 DCHECK(dim != nullptr);
640 if (!IsReady() && !inference_shape_.unknown_rank() &&
641 inference_shape_.dim_size(dim_index) != -1) {
642 *dim = inference_shape_.dim_size(dim_index);
643 return Status::OK();
644 } else {
645 return absl::visit(
646 [dim_index, dim](auto& data) { return data.Dim(dim_index, dim); },
647 data_);
648 }
649 }
650
NumElements(int64 * num_elements) const651 Status TensorHandle::NumElements(int64* num_elements) const {
652 DCHECK(num_elements != nullptr);
653 if (!IsReady() && inference_shape_.IsFullyDefined()) {
654 *num_elements = inference_shape_.num_elements();
655 return Status::OK();
656 } else {
657 return absl::visit(
658 [num_elements](auto& data) { return data.NumElements(num_elements); },
659 data_);
660 }
661 }
662
Unprotect(const Device * d)663 Status TensorHandle::Unprotect(const Device* d) {
664 DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
665
666 if (d == device_) {
667 return absl::visit([](auto& data) { return data.Unprotect(); }, data_);
668 }
669
670 tf_shared_lock l(mu_);
671 auto elem = local_mirrors_.find(d);
672 if (elem == local_mirrors_.end()) {
673 return errors::Internal("Invalid device: ", d,
674 " in Unprotect call to handle: ", this);
675 }
676
677 // Check if the handle is non-empty
678 auto& mirror = elem->second;
679 return mirror.Unprotect();
680 }
681
HasLocalMirror(const Device * d) const682 bool TensorHandle::HasLocalMirror(const Device* d) const {
683 DVLOG(3) << "HasLocalMirror on TensorHandle: " << this << " device: " << d;
684
685 tf_shared_lock l(mu_);
686 return local_mirrors_.find(d) != local_mirrors_.end();
687 }
688
AddEmptyLocalMirror(const Device * d)689 Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
690 DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this
691 << " device: " << d;
692
693 if (d == device_) {
694 return errors::Internal("Cannot add mirror for primary device.");
695 }
696
697 mutex_lock l(mu_);
698 if (local_mirrors_.find(d) != local_mirrors_.end()) {
699 return errors::AlreadyExists("Attempted to duplicate a local mirror.");
700 }
701
702 local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
703 std::forward_as_tuple());
704
705 return Status::OK();
706 }
707
708 #if !defined(IS_MOBILE_PLATFORM)
RemoteAddress(const Device * d,const bool wait_until_ready,int64 * op_id,int32 * output_num) const709 Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready,
710 int64* op_id, int32* output_num) const {
711 DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
712 << " " << d->name();
713
714 if (d != device_) {
715 tf_shared_lock l(mu_);
716 auto mirror = remote_mirrors_.find(d->name());
717 if (mirror != remote_mirrors_.end()) {
718 return mirror->second.OpIdAndOutputNum(wait_until_ready, op_id,
719 output_num);
720 }
721
722 return errors::FailedPrecondition(
723 "Could not find remote mirror for specified device");
724 }
725
726 if (Type() != REMOTE) {
727 return errors::InvalidArgument("Primary device is not remote");
728 }
729
730 auto& data = absl::get<RemoteTensorHandleData>(data_);
731 return data.OpIdAndOutputNum(wait_until_ready, op_id, output_num);
732 }
733
HasRemoteMirror(const Device * d,uint64 context_view_id) const734 bool TensorHandle::HasRemoteMirror(const Device* d,
735 uint64 context_view_id) const {
736 DVLOG(3) << "HasRemoteMirror on TensorHandle: " << this << " device: " << d
737 << " " << d->name();
738
739 tf_shared_lock l(mu_);
740 auto mirror = remote_mirrors_.find(d->name());
741 if (mirror != remote_mirrors_.end()) {
742 // Check if mirror is stale
743 if (mirror->second.context_view_id() != context_view_id) {
744 return false;
745 }
746 return true;
747 }
748
749 return false;
750 }
751
HasResourceShapeMirror(const Device * d,uint64 context_view_id) const752 bool TensorHandle::HasResourceShapeMirror(const Device* d,
753 uint64 context_view_id) const {
754 DVLOG(3) << "HasResourceShapeMirror on TensorHandle: " << this
755 << " device: " << d << " " << d->name();
756
757 tf_shared_lock l(mu_);
758 auto mirror = resource_shape_mirrors_.find(d->name());
759 if (mirror != resource_shape_mirrors_.end()) {
760 // Check if mirror is stale
761 if (mirror->second.context_view_id() != context_view_id) {
762 return false;
763 }
764 return true;
765 }
766 return false;
767 }
768
AddUnshapedRemoteMirror(const Device * d,int64 op_id,int output_num,const string & remote_task,EagerContext * ctx)769 Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id,
770 int output_num,
771 const string& remote_task,
772 EagerContext* ctx) {
773 DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
774 << " device: " << d << " " << d->name() << " op_id: " << op_id
775 << " output_num: " << output_num;
776
777 mutex_lock l(mu_);
778 auto remote_mirror = remote_mirrors_.find(d->name());
779 if (remote_mirror != remote_mirrors_.end()) {
780 if (remote_mirror->second.context_view_id() >= ctx->GetContextId()) {
781 return errors::Internal("Attempted to duplicate a remote mirror.");
782 }
783 // Remove stale mirror
784 remote_mirrors_.erase(remote_mirror);
785 }
786
787 remote_mirrors_.emplace(
788 std::piecewise_construct, std::forward_as_tuple(d->name()),
789 std::forward_as_tuple(op_id, output_num, remote_task, ctx));
790
791 return Status::OK();
792 }
793
AddResourceShapeMirror(const Device * d,int64 op_id,int output_num,EagerContext * ctx)794 Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
795 int output_num, EagerContext* ctx) {
796 DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
797
798 mutex_lock l(mu_);
799 auto mirror = resource_shape_mirrors_.find(d->name());
800 if (mirror != resource_shape_mirrors_.end()) {
801 if (mirror->second.context_view_id() == ctx->GetContextViewId()) {
802 return errors::Internal(
803 "Attempted to duplicate a resource shape mirror.");
804 }
805 // Remove stale mirror
806 resource_shape_mirrors_.erase(mirror);
807 }
808
809 resource_shape_mirrors_.emplace(
810 std::piecewise_construct, std::forward_as_tuple(d->name()),
811 std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId(),
812 /*is_ready=*/true));
813
814 return Status::OK();
815 }
816
SetRemoteShape(const TensorShape & shape,const Device * d,uint64 context_view_id)817 Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
818 uint64 context_view_id) {
819 return SetRemoteShapeAndDevice(shape, d, context_view_id, /*op_device=*/"");
820 }
821
SetRemoteShapeAndDevice(const TensorShape & shape,const Device * d,uint64 context_view_id,string op_device)822 Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape,
823 const Device* d,
824 uint64 context_view_id,
825 string op_device) {
826 DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
827 << " " << d->name();
828
829 if (d != device_) {
830 tf_shared_lock l(mu_);
831 auto remote_mirror = remote_mirrors_.find(d->name());
832 if (remote_mirror == remote_mirrors_.end()) {
833 return Status::OK();
834 }
835 auto& mirror = remote_mirror->second;
836 if (mirror.context_view_id() == context_view_id) {
837 return mirror.SetShape(shape);
838 } else if (mirror.context_view_id() < context_view_id) {
839 return errors::Internal(
840 absl::Substitute("Unexpected context_view_id ($0) which should not "
841 "be newer than the "
842 "one ($1) associated to the remote mirror.",
843 context_view_id, mirror.context_view_id()));
844 } else {
845 LOG(WARNING) << "SetRemoteShape is ignored for a remote mirror that is "
846 "accociated with a newer context_view_id.";
847 }
848 return Status::OK();
849 }
850
851 DCHECK(Type() == REMOTE)
852 << "SetRemoteShape is only called on remote handles.";
853
854 auto& data = absl::get<RemoteTensorHandleData>(data_);
855 // context_view_id is currently used to validate mirrors. The shape of
856 // RemoteTensorHandleData should be set without checking context_view_id.
857 // The reason behind it is that for the primary copy of data, if the remote
858 // worker / device is removed, the consumer should report a connection error
859 // indicating the remote tensor is no longer available.
860 // For mirrors, this is not the case because they colocate with the data
861 // consuming op/function device, and we (for now) have to aggressively
862 // invalidate those copies to avoid any false positives during cluster update.
863 if (op_device.empty()) {
864 return data.SetShape(shape);
865 } else {
866 if (!unknown_device_) {
867 return errors::Internal("Cannot reset known devices.");
868 }
869 Device* device;
870 TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(op_device.c_str(), &device));
871 device_ = device;
872 op_device_ = device;
873 resource_device_ = dtype == DT_RESOURCE ? device : nullptr;
874 resource_remote_device_incarnation_ =
875 GetRemoteDeviceIncarnation(resource_device_);
876 string remote_task;
877 if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
878 return errors::InvalidArgument(
879 "Unable to find remote task corresponding to device ",
880 device->name());
881 }
882 return data.SetShapeAndRemoteTask(shape, remote_task);
883 }
884 }
885
PoisonRemote(Status status,const Device * d,uint64 context_view_id)886 void TensorHandle::PoisonRemote(Status status, const Device* d,
887 uint64 context_view_id) {
888 DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d
889 << " " << d->name();
890
891 if (d == device_) {
892 DCHECK(Type() == REMOTE)
893 << "Poison can only be on remote handles: " << this;
894
895 auto& data = absl::get<RemoteTensorHandleData>(data_);
896 data.Poison(status);
897 } else {
898 tf_shared_lock l(mu_);
899 auto mirror = remote_mirrors_.find(d->name());
900 if (mirror != remote_mirrors_.end()) {
901 if (mirror->second.context_view_id() == context_view_id) {
902 mirror->second.Poison(status);
903 }
904 }
905 }
906 }
907 #endif
908
AddLocalMirror(tensorflow::Tensor && tensor,const Device * d)909 Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
910 const Device* d) {
911 if (d == device_) {
912 return errors::Internal(
913 "Local mirror assign conflicts with primary device.");
914 }
915
916 mutex_lock l(mu_);
917 auto elem =
918 local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
919 std::forward_as_tuple(std::move(tensor)));
920 if (!elem.second) {
921 return errors::AlreadyExists("Attempted to add existing mirror.");
922 }
923
924 return Status::OK();
925 }
926
SetTensor(tensorflow::Tensor && t,const Device * d)927 Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
928 DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
929
930 if (d == device_) {
931 DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles.";
932
933 if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
934 auto& resource_handle = t.flat<class ResourceHandle>()(0);
935 handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
936 }
937 auto& data = absl::get<LocalTensorHandleData>(data_);
938 return data.SetTensor(std::move(t));
939 } else {
940 tf_shared_lock l(mu_);
941 auto elem = local_mirrors_.find(d);
942 if (elem == local_mirrors_.end()) {
943 return errors::Internal(
944 "Attempted to set tensor for non-existent local mirror.");
945 }
946
947 auto& mirror = elem->second;
948 return mirror.SetTensor(std::move(t));
949 }
950
951 return Status::OK();
952 }
953
Poison(Status status,const Device * d)954 void TensorHandle::Poison(Status status, const Device* d) {
955 DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
956
957 if (d == device_) {
958 DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this;
959 absl::visit([status](auto& data) { data.Poison(status); }, data_);
960 } else {
961 tf_shared_lock l(mu_);
962 auto elem = local_mirrors_.find(d);
963 DCHECK(elem != local_mirrors_.end())
964 << "Attempted to poison non-existent local mirror, handle: " << this
965 << " device: " << d;
966
967 auto& mirror = elem->second;
968 mirror.Poison(status);
969 }
970 }
971
CopyToDevice(const EagerContext & ctx,tensorflow::Device * d,tensorflow::Tensor * output) const972 Status TensorHandle::CopyToDevice(const EagerContext& ctx,
973 tensorflow::Device* d,
974 tensorflow::Tensor* output) const {
975 tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d;
976 tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
977 const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
978 const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
979 bool is_same_device =
980 (srcd == dstd) || (srcd->name() == dstd->name()) || (dst_cpu && src_cpu);
981
982 const tensorflow::Tensor* src = nullptr;
983 TF_RETURN_IF_ERROR(Tensor(&src));
984 if (is_same_device) {
985 *output = *src;
986 return Status::OK();
987 }
988 if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
989 !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
990 return tensorflow::errors::InvalidArgument(
991 "Can't copy Tensor with type ",
992 tensorflow::DataTypeString(src->dtype()), " to device ", dstd->name(),
993 ".");
994 }
995 tensorflow::AllocatorAttributes attr;
996 if (src->dtype() == tensorflow::DT_VARIANT) {
997 attr.set_on_host(true);
998 }
999 tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
1000 if (src->shape().num_elements() == 0) {
1001 *output = dst;
1002 return Status::OK();
1003 }
1004 tensorflow::DeviceContext* src_device_context = nullptr;
1005 if (!src_cpu) {
1006 src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
1007 }
1008 tensorflow::DeviceContext* dst_device_context = nullptr;
1009 if (!dst_cpu) {
1010 dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
1011 }
1012 // TODO(ashankar): The Sync() call below may be more aggressive than
1013 // necessary. It is based on knowledge of implementation details - that
1014 // GPU devices are implemented using 3 streams - one for host->device copies,
1015 // one for device->host copies and one for sending operations to the GPU.
1016 // With that setup, Sync()ing across all 3 streams should be sufficient
1017 // but more than necessary (since it waits for operations that might have
1018 // nothing to do with this tensor to complete).
1019 TF_RETURN_IF_ERROR(srcd->Sync());
1020 tensorflow::Notification n;
1021 tensorflow::Status status;
1022 tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
1023 srcd, dstd, tensorflow::AllocatorAttributes(),
1024 tensorflow::AllocatorAttributes(), src, &dst,
1025 0 /*dev_to_dev_stream_index*/,
1026 [&status, &n](const tensorflow::Status& s) {
1027 status = s;
1028 n.Notify();
1029 });
1030 n.WaitForNotification();
1031 if (status.ok()) {
1032 *output = dst;
1033 return Status::OK();
1034 }
1035 return status;
1036 }
1037
GetResourceDevice(const ResourceHandle & handle,EagerContext * ctx)1038 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
1039 if (ctx == nullptr) {
1040 return nullptr;
1041 }
1042 Device* device = nullptr;
1043 if (!ctx->FindDeviceFromName(handle.device().c_str(), &device).ok()) {
1044 LOG(ERROR) << "Cannot find resource device: " << handle.device() << ".";
1045 return nullptr;
1046 }
1047 return device;
1048 }
1049
DebugString() const1050 string TensorHandle::DebugString() const {
1051 DVLOG(4) << "Calling TensorHandle::DebugString() on " << this;
1052
1053 string out;
1054 string device_debug = SafeDeviceDebugString(device_);
1055 strings::StrAppend(&out, "Device: ", device_debug);
1056 bool is_cpu = device_ != nullptr;
1057 // Consider supporting non-CPU tensors and CPU tensors with a device_ set to
1058 // non-NULL if needed.
1059 strings::StrAppend(
1060 &out, ", Tensor: ",
1061 is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_)
1062 : "?",
1063 "\n");
1064 return out;
1065 }
1066
DeviceName(Status * status) const1067 const char* TensorHandle::DeviceName(Status* status) const {
1068 status->Update(WaitUnknownDevice());
1069 tensorflow::Device* d = op_device();
1070 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1071 : d->name().c_str();
1072 }
1073
BackingDeviceName(Status * status) const1074 const char* TensorHandle::BackingDeviceName(Status* status) const {
1075 status->Update(WaitUnknownDevice());
1076 tensorflow::Device* d = device();
1077 return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
1078 : d->name().c_str();
1079 }
1080
DeviceType(Status * status) const1081 const char* TensorHandle::DeviceType(Status* status) const {
1082 status->Update(WaitUnknownDevice());
1083 tensorflow::Device* d = op_device();
1084 return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
1085 }
1086
DeviceId(Status * status) const1087 int TensorHandle::DeviceId(Status* status) const {
1088 status->Update(WaitUnknownDevice());
1089 tensorflow::Device* d = op_device();
1090 return (d == nullptr) ? 0 : d->parsed_name().id;
1091 }
1092
Copy()1093 tensorflow::ImmediateExecutionTensorHandle* TensorHandle::Copy() {
1094 Ref();
1095 return this;
1096 }
1097
1098 } // namespace tensorflow
1099