• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/worker.h"
17 
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_session.h"
26 #include "tensorflow/core/platform/device_tracer.h"
27 #include "tensorflow/core/platform/tracing.h"
28 
29 namespace tensorflow {
30 
Worker(WorkerEnv * env)31 Worker::Worker(WorkerEnv* env) : env_(env) {}
32 
GetStatusAsync(const GetStatusRequest * request,GetStatusResponse * response,StatusCallback done)33 void Worker::GetStatusAsync(const GetStatusRequest* request,
34                             GetStatusResponse* response, StatusCallback done) {
35   DeviceMgr* dm = env_->device_mgr;
36   std::vector<DeviceAttributes> devices;
37   dm->ListDeviceAttributes(&devices);
38   response->mutable_device_attributes()->Reserve(devices.size());
39   for (auto& d : devices) {
40     response->add_device_attributes()->Swap(&d);
41   }
42   done(Status::OK());
43 }
44 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)45 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
46                                       CreateWorkerSessionResponse* response,
47                                       StatusCallback done) {
48   Status s = env_->session_mgr->CreateSession(request->session_handle(),
49                                               request->server_def(),
50                                               request->isolate_session_state());
51   done(s);
52 }
53 
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)54 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
55                                       const DeleteWorkerSessionRequest* request,
56                                       DeleteWorkerSessionResponse* response,
57                                       StatusCallback done) {
58   Status s = env_->session_mgr->DeleteSession(request->session_handle());
59   done(s);
60 }
61 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)62 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
63                                 RegisterGraphResponse* response,
64                                 StatusCallback done) {
65   std::shared_ptr<WorkerSession> session;
66   Status s;
67   if (request->create_worker_session_called()) {
68     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
69                                                    &session);
70   } else {
71     session = env_->session_mgr->LegacySession();
72   }
73   if (s.ok()) {
74     s = session->graph_mgr->Register(
75         request->session_handle(), request->graph_def(),
76         request->graph_options(), request->debug_options(),
77         request->collective_graph_key(), session->cluster_flr.get(),
78         response->mutable_graph_handle());
79   }
80   done(s);
81 }
82 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)83 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
84                                   DeregisterGraphResponse* response,
85                                   StatusCallback done) {
86   std::shared_ptr<WorkerSession> session;
87   Status s;
88   if (request->create_worker_session_called()) {
89     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
90                                                    &session);
91   } else {
92     session = env_->session_mgr->LegacySession();
93   }
94   if (s.ok()) {
95     s = session->graph_mgr->Deregister(request->graph_handle());
96   }
97 
98   done(s);
99 }
100 
AbortStep(int64 step_id)101 void Worker::AbortStep(int64 step_id) {
102   Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
103   SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
104     // Delay a bit before aborting the step. This way, the root
105     // cause may return first back to the client instead of this
106     // cancellation generated abort error.
107     rendez->StartAbort(errors::Aborted("Step ", step_id,
108                                        " cancelled.  Cancelling rendezvous."));
109     rendez->Unref();
110   });
111 }
112 
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)113 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
114                                GraphMgr::NamedTensors* in,
115                                GraphMgr::NamedTensors* out) {
116   static Tensor empty_tensor(DT_FLOAT);
117   if (req->num_sends() > 0) {
118     Tensor val;
119     for (size_t i = 0; i < req->num_sends(); ++i) {
120       TF_RETURN_IF_ERROR(req->SendValue(i, &val));
121       in->insert({req->send_key(i), val});
122     }
123   }
124   for (size_t i = 0; i < req->num_recvs(); ++i) {
125     out->insert({req->recv_key(i), empty_tensor});
126   }
127   return Status::OK();
128 }
129 
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)130 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
131                            MutableRunGraphResponseWrapper* response,
132                            StatusCallback done) {
133   if (request->store_errors_in_response_body()) {
134     done = [response, done](const Status& status) {
135       response->set_status(status);
136       done(Status::OK());
137     };
138   }
139   if (request->is_partial()) {
140     DoPartialRunGraph(opts, request, response, std::move(done));
141   } else {
142     DoRunGraph(opts, request, response, std::move(done));
143   }
144 }
145 
CreateRunGraphRequest()146 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
147   return new InMemoryRunGraphRequest;
148 }
149 
CreateRunGraphResponse()150 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
151   return new InMemoryRunGraphResponse;
152 }
153 
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)154 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
155                         MutableRunGraphResponseWrapper* response,
156                         StatusCallback done) {
157   const int64 step_id = request->step_id();
158   TRACEPRINTF("RunGraph: %lld", step_id);
159   std::shared_ptr<WorkerSession> session;
160   Status s;
161   if (request->create_worker_session_called()) {
162     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
163                                                    &session);
164   } else {
165     session = env_->session_mgr->LegacySession();
166   }
167   if (!s.ok()) {
168     done(s);
169     return;
170   }
171   GraphMgr::NamedTensors in;
172   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
173   s = PrepareRunGraph(request, &in, out);
174   if (!s.ok()) {
175     delete out;
176     done(s);
177     return;
178   }
179   StepStatsCollector* collector = nullptr;
180   if (request->exec_opts().report_tensor_allocations_upon_oom() ||
181       request->exec_opts().record_timeline() ||
182       request->exec_opts().record_costs()) {
183     collector = new StepStatsCollector(response->mutable_step_stats());
184   }
185   DeviceTracer* tracer = nullptr;
186   if (collector && request->exec_opts().record_timeline()) {
187     // If timeline was requested, assume we want hardware level tracing.
188     std::unique_ptr<DeviceTracer> trptr = CreateDeviceTracer();
189     if (trptr) {
190       tracer = trptr.release();
191       Status s = tracer->Start();
192       if (!s.ok()) {
193         delete tracer;
194         if (errors::IsUnavailable(s)) {
195           LOG(WARNING)
196               << "Hardware tracing unavailable, continuing without it. " << s;
197           tracer = nullptr;
198         } else {
199           delete collector;
200           delete out;
201           done(s);
202           return;
203         }
204       }
205     }
206   }
207   CancellationManager* cm = new CancellationManager;
208   opts->SetCancelCallback([this, cm, step_id]() {
209     cm->StartCancel();
210     AbortStep(step_id);
211   });
212   CancellationToken token;
213   token = cancellation_manager_.get_cancellation_token();
214   bool already_cancelled = !cancellation_manager_.RegisterCallback(
215       token, [cm]() { cm->StartCancel(); });
216   if (already_cancelled) {
217     opts->ClearCancelCallback();
218     delete cm;
219     delete collector;
220     delete tracer;
221     delete out;
222     done(errors::Aborted("Call was aborted"));
223     return;
224   }
225   session->graph_mgr->ExecuteAsync(
226       request->graph_handle(), step_id, session.get(), request->exec_opts(),
227       collector, response, cm, in,
228       [this, step_id, response, session, cm, out, token, collector, tracer,
229        opts, done](Status s) {
230         if (s.ok()) {
231           s = session->graph_mgr->RecvOutputs(step_id, out);
232         }
233         opts->ClearCancelCallback();
234         cancellation_manager_.DeregisterCallback(token);
235         delete cm;
236 
237         if (tracer) {
238           Status tracer_status = tracer->Stop();
239           if (tracer_status.ok()) {
240             tracer_status = tracer->Collect(collector);
241           }
242           if (!tracer_status.ok()) {
243             LOG(ERROR) << "Bad status from tracer: " << tracer_status;
244           }
245         }
246         if (s.ok()) {
247           for (const auto& p : *out) {
248             const string& key = p.first;
249             const Tensor& val = p.second;
250             response->AddRecv(key, val);
251           }
252         }
253         if (collector) collector->Finalize();
254         delete collector;
255         delete tracer;
256         delete out;
257         done(s);
258       });
259 }
260 
261 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)262 void Worker::DoPartialRunGraph(CallOptions* opts,
263                                RunGraphRequestWrapper* request,
264                                MutableRunGraphResponseWrapper* response,
265                                StatusCallback done) {
266   const int64 step_id = request->step_id();
267   const string& graph_handle = request->graph_handle();
268   TRACEPRINTF("PartialRunGraph: %lld", step_id);
269   std::shared_ptr<WorkerSession> session;
270 
271   Status s;
272   if (request->create_worker_session_called()) {
273     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
274                                                    &session);
275   } else {
276     session = env_->session_mgr->LegacySession();
277   }
278   if (!s.ok()) {
279     done(s);
280     return;
281   }
282 
283   GraphMgr::NamedTensors in;
284   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
285   s = PrepareRunGraph(request, &in, out);
286   auto finish = [done, out, opts](const Status& s) {
287     opts->ClearCancelCallback();
288     delete out;
289     done(s);
290   };
291   if (!s.ok()) {
292     finish(s);
293     return;
294   }
295 
296   CancellationManager* cm = nullptr;
297   bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
298 
299   // Before we start doing anything, we set the RPC cancellation.
300   opts->SetCancelCallback([this, cm, step_id]() {
301     cm->StartCancel();
302     AbortStep(step_id);
303   });
304 
305   // If this is a new partial run request, the request will need to start the
306   // executors.
307   if (is_new_partial_run) {
308     CancellationToken token;
309     token = cancellation_manager_.get_cancellation_token();
310     cancellation_manager_.RegisterCallback(token,
311                                            [cm]() { cm->StartCancel(); });
312     session->graph_mgr->ExecuteAsync(
313         graph_handle, step_id, session.get(), request->exec_opts(),
314         nullptr /* collector */, nullptr /* response */, cm, in,
315         [this, token, step_id, session](Status s) {
316           cancellation_manager_.DeregisterCallback(token);
317           partial_run_mgr_.ExecutorDone(step_id, s);
318         });
319   } else {
320     // Send the partial run's new inputs.
321     s = session->graph_mgr->SendInputs(step_id, in);
322     if (!s.ok()) {
323       finish(s);
324       return;
325     }
326   }
327 
328   session->graph_mgr->RecvOutputsAsync(
329       step_id, out, [this, out, request, response, step_id, finish](Status s) {
330         if (s.ok()) {
331           // Construct and return the resp.
332           for (const auto& p : *out) {
333             const string& key = p.first;
334             const Tensor& val = p.second;
335             response->AddRecv(key, val);
336           }
337         }
338         if (request->is_last_partial_run()) {
339           partial_run_mgr_.PartialRunDone(step_id, finish, s);
340         } else {
341           finish(s);
342         }
343       });
344 }
345 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)346 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
347                                CleanupGraphResponse* response,
348                                StatusCallback done) {
349   const int64 step_id = request->step_id();
350   env_->rendezvous_mgr->Cleanup(step_id);
351   if (env_->collective_executor_mgr) {
352     env_->collective_executor_mgr->Cleanup(step_id);
353   }
354   for (Device* d : env_->local_devices) {
355     ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
356     if (sam) {
357       sam->Cleanup(step_id);
358     }
359   }
360   done(Status::OK());
361 }
362 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)363 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
364                              CleanupAllResponse* response,
365                              StatusCallback done) {
366   std::vector<string> containers;
367   for (const auto& c : request->container()) containers.push_back(c);
368   env_->device_mgr->ClearContainers(containers);
369   done(Status::OK());
370 }
371 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)372 void Worker::LoggingAsync(const LoggingRequest* request,
373                           LoggingResponse* response, StatusCallback done) {
374   done(errors::Unimplemented("Logging"));
375 }
376 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)377 void Worker::TracingAsync(const TracingRequest* request,
378                           TracingResponse* response, StatusCallback done) {
379   done(errors::Unimplemented("Tracing"));
380 }
381 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)382 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
383                           RecvBufResponse* response, StatusCallback done) {
384   // The base Worker class does not implement RecvBufAsync because
385   // it is not currently used for worker-to-worker communication. Use a
386   // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
387   // instead.
388   done(errors::Unimplemented("Worker::RecvBufAsync()"));
389 }
390 
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)391 void Worker::CompleteGroupAsync(CallOptions* opts,
392                                 const CompleteGroupRequest* request,
393                                 CompleteGroupResponse* response,
394                                 StatusCallback done) {
395   if (env_->collective_executor_mgr) {
396     env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
397         request, response, &cancellation_manager_, done);
398   } else {
399     done(
400         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
401   }
402 }
403 
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)404 void Worker::CompleteInstanceAsync(CallOptions* opts,
405                                    const CompleteInstanceRequest* request,
406                                    CompleteInstanceResponse* response,
407                                    StatusCallback done) {
408   if (env_->collective_executor_mgr) {
409     env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
410         request, response, &cancellation_manager_, done);
411   } else {
412     done(
413         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
414   }
415 }
416 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)417 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
418                                   GetStepSequenceResponse* response,
419                                   StatusCallback done) {
420   if (env_->collective_executor_mgr) {
421     env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
422                                                         done);
423   } else {
424     done(
425         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
426   }
427 }
428 
429 // Helper for RecvTensor. Validates "key" and returns the source
430 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)431 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
432                                  Device** src_dev) {
433   // Figures out which device the tensor is hosted on.
434   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
435   TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
436 
437   // Does the device have the right incarnation number we expect?
438   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
439     return errors::Aborted(
440         "RecvTensor expects a different device incarnation: ",
441         parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
442         ". Your worker job (\"",
443         env_->session_mgr->LegacySession()->worker_name,
444         "\") was probably restarted. Check your "
445         "worker job for the reason why it was restarted.");
446   }
447 
448   return Status::OK();
449 }
450 
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)451 void Worker::RecvTensorAsync(CallOptions* opts,
452                              const RecvTensorRequest* request,
453                              TensorResponse* response, StatusCallback done) {
454   // The base Worker class does not implement RecvTensorAsync, because
455   // it is not currently used for worker-to-worker communication. Use a
456   // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
457   // instead.
458   done(errors::Unimplemented("Worker::RecvTensorAsync()"));
459 }
460 
461 }  // namespace tensorflow
462