1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h"
16
17 #include <atomic>
18 #include <deque>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/stats_aggregator.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/data/dataset_utils.h"
28 #include "tensorflow/core/kernels/data/name_utils.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/platform/blocking_counter.h"
33 #include "tensorflow/core/platform/stringprintf.h"
34 #include "tensorflow/core/profiler/lib/traceme.h"
35 #include "tensorflow/core/profiler/lib/traceme_encode.h"
36
37 namespace tensorflow {
38 namespace data {
39 namespace experimental {
40
41 /* static */ constexpr const char* const
42 ParallelInterleaveDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const
44 ParallelInterleaveDatasetOp::kInputDataset;
45 /* static */ constexpr const char* const
46 ParallelInterleaveDatasetOp::kOtherArguments;
47 /* static */ constexpr const char* const
48 ParallelInterleaveDatasetOp::kCycleLength;
49 /* static */ constexpr const char* const
50 ParallelInterleaveDatasetOp::kBlockLength;
51 /* static */ constexpr const char* const
52 ParallelInterleaveDatasetOp::kDeterministic;
53 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
54 /* static */ constexpr const char* const
55 ParallelInterleaveDatasetOp::kBufferOutputElements;
56 /* static */ constexpr const char* const
57 ParallelInterleaveDatasetOp::kPrefetchInputElements;
58 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
59 /* static */ constexpr const char* const
60 ParallelInterleaveDatasetOp::kTarguments;
61 /* static */ constexpr const char* const
62 ParallelInterleaveDatasetOp::kOutputTypes;
63 /* static */ constexpr const char* const
64 ParallelInterleaveDatasetOp::kOutputShapes;
65
66 constexpr char kInputExhausted[] = "input_exhausted";
67 constexpr char kNextIndex[] = "next_index";
68 constexpr char kBlockCount[] = "block_count";
69 constexpr char kWorkersSize[] = "workers_size";
70 constexpr char kInterleaveSize[] = "interleave_size";
71 constexpr char kInterleaveIndices[] = "interleave_indices";
72 constexpr char kStagingSize[] = "staging_size";
73 constexpr char kStagingIndices[] = "staging_indices";
74 constexpr char kWorkerThreadsRunning[] = "worker_threads_running";
75 constexpr char kDataParallelInterleaveWorker[] =
76 "data_parallel_interleave_worker";
77 constexpr char kWorker[] = "worker";
78 constexpr char kInputSize[] = "input_size";
79 constexpr char kInput[] = "input";
80 constexpr char kOutputsSize[] = "outputs_size";
81 constexpr char kOutputs[] = "outputs";
82 constexpr char kIsProducing[] = "is_producing";
83 constexpr char kWorkerThread[] = "worker_thread";
84 constexpr char kIteratorExhausted[] = "iterator_exhausted";
85 constexpr char kIteratorCreationStatus[] = "iterator_creation_status";
86 constexpr char kOutput[] = "output";
87 constexpr char kEndOfSequence[] = "end_of_sequence";
88 constexpr char kStatus[] = "status";
89 constexpr char kOutputSize[] = "output_size";
90 constexpr char kCode[] = "code";
91 constexpr char KMessage[] = "msg";
92
93 class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
94 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64 cycle_length,int64 block_length,DeterminismPolicy deterministic,int64 buffer_output_elements,int64 prefetch_input_elements,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int op_version)95 Dataset(OpKernelContext* ctx, const DatasetBase* input,
96 std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
97 int64 block_length, DeterminismPolicy deterministic,
98 int64 buffer_output_elements, int64 prefetch_input_elements,
99 const DataTypeVector& output_types,
100 const std::vector<PartialTensorShape>& output_shapes, int op_version)
101 : DatasetBase(DatasetContext(ctx)),
102 input_(input),
103 captured_func_(std::move(captured_func)),
104 cycle_length_(cycle_length),
105 block_length_(block_length),
106 deterministic_(deterministic),
107 buffer_output_elements_(buffer_output_elements),
108 prefetch_input_elements_(prefetch_input_elements),
109 output_types_(output_types),
110 output_shapes_(output_shapes),
111 traceme_metadata_(
112 {{"block_length",
113 strings::Printf("%lld", static_cast<long long>(block_length))},
114 {"cycle_length",
115 strings::Printf("%lld", static_cast<long long>(cycle_length))},
116 {"deterministic",
117 deterministic.IsDeterministic() || deterministic.IsDefault()
118 ? "true"
119 : "false"}}),
120 op_version_(op_version) {
121 input_->Ref();
122 }
123
~Dataset()124 ~Dataset() override { input_->Unref(); }
125
MakeIteratorInternal(const string & prefix) const126 std::unique_ptr<IteratorBase> MakeIteratorInternal(
127 const string& prefix) const override {
128 name_utils::IteratorPrefixParams params;
129 params.op_version = op_version_;
130 bool deterministic =
131 deterministic_.IsDeterministic() || deterministic_.IsDefault();
132 return absl::make_unique<Iterator>(
133 Iterator::Params{
134 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
135 deterministic);
136 }
137
output_dtypes() const138 const DataTypeVector& output_dtypes() const override { return output_types_; }
139
output_shapes() const140 const std::vector<PartialTensorShape>& output_shapes() const override {
141 return output_shapes_;
142 }
143
DebugString() const144 string DebugString() const override {
145 name_utils::DatasetDebugStringParams params;
146 params.op_version = op_version_;
147 return name_utils::DatasetDebugString(kDatasetType, params);
148 }
149
InputDatasets(std::vector<const DatasetBase * > * inputs) const150 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
151 inputs->push_back(input_);
152 return Status::OK();
153 }
154
CheckExternalState() const155 Status CheckExternalState() const override {
156 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
157 return input_->CheckExternalState();
158 }
159
160 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const161 Status AsGraphDefInternal(SerializationContext* ctx,
162 DatasetGraphDefBuilder* b,
163 Node** output) const override {
164 std::vector<std::pair<size_t, Node*>> inputs;
165 std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
166 int input_index = 0;
167
168 Node* input_node;
169 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
170 inputs.emplace_back(input_index++, input_node);
171
172 std::vector<Node*> other_arguments;
173 DataTypeVector other_arguments_types;
174 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
175 &other_arguments_types));
176 list_inputs.emplace_back(input_index++, other_arguments);
177
178 Node* cycle_length_node;
179 TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
180 inputs.emplace_back(input_index++, cycle_length_node);
181
182 Node* block_length_node;
183 TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
184 inputs.emplace_back(input_index++, block_length_node);
185
186 if (op_version_ == 1) {
187 Node* sloppy_node;
188 TF_RETURN_IF_ERROR(
189 b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
190 inputs.emplace_back(input_index++, sloppy_node);
191 }
192
193 Node* buffer_output_elements_node;
194 TF_RETURN_IF_ERROR(
195 b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
196 inputs.emplace_back(input_index++, buffer_output_elements_node);
197
198 Node* prefetch_input_elements_node;
199 TF_RETURN_IF_ERROR(
200 b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
201 inputs.emplace_back(input_index++, prefetch_input_elements_node);
202
203 std::vector<std::pair<StringPiece, AttrValue>> attrs;
204
205 AttrValue f;
206 b->BuildAttrValue(captured_func_->func(), &f);
207 attrs.emplace_back(kFunc, f);
208
209 if (op_version_ == 2) {
210 AttrValue deterministic_attr;
211 b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
212 attrs.emplace_back(kDeterministic, deterministic_attr);
213 }
214
215 AttrValue other_arguments_types_attr;
216 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
217 attrs.emplace_back(kTarguments, other_arguments_types_attr);
218
219 TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
220 return Status::OK();
221 }
222
223 private:
num_threads() const224 int64 num_threads() const { return cycle_length_ + prefetch_input_elements_; }
225
226 // Parallel interleave's implementation is designed around a few principles:
227 // 1. Thread creation is relatively expensive. (Not reusing
228 // threads causes a number of indirect costs such as poorer tcmalloc
229 // performance due to thread-local caches, etc.) We allocate a fixed
230 // number of threads at the start and never change. This is why we've
231 // fused functionality that is theoretically orthogonal (i.e.
232 // .prefetch()) into the implementation.
233 // 2. Drop-in replacement for standard interleave. The goal will be to
234 // auto-opt people into an optimized implementation without any work
235 // on the customer's part. We thus go through great pains to maintain
236 // identical iteration orders, full determinism (disabled only via a
237 // flag, etc.)
238 // 3. Performance across a variety of environments and I/O envelopes.
239 //
240 // The actual implementation centers around a collection of worker threads
241 // and their corresponding worker state (tracked in the `workers_` vector).
242 // Worker threads repeatedly receive a vector of Tensors that are used as
243 // input to the flat-map function (`captured_func_`). The output of this
244 // function must be a dataset. The worker thread then repeatedly calls
245 // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
246 // that a caller will block waiting for an element to be produced.
247 //
248 // Pointers to these worker states are kept in 2 disjoint data structures:
249 // 1. `interleave_indices_` is a vector containing indices of WorkerStates
250 // in `workers_` that we are interleaving. Worker threads backing these
251 // WorkerStates should be regularly producing values.
252 // 2. `staging_indices_` is a deque containing indices of WorkerStates in
253 // `workers_` that we will move to `interleave_indices_` when an
254 // iterator in `interleave_indices_` is exhausted.
255 //
256 // The client calls `GetNext[Internal]()` to retrieve an output element. The
257 // internal implementation updates the state of `interleave_indices_` and
258 // `staging_indices_` as output iterators (run by the worker threads) are
259 // exhausted.
260 //
261 // `input_impl_` is the input iterator that generates arguments for the
262 // flat-map function (`captured_func_`). It is set to an iterator at
263 // Iterator construction, and is fixed until we consume all input elements.
264 // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
265 // memory.
266 //
267 // A few invariants are maintained:
268 // 1. No element in interleave_indices_ should be a -1 unless
269 // `staging_indices_` is empty and `input_impl_` is empty.
270 // 2. Every `worker_` element is pointed to by at most one element of the
271 // union of `interleave_indices_` and `staging_indices_`.
272 // 3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
273 // an element in `interleave_indices_` or `staging_indices_`.
274 class Iterator : public DatasetIterator<Dataset> {
275 public:
Iterator(const Params & params,bool deterministic)276 explicit Iterator(const Params& params, bool deterministic)
277 : DatasetIterator<Dataset>(params),
278 deterministic_(deterministic),
279 workers_(dataset()->num_threads()),
280 worker_thread_states_(dataset()->num_threads()) {}
281
~Iterator()282 ~Iterator() override { CancelThreads(); }
283
Initialize(IteratorContext * ctx)284 Status Initialize(IteratorContext* ctx) override {
285 // TODO(jsimsa): Register cancellation callback once the implementation is
286 // refactored not to hold mu_ while calling `GetNext` on the input.
287 TF_RETURN_IF_ERROR(
288 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
289 return dataset()->captured_func_->Instantiate(
290 ctx, &instantiated_captured_func_);
291 }
292
293 // It is implemented so that it matches the deterministic interleave
294 // unless getting the next element would block and we are allowed to be
295 // nondeterministic.
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)296 Status GetNextInternal(IteratorContext* ctx,
297 std::vector<Tensor>* out_tensors,
298 bool* end_of_sequence) override {
299 mutex_lock l(mu_);
300 TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
301 while (!cancelled_) {
302 // Wait for an item to become available, blocking if necessary. If we
303 // are allowed to be nondeterministic, we can skip over input datasets
304 // that do not have an item readily available.
305 bool can_produce_elements = false;
306 bool must_wait_for_input = true;
307 for (int64 i = 0; i < interleave_indices_.size(); ++i) {
308 int64 index = (next_index_ + i) % interleave_indices_.size();
309 int64 current_worker_index = interleave_indices_[index];
310 if (current_worker_index < 0) {
311 continue; // Empty interleave elements.
312 }
313 WorkerState* current_worker = &workers_[current_worker_index];
314 can_produce_elements |= current_worker->MayHaveElements();
315 if (!current_worker->outputs.empty()) {
316 // We have an element!
317 next_index_ = index;
318 const bool element_acquired_sloppily = !deterministic_ && i > 1;
319 if (!element_acquired_sloppily) {
320 // If the element was acquired in the regular (deterministic)
321 // order, then advance the current block and cycle pointers to
322 // the next element in the regular order.
323 block_count_++;
324 if (block_count_ == dataset()->block_length_) {
325 next_index_ = (index + 1) % interleave_indices_.size();
326 block_count_ = 0;
327 }
328 } else {
329 block_count_ = 0;
330 }
331 *end_of_sequence = false;
332 Status s = current_worker->outputs.front().status;
333 profiler::TraceMe traceme([&] {
334 return profiler::TraceMeEncode(
335 "ParallelInterleaveConsume",
336 {{"element_id", current_worker->outputs.front().id}});
337 });
338 current_worker->outputs.front().output.swap(*out_tensors);
339 current_worker->outputs.pop_front();
340 current_worker->cond_var.notify_one();
341 return s;
342 } else if (current_worker->is_producing && deterministic_) {
343 // current_worker.outputs.empty(), and we must wait for this
344 // iterator.
345 if (next_index_ != index) {
346 // We have advanced to a new iterator; reset block counts.
347 next_index_ = index;
348 block_count_ = 0;
349 }
350 break;
351 } else if (!current_worker->is_producing) {
352 // This iterator has reached end of input.
353 interleave_indices_[index] = -1;
354 if (input_impl_) {
355 // Start prefetching a new iterator.
356 std::vector<Tensor> args;
357 bool end_of_input = false;
358 Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
359 if (end_of_input) {
360 input_impl_.reset();
361 } else {
362 current_worker->SetInputs(s, std::move(args));
363 staging_indices_.emplace_back(current_worker_index);
364 }
365 }
366
367 if (!staging_indices_.empty()) {
368 // Move a worker from `staging_indices_` to
369 // `interleave_indices_`.
370 interleave_indices_[index] = staging_indices_.front();
371 staging_indices_.pop_front();
372
373 next_index_ = (index + 1) % interleave_indices_.size();
374 block_count_ = 0;
375 // Restart the inner [for] loop
376 can_produce_elements = true;
377 must_wait_for_input = false;
378 break;
379 }
380 }
381 }
382
383 if (!can_produce_elements && !input_impl_) {
384 // No potential for future values.
385 *end_of_sequence = true;
386 return Status::OK();
387 }
388
389 if (must_wait_for_input) {
390 // Wait for elements to become available.
391 RecordStop(ctx);
392 if (deterministic_) {
393 workers_[interleave_indices_[next_index_]].cond_var.wait(l);
394 } else {
395 any_element_available_cond_var_.wait(l);
396 }
397 RecordStart(ctx);
398 }
399 }
400 return errors::Cancelled(
401 "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
402 }
403
404 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const405 std::shared_ptr<model::Node> CreateNode(
406 IteratorContext* ctx, model::Node::Args args) const override {
407 return model::MakeAsyncInterleaveManyNode(std::move(args),
408 /*parameters=*/{});
409 }
410
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)411 Status SaveInternal(SerializationContext* ctx,
412 IteratorStateWriter* writer) override {
413 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
414 dataset()->captured_func_->CheckExternalState()));
415 // The order of locking is important here to avoid deadlock.
416 mutex_lock l(mu_);
417 mutex_lock ckpt_l(ckpt_mu_);
418 if (input_impl_) {
419 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
420 } else {
421 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInputExhausted, ""));
422 }
423 TF_RETURN_IF_ERROR(
424 writer->WriteScalar(prefix(), kNextIndex, next_index_));
425 TF_RETURN_IF_ERROR(
426 writer->WriteScalar(prefix(), kBlockCount, block_count_));
427 TF_RETURN_IF_ERROR(
428 writer->WriteScalar(prefix(), kWorkersSize, workers_.size()));
429 for (int i = 0; i < workers_.size(); ++i) {
430 TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
431 }
432 for (int i = 0; i < worker_thread_states_.size(); ++i) {
433 TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(ctx, writer, i));
434 }
435 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInterleaveSize,
436 interleave_indices_.size()));
437 for (int i = 0; i < interleave_indices_.size(); ++i) {
438 TF_RETURN_IF_ERROR(writer->WriteScalar(
439 prefix(), strings::StrCat(kInterleaveIndices, "_", i),
440 interleave_indices_[i]));
441 }
442 TF_RETURN_IF_ERROR(
443 writer->WriteScalar(prefix(), kStagingSize, staging_indices_.size()));
444 for (int i = 0; i < staging_indices_.size(); ++i) {
445 TF_RETURN_IF_ERROR(writer->WriteScalar(
446 prefix(), strings::StrCat(kStagingIndices, "_", i),
447 staging_indices_[i]));
448 }
449 if (!worker_threads_.empty()) {
450 TF_RETURN_IF_ERROR(
451 writer->WriteScalar(prefix(), kWorkerThreadsRunning, ""));
452 }
453 return Status::OK();
454 }
455
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)456 Status RestoreInternal(IteratorContext* ctx,
457 IteratorStateReader* reader) override {
458 {
459 // The order of locking is important here to avoid deadlock.
460 mutex_lock l(mu_);
461 mutex_lock ckpt_l(ckpt_mu_);
462 if (!reader->Contains(prefix(), kInputExhausted)) {
463 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
464 } else {
465 input_impl_.reset();
466 }
467 int64 temp;
468 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kNextIndex, &temp));
469 next_index_ = size_t(temp);
470 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBlockCount, &temp));
471 block_count_ = size_t(temp);
472
473 // Restore WorkerStates.
474 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kWorkersSize, &temp));
475 if (temp != dataset()->num_threads()) {
476 return errors::Internal("Expected ", dataset()->num_threads(),
477 " worker states but found ", temp, ".");
478 }
479 for (size_t i = 0; i < dataset()->num_threads(); ++i) {
480 TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
481 }
482 }
483 std::unique_ptr<thread::ThreadPool> threadpool = ctx->CreateThreadPool(
484 "read_worker_thread_state", dataset()->num_threads());
485 Status s = Status::OK();
486 BlockingCounter counter(dataset()->num_threads());
487 for (size_t i = 0; i < dataset()->num_threads(); ++i) {
488 threadpool->Schedule([this, i, ctx, reader, &s, &counter] {
489 WorkerThreadState state;
490 Status result = ReadWorkerThreadStateLocked(reader, i, ctx, &state);
491 mutex_lock l(mu_);
492 mutex_lock ckpt_l(ckpt_mu_);
493 if (!result.ok()) {
494 s.Update(result);
495 counter.DecrementCount();
496 return;
497 }
498 worker_thread_states_[i] = std::move(state);
499 counter.DecrementCount();
500 });
501 }
502 counter.Wait();
503 if (!s.ok()) {
504 return s;
505 }
506
507 mutex_lock l(mu_);
508 mutex_lock ckpt_l(ckpt_mu_);
509 // Restore `interleave_indices_`.
510 std::set<int64> all_indices;
511 {
512 int64 interleave_size;
513 TF_RETURN_IF_ERROR(
514 reader->ReadScalar(prefix(), kInterleaveSize, &interleave_size));
515 interleave_indices_.reserve(interleave_size);
516 for (int64 i = 0; i < interleave_size; ++i) {
517 int64 temp;
518 TF_RETURN_IF_ERROR(reader->ReadScalar(
519 prefix(), strings::StrCat(kInterleaveIndices, "_", i), &temp));
520 if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
521 return errors::Internal(
522 "Duplicate entry for ", temp,
523 " found when reading interleave and staging indices.");
524 }
525 if (temp >= 0) {
526 all_indices.insert(temp);
527 }
528 interleave_indices_.emplace_back(temp);
529 }
530 }
531
532 // Restore `staging_indices_`.
533 {
534 int64 staging_size;
535 TF_RETURN_IF_ERROR(
536 reader->ReadScalar(prefix(), kStagingSize, &staging_size));
537 for (int i = 0; i < staging_size; ++i) {
538 int64 temp;
539 TF_RETURN_IF_ERROR(reader->ReadScalar(
540 prefix(), strings::StrCat(kStagingIndices, "_", i), &temp));
541 if (all_indices.find(temp) != all_indices.end()) {
542 return errors::Internal(
543 "Duplicate entry for ", temp,
544 " found when reading interleave and staging indices.");
545 }
546 if (temp >= 0) {
547 all_indices.insert(temp);
548 }
549 staging_indices_.emplace_back(temp);
550 }
551 }
552
553 // Start Worker threads.
554 if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
555 worker_threads_.reserve(dataset()->num_threads());
556 for (size_t i = 0; i < dataset()->num_threads(); ++i) {
557 std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
558 worker_threads_.emplace_back(ctx->StartThread(
559 strings::StrCat(kDataParallelInterleaveWorker, "_", i),
560 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
561 }
562 }
563 return Status::OK();
564 }
565
GetTraceMeMetadata() const566 TraceMeMetadata GetTraceMeMetadata() const override {
567 return dataset()->traceme_metadata_;
568 }
569
570 private:
571 // OutputElem contains the information from a call to GetNext by an output
572 // iterator.
573 struct OutputElem {
574 // The output iterator sets `status` if getting the output element
575 // fails.
576 Status status;
577 // The buffered data element.
578 std::vector<Tensor> output;
579 int64 id = -1;
580
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem581 explicit OutputElem(const Status& s) : status(s) {}
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem582 OutputElem(const Status& s, int64 id) : status(s), id(id) {}
583 };
584
585 // Worker threads operate on their relevant WorkerState structs.
586 //
587 // WorkerState's fields are all protected by mu_;
588 struct WorkerState {
589 // The arguments to be used to construct an output iterator.
590 std::vector<Tensor> input;
591 // The buffered output elements.
592 std::deque<OutputElem> outputs;
593 // Set to true iff the worker thread expects to append more elements to
594 // outputs. is_producing can be false despite !outputs.empty().
595 // Concretely, all output elements will have been consumed only when:
596 // is_producing == false && outputs.empty();
597 bool is_producing = false;
598 // Condition variable used to coordinate between threads. The worker
599 // thread waits on this condition variable when it is either (1) waiting
600 // for the main thread to add arguments to `input`, or (2) waiting for
601 // the main thread to consume an element of `outputs`. The main thread
602 // waits on cond_var if it is waiting for the worker thread to produce
603 // an element into `outputs` (this implies deterministic==true).
604 condition_variable cond_var;
605
MayHaveElementstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState606 inline bool MayHaveElements() const {
607 return is_producing || !outputs.empty();
608 }
609
610 // Sets inputs for a worker thread and notifies it to start processing.
SetInputstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState611 void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
612 if (s.ok()) {
613 DCHECK(!MayHaveElements())
614 << "Tried to start inputs, despite already producing!";
615 input = std::move(input_arguments);
616 is_producing = true;
617 cond_var.notify_one();
618 } else {
619 outputs.emplace_back(s);
620 }
621 }
622 };
623
624 // The internal state of a worker thread that is not already captured
625 // in its `WorkerState`.
626 //
627 // This is needed only for checkpointing purposes. We keep this
628 // separate from `WorkerState` and guard its fields using a separate
629 // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
630 struct WorkerThreadState {
631 // The output element that has been produced from the input iterator
632 // and is waiting to be added to `WorkerState.outputs`.
633 OutputElem output_elem;
634
635 // Whether the input iterator returned an `end_of_sequence`.
636 bool end_of_sequence = false;
637
638 // Status returned from `MakeIteratorFromInputElement`.
639 Status iterator_creation_status;
640
641 // The arguments to be used to construct `iterator`.
642 std::vector<Tensor> input;
643
644 std::unique_ptr<IteratorBase> iterator;
645
WorkerThreadStatetensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerThreadState646 WorkerThreadState() : output_elem(Status::OK()) {}
647 };
648
CancelThreads()649 void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
650 mutex_lock l(mu_);
651 cancelled_ = true;
652 for (auto& worker : workers_) {
653 worker.cond_var.notify_all();
654 }
655 }
656
EnsureWorkerThreadsStarted(IteratorContext * ctx)657 Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
658 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
659 if (worker_threads_.empty() && input_impl_) {
660 worker_threads_.reserve(dataset()->num_threads());
661 for (int64 i = 0; i < dataset()->num_threads(); ++i) {
662 std::vector<Tensor> args;
663 bool end_of_input = false;
664 Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
665 if (end_of_input) {
666 input_impl_.reset();
667 return Status::OK();
668 }
669 workers_[i].SetInputs(s, std::move(args));
670 std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
671 worker_threads_.push_back(ctx->StartThread(
672 strings::StrCat(kDataParallelInterleaveWorker, "_", i),
673 [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
674 if (i < dataset()->cycle_length_) {
675 interleave_indices_.push_back(i);
676 } else {
677 staging_indices_.push_back(i);
678 }
679 }
680 DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
681 DCHECK(staging_indices_.size() == dataset()->prefetch_input_elements_);
682 }
683 return Status::OK();
684 }
685
686 // Produces elements into the worker's output buffers.
WorkerThread(const std::shared_ptr<IteratorContext> & ctx,const int64 thread_index)687 void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
688 const int64 thread_index) {
689 // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
690 //
691 // 1. Any local state that may need to be checkpointed should be kept
692 // in `worker_thread_states_[thread_index]`.
693 // 2. `WorkerThreadState` should contain state that is needed only for
694 // checkpointing, i.e., if we were to remove checkpointing support,
695 // we could keep that state as local variables in this thread.
696 // 3. This thread should only read/write state at `thread_index`
697 // and should not access other thread states.
698 // 4. When restoring from checkpoint, threads are started only after
699 // the restore is complete.
700 // 5. Once restored from a checkpoint, the local state is edited only
701 // by this thread. 3 & 4 allow making assumptions like temporarily
702 // caching local state in this thread and using it outside a lock
703 // e.g. `make_new_iterator`.
704 // 6. `ckpt_mu_` should be wisely used to create *consistent*
705 // checkpoint markers.
706
707 // std::function arguments are copy-constructable, so we pass raw
708 // pointers, and then immediately wrap them to ensure correct ownership.
709 RecordStart(ctx.get());
710 auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
711 mutex_lock l(mu_);
712 workers_[thread_index].cond_var.notify_all();
713 RecordStop(ctx.get());
714 });
715 bool make_new_iterator;
716 {
717 tf_shared_lock l(ckpt_mu_);
718 // Decide whether a new iterator should be built.
719 // 1. If there is an existing iterator, we use it.
720 // 2. If there was an error in iterator creation that could not be
721 // notified to the client we attempt to send that to the client
722 // first.
723 make_new_iterator =
724 worker_thread_states_[thread_index].iterator == nullptr &&
725 worker_thread_states_[thread_index].iterator_creation_status.ok();
726 }
727 // Even though `make_new_iterator` has cached values from
728 // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
729 // it is safe to *read* `make_new_iterator`outside of a lock without
730 // worrying about concurrent changes to values in
731 // `worker_thread_states_[thread_index]`. See comment at the start of
732 // this function for details.
733 while (true) {
734 // Whether creation of the iterator succeeded.
735 Status iterator_creation_status;
736 // 1. Build a new iterator or use the existing one.
737 if (make_new_iterator) {
738 // 1a. Get new input tensors or use the exiting ones.
739 bool read_new_input;
740 {
741 tf_shared_lock l(ckpt_mu_);
742 // worker_thread_states_[thread_index].input will be non-empty
743 // if checkpointing happened at CHECKPOINT_MARKER_A.
744 read_new_input = worker_thread_states_[thread_index].input.empty();
745 }
746
747 if (read_new_input) {
748 mutex_lock l(mu_);
749 while (!cancelled_ && !workers_[thread_index].is_producing) {
750 RecordStop(ctx.get());
751 workers_[thread_index].cond_var.wait(l);
752 RecordStart(ctx.get());
753 }
754 if (cancelled_) return;
755 // Copy the input tensors so that we do not need to block on `mu_`
756 // when building the iterator.
757 // We keep a copy of the input tensors in
758 // `WorkerThreadState.input` till the iterator is in use. This is
759 // used in `RestoreInternal` to re-build the iterator.
760 // TODO(b/78046638): Explore ways to avoid tracking the input
761 // tensors.
762 tf_shared_lock ckpt_l(ckpt_mu_);
763 worker_thread_states_[thread_index].input.swap(
764 workers_[thread_index].input);
765 // CHECKPOINT_MARKER_A
766 // We have the input tensors but have not built the iterator yet.
767 }
768
769 // 1b. Run the user defined function to produce a new iterator.
770 {
771 tf_shared_lock l(ckpt_mu_);
772 worker_thread_states_[thread_index].iterator_creation_status =
773 MakeIteratorFromInputElement(
774 ctx.get(), this, worker_thread_states_[thread_index].input,
775 thread_index, *instantiated_captured_func_, prefix(),
776 &worker_thread_states_[thread_index].iterator,
777 model_node());
778 iterator_creation_status =
779 worker_thread_states_[thread_index].iterator_creation_status;
780 if (!iterator_creation_status.ok()) {
781 worker_thread_states_[thread_index].input.clear();
782 }
783 // CHECKPOINT_MARKER_B
784 // Either an iterator has been successfully built and placed in
785 // `worker_thread_states_[thread_index].iterator` or it failed and
786 // a non-OK status has been put in
787 // `worker_thread_states_[thread_index].iterator_creation_status`.
788 }
789 } else {
790 tf_shared_lock l(ckpt_mu_);
791 iterator_creation_status =
792 worker_thread_states_[thread_index].iterator_creation_status;
793 // Mark that we have used up the restored iterator.
794 make_new_iterator = true;
795 }
796 // 2. Start producing elements or send error state to client if
797 // iterator creation failed.
798 if (!iterator_creation_status.ok()) {
799 mutex_lock l(mu_);
800 // Wait for space in the prefetch queue.
801 while (!cancelled_ && workers_[thread_index].outputs.size() ==
802 dataset()->buffer_output_elements_) {
803 RecordStop(ctx.get());
804 workers_[thread_index].cond_var.wait(l);
805 RecordStart(ctx.get());
806 }
807 if (cancelled_) return;
808 tf_shared_lock ckpt_l(ckpt_mu_);
809 workers_[thread_index].outputs.emplace_back(iterator_creation_status);
810 workers_[thread_index].is_producing = false;
811 worker_thread_states_[thread_index].iterator_creation_status =
812 Status::OK();
813 // CHECKPOINT_MARKER_C
814 // Non-OK iterator creation status has been notified to the
815 // client.
816 if (deterministic_) {
817 workers_[thread_index].cond_var.notify_one();
818 } else {
819 any_element_available_cond_var_.notify_one();
820 }
821 } else {
822 bool end_of_sequence = false;
823 while (!end_of_sequence) {
824 // 3.a Produce an element!
825 {
826 tf_shared_lock ckpt_l(ckpt_mu_);
827 if (worker_thread_states_[thread_index].output_elem.status.ok() &&
828 worker_thread_states_[thread_index]
829 .output_elem.output.empty() &&
830 !worker_thread_states_[thread_index].end_of_sequence) {
831 int64& id = worker_thread_states_[thread_index].output_elem.id;
832 profiler::TraceMe traceme(
833 [&] {
834 id = profiler::TraceMe::NewActivityId();
835 return profiler::TraceMeEncode(
836 "ParallelInterleaveProduce", {{"element_id", id}});
837 },
838 profiler::kInfo);
839 worker_thread_states_[thread_index].output_elem.status =
840 worker_thread_states_[thread_index].iterator->GetNext(
841 ctx.get(),
842 &worker_thread_states_[thread_index].output_elem.output,
843 &worker_thread_states_[thread_index].end_of_sequence);
844 end_of_sequence =
845 worker_thread_states_[thread_index].end_of_sequence;
846 } else {
847 end_of_sequence =
848 worker_thread_states_[thread_index].end_of_sequence;
849 }
850 // CHECKPOINT_MARKER_D
851 // An element has been read or an error or end_of_sequence has
852 // been received from the input iterator and is waiting to be
853 // sent to client.
854 }
855
856 // 3.b Make it available to the client.
857 {
858 mutex_lock l(mu_);
859
860 // Wait for space in the prefetch queue.
861 while (!cancelled_ && workers_[thread_index].outputs.size() ==
862 dataset()->buffer_output_elements_) {
863 RecordStop(ctx.get());
864 workers_[thread_index].cond_var.wait(l);
865 RecordStart(ctx.get());
866 }
867 if (cancelled_) return;
868
869 tf_shared_lock ckpt_l(ckpt_mu_);
870 workers_[thread_index].is_producing = !end_of_sequence;
871
872 // Output the element.
873
874 // Move the temporary state in WorkerThreadState to WorkerState
875 // and mark it as used.
876 if (end_of_sequence) {
877 worker_thread_states_[thread_index].iterator.reset();
878 worker_thread_states_[thread_index].input.clear();
879 worker_thread_states_[thread_index].end_of_sequence = false;
880 } else {
881 workers_[thread_index].outputs.emplace_back(
882 worker_thread_states_[thread_index].output_elem.status,
883 worker_thread_states_[thread_index].output_elem.id);
884 workers_[thread_index].outputs.back().output.swap(
885 worker_thread_states_[thread_index].output_elem.output);
886 }
887 worker_thread_states_[thread_index].output_elem.status =
888 Status::OK();
889 if (deterministic_) {
890 workers_[thread_index].cond_var.notify_one();
891 } else {
892 any_element_available_cond_var_.notify_one();
893 }
894 // CHECKPOINT_MARKER_E
895 // Output element or iterator status has been sent to the
896 // client.
897 }
898 }
899 }
900 }
901 }
902
WriteWorkerStateLocked(IteratorStateWriter * writer,int index)903 Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
904 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
905 string iterator_name =
906 strings::StrCat(prefix(), "::", kWorker, "_", index);
907 TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kInputSize,
908 workers_[index].input.size()));
909 for (int i = 0; i < workers_[index].input.size(); ++i) {
910 TF_RETURN_IF_ERROR(writer->WriteTensor(iterator_name,
911 strings::StrCat(kInput, "_", i),
912 workers_[index].input[i]));
913 }
914 TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kOutputsSize,
915 workers_[index].outputs.size()));
916 for (int i = 0; i < workers_[index].outputs.size(); ++i) {
917 TF_RETURN_IF_ERROR(WriteOutputElemLocked(
918 writer, workers_[index].outputs[i], iterator_name,
919 strings::StrCat(kOutputs, "_", i)));
920 }
921 if (workers_[index].is_producing) {
922 TF_RETURN_IF_ERROR(
923 writer->WriteScalar(iterator_name, kIsProducing, ""));
924 }
925 return Status::OK();
926 }
927
ReadWorkerStateLocked(IteratorStateReader * reader,int index,IteratorContext * ctx)928 Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
929 IteratorContext* ctx)
930 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
931 string worker_prefix =
932 strings::StrCat(prefix(), "::", kWorker, "_", index);
933 // Restore inputs.
934 int64 input_size;
935 TF_RETURN_IF_ERROR(
936 reader->ReadScalar(worker_prefix, kInputSize, &input_size));
937 workers_[index].input.reserve(input_size);
938 for (int i = 0; i < input_size; ++i) {
939 workers_[index].input.emplace_back();
940 TF_RETURN_IF_ERROR(reader->ReadTensor(worker_prefix,
941 strings::StrCat(kInput, "_", i),
942 &workers_[index].input.back()));
943 }
944 int64 outputs_size;
945 TF_RETURN_IF_ERROR(
946 reader->ReadScalar(worker_prefix, kOutputsSize, &outputs_size));
947 for (int i = 0; i < outputs_size; ++i) {
948 workers_[index].outputs.emplace_back(Status::OK());
949 TF_RETURN_IF_ERROR(ReadOutputElemLocked(
950 reader, &workers_[index].outputs.back(), worker_prefix,
951 strings::StrCat(kOutputs, "_", i)));
952 }
953 if (reader->Contains(worker_prefix, kIsProducing)) {
954 workers_[index].is_producing = true;
955 } else {
956 workers_[index].is_producing = false;
957 }
958 return Status::OK();
959 }
960
WriteWorkerThreadStateLocked(SerializationContext * ctx,IteratorStateWriter * writer,int index)961 Status WriteWorkerThreadStateLocked(SerializationContext* ctx,
962 IteratorStateWriter* writer, int index)
963 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
964 string iterator_name =
965 strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
966 if (worker_thread_states_[index].iterator != nullptr) {
967 TF_RETURN_IF_ERROR(
968 SaveInput(ctx, writer, worker_thread_states_[index].iterator));
969 } else {
970 TF_RETURN_IF_ERROR(
971 writer->WriteScalar(iterator_name, kIteratorExhausted, ""));
972 }
973 TF_RETURN_IF_ERROR(
974 writer->WriteScalar(iterator_name, kInputSize,
975 worker_thread_states_[index].input.size()));
976 for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
977 TF_RETURN_IF_ERROR(
978 writer->WriteTensor(iterator_name, strings::StrCat(kInput, "_", i),
979 worker_thread_states_[index].input[i]));
980 }
981 TF_RETURN_IF_ERROR(WriteStatusLocked(
982 writer, iterator_name, kIteratorCreationStatus,
983 worker_thread_states_[index].iterator_creation_status));
984 TF_RETURN_IF_ERROR(WriteOutputElemLocked(
985 writer, worker_thread_states_[index].output_elem, iterator_name,
986 kOutput));
987 if (worker_thread_states_[index].end_of_sequence) {
988 TF_RETURN_IF_ERROR(
989 writer->WriteScalar(iterator_name, kEndOfSequence, ""));
990 }
991 return Status::OK();
992 }
993
ReadWorkerThreadStateLocked(IteratorStateReader * reader,int index,IteratorContext * ctx,WorkerThreadState * state)994 Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
995 IteratorContext* ctx,
996 WorkerThreadState* state) {
997 string worker_prefix =
998 strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
999 // Restore inputs.
1000 int64 input_size;
1001 TF_RETURN_IF_ERROR(
1002 reader->ReadScalar(worker_prefix, kInputSize, &input_size));
1003 state->input.reserve(input_size);
1004 for (int i = 0; i < input_size; ++i) {
1005 state->input.emplace_back();
1006 TF_RETURN_IF_ERROR(reader->ReadTensor(worker_prefix,
1007 strings::StrCat(kInput, "_", i),
1008 &state->input.back()));
1009 }
1010 // Restore iterator
1011 if (reader->Contains(worker_prefix, kIteratorExhausted)) {
1012 state->iterator.reset();
1013 } else {
1014 std::unique_ptr<IteratorBase> iterator;
1015 // NOTE: We intentionally ignore resource modeling outside GetNext().
1016 TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
1017 ctx, this, state->input, index, *instantiated_captured_func_,
1018 prefix(), &iterator, /*node=*/nullptr));
1019 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
1020 state->iterator.swap(iterator);
1021 }
1022 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, worker_prefix,
1023 kIteratorCreationStatus,
1024 &state->iterator_creation_status));
1025 TF_RETURN_IF_ERROR(ReadOutputElemLocked(reader, &state->output_elem,
1026 worker_prefix, kOutput));
1027 if (reader->Contains(worker_prefix, kEndOfSequence)) {
1028 state->end_of_sequence = true;
1029 } else {
1030 state->end_of_sequence = false;
1031 }
1032 return Status::OK();
1033 }
1034
WriteOutputElemLocked(IteratorStateWriter * writer,const OutputElem & output_elem,const string & iterator_name,const string & prefix)1035 Status WriteOutputElemLocked(IteratorStateWriter* writer,
1036 const OutputElem& output_elem,
1037 const string& iterator_name,
1038 const string& prefix)
1039 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1040 TF_RETURN_IF_ERROR(WriteStatusLocked(
1041 writer, iterator_name, strings::StrCat(prefix, "_", kStatus),
1042 output_elem.status));
1043 TF_RETURN_IF_ERROR(writer->WriteScalar(
1044 iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1045 output_elem.output.size()));
1046 for (int i = 0; i < output_elem.output.size(); ++i) {
1047 TF_RETURN_IF_ERROR(writer->WriteTensor(
1048 iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i),
1049 output_elem.output[i]));
1050 }
1051 return Status::OK();
1052 }
1053
ReadOutputElemLocked(IteratorStateReader * reader,OutputElem * output_elem,const string & iterator_name,const string & prefix)1054 Status ReadOutputElemLocked(IteratorStateReader* reader,
1055 OutputElem* output_elem,
1056 const string& iterator_name,
1057 const string& prefix) {
1058 TF_RETURN_IF_ERROR(ReadStatusLocked(reader, iterator_name,
1059 strings::StrCat(prefix, "_", kStatus),
1060 &output_elem->status));
1061 int64 output_size;
1062 TF_RETURN_IF_ERROR(reader->ReadScalar(
1063 iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1064 &output_size));
1065 output_elem->output.reserve(output_size);
1066 for (int i = 0; i < output_size; ++i) {
1067 output_elem->output.emplace_back();
1068 TF_RETURN_IF_ERROR(reader->ReadTensor(
1069 iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i),
1070 &output_elem->output.back()));
1071 }
1072 return Status::OK();
1073 }
1074
WriteStatusLocked(IteratorStateWriter * writer,const string & iterator_name,const string & prefix,const Status & status)1075 Status WriteStatusLocked(IteratorStateWriter* writer,
1076 const string& iterator_name, const string& prefix,
1077 const Status& status)
1078 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1079 TF_RETURN_IF_ERROR(writer->WriteScalar(
1080 iterator_name, strings::StrCat(prefix, "_", kCode),
1081 static_cast<int64>(status.code())));
1082 if (!status.ok()) {
1083 TF_RETURN_IF_ERROR(writer->WriteScalar(
1084 iterator_name, strings::StrCat(prefix, "_", KMessage),
1085 status.error_message()));
1086 }
1087 return Status::OK();
1088 }
1089
ReadStatusLocked(IteratorStateReader * reader,const string & iterator_name,const string & prefix,Status * status)1090 Status ReadStatusLocked(IteratorStateReader* reader,
1091 const string& iterator_name, const string& prefix,
1092 Status* status) {
1093 int64 code_int;
1094 TF_RETURN_IF_ERROR(reader->ReadScalar(
1095 iterator_name, strings::StrCat(prefix, "_", kCode), &code_int));
1096 error::Code code = static_cast<error::Code>(code_int);
1097
1098 if (code != error::Code::OK) {
1099 tstring error_message;
1100 TF_RETURN_IF_ERROR(reader->ReadScalar(
1101 iterator_name, strings::StrCat(prefix, "_", KMessage),
1102 &error_message));
1103 *status = Status(code, error_message);
1104 } else {
1105 *status = Status::OK();
1106 }
1107 return Status::OK();
1108 }
1109
1110 // Mutex & condition variable to guard mutable iterator internals and
1111 // coordinate among worker threads and client thread[s].
1112 mutex mu_ TF_ACQUIRED_BEFORE(ckpt_mu_);
1113 // The main thread waits on this condition variable if running in
1114 // nondeterministic mode and no values are available.
1115 condition_variable any_element_available_cond_var_;
1116 // Whether outputs must be produced in deterministic order.
1117 const bool deterministic_;
1118 // Mutex used to wait for a consistent state while checkpointing.
1119 // Only Save and Restore require an exclusive lock on this mutex. In
1120 // other scenarios we just acquire a shared lock so the pipeline's
1121 // performance should not be affected in the absence of checkpointing.
1122 // A thread must not wait on any condition variable while holding
1123 // `ckpt_mu_` in either shared or exclusive modes.
1124 mutex ckpt_mu_;
1125
1126 // The iterator producing elements which are converted to datasets by
1127 // the dataset()->captured_func_ then interleaved together.
1128 // input_impl_ is reset when we have exhausted its input.
1129 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
1130
1131 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
1132
1133 // The WorkerState structs the worker threads operate on.
1134 // workers_ elements are in at most one of interleave_ and staging_.
1135 std::vector<WorkerState> workers_ TF_GUARDED_BY(mu_);
1136
1137 // Stores the temporary state of WorkerThreads which is not stored in
1138 // WorkerState. This is used for checkpointing purposes only.
1139 std::vector<WorkerThreadState> worker_thread_states_
1140 TF_GUARDED_BY(ckpt_mu_);
1141
1142 // Indices in `workers_` of iterators to interleave.
1143 std::vector<int64> interleave_indices_ TF_GUARDED_BY(mu_);
1144 // Indices in `workers_` of prefetched iterators.
1145 std::deque<int64> staging_indices_ TF_GUARDED_BY(mu_);
1146
1147 // The index into output_elements_ for next element to produce.
1148 size_t next_index_ TF_GUARDED_BY(mu_) = 0;
1149 // The number of items produced so far within the block
1150 size_t block_count_ TF_GUARDED_BY(mu_) = 0;
1151 // Flag to instruct the worker threads to exit.
1152 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1153 // The worker threads. This must be last to ensure the
1154 // threads have exited before any other members are deallocated.
1155 // TODO(b/65178177): Avoid allocating additional threads.
1156 std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
1157 };
1158
1159 const DatasetBase* const input_;
1160 const std::unique_ptr<CapturedFunction> captured_func_;
1161 const int64 cycle_length_;
1162 const int64 block_length_;
1163 const DeterminismPolicy deterministic_;
1164 const int64 buffer_output_elements_;
1165 const int64 prefetch_input_elements_;
1166 const DataTypeVector output_types_;
1167 const std::vector<PartialTensorShape> output_shapes_;
1168 const TraceMeMetadata traceme_metadata_;
1169 const int op_version_;
1170 };
1171
ParallelInterleaveDatasetOp(OpKernelConstruction * ctx)1172 ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
1173 OpKernelConstruction* ctx)
1174 : UnaryDatasetOpKernel(ctx),
1175 op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
1176 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
1177 &func_metadata_));
1178 if (op_version_ == 2) {
1179 std::string deterministic;
1180 OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
1181 OP_REQUIRES_OK(
1182 ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
1183 }
1184 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1185 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1186 }
1187
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1188 void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
1189 DatasetBase* input,
1190 DatasetBase** output) {
1191 int64 cycle_length = 0;
1192 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
1193 OP_REQUIRES(ctx, cycle_length > 0,
1194 errors::InvalidArgument("`cycle_length` must be > 0"));
1195
1196 int64 block_length = 0;
1197 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
1198 OP_REQUIRES(ctx, block_length > 0,
1199 errors::InvalidArgument("`block_length` must be > 0"));
1200
1201 if (op_version_ == 1) {
1202 bool sloppy = false;
1203 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
1204 if (sloppy) {
1205 deterministic_ =
1206 DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
1207 } else {
1208 deterministic_ =
1209 DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
1210 }
1211 }
1212
1213 int64 buffer_output_elements = 0;
1214 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
1215 &buffer_output_elements));
1216 OP_REQUIRES(ctx, buffer_output_elements > 0,
1217 errors::InvalidArgument("`buffer_output_elements` must be > 0"));
1218
1219 int64 prefetch_input_elements = 0;
1220 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
1221 &prefetch_input_elements));
1222 OP_REQUIRES(
1223 ctx, prefetch_input_elements >= 0,
1224 errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
1225
1226 std::unique_ptr<CapturedFunction> captured_func;
1227 OP_REQUIRES_OK(ctx,
1228 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
1229 &captured_func));
1230
1231 *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
1232 block_length, deterministic_, buffer_output_elements,
1233 prefetch_input_elements, output_types_, output_shapes_,
1234 op_version_);
1235 }
1236
1237 namespace {
1238 REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
1239 ParallelInterleaveDatasetOp);
1240 REGISTER_KERNEL_BUILDER(
1241 Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
1242 ParallelInterleaveDatasetOp);
1243 REGISTER_KERNEL_BUILDER(
1244 Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
1245 ParallelInterleaveDatasetOp);
1246
1247 REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
1248 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
1249 REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
1250
1251 } // namespace
1252 } // namespace experimental
1253 } // namespace data
1254 } // namespace tensorflow
1255