1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 18 19 #include <memory> 20 #include <string> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/flat_hash_set.h" 25 #include "absl/types/optional.h" 26 #include "tensorflow/core/data/service/common.h" 27 #include "tensorflow/core/data/service/common.pb.h" 28 #include "tensorflow/core/data/service/dataset_store.h" 29 #include "tensorflow/core/data/service/dispatcher.pb.h" 30 #include "tensorflow/core/data/service/dispatcher_state.h" 31 #include "tensorflow/core/data/service/task_remover.h" 32 #include "tensorflow/core/data/service/worker.grpc.pb.h" 33 #include "tensorflow/core/framework/dataset.h" 34 #include "tensorflow/core/platform/env.h" 35 #include "tensorflow/core/platform/macros.h" 36 #include "tensorflow/core/platform/mutex.h" 37 #include "tensorflow/core/platform/status.h" 38 #include "tensorflow/core/platform/thread_annotations.h" 39 #include "tensorflow/core/protobuf/data_service.pb.h" 40 #include "tensorflow/core/protobuf/service_config.pb.h" 41 #include "tensorflow/core/public/session.h" 42 43 namespace tensorflow { 44 namespace data { 45 46 // A service which coordinates a pool of workers to serve dataset elements over 47 // RPC. 48 // 49 // Glossary: 50 // * Dataset: A definition of how to generate a potentially large collection of 51 // elements. 52 // * Job: A coordinated phase of reading from the tf.data service. A job 53 // produces some amount of data, and (potentially multiple) consumers consume 54 // the data from the job until there is no data left. Each job has a 55 // ProcessingModeDef which determines what data it produces. 56 // * Task: A job is broken into multiple tasks, which each represent 57 // iterating over all of or part of the dataset. Workers process tasks. 58 // * Consumer: A process reading from the tf.data service. 59 // 60 // **Adding workers** 61 // 62 // tf.data service supports adding workers mid-job. When a new worker connects 63 // to the dispatcher, the dispatcher creates a new task for the worker, one task 64 // for each outstanding job. Consumers periodically heartbeat to the dispatcher 65 // to learn about new tasks. 66 // 67 // For non-round-robin-reads, there is no coordination among consumers. Each 68 // consumer will start reading from the new task as soon as it learns about the 69 // task from its heartbeat. Round robin reads, on the other hand, require 70 // consumers to read from the same task at each step. This requires coordination 71 // to ensure that all consumers start reading from the new task in the same 72 // round. 73 // 74 // The protocol for adding round robin tasks works as follows: 75 // 76 // - The dispatcher keeps track of which round each round-robin job is on. This 77 // information is reported by consumers in their heartbeats. 78 // - When a new worker joins and there is an outstanding round-robin job, we 79 // create a new task for the job and assign it to the worker. 80 // However, we don't yet report the task in consumer heartbeats. 81 // We call the task a "pending task" and add it to its job's "pending tasks" 82 // queue. 83 // - When we create a pending task, we choose a "target round" to try adding 84 // the task to. The target round is chosen by adding a "target round delta" to 85 // the latest reported round for the job. 86 // - When a consumer heartbeats for a job and there is a pending task for that 87 // job, the dispatcher sends a heartbeat response telling the consumer to 88 // block before reading from the target round. 89 // - When a consumer receives a heartbeat response telling it to block 90 // (before reading) a round, the consumer try to block the round. If the 91 // consumer has already started the round, it will too late to block the 92 // round. 93 // - When consumers heartbeat, they tell the dispatcher their current round and 94 // whether they have blocked themselves from reading past a certain round. If 95 // a consumer reports a current round exceeding the target round, the target 96 // round has failed and needs to be increased. We choose a new target round by 97 // doubling the previous target round delta. If the consumer reports that it 98 // has blocked before the target round, we record that the consumer is ready 99 // to add the new task. Once all consumers are ready to add the new task, we 100 // remove the task from the pending tasks list and begin reporting the task to 101 // consumers. We set the "starting_round" field of the task to indicate the 102 // target round where all consumers should start reading from the task. 103 // - If a new worker joins while there are already pending tasks, a pending 104 // task for the new worker is created and queued behind the existing tasks. 105 // The new task won't be considered until all previous pending tasks have been 106 // successfully added. 107 // 108 // An example of executing this protocol with two consumers could go as follows: 109 // 1. Consumers read up to round 50 and heartbeat that they are on round 50. 110 // 2. A new worker joins. Dispatcher chooses round 51 as the target round. 111 // 3. Consumer 1 heartbeats that its current round is 50. Dispatcher tells it to 112 // block round 51. 113 // 4. Consumer 2 heartbeats that its current round is 51. Dispatcher realizes 114 // that it is too late to block round 51 and chooses round 53 as the new 115 // target round. Dispatcher tells consumer 2 to block round 53. 116 // 5. Consumer 1 heartbeats that its current round is 50 and that it has blocked 117 // round 51. Dispatcher tells it to block round 53 instead. Dispatcher 118 // records that consumer 1 is ready to add a task in round 53. 119 // 6. Consumer 2 heartbeats that its current round is 52 and it has blocked 120 // round 53. Dispatcher realizes that all consumers are blocked on round 53 121 // or earlier and promotes the task from pending to regular. Dispatcher sends 122 // consumer 2 a task list containing the new task, and tells consumer 2 that 123 // it no longer needs to block. 124 // 7. Consumer 1 heartbeats. Dispatcher sends consumer 1 the task list 125 // containing the new task, and tells it that it no longer needs to block. 126 // 127 class DataServiceDispatcherImpl { 128 public: 129 explicit DataServiceDispatcherImpl( 130 const experimental::DispatcherConfig& config); 131 132 ~DataServiceDispatcherImpl(); 133 134 // Starts the dispatcher. If there is a journal, this will read from the 135 // journal to restore the dispatcher's state. 136 Status Start(); 137 138 // See dispatcher.proto for API documentation. 139 140 /// Worker-facing API. 141 Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, 142 WorkerHeartbeatResponse* response); 143 Status WorkerUpdate(const WorkerUpdateRequest* request, 144 WorkerUpdateResponse* response); 145 Status GetDatasetDef(const GetDatasetDefRequest* request, 146 GetDatasetDefResponse* response); 147 Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response); 148 149 /// Client-facing API. 150 Status GetVersion(const GetVersionRequest* request, 151 GetVersionResponse* response); 152 Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, 153 GetOrRegisterDatasetResponse* response); 154 Status GetElementSpec(const GetElementSpecRequest* request, 155 GetElementSpecResponse* response); 156 Status GetOrCreateJob(const GetOrCreateJobRequest* request, 157 GetOrCreateJobResponse* response); 158 Status ReleaseJobClient(const ReleaseJobClientRequest* request, 159 ReleaseJobClientResponse* response); 160 Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request, 161 MaybeRemoveTaskResponse* response); 162 Status ClientHeartbeat(const ClientHeartbeatRequest* request, 163 ClientHeartbeatResponse* response); 164 Status GetWorkers(const GetWorkersRequest* request, 165 GetWorkersResponse* response); 166 167 private: 168 // Restores split providers from the state in `job` and stores them in 169 // `restored`. 170 Status RestoreSplitProviders( 171 const DispatcherState::Job& job, 172 std::vector<std::unique_ptr<SplitProvider>>& restored) 173 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 174 // Makes split providers for the specified `dataset_id`, and stores thent in 175 // `split_providers`. 176 Status MakeSplitProviders( 177 int64_t dataset_id, 178 std::vector<std::unique_ptr<SplitProvider>>& split_providers) 179 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 180 // Registers a dataset with the given fingerprint, storing the new dataset's 181 // id in `dataset_id`. 182 Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset, 183 int64& dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 184 // Sets the element spec of the dataset for the specified `dataset_id`. 185 Status SetElementSpec(int64_t dataset_id, const std::string& element_spec) 186 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 187 // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a 188 // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is 189 // stored in `out_stub`. 190 Status GetOrCreateWorkerStub(const std::string& worker_address, 191 WorkerService::Stub*& out_stub) 192 TF_LOCKS_EXCLUDED(mu_); 193 // Creates a job and stores it in `job`. This method updates the 194 // dispatcher state with the new job, but does not assign tasks to workers. 195 Status CreateJob(const GetOrCreateJobRequest& request, 196 std::shared_ptr<const DispatcherState::Job>& job) 197 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 198 // Creates tasks for the specified worker, one task for every unfinished job. 199 Status CreateTasksForWorker(const std::string& worker_address); 200 // Finds tasks that should be deleted from a worker, updating the heartbeat 201 // response. 202 Status FindTasksToDelete( 203 const absl::flat_hash_set<int64>& current_tasks, 204 const std::vector<std::shared_ptr<const DispatcherState::Task>> 205 assigned_tasks, 206 WorkerHeartbeatResponse* response); 207 // Finds new tasks that should be assigned to a worker and adds them to 208 // the heartbeat response. 209 Status FindNewTasks( 210 const std::string& worker_address, 211 const absl::flat_hash_set<int64>& current_tasks, 212 std::vector<std::shared_ptr<const DispatcherState::Task>>& assigned_tasks, 213 WorkerHeartbeatResponse* response); 214 // Acquires a job client id to read from the given job and sets 215 // `job_client_id`. 216 Status AcquireJobClientId( 217 const std::shared_ptr<const DispatcherState::Job>& job, 218 int64& job_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 219 // Creates one task for each worker, for the given job. The created tasks are 220 // stored in `tasks`. This method only updates dispatcher metadata with the 221 // new tasks, but doesn't assign the tasks to the workers. 222 Status CreateTasksForJob( 223 std::shared_ptr<const DispatcherState::Job> job, 224 std::vector<std::shared_ptr<const DispatcherState::Task>>& tasks) 225 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 226 227 // Creates a new task for a job. The created task may be either pending or 228 // active. 229 Status CreateTask(std::shared_ptr<const DispatcherState::Job> job, 230 const std::string& worker_address, 231 std::shared_ptr<const DispatcherState::Task>& task) 232 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 233 // Creates a pending task for a round robin job. All consumers need to agree 234 // on which round to add the task in before the pending task can be promoted 235 // to a regular task. 236 Status CreatePendingTask(std::shared_ptr<const DispatcherState::Job> job, 237 const std::string& worker_address) 238 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 239 // Creates a new active task for a job, storing the created task in `task`. 240 Status CreateActiveTask(std::shared_ptr<const DispatcherState::Job> job, 241 const std::string& worker_address, 242 std::shared_ptr<const DispatcherState::Task>& task); 243 // Assigns the list of tasks to the workers indicated by their 244 // `worker_address` fields. 245 Status AssignTasks( 246 std::vector<std::shared_ptr<const DispatcherState::Task>> tasks) 247 TF_LOCKS_EXCLUDED(mu_); 248 // Assigns a task to the worker indicated by its `worker_address` field. 249 Status AssignTask(std::shared_ptr<const DispatcherState::Task> task) 250 TF_LOCKS_EXCLUDED(mu_); 251 // Validates that an existing job matches the requested processing mode, 252 // returning an error status describing any difference. 253 Status ValidateMatchingJob(std::shared_ptr<const DispatcherState::Job> job, 254 const GetOrCreateJobRequest& request) 255 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 256 // Fills out a TaskDef with information about a task. 257 Status PopulateTaskDef(std::shared_ptr<const DispatcherState::Task> task, 258 TaskDef* task_def) const 259 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 260 // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't. 261 Status CheckStarted() TF_LOCKS_EXCLUDED(mu_); 262 // Records that a split was produced by a call to `GetSplit`. 263 Status RecordSplitProduced(int64_t job_id, int64_t repetition, 264 int64_t split_provider_index, bool finished) 265 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 266 // Applies a state update, updating both the journal and the in-memory state. 267 Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 268 // Applies a state update, but doesn't update the journal. Only meant to be 269 // used when recovering state when the dispatcher starts. 270 Status ApplyWithoutJournaling(const Update& update) 271 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 272 // A thread which periodically checks for jobs to clean up. 273 void JobGcThread(); 274 // Scans for old jobs and marks them as finished. 275 Status GcOldJobs() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 276 // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and 277 // stores it in `dataset_def`. 278 Status GetDatasetDef(int64_t dataset_id, 279 std::shared_ptr<const DatasetDef>& dataset_def) 280 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 281 // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and 282 // stores it in `dataset_def`. 283 Status GetDatasetDef(const DispatcherState::Dataset& dataset, 284 std::shared_ptr<const DatasetDef>& dataset_def) 285 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 286 287 const experimental::DispatcherConfig& config_; 288 Env* env_; 289 290 mutex mu_; 291 bool started_ TF_GUARDED_BY(mu_) = false; 292 bool cancelled_ TF_GUARDED_BY(mu_) = false; 293 294 // Cached worker stubs for communicating with workers. 295 absl::flat_hash_map<std::string, std::unique_ptr<WorkerService::Stub>> 296 worker_stubs_ TF_GUARDED_BY(mu_); 297 // Store of dataset definitions. 298 std::unique_ptr<DatasetStore> dataset_store_ TF_GUARDED_BY(mu_); 299 // Mapping from job id to the split providers for the job. 300 absl::flat_hash_map<int64, std::vector<std::unique_ptr<SplitProvider>>> 301 split_providers_ TF_GUARDED_BY(mu_); 302 // Mapping from round robin job id to the round the job is currently on. This 303 // is based on the data provided by client heartbeats, and may be stale. 304 absl::flat_hash_map<int64, int64> round_robin_rounds_ TF_GUARDED_BY(mu_); 305 // Map from task id to a TaskRemover which determines when to remove the task. 306 absl::flat_hash_map<int64, std::shared_ptr<TaskRemover>> remove_task_requests_ 307 TF_GUARDED_BY(mu_); 308 309 absl::optional<std::unique_ptr<JournalWriter>> journal_writer_ 310 TF_GUARDED_BY(mu_); 311 DispatcherState state_ TF_GUARDED_BY(mu_); 312 // Condition variable for waking up the job gc thread. 313 condition_variable job_gc_thread_cv_; 314 std::unique_ptr<Thread> job_gc_thread_; 315 316 TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl); 317 }; 318 319 } // namespace data 320 } // namespace tensorflow 321 322 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_ 323