• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/worker.h"
17 
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_session.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/platform/tracing.h"
28 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
29 
30 namespace tensorflow {
31 
Worker(WorkerEnv * env)32 Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
33   // Enable log history collection in StatusGroup so that recent warning and
34   // error log messages will be attached to the root error status to be
35   // forwarded to the master.
36   StatusGroup::ConfigureLogHistory();
37 }
38 
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)39 void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
40                             GetStatusResponse* response, bool fail_fast,
41                             StatusCallback done) {
42   const DeviceMgr* dm = env_->device_mgr;
43   std::vector<DeviceAttributes> devices;
44   dm->ListDeviceAttributes(&devices);
45   response->mutable_device_attributes()->Reserve(devices.size());
46   for (auto& d : devices) {
47     response->add_device_attributes()->Swap(&d);
48   }
49   done(Status::OK());
50 }
51 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)52 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
53                                       CreateWorkerSessionResponse* response,
54                                       StatusCallback done) {
55   Status s = env_->session_mgr->CreateSession(
56       request->session_handle(), request->server_def(),
57       request->cluster_device_attributes(), request->isolate_session_state(),
58       request->master_task(), request->master_incarnation());
59   done(s);
60 }
61 
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)62 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
63                                       const DeleteWorkerSessionRequest* request,
64                                       DeleteWorkerSessionResponse* response,
65                                       StatusCallback done) {
66   Status s = env_->session_mgr->DeleteSession(request->session_handle());
67   done(s);
68 }
69 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)70 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
71                                 RegisterGraphResponse* response,
72                                 StatusCallback done) {
73   std::shared_ptr<WorkerSession> session;
74   Status s;
75   if (request->create_worker_session_called()) {
76     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
77                                                    &session);
78   } else {
79     session = env_->session_mgr->LegacySession();
80   }
81   if (s.ok()) {
82     s = session->graph_mgr()->Register(
83         request->session_handle(), request->graph_def(), session.get(),
84         request->graph_options(), request->debug_options(),
85         request->config_proto(), request->collective_graph_key(),
86         session->cluster_flr(), response->mutable_graph_handle());
87   }
88   done(s);
89 }
90 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)91 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
92                                   DeregisterGraphResponse* response,
93                                   StatusCallback done) {
94   std::shared_ptr<WorkerSession> session;
95   Status s;
96   if (request->create_worker_session_called()) {
97     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
98                                                    &session);
99   } else {
100     session = env_->session_mgr->LegacySession();
101   }
102   if (s.ok()) {
103     s = session->graph_mgr()->Deregister(request->graph_handle());
104   }
105 
106   done(s);
107 }
108 
AbortStep(int64_t step_id)109 void Worker::AbortStep(int64_t step_id) {
110   Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
111   SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
112     // Delay a bit before aborting the step. This way, the root
113     // cause may return first back to the client instead of this
114     // cancellation generated abort error.
115     rendez->StartAbort(errors::Aborted("Step ", step_id,
116                                        " cancelled.  Cancelling rendezvous."));
117     rendez->Unref();
118   });
119 }
120 
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)121 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
122                                GraphMgr::NamedTensors* in,
123                                GraphMgr::NamedTensors* out) {
124   static Tensor empty_tensor(DT_FLOAT);
125   if (req->num_sends() > 0) {
126     Tensor val;
127     for (size_t i = 0; i < req->num_sends(); ++i) {
128       TF_RETURN_IF_ERROR(req->SendValue(i, &val));
129       in->insert({req->send_key(i), val});
130     }
131   }
132   for (size_t i = 0; i < req->num_recvs(); ++i) {
133     out->insert({req->recv_key(i), empty_tensor});
134   }
135   return Status::OK();
136 }
137 
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)138 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
139                            MutableRunGraphResponseWrapper* response,
140                            StatusCallback done) {
141   if (request->store_errors_in_response_body()) {
142     done = [response, done](const Status& status) {
143       response->set_status(status);
144       done(Status::OK());
145     };
146   }
147   if (request->is_partial()) {
148     DoPartialRunGraph(opts, request, response, std::move(done));
149   } else {
150     DoRunGraph(opts, request, response, std::move(done));
151   }
152 }
153 
CreateRunGraphRequest()154 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
155   return new InMemoryRunGraphRequest;
156 }
157 
CreateRunGraphResponse()158 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
159   return new InMemoryRunGraphResponse;
160 }
161 
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)162 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
163                         MutableRunGraphResponseWrapper* response,
164                         StatusCallback done) {
165   const int64_t step_id = request->step_id();
166   TRACEPRINTF("RunGraph: %lld", step_id);
167   Status s = recent_request_ids_.TrackUnique(request->request_id(),
168                                              "RunGraph (Worker)", request);
169   if (!s.ok()) {
170     done(s);
171     return;
172   }
173 
174   std::shared_ptr<WorkerSession> session;
175   if (request->create_worker_session_called()) {
176     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
177                                                    &session);
178   } else {
179     session = env_->session_mgr->LegacySession();
180   }
181   if (!s.ok()) {
182     done(s);
183     return;
184   }
185   GraphMgr::NamedTensors in;
186   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
187   s = PrepareRunGraph(request, &in, out);
188   if (!s.ok()) {
189     delete out;
190     done(s);
191     return;
192   }
193   StepStatsCollector* collector = nullptr;
194   if (request->exec_opts().report_tensor_allocations_upon_oom() ||
195       request->exec_opts().record_timeline() ||
196       request->exec_opts().record_costs()) {
197     collector = new StepStatsCollector(response->mutable_step_stats());
198   }
199   DeviceProfilerSession* device_profiler_session = nullptr;
200   if (collector && request->exec_opts().record_timeline()) {
201     // If timeline was requested, assume we want hardware level tracing.
202     device_profiler_session = DeviceProfilerSession::Create().release();
203   }
204   CancellationManager* cm = new CancellationManager;
205   opts->SetCancelCallback([this, cm, step_id]() {
206     LOG(INFO) << "Cancellation requested for RunGraph.";
207     cm->StartCancel();
208     AbortStep(step_id);
209   });
210   CancellationToken token;
211   token = cancellation_manager_.get_cancellation_token();
212   bool already_cancelled = !cancellation_manager_.RegisterCallback(
213       token, [cm]() { cm->StartCancel(); });
214   if (already_cancelled) {
215     opts->ClearCancelCallback();
216     delete cm;
217     delete collector;
218     delete device_profiler_session;
219     delete out;
220     done(errors::Aborted("Call was aborted"));
221     return;
222   }
223   session->graph_mgr()->ExecuteAsync(
224       request->graph_handle(), step_id, session.get(), request->exec_opts(),
225       collector, response, cm, in,
226       [this, step_id, response, session, cm, out, token, collector,
227        device_profiler_session, opts, done](const Status& status) {
228         Status s = status;
229         if (s.ok()) {
230           s = session->graph_mgr()->RecvOutputs(step_id, out);
231         }
232 
233         opts->ClearCancelCallback();
234         cancellation_manager_.DeregisterCallback(token);
235         delete cm;
236 
237         if (device_profiler_session) {
238           device_profiler_session->CollectData(response->mutable_step_stats())
239               .IgnoreError();
240         }
241 
242         if (s.ok()) {
243           for (const auto& p : *out) {
244             const string& key = p.first;
245             const Tensor& val = p.second;
246             response->AddRecv(key, val);
247           }
248         }
249 
250         if (collector) collector->Finalize();
251         delete collector;
252         delete device_profiler_session;
253         delete out;
254         done(s);
255       });
256 }
257 
258 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)259 void Worker::DoPartialRunGraph(CallOptions* opts,
260                                RunGraphRequestWrapper* request,
261                                MutableRunGraphResponseWrapper* response,
262                                StatusCallback done) {
263   const int64_t step_id = request->step_id();
264   const string& graph_handle = request->graph_handle();
265   TRACEPRINTF("PartialRunGraph: %lld", step_id);
266   Status s = recent_request_ids_.TrackUnique(
267       request->request_id(), "PartialRunGraph (Worker)", request);
268   if (!s.ok()) {
269     done(s);
270     return;
271   }
272 
273   std::shared_ptr<WorkerSession> session;
274   if (request->create_worker_session_called()) {
275     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
276                                                    &session);
277   } else {
278     session = env_->session_mgr->LegacySession();
279   }
280   if (!s.ok()) {
281     done(s);
282     return;
283   }
284 
285   GraphMgr::NamedTensors in;
286   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
287   s = PrepareRunGraph(request, &in, out);
288   auto finish = [done, out, opts](const Status& s) {
289     opts->ClearCancelCallback();
290     delete out;
291     done(s);
292   };
293   if (!s.ok()) {
294     finish(s);
295     return;
296   }
297 
298   CancellationManager* cm = nullptr;
299   bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
300 
301   // Before we start doing anything, we set the RPC cancellation.
302   opts->SetCancelCallback([this, cm, step_id]() {
303     LOG(INFO) << "Cancellation requested for PartialRunGraph.";
304     cm->StartCancel();
305     AbortStep(step_id);
306   });
307 
308   // If this is a new partial run request, the request will need to start the
309   // executors.
310   if (is_new_partial_run) {
311     CancellationToken token;
312     token = cancellation_manager_.get_cancellation_token();
313     cancellation_manager_.RegisterCallback(token,
314                                            [cm]() { cm->StartCancel(); });
315     session->graph_mgr()->ExecuteAsync(
316         graph_handle, step_id, session.get(), request->exec_opts(),
317         nullptr /* collector */, nullptr /* response */, cm, in,
318         [this, token, step_id, session](Status s) {
319           cancellation_manager_.DeregisterCallback(token);
320           partial_run_mgr_.ExecutorDone(step_id, s);
321         });
322   } else {
323     // Send the partial run's new inputs.
324     s = session->graph_mgr()->SendInputs(step_id, in);
325     if (!s.ok()) {
326       finish(s);
327       return;
328     }
329   }
330 
331   session->graph_mgr()->RecvOutputsAsync(
332       step_id, out, [this, out, request, response, step_id, finish](Status s) {
333         if (s.ok()) {
334           // Construct and return the resp.
335           for (const auto& p : *out) {
336             const string& key = p.first;
337             const Tensor& val = p.second;
338             response->AddRecv(key, val);
339           }
340         }
341         if (request->is_last_partial_run()) {
342           partial_run_mgr_.PartialRunDone(step_id, finish, s);
343         } else {
344           finish(s);
345         }
346       });
347 }
348 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)349 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
350                                CleanupGraphResponse* response,
351                                StatusCallback done) {
352   const int64_t step_id = request->step_id();
353   env_->rendezvous_mgr->Cleanup(step_id);
354   if (env_->collective_executor_mgr) {
355     env_->collective_executor_mgr->Cleanup(step_id);
356   }
357   for (Device* d : env_->local_devices) {
358     ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
359     if (sam) {
360       sam->Cleanup(step_id);
361     }
362   }
363   done(Status::OK());
364 }
365 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)366 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
367                              CleanupAllResponse* response,
368                              StatusCallback done) {
369   std::vector<string> containers;
370   for (const auto& c : request->container()) containers.push_back(c);
371   env_->device_mgr->ClearContainers(containers);
372   done(Status::OK());
373 }
374 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)375 void Worker::LoggingAsync(const LoggingRequest* request,
376                           LoggingResponse* response, StatusCallback done) {
377   done(errors::Unimplemented("Logging"));
378 }
379 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)380 void Worker::TracingAsync(const TracingRequest* request,
381                           TracingResponse* response, StatusCallback done) {
382   done(errors::Unimplemented("Tracing"));
383 }
384 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)385 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
386                           RecvBufResponse* response, StatusCallback done) {
387   // The base Worker class does not implement RecvBufAsync because
388   // it is not currently used for worker-to-worker communication. Use a
389   // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
390   // instead.
391   done(errors::Unimplemented("Worker::RecvBufAsync()"));
392 }
393 
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)394 void Worker::CompleteGroupAsync(CallOptions* opts,
395                                 const CompleteGroupRequest* request,
396                                 CompleteGroupResponse* response,
397                                 StatusCallback done) {
398   if (!request->has_device_attributes()) {
399     done(errors::Internal(
400         "CompleteGroupRequest device_attributes is not set. Make sure you're "
401         "running the same version of Tensorflow on all workers."));
402     return;
403   }
404   if (env_->collective_executor_mgr) {
405     auto group_params = new CollGroupParams();
406     group_params->group_key = request->group_key();
407     group_params->group_size = request->group_size();
408     group_params->device_type = DeviceType(request->device_type());
409     env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
410         request->device_attributes(), group_params, &cancellation_manager_,
411         [response, group_params, done = std::move(done)](const Status& s) {
412           if (s.ok()) {
413             response->set_group_key(group_params->group_key);
414             response->set_group_size(group_params->group_size);
415             response->set_device_type(group_params->device_type.type_string());
416             response->set_num_tasks(group_params->num_tasks);
417             for (const DeviceAttributes& device : group_params->devices) {
418               *response->add_device_attributes() = device;
419             }
420             response->set_communicator_key(
421                 group_params->runtime_details.communicator_key);
422           } else {
423             LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
424           }
425           delete group_params;
426           done(s);
427         });
428   } else {
429     done(
430         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
431   }
432 }
433 
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)434 void Worker::CompleteInstanceAsync(CallOptions* opts,
435                                    const CompleteInstanceRequest* request,
436                                    CompleteInstanceResponse* response,
437                                    StatusCallback done) {
438   if (env_->collective_executor_mgr) {
439     env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
440         request, response, &cancellation_manager_, done);
441   } else {
442     done(
443         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
444   }
445 }
446 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)447 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
448                                   GetStepSequenceResponse* response,
449                                   StatusCallback done) {
450   if (env_->collective_executor_mgr) {
451     env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
452                                                         done);
453   } else {
454     done(
455         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
456   }
457 }
458 
459 // Helper for RecvTensor. Validates "key" and returns the source
460 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)461 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
462                                  Device** src_dev) {
463   // Figures out which device the tensor is hosted on.
464   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
465   TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
466 
467   // Does the device have the right incarnation number we expect?
468   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
469     return errors::Aborted(
470         "RecvTensor expects a different device incarnation: ",
471         parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
472         ". Your worker job (\"",
473         env_->session_mgr->LegacySession()->worker_name(),
474         "\") was probably restarted. Check your "
475         "worker job for the reason why it was restarted.");
476   }
477 
478   return Status::OK();
479 }
480 
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)481 void Worker::RecvTensorAsync(CallOptions* opts,
482                              const RecvTensorRequest* request,
483                              TensorResponse* response, StatusCallback done) {
484   // The base Worker class does not implement RecvTensorAsync, because
485   // it is not currently used for worker-to-worker communication. Use a
486   // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
487   // instead.
488   done(errors::Unimplemented("Worker::RecvTensorAsync()"));
489 }
490 
491 }  // namespace tensorflow
492