• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
16 
17 #include "tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h"
18 #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/gtl/cleanup.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 #include "tensorflow/core/profiler/lib/traceme.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 
DestroyRemoteTensorHandle(EagerContext * ctx,const string & remote_task,uint64 context_id,uint64 op_id,int output_num,bool ready)28 void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
29                                uint64 context_id, uint64 op_id, int output_num,
30                                bool ready) {
31   if (ctx->GetContextId() != context_id) {
32     // This means that this tensor was pointing to a remote device, which
33     // has been changed out from under us. Simply return since there is
34     // nothing we can do.
35     return;
36   }
37 
38   core::RefCountPtr<eager::EagerClient> eager_client;
39   Status status = ctx->GetClient(remote_task, &eager_client);
40   if (!status.ok()) {
41     LOG_EVERY_N_SEC(INFO, 60)
42         << "Unable to destroy remote tensor handle because the target "
43         << remote_task << " is no longer available.";
44     return;
45   }
46 
47   std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
48   request->set_context_id(context_id);
49 
50   auto* handle_to_decref = request->add_queue()->mutable_handle_to_decref();
51   handle_to_decref->set_op_id(op_id);
52   handle_to_decref->set_output_num(output_num);
53 
54   VLOG(3) << "Sending request to delete " << request->DebugString();
55   std::unique_ptr<EagerNode> node(
56       absl::make_unique<eager::DestroyTensorHandleNode>(
57           std::move(request), std::move(eager_client), ready));
58   auto& executor = ctx->Executor();
59   if (executor.Async()) {
60     Status status = executor.AddOrExecute(std::move(node));
61     if (!status.ok()) {
62       LOG_EVERY_N_SEC(WARNING, 60)
63           << "Unable to destroy remote tensor handles. If you are "
64              "running a tf.function, it usually indicates some op in "
65              "the graph gets an error: "
66           << status.error_message();
67     }
68   } else {
69     // This thread may still hold tensorflow::StreamingRPCState::mu_. We need
70     // to send out the destroy request in a new thread to avoid deadlock.
71     auto* released_node = node.release();
72     (*ctx->runner())([ctx, released_node] {
73       Status status =
74           ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
75       if (!status.ok()) {
76         LOG_EVERY_N_SEC(WARNING, 60)
77             << "Unable to destroy remote tensor handles. If you are "
78                "running a tf.function, it usually indicates some op in "
79                "the graph gets an error: "
80             << status.error_message();
81       }
82     });
83   }
84 }
85 }  // namespace
86 
RemoteTensorHandleData(int64 op_id,int output_num,uint64 context_view_id,bool is_ready)87 RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
88                                                uint64 context_view_id,
89                                                bool is_ready)
90     : is_ready_(is_ready),
91       op_id_(op_id),
92       output_num_(output_num),
93       context_view_id_(context_view_id),
94       ctx_(nullptr) {
95   DCHECK(op_id_ >= 0 && output_num_ >= 0)
96       << "Op ID and output num should be >= 0. Op ID: " << op_id
97       << ", Output num: " << output_num;
98 }
99 
RemoteTensorHandleData(int64 op_id,int output_num,const string & remote_task,EagerContext * ctx)100 RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
101                                                const string& remote_task,
102                                                EagerContext* ctx)
103     : is_ready_(false),
104       op_id_(op_id),
105       output_num_(output_num),
106       remote_task_(remote_task),
107       context_id_(ctx->GetContextId()),
108       context_view_id_(ctx->GetContextViewId()),
109       ctx_(ctx) {
110   DCHECK(op_id_ >= 0 && output_num_ >= 0)
111       << "Op ID and output num should be >= 0. Op ID: " << op_id
112       << ", Output num: " << output_num;
113   ctx_->Ref();
114 }
115 
~RemoteTensorHandleData()116 RemoteTensorHandleData::~RemoteTensorHandleData() {
117   if (ctx_) {
118     DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
119                               output_num_, /*ready=*/true);
120     ctx_->Unref();
121   }
122 }
123 
Shape(TensorShape * shape) const124 Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
125   TF_RETURN_IF_ERROR(WaitReady("Shape"));
126 
127   tf_shared_lock l(mu_);
128   *shape = shape_;
129 
130   return Status::OK();
131 }
132 
NumDims(int * num_dims) const133 Status RemoteTensorHandleData::NumDims(int* num_dims) const {
134   TF_RETURN_IF_ERROR(WaitReady("NumDims"));
135 
136   tf_shared_lock l(mu_);
137   *num_dims = shape_.dims();
138 
139   return Status::OK();
140 }
141 
Dim(int dim_index,int64 * dim) const142 Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
143   TF_RETURN_IF_ERROR(WaitReady("Dim"));
144 
145   tf_shared_lock l(mu_);
146   *dim = shape_.dim_size(dim_index);
147 
148   return Status::OK();
149 }
150 
NumElements(int64 * num_elements) const151 Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
152   TF_RETURN_IF_ERROR(WaitReady("NumElements"));
153 
154   tf_shared_lock l(mu_);
155   *num_elements = shape_.num_elements();
156 
157   return Status::OK();
158 }
159 
IsReady() const160 bool RemoteTensorHandleData::IsReady() const {
161   tf_shared_lock l(mu_);
162   return is_ready_;
163 }
164 
Poison(Status status)165 void RemoteTensorHandleData::Poison(Status status) {
166   mutex_lock l(mu_);
167   is_poisoned_ = status;
168   is_ready_ = true;
169 }
170 
IsPoisoned() const171 Status RemoteTensorHandleData::IsPoisoned() const {
172   tf_shared_lock l(mu_);
173   return is_poisoned_;
174 }
175 
SetShape(const TensorShape & shape)176 Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
177   return SetShapeAndRemoteTask(shape, /*remote_task=*/"");
178 }
179 
SetShapeAndRemoteTask(const TensorShape & shape,const string & remote_task)180 Status RemoteTensorHandleData::SetShapeAndRemoteTask(
181     const TensorShape& shape, const string& remote_task) {
182   // If `is_ready_` is set previously due to poisoning, return the original
183   // error that poisoned this tensor.
184   TF_RETURN_IF_ERROR(IsPoisoned());
185 
186   mutex_lock l(mu_);
187   if (is_ready_) {
188     return errors::Internal("SetShape is only called on non-ready handles.");
189   }
190 
191   shape_ = shape;
192   if (!remote_task.empty()) {
193     remote_task_ = remote_task;
194   }
195   is_poisoned_ = Status::OK();
196   is_ready_ = true;
197 
198   return Status::OK();
199 }
200 
DebugString() const201 string RemoteTensorHandleData::DebugString() const {
202   return strings::StrCat("RemoteTensorHandleData:", " op_id: ", op_id_,
203                          " output_num: ", output_num_);
204 }
205 
OpIdAndOutputNum(const bool wait_util_ready,int64 * op_id,int32 * output_num) const206 Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_util_ready,
207                                                 int64* op_id,
208                                                 int32* output_num) const {
209   if (wait_util_ready) {
210     TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady"));
211   }
212   *op_id = op_id_;
213   *output_num = output_num_;
214   return Status::OK();
215 }
216 
WaitReady(const char * caller) const217 Status RemoteTensorHandleData::WaitReady(const char* caller) const {
218   tf_shared_lock l(mu_);
219   if (!is_ready_) {
220     profiler::TraceMe activity(
221         [caller] { return absl::StrCat(caller, " WaitReady"); },
222         profiler::TraceMeLevel::kInfo);
223     DVLOG(3) << "WaitReady: " << caller << " " << this;
224     // TODO(b/155493048): add a timeout here if it could cause any hanging
225     // issue.
226     mu_.Await(Condition(&is_ready_));
227   }
228   return is_poisoned_;
229 }
230 
231 }  // namespace tensorflow
232