• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/worker.h"
17 
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_session.h"
26 #include "tensorflow/core/platform/tracing.h"
27 #include "tensorflow/core/profiler/lib/profiler_session.h"
28 
29 namespace tensorflow {
30 
Worker(WorkerEnv * env)31 Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
32   // Enable log history collection in StatusGroup so that recent warning and
33   // error log messages will be attached to the root error status to be
34   // forwarded to the master.
35   StatusGroup::ConfigureLogHistory();
36 }
37 
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)38 void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
39                             GetStatusResponse* response, bool fail_fast,
40                             StatusCallback done) {
41   const DeviceMgr* dm = env_->device_mgr;
42   std::vector<DeviceAttributes> devices;
43   dm->ListDeviceAttributes(&devices);
44   response->mutable_device_attributes()->Reserve(devices.size());
45   for (auto& d : devices) {
46     response->add_device_attributes()->Swap(&d);
47   }
48   done(Status::OK());
49 }
50 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)51 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
52                                       CreateWorkerSessionResponse* response,
53                                       StatusCallback done) {
54   Status s = env_->session_mgr->CreateSession(
55       request->session_handle(), request->server_def(),
56       request->cluster_device_attributes(), request->isolate_session_state(),
57       request->master_task(), request->master_incarnation());
58   done(s);
59 }
60 
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)61 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
62                                       const DeleteWorkerSessionRequest* request,
63                                       DeleteWorkerSessionResponse* response,
64                                       StatusCallback done) {
65   Status s = env_->session_mgr->DeleteSession(request->session_handle());
66   done(s);
67 }
68 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)69 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
70                                 RegisterGraphResponse* response,
71                                 StatusCallback done) {
72   std::shared_ptr<WorkerSession> session;
73   Status s;
74   if (request->create_worker_session_called()) {
75     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
76                                                    &session);
77   } else {
78     session = env_->session_mgr->LegacySession();
79   }
80   if (s.ok()) {
81     s = session->graph_mgr()->Register(
82         request->session_handle(), request->graph_def(), session.get(),
83         request->graph_options(), request->debug_options(),
84         request->config_proto(), request->collective_graph_key(),
85         session->cluster_flr(), response->mutable_graph_handle());
86   }
87   done(s);
88 }
89 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)90 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
91                                   DeregisterGraphResponse* response,
92                                   StatusCallback done) {
93   std::shared_ptr<WorkerSession> session;
94   Status s;
95   if (request->create_worker_session_called()) {
96     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
97                                                    &session);
98   } else {
99     session = env_->session_mgr->LegacySession();
100   }
101   if (s.ok()) {
102     s = session->graph_mgr()->Deregister(request->graph_handle());
103   }
104 
105   done(s);
106 }
107 
AbortStep(int64 step_id)108 void Worker::AbortStep(int64 step_id) {
109   Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
110   SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
111     // Delay a bit before aborting the step. This way, the root
112     // cause may return first back to the client instead of this
113     // cancellation generated abort error.
114     rendez->StartAbort(errors::Aborted("Step ", step_id,
115                                        " cancelled.  Cancelling rendezvous."));
116     rendez->Unref();
117   });
118 }
119 
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)120 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
121                                GraphMgr::NamedTensors* in,
122                                GraphMgr::NamedTensors* out) {
123   static Tensor empty_tensor(DT_FLOAT);
124   if (req->num_sends() > 0) {
125     Tensor val;
126     for (size_t i = 0; i < req->num_sends(); ++i) {
127       TF_RETURN_IF_ERROR(req->SendValue(i, &val));
128       in->insert({req->send_key(i), val});
129     }
130   }
131   for (size_t i = 0; i < req->num_recvs(); ++i) {
132     out->insert({req->recv_key(i), empty_tensor});
133   }
134   return Status::OK();
135 }
136 
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)137 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
138                            MutableRunGraphResponseWrapper* response,
139                            StatusCallback done) {
140   if (request->store_errors_in_response_body()) {
141     done = [response, done](const Status& status) {
142       response->set_status(status);
143       done(Status::OK());
144     };
145   }
146   if (request->is_partial()) {
147     DoPartialRunGraph(opts, request, response, std::move(done));
148   } else {
149     DoRunGraph(opts, request, response, std::move(done));
150   }
151 }
152 
CreateRunGraphRequest()153 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
154   return new InMemoryRunGraphRequest;
155 }
156 
CreateRunGraphResponse()157 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
158   return new InMemoryRunGraphResponse;
159 }
160 
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)161 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
162                         MutableRunGraphResponseWrapper* response,
163                         StatusCallback done) {
164   const int64 step_id = request->step_id();
165   TRACEPRINTF("RunGraph: %lld", step_id);
166   Status s = recent_request_ids_.TrackUnique(request->request_id(),
167                                              "RunGraph (Worker)", request);
168   if (!s.ok()) {
169     done(s);
170     return;
171   }
172 
173   std::shared_ptr<WorkerSession> session;
174   if (request->create_worker_session_called()) {
175     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
176                                                    &session);
177   } else {
178     session = env_->session_mgr->LegacySession();
179   }
180   if (!s.ok()) {
181     done(s);
182     return;
183   }
184   GraphMgr::NamedTensors in;
185   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
186   s = PrepareRunGraph(request, &in, out);
187   if (!s.ok()) {
188     delete out;
189     done(s);
190     return;
191   }
192   StepStatsCollector* collector = nullptr;
193   if (request->exec_opts().report_tensor_allocations_upon_oom() ||
194       request->exec_opts().record_timeline() ||
195       request->exec_opts().record_costs()) {
196     collector = new StepStatsCollector(response->mutable_step_stats());
197   }
198   ProfilerSession* profiler_session = nullptr;
199   if (collector && request->exec_opts().record_timeline()) {
200     // If timeline was requested, assume we want hardware level tracing.
201     ProfileOptions options = ProfilerSession::DefaultOptions();
202     options.set_host_tracer_level(0);
203     profiler_session = ProfilerSession::Create(options).release();
204   }
205   CancellationManager* cm = new CancellationManager;
206   opts->SetCancelCallback([this, cm, step_id]() {
207     LOG(INFO) << "Cancellation requested for RunGraph.";
208     cm->StartCancel();
209     AbortStep(step_id);
210   });
211   CancellationToken token;
212   token = cancellation_manager_.get_cancellation_token();
213   bool already_cancelled = !cancellation_manager_.RegisterCallback(
214       token, [cm]() { cm->StartCancel(); });
215   if (already_cancelled) {
216     opts->ClearCancelCallback();
217     delete cm;
218     delete collector;
219     delete profiler_session;
220     delete out;
221     done(errors::Aborted("Call was aborted"));
222     return;
223   }
224   session->graph_mgr()->ExecuteAsync(
225       request->graph_handle(), step_id, session.get(), request->exec_opts(),
226       collector, response, cm, in,
227       [this, step_id, response, session, cm, out, token, collector,
228        profiler_session, opts, done](const Status& status) {
229         Status s = status;
230         if (s.ok()) {
231           s = session->graph_mgr()->RecvOutputs(step_id, out);
232         }
233 
234         opts->ClearCancelCallback();
235         cancellation_manager_.DeregisterCallback(token);
236         delete cm;
237 
238         if (profiler_session) {
239           RunMetadata run_metadata;
240           profiler_session->CollectData(&run_metadata).IgnoreError();
241           response->mutable_step_stats()->MergeFrom(run_metadata.step_stats());
242         }
243 
244         if (s.ok()) {
245           for (const auto& p : *out) {
246             const string& key = p.first;
247             const Tensor& val = p.second;
248             response->AddRecv(key, val);
249           }
250         }
251 
252         if (collector) collector->Finalize();
253         delete collector;
254         delete profiler_session;
255         delete out;
256         done(s);
257       });
258 }
259 
260 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)261 void Worker::DoPartialRunGraph(CallOptions* opts,
262                                RunGraphRequestWrapper* request,
263                                MutableRunGraphResponseWrapper* response,
264                                StatusCallback done) {
265   const int64 step_id = request->step_id();
266   const string& graph_handle = request->graph_handle();
267   TRACEPRINTF("PartialRunGraph: %lld", step_id);
268   Status s = recent_request_ids_.TrackUnique(
269       request->request_id(), "PartialRunGraph (Worker)", request);
270   if (!s.ok()) {
271     done(s);
272     return;
273   }
274 
275   std::shared_ptr<WorkerSession> session;
276   if (request->create_worker_session_called()) {
277     s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
278                                                    &session);
279   } else {
280     session = env_->session_mgr->LegacySession();
281   }
282   if (!s.ok()) {
283     done(s);
284     return;
285   }
286 
287   GraphMgr::NamedTensors in;
288   GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
289   s = PrepareRunGraph(request, &in, out);
290   auto finish = [done, out, opts](const Status& s) {
291     opts->ClearCancelCallback();
292     delete out;
293     done(s);
294   };
295   if (!s.ok()) {
296     finish(s);
297     return;
298   }
299 
300   CancellationManager* cm = nullptr;
301   bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
302 
303   // Before we start doing anything, we set the RPC cancellation.
304   opts->SetCancelCallback([this, cm, step_id]() {
305     LOG(INFO) << "Cancellation requested for PartialRunGraph.";
306     cm->StartCancel();
307     AbortStep(step_id);
308   });
309 
310   // If this is a new partial run request, the request will need to start the
311   // executors.
312   if (is_new_partial_run) {
313     CancellationToken token;
314     token = cancellation_manager_.get_cancellation_token();
315     cancellation_manager_.RegisterCallback(token,
316                                            [cm]() { cm->StartCancel(); });
317     session->graph_mgr()->ExecuteAsync(
318         graph_handle, step_id, session.get(), request->exec_opts(),
319         nullptr /* collector */, nullptr /* response */, cm, in,
320         [this, token, step_id, session](Status s) {
321           cancellation_manager_.DeregisterCallback(token);
322           partial_run_mgr_.ExecutorDone(step_id, s);
323         });
324   } else {
325     // Send the partial run's new inputs.
326     s = session->graph_mgr()->SendInputs(step_id, in);
327     if (!s.ok()) {
328       finish(s);
329       return;
330     }
331   }
332 
333   session->graph_mgr()->RecvOutputsAsync(
334       step_id, out, [this, out, request, response, step_id, finish](Status s) {
335         if (s.ok()) {
336           // Construct and return the resp.
337           for (const auto& p : *out) {
338             const string& key = p.first;
339             const Tensor& val = p.second;
340             response->AddRecv(key, val);
341           }
342         }
343         if (request->is_last_partial_run()) {
344           partial_run_mgr_.PartialRunDone(step_id, finish, s);
345         } else {
346           finish(s);
347         }
348       });
349 }
350 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)351 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
352                                CleanupGraphResponse* response,
353                                StatusCallback done) {
354   const int64 step_id = request->step_id();
355   env_->rendezvous_mgr->Cleanup(step_id);
356   if (env_->collective_executor_mgr) {
357     env_->collective_executor_mgr->Cleanup(step_id);
358   }
359   for (Device* d : env_->local_devices) {
360     ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
361     if (sam) {
362       sam->Cleanup(step_id);
363     }
364   }
365   done(Status::OK());
366 }
367 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)368 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
369                              CleanupAllResponse* response,
370                              StatusCallback done) {
371   std::vector<string> containers;
372   for (const auto& c : request->container()) containers.push_back(c);
373   env_->device_mgr->ClearContainers(containers);
374   done(Status::OK());
375 }
376 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)377 void Worker::LoggingAsync(const LoggingRequest* request,
378                           LoggingResponse* response, StatusCallback done) {
379   done(errors::Unimplemented("Logging"));
380 }
381 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)382 void Worker::TracingAsync(const TracingRequest* request,
383                           TracingResponse* response, StatusCallback done) {
384   done(errors::Unimplemented("Tracing"));
385 }
386 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)387 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
388                           RecvBufResponse* response, StatusCallback done) {
389   // The base Worker class does not implement RecvBufAsync because
390   // it is not currently used for worker-to-worker communication. Use a
391   // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
392   // instead.
393   done(errors::Unimplemented("Worker::RecvBufAsync()"));
394 }
395 
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)396 void Worker::CompleteGroupAsync(CallOptions* opts,
397                                 const CompleteGroupRequest* request,
398                                 CompleteGroupResponse* response,
399                                 StatusCallback done) {
400   if (env_->collective_executor_mgr) {
401     env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
402         request, response, &cancellation_manager_, done);
403   } else {
404     done(
405         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
406   }
407 }
408 
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)409 void Worker::CompleteInstanceAsync(CallOptions* opts,
410                                    const CompleteInstanceRequest* request,
411                                    CompleteInstanceResponse* response,
412                                    StatusCallback done) {
413   if (env_->collective_executor_mgr) {
414     env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
415         request, response, &cancellation_manager_, done);
416   } else {
417     done(
418         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
419   }
420 }
421 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)422 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
423                                   GetStepSequenceResponse* response,
424                                   StatusCallback done) {
425   if (env_->collective_executor_mgr) {
426     env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
427                                                         done);
428   } else {
429     done(
430         errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
431   }
432 }
433 
434 // Helper for RecvTensor. Validates "key" and returns the source
435 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)436 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
437                                  Device** src_dev) {
438   // Figures out which device the tensor is hosted on.
439   string local_name = DeviceNameUtils::LocalName(parsed.src_device);
440   TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
441 
442   // Does the device have the right incarnation number we expect?
443   if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
444     return errors::Aborted(
445         "RecvTensor expects a different device incarnation: ",
446         parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
447         ". Your worker job (\"",
448         env_->session_mgr->LegacySession()->worker_name(),
449         "\") was probably restarted. Check your "
450         "worker job for the reason why it was restarted.");
451   }
452 
453   return Status::OK();
454 }
455 
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)456 void Worker::RecvTensorAsync(CallOptions* opts,
457                              const RecvTensorRequest* request,
458                              TensorResponse* response, StatusCallback done) {
459   // The base Worker class does not implement RecvTensorAsync, because
460   // it is not currently used for worker-to-worker communication. Use a
461   // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
462   // instead.
463   done(errors::Unimplemented("Worker::RecvTensorAsync()"));
464 }
465 
466 }  // namespace tensorflow
467