1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
16
17 #include "tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h"
18 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/profiler/lib/traceme.h"
23
24 namespace tensorflow {
25
26 namespace {
27
DestroyRemoteTensorHandle(EagerContext * ctx,const string & remote_task,uint64 context_id,uint64 op_id,int output_num,bool ready)28 void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
29 uint64 context_id, uint64 op_id, int output_num,
30 bool ready) {
31 if (ctx->GetContextId() != context_id) {
32 // This means that this tensor was pointing to a remote device, which
33 // has been changed out from under us. Simply return since there is
34 // nothing we can do.
35 return;
36 }
37
38 core::RefCountPtr<eager::EagerClient> eager_client;
39 Status status = ctx->GetClient(remote_task, &eager_client);
40 if (!status.ok()) {
41 LOG_EVERY_N_SEC(INFO, 60)
42 << "Unable to destroy remote tensor handle because the target "
43 << remote_task << " is no longer available.";
44 return;
45 }
46
47 std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
48 request->set_context_id(context_id);
49
50 auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
51 handle_to_decref->set_op_id(op_id);
52 handle_to_decref->set_output_num(output_num);
53
54 VLOG(3) << "Sending request to delete " << request->DebugString();
55 std::unique_ptr<EagerNode> node(
56 absl::make_unique<eager::DestroyTensorHandleNode>(
57 std::move(request), std::move(eager_client), ready));
58 auto& executor = ctx->Executor();
59 if (executor.Async()) {
60 Status status = executor.AddOrExecute(std::move(node));
61 if (!status.ok()) {
62 LOG_EVERY_N_SEC(WARNING, 60)
63 << "Unable to destroy remote tensor handles. If you are "
64 "running a tf.function, it usually indicates some op in "
65 "the graph gets an error: "
66 << status.error_message();
67 }
68 } else {
69 // This thread may still hold tensorflow::StreamingRPCState::mu_. We need
70 // to send out the destroy request in a new thread to avoid deadlock.
71 auto* released_node = node.release();
72 (*ctx->runner())([ctx, released_node] {
73 Status status =
74 ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
75 if (!status.ok()) {
76 LOG_EVERY_N_SEC(WARNING, 60)
77 << "Unable to destroy remote tensor handles. If you are "
78 "running a tf.function, it usually indicates some op in "
79 "the graph gets an error: "
80 << status.error_message();
81 }
82 });
83 }
84 }
85 } // namespace
86
RemoteTensorHandleData(int64 op_id,int output_num,uint64 context_view_id,bool is_ready)87 RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
88 uint64 context_view_id,
89 bool is_ready)
90 : is_ready_(is_ready),
91 op_id_(op_id),
92 output_num_(output_num),
93 context_view_id_(context_view_id),
94 ctx_(nullptr) {
95 DCHECK(op_id_ >= 0 && output_num_ >= 0)
96 << "Op ID and output num should be >= 0. Op ID: " << op_id
97 << ", Output num: " << output_num;
98 }
99
RemoteTensorHandleData(int64 op_id,int output_num,const string & remote_task,EagerContext * ctx)100 RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
101 const string& remote_task,
102 EagerContext* ctx)
103 : is_ready_(false),
104 op_id_(op_id),
105 output_num_(output_num),
106 remote_task_(remote_task),
107 context_id_(ctx->GetContextId()),
108 context_view_id_(ctx->GetContextViewId()),
109 ctx_(ctx) {
110 DCHECK(op_id_ >= 0 && output_num_ >= 0)
111 << "Op ID and output num should be >= 0. Op ID: " << op_id
112 << ", Output num: " << output_num;
113 ctx_->Ref();
114 }
115
~RemoteTensorHandleData()116 RemoteTensorHandleData::~RemoteTensorHandleData() {
117 if (ctx_) {
118 DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
119 output_num_, /*ready=*/true);
120 ctx_->Unref();
121 }
122 }
123
Shape(TensorShape * shape) const124 Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
125 TF_RETURN_IF_ERROR(WaitReady("Shape"));
126
127 tf_shared_lock l(mu_);
128 *shape = shape_;
129
130 return Status::OK();
131 }
132
NumDims(int * num_dims) const133 Status RemoteTensorHandleData::NumDims(int* num_dims) const {
134 TF_RETURN_IF_ERROR(WaitReady("NumDims"));
135
136 tf_shared_lock l(mu_);
137 *num_dims = shape_.dims();
138
139 return Status::OK();
140 }
141
Dim(int dim_index,int64 * dim) const142 Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
143 TF_RETURN_IF_ERROR(WaitReady("Dim"));
144
145 tf_shared_lock l(mu_);
146 *dim = shape_.dim_size(dim_index);
147
148 return Status::OK();
149 }
150
NumElements(int64 * num_elements) const151 Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
152 TF_RETURN_IF_ERROR(WaitReady("NumElements"));
153
154 tf_shared_lock l(mu_);
155 *num_elements = shape_.num_elements();
156
157 return Status::OK();
158 }
159
IsReady() const160 bool RemoteTensorHandleData::IsReady() const {
161 tf_shared_lock l(mu_);
162 return is_ready_;
163 }
164
Poison(Status status)165 void RemoteTensorHandleData::Poison(Status status) {
166 mutex_lock l(mu_);
167 is_poisoned_ = status;
168 is_ready_ = true;
169 }
170
IsPoisoned() const171 Status RemoteTensorHandleData::IsPoisoned() const {
172 tf_shared_lock l(mu_);
173 return is_poisoned_;
174 }
175
SetShape(const TensorShape & shape)176 Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
177 return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
178 }
179
SetShapeAndRemoteTask(const TensorShape & shape,const string & remote_task)180 Status RemoteTensorHandleData::SetShapeAndRemoteTask(
181 const TensorShape& shape, const string& remote_task) {
182 // If `is_ready_` is set previously due to poisoning, return the original
183 // error that poisoned this tensor.
184 TF_RETURN_IF_ERROR(IsPoisoned());
185
186 mutex_lock l(mu_);
187 if (is_ready_) {
188 return errors::Internal("SetShape is only called on non-ready handles.");
189 }
190
191 shape_ = shape;
192 if (!remote_task.empty()) {
193 remote_task_ = remote_task;
194 }
195 is_poisoned_ = Status::OK();
196 is_ready_ = true;
197
198 return Status::OK();
199 }
200
DebugString() const201 string RemoteTensorHandleData::DebugString() const {
202 return strings::StrCat("RemoteTensorHandleData:", " op_id: ", op_id_,
203 " output_num: ", output_num_);
204 }
205
OpIdAndOutputNum(const bool wait_util_ready,int64 * op_id,int32 * output_num) const206 Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
207 int64* op_id,
208 int32* output_num) const {
209 if (wait_util_ready) {
210 TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
211 }
212 *op_id = op_id_;
213 *output_num = output_num_;
214 return Status::OK();
215 }
216
WaitReady(const char * caller) const217 Status RemoteTensorHandleData::WaitReady(const char* caller) const {
218 tf_shared_lock l(mu_);
219 if (!is_ready_) {
220 profiler::TraceMe activity(
221 [caller] { return absl::StrCat(caller, " WaitReady"); },
222 profiler::TraceMeLevel::kInfo);
223 DVLOG(3) << "WaitReady: " << caller << " " << this;
224 // TODO(b/155493048): add a timeout here if it could cause any hanging
225 // issue.
226 mu_.Await(Condition(&is_ready_));
227 }
228 return is_poisoned_;
229 }
230
231 } // namespace tensorflow
232