• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
17 
18 #include <memory>
19 #include <queue>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/optional.h"
28 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
29 #include "tensorflow/core/data/service/common.h"
30 #include "tensorflow/core/data/service/common.pb.h"
31 #include "tensorflow/core/data/service/journal.h"
32 #include "tensorflow/core/data/service/journal.pb.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/protobuf/data_service.pb.h"
35 #include "tensorflow/core/protobuf/service_config.pb.h"
36 
37 namespace tensorflow {
38 namespace data {
39 
40 // A class encapsulating the journaled state of the dispatcher. All state
41 // modifications must be done via `Apply`. This helps to ensure that
42 // replaying the journal will allow us to restore the exact same state.
43 //
44 // The following usage pattern will keep the journal in sync with the state of
45 // the dispatcher:
46 // {
47 //   mutex_lock l(mu_);
48 //   Update update = ...  // create an update
49 //   dispatcher_state.Apply(update);
50 //   journal_writer.write(Update);
51 //   // Unlock mu_
52 // }
53 //
54 // The division of functionality between DispatcherImpl and DispatcherState is
55 // as follows:
56 //   - DispatcherImpl is responsible for handling RPC requests, reading from
57 //     DispatcherState, and deciding what updates to apply to DispatcherState.
58 //     DispatcherImpl handles all synchronization.
59 //   - DispatcherState is responsible for making the state changes requested by
60 //     DispatcherImpl and for providing DispatcherImpl with read-only access to
61 //     the state.
62 //
63 // DispatcherState is thread-compatible but not thread-safe.
64 class DispatcherState {
65  public:
66   DispatcherState();
67   explicit DispatcherState(
68       const experimental::DispatcherConfig& dispatcher_config);
69   DispatcherState(const DispatcherState&) = delete;
70   DispatcherState& operator=(const DispatcherState&) = delete;
71 
72   // Applies the given update to the dispatcher's state.
73   Status Apply(const Update& update);
74 
75   // A dataset registered with the dispatcher.
76   struct Dataset {
DatasetDataset77     explicit Dataset(int64_t dataset_id, int64_t fingerprint)
78         : dataset_id(dataset_id), fingerprint(fingerprint) {}
79 
80     const int64 dataset_id;
81     const int64 fingerprint;
82   };
83 
84   // A worker registered with the dispatcher.
85   struct Worker {
WorkerWorker86     explicit Worker(const std::string& address,
87                     const std::string& transfer_address)
88         : address(address), transfer_address(transfer_address) {}
89 
90     const std::string address;
91     const std::string transfer_address;
92   };
93 
94   // A key for identifying a named job. The key contains a user-specified name,
95   // as well as an index describing which iteration of the job we are on.
96   struct NamedJobKey {
NamedJobKeyNamedJobKey97     explicit NamedJobKey(absl::string_view name, int64_t index)
98         : name(name), index(index) {}
99 
100     friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
101       return lhs.name == rhs.name && lhs.index == rhs.index;
102     }
103 
104     template <typename H>
AbslHashValueNamedJobKey105     friend H AbslHashValue(H h, const NamedJobKey& k) {
106       return H::combine(std::move(h), k.name, k.index);
107     }
108 
109     const std::string name;
110     const int64 index;
111   };
112 
113   struct DistributedEpochState {
DistributedEpochStateDistributedEpochState114     explicit DistributedEpochState(int64_t num_split_providers)
115         : repetitions(num_split_providers), indices(num_split_providers) {}
116 
117     // The current repetition for each split provider.
118     std::vector<int64> repetitions;
119     // Number of splits produced so far by each split provider.
120     std::vector<int64> indices;
121   };
122 
123   struct Task;
124 
125   struct PendingTask {
PendingTaskPendingTask126     explicit PendingTask(std::shared_ptr<Task> task, int64_t target_round)
127         : task(std::move(task)), target_round(target_round) {}
128 
129     std::shared_ptr<Task> task;
130     // The target round where we want to insert the task.
131     int64 target_round;
132     // Which consumers have responded that they have successfully blocked
133     // before the target round.
134     absl::flat_hash_set<int64> ready_consumers;
135     // How many times we have failed to add the task.
136     int64 failures = 0;
137   };
138 
139   // A job for processing a dataset.
140   struct Job {
JobJob141     explicit Job(int64_t job_id, int64_t dataset_id,
142                  const ProcessingModeDef& processing_mode,
143                  int64_t num_split_providers,
144                  absl::optional<NamedJobKey> named_job_key,
145                  absl::optional<int64> num_consumers,
146                  TargetWorkers target_workers)
147         : job_id(job_id),
148           dataset_id(dataset_id),
149           processing_mode(processing_mode),
150           named_job_key(named_job_key),
151           num_consumers(num_consumers),
152           target_workers(target_workers) {
153       if (IsDynamicShard(processing_mode)) {
154         distributed_epoch_state = DistributedEpochState(num_split_providers);
155       }
156     }
157 
IsRoundRobinJob158     bool IsRoundRobin() const { return num_consumers.has_value(); }
159 
DebugStringJob160     std::string DebugString() const {
161       if (named_job_key.has_value()) {
162         return absl::StrCat(named_job_key.value().name, "_",
163                             named_job_key.value().index);
164       }
165       return absl::StrCat(job_id);
166     }
167 
168     const int64 job_id;
169     const int64 dataset_id;
170     const ProcessingModeDef processing_mode;
171     const absl::optional<NamedJobKey> named_job_key;
172     absl::optional<DistributedEpochState> distributed_epoch_state;
173     const absl::optional<int64> num_consumers;
174     const TargetWorkers target_workers;
175     std::queue<PendingTask> pending_tasks;
176     int64 num_clients = 0;
177     int64 last_client_released_micros = -1;
178     bool finished = false;
179     // Indicates whether the job was garbage collected.
180     bool garbage_collected = false;
181   };
182 
183   struct Task {
TaskTask184     explicit Task(int64_t task_id, const std::shared_ptr<Job>& job,
185                   const std::string& worker_address,
186                   const std::string& transfer_address)
187         : task_id(task_id),
188           job(job),
189           worker_address(worker_address),
190           transfer_address(transfer_address) {}
191 
192     const int64 task_id;
193     const std::shared_ptr<Job> job;
194     const std::string worker_address;
195     const std::string transfer_address;
196     int64 starting_round = 0;
197     bool finished = false;
198     bool removed = false;
199   };
200 
201   using TasksById = absl::flat_hash_map<int64, std::shared_ptr<Task>>;
202 
203   // Returns the next available dataset id.
204   int64 NextAvailableDatasetId() const;
205   // Gets the element_spec by searching for the dataset_id key.
206   Status GetElementSpec(int64_t dataset_id, std::string& element_spec) const;
207   // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset.
208   Status DatasetFromId(int64_t id,
209                        std::shared_ptr<const Dataset>& dataset) const;
210   // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such
211   // dataset.
212   Status DatasetFromFingerprint(uint64 fingerprint,
213                                 std::shared_ptr<const Dataset>& dataset) const;
214 
215   // Gets a worker by address. Returns NOT_FOUND if there is no such worker.
216   Status WorkerFromAddress(const std::string& address,
217                            std::shared_ptr<const Worker>& worker) const;
218   // Lists all workers registered with the dispatcher.
219   std::vector<std::shared_ptr<const Worker>> ListWorkers() const;
220 
221   // Returns the next available job id.
222   int64 NextAvailableJobId() const;
223   // Returns a list of all jobs.
224   std::vector<std::shared_ptr<const Job>> ListJobs();
225   // Gets a job by id. Returns NOT_FOUND if there is no such job.
226   Status JobFromId(int64_t id, std::shared_ptr<const Job>& job) const;
227   // Gets a named job by key. Returns NOT_FOUND if there is no such job.
228   Status NamedJobByKey(NamedJobKey key, std::shared_ptr<const Job>& job) const;
229 
230   // Returns the job associated with the given job client id. Returns NOT_FOUND
231   // if the job_client_id is unknown or has been released.
232   Status JobForJobClientId(int64_t job_client_id,
233                            std::shared_ptr<const Job>& job);
234   // Returns the next available job client id.
235   int64 NextAvailableJobClientId() const;
236 
237   // Returns the next available task id.
238   int64 NextAvailableTaskId() const;
239   // Gets a task by id. Returns NOT_FOUND if there is no such task.
240   Status TaskFromId(int64_t id, std::shared_ptr<const Task>& task) const;
241   // Stores a list of all tasks for the given job to `tasks`. Returns NOT_FOUND
242   // if there is no such job.
243   Status TasksForJob(int64_t job_id,
244                      std::vector<std::shared_ptr<const Task>>& tasks) const;
245   // Stores a list of all tasks for the given worker to `tasks`. Returns
246   // NOT_FOUND if there is no such worker.
247   Status TasksForWorker(const absl::string_view worker_address,
248                         std::vector<std::shared_ptr<const Task>>& tasks) const;
249 
250   // If the dispatcher config explicitly specifies a list of workers, validates
251   // `worker_address` is in the list.
252   Status ValidateWorker(absl::string_view worker_address) const;
253 
254   // If the dispatcher config specifies worker addresses, `GetWorkerIndex`
255   // returns the worker index according to the list. This is useful for
256   // deterministically sharding a dataset among a fixed set of workers.
257   StatusOr<int64> GetWorkerIndex(absl::string_view worker_address) const;
258 
259  private:
260   void RegisterDataset(const RegisterDatasetUpdate& register_dataset);
261   void RegisterWorker(const RegisterWorkerUpdate& register_worker);
262   void CreateJob(const CreateJobUpdate& create_job);
263   void ProduceSplit(const ProduceSplitUpdate& produce_split);
264   void AcquireJobClient(const AcquireJobClientUpdate& acquire_job_client);
265   void ReleaseJobClient(const ReleaseJobClientUpdate& release_job_client);
266   void GarbageCollectJob(const GarbageCollectJobUpdate& garbage_collect_job);
267   void RemoveTask(const RemoveTaskUpdate& remove_task);
268   void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task);
269   void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat);
270   void CreateTask(const CreateTaskUpdate& create_task);
271   void FinishTask(const FinishTaskUpdate& finish_task);
272   void SetElementSpec(const SetElementSpecUpdate& set_element_spec);
273 
274   int64 next_available_dataset_id_ = 1000;
275   // Registered datasets, keyed by dataset ids.
276   absl::flat_hash_map<int64, std::shared_ptr<Dataset>> datasets_by_id_;
277   // Registered datasets, keyed by dataset fingerprints.
278   absl::flat_hash_map<uint64, std::shared_ptr<Dataset>>
279       datasets_by_fingerprint_;
280   // Saved element_spec, keyed by dataset ids.
281   absl::flat_hash_map<int64, std::string> id_element_spec_info_;
282 
283   // Registered workers, keyed by address.
284   absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
285 
286   // Assigns an index to each worker according to worker addresses list
287   // specified in the dispatcher config.
288   WorkerIndexResolver worker_index_resolver_;
289 
290   int64 next_available_job_id_ = 2000;
291   // Jobs, keyed by job ids.
292   absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_;
293   // Named jobs, keyed by their names and indices. Not all jobs have names, so
294   // this is a subset of the jobs stored in `jobs_`.
295   absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_;
296 
297   int64 next_available_job_client_id_ = 3000;
298   // Mapping from client ids to the jobs they are associated with.
299   absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_for_client_ids_;
300 
301   int64 next_available_task_id_ = 4000;
302   // Tasks, keyed by task ids.
303   TasksById tasks_;
304   // List of tasks associated with each job.
305   absl::flat_hash_map<int64, std::vector<std::shared_ptr<Task>>> tasks_by_job_;
306   // Tasks, keyed by worker addresses. The values are a map from task id to
307   // task.
308   absl::flat_hash_map<std::string, TasksById> tasks_by_worker_;
309 };
310 
311 }  // namespace data
312 }  // namespace tensorflow
313 
314 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
315