1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/task_runner.h"
16
17 #include <memory>
18 #include <vector>
19
20 #include "tensorflow/core/data/service/thread_safe_buffer.h"
21 #include "tensorflow/core/data/standalone.h"
22 #include "tensorflow/core/framework/cancellation.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/tensor_util.h"
25 #include "tensorflow/core/lib/gtl/cleanup.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/protobuf/service_config.pb.h"
33
34 namespace tensorflow {
35 namespace data {
36 namespace {
37 // Time to wait before skipping a round if data still isn't available.
38 const int64_t kWaitBeforeSkipUs = 100 * 1000; // 100ms.
39
40 } // namespace
41
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,std::unique_ptr<standalone::Iterator> iterator)42 StandaloneTaskIterator::StandaloneTaskIterator(
43 std::unique_ptr<standalone::Dataset> dataset,
44 std::unique_ptr<standalone::Iterator> iterator)
45 : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
46
GetNext(std::vector<Tensor> & element,bool & end_of_sequence)47 Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
48 bool& end_of_sequence) {
49 return iterator_->GetNext(&element, &end_of_sequence);
50 }
51
Cardinality() const52 int64 StandaloneTaskIterator::Cardinality() const {
53 return dataset_->Get()->Cardinality();
54 }
55
Create(const experimental::WorkerConfig & worker_config,const TaskDef & task_def,std::unique_ptr<TaskIterator> iterator,std::unique_ptr<TaskRunner> & out)56 Status TaskRunner::Create(const experimental::WorkerConfig& worker_config,
57 const TaskDef& task_def,
58 std::unique_ptr<TaskIterator> iterator,
59 std::unique_ptr<TaskRunner>& out) {
60 if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
61 int64_t cardinality = iterator->Cardinality();
62 if (cardinality != kInfiniteCardinality &&
63 cardinality != kUnknownCardinality) {
64 return errors::FailedPrecondition(
65 "Round robin reads require that the input dataset has infinite "
66 "cardinality, but the dataset has cardinality ",
67 cardinality,
68 ". Consider adding a `.repeat()` transformation to the dataset.");
69 }
70 out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
71 task_def.num_consumers(),
72 task_def.worker_address());
73 } else {
74 out =
75 absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
76 }
77 return Status::OK();
78 }
79
FirstComeFirstServedTaskRunner(std::unique_ptr<TaskIterator> iterator)80 FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
81 std::unique_ptr<TaskIterator> iterator)
82 : iterator_(std::move(iterator)), buffer_(/*buffer_size=*/1) {
83 RunPrefetchThread();
84 }
85
~FirstComeFirstServedTaskRunner()86 FirstComeFirstServedTaskRunner::~FirstComeFirstServedTaskRunner() { Cancel(); }
87
GetNext(const GetElementRequest & req,GetElementResult & result)88 Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
89 GetElementResult& result) {
90 TF_ASSIGN_OR_RETURN(result, buffer_.Pop());
91 return Status::OK();
92 }
93
PrefetchFn()94 Status FirstComeFirstServedTaskRunner::PrefetchFn() {
95 while (true) {
96 TF_RETURN_IF_ERROR(buffer_.Push(GetNextFromInputIterator()));
97 }
98 return Status::OK();
99 }
100
RunPrefetchThread()101 void FirstComeFirstServedTaskRunner::RunPrefetchThread() {
102 auto prefetch_fn = [this] {
103 Status status = PrefetchFn();
104 if (!status.ok()) {
105 buffer_.Cancel(status);
106 }
107 };
108 prefetch_thread_ = absl::WrapUnique(Env::Default()->StartThread(
109 /*thread_options=*/{}, /*name=*/"tf_data_service_fcfs_prefetch_thread",
110 prefetch_fn));
111 }
112
113 StatusOr<GetElementResult>
GetNextFromInputIterator()114 FirstComeFirstServedTaskRunner::GetNextFromInputIterator()
115 TF_LOCKS_EXCLUDED(mu_) {
116 GetElementResult result;
117 std::vector<Tensor> element;
118 bool end_of_task;
119 result.skip = false;
120 {
121 mutex_lock l(mu_);
122 TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
123 result.end_of_sequence = end_of_task;
124 result.element_index = element_index_++;
125 }
126 if (!end_of_task) {
127 result.components = std::move(element);
128 }
129 return result;
130 }
131
Cancel()132 void FirstComeFirstServedTaskRunner::Cancel() {
133 VLOG(2) << "Cancelling tf.data service FCFS task.";
134 buffer_.Cancel(errors::Cancelled("tf.data service FCFS task is cancelled."));
135 }
136
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,int64_t num_consumers,string worker_address)137 RoundRobinTaskRunner::RoundRobinTaskRunner(
138 std::unique_ptr<TaskIterator> iterator, int64_t num_consumers,
139 string worker_address)
140 : num_consumers_(num_consumers),
141 worker_address_(worker_address),
142 buffer_(num_consumers_),
143 prefetch_thread_(std::move(iterator), num_consumers_) {
144 VLOG(1) << "Creating task runner for distributing data round-robin to "
145 << num_consumers << " consumers";
146 }
147
ValidateRequest(const GetElementRequest & req)148 Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
149 if (req.consumer_index() < 0 || req.round_index() < 0) {
150 return errors::FailedPrecondition(
151 "RoundRobinTaskRunner needs to know the consumer index and element "
152 "index of each request.");
153 }
154 if (req.consumer_index() >= num_consumers_) {
155 return errors::FailedPrecondition(
156 "Requesting data for consumer index ", req.consumer_index(),
157 ", but the task is configured for only ", num_consumers_, " consumers");
158 }
159 return Status::OK();
160 }
161
PrepareFullRound(int64_t wait_us)162 Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us)
163 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
164 VLOG(1) << worker_address_ << ": Preparing full round for round "
165 << current_round_;
166 // This was the last request to arrive, time to start a new round.
167 TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
168 round_skipped_ = buffer_.empty();
169 new_round_cv_.notify_all();
170 return Status::OK();
171 }
172
PreparePartialRound()173 Status RoundRobinTaskRunner::PreparePartialRound()
174 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
175 VLOG(1) << worker_address_ << ": Starting partial round " << first_round_
176 << " for " << requests_[first_round_].size() << " consumers";
177 current_round_ = first_round_;
178 new_round_cv_.notify_all();
179 // Indicates that we need a partial round to get consumers back in sync.
180 auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
181 if (next_round_request.skipped_previous_round()) {
182 VLOG(1) << "Skipping partial round";
183 round_skipped_ = true;
184 return Status::OK();
185 }
186 TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
187 round_skipped_ = false;
188 return Status::OK();
189 }
190
PrepareRound(const GetElementRequest & req)191 Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
192 mutex_lock l(mu_);
193 first_round_ = std::min(first_round_, req.round_index());
194 absl::flat_hash_map<int64, const GetElementRequest*>& round =
195 requests_[req.round_index()];
196 round[req.consumer_index()] = &req;
197 auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
198 requests_[req.round_index()].erase(req.consumer_index());
199 });
200 if (current_round_ < req.round_index() && round.size() == num_consumers_) {
201 current_round_ = req.round_index();
202 int64_t wait_us = kWaitBeforeSkipUs;
203 if (!req.allow_skip()) {
204 wait_us = -1;
205 }
206 TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
207 }
208 if (current_round_ < 0 &&
209 requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
210 num_consumers_) {
211 TF_RETURN_IF_ERROR(PreparePartialRound());
212 }
213 while (!cancelled_ && current_round_ < req.round_index()) {
214 TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
215 new_round_cv_.wait(l);
216 }
217 if (current_round_ < req.round_index() && cancelled_) {
218 return errors::Cancelled("Worker is shutting down.");
219 }
220 if (current_round_ != req.round_index()) {
221 return errors::FailedPrecondition(
222 "Consumer ", req.consumer_index(), " requested data for round ",
223 req.round_index(), ", but the current round has already reached ",
224 current_round_,
225 ". This may indicate that the consumer was restarted with the same job "
226 "name.`");
227 }
228 return prefetch_thread_.GetStatus();
229 }
230
GetNext(const GetElementRequest & req,GetElementResult & result)231 Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
232 GetElementResult& result) {
233 TF_RETURN_IF_ERROR(ValidateRequest(req));
234 result.end_of_sequence = false;
235 VLOG(2) << worker_address_ << ": Received request from consumer index "
236 << req.consumer_index() << " for round " << req.round_index();
237 TF_RETURN_IF_ERROR(PrepareRound(req));
238 tf_shared_lock l(mu_);
239 result.skip = round_skipped_;
240 if (round_skipped_) {
241 VLOG(1) << worker_address_ << ": Buffer not ready, skipping round "
242 << current_round_ << " for consumer " << req.consumer_index();
243 return Status::OK();
244 }
245 auto& buffer_result = buffer_[req.consumer_index()];
246 result.element_index = buffer_result->index;
247 std::vector<Tensor> element;
248 for (auto& component : buffer_result->components) {
249 element.push_back(tensor::DeepCopy(component));
250 }
251 if (VLOG_IS_ON(2)) {
252 int64_t size = 0;
253 for (auto& component : element) {
254 size += component.TotalBytes();
255 }
256 VLOG(2) << worker_address_ << ": Returning element " << result.element_index
257 << " to consumer " << req.consumer_index() << " for round "
258 << req.round_index() << ". element size " << size;
259 }
260 result.components = std::move(element);
261 return Status::OK();
262 }
263
Cancel()264 void RoundRobinTaskRunner::Cancel() {
265 mutex_lock l(mu_);
266 cancelled_ = true;
267 new_round_cv_.notify_all();
268 }
269
PrefetchThread(std::unique_ptr<TaskIterator> iterator,int64_t round_size)270 PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
271 int64_t round_size)
272 : iterator_(std::move(iterator)), round_size_(round_size) {
273 thread_ = absl::WrapUnique(
274 Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
275 }
276
~PrefetchThread()277 PrefetchThread::~PrefetchThread() {
278 mutex_lock l(mu_);
279 cancelled_ = true;
280 cv_.notify_all();
281 }
282
Run()283 void PrefetchThread::Run() {
284 while (true) {
285 {
286 mutex_lock l(mu_);
287 while (!cancelled_ && buffer_.size() >= round_size_) {
288 cv_.wait(l);
289 }
290 if (cancelled_) {
291 return;
292 }
293 }
294 std::vector<Tensor> element;
295 bool end_of_sequence;
296 Status s = iterator_->GetNext(element, end_of_sequence);
297 if (!s.ok()) {
298 mutex_lock l(mu_);
299 status_ = s;
300 cv_.notify_all();
301 return;
302 }
303 if (end_of_sequence) {
304 mutex_lock l(mu_);
305 status_ = errors::FailedPrecondition(
306 "Encountered end of sequence on a round-robin read iterator. "
307 "Please ensure that the dataset used for round-robin reading has "
308 "infinite cardinality, e.g. by adding a .repeat() transformation "
309 "at the end.");
310 cv_.notify_all();
311 return;
312 }
313 mutex_lock l(mu_);
314 buffer_.push_back(absl::make_unique<Element>(std::move(element), index_++));
315 cv_.notify_all();
316 }
317 }
318
FillBuffer(int64_t wait_us,std::vector<std::unique_ptr<Element>> & out)319 Status PrefetchThread::FillBuffer(int64_t wait_us,
320 std::vector<std::unique_ptr<Element>>& out) {
321 int64_t start_us = Env::Default()->NowMicros();
322 out.clear();
323 mutex_lock l(mu_);
324 while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
325 int64_t remaining_us = start_us + wait_us - Env::Default()->NowMicros();
326 if (wait_us >= 0 && remaining_us <= 0) {
327 break;
328 }
329 cv_.wait_for(l, std::chrono::microseconds(remaining_us));
330 }
331 TF_RETURN_IF_ERROR(status_);
332 if (cancelled_) {
333 return errors::Cancelled("Prefetch thread cancelled");
334 }
335 if (buffer_.size() < round_size_) {
336 DCHECK_GE(wait_us, 0);
337 return Status::OK();
338 }
339 for (auto& elem : buffer_) {
340 out.push_back(std::move(elem));
341 }
342 buffer_.clear();
343 cv_.notify_all();
344 return Status::OK();
345 }
346
GetStatus()347 Status PrefetchThread::GetStatus() {
348 mutex_lock l(mu_);
349 return status_;
350 }
351 } // namespace data
352 } // namespace tensorflow
353