1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/worker.h"
17
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_session.h"
26 #include "tensorflow/core/platform/device_tracer.h"
27 #include "tensorflow/core/platform/tracing.h"
28
29 namespace tensorflow {
30
Worker(WorkerEnv * env)31 Worker::Worker(WorkerEnv* env) : env_(env) {}
32
GetStatusAsync(const GetStatusRequest * request,GetStatusResponse * response,StatusCallback done)33 void Worker::GetStatusAsync(const GetStatusRequest* request,
34 GetStatusResponse* response, StatusCallback done) {
35 DeviceMgr* dm = env_->device_mgr;
36 std::vector<DeviceAttributes> devices;
37 dm->ListDeviceAttributes(&devices);
38 response->mutable_device_attributes()->Reserve(devices.size());
39 for (auto& d : devices) {
40 response->add_device_attributes()->Swap(&d);
41 }
42 done(Status::OK());
43 }
44
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)45 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
46 CreateWorkerSessionResponse* response,
47 StatusCallback done) {
48 Status s = env_->session_mgr->CreateSession(request->session_handle(),
49 request->server_def(),
50 request->isolate_session_state());
51 done(s);
52 }
53
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)54 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
55 const DeleteWorkerSessionRequest* request,
56 DeleteWorkerSessionResponse* response,
57 StatusCallback done) {
58 Status s = env_->session_mgr->DeleteSession(request->session_handle());
59 done(s);
60 }
61
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)62 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
63 RegisterGraphResponse* response,
64 StatusCallback done) {
65 std::shared_ptr<WorkerSession> session;
66 Status s;
67 if (request->create_worker_session_called()) {
68 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
69 &session);
70 } else {
71 session = env_->session_mgr->LegacySession();
72 }
73 if (s.ok()) {
74 s = session->graph_mgr->Register(
75 request->session_handle(), request->graph_def(),
76 request->graph_options(), request->debug_options(),
77 request->collective_graph_key(), session->cluster_flr.get(),
78 response->mutable_graph_handle());
79 }
80 done(s);
81 }
82
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)83 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
84 DeregisterGraphResponse* response,
85 StatusCallback done) {
86 std::shared_ptr<WorkerSession> session;
87 Status s;
88 if (request->create_worker_session_called()) {
89 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
90 &session);
91 } else {
92 session = env_->session_mgr->LegacySession();
93 }
94 if (s.ok()) {
95 s = session->graph_mgr->Deregister(request->graph_handle());
96 }
97
98 done(s);
99 }
100
AbortStep(int64 step_id)101 void Worker::AbortStep(int64 step_id) {
102 Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
103 SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
104 // Delay a bit before aborting the step. This way, the root
105 // cause may return first back to the client instead of this
106 // cancellation generated abort error.
107 rendez->StartAbort(errors::Aborted("Step ", step_id,
108 " cancelled. Cancelling rendezvous."));
109 rendez->Unref();
110 });
111 }
112
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)113 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
114 GraphMgr::NamedTensors* in,
115 GraphMgr::NamedTensors* out) {
116 static Tensor empty_tensor(DT_FLOAT);
117 if (req->num_sends() > 0) {
118 Tensor val;
119 for (size_t i = 0; i < req->num_sends(); ++i) {
120 TF_RETURN_IF_ERROR(req->SendValue(i, &val));
121 in->insert({req->send_key(i), val});
122 }
123 }
124 for (size_t i = 0; i < req->num_recvs(); ++i) {
125 out->insert({req->recv_key(i), empty_tensor});
126 }
127 return Status::OK();
128 }
129
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)130 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
131 MutableRunGraphResponseWrapper* response,
132 StatusCallback done) {
133 if (request->store_errors_in_response_body()) {
134 done = [response, done](const Status& status) {
135 response->set_status(status);
136 done(Status::OK());
137 };
138 }
139 if (request->is_partial()) {
140 DoPartialRunGraph(opts, request, response, std::move(done));
141 } else {
142 DoRunGraph(opts, request, response, std::move(done));
143 }
144 }
145
CreateRunGraphRequest()146 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
147 return new InMemoryRunGraphRequest;
148 }
149
CreateRunGraphResponse()150 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
151 return new InMemoryRunGraphResponse;
152 }
153
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)154 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
155 MutableRunGraphResponseWrapper* response,
156 StatusCallback done) {
157 const int64 step_id = request->step_id();
158 TRACEPRINTF("RunGraph: %lld", step_id);
159 std::shared_ptr<WorkerSession> session;
160 Status s;
161 if (request->create_worker_session_called()) {
162 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
163 &session);
164 } else {
165 session = env_->session_mgr->LegacySession();
166 }
167 if (!s.ok()) {
168 done(s);
169 return;
170 }
171 GraphMgr::NamedTensors in;
172 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
173 s = PrepareRunGraph(request, &in, out);
174 if (!s.ok()) {
175 delete out;
176 done(s);
177 return;
178 }
179 StepStatsCollector* collector = nullptr;
180 if (request->exec_opts().report_tensor_allocations_upon_oom() ||
181 request->exec_opts().record_timeline() ||
182 request->exec_opts().record_costs()) {
183 collector = new StepStatsCollector(response->mutable_step_stats());
184 }
185 DeviceTracer* tracer = nullptr;
186 if (collector && request->exec_opts().record_timeline()) {
187 // If timeline was requested, assume we want hardware level tracing.
188 std::unique_ptr<DeviceTracer> trptr = CreateDeviceTracer();
189 if (trptr) {
190 tracer = trptr.release();
191 Status s = tracer->Start();
192 if (!s.ok()) {
193 delete tracer;
194 if (errors::IsUnavailable(s)) {
195 LOG(WARNING)
196 << "Hardware tracing unavailable, continuing without it. " << s;
197 tracer = nullptr;
198 } else {
199 delete collector;
200 delete out;
201 done(s);
202 return;
203 }
204 }
205 }
206 }
207 CancellationManager* cm = new CancellationManager;
208 opts->SetCancelCallback([this, cm, step_id]() {
209 cm->StartCancel();
210 AbortStep(step_id);
211 });
212 CancellationToken token;
213 token = cancellation_manager_.get_cancellation_token();
214 bool already_cancelled = !cancellation_manager_.RegisterCallback(
215 token, [cm]() { cm->StartCancel(); });
216 if (already_cancelled) {
217 opts->ClearCancelCallback();
218 delete cm;
219 delete collector;
220 delete tracer;
221 delete out;
222 done(errors::Aborted("Call was aborted"));
223 return;
224 }
225 session->graph_mgr->ExecuteAsync(
226 request->graph_handle(), step_id, session.get(), request->exec_opts(),
227 collector, response, cm, in,
228 [this, step_id, response, session, cm, out, token, collector, tracer,
229 opts, done](Status s) {
230 if (s.ok()) {
231 s = session->graph_mgr->RecvOutputs(step_id, out);
232 }
233 opts->ClearCancelCallback();
234 cancellation_manager_.DeregisterCallback(token);
235 delete cm;
236
237 if (tracer) {
238 Status tracer_status = tracer->Stop();
239 if (tracer_status.ok()) {
240 tracer_status = tracer->Collect(collector);
241 }
242 if (!tracer_status.ok()) {
243 LOG(ERROR) << "Bad status from tracer: " << tracer_status;
244 }
245 }
246 if (s.ok()) {
247 for (const auto& p : *out) {
248 const string& key = p.first;
249 const Tensor& val = p.second;
250 response->AddRecv(key, val);
251 }
252 }
253 if (collector) collector->Finalize();
254 delete collector;
255 delete tracer;
256 delete out;
257 done(s);
258 });
259 }
260
261 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)262 void Worker::DoPartialRunGraph(CallOptions* opts,
263 RunGraphRequestWrapper* request,
264 MutableRunGraphResponseWrapper* response,
265 StatusCallback done) {
266 const int64 step_id = request->step_id();
267 const string& graph_handle = request->graph_handle();
268 TRACEPRINTF("PartialRunGraph: %lld", step_id);
269 std::shared_ptr<WorkerSession> session;
270
271 Status s;
272 if (request->create_worker_session_called()) {
273 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
274 &session);
275 } else {
276 session = env_->session_mgr->LegacySession();
277 }
278 if (!s.ok()) {
279 done(s);
280 return;
281 }
282
283 GraphMgr::NamedTensors in;
284 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
285 s = PrepareRunGraph(request, &in, out);
286 auto finish = [done, out, opts](const Status& s) {
287 opts->ClearCancelCallback();
288 delete out;
289 done(s);
290 };
291 if (!s.ok()) {
292 finish(s);
293 return;
294 }
295
296 CancellationManager* cm = nullptr;
297 bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
298
299 // Before we start doing anything, we set the RPC cancellation.
300 opts->SetCancelCallback([this, cm, step_id]() {
301 cm->StartCancel();
302 AbortStep(step_id);
303 });
304
305 // If this is a new partial run request, the request will need to start the
306 // executors.
307 if (is_new_partial_run) {
308 CancellationToken token;
309 token = cancellation_manager_.get_cancellation_token();
310 cancellation_manager_.RegisterCallback(token,
311 [cm]() { cm->StartCancel(); });
312 session->graph_mgr->ExecuteAsync(
313 graph_handle, step_id, session.get(), request->exec_opts(),
314 nullptr /* collector */, nullptr /* response */, cm, in,
315 [this, token, step_id, session](Status s) {
316 cancellation_manager_.DeregisterCallback(token);
317 partial_run_mgr_.ExecutorDone(step_id, s);
318 });
319 } else {
320 // Send the partial run's new inputs.
321 s = session->graph_mgr->SendInputs(step_id, in);
322 if (!s.ok()) {
323 finish(s);
324 return;
325 }
326 }
327
328 session->graph_mgr->RecvOutputsAsync(
329 step_id, out, [this, out, request, response, step_id, finish](Status s) {
330 if (s.ok()) {
331 // Construct and return the resp.
332 for (const auto& p : *out) {
333 const string& key = p.first;
334 const Tensor& val = p.second;
335 response->AddRecv(key, val);
336 }
337 }
338 if (request->is_last_partial_run()) {
339 partial_run_mgr_.PartialRunDone(step_id, finish, s);
340 } else {
341 finish(s);
342 }
343 });
344 }
345
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)346 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
347 CleanupGraphResponse* response,
348 StatusCallback done) {
349 const int64 step_id = request->step_id();
350 env_->rendezvous_mgr->Cleanup(step_id);
351 if (env_->collective_executor_mgr) {
352 env_->collective_executor_mgr->Cleanup(step_id);
353 }
354 for (Device* d : env_->local_devices) {
355 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
356 if (sam) {
357 sam->Cleanup(step_id);
358 }
359 }
360 done(Status::OK());
361 }
362
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)363 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
364 CleanupAllResponse* response,
365 StatusCallback done) {
366 std::vector<string> containers;
367 for (const auto& c : request->container()) containers.push_back(c);
368 env_->device_mgr->ClearContainers(containers);
369 done(Status::OK());
370 }
371
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)372 void Worker::LoggingAsync(const LoggingRequest* request,
373 LoggingResponse* response, StatusCallback done) {
374 done(errors::Unimplemented("Logging"));
375 }
376
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)377 void Worker::TracingAsync(const TracingRequest* request,
378 TracingResponse* response, StatusCallback done) {
379 done(errors::Unimplemented("Tracing"));
380 }
381
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)382 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
383 RecvBufResponse* response, StatusCallback done) {
384 // The base Worker class does not implement RecvBufAsync because
385 // it is not currently used for worker-to-worker communication. Use a
386 // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
387 // instead.
388 done(errors::Unimplemented("Worker::RecvBufAsync()"));
389 }
390
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)391 void Worker::CompleteGroupAsync(CallOptions* opts,
392 const CompleteGroupRequest* request,
393 CompleteGroupResponse* response,
394 StatusCallback done) {
395 if (env_->collective_executor_mgr) {
396 env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
397 request, response, &cancellation_manager_, done);
398 } else {
399 done(
400 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
401 }
402 }
403
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)404 void Worker::CompleteInstanceAsync(CallOptions* opts,
405 const CompleteInstanceRequest* request,
406 CompleteInstanceResponse* response,
407 StatusCallback done) {
408 if (env_->collective_executor_mgr) {
409 env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
410 request, response, &cancellation_manager_, done);
411 } else {
412 done(
413 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
414 }
415 }
416
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)417 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
418 GetStepSequenceResponse* response,
419 StatusCallback done) {
420 if (env_->collective_executor_mgr) {
421 env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
422 done);
423 } else {
424 done(
425 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
426 }
427 }
428
429 // Helper for RecvTensor. Validates "key" and returns the source
430 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)431 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
432 Device** src_dev) {
433 // Figures out which device the tensor is hosted on.
434 string local_name = DeviceNameUtils::LocalName(parsed.src_device);
435 TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
436
437 // Does the device have the right incarnation number we expect?
438 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
439 return errors::Aborted(
440 "RecvTensor expects a different device incarnation: ",
441 parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
442 ". Your worker job (\"",
443 env_->session_mgr->LegacySession()->worker_name,
444 "\") was probably restarted. Check your "
445 "worker job for the reason why it was restarted.");
446 }
447
448 return Status::OK();
449 }
450
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)451 void Worker::RecvTensorAsync(CallOptions* opts,
452 const RecvTensorRequest* request,
453 TensorResponse* response, StatusCallback done) {
454 // The base Worker class does not implement RecvTensorAsync, because
455 // it is not currently used for worker-to-worker communication. Use a
456 // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
457 // instead.
458 done(errors::Unimplemented("Worker::RecvTensorAsync()"));
459 }
460
461 } // namespace tensorflow
462