1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/worker.h"
17
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_session.h"
26 #include "tensorflow/core/platform/tracing.h"
27 #include "tensorflow/core/profiler/lib/profiler_session.h"
28
29 namespace tensorflow {
30
Worker(WorkerEnv * env)31 Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
32 // Enable log history collection in StatusGroup so that recent warning and
33 // error log messages will be attached to the root error status to be
34 // forwarded to the master.
35 StatusGroup::ConfigureLogHistory();
36 }
37
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)38 void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
39 GetStatusResponse* response, bool fail_fast,
40 StatusCallback done) {
41 const DeviceMgr* dm = env_->device_mgr;
42 std::vector<DeviceAttributes> devices;
43 dm->ListDeviceAttributes(&devices);
44 response->mutable_device_attributes()->Reserve(devices.size());
45 for (auto& d : devices) {
46 response->add_device_attributes()->Swap(&d);
47 }
48 done(Status::OK());
49 }
50
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)51 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
52 CreateWorkerSessionResponse* response,
53 StatusCallback done) {
54 Status s = env_->session_mgr->CreateSession(
55 request->session_handle(), request->server_def(),
56 request->cluster_device_attributes(), request->isolate_session_state(),
57 request->master_task(), request->master_incarnation());
58 done(s);
59 }
60
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)61 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
62 const DeleteWorkerSessionRequest* request,
63 DeleteWorkerSessionResponse* response,
64 StatusCallback done) {
65 Status s = env_->session_mgr->DeleteSession(request->session_handle());
66 done(s);
67 }
68
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)69 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
70 RegisterGraphResponse* response,
71 StatusCallback done) {
72 std::shared_ptr<WorkerSession> session;
73 Status s;
74 if (request->create_worker_session_called()) {
75 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
76 &session);
77 } else {
78 session = env_->session_mgr->LegacySession();
79 }
80 if (s.ok()) {
81 s = session->graph_mgr()->Register(
82 request->session_handle(), request->graph_def(), session.get(),
83 request->graph_options(), request->debug_options(),
84 request->config_proto(), request->collective_graph_key(),
85 session->cluster_flr(), response->mutable_graph_handle());
86 }
87 done(s);
88 }
89
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)90 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
91 DeregisterGraphResponse* response,
92 StatusCallback done) {
93 std::shared_ptr<WorkerSession> session;
94 Status s;
95 if (request->create_worker_session_called()) {
96 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
97 &session);
98 } else {
99 session = env_->session_mgr->LegacySession();
100 }
101 if (s.ok()) {
102 s = session->graph_mgr()->Deregister(request->graph_handle());
103 }
104
105 done(s);
106 }
107
AbortStep(int64 step_id)108 void Worker::AbortStep(int64 step_id) {
109 Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
110 SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
111 // Delay a bit before aborting the step. This way, the root
112 // cause may return first back to the client instead of this
113 // cancellation generated abort error.
114 rendez->StartAbort(errors::Aborted("Step ", step_id,
115 " cancelled. Cancelling rendezvous."));
116 rendez->Unref();
117 });
118 }
119
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)120 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
121 GraphMgr::NamedTensors* in,
122 GraphMgr::NamedTensors* out) {
123 static Tensor empty_tensor(DT_FLOAT);
124 if (req->num_sends() > 0) {
125 Tensor val;
126 for (size_t i = 0; i < req->num_sends(); ++i) {
127 TF_RETURN_IF_ERROR(req->SendValue(i, &val));
128 in->insert({req->send_key(i), val});
129 }
130 }
131 for (size_t i = 0; i < req->num_recvs(); ++i) {
132 out->insert({req->recv_key(i), empty_tensor});
133 }
134 return Status::OK();
135 }
136
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)137 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
138 MutableRunGraphResponseWrapper* response,
139 StatusCallback done) {
140 if (request->store_errors_in_response_body()) {
141 done = [response, done](const Status& status) {
142 response->set_status(status);
143 done(Status::OK());
144 };
145 }
146 if (request->is_partial()) {
147 DoPartialRunGraph(opts, request, response, std::move(done));
148 } else {
149 DoRunGraph(opts, request, response, std::move(done));
150 }
151 }
152
CreateRunGraphRequest()153 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
154 return new InMemoryRunGraphRequest;
155 }
156
CreateRunGraphResponse()157 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
158 return new InMemoryRunGraphResponse;
159 }
160
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)161 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
162 MutableRunGraphResponseWrapper* response,
163 StatusCallback done) {
164 const int64 step_id = request->step_id();
165 TRACEPRINTF("RunGraph: %lld", step_id);
166 Status s = recent_request_ids_.TrackUnique(request->request_id(),
167 "RunGraph (Worker)", request);
168 if (!s.ok()) {
169 done(s);
170 return;
171 }
172
173 std::shared_ptr<WorkerSession> session;
174 if (request->create_worker_session_called()) {
175 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
176 &session);
177 } else {
178 session = env_->session_mgr->LegacySession();
179 }
180 if (!s.ok()) {
181 done(s);
182 return;
183 }
184 GraphMgr::NamedTensors in;
185 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
186 s = PrepareRunGraph(request, &in, out);
187 if (!s.ok()) {
188 delete out;
189 done(s);
190 return;
191 }
192 StepStatsCollector* collector = nullptr;
193 if (request->exec_opts().report_tensor_allocations_upon_oom() ||
194 request->exec_opts().record_timeline() ||
195 request->exec_opts().record_costs()) {
196 collector = new StepStatsCollector(response->mutable_step_stats());
197 }
198 ProfilerSession* profiler_session = nullptr;
199 if (collector && request->exec_opts().record_timeline()) {
200 // If timeline was requested, assume we want hardware level tracing.
201 ProfileOptions options = ProfilerSession::DefaultOptions();
202 options.set_host_tracer_level(0);
203 profiler_session = ProfilerSession::Create(options).release();
204 }
205 CancellationManager* cm = new CancellationManager;
206 opts->SetCancelCallback([this, cm, step_id]() {
207 LOG(INFO) << "Cancellation requested for RunGraph.";
208 cm->StartCancel();
209 AbortStep(step_id);
210 });
211 CancellationToken token;
212 token = cancellation_manager_.get_cancellation_token();
213 bool already_cancelled = !cancellation_manager_.RegisterCallback(
214 token, [cm]() { cm->StartCancel(); });
215 if (already_cancelled) {
216 opts->ClearCancelCallback();
217 delete cm;
218 delete collector;
219 delete profiler_session;
220 delete out;
221 done(errors::Aborted("Call was aborted"));
222 return;
223 }
224 session->graph_mgr()->ExecuteAsync(
225 request->graph_handle(), step_id, session.get(), request->exec_opts(),
226 collector, response, cm, in,
227 [this, step_id, response, session, cm, out, token, collector,
228 profiler_session, opts, done](const Status& status) {
229 Status s = status;
230 if (s.ok()) {
231 s = session->graph_mgr()->RecvOutputs(step_id, out);
232 }
233
234 opts->ClearCancelCallback();
235 cancellation_manager_.DeregisterCallback(token);
236 delete cm;
237
238 if (profiler_session) {
239 RunMetadata run_metadata;
240 profiler_session->CollectData(&run_metadata).IgnoreError();
241 response->mutable_step_stats()->MergeFrom(run_metadata.step_stats());
242 }
243
244 if (s.ok()) {
245 for (const auto& p : *out) {
246 const string& key = p.first;
247 const Tensor& val = p.second;
248 response->AddRecv(key, val);
249 }
250 }
251
252 if (collector) collector->Finalize();
253 delete collector;
254 delete profiler_session;
255 delete out;
256 done(s);
257 });
258 }
259
260 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)261 void Worker::DoPartialRunGraph(CallOptions* opts,
262 RunGraphRequestWrapper* request,
263 MutableRunGraphResponseWrapper* response,
264 StatusCallback done) {
265 const int64 step_id = request->step_id();
266 const string& graph_handle = request->graph_handle();
267 TRACEPRINTF("PartialRunGraph: %lld", step_id);
268 Status s = recent_request_ids_.TrackUnique(
269 request->request_id(), "PartialRunGraph (Worker)", request);
270 if (!s.ok()) {
271 done(s);
272 return;
273 }
274
275 std::shared_ptr<WorkerSession> session;
276 if (request->create_worker_session_called()) {
277 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
278 &session);
279 } else {
280 session = env_->session_mgr->LegacySession();
281 }
282 if (!s.ok()) {
283 done(s);
284 return;
285 }
286
287 GraphMgr::NamedTensors in;
288 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
289 s = PrepareRunGraph(request, &in, out);
290 auto finish = [done, out, opts](const Status& s) {
291 opts->ClearCancelCallback();
292 delete out;
293 done(s);
294 };
295 if (!s.ok()) {
296 finish(s);
297 return;
298 }
299
300 CancellationManager* cm = nullptr;
301 bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
302
303 // Before we start doing anything, we set the RPC cancellation.
304 opts->SetCancelCallback([this, cm, step_id]() {
305 LOG(INFO) << "Cancellation requested for PartialRunGraph.";
306 cm->StartCancel();
307 AbortStep(step_id);
308 });
309
310 // If this is a new partial run request, the request will need to start the
311 // executors.
312 if (is_new_partial_run) {
313 CancellationToken token;
314 token = cancellation_manager_.get_cancellation_token();
315 cancellation_manager_.RegisterCallback(token,
316 [cm]() { cm->StartCancel(); });
317 session->graph_mgr()->ExecuteAsync(
318 graph_handle, step_id, session.get(), request->exec_opts(),
319 nullptr /* collector */, nullptr /* response */, cm, in,
320 [this, token, step_id, session](Status s) {
321 cancellation_manager_.DeregisterCallback(token);
322 partial_run_mgr_.ExecutorDone(step_id, s);
323 });
324 } else {
325 // Send the partial run's new inputs.
326 s = session->graph_mgr()->SendInputs(step_id, in);
327 if (!s.ok()) {
328 finish(s);
329 return;
330 }
331 }
332
333 session->graph_mgr()->RecvOutputsAsync(
334 step_id, out, [this, out, request, response, step_id, finish](Status s) {
335 if (s.ok()) {
336 // Construct and return the resp.
337 for (const auto& p : *out) {
338 const string& key = p.first;
339 const Tensor& val = p.second;
340 response->AddRecv(key, val);
341 }
342 }
343 if (request->is_last_partial_run()) {
344 partial_run_mgr_.PartialRunDone(step_id, finish, s);
345 } else {
346 finish(s);
347 }
348 });
349 }
350
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)351 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
352 CleanupGraphResponse* response,
353 StatusCallback done) {
354 const int64 step_id = request->step_id();
355 env_->rendezvous_mgr->Cleanup(step_id);
356 if (env_->collective_executor_mgr) {
357 env_->collective_executor_mgr->Cleanup(step_id);
358 }
359 for (Device* d : env_->local_devices) {
360 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
361 if (sam) {
362 sam->Cleanup(step_id);
363 }
364 }
365 done(Status::OK());
366 }
367
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)368 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
369 CleanupAllResponse* response,
370 StatusCallback done) {
371 std::vector<string> containers;
372 for (const auto& c : request->container()) containers.push_back(c);
373 env_->device_mgr->ClearContainers(containers);
374 done(Status::OK());
375 }
376
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)377 void Worker::LoggingAsync(const LoggingRequest* request,
378 LoggingResponse* response, StatusCallback done) {
379 done(errors::Unimplemented("Logging"));
380 }
381
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)382 void Worker::TracingAsync(const TracingRequest* request,
383 TracingResponse* response, StatusCallback done) {
384 done(errors::Unimplemented("Tracing"));
385 }
386
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)387 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
388 RecvBufResponse* response, StatusCallback done) {
389 // The base Worker class does not implement RecvBufAsync because
390 // it is not currently used for worker-to-worker communication. Use a
391 // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
392 // instead.
393 done(errors::Unimplemented("Worker::RecvBufAsync()"));
394 }
395
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)396 void Worker::CompleteGroupAsync(CallOptions* opts,
397 const CompleteGroupRequest* request,
398 CompleteGroupResponse* response,
399 StatusCallback done) {
400 if (env_->collective_executor_mgr) {
401 env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
402 request, response, &cancellation_manager_, done);
403 } else {
404 done(
405 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
406 }
407 }
408
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)409 void Worker::CompleteInstanceAsync(CallOptions* opts,
410 const CompleteInstanceRequest* request,
411 CompleteInstanceResponse* response,
412 StatusCallback done) {
413 if (env_->collective_executor_mgr) {
414 env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
415 request, response, &cancellation_manager_, done);
416 } else {
417 done(
418 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
419 }
420 }
421
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)422 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
423 GetStepSequenceResponse* response,
424 StatusCallback done) {
425 if (env_->collective_executor_mgr) {
426 env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
427 done);
428 } else {
429 done(
430 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
431 }
432 }
433
434 // Helper for RecvTensor. Validates "key" and returns the source
435 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)436 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
437 Device** src_dev) {
438 // Figures out which device the tensor is hosted on.
439 string local_name = DeviceNameUtils::LocalName(parsed.src_device);
440 TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
441
442 // Does the device have the right incarnation number we expect?
443 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
444 return errors::Aborted(
445 "RecvTensor expects a different device incarnation: ",
446 parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
447 ". Your worker job (\"",
448 env_->session_mgr->LegacySession()->worker_name(),
449 "\") was probably restarted. Check your "
450 "worker job for the reason why it was restarted.");
451 }
452
453 return Status::OK();
454 }
455
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)456 void Worker::RecvTensorAsync(CallOptions* opts,
457 const RecvTensorRequest* request,
458 TensorResponse* response, StatusCallback done) {
459 // The base Worker class does not implement RecvTensorAsync, because
460 // it is not currently used for worker-to-worker communication. Use a
461 // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
462 // instead.
463 done(errors::Unimplemented("Worker::RecvTensorAsync()"));
464 }
465
466 } // namespace tensorflow
467