• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/types/optional.h"
26 #include "tensorflow/core/data/service/common.h"
27 #include "tensorflow/core/data/service/common.pb.h"
28 #include "tensorflow/core/data/service/dataset_store.h"
29 #include "tensorflow/core/data/service/dispatcher.pb.h"
30 #include "tensorflow/core/data/service/dispatcher_state.h"
31 #include "tensorflow/core/data/service/task_remover.h"
32 #include "tensorflow/core/data/service/worker.grpc.pb.h"
33 #include "tensorflow/core/framework/dataset.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/macros.h"
36 #include "tensorflow/core/platform/mutex.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/thread_annotations.h"
39 #include "tensorflow/core/protobuf/data_service.pb.h"
40 #include "tensorflow/core/protobuf/service_config.pb.h"
41 #include "tensorflow/core/public/session.h"
42 
43 namespace tensorflow {
44 namespace data {
45 
46 // A service which coordinates a pool of workers to serve dataset elements over
47 // RPC.
48 //
49 // Glossary:
50 // * Dataset: A definition of how to generate a potentially large collection of
51 //   elements.
52 // * Job: A coordinated phase of reading from the tf.data service. A job
53 //   produces some amount of data, and (potentially multiple) consumers consume
54 //   the data from the job until there is no data left. Each job has a
55 //   ProcessingModeDef which determines what data it produces.
56 // * Task: A job is broken into multiple tasks, which each represent
57 //   iterating over all of or part of the dataset. Workers process tasks.
58 // * Consumer: A process reading from the tf.data service.
59 //
60 // **Adding workers**
61 //
62 // tf.data service supports adding workers mid-job. When a new worker connects
63 // to the dispatcher, the dispatcher creates a new task for the worker, one task
64 // for each outstanding job. Consumers periodically heartbeat to the dispatcher
65 // to learn about new tasks.
66 //
67 // For non-round-robin-reads, there is no coordination among consumers. Each
68 // consumer will start reading from the new task as soon as it learns about the
69 // task from its heartbeat. Round robin reads, on the other hand, require
70 // consumers to read from the same task at each step. This requires coordination
71 // to ensure that all consumers start reading from the new task in the same
72 // round.
73 //
74 // The protocol for adding round robin tasks works as follows:
75 //
76 // - The dispatcher keeps track of which round each round-robin job is on. This
77 //   information is reported by consumers in their heartbeats.
78 // - When a new worker joins and there is an outstanding round-robin job, we
79 //   create a new task for the job and assign it to the worker.
80 //   However, we don't yet report the task in consumer heartbeats.
81 //   We call the task a "pending task" and add it to its job's "pending tasks"
82 //   queue.
83 // - When we create a pending task, we choose a "target round" to try adding
84 //   the task to. The target round is chosen by adding a "target round delta" to
85 //   the latest reported round for the job.
86 // - When a consumer heartbeats for a job and there is a pending task for that
87 //   job, the dispatcher sends a heartbeat response telling the consumer to
88 //   block before reading from the target round.
89 // - When a consumer receives a heartbeat response telling it to block
90 //   (before reading) a round, the consumer try to block the round. If the
91 //   consumer has already started the round, it will too late to block the
92 //   round.
93 // - When consumers heartbeat, they tell the dispatcher their current round and
94 //   whether they have blocked themselves from reading past a certain round. If
95 //   a consumer reports a current round exceeding the target round, the target
96 //   round has failed and needs to be increased. We choose a new target round by
97 //   doubling the previous target round delta. If the consumer reports that it
98 //   has blocked before the target round, we record that the consumer is ready
99 //   to add the new task. Once all consumers are ready to add the new task, we
100 //   remove the task from the pending tasks list and begin reporting the task to
101 //   consumers. We set the "starting_round" field of the task to indicate the
102 //   target round where all consumers should start reading from the task.
103 // - If a new worker joins while there are already pending tasks, a pending
104 //   task for the new worker is created and queued behind the existing tasks.
105 //   The new task won't be considered until all previous pending tasks have been
106 //   successfully added.
107 //
108 // An example of executing this protocol with two consumers could go as follows:
109 // 1. Consumers read up to round 50 and heartbeat that they are on round 50.
110 // 2. A new worker joins. Dispatcher chooses round 51 as the target round.
111 // 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to
112 //    block round 51.
113 // 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes
114 //    that it is too late to block round 51 and chooses round 53 as the new
115 //    target round. Dispatcher tells consumer 2 to block round 53.
116 // 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked
117 //    round 51. Dispatcher tells it to block round 53 instead. Dispatcher
118 //    records that consumer 1 is ready to add a task in round 53.
119 // 6. Consumer 2 heartbeats that its current round is 52 and it has blocked
120 //    round 53. Dispatcher realizes that all consumers are blocked on round 53
121 //    or earlier and promotes the task from pending to regular. Dispatcher sends
122 //    consumer 2 a task list containing the new task, and tells consumer 2 that
123 //    it no longer needs to block.
124 // 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list
125 //    containing the new task, and tells it that it no longer needs to block.
126 //
127 class DataServiceDispatcherImpl {
128  public:
129   explicit DataServiceDispatcherImpl(
130       const experimental::DispatcherConfig& config);
131 
132   ~DataServiceDispatcherImpl();
133 
134   // Starts the dispatcher. If there is a journal, this will read from the
135   // journal to restore the dispatcher's state.
136   Status Start();
137 
138   // See dispatcher.proto for API documentation.
139 
140   /// Worker-facing API.
141   Status WorkerHeartbeat(const WorkerHeartbeatRequest* request,
142                          WorkerHeartbeatResponse* response);
143   Status WorkerUpdate(const WorkerUpdateRequest* request,
144                       WorkerUpdateResponse* response);
145   Status GetDatasetDef(const GetDatasetDefRequest* request,
146                        GetDatasetDefResponse* response);
147   Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response);
148 
149   /// Client-facing API.
150   Status GetVersion(const GetVersionRequest* request,
151                     GetVersionResponse* response);
152   Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request,
153                               GetOrRegisterDatasetResponse* response);
154   Status GetElementSpec(const GetElementSpecRequest* request,
155                         GetElementSpecResponse* response);
156   Status GetOrCreateJob(const GetOrCreateJobRequest* request,
157                         GetOrCreateJobResponse* response);
158   Status ReleaseJobClient(const ReleaseJobClientRequest* request,
159                           ReleaseJobClientResponse* response);
160   Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request,
161                          MaybeRemoveTaskResponse* response);
162   Status ClientHeartbeat(const ClientHeartbeatRequest* request,
163                          ClientHeartbeatResponse* response);
164   Status GetWorkers(const GetWorkersRequest* request,
165                     GetWorkersResponse* response);
166 
167  private:
168   // Restores split providers from the state in `job` and stores them in
169   // `restored`.
170   Status RestoreSplitProviders(
171       const DispatcherState::Job& job,
172       std::vector<std::unique_ptr<SplitProvider>>& restored)
173       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
174   // Makes split providers for the specified `dataset_id`, and stores thent in
175   // `split_providers`.
176   Status MakeSplitProviders(
177       int64_t dataset_id,
178       std::vector<std::unique_ptr<SplitProvider>>& split_providers)
179       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
180   // Registers a dataset with the given fingerprint, storing the new dataset's
181   // id in `dataset_id`.
182   Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
183                          int64& dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
184   // Sets the element spec of the dataset for the specified `dataset_id`.
185   Status SetElementSpec(int64_t dataset_id, const std::string& element_spec)
186       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
187   // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
188   // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is
189   // stored in `out_stub`.
190   Status GetOrCreateWorkerStub(const std::string& worker_address,
191                                WorkerService::Stub*& out_stub)
192       TF_LOCKS_EXCLUDED(mu_);
193   // Creates a job and stores it in `job`. This method updates the
194   // dispatcher state with the new job, but does not assign tasks to workers.
195   Status CreateJob(const GetOrCreateJobRequest& request,
196                    std::shared_ptr<const DispatcherState::Job>& job)
197       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
198   // Creates tasks for the specified worker, one task for every unfinished job.
199   Status CreateTasksForWorker(const std::string& worker_address);
200   // Finds tasks that should be deleted from a worker, updating the heartbeat
201   // response.
202   Status FindTasksToDelete(
203       const absl::flat_hash_set<int64>& current_tasks,
204       const std::vector<std::shared_ptr<const DispatcherState::Task>>
205           assigned_tasks,
206       WorkerHeartbeatResponse* response);
207   // Finds new tasks that should be assigned to a worker and adds them to
208   // the heartbeat response.
209   Status FindNewTasks(
210       const std::string& worker_address,
211       const absl::flat_hash_set<int64>& current_tasks,
212       std::vector<std::shared_ptr<const DispatcherState::Task>>& assigned_tasks,
213       WorkerHeartbeatResponse* response);
214   // Acquires a job client id to read from the given job and sets
215   // `job_client_id`.
216   Status AcquireJobClientId(
217       const std::shared_ptr<const DispatcherState::Job>& job,
218       int64& job_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
219   // Creates one task for each worker, for the given job. The created tasks are
220   // stored in `tasks`. This method only updates dispatcher metadata with the
221   // new tasks, but doesn't assign the tasks to the workers.
222   Status CreateTasksForJob(
223       std::shared_ptr<const DispatcherState::Job> job,
224       std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks)
225       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
226 
227   // Creates a new task for a job. The created task may be either pending or
228   // active.
229   Status CreateTask(std::shared_ptr<const DispatcherState::Job> job,
230                     const std::string& worker_address,
231                     std::shared_ptr<const DispatcherState::Task>& task)
232       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
233   // Creates a pending task for a round robin job. All consumers need to agree
234   // on which round to add the task in before the pending task can be promoted
235   // to a regular task.
236   Status CreatePendingTask(std::shared_ptr<const DispatcherState::Job> job,
237                            const std::string& worker_address)
238       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
239   // Creates a new active task for a job, storing the created task in `task`.
240   Status CreateActiveTask(std::shared_ptr<const DispatcherState::Job> job,
241                           const std::string& worker_address,
242                           std::shared_ptr<const DispatcherState::Task>& task);
243   // Assigns the list of tasks to the workers indicated by their
244   // `worker_address` fields.
245   Status AssignTasks(
246       std::vector<std::shared_ptr<const DispatcherState::Task>> tasks)
247       TF_LOCKS_EXCLUDED(mu_);
248   // Assigns a task to the worker indicated by its `worker_address` field.
249   Status AssignTask(std::shared_ptr<const DispatcherState::Task> task)
250       TF_LOCKS_EXCLUDED(mu_);
251   // Validates that an existing job matches the requested processing mode,
252   // returning an error status describing any difference.
253   Status ValidateMatchingJob(std::shared_ptr<const DispatcherState::Job> job,
254                              const GetOrCreateJobRequest& request)
255       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
256   // Fills out a TaskDef with information about a task.
257   Status PopulateTaskDef(std::shared_ptr<const DispatcherState::Task> task,
258                          TaskDef* task_def) const
259       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
260   // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't.
261   Status CheckStarted() TF_LOCKS_EXCLUDED(mu_);
262   // Records that a split was produced by a call to `GetSplit`.
263   Status RecordSplitProduced(int64_t job_id, int64_t repetition,
264                              int64_t split_provider_index, bool finished)
265       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
266   // Applies a state update, updating both the journal and the in-memory state.
267   Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
268   // Applies a state update, but doesn't update the journal. Only meant to be
269   // used when recovering state when the dispatcher starts.
270   Status ApplyWithoutJournaling(const Update& update)
271       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
272   // A thread which periodically checks for jobs to clean up.
273   void JobGcThread();
274   // Scans for old jobs and marks them as finished.
275   Status GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
276   // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and
277   // stores it in `dataset_def`.
278   Status GetDatasetDef(int64_t dataset_id,
279                        std::shared_ptr<const DatasetDef>& dataset_def)
280       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
281   // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and
282   // stores it in `dataset_def`.
283   Status GetDatasetDef(const DispatcherState::Dataset& dataset,
284                        std::shared_ptr<const DatasetDef>& dataset_def)
285       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
286 
287   const experimental::DispatcherConfig& config_;
288   Env* env_;
289 
290   mutex mu_;
291   bool started_ TF_GUARDED_BY(mu_) = false;
292   bool cancelled_ TF_GUARDED_BY(mu_) = false;
293 
294   // Cached worker stubs for communicating with workers.
295   absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>>
296       worker_stubs_ TF_GUARDED_BY(mu_);
297   // Store of dataset definitions.
298   std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_);
299   // Mapping from job id to the split providers for the job.
300   absl::flat_hash_map<int64, std::vector<std::unique_ptr<SplitProvider>>>
301       split_providers_ TF_GUARDED_BY(mu_);
302   // Mapping from round robin job id to the round the job is currently on. This
303   // is based on the data provided by client heartbeats, and may be stale.
304   absl::flat_hash_map<int64, int64> round_robin_rounds_ TF_GUARDED_BY(mu_);
305   // Map from task id to a TaskRemover which determines when to remove the task.
306   absl::flat_hash_map<int64, std::shared_ptr<TaskRemover>> remove_task_requests_
307       TF_GUARDED_BY(mu_);
308 
309   absl::optional<std::unique_ptr<JournalWriter>> journal_writer_
310       TF_GUARDED_BY(mu_);
311   DispatcherState state_ TF_GUARDED_BY(mu_);
312   // Condition variable for waking up the job gc thread.
313   condition_variable job_gc_thread_cv_;
314   std::unique_ptr<Thread> job_gc_thread_;
315 
316   TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
317 };
318 
319 }  // namespace data
320 }  // namespace tensorflow
321 
322 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
323