1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 17 18 #include <memory> 19 #include <vector> 20 21 #include "tensorflow/core/data/service/common.pb.h" 22 #include "tensorflow/core/data/service/data_transfer.h" 23 #include "tensorflow/core/data/service/thread_safe_buffer.h" 24 #include "tensorflow/core/data/service/worker.pb.h" 25 #include "tensorflow/core/data/standalone.h" 26 #include "tensorflow/core/platform/macros.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/status.h" 29 #include "tensorflow/core/platform/statusor.h" 30 #include "tensorflow/core/platform/thread_annotations.h" 31 #include "tensorflow/core/protobuf/service_config.pb.h" 32 33 namespace tensorflow { 34 namespace data { 35 36 // Iterator over a task's elements. 37 class TaskIterator { 38 public: 39 virtual ~TaskIterator() = default; 40 // If the iterator is not yet exhausted, `GetNext` stores the next element in 41 // `element` and sets `end_of_sequence` to `false`. Otherwise, sets 42 // `end_of_sequence to `true`. 43 virtual Status GetNext(std::vector<Tensor>& element, 44 bool& end_of_sequence) = 0; 45 // Reports the cardinality of the dataset that created this iterator. 46 virtual int64 Cardinality() const = 0; 47 }; 48 49 // Implementation of TaskIterator wrapping a standalone iterator. 50 class StandaloneTaskIterator : public TaskIterator { 51 public: 52 // `dataset` should be the dataset that created `iterator`. 53 // StandaloneTaskIterator takes ownership of the dataset to ensures it 54 // lives as long as `iterator`. 55 StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset, 56 std::unique_ptr<standalone::Iterator> iterator); 57 Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override; 58 int64 Cardinality() const override; 59 60 private: 61 std::unique_ptr<standalone::Dataset> dataset_; 62 std::unique_ptr<standalone::Iterator> iterator_; 63 }; 64 65 // Interface for providing elements to task consumers. 66 class TaskRunner { 67 public: 68 // Creates a `TaskRunner` and stores it in `out`. 69 static Status Create(const experimental::WorkerConfig& worker_config, 70 const TaskDef& task_def, 71 std::unique_ptr<TaskIterator> iterator, 72 std::unique_ptr<TaskRunner>& out); 73 virtual ~TaskRunner() = default; 74 // Gets the next element for the given request. 75 virtual Status GetNext(const GetElementRequest& req, 76 GetElementResult& result) = 0; 77 // Cancels in-progress `GetNext` requests. 78 virtual void Cancel() = 0; 79 }; 80 81 // A task runner which provides elements on a first-come first-served basis. 82 // It does not consider which consumer is making the request. 83 class FirstComeFirstServedTaskRunner : public TaskRunner { 84 public: 85 explicit FirstComeFirstServedTaskRunner( 86 std::unique_ptr<TaskIterator> iterator); 87 ~FirstComeFirstServedTaskRunner() override; 88 89 Status GetNext(const GetElementRequest& req, 90 GetElementResult& result) override; 91 void Cancel() override; 92 93 private: 94 // Function to continually prefetch the next element. Returns an error if the 95 // task has been cancelled. 96 Status PrefetchFn(); 97 98 // Runs `PrefetchFn` on a dedicated thread. 99 void RunPrefetchThread(); 100 101 // Gets the next element from the input iterator. 102 StatusOr<GetElementResult> GetNextFromInputIterator() TF_LOCKS_EXCLUDED(mu_); 103 104 mutex mu_; 105 std::unique_ptr<TaskIterator> iterator_ TF_GUARDED_BY(mu_); 106 int64 element_index_ TF_GUARDED_BY(mu_) = 0; 107 108 ThreadSafeBuffer<GetElementResult> buffer_; 109 std::unique_ptr<Thread> prefetch_thread_; 110 111 TF_DISALLOW_COPY_AND_ASSIGN(FirstComeFirstServedTaskRunner); 112 }; 113 114 // An element produced by a task. 115 struct Element { ElementElement116 explicit Element(std::vector<Tensor>&& components, int64_t index) 117 : components(components), index(index) {} 118 // The components of the element. 119 std::vector<Tensor> components; 120 // The element's index within the task, e.g. 0 for the first element produced 121 // by the task, 1 for the second element, etc. 122 int64 index; 123 }; 124 125 // Thread for prefetching a round worth of elements. 126 class PrefetchThread { 127 public: 128 explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator, 129 int64_t round_size); 130 ~PrefetchThread(); 131 // Runs the prefetch thread. It runs until an error is encountered or the 132 // destructor is called. 133 void Run(); 134 // Fills `out` with a round of data. Waits for up to `wait_us` micoseconds 135 // before giving up and returning with `out` empty. A negative `wait_us` 136 // signals to wait indefinitely. 137 Status FillBuffer(int64_t wait_us, 138 std::vector<std::unique_ptr<Element>>& out); 139 // Returns the status for any failures encountered by the prefetch thread. 140 Status GetStatus(); 141 142 private: 143 const std::unique_ptr<TaskIterator> iterator_; 144 const int64 round_size_; 145 mutex mu_; 146 int64 index_ TF_GUARDED_BY(mu_) = 0; 147 // Buffered results for the next round. 148 std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_); 149 // The status if the prefetch thread fails. 150 Status status_ TF_GUARDED_BY(mu_) = Status::OK(); 151 // Thread which constantly tries to fill `buffer_` up with 152 // `num_consumers` elements. 153 std::unique_ptr<Thread> thread_; 154 // Condition variable notified when elements are added to or removed from 155 // `buffer_`, or when `status_` is changed. 156 condition_variable cv_; 157 bool cancelled_ TF_GUARDED_BY(mu_) = false; 158 }; 159 160 // A task runner which enforces round-robin order for consuming a task's 161 // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds". 162 // In each successive round, the runner waits to receive requests from all 163 // consumers. These requests are blocked until all requests arrive. Once all 164 // requests arrive, the runner hands out elements to consumers in order of their 165 // consumer indices. 166 // 167 // Consumers are expected to successively request consecutive element indices, 168 // starting at 0. The same element can be requested multiple times by the same 169 // consumer, as long as the consumer hasn't yet requested the next element (at 170 // the start of each round we discard elements from the previous round). 171 // 172 // If the worker restarts mid-round, a situation arises where some consumers 173 // are requesting element index `n` while others are requesting element index 174 // `n + 1`. To remedy this, the first round after restart may be a partial 175 // round, where we only serve elements to consumers requesting data for element 176 // index `n`, blocking other consumers until the second round. 177 class RoundRobinTaskRunner : public TaskRunner { 178 public: 179 RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator, 180 int64_t num_consumers, string worker_address); 181 182 Status GetNext(const GetElementRequest& req, 183 GetElementResult& result) override; 184 void Cancel() override; 185 186 private: 187 // Prepares a full round of data. `wait_us` indicates how long to wait before 188 // skipping if a full round of data is not yet ready. 189 Status PrepareFullRound(int64_t wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 190 // Prepares a partial round to get consumers back in sync. 191 Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 192 Status ValidateRequest(const GetElementRequest& req); 193 // Prepares data for the next round, blocking until the round is ready to 194 // start. 195 Status PrepareRound(const GetElementRequest& req); 196 const int64 num_consumers_; 197 const string worker_address_; 198 mutex mu_; 199 bool cancelled_ TF_GUARDED_BY(mu_) = false; 200 // Condition variable notified whenever we start a new round of round-robin. 201 condition_variable new_round_cv_; 202 // Outstanding requests, indexed by round number and then consumer index. 203 absl::flat_hash_map<int64, 204 absl::flat_hash_map<int64, const GetElementRequest*>> 205 requests_ TF_GUARDED_BY(mu_); 206 // Index of the first round we plan to serve. At startup, this is the minimum 207 // of all requested element indices. 208 int64 first_round_ TF_GUARDED_BY(mu_) = kint64max; 209 int64 current_round_ TF_GUARDED_BY(mu_) = -1; 210 bool round_skipped_ TF_GUARDED_BY(mu_) = false; 211 // Buffered results for the current round. 212 std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_); 213 // Thread which constantly tries to prepare `num_consumers` elements for the 214 // next round. 215 PrefetchThread prefetch_thread_; 216 }; 217 218 } // namespace data 219 } // namespace tensorflow 220 221 #endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ 222