1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/dispatcher_state.h"
16
17 #include <memory>
18 #include <string>
19 #include <vector>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/optional.h"
24 #include "tensorflow/core/data/service/common.h"
25 #include "tensorflow/core/data/service/journal.h"
26 #include "tensorflow/core/data/service/journal.pb.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/protobuf/data_service.pb.h"
30 #include "tensorflow/core/protobuf/service_config.pb.h"
31
32 namespace tensorflow {
33 namespace data {
34
DispatcherState()35 DispatcherState::DispatcherState()
36 : worker_index_resolver_(std::vector<std::string>{}) {}
37
DispatcherState(const experimental::DispatcherConfig & dispatcher_config)38 DispatcherState::DispatcherState(
39 const experimental::DispatcherConfig& dispatcher_config)
40 : worker_index_resolver_(dispatcher_config.worker_addresses()) {}
41
Apply(const Update & update)42 Status DispatcherState::Apply(const Update& update) {
43 switch (update.update_type_case()) {
44 case Update::kRegisterDataset:
45 RegisterDataset(update.register_dataset());
46 break;
47 case Update::kRegisterWorker:
48 RegisterWorker(update.register_worker());
49 break;
50 case Update::kCreateJob:
51 CreateJob(update.create_job());
52 break;
53 case Update::kProduceSplit:
54 ProduceSplit(update.produce_split());
55 break;
56 case Update::kAcquireJobClient:
57 AcquireJobClient(update.acquire_job_client());
58 break;
59 case Update::kReleaseJobClient:
60 ReleaseJobClient(update.release_job_client());
61 break;
62 case Update::kGarbageCollectJob:
63 GarbageCollectJob(update.garbage_collect_job());
64 break;
65 case Update::kRemoveTask:
66 RemoveTask(update.remove_task());
67 break;
68 case Update::kCreatePendingTask:
69 CreatePendingTask(update.create_pending_task());
70 break;
71 case Update::kClientHeartbeat:
72 ClientHeartbeat(update.client_heartbeat());
73 break;
74 case Update::kCreateTask:
75 CreateTask(update.create_task());
76 break;
77 case Update::kFinishTask:
78 FinishTask(update.finish_task());
79 break;
80 case Update::kSetElementSpec:
81 SetElementSpec(update.set_element_spec());
82 break;
83 case Update::UPDATE_TYPE_NOT_SET:
84 return errors::Internal("Update type not set.");
85 }
86
87 return Status::OK();
88 }
89
RegisterDataset(const RegisterDatasetUpdate & register_dataset)90 void DispatcherState::RegisterDataset(
91 const RegisterDatasetUpdate& register_dataset) {
92 int64_t id = register_dataset.dataset_id();
93 int64_t fingerprint = register_dataset.fingerprint();
94 auto dataset = std::make_shared<Dataset>(id, fingerprint);
95 DCHECK(!datasets_by_id_.contains(id));
96 datasets_by_id_[id] = dataset;
97 DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
98 datasets_by_fingerprint_[fingerprint] = dataset;
99 next_available_dataset_id_ = std::max(next_available_dataset_id_, id + 1);
100 }
101
RegisterWorker(const RegisterWorkerUpdate & register_worker)102 void DispatcherState::RegisterWorker(
103 const RegisterWorkerUpdate& register_worker) {
104 std::string address = register_worker.worker_address();
105 DCHECK(!workers_.contains(address));
106 workers_[address] =
107 std::make_shared<Worker>(address, register_worker.transfer_address());
108 tasks_by_worker_[address] =
109 absl::flat_hash_map<int64, std::shared_ptr<Task>>();
110 worker_index_resolver_.AddWorker(address);
111 }
112
CreateJob(const CreateJobUpdate & create_job)113 void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
114 int64_t job_id = create_job.job_id();
115 absl::optional<NamedJobKey> named_job_key;
116 if (create_job.has_named_job_key()) {
117 named_job_key.emplace(create_job.named_job_key().name(),
118 create_job.named_job_key().index());
119 }
120 absl::optional<int64> num_consumers;
121 if (create_job.optional_num_consumers_case() ==
122 CreateJobUpdate::kNumConsumers) {
123 num_consumers = create_job.num_consumers();
124 }
125 auto job = std::make_shared<Job>(
126 job_id, create_job.dataset_id(), create_job.processing_mode_def(),
127 create_job.num_split_providers(), named_job_key, num_consumers,
128 create_job.target_workers());
129 DCHECK(!jobs_.contains(job_id));
130 jobs_[job_id] = job;
131 tasks_by_job_[job_id] = std::vector<std::shared_ptr<Task>>();
132 if (named_job_key.has_value()) {
133 DCHECK(!named_jobs_.contains(named_job_key.value()) ||
134 named_jobs_[named_job_key.value()]->garbage_collected);
135 named_jobs_[named_job_key.value()] = job;
136 }
137 next_available_job_id_ = std::max(next_available_job_id_, job_id + 1);
138 }
139
ProduceSplit(const ProduceSplitUpdate & produce_split)140 void DispatcherState::ProduceSplit(const ProduceSplitUpdate& produce_split) {
141 std::shared_ptr<Job> job = jobs_[produce_split.job_id()];
142 DCHECK(job->distributed_epoch_state.has_value());
143 DistributedEpochState& state = job->distributed_epoch_state.value();
144 int64_t provider_index = produce_split.split_provider_index();
145 DCHECK_EQ(produce_split.repetition(), state.repetitions[provider_index]);
146 if (produce_split.finished()) {
147 state.repetitions[provider_index]++;
148 state.indices[provider_index] = 0;
149 return;
150 }
151 state.indices[provider_index]++;
152 }
153
AcquireJobClient(const AcquireJobClientUpdate & acquire_job_client)154 void DispatcherState::AcquireJobClient(
155 const AcquireJobClientUpdate& acquire_job_client) {
156 int64_t job_client_id = acquire_job_client.job_client_id();
157 std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
158 DCHECK(!job);
159 job = jobs_[acquire_job_client.job_id()];
160 DCHECK(job);
161 job->num_clients++;
162 next_available_job_client_id_ =
163 std::max(next_available_job_client_id_, job_client_id + 1);
164 }
165
ReleaseJobClient(const ReleaseJobClientUpdate & release_job_client)166 void DispatcherState::ReleaseJobClient(
167 const ReleaseJobClientUpdate& release_job_client) {
168 int64_t job_client_id = release_job_client.job_client_id();
169 std::shared_ptr<Job>& job = jobs_for_client_ids_[job_client_id];
170 DCHECK(job);
171 job->num_clients--;
172 DCHECK_GE(job->num_clients, 0);
173 job->last_client_released_micros = release_job_client.time_micros();
174 jobs_for_client_ids_.erase(job_client_id);
175 }
176
GarbageCollectJob(const GarbageCollectJobUpdate & garbage_collect_job)177 void DispatcherState::GarbageCollectJob(
178 const GarbageCollectJobUpdate& garbage_collect_job) {
179 int64_t job_id = garbage_collect_job.job_id();
180 for (auto& task : tasks_by_job_[job_id]) {
181 task->finished = true;
182 tasks_by_worker_[task->worker_address].erase(task->task_id);
183 }
184 jobs_[job_id]->finished = true;
185 jobs_[job_id]->garbage_collected = true;
186 }
187
RemoveTask(const RemoveTaskUpdate & remove_task)188 void DispatcherState::RemoveTask(const RemoveTaskUpdate& remove_task) {
189 std::shared_ptr<Task>& task = tasks_[remove_task.task_id()];
190 DCHECK(task);
191 task->removed = true;
192 auto& tasks_for_job = tasks_by_job_[task->job->job_id];
193 for (auto it = tasks_for_job.begin(); it != tasks_for_job.end(); ++it) {
194 if ((*it)->task_id == task->task_id) {
195 tasks_for_job.erase(it);
196 break;
197 }
198 }
199 tasks_by_worker_[task->worker_address].erase(task->task_id);
200 tasks_.erase(task->task_id);
201 VLOG(1) << "Removed task " << remove_task.task_id() << " from worker "
202 << task->worker_address;
203 }
204
CreatePendingTask(const CreatePendingTaskUpdate & create_pending_task)205 void DispatcherState::CreatePendingTask(
206 const CreatePendingTaskUpdate& create_pending_task) {
207 int64_t task_id = create_pending_task.task_id();
208 auto& task = tasks_[task_id];
209 DCHECK_EQ(task, nullptr);
210 auto& job = jobs_[create_pending_task.job_id()];
211 DCHECK_NE(job, nullptr);
212 task =
213 std::make_shared<Task>(task_id, job, create_pending_task.worker_address(),
214 create_pending_task.transfer_address());
215 job->pending_tasks.emplace(task, create_pending_task.starting_round());
216 tasks_by_worker_[create_pending_task.worker_address()][task->task_id] = task;
217 next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
218 }
219
ClientHeartbeat(const ClientHeartbeatUpdate & client_heartbeat)220 void DispatcherState::ClientHeartbeat(
221 const ClientHeartbeatUpdate& client_heartbeat) {
222 int64_t job_client_id = client_heartbeat.job_client_id();
223 auto& job = jobs_for_client_ids_[job_client_id];
224 DCHECK(!job->pending_tasks.empty());
225 auto& task = job->pending_tasks.front();
226 if (client_heartbeat.has_task_rejected()) {
227 task.failures++;
228 task.ready_consumers.clear();
229 task.target_round = client_heartbeat.task_rejected().new_target_round();
230 }
231 if (client_heartbeat.task_accepted()) {
232 task.ready_consumers.insert(job_client_id);
233 if (task.ready_consumers.size() == job->num_consumers.value()) {
234 VLOG(1) << "Promoting task " << task.task->task_id
235 << " from pending to active";
236 task.task->starting_round = task.target_round;
237 tasks_by_job_[job->job_id].push_back(task.task);
238 job->pending_tasks.pop();
239 }
240 }
241 }
242
CreateTask(const CreateTaskUpdate & create_task)243 void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
244 int64_t task_id = create_task.task_id();
245 auto& task = tasks_[task_id];
246 DCHECK_EQ(task, nullptr);
247 auto& job = jobs_[create_task.job_id()];
248 DCHECK_NE(job, nullptr);
249 task = std::make_shared<Task>(task_id, job, create_task.worker_address(),
250 create_task.transfer_address());
251 tasks_by_job_[create_task.job_id()].push_back(task);
252 tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
253 next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
254 }
255
FinishTask(const FinishTaskUpdate & finish_task)256 void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
257 VLOG(2) << "Marking task " << finish_task.task_id() << " as finished";
258 int64_t task_id = finish_task.task_id();
259 auto& task = tasks_[task_id];
260 DCHECK(task != nullptr);
261 task->finished = true;
262 tasks_by_worker_[task->worker_address].erase(task->task_id);
263 bool all_finished = true;
264 for (const auto& task_for_job : tasks_by_job_[task->job->job_id]) {
265 if (!task_for_job->finished) {
266 all_finished = false;
267 }
268 }
269 VLOG(3) << "Job " << task->job->job_id << " finished: " << all_finished;
270 jobs_[task->job->job_id]->finished = all_finished;
271 }
272
SetElementSpec(const SetElementSpecUpdate & set_element_spec)273 void DispatcherState::SetElementSpec(
274 const SetElementSpecUpdate& set_element_spec) {
275 int64_t dataset_id = set_element_spec.dataset_id();
276 std::string element_spec = set_element_spec.element_spec();
277 DCHECK(!id_element_spec_info_.contains(dataset_id));
278 id_element_spec_info_[dataset_id] = element_spec;
279 }
280
GetElementSpec(int64_t dataset_id,std::string & element_spec) const281 Status DispatcherState::GetElementSpec(int64_t dataset_id,
282 std::string& element_spec) const {
283 auto it = id_element_spec_info_.find(dataset_id);
284 if (it == id_element_spec_info_.end()) {
285 return errors::NotFound("Element_spec with key ", dataset_id, " not found");
286 }
287 element_spec = it->second;
288 return Status::OK();
289 }
290
NextAvailableDatasetId() const291 int64 DispatcherState::NextAvailableDatasetId() const {
292 return next_available_dataset_id_;
293 }
294
DatasetFromId(int64_t id,std::shared_ptr<const Dataset> & dataset) const295 Status DispatcherState::DatasetFromId(
296 int64_t id, std::shared_ptr<const Dataset>& dataset) const {
297 auto it = datasets_by_id_.find(id);
298 if (it == datasets_by_id_.end()) {
299 return errors::NotFound("Dataset id ", id, " not found");
300 }
301 dataset = it->second;
302 return Status::OK();
303 }
304
DatasetFromFingerprint(uint64 fingerprint,std::shared_ptr<const Dataset> & dataset) const305 Status DispatcherState::DatasetFromFingerprint(
306 uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
307 auto it = datasets_by_fingerprint_.find(fingerprint);
308 if (it == datasets_by_fingerprint_.end()) {
309 return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
310 }
311 dataset = it->second;
312 return Status::OK();
313 }
314
WorkerFromAddress(const std::string & address,std::shared_ptr<const Worker> & worker) const315 Status DispatcherState::WorkerFromAddress(
316 const std::string& address, std::shared_ptr<const Worker>& worker) const {
317 auto it = workers_.find(address);
318 if (it == workers_.end()) {
319 return errors::NotFound("Worker with address ", address, " not found.");
320 }
321 worker = it->second;
322 return Status::OK();
323 }
324
325 std::vector<std::shared_ptr<const DispatcherState::Worker>>
ListWorkers() const326 DispatcherState::ListWorkers() const {
327 std::vector<std::shared_ptr<const Worker>> workers;
328 workers.reserve(workers_.size());
329 for (const auto& it : workers_) {
330 workers.push_back(it.second);
331 }
332 return workers;
333 }
334
335 std::vector<std::shared_ptr<const DispatcherState::Job>>
ListJobs()336 DispatcherState::ListJobs() {
337 std::vector<std::shared_ptr<const DispatcherState::Job>> jobs;
338 jobs.reserve(jobs_.size());
339 for (const auto& it : jobs_) {
340 jobs.push_back(it.second);
341 }
342 return jobs;
343 }
344
JobFromId(int64_t id,std::shared_ptr<const Job> & job) const345 Status DispatcherState::JobFromId(int64_t id,
346 std::shared_ptr<const Job>& job) const {
347 auto it = jobs_.find(id);
348 if (it == jobs_.end()) {
349 return errors::NotFound("Job id ", id, " not found");
350 }
351 job = it->second;
352 return Status::OK();
353 }
354
NamedJobByKey(NamedJobKey named_job_key,std::shared_ptr<const Job> & job) const355 Status DispatcherState::NamedJobByKey(NamedJobKey named_job_key,
356 std::shared_ptr<const Job>& job) const {
357 auto it = named_jobs_.find(named_job_key);
358 if (it == named_jobs_.end()) {
359 return errors::NotFound("Named job key (", named_job_key.name, ", ",
360 named_job_key.index, ") not found");
361 }
362 job = it->second;
363 return Status::OK();
364 }
365
NextAvailableJobId() const366 int64 DispatcherState::NextAvailableJobId() const {
367 return next_available_job_id_;
368 }
369
JobForJobClientId(int64_t job_client_id,std::shared_ptr<const Job> & job)370 Status DispatcherState::JobForJobClientId(int64_t job_client_id,
371 std::shared_ptr<const Job>& job) {
372 job = jobs_for_client_ids_[job_client_id];
373 if (!job) {
374 return errors::NotFound("Job client id not found: ", job_client_id);
375 }
376 return Status::OK();
377 }
378
NextAvailableJobClientId() const379 int64 DispatcherState::NextAvailableJobClientId() const {
380 return next_available_job_client_id_;
381 }
382
TaskFromId(int64_t id,std::shared_ptr<const Task> & task) const383 Status DispatcherState::TaskFromId(int64_t id,
384 std::shared_ptr<const Task>& task) const {
385 auto it = tasks_.find(id);
386 if (it == tasks_.end()) {
387 return errors::NotFound("Task ", id, " not found");
388 }
389 task = it->second;
390 return Status::OK();
391 }
392
TasksForJob(int64_t job_id,std::vector<std::shared_ptr<const Task>> & tasks) const393 Status DispatcherState::TasksForJob(
394 int64_t job_id, std::vector<std::shared_ptr<const Task>>& tasks) const {
395 auto it = tasks_by_job_.find(job_id);
396 if (it == tasks_by_job_.end()) {
397 return errors::NotFound("Job ", job_id, " not found");
398 }
399 tasks.clear();
400 tasks.reserve(it->second.size());
401 for (const auto& task : it->second) {
402 tasks.push_back(task);
403 }
404 return Status::OK();
405 }
406
TasksForWorker(absl::string_view worker_address,std::vector<std::shared_ptr<const Task>> & tasks) const407 Status DispatcherState::TasksForWorker(
408 absl::string_view worker_address,
409 std::vector<std::shared_ptr<const Task>>& tasks) const {
410 tasks.clear();
411 auto it = tasks_by_worker_.find(worker_address);
412 if (it == tasks_by_worker_.end()) {
413 return errors::NotFound("Worker ", worker_address, " not found");
414 }
415 const absl::flat_hash_map<int64, std::shared_ptr<Task>>& worker_tasks =
416 it->second;
417 tasks.reserve(worker_tasks.size());
418 for (const auto& task : worker_tasks) {
419 tasks.push_back(task.second);
420 }
421 return Status::OK();
422 }
423
NextAvailableTaskId() const424 int64 DispatcherState::NextAvailableTaskId() const {
425 return next_available_task_id_;
426 }
427
ValidateWorker(absl::string_view worker_address) const428 Status DispatcherState::ValidateWorker(absl::string_view worker_address) const {
429 return worker_index_resolver_.ValidateWorker(worker_address);
430 }
431
GetWorkerIndex(absl::string_view worker_address) const432 StatusOr<int64> DispatcherState::GetWorkerIndex(
433 absl::string_view worker_address) const {
434 return worker_index_resolver_.GetWorkerIndex(worker_address);
435 }
436
437 } // namespace data
438 } // namespace tensorflow
439