1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/dispatcher_state.h"
16 
17 #include <memory>
18 #include <string>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/data/service/common.h"
25 #include "tensorflow/core/data/service/journal.h"
26 #include "tensorflow/core/data/service/journal.pb.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/protobuf/data_service.pb.h"
30 #include "tensorflow/core/protobuf/service_config.pb.h"
31 
32 namespace tensorflow {
33 namespace data {
34 
DispatcherState()35 DispatcherState::DispatcherState()
36     : worker_index_resolver_(std::vector<std::string>{}) {}
37 
DispatcherState(const experimental::DispatcherConfig & dispatcher_config)38 DispatcherState::DispatcherState(
39     const experimental::DispatcherConfig& dispatcher_config)
40     : worker_index_resolver_(dispatcher_config.worker_addresses()) {}
41 
Apply(const Update & update)42 Status DispatcherState::Apply(const Update& update) {
43   switch (update.update_type_case()) {
44     case Update::kRegisterDataset:
45       RegisterDataset(update.register_dataset());
46       break;
47     case Update::kRegisterWorker:
48       RegisterWorker(update.register_worker());
49       break;
50     case Update::kCreateJob:
51       CreateJob(update.create_job());
52       break;
53     case Update::kProduceSplit:
54       ProduceSplit(update.produce_split());
55       break;
56     case Update::kAcquireJobClient:
57       AcquireJobClient(update.acquire_job_client());
58       break;
59     case Update::kReleaseJobClient:
60       ReleaseJobClient(update.release_job_client());
61       break;
62     case Update::kGarbageCollectJob:
63       GarbageCollectJob(update.garbage_collect_job());
64       break;
65     case Update::kRemoveTask:
66       RemoveTask(update.remove_task());
67       break;
68     case Update::kCreatePendingTask:
69       CreatePendingTask(update.create_pending_task());
70       break;
71     case Update::kClientHeartbeat:
72       ClientHeartbeat(update.client_heartbeat());
73       break;
74     case Update::kCreateTask:
75       CreateTask(update.create_task());
76       break;
77     case Update::kFinishTask:
78       FinishTask(update.finish_task());
79       break;
80     case Update::kSetElementSpec:
81       SetElementSpec(update.set_element_spec());
82       break;
83     case Update::UPDATE_TYPE_NOT_SET:
84       return errors::Internal("Update type not set.");
85   }
86 
87   return Status::OK();
88 }
89 
RegisterDataset(const RegisterDatasetUpdate & register_dataset)90 void DispatcherState::RegisterDataset(
91     const RegisterDatasetUpdate& register_dataset) {
92   int64_t id = register_dataset.dataset_id();
93   int64_t fingerprint = register_dataset.fingerprint();
94   auto dataset = std::make_shared<Dataset>(id, fingerprint);
95   DCHECK(!datasets_by_id_.contains(id));
96   datasets_by_id_[id] = dataset;
97   DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
98   datasets_by_fingerprint_[fingerprint] = dataset;
99   next_available_dataset_id_ = std::max(next_available_dataset_id_, id + 1);
100 }
101 
RegisterWorker(const RegisterWorkerUpdate & register_worker)102 void DispatcherState::RegisterWorker(
103     const RegisterWorkerUpdate& register_worker) {
104   std::string address = register_worker.worker_address();
105   DCHECK(!workers_.contains(address));
106   workers_[address] =
107       std::make_shared<Worker>(address, register_worker.transfer_address());
108   tasks_by_worker_[address] =
109       absl::flat_hash_map<int64, std::shared_ptr<Task>>();
110   worker_index_resolver_.AddWorker(address);
111 }
112 
CreateJob(const CreateJobUpdate & create_job)113 void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
114   int64_t job_id = create_job.job_id();
115   absl::optional<NamedJobKey> named_job_key;
116   if (create_job.has_named_job_key()) {
117     named_job_key.emplace(create_job.named_job_key().name(),
118                           create_job.named_job_key().index());
119   }
120   absl::optional<int64> num_consumers;
121   if (create_job.optional_num_consumers_case() ==
122       CreateJobUpdate::kNumConsumers) {
123     num_consumers = create_job.num_consumers();
124   }
125   auto job = std::make_shared<Job>(
126       job_id, create_job.dataset_id(), create_job.processing_mode_def(),
127       create_job.num_split_providers(), named_job_key, num_consumers,
128       create_job.target_workers());
129   DCHECK(!jobs_.contains(job_id));
130   jobs_[job_id] = job;
131   tasks_by_job_[job_id] = std::vector<std::shared_ptr<Task>>();
132   if (named_job_key.has_value()) {
133     DCHECK(!named_jobs_.contains(named_job_key.value()) ||
134            named_jobs_[named_job_key.value()]->garbage_collected);
135     named_jobs_[named_job_key.value()] = job;
136   }
137   next_available_job_id_ = std::max(next_available_job_id_, job_id + 1);
138 }
139 
ProduceSplit(const ProduceSplitUpdate & produce_split)140 void DispatcherState::ProduceSplit(const ProduceSplitUpdate& produce_split) {
141   std::shared_ptr<Job> job = jobs_[produce_split.job_id()];
142   DCHECK(job->distributed_epoch_state.has_value());
143   DistributedEpochState& state = job->distributed_epoch_state.value();
144   int64_t provider_index = produce_split.split_provider_index();
145   DCHECK_EQ(produce_split.repetition(), state.repetitions[provider_index]);
146   if (produce_split.finished()) {
147     state.repetitions[provider_index]++;
148     state.indices[provider_index] = 0;
149     return;
150   }
151   state.indices[provider_index]++;
152 }
153 
AcquireJobClient(const AcquireJobClientUpdate & acquire_job_client)154 void DispatcherState::AcquireJobClient(
155     const AcquireJobClientUpdate& acquire_job_client) {
156   int64_t job_client_id = acquire_job_client.job_client_id();
157   std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
158   DCHECK(!job);
159   job = jobs_[acquire_job_client.job_id()];
160   DCHECK(job);
161   job->num_clients++;
162   next_available_job_client_id_ =
163       std::max(next_available_job_client_id_, job_client_id + 1);
164 }
165 
ReleaseJobClient(const ReleaseJobClientUpdate & release_job_client)166 void DispatcherState::ReleaseJobClient(
167     const ReleaseJobClientUpdate& release_job_client) {
168   int64_t job_client_id = release_job_client.job_client_id();
169   std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
170   DCHECK(job);
171   job->num_clients--;
172   DCHECK_GE(job->num_clients, 0);
173   job->last_client_released_micros = release_job_client.time_micros();
174   jobs_for_client_ids_.erase(job_client_id);
175 }
176 
GarbageCollectJob(const GarbageCollectJobUpdate & garbage_collect_job)177 void DispatcherState::GarbageCollectJob(
178     const GarbageCollectJobUpdate& garbage_collect_job) {
179   int64_t job_id = garbage_collect_job.job_id();
180   for (auto& task : tasks_by_job_[job_id]) {
181     task->finished = true;
182     tasks_by_worker_[task->worker_address].erase(task->task_id);
183   }
184   jobs_[job_id]->finished = true;
185   jobs_[job_id]->garbage_collected = true;
186 }
187 
RemoveTask(const RemoveTaskUpdate & remove_task)188 void DispatcherState::RemoveTask(const RemoveTaskUpdate& remove_task) {
189   std::shared_ptr<Task>& task = tasks_[remove_task.task_id()];
190   DCHECK(task);
191   task->removed = true;
192   auto& tasks_for_job = tasks_by_job_[task->job->job_id];
193   for (auto it = tasks_for_job.begin(); it != tasks_for_job.end(); ++it) {
194     if ((*it)->task_id == task->task_id) {
195       tasks_for_job.erase(it);
196       break;
197     }
198   }
199   tasks_by_worker_[task->worker_address].erase(task->task_id);
200   tasks_.erase(task->task_id);
201   VLOG(1) << "Removed task " << remove_task.task_id() << " from worker "
202           << task->worker_address;
203 }
204 
CreatePendingTask(const CreatePendingTaskUpdate & create_pending_task)205 void DispatcherState::CreatePendingTask(
206     const CreatePendingTaskUpdate& create_pending_task) {
207   int64_t task_id = create_pending_task.task_id();
208   auto& task = tasks_[task_id];
209   DCHECK_EQ(task, nullptr);
210   auto& job = jobs_[create_pending_task.job_id()];
211   DCHECK_NE(job, nullptr);
212   task =
213       std::make_shared<Task>(task_id, job, create_pending_task.worker_address(),
214                              create_pending_task.transfer_address());
215   job->pending_tasks.emplace(task, create_pending_task.starting_round());
216   tasks_by_worker_[create_pending_task.worker_address()][task->task_id] = task;
217   next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
218 }
219 
ClientHeartbeat(const ClientHeartbeatUpdate & client_heartbeat)220 void DispatcherState::ClientHeartbeat(
221     const ClientHeartbeatUpdate& client_heartbeat) {
222   int64_t job_client_id = client_heartbeat.job_client_id();
223   auto& job = jobs_for_client_ids_[job_client_id];
224   DCHECK(!job->pending_tasks.empty());
225   auto& task = job->pending_tasks.front();
226   if (client_heartbeat.has_task_rejected()) {
227     task.failures++;
228     task.ready_consumers.clear();
229     task.target_round = client_heartbeat.task_rejected().new_target_round();
230   }
231   if (client_heartbeat.task_accepted()) {
232     task.ready_consumers.insert(job_client_id);
233     if (task.ready_consumers.size() == job->num_consumers.value()) {
234       VLOG(1) << "Promoting task " << task.task->task_id
235               << " from pending to active";
236       task.task->starting_round = task.target_round;
237       tasks_by_job_[job->job_id].push_back(task.task);
238       job->pending_tasks.pop();
239     }
240   }
241 }
242 
CreateTask(const CreateTaskUpdate & create_task)243 void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
244   int64_t task_id = create_task.task_id();
245   auto& task = tasks_[task_id];
246   DCHECK_EQ(task, nullptr);
247   auto& job = jobs_[create_task.job_id()];
248   DCHECK_NE(job, nullptr);
249   task = std::make_shared<Task>(task_id, job, create_task.worker_address(),
250                                 create_task.transfer_address());
251   tasks_by_job_[create_task.job_id()].push_back(task);
252   tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
253   next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
254 }
255 
FinishTask(const FinishTaskUpdate & finish_task)256 void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
257   VLOG(2) << "Marking task " << finish_task.task_id() << " as finished";
258   int64_t task_id = finish_task.task_id();
259   auto& task = tasks_[task_id];
260   DCHECK(task != nullptr);
261   task->finished = true;
262   tasks_by_worker_[task->worker_address].erase(task->task_id);
263   bool all_finished = true;
264   for (const auto& task_for_job : tasks_by_job_[task->job->job_id]) {
265     if (!task_for_job->finished) {
266       all_finished = false;
267     }
268   }
269   VLOG(3) << "Job " << task->job->job_id << " finished: " << all_finished;
270   jobs_[task->job->job_id]->finished = all_finished;
271 }
272 
SetElementSpec(const SetElementSpecUpdate & set_element_spec)273 void DispatcherState::SetElementSpec(
274     const SetElementSpecUpdate& set_element_spec) {
275   int64_t dataset_id = set_element_spec.dataset_id();
276   std::string element_spec = set_element_spec.element_spec();
277   DCHECK(!id_element_spec_info_.contains(dataset_id));
278   id_element_spec_info_[dataset_id] = element_spec;
279 }
280 
GetElementSpec(int64_t dataset_id,std::string & element_spec) const281 Status DispatcherState::GetElementSpec(int64_t dataset_id,
282                                        std::string& element_spec) const {
283   auto it = id_element_spec_info_.find(dataset_id);
284   if (it == id_element_spec_info_.end()) {
285     return errors::NotFound("Element_spec with key ", dataset_id, " not found");
286   }
287   element_spec = it->second;
288   return Status::OK();
289 }
290 
NextAvailableDatasetId() const291 int64 DispatcherState::NextAvailableDatasetId() const {
292   return next_available_dataset_id_;
293 }
294 
DatasetFromId(int64_t id,std::shared_ptr<const Dataset> & dataset) const295 Status DispatcherState::DatasetFromId(
296     int64_t id, std::shared_ptr<const Dataset>& dataset) const {
297   auto it = datasets_by_id_.find(id);
298   if (it == datasets_by_id_.end()) {
299     return errors::NotFound("Dataset id ", id, " not found");
300   }
301   dataset = it->second;
302   return Status::OK();
303 }
304 
DatasetFromFingerprint(uint64 fingerprint,std::shared_ptr<const Dataset> & dataset) const305 Status DispatcherState::DatasetFromFingerprint(
306     uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
307   auto it = datasets_by_fingerprint_.find(fingerprint);
308   if (it == datasets_by_fingerprint_.end()) {
309     return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
310   }
311   dataset = it->second;
312   return Status::OK();
313 }
314 
WorkerFromAddress(const std::string & address,std::shared_ptr<const Worker> & worker) const315 Status DispatcherState::WorkerFromAddress(
316     const std::string& address, std::shared_ptr<const Worker>& worker) const {
317   auto it = workers_.find(address);
318   if (it == workers_.end()) {
319     return errors::NotFound("Worker with address ", address, " not found.");
320   }
321   worker = it->second;
322   return Status::OK();
323 }
324 
325 std::vector<std::shared_ptr<const DispatcherState::Worker>>
ListWorkers() const326 DispatcherState::ListWorkers() const {
327   std::vector<std::shared_ptr<const Worker>> workers;
328   workers.reserve(workers_.size());
329   for (const auto& it : workers_) {
330     workers.push_back(it.second);
331   }
332   return workers;
333 }
334 
335 std::vector<std::shared_ptr<const DispatcherState::Job>>
ListJobs()336 DispatcherState::ListJobs() {
337   std::vector<std::shared_ptr<const DispatcherState::Job>> jobs;
338   jobs.reserve(jobs_.size());
339   for (const auto& it : jobs_) {
340     jobs.push_back(it.second);
341   }
342   return jobs;
343 }
344 
JobFromId(int64_t id,std::shared_ptr<const Job> & job) const345 Status DispatcherState::JobFromId(int64_t id,
346                                   std::shared_ptr<const Job>& job) const {
347   auto it = jobs_.find(id);
348   if (it == jobs_.end()) {
349     return errors::NotFound("Job id ", id, " not found");
350   }
351   job = it->second;
352   return Status::OK();
353 }
354 
NamedJobByKey(NamedJobKey named_job_key,std::shared_ptr<const Job> & job) const355 Status DispatcherState::NamedJobByKey(NamedJobKey named_job_key,
356                                       std::shared_ptr<const Job>& job) const {
357   auto it = named_jobs_.find(named_job_key);
358   if (it == named_jobs_.end()) {
359     return errors::NotFound("Named job key (", named_job_key.name, ", ",
360                             named_job_key.index, ") not found");
361   }
362   job = it->second;
363   return Status::OK();
364 }
365 
NextAvailableJobId() const366 int64 DispatcherState::NextAvailableJobId() const {
367   return next_available_job_id_;
368 }
369 
JobForJobClientId(int64_t job_client_id,std::shared_ptr<const Job> & job)370 Status DispatcherState::JobForJobClientId(int64_t job_client_id,
371                                           std::shared_ptr<const Job>& job) {
372   job = jobs_for_client_ids_[job_client_id];
373   if (!job) {
374     return errors::NotFound("Job client id not found: ", job_client_id);
375   }
376   return Status::OK();
377 }
378 
NextAvailableJobClientId() const379 int64 DispatcherState::NextAvailableJobClientId() const {
380   return next_available_job_client_id_;
381 }
382 
TaskFromId(int64_t id,std::shared_ptr<const Task> & task) const383 Status DispatcherState::TaskFromId(int64_t id,
384                                    std::shared_ptr<const Task>& task) const {
385   auto it = tasks_.find(id);
386   if (it == tasks_.end()) {
387     return errors::NotFound("Task ", id, " not found");
388   }
389   task = it->second;
390   return Status::OK();
391 }
392 
TasksForJob(int64_t job_id,std::vector<std::shared_ptr<const Task>> & tasks) const393 Status DispatcherState::TasksForJob(
394     int64_t job_id, std::vector<std::shared_ptr<const Task>>& tasks) const {
395   auto it = tasks_by_job_.find(job_id);
396   if (it == tasks_by_job_.end()) {
397     return errors::NotFound("Job ", job_id, " not found");
398   }
399   tasks.clear();
400   tasks.reserve(it->second.size());
401   for (const auto& task : it->second) {
402     tasks.push_back(task);
403   }
404   return Status::OK();
405 }
406 
TasksForWorker(absl::string_view worker_address,std::vector<std::shared_ptr<const Task>> & tasks) const407 Status DispatcherState::TasksForWorker(
408     absl::string_view worker_address,
409     std::vector<std::shared_ptr<const Task>>& tasks) const {
410   tasks.clear();
411   auto it = tasks_by_worker_.find(worker_address);
412   if (it == tasks_by_worker_.end()) {
413     return errors::NotFound("Worker ", worker_address, " not found");
414   }
415   const absl::flat_hash_map<int64, std::shared_ptr<Task>>& worker_tasks =
416       it->second;
417   tasks.reserve(worker_tasks.size());
418   for (const auto& task : worker_tasks) {
419     tasks.push_back(task.second);
420   }
421   return Status::OK();
422 }
423 
NextAvailableTaskId() const424 int64 DispatcherState::NextAvailableTaskId() const {
425   return next_available_task_id_;
426 }
427 
ValidateWorker(absl::string_view worker_address) const428 Status DispatcherState::ValidateWorker(absl::string_view worker_address) const {
429   return worker_index_resolver_.ValidateWorker(worker_address);
430 }
431 
GetWorkerIndex(absl::string_view worker_address) const432 StatusOr<int64> DispatcherState::GetWorkerIndex(
433     absl::string_view worker_address) const {
434   return worker_index_resolver_.GetWorkerIndex(worker_address);
435 }
436 
437 }  // namespace data
438 }  // namespace tensorflow
439