• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "tensorflow/core/data/service/common.pb.h"
22 #include "tensorflow/core/data/service/data_transfer.h"
23 #include "tensorflow/core/data/service/thread_safe_buffer.h"
24 #include "tensorflow/core/data/service/worker.pb.h"
25 #include "tensorflow/core/data/standalone.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/status.h"
29 #include "tensorflow/core/platform/statusor.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/core/protobuf/service_config.pb.h"
32 
33 namespace tensorflow {
34 namespace data {
35 
36 // Iterator over a task's elements.
37 class TaskIterator {
38  public:
39   virtual ~TaskIterator() = default;
40   // If the iterator is not yet exhausted, `GetNext` stores the next element in
41   // `element` and sets `end_of_sequence` to `false`. Otherwise, sets
42   // `end_of_sequence to `true`.
43   virtual Status GetNext(std::vector<Tensor>& element,
44                          bool& end_of_sequence) = 0;
45   // Reports the cardinality of the dataset that created this iterator.
46   virtual int64 Cardinality() const = 0;
47 };
48 
49 // Implementation of TaskIterator wrapping a standalone iterator.
50 class StandaloneTaskIterator : public TaskIterator {
51  public:
52   // `dataset` should be the dataset that created `iterator`.
53   // StandaloneTaskIterator takes ownership of the dataset to ensures it
54   // lives as long as `iterator`.
55   StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,
56                          std::unique_ptr<standalone::Iterator> iterator);
57   Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override;
58   int64 Cardinality() const override;
59 
60  private:
61   std::unique_ptr<standalone::Dataset> dataset_;
62   std::unique_ptr<standalone::Iterator> iterator_;
63 };
64 
65 // Interface for providing elements to task consumers.
66 class TaskRunner {
67  public:
68   // Creates a `TaskRunner` and stores it in `out`.
69   static Status Create(const experimental::WorkerConfig& worker_config,
70                        const TaskDef& task_def,
71                        std::unique_ptr<TaskIterator> iterator,
72                        std::unique_ptr<TaskRunner>& out);
73   virtual ~TaskRunner() = default;
74   // Gets the next element for the given request.
75   virtual Status GetNext(const GetElementRequest& req,
76                          GetElementResult& result) = 0;
77   // Cancels in-progress `GetNext` requests.
78   virtual void Cancel() = 0;
79 };
80 
81 // A task runner which provides elements on a first-come first-served basis.
82 // It does not consider which consumer is making the request.
83 class FirstComeFirstServedTaskRunner : public TaskRunner {
84  public:
85   explicit FirstComeFirstServedTaskRunner(
86       std::unique_ptr<TaskIterator> iterator);
87   ~FirstComeFirstServedTaskRunner() override;
88 
89   Status GetNext(const GetElementRequest& req,
90                  GetElementResult& result) override;
91   void Cancel() override;
92 
93  private:
94   // Function to continually prefetch the next element. Returns an error if the
95   // task has been cancelled.
96   Status PrefetchFn();
97 
98   // Runs `PrefetchFn` on a dedicated thread.
99   void RunPrefetchThread();
100 
101   // Gets the next element from the input iterator.
102   StatusOr<GetElementResult> GetNextFromInputIterator() TF_LOCKS_EXCLUDED(mu_);
103 
104   mutex mu_;
105   std::unique_ptr<TaskIterator> iterator_ TF_GUARDED_BY(mu_);
106   int64 element_index_ TF_GUARDED_BY(mu_) = 0;
107 
108   ThreadSafeBuffer<GetElementResult> buffer_;
109   std::unique_ptr<Thread> prefetch_thread_;
110 
111   TF_DISALLOW_COPY_AND_ASSIGN(FirstComeFirstServedTaskRunner);
112 };
113 
114 // An element produced by a task.
115 struct Element {
ElementElement116   explicit Element(std::vector<Tensor>&& components, int64_t index)
117       : components(components), index(index) {}
118   // The components of the element.
119   std::vector<Tensor> components;
120   // The element's index within the task, e.g. 0 for the first element produced
121   // by the task, 1 for the second element, etc.
122   int64 index;
123 };
124 
125 // Thread for prefetching a round worth of elements.
126 class PrefetchThread {
127  public:
128   explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
129                           int64_t round_size);
130   ~PrefetchThread();
131   // Runs the prefetch thread. It runs until an error is encountered or the
132   // destructor is called.
133   void Run();
134   // Fills `out` with a round of data. Waits for up to `wait_us` micoseconds
135   // before giving up and returning with `out` empty. A negative `wait_us`
136   // signals to wait indefinitely.
137   Status FillBuffer(int64_t wait_us,
138                     std::vector<std::unique_ptr<Element>>& out);
139   // Returns the status for any failures encountered by the prefetch thread.
140   Status GetStatus();
141 
142  private:
143   const std::unique_ptr<TaskIterator> iterator_;
144   const int64 round_size_;
145   mutex mu_;
146   int64 index_ TF_GUARDED_BY(mu_) = 0;
147   // Buffered results for the next round.
148   std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_);
149   // The status if the prefetch thread fails.
150   Status status_ TF_GUARDED_BY(mu_) = Status::OK();
151   // Thread which constantly tries to fill `buffer_` up with
152   // `num_consumers` elements.
153   std::unique_ptr<Thread> thread_;
154   // Condition variable notified when elements are added to or removed from
155   // `buffer_`, or when `status_` is changed.
156   condition_variable cv_;
157   bool cancelled_ TF_GUARDED_BY(mu_) = false;
158 };
159 
160 // A task runner which enforces round-robin order for consuming a task's
161 // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
162 // In each successive round, the runner waits to receive requests from all
163 // consumers. These requests are blocked until all requests arrive. Once all
164 // requests arrive, the runner hands out elements to consumers in order of their
165 // consumer indices.
166 //
167 // Consumers are expected to successively request consecutive element indices,
168 // starting at 0. The same element can be requested multiple times by the same
169 // consumer, as long as the consumer hasn't yet requested the next element (at
170 // the start of each round we discard elements from the previous round).
171 //
172 // If the worker restarts mid-round, a situation arises where some consumers
173 // are requesting element index `n` while others are requesting element index
174 // `n + 1`. To remedy this, the first round after restart may be a partial
175 // round, where we only serve elements to consumers requesting data for element
176 // index `n`, blocking other consumers until the second round.
177 class RoundRobinTaskRunner : public TaskRunner {
178  public:
179   RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
180                        int64_t num_consumers, string worker_address);
181 
182   Status GetNext(const GetElementRequest& req,
183                  GetElementResult& result) override;
184   void Cancel() override;
185 
186  private:
187   // Prepares a full round of data. `wait_us` indicates how long to wait before
188   // skipping if a full round of data is not yet ready.
189   Status PrepareFullRound(int64_t wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
190   // Prepares a partial round to get consumers back in sync.
191   Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
192   Status ValidateRequest(const GetElementRequest& req);
193   // Prepares data for the next round, blocking until the round is ready to
194   // start.
195   Status PrepareRound(const GetElementRequest& req);
196   const int64 num_consumers_;
197   const string worker_address_;
198   mutex mu_;
199   bool cancelled_ TF_GUARDED_BY(mu_) = false;
200   // Condition variable notified whenever we start a new round of round-robin.
201   condition_variable new_round_cv_;
202   // Outstanding requests, indexed by round number and then consumer index.
203   absl::flat_hash_map<int64,
204                       absl::flat_hash_map<int64, const GetElementRequest*>>
205       requests_ TF_GUARDED_BY(mu_);
206   // Index of the first round we plan to serve. At startup, this is the minimum
207   // of all requested element indices.
208   int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
209   int64 current_round_ TF_GUARDED_BY(mu_) = -1;
210   bool round_skipped_ TF_GUARDED_BY(mu_) = false;
211   // Buffered results for the current round.
212   std::vector<std::unique_ptr<Element>> buffer_ TF_GUARDED_BY(mu_);
213   // Thread which constantly tries to prepare `num_consumers` elements for the
214   // next round.
215   PrefetchThread prefetch_thread_;
216 };
217 
218 }  // namespace data
219 }  // namespace tensorflow
220 
221 #endif  // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
222