1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/worker.h"
17
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/device_mgr.h"
20 #include "tensorflow/core/common_runtime/process_util.h"
21 #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
22 #include "tensorflow/core/common_runtime/step_stats_collector.h"
23 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
25 #include "tensorflow/core/distributed_runtime/worker_session.h"
26 #include "tensorflow/core/framework/collective.h"
27 #include "tensorflow/core/platform/tracing.h"
28 #include "tensorflow/core/profiler/lib/device_profiler_session.h"
29
30 namespace tensorflow {
31
Worker(WorkerEnv * env)32 Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
33 // Enable log history collection in StatusGroup so that recent warning and
34 // error log messages will be attached to the root error status to be
35 // forwarded to the master.
36 StatusGroup::ConfigureLogHistory();
37 }
38
GetStatusAsync(CallOptions * opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)39 void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request,
40 GetStatusResponse* response, bool fail_fast,
41 StatusCallback done) {
42 const DeviceMgr* dm = env_->device_mgr;
43 std::vector<DeviceAttributes> devices;
44 dm->ListDeviceAttributes(&devices);
45 response->mutable_device_attributes()->Reserve(devices.size());
46 for (auto& d : devices) {
47 response->add_device_attributes()->Swap(&d);
48 }
49 done(Status::OK());
50 }
51
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)52 void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
53 CreateWorkerSessionResponse* response,
54 StatusCallback done) {
55 Status s = env_->session_mgr->CreateSession(
56 request->session_handle(), request->server_def(),
57 request->cluster_device_attributes(), request->isolate_session_state(),
58 request->master_task(), request->master_incarnation());
59 done(s);
60 }
61
DeleteWorkerSessionAsync(CallOptions * opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)62 void Worker::DeleteWorkerSessionAsync(CallOptions* opts,
63 const DeleteWorkerSessionRequest* request,
64 DeleteWorkerSessionResponse* response,
65 StatusCallback done) {
66 Status s = env_->session_mgr->DeleteSession(request->session_handle());
67 done(s);
68 }
69
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)70 void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
71 RegisterGraphResponse* response,
72 StatusCallback done) {
73 std::shared_ptr<WorkerSession> session;
74 Status s;
75 if (request->create_worker_session_called()) {
76 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
77 &session);
78 } else {
79 session = env_->session_mgr->LegacySession();
80 }
81 if (s.ok()) {
82 s = session->graph_mgr()->Register(
83 request->session_handle(), request->graph_def(), session.get(),
84 request->graph_options(), request->debug_options(),
85 request->config_proto(), request->collective_graph_key(),
86 session->cluster_flr(), response->mutable_graph_handle());
87 }
88 done(s);
89 }
90
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)91 void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
92 DeregisterGraphResponse* response,
93 StatusCallback done) {
94 std::shared_ptr<WorkerSession> session;
95 Status s;
96 if (request->create_worker_session_called()) {
97 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
98 &session);
99 } else {
100 session = env_->session_mgr->LegacySession();
101 }
102 if (s.ok()) {
103 s = session->graph_mgr()->Deregister(request->graph_handle());
104 }
105
106 done(s);
107 }
108
AbortStep(int64_t step_id)109 void Worker::AbortStep(int64_t step_id) {
110 Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
111 SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
112 // Delay a bit before aborting the step. This way, the root
113 // cause may return first back to the client instead of this
114 // cancellation generated abort error.
115 rendez->StartAbort(errors::Aborted("Step ", step_id,
116 " cancelled. Cancelling rendezvous."));
117 rendez->Unref();
118 });
119 }
120
PrepareRunGraph(RunGraphRequestWrapper * req,GraphMgr::NamedTensors * in,GraphMgr::NamedTensors * out)121 Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req,
122 GraphMgr::NamedTensors* in,
123 GraphMgr::NamedTensors* out) {
124 static Tensor empty_tensor(DT_FLOAT);
125 if (req->num_sends() > 0) {
126 Tensor val;
127 for (size_t i = 0; i < req->num_sends(); ++i) {
128 TF_RETURN_IF_ERROR(req->SendValue(i, &val));
129 in->insert({req->send_key(i), val});
130 }
131 }
132 for (size_t i = 0; i < req->num_recvs(); ++i) {
133 out->insert({req->recv_key(i), empty_tensor});
134 }
135 return Status::OK();
136 }
137
RunGraphAsync(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)138 void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
139 MutableRunGraphResponseWrapper* response,
140 StatusCallback done) {
141 if (request->store_errors_in_response_body()) {
142 done = [response, done](const Status& status) {
143 response->set_status(status);
144 done(Status::OK());
145 };
146 }
147 if (request->is_partial()) {
148 DoPartialRunGraph(opts, request, response, std::move(done));
149 } else {
150 DoRunGraph(opts, request, response, std::move(done));
151 }
152 }
153
CreateRunGraphRequest()154 MutableRunGraphRequestWrapper* Worker::CreateRunGraphRequest() {
155 return new InMemoryRunGraphRequest;
156 }
157
CreateRunGraphResponse()158 MutableRunGraphResponseWrapper* Worker::CreateRunGraphResponse() {
159 return new InMemoryRunGraphResponse;
160 }
161
DoRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)162 void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
163 MutableRunGraphResponseWrapper* response,
164 StatusCallback done) {
165 const int64_t step_id = request->step_id();
166 TRACEPRINTF("RunGraph: %lld", step_id);
167 Status s = recent_request_ids_.TrackUnique(request->request_id(),
168 "RunGraph (Worker)", request);
169 if (!s.ok()) {
170 done(s);
171 return;
172 }
173
174 std::shared_ptr<WorkerSession> session;
175 if (request->create_worker_session_called()) {
176 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
177 &session);
178 } else {
179 session = env_->session_mgr->LegacySession();
180 }
181 if (!s.ok()) {
182 done(s);
183 return;
184 }
185 GraphMgr::NamedTensors in;
186 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
187 s = PrepareRunGraph(request, &in, out);
188 if (!s.ok()) {
189 delete out;
190 done(s);
191 return;
192 }
193 StepStatsCollector* collector = nullptr;
194 if (request->exec_opts().report_tensor_allocations_upon_oom() ||
195 request->exec_opts().record_timeline() ||
196 request->exec_opts().record_costs()) {
197 collector = new StepStatsCollector(response->mutable_step_stats());
198 }
199 DeviceProfilerSession* device_profiler_session = nullptr;
200 if (collector && request->exec_opts().record_timeline()) {
201 // If timeline was requested, assume we want hardware level tracing.
202 device_profiler_session = DeviceProfilerSession::Create().release();
203 }
204 CancellationManager* cm = new CancellationManager;
205 opts->SetCancelCallback([this, cm, step_id]() {
206 LOG(INFO) << "Cancellation requested for RunGraph.";
207 cm->StartCancel();
208 AbortStep(step_id);
209 });
210 CancellationToken token;
211 token = cancellation_manager_.get_cancellation_token();
212 bool already_cancelled = !cancellation_manager_.RegisterCallback(
213 token, [cm]() { cm->StartCancel(); });
214 if (already_cancelled) {
215 opts->ClearCancelCallback();
216 delete cm;
217 delete collector;
218 delete device_profiler_session;
219 delete out;
220 done(errors::Aborted("Call was aborted"));
221 return;
222 }
223 session->graph_mgr()->ExecuteAsync(
224 request->graph_handle(), step_id, session.get(), request->exec_opts(),
225 collector, response, cm, in,
226 [this, step_id, response, session, cm, out, token, collector,
227 device_profiler_session, opts, done](const Status& status) {
228 Status s = status;
229 if (s.ok()) {
230 s = session->graph_mgr()->RecvOutputs(step_id, out);
231 }
232
233 opts->ClearCancelCallback();
234 cancellation_manager_.DeregisterCallback(token);
235 delete cm;
236
237 if (device_profiler_session) {
238 device_profiler_session->CollectData(response->mutable_step_stats())
239 .IgnoreError();
240 }
241
242 if (s.ok()) {
243 for (const auto& p : *out) {
244 const string& key = p.first;
245 const Tensor& val = p.second;
246 response->AddRecv(key, val);
247 }
248 }
249
250 if (collector) collector->Finalize();
251 delete collector;
252 delete device_profiler_session;
253 delete out;
254 done(s);
255 });
256 }
257
258 // TODO(suharshs): Add stats collection support to partial run.
DoPartialRunGraph(CallOptions * opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)259 void Worker::DoPartialRunGraph(CallOptions* opts,
260 RunGraphRequestWrapper* request,
261 MutableRunGraphResponseWrapper* response,
262 StatusCallback done) {
263 const int64_t step_id = request->step_id();
264 const string& graph_handle = request->graph_handle();
265 TRACEPRINTF("PartialRunGraph: %lld", step_id);
266 Status s = recent_request_ids_.TrackUnique(
267 request->request_id(), "PartialRunGraph (Worker)", request);
268 if (!s.ok()) {
269 done(s);
270 return;
271 }
272
273 std::shared_ptr<WorkerSession> session;
274 if (request->create_worker_session_called()) {
275 s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
276 &session);
277 } else {
278 session = env_->session_mgr->LegacySession();
279 }
280 if (!s.ok()) {
281 done(s);
282 return;
283 }
284
285 GraphMgr::NamedTensors in;
286 GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
287 s = PrepareRunGraph(request, &in, out);
288 auto finish = [done, out, opts](const Status& s) {
289 opts->ClearCancelCallback();
290 delete out;
291 done(s);
292 };
293 if (!s.ok()) {
294 finish(s);
295 return;
296 }
297
298 CancellationManager* cm = nullptr;
299 bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
300
301 // Before we start doing anything, we set the RPC cancellation.
302 opts->SetCancelCallback([this, cm, step_id]() {
303 LOG(INFO) << "Cancellation requested for PartialRunGraph.";
304 cm->StartCancel();
305 AbortStep(step_id);
306 });
307
308 // If this is a new partial run request, the request will need to start the
309 // executors.
310 if (is_new_partial_run) {
311 CancellationToken token;
312 token = cancellation_manager_.get_cancellation_token();
313 cancellation_manager_.RegisterCallback(token,
314 [cm]() { cm->StartCancel(); });
315 session->graph_mgr()->ExecuteAsync(
316 graph_handle, step_id, session.get(), request->exec_opts(),
317 nullptr /* collector */, nullptr /* response */, cm, in,
318 [this, token, step_id, session](Status s) {
319 cancellation_manager_.DeregisterCallback(token);
320 partial_run_mgr_.ExecutorDone(step_id, s);
321 });
322 } else {
323 // Send the partial run's new inputs.
324 s = session->graph_mgr()->SendInputs(step_id, in);
325 if (!s.ok()) {
326 finish(s);
327 return;
328 }
329 }
330
331 session->graph_mgr()->RecvOutputsAsync(
332 step_id, out, [this, out, request, response, step_id, finish](Status s) {
333 if (s.ok()) {
334 // Construct and return the resp.
335 for (const auto& p : *out) {
336 const string& key = p.first;
337 const Tensor& val = p.second;
338 response->AddRecv(key, val);
339 }
340 }
341 if (request->is_last_partial_run()) {
342 partial_run_mgr_.PartialRunDone(step_id, finish, s);
343 } else {
344 finish(s);
345 }
346 });
347 }
348
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)349 void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
350 CleanupGraphResponse* response,
351 StatusCallback done) {
352 const int64_t step_id = request->step_id();
353 env_->rendezvous_mgr->Cleanup(step_id);
354 if (env_->collective_executor_mgr) {
355 env_->collective_executor_mgr->Cleanup(step_id);
356 }
357 for (Device* d : env_->local_devices) {
358 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
359 if (sam) {
360 sam->Cleanup(step_id);
361 }
362 }
363 done(Status::OK());
364 }
365
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)366 void Worker::CleanupAllAsync(const CleanupAllRequest* request,
367 CleanupAllResponse* response,
368 StatusCallback done) {
369 std::vector<string> containers;
370 for (const auto& c : request->container()) containers.push_back(c);
371 env_->device_mgr->ClearContainers(containers);
372 done(Status::OK());
373 }
374
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)375 void Worker::LoggingAsync(const LoggingRequest* request,
376 LoggingResponse* response, StatusCallback done) {
377 done(errors::Unimplemented("Logging"));
378 }
379
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)380 void Worker::TracingAsync(const TracingRequest* request,
381 TracingResponse* response, StatusCallback done) {
382 done(errors::Unimplemented("Tracing"));
383 }
384
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)385 void Worker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
386 RecvBufResponse* response, StatusCallback done) {
387 // The base Worker class does not implement RecvBufAsync because
388 // it is not currently used for worker-to-worker communication. Use a
389 // transport-specific implementation (such as `GrpcWorker::RecvBufAsync()`)
390 // instead.
391 done(errors::Unimplemented("Worker::RecvBufAsync()"));
392 }
393
CompleteGroupAsync(CallOptions * opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)394 void Worker::CompleteGroupAsync(CallOptions* opts,
395 const CompleteGroupRequest* request,
396 CompleteGroupResponse* response,
397 StatusCallback done) {
398 if (!request->has_device_attributes()) {
399 done(errors::Internal(
400 "CompleteGroupRequest device_attributes is not set. Make sure you're "
401 "running the same version of Tensorflow on all workers."));
402 return;
403 }
404 if (env_->collective_executor_mgr) {
405 auto group_params = new CollGroupParams();
406 group_params->group_key = request->group_key();
407 group_params->group_size = request->group_size();
408 group_params->device_type = DeviceType(request->device_type());
409 env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync(
410 request->device_attributes(), group_params, &cancellation_manager_,
411 [response, group_params, done = std::move(done)](const Status& s) {
412 if (s.ok()) {
413 response->set_group_key(group_params->group_key);
414 response->set_group_size(group_params->group_size);
415 response->set_device_type(group_params->device_type.type_string());
416 response->set_num_tasks(group_params->num_tasks);
417 for (const DeviceAttributes& device : group_params->devices) {
418 *response->add_device_attributes() = device;
419 }
420 response->set_communicator_key(
421 group_params->runtime_details.communicator_key);
422 } else {
423 LOG(ERROR) << "Bad status from CompleteGroupDistributed: " << s;
424 }
425 delete group_params;
426 done(s);
427 });
428 } else {
429 done(
430 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
431 }
432 }
433
CompleteInstanceAsync(CallOptions * opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)434 void Worker::CompleteInstanceAsync(CallOptions* opts,
435 const CompleteInstanceRequest* request,
436 CompleteInstanceResponse* response,
437 StatusCallback done) {
438 if (env_->collective_executor_mgr) {
439 env_->collective_executor_mgr->GetParamResolver()->CompleteInstanceAsync(
440 request, response, &cancellation_manager_, done);
441 } else {
442 done(
443 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
444 }
445 }
446
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)447 void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request,
448 GetStepSequenceResponse* response,
449 StatusCallback done) {
450 if (env_->collective_executor_mgr) {
451 env_->collective_executor_mgr->GetStepSequenceAsync(request, response,
452 done);
453 } else {
454 done(
455 errors::Internal("Runtime not initialized with CollectiveExecutorMgr"));
456 }
457 }
458
459 // Helper for RecvTensor. Validates "key" and returns the source
460 // device in "*src_dev".
PrepareRecvTensor(const Rendezvous::ParsedKey & parsed,Device ** src_dev)461 Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
462 Device** src_dev) {
463 // Figures out which device the tensor is hosted on.
464 string local_name = DeviceNameUtils::LocalName(parsed.src_device);
465 TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
466
467 // Does the device have the right incarnation number we expect?
468 if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
469 return errors::Aborted(
470 "RecvTensor expects a different device incarnation: ",
471 parsed.src_incarnation, " vs. ", (*src_dev)->attributes().incarnation(),
472 ". Your worker job (\"",
473 env_->session_mgr->LegacySession()->worker_name(),
474 "\") was probably restarted. Check your "
475 "worker job for the reason why it was restarted.");
476 }
477
478 return Status::OK();
479 }
480
RecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)481 void Worker::RecvTensorAsync(CallOptions* opts,
482 const RecvTensorRequest* request,
483 TensorResponse* response, StatusCallback done) {
484 // The base Worker class does not implement RecvTensorAsync, because
485 // it is not currently used for worker-to-worker communication. Use a
486 // transport-specific implementation (such as `GrpcWorker::RecvTensorAsync()`)
487 // instead.
488 done(errors::Unimplemented("Worker::RecvTensorAsync()"));
489 }
490
491 } // namespace tensorflow
492