• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.h"
16 
17 #include <atomic>
18 #include <deque>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
23 #include "tensorflow/core/framework/dataset.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/stats_aggregator.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/kernels/data/dataset_utils.h"
28 #include "tensorflow/core/kernels/data/name_utils.h"
29 #include "tensorflow/core/lib/core/threadpool.h"
30 #include "tensorflow/core/lib/gtl/cleanup.h"
31 #include "tensorflow/core/lib/random/random.h"
32 #include "tensorflow/core/platform/blocking_counter.h"
33 #include "tensorflow/core/platform/stringprintf.h"
34 #include "tensorflow/core/profiler/lib/traceme.h"
35 #include "tensorflow/core/profiler/lib/traceme_encode.h"
36 
37 namespace tensorflow {
38 namespace data {
39 namespace experimental {
40 
41 /* static */ constexpr const char* const
42     ParallelInterleaveDatasetOp::kDatasetType;
43 /* static */ constexpr const char* const
44     ParallelInterleaveDatasetOp::kInputDataset;
45 /* static */ constexpr const char* const
46     ParallelInterleaveDatasetOp::kOtherArguments;
47 /* static */ constexpr const char* const
48     ParallelInterleaveDatasetOp::kCycleLength;
49 /* static */ constexpr const char* const
50     ParallelInterleaveDatasetOp::kBlockLength;
51 /* static */ constexpr const char* const
52     ParallelInterleaveDatasetOp::kDeterministic;
53 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
54 /* static */ constexpr const char* const
55     ParallelInterleaveDatasetOp::kBufferOutputElements;
56 /* static */ constexpr const char* const
57     ParallelInterleaveDatasetOp::kPrefetchInputElements;
58 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
59 /* static */ constexpr const char* const
60     ParallelInterleaveDatasetOp::kTarguments;
61 /* static */ constexpr const char* const
62     ParallelInterleaveDatasetOp::kOutputTypes;
63 /* static */ constexpr const char* const
64     ParallelInterleaveDatasetOp::kOutputShapes;
65 
66 constexpr char kInputExhausted[] = "input_exhausted";
67 constexpr char kNextIndex[] = "next_index";
68 constexpr char kBlockCount[] = "block_count";
69 constexpr char kWorkersSize[] = "workers_size";
70 constexpr char kInterleaveSize[] = "interleave_size";
71 constexpr char kInterleaveIndices[] = "interleave_indices";
72 constexpr char kStagingSize[] = "staging_size";
73 constexpr char kStagingIndices[] = "staging_indices";
74 constexpr char kWorkerThreadsRunning[] = "worker_threads_running";
75 constexpr char kDataParallelInterleaveWorker[] =
76     "data_parallel_interleave_worker";
77 constexpr char kWorker[] = "worker";
78 constexpr char kInputSize[] = "input_size";
79 constexpr char kInput[] = "input";
80 constexpr char kOutputsSize[] = "outputs_size";
81 constexpr char kOutputs[] = "outputs";
82 constexpr char kIsProducing[] = "is_producing";
83 constexpr char kWorkerThread[] = "worker_thread";
84 constexpr char kIteratorExhausted[] = "iterator_exhausted";
85 constexpr char kIteratorCreationStatus[] = "iterator_creation_status";
86 constexpr char kOutput[] = "output";
87 constexpr char kEndOfSequence[] = "end_of_sequence";
88 constexpr char kStatus[] = "status";
89 constexpr char kOutputSize[] = "output_size";
90 constexpr char kCode[] = "code";
91 constexpr char KMessage[] = "msg";
92 
93 class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
94  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64 cycle_length,int64 block_length,DeterminismPolicy deterministic,int64 buffer_output_elements,int64 prefetch_input_elements,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int op_version)95   Dataset(OpKernelContext* ctx, const DatasetBase* input,
96           std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
97           int64 block_length, DeterminismPolicy deterministic,
98           int64 buffer_output_elements, int64 prefetch_input_elements,
99           const DataTypeVector& output_types,
100           const std::vector<PartialTensorShape>& output_shapes, int op_version)
101       : DatasetBase(DatasetContext(ctx)),
102         input_(input),
103         captured_func_(std::move(captured_func)),
104         cycle_length_(cycle_length),
105         block_length_(block_length),
106         deterministic_(deterministic),
107         buffer_output_elements_(buffer_output_elements),
108         prefetch_input_elements_(prefetch_input_elements),
109         output_types_(output_types),
110         output_shapes_(output_shapes),
111         traceme_metadata_(
112             {{"block_length",
113               strings::Printf("%lld", static_cast<long long>(block_length))},
114              {"cycle_length",
115               strings::Printf("%lld", static_cast<long long>(cycle_length))},
116              {"deterministic",
117               deterministic.IsDeterministic() || deterministic.IsDefault()
118                   ? "true"
119                   : "false"}}),
120         op_version_(op_version) {
121     input_->Ref();
122   }
123 
~Dataset()124   ~Dataset() override { input_->Unref(); }
125 
MakeIteratorInternal(const string & prefix) const126   std::unique_ptr<IteratorBase> MakeIteratorInternal(
127       const string& prefix) const override {
128     name_utils::IteratorPrefixParams params;
129     params.op_version = op_version_;
130     bool deterministic =
131         deterministic_.IsDeterministic() || deterministic_.IsDefault();
132     return absl::make_unique<Iterator>(
133         Iterator::Params{
134             this, name_utils::IteratorPrefix(kDatasetType, prefix, params)},
135         deterministic);
136   }
137 
output_dtypes() const138   const DataTypeVector& output_dtypes() const override { return output_types_; }
139 
output_shapes() const140   const std::vector<PartialTensorShape>& output_shapes() const override {
141     return output_shapes_;
142   }
143 
DebugString() const144   string DebugString() const override {
145     name_utils::DatasetDebugStringParams params;
146     params.op_version = op_version_;
147     return name_utils::DatasetDebugString(kDatasetType, params);
148   }
149 
InputDatasets(std::vector<const DatasetBase * > * inputs) const150   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
151     inputs->push_back(input_);
152     return Status::OK();
153   }
154 
CheckExternalState() const155   Status CheckExternalState() const override {
156     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
157     return input_->CheckExternalState();
158   }
159 
160  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const161   Status AsGraphDefInternal(SerializationContext* ctx,
162                             DatasetGraphDefBuilder* b,
163                             Node** output) const override {
164     std::vector<std::pair<size_t, Node*>> inputs;
165     std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
166     int input_index = 0;
167 
168     Node* input_node;
169     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
170     inputs.emplace_back(input_index++, input_node);
171 
172     std::vector<Node*> other_arguments;
173     DataTypeVector other_arguments_types;
174     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
175                                                   &other_arguments_types));
176     list_inputs.emplace_back(input_index++, other_arguments);
177 
178     Node* cycle_length_node;
179     TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
180     inputs.emplace_back(input_index++, cycle_length_node);
181 
182     Node* block_length_node;
183     TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
184     inputs.emplace_back(input_index++, block_length_node);
185 
186     if (op_version_ == 1) {
187       Node* sloppy_node;
188       TF_RETURN_IF_ERROR(
189           b->AddScalar(deterministic_.IsNondeterministic(), &sloppy_node));
190       inputs.emplace_back(input_index++, sloppy_node);
191     }
192 
193     Node* buffer_output_elements_node;
194     TF_RETURN_IF_ERROR(
195         b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
196     inputs.emplace_back(input_index++, buffer_output_elements_node);
197 
198     Node* prefetch_input_elements_node;
199     TF_RETURN_IF_ERROR(
200         b->AddScalar(prefetch_input_elements_, &prefetch_input_elements_node));
201     inputs.emplace_back(input_index++, prefetch_input_elements_node);
202 
203     std::vector<std::pair<StringPiece, AttrValue>> attrs;
204 
205     AttrValue f;
206     b->BuildAttrValue(captured_func_->func(), &f);
207     attrs.emplace_back(kFunc, f);
208 
209     if (op_version_ == 2) {
210       AttrValue deterministic_attr;
211       b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
212       attrs.emplace_back(kDeterministic, deterministic_attr);
213     }
214 
215     AttrValue other_arguments_types_attr;
216     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
217     attrs.emplace_back(kTarguments, other_arguments_types_attr);
218 
219     TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
220     return Status::OK();
221   }
222 
223  private:
num_threads() const224   int64 num_threads() const { return cycle_length_ + prefetch_input_elements_; }
225 
226   // Parallel interleave's implementation is designed around a few principles:
227   //  1. Thread creation is relatively expensive. (Not reusing
228   //     threads causes a number of indirect costs such as poorer tcmalloc
229   //     performance due to thread-local caches, etc.) We allocate a fixed
230   //     number of threads at the start and never change. This is why we've
231   //     fused functionality that is theoretically orthogonal (i.e.
232   //     .prefetch()) into the implementation.
233   //  2. Drop-in replacement for standard interleave. The goal will be to
234   //     auto-opt people into an optimized implementation without any work
235   //     on the customer's part. We thus go through great pains to maintain
236   //     identical iteration orders, full determinism (disabled only via a
237   //     flag, etc.)
238   //  3. Performance across a variety of environments and I/O envelopes.
239   //
240   // The actual implementation centers around a collection of worker threads
241   // and their corresponding worker state (tracked in the `workers_` vector).
242   // Worker threads repeatedly receive a vector of Tensors that are used as
243   // input to the flat-map function (`captured_func_`). The output of this
244   // function must be a dataset. The worker thread then repeatedly calls
245   // `GetNext()`, maintaining a buffer of elements to minimize the likelihood
246   // that a caller will block waiting for an element to be produced.
247   //
248   // Pointers to these worker states are kept in 2 disjoint data structures:
249   //  1. `interleave_indices_` is a vector containing indices of WorkerStates
250   //     in `workers_` that we are interleaving. Worker threads backing these
251   //     WorkerStates should be regularly producing values.
252   //  2. `staging_indices_` is a deque containing indices of WorkerStates in
253   //     `workers_` that we will move to `interleave_indices_` when an
254   //     iterator in `interleave_indices_` is exhausted.
255   //
256   // The client calls `GetNext[Internal]()` to retrieve an output element. The
257   // internal implementation updates the state of `interleave_indices_` and
258   // `staging_indices_` as output iterators (run by the worker threads) are
259   // exhausted.
260   //
261   // `input_impl_` is the input iterator that generates arguments for the
262   // flat-map function (`captured_func_`). It is set to an iterator at
263   // Iterator construction, and is fixed until we consume all input elements.
264   // Once it is exhausted, we reset the unique_ptr to eagerly deallocate
265   // memory.
266   //
267   // A few invariants are maintained:
268   //  1. No element in interleave_indices_ should be a -1 unless
269   //     `staging_indices_` is empty and `input_impl_` is empty.
270   //  2. Every `worker_` element is pointed to by at most one element of the
271   //     union of `interleave_indices_` and `staging_indices_`.
272   //  3. Unless `input_impl_` is empty, every `worker_` must be pointed to by
273   //     an element in `interleave_indices_` or `staging_indices_`.
274   class Iterator : public DatasetIterator<Dataset> {
275    public:
Iterator(const Params & params,bool deterministic)276     explicit Iterator(const Params& params, bool deterministic)
277         : DatasetIterator<Dataset>(params),
278           deterministic_(deterministic),
279           workers_(dataset()->num_threads()),
280           worker_thread_states_(dataset()->num_threads()) {}
281 
~Iterator()282     ~Iterator() override { CancelThreads(); }
283 
Initialize(IteratorContext * ctx)284     Status Initialize(IteratorContext* ctx) override {
285       // TODO(jsimsa): Register cancellation callback once the implementation is
286       // refactored not to hold mu_ while calling `GetNext` on the input.
287       TF_RETURN_IF_ERROR(
288           dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
289       return dataset()->captured_func_->Instantiate(
290           ctx, &instantiated_captured_func_);
291     }
292 
293     // It is implemented so that it matches the deterministic interleave
294     // unless getting the next element would block and we are allowed to be
295     // nondeterministic.
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)296     Status GetNextInternal(IteratorContext* ctx,
297                            std::vector<Tensor>* out_tensors,
298                            bool* end_of_sequence) override {
299       mutex_lock l(mu_);
300       TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx));
301       while (!cancelled_) {
302         // Wait for an item to become available, blocking if necessary. If we
303         // are allowed to be nondeterministic, we can skip over input datasets
304         // that do not have an item readily available.
305         bool can_produce_elements = false;
306         bool must_wait_for_input = true;
307         for (int64 i = 0; i < interleave_indices_.size(); ++i) {
308           int64 index = (next_index_ + i) % interleave_indices_.size();
309           int64 current_worker_index = interleave_indices_[index];
310           if (current_worker_index < 0) {
311             continue;  // Empty interleave elements.
312           }
313           WorkerState* current_worker = &workers_[current_worker_index];
314           can_produce_elements |= current_worker->MayHaveElements();
315           if (!current_worker->outputs.empty()) {
316             // We have an element!
317             next_index_ = index;
318             const bool element_acquired_sloppily = !deterministic_ && i > 1;
319             if (!element_acquired_sloppily) {
320               // If the element was acquired in the regular (deterministic)
321               // order, then advance the current block and cycle pointers to
322               // the next element in the regular order.
323               block_count_++;
324               if (block_count_ == dataset()->block_length_) {
325                 next_index_ = (index + 1) % interleave_indices_.size();
326                 block_count_ = 0;
327               }
328             } else {
329               block_count_ = 0;
330             }
331             *end_of_sequence = false;
332             Status s = current_worker->outputs.front().status;
333             profiler::TraceMe traceme([&] {
334               return profiler::TraceMeEncode(
335                   "ParallelInterleaveConsume",
336                   {{"element_id", current_worker->outputs.front().id}});
337             });
338             current_worker->outputs.front().output.swap(*out_tensors);
339             current_worker->outputs.pop_front();
340             current_worker->cond_var.notify_one();
341             return s;
342           } else if (current_worker->is_producing && deterministic_) {
343             // current_worker.outputs.empty(), and we must wait for this
344             // iterator.
345             if (next_index_ != index) {
346               // We have advanced to a new iterator; reset block counts.
347               next_index_ = index;
348               block_count_ = 0;
349             }
350             break;
351           } else if (!current_worker->is_producing) {
352             // This iterator has reached end of input.
353             interleave_indices_[index] = -1;
354             if (input_impl_) {
355               // Start prefetching a new iterator.
356               std::vector<Tensor> args;
357               bool end_of_input = false;
358               Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
359               if (end_of_input) {
360                 input_impl_.reset();
361               } else {
362                 current_worker->SetInputs(s, std::move(args));
363                 staging_indices_.emplace_back(current_worker_index);
364               }
365             }
366 
367             if (!staging_indices_.empty()) {
368               // Move a worker from `staging_indices_` to
369               // `interleave_indices_`.
370               interleave_indices_[index] = staging_indices_.front();
371               staging_indices_.pop_front();
372 
373               next_index_ = (index + 1) % interleave_indices_.size();
374               block_count_ = 0;
375               // Restart the inner [for] loop
376               can_produce_elements = true;
377               must_wait_for_input = false;
378               break;
379             }
380           }
381         }
382 
383         if (!can_produce_elements && !input_impl_) {
384           // No potential for future values.
385           *end_of_sequence = true;
386           return Status::OK();
387         }
388 
389         if (must_wait_for_input) {
390           // Wait for elements to become available.
391           RecordStop(ctx);
392           if (deterministic_) {
393             workers_[interleave_indices_[next_index_]].cond_var.wait(l);
394           } else {
395             any_element_available_cond_var_.wait(l);
396           }
397           RecordStart(ctx);
398         }
399       }
400       return errors::Cancelled(
401           "ParallelInterleaveDatasetOp::Dataset::Iterator::GetNext");
402     }
403 
404    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const405     std::shared_ptr<model::Node> CreateNode(
406         IteratorContext* ctx, model::Node::Args args) const override {
407       return model::MakeAsyncInterleaveManyNode(std::move(args),
408                                                 /*parameters=*/{});
409     }
410 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)411     Status SaveInternal(SerializationContext* ctx,
412                         IteratorStateWriter* writer) override {
413       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
414           dataset()->captured_func_->CheckExternalState()));
415       // The order of locking is important here to avoid deadlock.
416       mutex_lock l(mu_);
417       mutex_lock ckpt_l(ckpt_mu_);
418       if (input_impl_) {
419         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
420       } else {
421         TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInputExhausted, ""));
422       }
423       TF_RETURN_IF_ERROR(
424           writer->WriteScalar(prefix(), kNextIndex, next_index_));
425       TF_RETURN_IF_ERROR(
426           writer->WriteScalar(prefix(), kBlockCount, block_count_));
427       TF_RETURN_IF_ERROR(
428           writer->WriteScalar(prefix(), kWorkersSize, workers_.size()));
429       for (int i = 0; i < workers_.size(); ++i) {
430         TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
431       }
432       for (int i = 0; i < worker_thread_states_.size(); ++i) {
433         TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(ctx, writer, i));
434       }
435       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInterleaveSize,
436                                              interleave_indices_.size()));
437       for (int i = 0; i < interleave_indices_.size(); ++i) {
438         TF_RETURN_IF_ERROR(writer->WriteScalar(
439             prefix(), strings::StrCat(kInterleaveIndices, "_", i),
440             interleave_indices_[i]));
441       }
442       TF_RETURN_IF_ERROR(
443           writer->WriteScalar(prefix(), kStagingSize, staging_indices_.size()));
444       for (int i = 0; i < staging_indices_.size(); ++i) {
445         TF_RETURN_IF_ERROR(writer->WriteScalar(
446             prefix(), strings::StrCat(kStagingIndices, "_", i),
447             staging_indices_[i]));
448       }
449       if (!worker_threads_.empty()) {
450         TF_RETURN_IF_ERROR(
451             writer->WriteScalar(prefix(), kWorkerThreadsRunning, ""));
452       }
453       return Status::OK();
454     }
455 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)456     Status RestoreInternal(IteratorContext* ctx,
457                            IteratorStateReader* reader) override {
458       {
459         // The order of locking is important here to avoid deadlock.
460         mutex_lock l(mu_);
461         mutex_lock ckpt_l(ckpt_mu_);
462         if (!reader->Contains(prefix(), kInputExhausted)) {
463           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
464         } else {
465           input_impl_.reset();
466         }
467         int64 temp;
468         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kNextIndex, &temp));
469         next_index_ = size_t(temp);
470         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kBlockCount, &temp));
471         block_count_ = size_t(temp);
472 
473         // Restore WorkerStates.
474         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kWorkersSize, &temp));
475         if (temp != dataset()->num_threads()) {
476           return errors::Internal("Expected ", dataset()->num_threads(),
477                                   " worker states but found ", temp, ".");
478         }
479         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
480           TF_RETURN_IF_ERROR(ReadWorkerStateLocked(reader, i, ctx));
481         }
482       }
483       std::unique_ptr<thread::ThreadPool> threadpool = ctx->CreateThreadPool(
484           "read_worker_thread_state", dataset()->num_threads());
485       Status s = Status::OK();
486       BlockingCounter counter(dataset()->num_threads());
487       for (size_t i = 0; i < dataset()->num_threads(); ++i) {
488         threadpool->Schedule([this, i, ctx, reader, &s, &counter] {
489           WorkerThreadState state;
490           Status result = ReadWorkerThreadStateLocked(reader, i, ctx, &state);
491           mutex_lock l(mu_);
492           mutex_lock ckpt_l(ckpt_mu_);
493           if (!result.ok()) {
494             s.Update(result);
495             counter.DecrementCount();
496             return;
497           }
498           worker_thread_states_[i] = std::move(state);
499           counter.DecrementCount();
500         });
501       }
502       counter.Wait();
503       if (!s.ok()) {
504         return s;
505       }
506 
507       mutex_lock l(mu_);
508       mutex_lock ckpt_l(ckpt_mu_);
509       // Restore `interleave_indices_`.
510       std::set<int64> all_indices;
511       {
512         int64 interleave_size;
513         TF_RETURN_IF_ERROR(
514             reader->ReadScalar(prefix(), kInterleaveSize, &interleave_size));
515         interleave_indices_.reserve(interleave_size);
516         for (int64 i = 0; i < interleave_size; ++i) {
517           int64 temp;
518           TF_RETURN_IF_ERROR(reader->ReadScalar(
519               prefix(), strings::StrCat(kInterleaveIndices, "_", i), &temp));
520           if (temp >= 0 && all_indices.find(temp) != all_indices.end()) {
521             return errors::Internal(
522                 "Duplicate entry for ", temp,
523                 " found when reading interleave and staging indices.");
524           }
525           if (temp >= 0) {
526             all_indices.insert(temp);
527           }
528           interleave_indices_.emplace_back(temp);
529         }
530       }
531 
532       // Restore `staging_indices_`.
533       {
534         int64 staging_size;
535         TF_RETURN_IF_ERROR(
536             reader->ReadScalar(prefix(), kStagingSize, &staging_size));
537         for (int i = 0; i < staging_size; ++i) {
538           int64 temp;
539           TF_RETURN_IF_ERROR(reader->ReadScalar(
540               prefix(), strings::StrCat(kStagingIndices, "_", i), &temp));
541           if (all_indices.find(temp) != all_indices.end()) {
542             return errors::Internal(
543                 "Duplicate entry for ", temp,
544                 " found when reading interleave and staging indices.");
545           }
546           if (temp >= 0) {
547             all_indices.insert(temp);
548           }
549           staging_indices_.emplace_back(temp);
550         }
551       }
552 
553       // Start Worker threads.
554       if (reader->Contains(prefix(), kWorkerThreadsRunning)) {
555         worker_threads_.reserve(dataset()->num_threads());
556         for (size_t i = 0; i < dataset()->num_threads(); ++i) {
557           std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
558           worker_threads_.emplace_back(ctx->StartThread(
559               strings::StrCat(kDataParallelInterleaveWorker, "_", i),
560               [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
561         }
562       }
563       return Status::OK();
564     }
565 
GetTraceMeMetadata() const566     TraceMeMetadata GetTraceMeMetadata() const override {
567       return dataset()->traceme_metadata_;
568     }
569 
570    private:
571     // OutputElem contains the information from a call to GetNext by an output
572     // iterator.
573     struct OutputElem {
574       // The output iterator sets `status` if getting the output element
575       // fails.
576       Status status;
577       // The buffered data element.
578       std::vector<Tensor> output;
579       int64 id = -1;
580 
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem581       explicit OutputElem(const Status& s) : status(s) {}
OutputElemtensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::OutputElem582       OutputElem(const Status& s, int64 id) : status(s), id(id) {}
583     };
584 
585     // Worker threads operate on their relevant WorkerState structs.
586     //
587     // WorkerState's fields are all protected by mu_;
588     struct WorkerState {
589       // The arguments to be used to construct an output iterator.
590       std::vector<Tensor> input;
591       // The buffered output elements.
592       std::deque<OutputElem> outputs;
593       // Set to true iff the worker thread expects to append more elements to
594       // outputs. is_producing can be false despite !outputs.empty().
595       // Concretely, all output elements will have been consumed only when:
596       // is_producing == false && outputs.empty();
597       bool is_producing = false;
598       // Condition variable used to coordinate between threads. The worker
599       // thread waits on this condition variable when it is either (1) waiting
600       // for the main thread to add arguments to `input`, or (2) waiting for
601       // the main thread to consume an element of `outputs`. The main thread
602       // waits on cond_var if it is waiting for the worker thread to produce
603       // an element into `outputs` (this implies deterministic==true).
604       condition_variable cond_var;
605 
MayHaveElementstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState606       inline bool MayHaveElements() const {
607         return is_producing || !outputs.empty();
608       }
609 
610       // Sets inputs for a worker thread and notifies it to start processing.
SetInputstensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerState611       void SetInputs(const Status& s, std::vector<Tensor> input_arguments) {
612         if (s.ok()) {
613           DCHECK(!MayHaveElements())
614               << "Tried to start inputs, despite already producing!";
615           input = std::move(input_arguments);
616           is_producing = true;
617           cond_var.notify_one();
618         } else {
619           outputs.emplace_back(s);
620         }
621       }
622     };
623 
624     // The internal state of a worker thread that is not already captured
625     // in its `WorkerState`.
626     //
627     // This is needed only for checkpointing purposes. We keep this
628     // separate from `WorkerState` and guard its fields using a separate
629     // lock `ckpt_mu_` so as to not affect the performance of main pipeline.
630     struct WorkerThreadState {
631       // The output element that has been produced from the input iterator
632       // and is waiting to be added to `WorkerState.outputs`.
633       OutputElem output_elem;
634 
635       // Whether the input iterator returned an `end_of_sequence`.
636       bool end_of_sequence = false;
637 
638       // Status returned from `MakeIteratorFromInputElement`.
639       Status iterator_creation_status;
640 
641       // The arguments to be used to construct `iterator`.
642       std::vector<Tensor> input;
643 
644       std::unique_ptr<IteratorBase> iterator;
645 
WorkerThreadStatetensorflow::data::experimental::ParallelInterleaveDatasetOp::Dataset::Iterator::WorkerThreadState646       WorkerThreadState() : output_elem(Status::OK()) {}
647     };
648 
CancelThreads()649     void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
650       mutex_lock l(mu_);
651       cancelled_ = true;
652       for (auto& worker : workers_) {
653         worker.cond_var.notify_all();
654       }
655     }
656 
EnsureWorkerThreadsStarted(IteratorContext * ctx)657     Status EnsureWorkerThreadsStarted(IteratorContext* ctx)
658         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
659       if (worker_threads_.empty() && input_impl_) {
660         worker_threads_.reserve(dataset()->num_threads());
661         for (int64 i = 0; i < dataset()->num_threads(); ++i) {
662           std::vector<Tensor> args;
663           bool end_of_input = false;
664           Status s = input_impl_->GetNext(ctx, &args, &end_of_input);
665           if (end_of_input) {
666             input_impl_.reset();
667             return Status::OK();
668           }
669           workers_[i].SetInputs(s, std::move(args));
670           std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
671           worker_threads_.push_back(ctx->StartThread(
672               strings::StrCat(kDataParallelInterleaveWorker, "_", i),
673               [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
674           if (i < dataset()->cycle_length_) {
675             interleave_indices_.push_back(i);
676           } else {
677             staging_indices_.push_back(i);
678           }
679         }
680         DCHECK(interleave_indices_.size() == dataset()->cycle_length_);
681         DCHECK(staging_indices_.size() == dataset()->prefetch_input_elements_);
682       }
683       return Status::OK();
684     }
685 
686     // Produces elements into the worker's output buffers.
WorkerThread(const std::shared_ptr<IteratorContext> & ctx,const int64 thread_index)687     void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
688                       const int64 thread_index) {
689       // Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
690       //
691       // 1. Any local state that may need to be checkpointed should be kept
692       //    in `worker_thread_states_[thread_index]`.
693       // 2. `WorkerThreadState` should contain state that is needed only for
694       //    checkpointing, i.e., if we were to remove checkpointing support,
695       //    we could keep that state as local variables in this thread.
696       // 3. This thread should only read/write state at `thread_index`
697       //    and should not access other thread states.
698       // 4. When restoring from checkpoint, threads are started only after
699       //    the restore is complete.
700       // 5. Once restored from a checkpoint, the local state is edited only
701       //    by this thread. 3 & 4 allow making assumptions like temporarily
702       //    caching local state in this thread and using it outside a lock
703       //    e.g. `make_new_iterator`.
704       // 6. `ckpt_mu_` should be wisely used to create *consistent*
705       //    checkpoint markers.
706 
707       // std::function arguments are copy-constructable, so we pass raw
708       // pointers, and then immediately wrap them to ensure correct ownership.
709       RecordStart(ctx.get());
710       auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
711         mutex_lock l(mu_);
712         workers_[thread_index].cond_var.notify_all();
713         RecordStop(ctx.get());
714       });
715       bool make_new_iterator;
716       {
717         tf_shared_lock l(ckpt_mu_);
718         // Decide whether a new iterator should be built.
719         // 1. If there is an existing iterator, we use it.
720         // 2. If there was an error in iterator creation that could not be
721         //    notified to the client we attempt to send that to the client
722         //    first.
723         make_new_iterator =
724             worker_thread_states_[thread_index].iterator == nullptr &&
725             worker_thread_states_[thread_index].iterator_creation_status.ok();
726       }
727       // Even though `make_new_iterator` has cached values from
728       // `worker_thread_states_[thread_index]` which is guarded by ckpt_mu_,
729       // it is safe to *read* `make_new_iterator`outside of a lock without
730       // worrying about concurrent changes to values in
731       // `worker_thread_states_[thread_index]`. See comment at the start of
732       // this function for details.
733       while (true) {
734         // Whether creation of the iterator succeeded.
735         Status iterator_creation_status;
736         // 1. Build a new iterator or use the existing one.
737         if (make_new_iterator) {
738           // 1a. Get new input tensors or use the exiting ones.
739           bool read_new_input;
740           {
741             tf_shared_lock l(ckpt_mu_);
742             // worker_thread_states_[thread_index].input will be non-empty
743             // if checkpointing happened at CHECKPOINT_MARKER_A.
744             read_new_input = worker_thread_states_[thread_index].input.empty();
745           }
746 
747           if (read_new_input) {
748             mutex_lock l(mu_);
749             while (!cancelled_ && !workers_[thread_index].is_producing) {
750               RecordStop(ctx.get());
751               workers_[thread_index].cond_var.wait(l);
752               RecordStart(ctx.get());
753             }
754             if (cancelled_) return;
755             // Copy the input tensors so that we do not need to block on `mu_`
756             // when building the iterator.
757             // We keep a copy of the input tensors in
758             // `WorkerThreadState.input` till the iterator is in use. This is
759             // used in `RestoreInternal` to re-build the iterator.
760             // TODO(b/78046638): Explore ways to avoid tracking the input
761             // tensors.
762             tf_shared_lock ckpt_l(ckpt_mu_);
763             worker_thread_states_[thread_index].input.swap(
764                 workers_[thread_index].input);
765             // CHECKPOINT_MARKER_A
766             // We have the input tensors but have not built the iterator yet.
767           }
768 
769           // 1b. Run the user defined function to produce a new iterator.
770           {
771             tf_shared_lock l(ckpt_mu_);
772             worker_thread_states_[thread_index].iterator_creation_status =
773                 MakeIteratorFromInputElement(
774                     ctx.get(), this, worker_thread_states_[thread_index].input,
775                     thread_index, *instantiated_captured_func_, prefix(),
776                     &worker_thread_states_[thread_index].iterator,
777                     model_node());
778             iterator_creation_status =
779                 worker_thread_states_[thread_index].iterator_creation_status;
780             if (!iterator_creation_status.ok()) {
781               worker_thread_states_[thread_index].input.clear();
782             }
783             // CHECKPOINT_MARKER_B
784             // Either an iterator has been successfully built and placed in
785             // `worker_thread_states_[thread_index].iterator` or it failed and
786             // a non-OK status has been put in
787             // `worker_thread_states_[thread_index].iterator_creation_status`.
788           }
789         } else {
790           tf_shared_lock l(ckpt_mu_);
791           iterator_creation_status =
792               worker_thread_states_[thread_index].iterator_creation_status;
793           // Mark that we have used up the restored iterator.
794           make_new_iterator = true;
795         }
796         // 2. Start producing elements or send error state to client if
797         //    iterator creation failed.
798         if (!iterator_creation_status.ok()) {
799           mutex_lock l(mu_);
800           // Wait for space in the prefetch queue.
801           while (!cancelled_ && workers_[thread_index].outputs.size() ==
802                                     dataset()->buffer_output_elements_) {
803             RecordStop(ctx.get());
804             workers_[thread_index].cond_var.wait(l);
805             RecordStart(ctx.get());
806           }
807           if (cancelled_) return;
808           tf_shared_lock ckpt_l(ckpt_mu_);
809           workers_[thread_index].outputs.emplace_back(iterator_creation_status);
810           workers_[thread_index].is_producing = false;
811           worker_thread_states_[thread_index].iterator_creation_status =
812               Status::OK();
813           // CHECKPOINT_MARKER_C
814           // Non-OK iterator creation status has been notified to the
815           // client.
816           if (deterministic_) {
817             workers_[thread_index].cond_var.notify_one();
818           } else {
819             any_element_available_cond_var_.notify_one();
820           }
821         } else {
822           bool end_of_sequence = false;
823           while (!end_of_sequence) {
824             // 3.a Produce an element!
825             {
826               tf_shared_lock ckpt_l(ckpt_mu_);
827               if (worker_thread_states_[thread_index].output_elem.status.ok() &&
828                   worker_thread_states_[thread_index]
829                       .output_elem.output.empty() &&
830                   !worker_thread_states_[thread_index].end_of_sequence) {
831                 int64& id = worker_thread_states_[thread_index].output_elem.id;
832                 profiler::TraceMe traceme(
833                     [&] {
834                       id = profiler::TraceMe::NewActivityId();
835                       return profiler::TraceMeEncode(
836                           "ParallelInterleaveProduce", {{"element_id", id}});
837                     },
838                     profiler::kInfo);
839                 worker_thread_states_[thread_index].output_elem.status =
840                     worker_thread_states_[thread_index].iterator->GetNext(
841                         ctx.get(),
842                         &worker_thread_states_[thread_index].output_elem.output,
843                         &worker_thread_states_[thread_index].end_of_sequence);
844                 end_of_sequence =
845                     worker_thread_states_[thread_index].end_of_sequence;
846               } else {
847                 end_of_sequence =
848                     worker_thread_states_[thread_index].end_of_sequence;
849               }
850               // CHECKPOINT_MARKER_D
851               // An element has been read or an error or end_of_sequence has
852               // been received from the input iterator and is waiting to be
853               // sent to client.
854             }
855 
856             // 3.b Make it available to the client.
857             {
858               mutex_lock l(mu_);
859 
860               // Wait for space in the prefetch queue.
861               while (!cancelled_ && workers_[thread_index].outputs.size() ==
862                                         dataset()->buffer_output_elements_) {
863                 RecordStop(ctx.get());
864                 workers_[thread_index].cond_var.wait(l);
865                 RecordStart(ctx.get());
866               }
867               if (cancelled_) return;
868 
869               tf_shared_lock ckpt_l(ckpt_mu_);
870               workers_[thread_index].is_producing = !end_of_sequence;
871 
872               // Output the element.
873 
874               // Move the temporary state in WorkerThreadState to WorkerState
875               // and mark it as used.
876               if (end_of_sequence) {
877                 worker_thread_states_[thread_index].iterator.reset();
878                 worker_thread_states_[thread_index].input.clear();
879                 worker_thread_states_[thread_index].end_of_sequence = false;
880               } else {
881                 workers_[thread_index].outputs.emplace_back(
882                     worker_thread_states_[thread_index].output_elem.status,
883                     worker_thread_states_[thread_index].output_elem.id);
884                 workers_[thread_index].outputs.back().output.swap(
885                     worker_thread_states_[thread_index].output_elem.output);
886               }
887               worker_thread_states_[thread_index].output_elem.status =
888                   Status::OK();
889               if (deterministic_) {
890                 workers_[thread_index].cond_var.notify_one();
891               } else {
892                 any_element_available_cond_var_.notify_one();
893               }
894               // CHECKPOINT_MARKER_E
895               // Output element or iterator status has been sent to the
896               // client.
897             }
898           }
899         }
900       }
901     }
902 
WriteWorkerStateLocked(IteratorStateWriter * writer,int index)903     Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index)
904         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
905       string iterator_name =
906           strings::StrCat(prefix(), "::", kWorker, "_", index);
907       TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kInputSize,
908                                              workers_[index].input.size()));
909       for (int i = 0; i < workers_[index].input.size(); ++i) {
910         TF_RETURN_IF_ERROR(writer->WriteTensor(iterator_name,
911                                                strings::StrCat(kInput, "_", i),
912                                                workers_[index].input[i]));
913       }
914       TF_RETURN_IF_ERROR(writer->WriteScalar(iterator_name, kOutputsSize,
915                                              workers_[index].outputs.size()));
916       for (int i = 0; i < workers_[index].outputs.size(); ++i) {
917         TF_RETURN_IF_ERROR(WriteOutputElemLocked(
918             writer, workers_[index].outputs[i], iterator_name,
919             strings::StrCat(kOutputs, "_", i)));
920       }
921       if (workers_[index].is_producing) {
922         TF_RETURN_IF_ERROR(
923             writer->WriteScalar(iterator_name, kIsProducing, ""));
924       }
925       return Status::OK();
926     }
927 
ReadWorkerStateLocked(IteratorStateReader * reader,int index,IteratorContext * ctx)928     Status ReadWorkerStateLocked(IteratorStateReader* reader, int index,
929                                  IteratorContext* ctx)
930         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
931       string worker_prefix =
932           strings::StrCat(prefix(), "::", kWorker, "_", index);
933       // Restore inputs.
934       int64 input_size;
935       TF_RETURN_IF_ERROR(
936           reader->ReadScalar(worker_prefix, kInputSize, &input_size));
937       workers_[index].input.reserve(input_size);
938       for (int i = 0; i < input_size; ++i) {
939         workers_[index].input.emplace_back();
940         TF_RETURN_IF_ERROR(reader->ReadTensor(worker_prefix,
941                                               strings::StrCat(kInput, "_", i),
942                                               &workers_[index].input.back()));
943       }
944       int64 outputs_size;
945       TF_RETURN_IF_ERROR(
946           reader->ReadScalar(worker_prefix, kOutputsSize, &outputs_size));
947       for (int i = 0; i < outputs_size; ++i) {
948         workers_[index].outputs.emplace_back(Status::OK());
949         TF_RETURN_IF_ERROR(ReadOutputElemLocked(
950             reader, &workers_[index].outputs.back(), worker_prefix,
951             strings::StrCat(kOutputs, "_", i)));
952       }
953       if (reader->Contains(worker_prefix, kIsProducing)) {
954         workers_[index].is_producing = true;
955       } else {
956         workers_[index].is_producing = false;
957       }
958       return Status::OK();
959     }
960 
WriteWorkerThreadStateLocked(SerializationContext * ctx,IteratorStateWriter * writer,int index)961     Status WriteWorkerThreadStateLocked(SerializationContext* ctx,
962                                         IteratorStateWriter* writer, int index)
963         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
964       string iterator_name =
965           strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
966       if (worker_thread_states_[index].iterator != nullptr) {
967         TF_RETURN_IF_ERROR(
968             SaveInput(ctx, writer, worker_thread_states_[index].iterator));
969       } else {
970         TF_RETURN_IF_ERROR(
971             writer->WriteScalar(iterator_name, kIteratorExhausted, ""));
972       }
973       TF_RETURN_IF_ERROR(
974           writer->WriteScalar(iterator_name, kInputSize,
975                               worker_thread_states_[index].input.size()));
976       for (int i = 0; i < worker_thread_states_[index].input.size(); ++i) {
977         TF_RETURN_IF_ERROR(
978             writer->WriteTensor(iterator_name, strings::StrCat(kInput, "_", i),
979                                 worker_thread_states_[index].input[i]));
980       }
981       TF_RETURN_IF_ERROR(WriteStatusLocked(
982           writer, iterator_name, kIteratorCreationStatus,
983           worker_thread_states_[index].iterator_creation_status));
984       TF_RETURN_IF_ERROR(WriteOutputElemLocked(
985           writer, worker_thread_states_[index].output_elem, iterator_name,
986           kOutput));
987       if (worker_thread_states_[index].end_of_sequence) {
988         TF_RETURN_IF_ERROR(
989             writer->WriteScalar(iterator_name, kEndOfSequence, ""));
990       }
991       return Status::OK();
992     }
993 
ReadWorkerThreadStateLocked(IteratorStateReader * reader,int index,IteratorContext * ctx,WorkerThreadState * state)994     Status ReadWorkerThreadStateLocked(IteratorStateReader* reader, int index,
995                                        IteratorContext* ctx,
996                                        WorkerThreadState* state) {
997       string worker_prefix =
998           strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
999       // Restore inputs.
1000       int64 input_size;
1001       TF_RETURN_IF_ERROR(
1002           reader->ReadScalar(worker_prefix, kInputSize, &input_size));
1003       state->input.reserve(input_size);
1004       for (int i = 0; i < input_size; ++i) {
1005         state->input.emplace_back();
1006         TF_RETURN_IF_ERROR(reader->ReadTensor(worker_prefix,
1007                                               strings::StrCat(kInput, "_", i),
1008                                               &state->input.back()));
1009       }
1010       // Restore iterator
1011       if (reader->Contains(worker_prefix, kIteratorExhausted)) {
1012         state->iterator.reset();
1013       } else {
1014         std::unique_ptr<IteratorBase> iterator;
1015         // NOTE: We intentionally ignore resource modeling outside GetNext().
1016         TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
1017             ctx, this, state->input, index, *instantiated_captured_func_,
1018             prefix(), &iterator, /*node=*/nullptr));
1019         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
1020         state->iterator.swap(iterator);
1021       }
1022       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, worker_prefix,
1023                                           kIteratorCreationStatus,
1024                                           &state->iterator_creation_status));
1025       TF_RETURN_IF_ERROR(ReadOutputElemLocked(reader, &state->output_elem,
1026                                               worker_prefix, kOutput));
1027       if (reader->Contains(worker_prefix, kEndOfSequence)) {
1028         state->end_of_sequence = true;
1029       } else {
1030         state->end_of_sequence = false;
1031       }
1032       return Status::OK();
1033     }
1034 
WriteOutputElemLocked(IteratorStateWriter * writer,const OutputElem & output_elem,const string & iterator_name,const string & prefix)1035     Status WriteOutputElemLocked(IteratorStateWriter* writer,
1036                                  const OutputElem& output_elem,
1037                                  const string& iterator_name,
1038                                  const string& prefix)
1039         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1040       TF_RETURN_IF_ERROR(WriteStatusLocked(
1041           writer, iterator_name, strings::StrCat(prefix, "_", kStatus),
1042           output_elem.status));
1043       TF_RETURN_IF_ERROR(writer->WriteScalar(
1044           iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1045           output_elem.output.size()));
1046       for (int i = 0; i < output_elem.output.size(); ++i) {
1047         TF_RETURN_IF_ERROR(writer->WriteTensor(
1048             iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i),
1049             output_elem.output[i]));
1050       }
1051       return Status::OK();
1052     }
1053 
ReadOutputElemLocked(IteratorStateReader * reader,OutputElem * output_elem,const string & iterator_name,const string & prefix)1054     Status ReadOutputElemLocked(IteratorStateReader* reader,
1055                                 OutputElem* output_elem,
1056                                 const string& iterator_name,
1057                                 const string& prefix) {
1058       TF_RETURN_IF_ERROR(ReadStatusLocked(reader, iterator_name,
1059                                           strings::StrCat(prefix, "_", kStatus),
1060                                           &output_elem->status));
1061       int64 output_size;
1062       TF_RETURN_IF_ERROR(reader->ReadScalar(
1063           iterator_name, strings::StrCat(prefix, "_", kOutputSize),
1064           &output_size));
1065       output_elem->output.reserve(output_size);
1066       for (int i = 0; i < output_size; ++i) {
1067         output_elem->output.emplace_back();
1068         TF_RETURN_IF_ERROR(reader->ReadTensor(
1069             iterator_name, strings::StrCat(prefix, "_", kOutput, "_", i),
1070             &output_elem->output.back()));
1071       }
1072       return Status::OK();
1073     }
1074 
WriteStatusLocked(IteratorStateWriter * writer,const string & iterator_name,const string & prefix,const Status & status)1075     Status WriteStatusLocked(IteratorStateWriter* writer,
1076                              const string& iterator_name, const string& prefix,
1077                              const Status& status)
1078         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
1079       TF_RETURN_IF_ERROR(writer->WriteScalar(
1080           iterator_name, strings::StrCat(prefix, "_", kCode),
1081           static_cast<int64>(status.code())));
1082       if (!status.ok()) {
1083         TF_RETURN_IF_ERROR(writer->WriteScalar(
1084             iterator_name, strings::StrCat(prefix, "_", KMessage),
1085             status.error_message()));
1086       }
1087       return Status::OK();
1088     }
1089 
ReadStatusLocked(IteratorStateReader * reader,const string & iterator_name,const string & prefix,Status * status)1090     Status ReadStatusLocked(IteratorStateReader* reader,
1091                             const string& iterator_name, const string& prefix,
1092                             Status* status) {
1093       int64 code_int;
1094       TF_RETURN_IF_ERROR(reader->ReadScalar(
1095           iterator_name, strings::StrCat(prefix, "_", kCode), &code_int));
1096       error::Code code = static_cast<error::Code>(code_int);
1097 
1098       if (code != error::Code::OK) {
1099         tstring error_message;
1100         TF_RETURN_IF_ERROR(reader->ReadScalar(
1101             iterator_name, strings::StrCat(prefix, "_", KMessage),
1102             &error_message));
1103         *status = Status(code, error_message);
1104       } else {
1105         *status = Status::OK();
1106       }
1107       return Status::OK();
1108     }
1109 
1110     // Mutex & condition variable to guard mutable iterator internals and
1111     // coordinate among worker threads and client thread[s].
1112     mutex mu_ TF_ACQUIRED_BEFORE(ckpt_mu_);
1113     // The main thread waits on this condition variable if running in
1114     // nondeterministic mode and no values are available.
1115     condition_variable any_element_available_cond_var_;
1116     // Whether outputs must be produced in deterministic order.
1117     const bool deterministic_;
1118     // Mutex used to wait for a consistent state while checkpointing.
1119     // Only Save and Restore require an exclusive lock on this mutex. In
1120     // other scenarios we just acquire a shared lock so the pipeline's
1121     // performance should not be affected in the absence of checkpointing.
1122     // A thread must not wait on any condition variable while holding
1123     // `ckpt_mu_` in either shared or exclusive modes.
1124     mutex ckpt_mu_;
1125 
1126     // The iterator producing elements which are converted to datasets by
1127     // the dataset()->captured_func_ then interleaved together.
1128     // input_impl_ is reset when we have exhausted its input.
1129     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
1130 
1131     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
1132 
1133     // The WorkerState structs the worker threads operate on.
1134     // workers_ elements are in at most one of interleave_ and staging_.
1135     std::vector<WorkerState> workers_ TF_GUARDED_BY(mu_);
1136 
1137     // Stores the temporary state of WorkerThreads which is not stored in
1138     // WorkerState. This is used for checkpointing purposes only.
1139     std::vector<WorkerThreadState> worker_thread_states_
1140         TF_GUARDED_BY(ckpt_mu_);
1141 
1142     // Indices in `workers_` of iterators to interleave.
1143     std::vector<int64> interleave_indices_ TF_GUARDED_BY(mu_);
1144     // Indices in `workers_` of prefetched iterators.
1145     std::deque<int64> staging_indices_ TF_GUARDED_BY(mu_);
1146 
1147     // The index into output_elements_ for next element to produce.
1148     size_t next_index_ TF_GUARDED_BY(mu_) = 0;
1149     // The number of items produced so far within the block
1150     size_t block_count_ TF_GUARDED_BY(mu_) = 0;
1151     // Flag to instruct the worker threads to exit.
1152     bool cancelled_ TF_GUARDED_BY(mu_) = false;
1153     // The worker threads. This must be last to ensure the
1154     // threads have exited before any other members are deallocated.
1155     // TODO(b/65178177): Avoid allocating additional threads.
1156     std::vector<std::unique_ptr<Thread>> worker_threads_ TF_GUARDED_BY(mu_);
1157   };
1158 
1159   const DatasetBase* const input_;
1160   const std::unique_ptr<CapturedFunction> captured_func_;
1161   const int64 cycle_length_;
1162   const int64 block_length_;
1163   const DeterminismPolicy deterministic_;
1164   const int64 buffer_output_elements_;
1165   const int64 prefetch_input_elements_;
1166   const DataTypeVector output_types_;
1167   const std::vector<PartialTensorShape> output_shapes_;
1168   const TraceMeMetadata traceme_metadata_;
1169   const int op_version_;
1170 };
1171 
ParallelInterleaveDatasetOp(OpKernelConstruction * ctx)1172 ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
1173     OpKernelConstruction* ctx)
1174     : UnaryDatasetOpKernel(ctx),
1175       op_version_(ctx->HasAttr(kDeterministic) ? 2 : 1) {
1176   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
1177                                                &func_metadata_));
1178   if (op_version_ == 2) {
1179     std::string deterministic;
1180     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
1181     OP_REQUIRES_OK(
1182         ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
1183   }
1184   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1185   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1186 }
1187 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1188 void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
1189                                               DatasetBase* input,
1190                                               DatasetBase** output) {
1191   int64 cycle_length = 0;
1192   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
1193   OP_REQUIRES(ctx, cycle_length > 0,
1194               errors::InvalidArgument("`cycle_length` must be > 0"));
1195 
1196   int64 block_length = 0;
1197   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
1198   OP_REQUIRES(ctx, block_length > 0,
1199               errors::InvalidArgument("`block_length` must be > 0"));
1200 
1201   if (op_version_ == 1) {
1202     bool sloppy = false;
1203     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSloppy, &sloppy));
1204     if (sloppy) {
1205       deterministic_ =
1206           DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
1207     } else {
1208       deterministic_ =
1209           DeterminismPolicy(DeterminismPolicy::Type::kDeterministic);
1210     }
1211   }
1212 
1213   int64 buffer_output_elements = 0;
1214   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
1215                                           &buffer_output_elements));
1216   OP_REQUIRES(ctx, buffer_output_elements > 0,
1217               errors::InvalidArgument("`buffer_output_elements` must be > 0"));
1218 
1219   int64 prefetch_input_elements = 0;
1220   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
1221                                           &prefetch_input_elements));
1222   OP_REQUIRES(
1223       ctx, prefetch_input_elements >= 0,
1224       errors::InvalidArgument("`prefetch_input_elements` must be >= 0"));
1225 
1226   std::unique_ptr<CapturedFunction> captured_func;
1227   OP_REQUIRES_OK(ctx,
1228                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
1229                                           &captured_func));
1230 
1231   *output = new Dataset(ctx, input, std::move(captured_func), cycle_length,
1232                         block_length, deterministic_, buffer_output_elements,
1233                         prefetch_input_elements, output_types_, output_shapes_,
1234                         op_version_);
1235 }
1236 
1237 namespace {
1238 REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
1239                         ParallelInterleaveDatasetOp);
1240 REGISTER_KERNEL_BUILDER(
1241     Name("ExperimentalParallelInterleaveDataset").Device(DEVICE_CPU),
1242     ParallelInterleaveDatasetOp);
1243 REGISTER_KERNEL_BUILDER(
1244     Name("LegacyParallelInterleaveDatasetV2").Device(DEVICE_CPU),
1245     ParallelInterleaveDatasetOp);
1246 
1247 REGISTER_INPUT_COLOCATION_EXEMPTION("ParallelInterleaveDataset");
1248 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalParallelInterleaveDataset");
1249 REGISTER_INPUT_COLOCATION_EXEMPTION("LegacyParallelInterleaveDatasetV2");
1250 
1251 }  // namespace
1252 }  // namespace experimental
1253 }  // namespace data
1254 }  // namespace tensorflow
1255