1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 17 18 #include <memory> 19 #include <queue> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/container/flat_hash_map.h" 25 #include "absl/container/flat_hash_set.h" 26 #include "absl/strings/string_view.h" 27 #include "absl/types/optional.h" 28 #include "tensorflow/core/data/service/auto_shard_rewriter.h" 29 #include "tensorflow/core/data/service/common.h" 30 #include "tensorflow/core/data/service/common.pb.h" 31 #include "tensorflow/core/data/service/journal.h" 32 #include "tensorflow/core/data/service/journal.pb.h" 33 #include "tensorflow/core/platform/status.h" 34 #include "tensorflow/core/protobuf/data_service.pb.h" 35 #include "tensorflow/core/protobuf/service_config.pb.h" 36 37 namespace tensorflow { 38 namespace data { 39 40 // A class encapsulating the journaled state of the dispatcher. All state 41 // modifications must be done via `Apply`. This helps to ensure that 42 // replaying the journal will allow us to restore the exact same state. 43 // 44 // The following usage pattern will keep the journal in sync with the state of 45 // the dispatcher: 46 // { 47 // mutex_lock l(mu_); 48 // Update update = ... // create an update 49 // dispatcher_state.Apply(update); 50 // journal_writer.write(Update); 51 // // Unlock mu_ 52 // } 53 // 54 // The division of functionality between DispatcherImpl and DispatcherState is 55 // as follows: 56 // - DispatcherImpl is responsible for handling RPC requests, reading from 57 // DispatcherState, and deciding what updates to apply to DispatcherState. 58 // DispatcherImpl handles all synchronization. 59 // - DispatcherState is responsible for making the state changes requested by 60 // DispatcherImpl and for providing DispatcherImpl with read-only access to 61 // the state. 62 // 63 // DispatcherState is thread-compatible but not thread-safe. 64 class DispatcherState { 65 public: 66 DispatcherState(); 67 explicit DispatcherState( 68 const experimental::DispatcherConfig& dispatcher_config); 69 DispatcherState(const DispatcherState&) = delete; 70 DispatcherState& operator=(const DispatcherState&) = delete; 71 72 // Applies the given update to the dispatcher's state. 73 Status Apply(const Update& update); 74 75 // A dataset registered with the dispatcher. 76 struct Dataset { DatasetDataset77 explicit Dataset(int64_t dataset_id, int64_t fingerprint) 78 : dataset_id(dataset_id), fingerprint(fingerprint) {} 79 80 const int64 dataset_id; 81 const int64 fingerprint; 82 }; 83 84 // A worker registered with the dispatcher. 85 struct Worker { WorkerWorker86 explicit Worker(const std::string& address, 87 const std::string& transfer_address) 88 : address(address), transfer_address(transfer_address) {} 89 90 const std::string address; 91 const std::string transfer_address; 92 }; 93 94 // A key for identifying a named job. The key contains a user-specified name, 95 // as well as an index describing which iteration of the job we are on. 96 struct NamedJobKey { NamedJobKeyNamedJobKey97 explicit NamedJobKey(absl::string_view name, int64_t index) 98 : name(name), index(index) {} 99 100 friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) { 101 return lhs.name == rhs.name && lhs.index == rhs.index; 102 } 103 104 template <typename H> AbslHashValueNamedJobKey105 friend H AbslHashValue(H h, const NamedJobKey& k) { 106 return H::combine(std::move(h), k.name, k.index); 107 } 108 109 const std::string name; 110 const int64 index; 111 }; 112 113 struct DistributedEpochState { DistributedEpochStateDistributedEpochState114 explicit DistributedEpochState(int64_t num_split_providers) 115 : repetitions(num_split_providers), indices(num_split_providers) {} 116 117 // The current repetition for each split provider. 118 std::vector<int64> repetitions; 119 // Number of splits produced so far by each split provider. 120 std::vector<int64> indices; 121 }; 122 123 struct Task; 124 125 struct PendingTask { PendingTaskPendingTask126 explicit PendingTask(std::shared_ptr<Task> task, int64_t target_round) 127 : task(std::move(task)), target_round(target_round) {} 128 129 std::shared_ptr<Task> task; 130 // The target round where we want to insert the task. 131 int64 target_round; 132 // Which consumers have responded that they have successfully blocked 133 // before the target round. 134 absl::flat_hash_set<int64> ready_consumers; 135 // How many times we have failed to add the task. 136 int64 failures = 0; 137 }; 138 139 // A job for processing a dataset. 140 struct Job { JobJob141 explicit Job(int64_t job_id, int64_t dataset_id, 142 const ProcessingModeDef& processing_mode, 143 int64_t num_split_providers, 144 absl::optional<NamedJobKey> named_job_key, 145 absl::optional<int64> num_consumers, 146 TargetWorkers target_workers) 147 : job_id(job_id), 148 dataset_id(dataset_id), 149 processing_mode(processing_mode), 150 named_job_key(named_job_key), 151 num_consumers(num_consumers), 152 target_workers(target_workers) { 153 if (IsDynamicShard(processing_mode)) { 154 distributed_epoch_state = DistributedEpochState(num_split_providers); 155 } 156 } 157 IsRoundRobinJob158 bool IsRoundRobin() const { return num_consumers.has_value(); } 159 DebugStringJob160 std::string DebugString() const { 161 if (named_job_key.has_value()) { 162 return absl::StrCat(named_job_key.value().name, "_", 163 named_job_key.value().index); 164 } 165 return absl::StrCat(job_id); 166 } 167 168 const int64 job_id; 169 const int64 dataset_id; 170 const ProcessingModeDef processing_mode; 171 const absl::optional<NamedJobKey> named_job_key; 172 absl::optional<DistributedEpochState> distributed_epoch_state; 173 const absl::optional<int64> num_consumers; 174 const TargetWorkers target_workers; 175 std::queue<PendingTask> pending_tasks; 176 int64 num_clients = 0; 177 int64 last_client_released_micros = -1; 178 bool finished = false; 179 // Indicates whether the job was garbage collected. 180 bool garbage_collected = false; 181 }; 182 183 struct Task { TaskTask184 explicit Task(int64_t task_id, const std::shared_ptr<Job>& job, 185 const std::string& worker_address, 186 const std::string& transfer_address) 187 : task_id(task_id), 188 job(job), 189 worker_address(worker_address), 190 transfer_address(transfer_address) {} 191 192 const int64 task_id; 193 const std::shared_ptr<Job> job; 194 const std::string worker_address; 195 const std::string transfer_address; 196 int64 starting_round = 0; 197 bool finished = false; 198 bool removed = false; 199 }; 200 201 using TasksById = absl::flat_hash_map<int64, std::shared_ptr<Task>>; 202 203 // Returns the next available dataset id. 204 int64 NextAvailableDatasetId() const; 205 // Gets the element_spec by searching for the dataset_id key. 206 Status GetElementSpec(int64_t dataset_id, std::string& element_spec) const; 207 // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset. 208 Status DatasetFromId(int64_t id, 209 std::shared_ptr<const Dataset>& dataset) const; 210 // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such 211 // dataset. 212 Status DatasetFromFingerprint(uint64 fingerprint, 213 std::shared_ptr<const Dataset>& dataset) const; 214 215 // Gets a worker by address. Returns NOT_FOUND if there is no such worker. 216 Status WorkerFromAddress(const std::string& address, 217 std::shared_ptr<const Worker>& worker) const; 218 // Lists all workers registered with the dispatcher. 219 std::vector<std::shared_ptr<const Worker>> ListWorkers() const; 220 221 // Returns the next available job id. 222 int64 NextAvailableJobId() const; 223 // Returns a list of all jobs. 224 std::vector<std::shared_ptr<const Job>> ListJobs(); 225 // Gets a job by id. Returns NOT_FOUND if there is no such job. 226 Status JobFromId(int64_t id, std::shared_ptr<const Job>& job) const; 227 // Gets a named job by key. Returns NOT_FOUND if there is no such job. 228 Status NamedJobByKey(NamedJobKey key, std::shared_ptr<const Job>& job) const; 229 230 // Returns the job associated with the given job client id. Returns NOT_FOUND 231 // if the job_client_id is unknown or has been released. 232 Status JobForJobClientId(int64_t job_client_id, 233 std::shared_ptr<const Job>& job); 234 // Returns the next available job client id. 235 int64 NextAvailableJobClientId() const; 236 237 // Returns the next available task id. 238 int64 NextAvailableTaskId() const; 239 // Gets a task by id. Returns NOT_FOUND if there is no such task. 240 Status TaskFromId(int64_t id, std::shared_ptr<const Task>& task) const; 241 // Stores a list of all tasks for the given job to `tasks`. Returns NOT_FOUND 242 // if there is no such job. 243 Status TasksForJob(int64_t job_id, 244 std::vector<std::shared_ptr<const Task>>& tasks) const; 245 // Stores a list of all tasks for the given worker to `tasks`. Returns 246 // NOT_FOUND if there is no such worker. 247 Status TasksForWorker(const absl::string_view worker_address, 248 std::vector<std::shared_ptr<const Task>>& tasks) const; 249 250 // If the dispatcher config explicitly specifies a list of workers, validates 251 // `worker_address` is in the list. 252 Status ValidateWorker(absl::string_view worker_address) const; 253 254 // If the dispatcher config specifies worker addresses, `GetWorkerIndex` 255 // returns the worker index according to the list. This is useful for 256 // deterministically sharding a dataset among a fixed set of workers. 257 StatusOr<int64> GetWorkerIndex(absl::string_view worker_address) const; 258 259 private: 260 void RegisterDataset(const RegisterDatasetUpdate& register_dataset); 261 void RegisterWorker(const RegisterWorkerUpdate& register_worker); 262 void CreateJob(const CreateJobUpdate& create_job); 263 void ProduceSplit(const ProduceSplitUpdate& produce_split); 264 void AcquireJobClient(const AcquireJobClientUpdate& acquire_job_client); 265 void ReleaseJobClient(const ReleaseJobClientUpdate& release_job_client); 266 void GarbageCollectJob(const GarbageCollectJobUpdate& garbage_collect_job); 267 void RemoveTask(const RemoveTaskUpdate& remove_task); 268 void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task); 269 void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat); 270 void CreateTask(const CreateTaskUpdate& create_task); 271 void FinishTask(const FinishTaskUpdate& finish_task); 272 void SetElementSpec(const SetElementSpecUpdate& set_element_spec); 273 274 int64 next_available_dataset_id_ = 1000; 275 // Registered datasets, keyed by dataset ids. 276 absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_; 277 // Registered datasets, keyed by dataset fingerprints. 278 absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> 279 datasets_by_fingerprint_; 280 // Saved element_spec, keyed by dataset ids. 281 absl::flat_hash_map<int64, std::string> id_element_spec_info_; 282 283 // Registered workers, keyed by address. 284 absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_; 285 286 // Assigns an index to each worker according to worker addresses list 287 // specified in the dispatcher config. 288 WorkerIndexResolver worker_index_resolver_; 289 290 int64 next_available_job_id_ = 2000; 291 // Jobs, keyed by job ids. 292 absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_; 293 // Named jobs, keyed by their names and indices. Not all jobs have names, so 294 // this is a subset of the jobs stored in `jobs_`. 295 absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_; 296 297 int64 next_available_job_client_id_ = 3000; 298 // Mapping from client ids to the jobs they are associated with. 299 absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_; 300 301 int64 next_available_task_id_ = 4000; 302 // Tasks, keyed by task ids. 303 TasksById tasks_; 304 // List of tasks associated with each job. 305 absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_; 306 // Tasks, keyed by worker addresses. The values are a map from task id to 307 // task. 308 absl::flat_hash_map<std::string, TasksById> tasks_by_worker_; 309 }; 310 311 } // namespace data 312 } // namespace tensorflow 313 314 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 315