• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/task_runner.h"
16 
17 #include <memory>
18 #include <vector>
19 
20 #include "tensorflow/core/data/service/thread_safe_buffer.h"
21 #include "tensorflow/core/data/standalone.h"
22 #include "tensorflow/core/framework/cancellation.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/tensor_util.h"
25 #include "tensorflow/core/lib/gtl/cleanup.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/mutex.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/platform/thread_annotations.h"
32 #include "tensorflow/core/protobuf/service_config.pb.h"
33 
34 namespace tensorflow {
35 namespace data {
36 namespace {
37 // Time to wait before skipping a round if data still isn't available.
38 const int64_t kWaitBeforeSkipUs = 100 * 1000;  // 100ms.
39 
40 }  // namespace
41 
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,std::unique_ptr<standalone::Iterator> iterator)42 StandaloneTaskIterator::StandaloneTaskIterator(
43     std::unique_ptr<standalone::Dataset> dataset,
44     std::unique_ptr<standalone::Iterator> iterator)
45     : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
46 
GetNext(std::vector<Tensor> & element,bool & end_of_sequence)47 Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
48                                        bool& end_of_sequence) {
49   return iterator_->GetNext(&element, &end_of_sequence);
50 }
51 
Cardinality() const52 int64 StandaloneTaskIterator::Cardinality() const {
53   return dataset_->Get()->Cardinality();
54 }
55 
Create(const experimental::WorkerConfig & worker_config,const TaskDef & task_def,std::unique_ptr<TaskIterator> iterator,std::unique_ptr<TaskRunner> & out)56 Status TaskRunner::Create(const experimental::WorkerConfig& worker_config,
57                           const TaskDef& task_def,
58                           std::unique_ptr<TaskIterator> iterator,
59                           std::unique_ptr<TaskRunner>& out) {
60   if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
61     int64_t cardinality = iterator->Cardinality();
62     if (cardinality != kInfiniteCardinality &&
63         cardinality != kUnknownCardinality) {
64       return errors::FailedPrecondition(
65           "Round robin reads require that the input dataset has infinite "
66           "cardinality, but the dataset has cardinality ",
67           cardinality,
68           ". Consider adding a `.repeat()` transformation to the dataset.");
69     }
70     out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
71                                                   task_def.num_consumers(),
72                                                   task_def.worker_address());
73   } else {
74     out =
75         absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
76   }
77   return Status::OK();
78 }
79 
FirstComeFirstServedTaskRunner(std::unique_ptr<TaskIterator> iterator)80 FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
81     std::unique_ptr<TaskIterator> iterator)
82     : iterator_(std::move(iterator)), buffer_(/*buffer_size=*/1) {
83   RunPrefetchThread();
84 }
85 
~FirstComeFirstServedTaskRunner()86 FirstComeFirstServedTaskRunner::~FirstComeFirstServedTaskRunner() { Cancel(); }
87 
GetNext(const GetElementRequest & req,GetElementResult & result)88 Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
89                                                GetElementResult& result) {
90   TF_ASSIGN_OR_RETURN(result, buffer_.Pop());
91   return Status::OK();
92 }
93 
PrefetchFn()94 Status FirstComeFirstServedTaskRunner::PrefetchFn() {
95   while (true) {
96     TF_RETURN_IF_ERROR(buffer_.Push(GetNextFromInputIterator()));
97   }
98   return Status::OK();
99 }
100 
RunPrefetchThread()101 void FirstComeFirstServedTaskRunner::RunPrefetchThread() {
102   auto prefetch_fn = [this] {
103     Status status = PrefetchFn();
104     if (!status.ok()) {
105       buffer_.Cancel(status);
106     }
107   };
108   prefetch_thread_ = absl::WrapUnique(Env::Default()->StartThread(
109       /*thread_options=*/{}, /*name=*/"tf_data_service_fcfs_prefetch_thread",
110       prefetch_fn));
111 }
112 
113 StatusOr<GetElementResult>
GetNextFromInputIterator()114 FirstComeFirstServedTaskRunner::GetNextFromInputIterator()
115     TF_LOCKS_EXCLUDED(mu_) {
116   GetElementResult result;
117   std::vector<Tensor> element;
118   bool end_of_task;
119   result.skip = false;
120   {
121     mutex_lock l(mu_);
122     TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
123     result.end_of_sequence = end_of_task;
124     result.element_index = element_index_++;
125   }
126   if (!end_of_task) {
127     result.components = std::move(element);
128   }
129   return result;
130 }
131 
Cancel()132 void FirstComeFirstServedTaskRunner::Cancel() {
133   VLOG(2) << "Cancelling tf.data service FCFS task.";
134   buffer_.Cancel(errors::Cancelled("tf.data service FCFS task is cancelled."));
135 }
136 
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,int64_t num_consumers,string worker_address)137 RoundRobinTaskRunner::RoundRobinTaskRunner(
138     std::unique_ptr<TaskIterator> iterator, int64_t num_consumers,
139     string worker_address)
140     : num_consumers_(num_consumers),
141       worker_address_(worker_address),
142       buffer_(num_consumers_),
143       prefetch_thread_(std::move(iterator), num_consumers_) {
144   VLOG(1) << "Creating task runner for distributing data round-robin to "
145           << num_consumers << " consumers";
146 }
147 
ValidateRequest(const GetElementRequest & req)148 Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
149   if (req.consumer_index() < 0 || req.round_index() < 0) {
150     return errors::FailedPrecondition(
151         "RoundRobinTaskRunner needs to know the consumer index and element "
152         "index of each request.");
153   }
154   if (req.consumer_index() >= num_consumers_) {
155     return errors::FailedPrecondition(
156         "Requesting data for consumer index ", req.consumer_index(),
157         ", but the task is configured for only ", num_consumers_, " consumers");
158   }
159   return Status::OK();
160 }
161 
PrepareFullRound(int64_t wait_us)162 Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us)
163     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
164   VLOG(1) << worker_address_ << ": Preparing full round for round "
165           << current_round_;
166   // This was the last request to arrive, time to start a new round.
167   TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
168   round_skipped_ = buffer_.empty();
169   new_round_cv_.notify_all();
170   return Status::OK();
171 }
172 
PreparePartialRound()173 Status RoundRobinTaskRunner::PreparePartialRound()
174     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
175   VLOG(1) << worker_address_ << ": Starting partial round " << first_round_
176           << " for " << requests_[first_round_].size() << " consumers";
177   current_round_ = first_round_;
178   new_round_cv_.notify_all();
179   // Indicates that we need a partial round to get consumers back in sync.
180   auto next_round_request = *(requests_[first_round_ + 1].begin()->second);
181   if (next_round_request.skipped_previous_round()) {
182     VLOG(1) << "Skipping partial round";
183     round_skipped_ = true;
184     return Status::OK();
185   }
186   TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
187   round_skipped_ = false;
188   return Status::OK();
189 }
190 
PrepareRound(const GetElementRequest & req)191 Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
192   mutex_lock l(mu_);
193   first_round_ = std::min(first_round_, req.round_index());
194   absl::flat_hash_map<int64, const GetElementRequest*>& round =
195       requests_[req.round_index()];
196   round[req.consumer_index()] = &req;
197   auto cleanup = gtl::MakeCleanup([&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
198     requests_[req.round_index()].erase(req.consumer_index());
199   });
200   if (current_round_ < req.round_index() && round.size() == num_consumers_) {
201     current_round_ = req.round_index();
202     int64_t wait_us = kWaitBeforeSkipUs;
203     if (!req.allow_skip()) {
204       wait_us = -1;
205     }
206     TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
207   }
208   if (current_round_ < 0 &&
209       requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
210           num_consumers_) {
211     TF_RETURN_IF_ERROR(PreparePartialRound());
212   }
213   while (!cancelled_ && current_round_ < req.round_index()) {
214     TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
215     new_round_cv_.wait(l);
216   }
217   if (current_round_ < req.round_index() && cancelled_) {
218     return errors::Cancelled("Worker is shutting down.");
219   }
220   if (current_round_ != req.round_index()) {
221     return errors::FailedPrecondition(
222         "Consumer ", req.consumer_index(), " requested data for round ",
223         req.round_index(), ", but the current round has already reached ",
224         current_round_,
225         ". This may indicate that the consumer was restarted with the same job "
226         "name.`");
227   }
228   return prefetch_thread_.GetStatus();
229 }
230 
GetNext(const GetElementRequest & req,GetElementResult & result)231 Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
232                                      GetElementResult& result) {
233   TF_RETURN_IF_ERROR(ValidateRequest(req));
234   result.end_of_sequence = false;
235   VLOG(2) << worker_address_ << ": Received request from consumer index "
236           << req.consumer_index() << " for round " << req.round_index();
237   TF_RETURN_IF_ERROR(PrepareRound(req));
238   tf_shared_lock l(mu_);
239   result.skip = round_skipped_;
240   if (round_skipped_) {
241     VLOG(1) << worker_address_ << ": Buffer not ready, skipping round "
242             << current_round_ << " for consumer " << req.consumer_index();
243     return Status::OK();
244   }
245   auto& buffer_result = buffer_[req.consumer_index()];
246   result.element_index = buffer_result->index;
247   std::vector<Tensor> element;
248   for (auto& component : buffer_result->components) {
249     element.push_back(tensor::DeepCopy(component));
250   }
251   if (VLOG_IS_ON(2)) {
252     int64_t size = 0;
253     for (auto& component : element) {
254       size += component.TotalBytes();
255     }
256     VLOG(2) << worker_address_ << ": Returning element " << result.element_index
257             << " to consumer " << req.consumer_index() << " for round "
258             << req.round_index() << ". element size " << size;
259   }
260   result.components = std::move(element);
261   return Status::OK();
262 }
263 
Cancel()264 void RoundRobinTaskRunner::Cancel() {
265   mutex_lock l(mu_);
266   cancelled_ = true;
267   new_round_cv_.notify_all();
268 }
269 
PrefetchThread(std::unique_ptr<TaskIterator> iterator,int64_t round_size)270 PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
271                                int64_t round_size)
272     : iterator_(std::move(iterator)), round_size_(round_size) {
273   thread_ = absl::WrapUnique(
274       Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
275 }
276 
~PrefetchThread()277 PrefetchThread::~PrefetchThread() {
278   mutex_lock l(mu_);
279   cancelled_ = true;
280   cv_.notify_all();
281 }
282 
Run()283 void PrefetchThread::Run() {
284   while (true) {
285     {
286       mutex_lock l(mu_);
287       while (!cancelled_ && buffer_.size() >= round_size_) {
288         cv_.wait(l);
289       }
290       if (cancelled_) {
291         return;
292       }
293     }
294     std::vector<Tensor> element;
295     bool end_of_sequence;
296     Status s = iterator_->GetNext(element, end_of_sequence);
297     if (!s.ok()) {
298       mutex_lock l(mu_);
299       status_ = s;
300       cv_.notify_all();
301       return;
302     }
303     if (end_of_sequence) {
304       mutex_lock l(mu_);
305       status_ = errors::FailedPrecondition(
306           "Encountered end of sequence on a round-robin read iterator. "
307           "Please ensure that the dataset used for round-robin reading has "
308           "infinite cardinality, e.g. by adding a .repeat() transformation "
309           "at the end.");
310       cv_.notify_all();
311       return;
312     }
313     mutex_lock l(mu_);
314     buffer_.push_back(absl::make_unique<Element>(std::move(element), index_++));
315     cv_.notify_all();
316   }
317 }
318 
FillBuffer(int64_t wait_us,std::vector<std::unique_ptr<Element>> & out)319 Status PrefetchThread::FillBuffer(int64_t wait_us,
320                                   std::vector<std::unique_ptr<Element>>& out) {
321   int64_t start_us = Env::Default()->NowMicros();
322   out.clear();
323   mutex_lock l(mu_);
324   while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
325     int64_t remaining_us = start_us + wait_us - Env::Default()->NowMicros();
326     if (wait_us >= 0 && remaining_us <= 0) {
327       break;
328     }
329     cv_.wait_for(l, std::chrono::microseconds(remaining_us));
330   }
331   TF_RETURN_IF_ERROR(status_);
332   if (cancelled_) {
333     return errors::Cancelled("Prefetch thread cancelled");
334   }
335   if (buffer_.size() < round_size_) {
336     DCHECK_GE(wait_us, 0);
337     return Status::OK();
338   }
339   for (auto& elem : buffer_) {
340     out.push_back(std::move(elem));
341   }
342   buffer_.clear();
343   cv_.notify_all();
344   return Status::OK();
345 }
346 
GetStatus()347 Status PrefetchThread::GetStatus() {
348   mutex_lock l(mu_);
349   return status_;
350 }
351 }  // namespace data
352 }  // namespace tensorflow
353