1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/dispatcher_impl.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #ifdef PLATFORM_GOOGLE
24 #include "file/logging/log_lines.h"
25 #endif
26 #include "grpcpp/create_channel.h"
27 #include "grpcpp/impl/codegen/server_context.h"
28 #include "grpcpp/security/credentials.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/container/flat_hash_set.h"
31 #include "absl/memory/memory.h"
32 #include "absl/types/optional.h"
33 #include "tensorflow/core/data/dataset_utils.h"
34 #include "tensorflow/core/data/hash_utils.h"
35 #include "tensorflow/core/data/service/common.h"
36 #include "tensorflow/core/data/service/common.pb.h"
37 #include "tensorflow/core/data/service/credentials_factory.h"
38 #include "tensorflow/core/data/service/dataset_store.h"
39 #include "tensorflow/core/data/service/dispatcher.pb.h"
40 #include "tensorflow/core/data/service/dispatcher_state.h"
41 #include "tensorflow/core/data/service/grpc_util.h"
42 #include "tensorflow/core/data/service/journal.h"
43 #include "tensorflow/core/data/service/worker.grpc.pb.h"
44 #include "tensorflow/core/data/standalone.h"
45 #include "tensorflow/core/framework/dataset.h"
46 #include "tensorflow/core/framework/graph.pb.h"
47 #include "tensorflow/core/framework/node_def.pb.h"
48 #include "tensorflow/core/framework/tensor.h"
49 #include "tensorflow/core/platform/env.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/mutex.h"
52 #include "tensorflow/core/platform/path.h"
53 #include "tensorflow/core/platform/protobuf.h"
54 #include "tensorflow/core/platform/status.h"
55 #include "tensorflow/core/platform/thread_annotations.h"
56 #include "tensorflow/core/protobuf/data_service.pb.h"
57 #include "tensorflow/core/protobuf/service_config.pb.h"
58 #include "tensorflow/core/public/session_options.h"
59
60 namespace tensorflow {
61 namespace data {
62 namespace {
63
64 using ::tensorflow::protobuf::util::MessageDifferencer;
65
66 // The name of the journal directory inside the dispatcher's working directory.
67 // This name is load-bearing; do not change.
68 constexpr char kJournalDir[] = "tf_data_dispatcher_journal";
69 // The name of the datasets directory inside the dispatcher's working directory.
70 constexpr char kDatasetsDir[] = "datasets";
71
72 constexpr std::array<const char*, 8> kNodeNameSharingOps = {
73 "HashTable",
74 "HashTableV2",
75 "MutableHashTable",
76 "MutableHashTableV2",
77 "MutableDenseHashTable",
78 "MutableDenseHashTableV2",
79 "MutableHashTableOfTensors",
80 "MutableHashTableOfTensorsV2",
81 };
82
83 using Dataset = DispatcherState::Dataset;
84 using Worker = DispatcherState::Worker;
85 using NamedJobKey = DispatcherState::NamedJobKey;
86 using Job = DispatcherState::Job;
87 using Task = DispatcherState::Task;
88
JournalDir(const std::string & work_dir)89 std::string JournalDir(const std::string& work_dir) {
90 return io::JoinPath(work_dir, kJournalDir);
91 }
92
DatasetsDir(const std::string & work_dir)93 std::string DatasetsDir(const std::string& work_dir) {
94 return io::JoinPath(work_dir, kDatasetsDir);
95 }
96
DatasetKey(int64_t id,uint64 fingerprint)97 std::string DatasetKey(int64_t id, uint64 fingerprint) {
98 return absl::StrCat("id_", id, "_fp_", fingerprint);
99 }
100
CreateWorkerStub(const std::string & address,const std::string & protocol,std::unique_ptr<WorkerService::Stub> & stub)101 Status CreateWorkerStub(const std::string& address, const std::string& protocol,
102 std::unique_ptr<WorkerService::Stub>& stub) {
103 ::grpc::ChannelArguments args;
104 args.SetMaxReceiveMessageSize(-1);
105 std::shared_ptr<::grpc::ChannelCredentials> credentials;
106 TF_RETURN_IF_ERROR(
107 CredentialsFactory::CreateClientCredentials(protocol, &credentials));
108 auto channel = ::grpc::CreateCustomChannel(address, credentials, args);
109 stub = WorkerService::NewStub(channel);
110 return Status::OK();
111 }
112
PrepareGraph(GraphDef * graph)113 void PrepareGraph(GraphDef* graph) {
114 for (NodeDef& node : *graph->mutable_node()) {
115 for (const auto& op : kNodeNameSharingOps) {
116 // Set `use_node_name_sharing` to `true` so that resources aren't deleted
117 // prematurely. Otherwise, resources may be deleted when their ops are
118 // deleted at the end of the GraphRunner::Run used by standalone::Dataset.
119 if (node.op() == op) {
120 (*node.mutable_attr())["use_node_name_sharing"].set_b(true);
121 }
122 if (!node.device().empty()) {
123 *node.mutable_device() = "";
124 }
125 }
126 }
127 StripDevicePlacement(graph->mutable_library());
128 }
129 } // namespace
130
DataServiceDispatcherImpl(const experimental::DispatcherConfig & config)131 DataServiceDispatcherImpl::DataServiceDispatcherImpl(
132 const experimental::DispatcherConfig& config)
133 : config_(config), env_(Env::Default()), state_(config_) {
134 if (config_.work_dir().empty()) {
135 dataset_store_ = absl::make_unique<MemoryDatasetStore>();
136 } else {
137 dataset_store_ = absl::make_unique<FileSystemDatasetStore>(
138 DatasetsDir(config_.work_dir()));
139 }
140 }
141
~DataServiceDispatcherImpl()142 DataServiceDispatcherImpl::~DataServiceDispatcherImpl() {
143 {
144 mutex_lock l(mu_);
145 cancelled_ = true;
146 job_gc_thread_cv_.notify_all();
147 }
148 job_gc_thread_.reset();
149 }
150
Start()151 Status DataServiceDispatcherImpl::Start() {
152 mutex_lock l(mu_);
153 if (config_.job_gc_timeout_ms() >= 0) {
154 job_gc_thread_ = absl::WrapUnique(
155 env_->StartThread({}, "job-gc-thread", [&] { JobGcThread(); }));
156 }
157 if (config_.work_dir().empty()) {
158 if (config_.fault_tolerant_mode()) {
159 return errors::InvalidArgument(
160 "fault_tolerant_mode is True, but no work_dir is configured.");
161 }
162 } else {
163 TF_RETURN_IF_ERROR(
164 env_->RecursivelyCreateDir(DatasetsDir(config_.work_dir())));
165 }
166 if (!config_.fault_tolerant_mode()) {
167 LOG(INFO) << "Running with fault_tolerant_mode=False. The dispatcher will "
168 "not be able to recover its state on restart.";
169 started_ = true;
170 return Status::OK();
171 }
172 journal_writer_ = absl::make_unique<FileJournalWriter>(
173 env_, JournalDir(config_.work_dir()));
174 LOG(INFO) << "Attempting to restore dispatcher state from journal in "
175 << JournalDir(config_.work_dir());
176 Update update;
177 bool end_of_journal = false;
178 FileJournalReader reader(env_, JournalDir(config_.work_dir()));
179 Status s = reader.Read(update, end_of_journal);
180 if (errors::IsNotFound(s)) {
181 LOG(INFO) << "No journal found. Starting dispatcher from new state.";
182 } else if (!s.ok()) {
183 return s;
184 } else {
185 while (!end_of_journal) {
186 TF_RETURN_IF_ERROR(ApplyWithoutJournaling(update));
187 TF_RETURN_IF_ERROR(reader.Read(update, end_of_journal));
188 }
189 }
190 for (const auto& job : state_.ListJobs()) {
191 if (IsDynamicShard(job->processing_mode)) {
192 TF_RETURN_IF_ERROR(
193 RestoreSplitProviders(*job, split_providers_[job->job_id]));
194 }
195 }
196 // Initialize the journal writer in `Start` so that we fail fast in case it
197 // can't be initialized.
198 TF_RETURN_IF_ERROR(journal_writer_.value()->EnsureInitialized());
199 started_ = true;
200 return Status::OK();
201 }
202
RestoreSplitProviders(const Job & job,std::vector<std::unique_ptr<SplitProvider>> & restored)203 Status DataServiceDispatcherImpl::RestoreSplitProviders(
204 const Job& job, std::vector<std::unique_ptr<SplitProvider>>& restored)
205 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
206 const std::vector<int64>& indices =
207 job.distributed_epoch_state.value().indices;
208 std::vector<std::unique_ptr<SplitProvider>> split_providers;
209 TF_RETURN_IF_ERROR(MakeSplitProviders(job.dataset_id, split_providers));
210 for (int provider_index = 0; provider_index < indices.size();
211 ++provider_index) {
212 int index = indices[provider_index];
213 VLOG(1) << "Restoring split provider " << provider_index << " for job "
214 << job.job_id << " to index " << index;
215 Tensor unused_tensor;
216 bool unused_end_of_splits;
217 for (int i = 0; i < index; ++i) {
218 TF_RETURN_IF_ERROR(split_providers[provider_index]->GetNext(
219 &unused_tensor, &unused_end_of_splits));
220 }
221 }
222 restored = std::move(split_providers);
223 return Status::OK();
224 }
225
FindTasksToDelete(const absl::flat_hash_set<int64> & current_tasks,const std::vector<std::shared_ptr<const Task>> assigned_tasks,WorkerHeartbeatResponse * response)226 Status DataServiceDispatcherImpl::FindTasksToDelete(
227 const absl::flat_hash_set<int64>& current_tasks,
228 const std::vector<std::shared_ptr<const Task>> assigned_tasks,
229 WorkerHeartbeatResponse* response) {
230 absl::flat_hash_set<int64> assigned_ids;
231 for (const auto& assigned : assigned_tasks) {
232 assigned_ids.insert(assigned->task_id);
233 }
234 for (int64_t current_task : current_tasks) {
235 if (!assigned_ids.contains(current_task)) {
236 response->add_tasks_to_delete(current_task);
237 }
238 }
239 return Status::OK();
240 }
241
FindNewTasks(const std::string & worker_address,const absl::flat_hash_set<int64> & current_tasks,std::vector<std::shared_ptr<const Task>> & assigned_tasks,WorkerHeartbeatResponse * response)242 Status DataServiceDispatcherImpl::FindNewTasks(
243 const std::string& worker_address,
244 const absl::flat_hash_set<int64>& current_tasks,
245 std::vector<std::shared_ptr<const Task>>& assigned_tasks,
246 WorkerHeartbeatResponse* response) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
247 // Check for round-robin jobs that had tasks on the worker removed. Now that
248 // the worker is back, we create a new pending task for the worker.
249 absl::flat_hash_set<int64> assigned_job_ids;
250 for (const auto& task : assigned_tasks) {
251 assigned_job_ids.insert(task->job->job_id);
252 }
253 for (const auto& job : state_.ListJobs()) {
254 if (!assigned_job_ids.contains(job->job_id) && job->IsRoundRobin() &&
255 !job->finished) {
256 VLOG(1) << "Creating pending task for reconnected worker "
257 << worker_address;
258 TF_RETURN_IF_ERROR(CreatePendingTask(job, worker_address));
259 }
260 }
261 // Refresh assigned_tasks to include newly added pending tasks.
262 TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
263 for (const auto& task : assigned_tasks) {
264 if (current_tasks.contains(task->task_id)) {
265 continue;
266 }
267 TaskDef* task_def = response->add_new_tasks();
268 TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
269 }
270 return Status::OK();
271 }
272
WorkerHeartbeat(const WorkerHeartbeatRequest * request,WorkerHeartbeatResponse * response)273 Status DataServiceDispatcherImpl::WorkerHeartbeat(
274 const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) {
275 TF_RETURN_IF_ERROR(CheckStarted());
276 VLOG(4) << "Received worker heartbeat request from worker "
277 << request->worker_address();
278 mutex_lock l(mu_);
279 const std::string& worker_address = request->worker_address();
280 // Assigned tasks from the perspective of the dispatcher.
281 std::vector<std::shared_ptr<const Task>> assigned_tasks;
282 Status s = state_.TasksForWorker(worker_address, assigned_tasks);
283 if (!s.ok()) {
284 if (!errors::IsNotFound(s)) {
285 return s;
286 }
287 VLOG(1) << "Registering new worker at address " << worker_address;
288 TF_RETURN_IF_ERROR(state_.ValidateWorker(worker_address));
289 Update update;
290 update.mutable_register_worker()->set_worker_address(worker_address);
291 update.mutable_register_worker()->set_transfer_address(
292 request->transfer_address());
293 TF_RETURN_IF_ERROR(Apply(update));
294 TF_RETURN_IF_ERROR(CreateTasksForWorker(worker_address));
295 TF_RETURN_IF_ERROR(state_.TasksForWorker(worker_address, assigned_tasks));
296 }
297 absl::flat_hash_set<int64> current_tasks;
298 current_tasks.insert(request->current_tasks().cbegin(),
299 request->current_tasks().cend());
300 TF_RETURN_IF_ERROR(
301 FindTasksToDelete(current_tasks, assigned_tasks, response));
302 TF_RETURN_IF_ERROR(
303 FindNewTasks(worker_address, current_tasks, assigned_tasks, response));
304
305 VLOG(4) << "Finished worker heartbeat for worker at address "
306 << request->worker_address();
307 return Status::OK();
308 }
309
WorkerUpdate(const WorkerUpdateRequest * request,WorkerUpdateResponse * response)310 Status DataServiceDispatcherImpl::WorkerUpdate(
311 const WorkerUpdateRequest* request, WorkerUpdateResponse* response) {
312 TF_RETURN_IF_ERROR(CheckStarted());
313 mutex_lock l(mu_);
314 for (auto& update : request->updates()) {
315 int64_t task_id = update.task_id();
316 std::shared_ptr<const Task> task;
317 TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
318 if (update.completed()) {
319 if (task->finished) {
320 VLOG(1) << "Received completion update for already-finished task "
321 << task->task_id << " on worker " << task->worker_address;
322 continue;
323 }
324 Update update;
325 update.mutable_finish_task()->set_task_id(task_id);
326 TF_RETURN_IF_ERROR(Apply(update));
327 VLOG(3) << "Task " << task_id << " from job " << task->job->job_id
328 << " completed";
329 }
330 }
331 return Status::OK();
332 }
333
GetDatasetDef(const GetDatasetDefRequest * request,GetDatasetDefResponse * response)334 Status DataServiceDispatcherImpl::GetDatasetDef(
335 const GetDatasetDefRequest* request, GetDatasetDefResponse* response) {
336 TF_RETURN_IF_ERROR(CheckStarted());
337 mutex_lock l(mu_);
338 std::shared_ptr<const Dataset> dataset;
339 TF_RETURN_IF_ERROR(state_.DatasetFromId(request->dataset_id(), dataset));
340 std::shared_ptr<const DatasetDef> dataset_def;
341 TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
342 *response->mutable_dataset_def() = *dataset_def;
343 return Status::OK();
344 }
345
GetSplit(const GetSplitRequest * request,GetSplitResponse * response)346 Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request,
347 GetSplitResponse* response) {
348 TF_RETURN_IF_ERROR(CheckStarted());
349 mutex_lock l(mu_);
350 int64_t job_id = request->job_id();
351 int64_t repetition = request->repetition();
352 int64_t provider_index = request->split_provider_index();
353 VLOG(3) << "Received GetSplit request for job " << job_id << ", repetition "
354 << repetition << ", split provider index " << provider_index;
355 std::shared_ptr<const Job> job;
356 TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
357 if (!job->distributed_epoch_state.has_value()) {
358 return errors::FailedPrecondition(
359 "Cannot get split for job ", job_id,
360 ", since it is not a distributed_epoch job.");
361 }
362 int64_t current_repetition =
363 job->distributed_epoch_state.value().repetitions[provider_index];
364 if (repetition < current_repetition) {
365 response->set_end_of_splits(true);
366 VLOG(3) << "Returning end_of_splits since current reptition "
367 << current_repetition << " is greater than the requested reptition "
368 << repetition;
369 return Status::OK();
370 }
371 SplitProvider* split_provider =
372 split_providers_[job_id][provider_index].get();
373 DCHECK(split_provider != nullptr);
374 Tensor split;
375 bool end_of_splits = false;
376 TF_RETURN_IF_ERROR(split_provider->GetNext(&split, &end_of_splits));
377 TF_RETURN_IF_ERROR(RecordSplitProduced(
378 job_id, repetition, request->split_provider_index(), end_of_splits));
379 response->set_end_of_splits(end_of_splits);
380 if (end_of_splits) {
381 // Reset the split provider to prepare for the next repetition.
382 TF_RETURN_IF_ERROR(split_providers_[job_id][provider_index]->Reset());
383 } else {
384 split.AsProtoTensorContent(response->mutable_split());
385 }
386 VLOG(3) << "Returning from GetSplit, end_of_splits=" << end_of_splits;
387 return Status::OK();
388 }
389
MakeSplitProviders(int64_t dataset_id,std::vector<std::unique_ptr<SplitProvider>> & split_providers)390 Status DataServiceDispatcherImpl::MakeSplitProviders(
391 int64_t dataset_id,
392 std::vector<std::unique_ptr<SplitProvider>>& split_providers)
393 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
394 std::shared_ptr<const Dataset> dataset;
395 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
396 std::shared_ptr<const DatasetDef> dataset_def;
397 TF_RETURN_IF_ERROR(GetDatasetDef(*dataset, dataset_def));
398 standalone::Dataset::Params params;
399 std::unique_ptr<standalone::Dataset> standalone_dataset;
400 TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
401 params, dataset_def->graph(), &standalone_dataset));
402 TF_RETURN_IF_ERROR(standalone_dataset->MakeSplitProviders(&split_providers));
403 return Status::OK();
404 }
405
GetVersion(const GetVersionRequest * request,GetVersionResponse * response)406 Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request,
407 GetVersionResponse* response) {
408 response->set_version(kDataServiceVersion);
409 return Status::OK();
410 }
411
GetOrRegisterDataset(const GetOrRegisterDatasetRequest * request,GetOrRegisterDatasetResponse * response)412 Status DataServiceDispatcherImpl::GetOrRegisterDataset(
413 const GetOrRegisterDatasetRequest* request,
414 GetOrRegisterDatasetResponse* response) {
415 TF_RETURN_IF_ERROR(CheckStarted());
416 uint64 fingerprint;
417 DatasetDef dataset_def = request->dataset();
418 GraphDef* graph = dataset_def.mutable_graph();
419 PrepareGraph(graph);
420 TF_RETURN_IF_ERROR(HashGraph(*graph, &fingerprint));
421
422 mutex_lock l(mu_);
423 #if defined(PLATFORM_GOOGLE)
424 VLOG_LINES(4,
425 absl::StrCat("Registering dataset graph: ", graph->DebugString()));
426 #else
427 VLOG(4) << "Registering dataset graph: " << graph->DebugString();
428 #endif
429 std::shared_ptr<const Dataset> dataset;
430 Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
431 if (s.ok()) {
432 int64_t id = dataset->dataset_id;
433 VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
434 << fingerprint << ". Returning id " << id;
435 response->set_dataset_id(id);
436 return Status::OK();
437 } else if (!errors::IsNotFound(s)) {
438 return s;
439 }
440
441 int64_t id;
442 TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def, id));
443 if (!request->element_spec().empty()) {
444 TF_RETURN_IF_ERROR(SetElementSpec(id, request->element_spec()));
445 }
446
447 response->set_dataset_id(id);
448 VLOG(3) << "Registered new dataset with id " << id;
449 return Status::OK();
450 }
451
RegisterDataset(uint64 fingerprint,const DatasetDef & dataset,int64 & dataset_id)452 Status DataServiceDispatcherImpl::RegisterDataset(uint64 fingerprint,
453 const DatasetDef& dataset,
454 int64& dataset_id)
455 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
456 dataset_id = state_.NextAvailableDatasetId();
457 Update update;
458 RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
459 register_dataset->set_dataset_id(dataset_id);
460 register_dataset->set_fingerprint(fingerprint);
461 TF_RETURN_IF_ERROR(
462 dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
463 return Apply(update);
464 }
465
SetElementSpec(int64_t dataset_id,const std::string & element_spec)466 Status DataServiceDispatcherImpl::SetElementSpec(
467 int64_t dataset_id, const std::string& element_spec)
468 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
469 Update update;
470 SetElementSpecUpdate* set_element_spec = update.mutable_set_element_spec();
471 set_element_spec->set_dataset_id(dataset_id);
472 set_element_spec->set_element_spec(element_spec);
473 TF_RETURN_IF_ERROR(Apply(update));
474 return Status::OK();
475 }
476
GetElementSpec(const GetElementSpecRequest * request,GetElementSpecResponse * response)477 Status DataServiceDispatcherImpl::GetElementSpec(
478 const GetElementSpecRequest* request, GetElementSpecResponse* response) {
479 TF_RETURN_IF_ERROR(CheckStarted());
480 mutex_lock l(mu_);
481 VLOG(4) << "Read the element spec.";
482 int64_t dataset_id = request->dataset_id();
483
484 std::string element_spec;
485 TF_RETURN_IF_ERROR(state_.GetElementSpec(dataset_id, element_spec));
486 VLOG(3) << "Get the `element_spec` for registered dataset with dataset id: "
487 << dataset_id << ".";
488 *response->mutable_element_spec() = element_spec;
489 return Status::OK();
490 }
491
GetOrCreateJob(const GetOrCreateJobRequest * request,GetOrCreateJobResponse * response)492 Status DataServiceDispatcherImpl::GetOrCreateJob(
493 const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
494 TF_RETURN_IF_ERROR(CheckStarted());
495 VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")";
496 absl::optional<NamedJobKey> key;
497 if (request->has_job_key()) {
498 key.emplace(request->job_key().job_name(),
499 request->job_key().job_name_index());
500 }
501 std::shared_ptr<const Job> job;
502 std::vector<std::shared_ptr<const Task>> tasks;
503 {
504 mutex_lock l(mu_);
505 if (key.has_value()) {
506 Status s = state_.NamedJobByKey(key.value(), job);
507 if (s.ok()) {
508 TF_RETURN_IF_ERROR(ValidateMatchingJob(job, *request));
509 // If the matching job was already garbage-collected, we fall through to
510 // re-create the job.
511 if (!job->garbage_collected) {
512 int64_t job_client_id;
513 TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
514 response->set_job_client_id(job_client_id);
515 VLOG(3) << "Found existing job for name=" << key.value().name
516 << ", index=" << key.value().index
517 << ". job_id: " << job->job_id;
518 return Status::OK();
519 }
520 } else if (!errors::IsNotFound(s)) {
521 return s;
522 }
523 }
524 TF_RETURN_IF_ERROR(CreateJob(*request, job));
525 int64_t job_client_id;
526 TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
527 response->set_job_client_id(job_client_id);
528 TF_RETURN_IF_ERROR(CreateTasksForJob(job, tasks));
529 }
530 TF_RETURN_IF_ERROR(AssignTasks(tasks));
531 VLOG(3) << "Created job " << job->job_id << " for CreateJob("
532 << request->DebugString() << ")";
533 return Status::OK();
534 }
535
MaybeRemoveTask(const MaybeRemoveTaskRequest * request,MaybeRemoveTaskResponse * response)536 Status DataServiceDispatcherImpl::MaybeRemoveTask(
537 const MaybeRemoveTaskRequest* request, MaybeRemoveTaskResponse* response) {
538 VLOG(1) << "Attempting to remove task. Request: " << request->DebugString();
539 std::shared_ptr<TaskRemover> remover;
540 std::shared_ptr<const Task> task;
541 {
542 mutex_lock l(mu_);
543 Status s = state_.TaskFromId(request->task_id(), task);
544 if (errors::IsNotFound(s)) {
545 // Task is already removed.
546 response->set_removed(true);
547 return Status::OK();
548 }
549 TF_RETURN_IF_ERROR(s);
550 auto& remover_ref = remove_task_requests_[task->task_id];
551 if (remover_ref == nullptr) {
552 if (!task->job->IsRoundRobin()) {
553 return errors::FailedPrecondition(
554 "MaybeRemoveTask called on a non-round-robin task.");
555 }
556 remover_ref =
557 std::make_shared<TaskRemover>(task->job->num_consumers.value());
558 }
559 remover = remover_ref;
560 }
561 bool removed =
562 remover->RequestRemoval(request->consumer_index(), request->round());
563 response->set_removed(removed);
564 if (!removed) {
565 VLOG(1) << "Failed to remove task " << task->task_id;
566 return Status::OK();
567 }
568 mutex_lock l(mu_);
569 if (!task->removed) {
570 Update update;
571 RemoveTaskUpdate* remove_task = update.mutable_remove_task();
572 remove_task->set_task_id(request->task_id());
573 TF_RETURN_IF_ERROR(Apply(update));
574 }
575 VLOG(1) << "Task " << task->task_id << " successfully removed";
576 return Status::OK();
577 }
578
ReleaseJobClient(const ReleaseJobClientRequest * request,ReleaseJobClientResponse * response)579 Status DataServiceDispatcherImpl::ReleaseJobClient(
580 const ReleaseJobClientRequest* request,
581 ReleaseJobClientResponse* response) {
582 TF_RETURN_IF_ERROR(CheckStarted());
583 mutex_lock l(mu_);
584 int64_t job_client_id = request->job_client_id();
585 std::shared_ptr<const Job> job;
586 TF_RETURN_IF_ERROR(state_.JobForJobClientId(job_client_id, job));
587 Update update;
588 ReleaseJobClientUpdate* release_job_client =
589 update.mutable_release_job_client();
590 release_job_client->set_job_client_id(job_client_id);
591 release_job_client->set_time_micros(env_->NowMicros());
592 TF_RETURN_IF_ERROR(Apply(update));
593 return Status::OK();
594 }
595
596 // Validates that the job matches the requested processing mode.
ValidateMatchingJob(std::shared_ptr<const Job> job,const GetOrCreateJobRequest & request)597 Status DataServiceDispatcherImpl::ValidateMatchingJob(
598 std::shared_ptr<const Job> job, const GetOrCreateJobRequest& request)
599 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
600 DCHECK(job->named_job_key.has_value());
601 std::string job_name = job->named_job_key->name;
602
603 if (!MessageDifferencer::Equals(job->processing_mode,
604 request.processing_mode_def())) {
605 return errors::FailedPrecondition(
606 "Tried to create a job with name ", job_name, " and processing_mode <",
607 request.processing_mode_def().ShortDebugString(),
608 "> but there is already an existing job with that name using "
609 "processing mode <",
610 job->processing_mode.ShortDebugString(), ">");
611 }
612
613 if (job->target_workers != request.target_workers()) {
614 return errors::InvalidArgument(
615 "Tried to create job with name ", job_name, " and target_workers <",
616 TargetWorkersToString(request.target_workers()),
617 ">, but there is already an existing job "
618 "with that name using target_workers <",
619 TargetWorkersToString(job->target_workers), ">.");
620 }
621 return Status::OK();
622 }
623
CreateJob(const GetOrCreateJobRequest & request,std::shared_ptr<const Job> & job)624 Status DataServiceDispatcherImpl::CreateJob(
625 const GetOrCreateJobRequest& request, std::shared_ptr<const Job>& job)
626 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
627 TF_RETURN_IF_ERROR(ValidateProcessingMode(request.processing_mode_def()));
628 int64_t job_id = state_.NextAvailableJobId();
629 int64_t num_split_providers = 0;
630 if (IsDynamicShard(request.processing_mode_def())) {
631 TF_RETURN_IF_ERROR(
632 MakeSplitProviders(request.dataset_id(), split_providers_[job_id]));
633 num_split_providers = split_providers_[job_id].size();
634 }
635 Update update;
636 CreateJobUpdate* create_job = update.mutable_create_job();
637 create_job->set_job_id(job_id);
638 create_job->set_dataset_id(request.dataset_id());
639 *create_job->mutable_processing_mode_def() = request.processing_mode_def();
640 create_job->set_num_split_providers(num_split_providers);
641 if (request.has_job_key()) {
642 NamedJobKeyDef* key = create_job->mutable_named_job_key();
643 key->set_name(request.job_key().job_name());
644 key->set_index(request.job_key().job_name_index());
645 }
646 if (request.optional_num_consumers_case() ==
647 GetOrCreateJobRequest::kNumConsumers) {
648 create_job->set_num_consumers(request.num_consumers());
649 }
650 create_job->set_target_workers(request.target_workers());
651 TF_RETURN_IF_ERROR(Apply(update));
652 TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));
653 return Status::OK();
654 }
655
CreateTasksForWorker(const std::string & worker_address)656 Status DataServiceDispatcherImpl::CreateTasksForWorker(
657 const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
658 std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
659 for (const auto& job : jobs) {
660 if (job->finished) {
661 continue;
662 }
663 if (job->num_consumers.has_value()) {
664 TF_RETURN_IF_ERROR(CreatePendingTask(job, worker_address));
665 continue;
666 }
667 std::shared_ptr<const Task> task;
668 TF_RETURN_IF_ERROR(CreateTask(job, worker_address, task));
669 }
670 return Status::OK();
671 }
672
AcquireJobClientId(const std::shared_ptr<const Job> & job,int64 & job_client_id)673 Status DataServiceDispatcherImpl::AcquireJobClientId(
674 const std::shared_ptr<const Job>& job, int64& job_client_id)
675 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
676 job_client_id = state_.NextAvailableJobClientId();
677 Update update;
678 AcquireJobClientUpdate* acquire_job_client =
679 update.mutable_acquire_job_client();
680 acquire_job_client->set_job_client_id(job_client_id);
681 acquire_job_client->set_job_id(job->job_id);
682 TF_RETURN_IF_ERROR(Apply(update));
683 return Status::OK();
684 }
685
CreateTasksForJob(std::shared_ptr<const Job> job,std::vector<std::shared_ptr<const Task>> & tasks)686 Status DataServiceDispatcherImpl::CreateTasksForJob(
687 std::shared_ptr<const Job> job,
688 std::vector<std::shared_ptr<const Task>>& tasks)
689 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
690 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
691 tasks.clear();
692 tasks.reserve(workers.size());
693 for (const auto& worker : workers) {
694 std::shared_ptr<const Task> task;
695 TF_RETURN_IF_ERROR(CreateTask(job, worker->address, task));
696 tasks.push_back(task);
697 }
698 return Status::OK();
699 }
700
CreatePendingTask(std::shared_ptr<const Job> job,const std::string & worker_address)701 Status DataServiceDispatcherImpl::CreatePendingTask(
702 std::shared_ptr<const Job> job, const std::string& worker_address)
703 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
704 int64_t task_id = state_.NextAvailableTaskId();
705 Update update;
706 CreatePendingTaskUpdate* create_task = update.mutable_create_pending_task();
707 create_task->set_task_id(task_id);
708 create_task->set_job_id(job->job_id);
709 create_task->set_worker_address(worker_address);
710 create_task->set_starting_round(round_robin_rounds_[job->job_id] + 1);
711 std::shared_ptr<const Worker> worker;
712 TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
713 create_task->set_transfer_address(worker->transfer_address);
714 TF_RETURN_IF_ERROR(Apply(update));
715 return Status::OK();
716 }
717
CreateTask(std::shared_ptr<const Job> job,const std::string & worker_address,std::shared_ptr<const Task> & task)718 Status DataServiceDispatcherImpl::CreateTask(std::shared_ptr<const Job> job,
719 const std::string& worker_address,
720 std::shared_ptr<const Task>& task)
721 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
722 int64_t task_id = state_.NextAvailableTaskId();
723 Update update;
724 CreateTaskUpdate* create_task = update.mutable_create_task();
725 create_task->set_task_id(task_id);
726 create_task->set_job_id(job->job_id);
727 create_task->set_worker_address(worker_address);
728 std::shared_ptr<const Worker> worker;
729 TF_RETURN_IF_ERROR(state_.WorkerFromAddress(worker_address, worker));
730 create_task->set_transfer_address(worker->transfer_address);
731 TF_RETURN_IF_ERROR(Apply(update));
732 TF_RETURN_IF_ERROR(state_.TaskFromId(task_id, task));
733 return Status::OK();
734 }
735
AssignTasks(std::vector<std::shared_ptr<const Task>> tasks)736 Status DataServiceDispatcherImpl::AssignTasks(
737 std::vector<std::shared_ptr<const Task>> tasks) TF_LOCKS_EXCLUDED(mu_) {
738 for (const auto& task : tasks) {
739 TF_RETURN_IF_ERROR(AssignTask(task));
740 }
741 return Status::OK();
742 }
743
GetOrCreateWorkerStub(const std::string & worker_address,WorkerService::Stub * & out_stub)744 Status DataServiceDispatcherImpl::GetOrCreateWorkerStub(
745 const std::string& worker_address, WorkerService::Stub*& out_stub)
746 TF_LOCKS_EXCLUDED(mu_) {
747 {
748 mutex_lock l(mu_);
749 auto it = worker_stubs_.find(worker_address);
750 if (it != worker_stubs_.end()) {
751 out_stub = it->second.get();
752 return Status::OK();
753 }
754 }
755 std::unique_ptr<WorkerService::Stub> stub;
756 TF_RETURN_IF_ERROR(
757 CreateWorkerStub(worker_address, config_.protocol(), stub));
758 {
759 mutex_lock l(mu_);
760 // A concurrent call could have already created the stub.
761 auto& worker = worker_stubs_[worker_address];
762 if (worker == nullptr) {
763 worker = std::move(stub);
764 }
765 out_stub = worker.get();
766 }
767 return Status::OK();
768 }
769
AssignTask(std::shared_ptr<const Task> task)770 Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
771 TF_LOCKS_EXCLUDED(mu_) {
772 VLOG(2) << "Started assigning task " << task->task_id << " to worker "
773 << task->worker_address;
774 grpc::ClientContext client_ctx;
775 ProcessTaskRequest req;
776 TaskDef* task_def = req.mutable_task();
777 {
778 mutex_lock l(mu_);
779 TF_RETURN_IF_ERROR(PopulateTaskDef(task, task_def));
780 }
781 ProcessTaskResponse resp;
782 WorkerService::Stub* stub;
783 TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
784 grpc::Status s = stub->ProcessTask(&client_ctx, req, &resp);
785 if (!s.ok()) {
786 if (s.error_code() == grpc::StatusCode::UNAVAILABLE ||
787 s.error_code() == grpc::StatusCode::ABORTED ||
788 s.error_code() == grpc::StatusCode::CANCELLED) {
789 // Worker is presumably preempted. We will assign the task to the worker
790 // when it reconnects.
791 return Status::OK();
792 }
793 return grpc_util::WrapError(
794 absl::StrCat("Failed to submit task to worker ", task->worker_address),
795 s);
796 }
797 VLOG(2) << "Finished assigning task " << task->task_id << " to worker "
798 << task->worker_address;
799 return Status::OK();
800 }
801
ClientHeartbeat(const ClientHeartbeatRequest * request,ClientHeartbeatResponse * response)802 Status DataServiceDispatcherImpl::ClientHeartbeat(
803 const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) {
804 TF_RETURN_IF_ERROR(CheckStarted());
805 mutex_lock l(mu_);
806 VLOG(4) << "Received heartbeat from client id " << request->job_client_id();
807 std::shared_ptr<const Job> job;
808 Status s = state_.JobForJobClientId(request->job_client_id(), job);
809 if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) {
810 return errors::NotFound(
811 "Unknown job client id ", request->job_client_id(),
812 ". The dispatcher is not configured to be fault tolerant, so this "
813 "could be caused by a dispatcher restart.");
814 }
815 TF_RETURN_IF_ERROR(s);
816 if (job->garbage_collected) {
817 return errors::FailedPrecondition(
818 "The requested job has been garbage collected due to inactivity. "
819 "Consider configuring the dispatcher with a higher "
820 "`job_gc_timeout_ms`.");
821 }
822 if (request->optional_current_round_case() ==
823 ClientHeartbeatRequest::kCurrentRound) {
824 round_robin_rounds_[request->job_client_id()] =
825 std::max(round_robin_rounds_[request->job_client_id()],
826 request->current_round());
827 }
828 if (!job->pending_tasks.empty()) {
829 const auto& task = job->pending_tasks.front();
830 Update update;
831 ClientHeartbeatUpdate* client_heartbeat = update.mutable_client_heartbeat();
832 bool apply_update = false;
833 client_heartbeat->set_job_client_id(request->job_client_id());
834 absl::optional<int64> blocked_round;
835 if (request->optional_blocked_round_case() ==
836 ClientHeartbeatRequest::kBlockedRound) {
837 blocked_round = request->blocked_round();
838 }
839 VLOG(1) << "Handling pending task in job client heartbeat. job_client_id: "
840 << request->job_client_id()
841 << ". current_round: " << request->current_round()
842 << ". blocked_round: " << blocked_round.value_or(-1)
843 << ". target_round: " << task.target_round;
844 if (request->current_round() >= task.target_round) {
845 TaskRejected* rejected = client_heartbeat->mutable_task_rejected();
846 // Exponentially try later and later rounds until consumers all agree.
847 int64_t round_offset = 2;
848 for (int i = 0; i < task.failures; ++i) {
849 round_offset *= 2;
850 }
851 rejected->set_new_target_round(
852 round_robin_rounds_[request->job_client_id()] + round_offset);
853 apply_update = true;
854 }
855 if (blocked_round.has_value() &&
856 blocked_round.value() <= task.target_round &&
857 !task.ready_consumers.contains(request->job_client_id())) {
858 client_heartbeat->set_task_accepted(true);
859 apply_update = true;
860 }
861 if (apply_update) {
862 TF_RETURN_IF_ERROR(Apply(update));
863 }
864 }
865 if (!job->pending_tasks.empty()) {
866 response->set_block_round(job->pending_tasks.front().target_round);
867 }
868
869 std::vector<std::shared_ptr<const Task>> tasks;
870 TF_RETURN_IF_ERROR(state_.TasksForJob(job->job_id, tasks));
871 for (const auto& task : tasks) {
872 TaskInfo* task_info = response->mutable_task_info()->Add();
873 task_info->set_worker_address(task->worker_address);
874 task_info->set_transfer_address(task->transfer_address);
875 task_info->set_task_id(task->task_id);
876 task_info->set_job_id(job->job_id);
877 task_info->set_starting_round(task->starting_round);
878 }
879 response->set_job_finished(job->finished);
880 VLOG(4) << "Found " << response->task_info_size()
881 << " tasks for job client id " << request->job_client_id();
882 return Status::OK();
883 }
884
GetWorkers(const GetWorkersRequest * request,GetWorkersResponse * response)885 Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request,
886 GetWorkersResponse* response) {
887 TF_RETURN_IF_ERROR(CheckStarted());
888 mutex_lock l(mu_);
889 VLOG(3) << "Enter GetWorkers";
890 std::vector<std::shared_ptr<const Worker>> workers = state_.ListWorkers();
891 for (const auto& worker : workers) {
892 WorkerInfo* info = response->add_workers();
893 info->set_address(worker->address);
894 }
895 VLOG(3) << "Returning list of " << response->workers_size()
896 << " workers from GetWorkers";
897 return Status::OK();
898 }
899
PopulateTaskDef(std::shared_ptr<const Task> task,TaskDef * task_def) const900 Status DataServiceDispatcherImpl::PopulateTaskDef(
901 std::shared_ptr<const Task> task, TaskDef* task_def) const
902 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
903 task_def->set_dataset_id(task->job->dataset_id);
904 task_def->set_job_id(task->job->job_id);
905 task_def->set_worker_address(task->worker_address);
906 task_def->set_task_id(task->task_id);
907 *task_def->mutable_processing_mode_def() = task->job->processing_mode;
908 if (IsStaticShard(task->job->processing_mode)) {
909 task_def->set_num_workers(config_.worker_addresses_size());
910 TF_ASSIGN_OR_RETURN(int64_t worker_index,
911 state_.GetWorkerIndex(task->worker_address));
912 task_def->set_worker_index(worker_index);
913 }
914 if (task->job->distributed_epoch_state.has_value()) {
915 task_def->set_num_split_providers(
916 task->job->distributed_epoch_state.value().indices.size());
917 }
918 if (task->job->num_consumers.has_value()) {
919 task_def->set_num_consumers(task->job->num_consumers.value());
920 }
921 std::shared_ptr<const Dataset> dataset;
922 TF_RETURN_IF_ERROR(state_.DatasetFromId(task->job->dataset_id, dataset));
923 std::string dataset_key =
924 DatasetKey(dataset->dataset_id, dataset->fingerprint);
925 if (config_.work_dir().empty()) {
926 std::shared_ptr<const DatasetDef> dataset_def;
927 TF_RETURN_IF_ERROR(dataset_store_->Get(dataset_key, dataset_def));
928 *task_def->mutable_dataset_def() = *dataset_def;
929 } else {
930 std::string path =
931 io::JoinPath(DatasetsDir(config_.work_dir()), dataset_key);
932 task_def->set_path(path);
933 }
934 return Status::OK();
935 }
936
CheckStarted()937 Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) {
938 mutex_lock l(mu_);
939 if (!started_) {
940 return errors::Unavailable("Dispatcher has not started yet.");
941 }
942 return Status::OK();
943 }
944
RecordSplitProduced(int64_t job_id,int64_t repetition,int64_t split_provider_index,bool finished)945 Status DataServiceDispatcherImpl::RecordSplitProduced(
946 int64_t job_id, int64_t repetition, int64_t split_provider_index,
947 bool finished) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
948 Update update;
949 ProduceSplitUpdate* produce_split = update.mutable_produce_split();
950 produce_split->set_job_id(job_id);
951 produce_split->set_repetition(repetition);
952 produce_split->set_split_provider_index(split_provider_index);
953 produce_split->set_finished(finished);
954 return Apply(update);
955 }
956
ApplyWithoutJournaling(const Update & update)957 Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update)
958 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
959 return state_.Apply(update);
960 }
961
Apply(const Update & update)962 Status DataServiceDispatcherImpl::Apply(const Update& update)
963 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
964 if (journal_writer_.has_value()) {
965 TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update));
966 }
967 return state_.Apply(update);
968 }
969
JobGcThread()970 void DataServiceDispatcherImpl::JobGcThread() {
971 int64_t next_check_micros = 0;
972 while (true) {
973 mutex_lock l(mu_);
974 while (!cancelled_ && env_->NowMicros() < next_check_micros) {
975 int64_t remaining_micros = next_check_micros - env_->NowMicros();
976 job_gc_thread_cv_.wait_for(l,
977 std::chrono::microseconds(remaining_micros));
978 }
979 if (cancelled_) {
980 return;
981 }
982 Status s = GcOldJobs();
983 if (!s.ok()) {
984 LOG(WARNING) << "Error garbage collecting old jobs: " << s;
985 }
986 next_check_micros =
987 env_->NowMicros() + (config_.job_gc_check_interval_ms() * 1000);
988 }
989 }
990
GcOldJobs()991 Status DataServiceDispatcherImpl::GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
992 std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
993 int64_t now = env_->NowMicros();
994 for (const auto& job : jobs) {
995 if (job->finished || job->num_clients > 0 ||
996 job->last_client_released_micros < 0 ||
997 now < job->last_client_released_micros +
998 (config_.job_gc_timeout_ms() * 1000)) {
999 continue;
1000 }
1001 Update update;
1002 update.mutable_garbage_collect_job()->set_job_id(job->job_id);
1003 TF_RETURN_IF_ERROR(state_.Apply(update));
1004 LOG(INFO) << "Garbage collected job " << job->DebugString();
1005 }
1006 return Status::OK();
1007 }
1008
GetDatasetDef(int64_t dataset_id,std::shared_ptr<const DatasetDef> & dataset_def)1009 Status DataServiceDispatcherImpl::GetDatasetDef(
1010 int64_t dataset_id, std::shared_ptr<const DatasetDef>& dataset_def)
1011 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1012 std::shared_ptr<const Dataset> dataset;
1013 TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
1014 return GetDatasetDef(*dataset, dataset_def);
1015 }
1016
GetDatasetDef(const Dataset & dataset,std::shared_ptr<const DatasetDef> & dataset_def)1017 Status DataServiceDispatcherImpl::GetDatasetDef(
1018 const Dataset& dataset, std::shared_ptr<const DatasetDef>& dataset_def)
1019 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1020 std::string key = DatasetKey(dataset.dataset_id, dataset.fingerprint);
1021 return dataset_store_->Get(key, dataset_def);
1022 }
1023
1024 } // namespace data
1025 } // namespace tensorflow
1026