• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/parallel_interleave_dataset_op.h"
16 
17 #include <atomic>
18 #include <deque>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/strings/str_format.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/common_runtime/metrics.h"
26 #include "tensorflow/core/framework/dataset.h"
27 #include "tensorflow/core/framework/model.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/data/captured_function.h"
33 #include "tensorflow/core/kernels/data/dataset_utils.h"
34 #include "tensorflow/core/kernels/data/name_utils.h"
35 #include "tensorflow/core/kernels/data/stats_utils.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/lib/gtl/cleanup.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/lib/strings/stringprintf.h"
42 #include "tensorflow/core/platform/blocking_counter.h"
43 #include "tensorflow/core/platform/cpu_info.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/stringprintf.h"
46 #include "tensorflow/core/profiler/lib/traceme.h"
47 #include "tensorflow/core/profiler/lib/traceme_encode.h"
48 
49 namespace tensorflow {
50 namespace data {
51 
52 // See documentation in ../../ops/dataset_ops.cc for a high-level
53 // description of the following op.
54 
55 /* static */ constexpr const char* const
56     ParallelInterleaveDatasetOp::kDatasetType;
57 /* static */ constexpr const char* const
58     ParallelInterleaveDatasetOp::kInputDataset;
59 /* static */ constexpr const char* const
60     ParallelInterleaveDatasetOp::kOtherArguments;
61 /* static */ constexpr const char* const
62     ParallelInterleaveDatasetOp::kCycleLength;
63 /* static */ constexpr const char* const
64     ParallelInterleaveDatasetOp::kBlockLength;
65 /* static */ constexpr const char* const
66     ParallelInterleaveDatasetOp::kBufferOutputElements;
67 /* static */ constexpr const char* const
68     ParallelInterleaveDatasetOp::kPrefetchInputElements;
69 /* static */ constexpr const char* const
70     ParallelInterleaveDatasetOp::kNumParallelCalls;
71 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
72 /* static */ constexpr const char* const
73     ParallelInterleaveDatasetOp::kTarguments;
74 /* static */ constexpr const char* const
75     ParallelInterleaveDatasetOp::kOutputTypes;
76 /* static */ constexpr const char* const
77     ParallelInterleaveDatasetOp::kOutputShapes;
78 /* static */ constexpr const char* const
79     ParallelInterleaveDatasetOp::kDeterministic;
80 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
81 
82 namespace {
83 
84 constexpr char kTfDataParallelInterleaveWorkerPool[] =
85     "tf_data_parallel_interleave_worker_pool";
86 constexpr char kParallelism[] = "parallelism";
87 constexpr char kBlockIndex[] = "block_index";
88 constexpr char kCycleIndex[] = "cycle_index";
89 constexpr char kEndOfInput[] = "end_of_input";
90 constexpr char kElementIdCounter[] = "element_id_counter";
91 constexpr char kCurrentElements[] = "current_elements";
92 constexpr char kCurrentElementsSize[] = "current_elements.size";
93 constexpr char kFutureElements[] = "future_elements";
94 constexpr char kFutureElementsSize[] = "future_elements.size";
95 constexpr char kResultsSuffix[] = ".results";
96 constexpr char kCodeSuffix[] = ".code";
97 constexpr char kErrorMessageSuffix[] = ".error_message";
98 constexpr char kIdSuffix[] = ".id";
99 constexpr char kSizeSuffix[] = ".size";
100 constexpr char kInputsSuffix[] = ".inputs";
101 constexpr char kIsReadySuffix[] = ".is_ready";
102 
103 constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
104 constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
105 constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
106 
107 // `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
108 // elements that will be prefetched ahead of time. The purpose of prefetching
109 // future cycle elements is to overlap expensive initialization (e.g. opening of
110 // a remote file) with other computation.
111 constexpr double kDefaultCyclePrefetchFactor = 2.0L;
112 
113 // `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
114 // per-iterator results that will be prefetched ahead of time. The `+ 1` is to
115 // match the behavior of the original implementation.
116 constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
117 
118 // Period between reporting dataset statistics.
119 constexpr int kStatsReportingPeriodMillis = 1000;
120 
CeilDiv(int64 numerator,int64 denominator)121 inline int64 CeilDiv(int64 numerator, int64 denominator) {
122   return (numerator + denominator - 1) / denominator;
123 }
124 
ComputeBufferOutputElements(int64 configured_buffer_output_elements,int64 block_length)125 int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
126                                   int64 block_length) {
127   if (configured_buffer_output_elements != model::kAutotune) {
128     return configured_buffer_output_elements;
129   }
130   return kDefaultPerIteratorPrefetchFactor * block_length + 1;
131 }
132 
ComputePrefetchInputElements(int64 configured_prefetch_input_elements,int64 cycle_length)133 int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
134                                    int64 cycle_length) {
135   if (configured_prefetch_input_elements != model::kAutotune) {
136     return configured_prefetch_input_elements;
137   }
138   return kDefaultCyclePrefetchFactor * cycle_length;
139 }
140 
OpVersionFromOpName(absl::string_view op_name)141 int64 OpVersionFromOpName(absl::string_view op_name) {
142   if (op_name == kParallelInterleaveDatasetV2) {
143     return 2;
144   } else if (op_name == kParallelInterleaveDatasetV3) {
145     return 3;
146   } else {
147     DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
148     return 4;
149   }
150 }
151 
152 }  // namespace
153 
154 // The motivation for creating an alternative implementation of parallel
155 // interleave is to decouple the degree of parallelism from the cycle length.
156 // This makes it possible to change the degree of parallelism (e.g. through
157 // auto-tuning) without changing the cycle length (which would change the order
158 // in which elements are produced).
159 class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
160  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64 cycle_length,int64 block_length,int64 buffer_output_elements,int64 prefetch_input_elements,int64 num_parallel_calls,DeterminismPolicy deterministic,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int op_version)161   Dataset(OpKernelContext* ctx, const DatasetBase* input,
162           std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
163           int64 block_length, int64 buffer_output_elements,
164           int64 prefetch_input_elements, int64 num_parallel_calls,
165           DeterminismPolicy deterministic, const DataTypeVector& output_types,
166           const std::vector<PartialTensorShape>& output_shapes, int op_version)
167       : DatasetBase(DatasetContext(ctx)),
168         input_(input),
169         captured_func_(std::move(captured_func)),
170         cycle_length_(cycle_length),
171         block_length_(block_length),
172         buffer_output_elements_(
173             ComputeBufferOutputElements(buffer_output_elements, block_length)),
174         prefetch_input_elements_(ComputePrefetchInputElements(
175             prefetch_input_elements, cycle_length)),
176         num_parallel_calls_(num_parallel_calls),
177         deterministic_(deterministic),
178         output_types_(output_types),
179         output_shapes_(output_shapes),
180         op_version_(op_version),
181         traceme_metadata_(
182             {{"autotune",
183               num_parallel_calls == model::kAutotune ? "true" : "false"},
184              {"block_length",
185               strings::Printf("%lld", static_cast<long long>(block_length))},
186              {"cycle_length",
187               strings::Printf("%lld", static_cast<long long>(cycle_length))},
188              {"deterministic",
189               deterministic.IsNondeterministic() ? "false" : "true"}}) {
190     input_->Ref();
191   }
192 
~Dataset()193   ~Dataset() override { input_->Unref(); }
194 
MakeIteratorInternal(const string & prefix) const195   std::unique_ptr<IteratorBase> MakeIteratorInternal(
196       const string& prefix) const override {
197     name_utils::IteratorPrefixParams params;
198     params.op_version = op_version_;
199     bool deterministic =
200         deterministic_.IsDeterministic() || deterministic_.IsDefault();
201     return absl::make_unique<ParallelInterleaveIterator>(
202         ParallelInterleaveIterator::Params{
203             this,
204             name_utils::IteratorPrefix(
205                 ParallelInterleaveDatasetOp::kDatasetType, prefix, params)},
206         deterministic);
207   }
208 
output_dtypes() const209   const DataTypeVector& output_dtypes() const override { return output_types_; }
210 
output_shapes() const211   const std::vector<PartialTensorShape>& output_shapes() const override {
212     return output_shapes_;
213   }
214 
DebugString() const215   string DebugString() const override {
216     name_utils::DatasetDebugStringParams params;
217     params.op_version = op_version_;
218     return name_utils::DatasetDebugString(
219         ParallelInterleaveDatasetOp::kDatasetType, params);
220   }
221 
Cardinality() const222   int64 Cardinality() const override {
223     int64 n = input_->Cardinality();
224     if (n == kInfiniteCardinality) {
225       return n;
226     }
227     return kUnknownCardinality;
228   }
229 
InputDatasets(std::vector<const DatasetBase * > * inputs) const230   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
231     inputs->push_back(input_);
232     return Status::OK();
233   }
234 
CheckExternalState() const235   Status CheckExternalState() const override {
236     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
237     return input_->CheckExternalState();
238   }
239 
240  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const241   Status AsGraphDefInternal(SerializationContext* ctx,
242                             DatasetGraphDefBuilder* b,
243                             Node** output) const override {
244     std::vector<std::pair<size_t, Node*>> inputs;
245     std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
246     int input_index = 0;
247 
248     Node* input_node;
249     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
250     inputs.emplace_back(input_index++, input_node);
251 
252     std::vector<Node*> other_arguments;
253     DataTypeVector other_arguments_types;
254     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
255                                                   &other_arguments_types));
256     list_inputs.emplace_back(input_index++, other_arguments);
257 
258     Node* cycle_length_node;
259     TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
260     inputs.emplace_back(input_index++, cycle_length_node);
261 
262     Node* block_length_node;
263     TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
264     inputs.emplace_back(input_index++, block_length_node);
265 
266     if (op_version_ >= 4) {
267       Node* buffer_output_elements_node;
268       TF_RETURN_IF_ERROR(
269           b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
270       inputs.emplace_back(input_index++, buffer_output_elements_node);
271 
272       Node* prefetch_input_elements_node;
273       TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
274                                       &prefetch_input_elements_node));
275       inputs.emplace_back(input_index++, prefetch_input_elements_node);
276     }
277 
278     Node* num_parallel_calls_node;
279     TF_RETURN_IF_ERROR(
280         b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
281     inputs.emplace_back(input_index++, num_parallel_calls_node);
282 
283     std::vector<std::pair<StringPiece, AttrValue>> attrs;
284     AttrValue f;
285     b->BuildAttrValue(captured_func_->func(), &f);
286     attrs.emplace_back(kFunc, f);
287 
288     AttrValue other_arguments_types_attr;
289     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
290     attrs.emplace_back(kTarguments, other_arguments_types_attr);
291 
292     if (op_version_ == 2) {
293       AttrValue sloppy_attr;
294       b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
295       attrs.emplace_back(kSloppy, sloppy_attr);
296     }
297     if (op_version_ >= 3) {
298       AttrValue deterministic_attr;
299       b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
300       attrs.emplace_back(kDeterministic, deterministic_attr);
301     }
302 
303     TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
304     return Status::OK();
305   }
306 
307  private:
308   class ParallelInterleaveIterator : public DatasetIterator<Dataset> {
309    public:
ParallelInterleaveIterator(const Params & params,bool deterministic)310     ParallelInterleaveIterator(const Params& params, bool deterministic)
311         : DatasetIterator<Dataset>(params),
312           mu_(std::make_shared<mutex>()),
313           num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
314           num_parallel_calls_(std::make_shared<model::SharedState>(
315               params.dataset->num_parallel_calls_, mu_,
316               num_parallel_calls_cond_var_)),
317           deterministic_(deterministic),
318           current_elements_(params.dataset->cycle_length_) {}
319 
~ParallelInterleaveIterator()320     ~ParallelInterleaveIterator() override {
321       CancelThreads(/*wait=*/true);
322       if (deregister_fn_) deregister_fn_();
323     }
324 
Initialize(IteratorContext * ctx)325     Status Initialize(IteratorContext* ctx) override {
326       mutex_lock l(*mu_);
327       // Note that if `ctx->thread_pool()` is non-null, then instead of creating
328       // a dedicated thread pool of size `num_threads`, computation will be
329       // scheduled into the shared threadpool. The threadpool is guaranteed to
330       // support `num_threads` concurrent tasks without blocking indefinitely.
331       //
332       // Allocate one thread for the worker manager, one thread for stats
333       // collection, `cycle_length_` threads for the current workers, and
334       // `future_elements_prefetch_` for the future workers.
335       int max_current_workers = dataset()->cycle_length_;
336       int future_workers =
337           dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
338       int num_threads = 1 + max_current_workers + future_workers;
339       if (ctx->stats_aggregator()) {
340         num_threads++;
341       }
342       thread_pool_ = ctx->CreateThreadPool(kTfDataParallelInterleaveWorkerPool,
343                                            num_threads);
344       if (num_parallel_calls_->value == model::kAutotune) {
345         num_parallel_calls_->value = dataset()->cycle_length_;
346       }
347       // TODO(jsimsa): Register cancellation callback once the implementation is
348       // refactored not to hold mu_ while calling `GetNext` on the input.
349       ctx_ = std::make_unique<IteratorContext>(*ctx);
350       TF_RETURN_IF_ERROR(
351           dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
352       return dataset()->captured_func_->Instantiate(
353           ctx, &instantiated_captured_func_);
354     }
355 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)356     Status GetNextInternal(IteratorContext* ctx,
357                            std::vector<Tensor>* out_tensors,
358                            bool* end_of_sequence) override {
359       std::shared_ptr<Result> result;
360       {
361         mutex_lock l(*mu_);
362         EnsureInitialElementsCreated();
363         EnsureThreadsStarted();
364         while (!cancelled_ && !Consume(&result)) {
365           RecordStop(ctx);
366           if (deterministic_) {
367             VLOG(3) << "Blocked waiting for element "
368                     << current_elements_[cycle_index_]->id;
369             current_elements_[cycle_index_]->cond_var.wait(l);
370           } else {
371             any_element_available_cond_var_.wait(l);
372           }
373           RecordStart(ctx);
374         }
375         if (cancelled_) {
376           return errors::Cancelled("Iterator was cancelled");
377         }
378       }
379       if (!result) {
380         *end_of_sequence = true;
381         return Status::OK();
382       }
383       profiler::TraceMe traceme([&] {
384         return profiler::TraceMeEncode("ParallelInterleaveConsume",
385                                        {{"element_id", result->id}});
386       });
387       if (result->status.ok()) {
388         *out_tensors = std::move(result->return_values);
389         RecordBufferDequeue(ctx, *out_tensors);
390       }
391       *end_of_sequence = false;
392       return result->status;
393     }
394 
395    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const396     std::shared_ptr<model::Node> CreateNode(
397         IteratorContext* ctx, model::Node::Args args) const override {
398       return model::MakeAsyncInterleaveManyNode(
399           std::move(args),
400           {model::MakeParameter(kParallelism, num_parallel_calls_, /*min=*/1,
401                                 /*max=*/dataset()->cycle_length_)});
402     }
403 
404     // TODO(aaudibert): Refactor the implementations to avoid the need for
405     // `IteratorContext` when saving the state of the iterator.
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)406     Status SaveInternal(SerializationContext* ctx,
407                         IteratorStateWriter* writer) override {
408       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
409           dataset()->captured_func_->CheckExternalState()));
410       mutex_lock l(*mu_);
411       wait_for_checkpoint_ = true;
412       // Wait for all in-flight calls to complete.
413       while (num_active_workers_ > 0) {
414         RecordStop(ctx_.get());
415         zero_active_workers_cond_var_.wait(l);
416         RecordStart(ctx_.get());
417       }
418       // Initialize all elements and filter out elements with no input.
419       InitializeInputs(element_id_counter_);
420       for (auto& element : current_elements_) {
421         if (element && element->no_input) {
422           element.reset();
423         }
424       }
425       while (!future_elements_.empty() && future_elements_.back()->no_input) {
426         future_elements_.pop_back();
427       }
428       wait_for_checkpoint_ = false;
429       DCHECK_EQ(num_active_workers_, 0);
430       VLOG(4) << "State before save:\n" << DebugString();
431       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
432       TF_RETURN_IF_ERROR(
433           writer->WriteScalar(prefix(), kBlockIndex, block_index_));
434       TF_RETURN_IF_ERROR(
435           writer->WriteScalar(prefix(), kCycleIndex, cycle_index_));
436       if (end_of_input_) {
437         TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kEndOfInput, ""));
438       }
439       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kElementIdCounter,
440                                              element_id_counter_));
441       TF_RETURN_IF_ERROR(WriteCurrentElements(ctx, writer));
442       TF_RETURN_IF_ERROR(WriteFutureElements(ctx, writer));
443       // Wake workers back up.
444       current_workers_cond_var_.notify_all();
445       future_workers_cond_var_.notify_all();
446       return Status::OK();
447     }
448 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)449     Status RestoreInternal(IteratorContext* ctx,
450                            IteratorStateReader* reader) override {
451       {
452         mutex_lock l(*mu_);
453         DCHECK(!threads_initialized_);
454         DCHECK(!initial_elements_created_);
455         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
456         TF_RETURN_IF_ERROR(
457             reader->ReadScalar(prefix(), kBlockIndex, &block_index_));
458         TF_RETURN_IF_ERROR(
459             reader->ReadScalar(prefix(), kCycleIndex, &cycle_index_));
460         TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kElementIdCounter,
461                                               &element_id_counter_));
462         end_of_input_ = reader->Contains(prefix(), kEndOfInput);
463       }
464       TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
465       TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader));
466       mutex_lock l(*mu_);
467       initial_elements_created_ = false;
468       for (int i = 0; i < current_elements_.size(); ++i) {
469         int index = (cycle_index_ + i) % current_elements_.size();
470         auto element = current_elements_[index];
471         if (element) {
472           elements_to_process_.push_back(index);
473           element->initialized = true;
474           element->cycle_index = index;
475           initial_elements_created_ = true;
476         }
477       }
478       for (const auto& element : future_elements_) {
479         element->initialized = true;
480       }
481       last_valid_current_element_ = current_elements_.size() - 1;
482       while (last_valid_current_element_ >= 0 &&
483              !current_elements_[last_valid_current_element_]) {
484         last_valid_current_element_--;
485       }
486       VLOG(2) << "Parallel interleave iterator restored";
487       VLOG(4) << "State after restore:\n" << DebugString();
488       return Status::OK();
489     }
490 
GetTraceMeMetadata() const491     TraceMeMetadata GetTraceMeMetadata() const override {
492       int64 parallelism = -1;
493       // NOTE: We only set the parallelism value if the lock can be acquired
494       // right away to avoid introducing tracing overhead.
495       if (mu_->try_lock()) {
496         parallelism = num_parallel_calls_->value;
497         mu_->unlock();
498       }
499       auto result = dataset()->traceme_metadata_;
500       result.push_back(std::make_pair(
501           "parallelism",
502           strings::Printf("%lld", static_cast<long long>(parallelism))));
503       return result;
504     }
505 
506    private:
507     // Represents the result of fetching an element from a dataset.
508     struct Result {
509       Status status;
510       int64 id = -1;
511       std::vector<Tensor> return_values;
512     };
513 
514     // The interleave transformation repeatedly inputs elements, applies the
515     // user-provided function to transform the input elements to datasets, and
516     // interleaves the elements of these datasets as its output.
517     //
518     // This structure represents an input element and derived state.
519     struct Element {
520       // Unique identifier, needed to support checkpointing.
521       int64 id TF_GUARDED_BY(&ParallelInterleaveIterator::mu_);
522       // The actual input element.  Iterator created from the input element. A
523       // null value indicates that the element either reached end of input or
524       // hasn't been initialized yet.
525       std::unique_ptr<std::vector<Tensor>> inputs
526           TF_GUARDED_BY(&ParallelInterleaveIterator::mu_);
527       // Iterator created from the input element. A null value indicates that
528       // the element either reached end of input or hasn't been initialized yet.
529       std::unique_ptr<IteratorBase> iterator
530           TF_GUARDED_BY(&ParallelInterleaveIterator::mu_);
531       // Buffer for storing the outputs of `iterator`.
532       std::deque<std::shared_ptr<Result>> TF_GUARDED_BY(
533           &ParallelInterleaveIterator::mu_) results;
534       // The element's index in the cycle, if it is in the current cycle.
535       // -1 if the element is not in the current cycle.
536       int64 cycle_index TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = -1;
537       // Whether the element is currently being processed by a worker thread.
538       // This is used to ensure that only one thread at a time tries to process
539       // an element.
540       bool active TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
541       // Whether the inputs and iterator have been initialized.
542       bool initialized TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
543       // Whether we tried to initialize the element, but the input iterator
544       // was exhausted so we could produce no inputs.
545       bool no_input TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
546       // Condition variable for communicating between current worker threads
547       // and GetNext.
548       condition_variable cond_var;
549 
DebugStringtensorflow::data::ParallelInterleaveDatasetOp::Dataset::ParallelInterleaveIterator::Element550       std::string DebugString()
551           TF_EXCLUSIVE_LOCKS_REQUIRED(&ParallelInterleaveIterator::mu_) {
552         return absl::StrFormat(
553             "Element(id: %d, iterator_null: %d, results_size: %d, "
554             "cycle_index: %d, active: %d, initialized: %d, no_input: %d)",
555             id, iterator == nullptr, results.size(), cycle_index, active,
556             initialized, no_input);
557       }
558     };
559 
560     // Sets the cancellation bit and wakes up all threads that need to be
561     // cancelled. Optionally, the method waits until all threads finish
562     // executing.
CancelThreads(bool wait)563     void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
564       mutex_lock l(*mu_);
565       cancelled_ = true;
566       // Wake up all threads so that they can exit. This will also wake up any
567       // threads waiting in GetNextInternal.
568       for (const auto& element : current_elements_) {
569         if (element) {
570           element->cond_var.notify_all();
571         }
572       }
573       current_workers_cond_var_.notify_all();
574       future_workers_cond_var_.notify_all();
575       num_parallel_calls_cond_var_->notify_all();
576       stats_thread_cond_var_.notify_all();
577       while (wait && outstanding_threads_ > 0) {
578         outstanding_threads_finished_cond_var_.wait(l);
579       }
580       any_element_available_cond_var_.notify_all();
581       zero_active_workers_cond_var_.notify_all();
582     }
583 
EnsureInitialElementsCreated()584     void EnsureInitialElementsCreated() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
585       if (!initial_elements_created_) {
586         for (int i = 0; i < dataset()->cycle_length_; ++i) {
587           current_elements_[i] = MakeElement();
588           if (!current_elements_[i]) {
589             break;
590           }
591           current_elements_[i]->cycle_index = i;
592           elements_to_process_.push_back(i);
593           last_valid_current_element_ = i;
594         }
595         initial_elements_created_ = true;
596       }
597     }
598 
EnsureThreadsStarted()599     void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
600       if (!threads_initialized_) {
601         IncrementOutstandingThreads();
602         thread_pool_->Schedule([this]() { WorkerManagerThread(); });
603         if (ctx_->stats_aggregator()) {
604           IncrementOutstandingThreads();
605           thread_pool_->Schedule([this]() { StatsThread(); });
606         }
607         threads_initialized_ = true;
608       }
609     }
610 
611     // Advances the position in the interleave cycle to the next cycle
612     // element.
AdvanceToNextInCycle()613     void AdvanceToNextInCycle() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
614       DCHECK_NE(last_valid_current_element_, -1);
615       block_index_ = 0;
616       cycle_index_ = (cycle_index_ + 1) % (last_valid_current_element_ + 1);
617     }
618 
619     // Advances the position in the interleave cycle by one.
AdvancePosition()620     void AdvancePosition() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
621       ++block_index_;
622       if (block_index_ == dataset()->block_length_) {
623         AdvanceToNextInCycle();
624       }
625     }
626 
627     // Consumes a result (if available), returning an indication of whether
628     // a result is available. If `true` is returned, `result` either
629     // points to a valid result or is null if end of input has been reached.
Consume(std::shared_ptr<Result> * result)630     bool Consume(std::shared_ptr<Result>* result)
631         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
632       if (deterministic_) {
633         return ConsumeHelper(result);
634       }
635       // If we are allowed to be nondeterministic (i.e. return results out of
636       // order), try to find an element in the cycle that has a result
637       // available.
638       for (int i = 0; i < dataset()->cycle_length_; ++i) {
639         if (ConsumeHelper(result)) {
640           return true;
641         }
642         AdvanceToNextInCycle();
643       }
644       return false;
645     }
646 
647     // Consumes a result (if available), returning an indication of whether
648     // a result is available. If `true` is returned, `result` either
649     // points to a valid result or is null if end of input has been reached.
ConsumeHelper(std::shared_ptr<Result> * result)650     bool ConsumeHelper(std::shared_ptr<Result>* result)
651         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
652       while (true) {
653         if (last_valid_current_element_ == -1) {
654           // Reached end of input.
655           return true;
656         }
657         for (int64 i = 0; i < (last_valid_current_element_ + 1); ++i) {
658           int64 index = (cycle_index_ + i) % (last_valid_current_element_ + 1);
659           if (current_elements_[index]) {
660             cycle_index_ = index;
661             if (i > 0) {
662               block_index_ = 0;
663             }
664             break;
665           }
666         }
667         DCHECK(current_elements_[cycle_index_]);
668         std::shared_ptr<Element> element = current_elements_[cycle_index_];
669         if (!element->results.empty()) {
670           // We found a result.
671           std::swap(*result, element->results.front());
672           element->results.pop_front();
673           if (!element->active) {
674             elements_to_process_.push_back(cycle_index_);
675             current_workers_cond_var_.notify_one();
676           }
677           AdvancePosition();
678           return true;
679         }
680         if (!element->initialized || element->iterator) {
681           // The element is still producing results, so we wait.
682           return false;
683         }
684         // We've consumed all results from the element. Get a new element from
685         // future_elements, or create a new element if no future elements are
686         // available.
687         if (!future_elements_.empty()) {
688           std::shared_ptr<Element> future_element =
689               std::move(future_elements_.front());
690           future_elements_.pop_front();
691           if (future_element->iterator) {
692             EnableAutotune(ctx_.get(), future_element->iterator.get());
693           }
694           future_element->cycle_index = cycle_index_;
695           current_elements_[cycle_index_] = std::move(future_element);
696           future_workers_cond_var_.notify_one();
697           if (!current_elements_[cycle_index_]->active) {
698             current_workers_cond_var_.notify_one();
699           }
700         } else {
701           current_elements_[cycle_index_] = MakeElement();
702           if (current_elements_[cycle_index_]) {
703             current_elements_[cycle_index_]->cycle_index = cycle_index_;
704             elements_to_process_.push_back(cycle_index_);
705             element->cycle_index = cycle_index_;
706             current_workers_cond_var_.notify_one();
707           }
708           while (last_valid_current_element_ >= 0 &&
709                  !current_elements_[last_valid_current_element_]) {
710             last_valid_current_element_--;
711             if (cycle_index_ > last_valid_current_element_) {
712               // We are about to move the cycle index below in
713               // AdvanceToNextInCycle().
714               cycle_index_ = last_valid_current_element_;
715             }
716           }
717         }
718         if (last_valid_current_element_ != -1) {
719           AdvanceToNextInCycle();
720         }
721       }
722     }
723 
724     // Creates a new element.
MakeElement()725     std::shared_ptr<Element> MakeElement() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
726       if (end_of_input_) {
727         return nullptr;
728       }
729       auto element = std::make_shared<Element>();
730       element->id = element_id_counter_++;
731       uninitialized_elements_.push_back(element);
732       return element;
733     }
734 
735     // Thread responsible for launching all worker threads. The thread stays
736     // around after startup in case autotuning increases num_parallel_calls.
WorkerManagerThread()737     void WorkerManagerThread() TF_LOCKS_EXCLUDED(mu_) {
738       RecordStart(ctx_.get());
739       auto cleanup = gtl::MakeCleanup([this]() {
740         RecordStop(ctx_.get());
741         mutex_lock l(*mu_);
742         DecrementOutstandingThreads();
743       });
744       int initial_current_workers;
745       // When elements are moved from `future_elements_` to `current_elements_`,
746       // the future worker which created the element may continue to process
747       // the element for some time. That is why we need an additional
748       // `cycle_length_` future workers to guarantee that whenever
749       // `future_element_.size() < future_elements_prefetch_`, there will be a
750       // future worker available to create a new future element.
751       int future_workers =
752           dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
753       {
754         mutex_lock l(*mu_);
755         initial_current_workers = num_parallel_calls_->value;
756         outstanding_threads_ += initial_current_workers + future_workers;
757         num_current_workers_ += initial_current_workers;
758         num_active_workers_ += initial_current_workers + future_workers;
759         num_current_active_workers_ += initial_current_workers;
760       }
761       // Start current workers before future workers to improve startup time.
762       for (int i = 0; i < initial_current_workers; ++i) {
763         StartCurrentWorkerThread();
764       }
765       for (int i = 0; i < future_workers; ++i) {
766         StartFutureWorkerThread();
767       }
768       while (true) {
769         {
770           mutex_lock l(*mu_);
771           while (!cancelled_ &&
772                  num_current_workers_ >= num_parallel_calls_->value) {
773             RecordStop(ctx_.get());
774             num_parallel_calls_cond_var_->wait(l);
775             RecordStart(ctx_.get());
776           }
777           if (cancelled_ || end_of_input_) {
778             return;
779           }
780           IncrementOutstandingThreads();
781           IncrementCurrentWorkers();
782           IncrementActiveWorkers();
783           IncrementCurrentActiveWorkers();
784           StartCurrentWorkerThread();
785         }
786       }
787     }
788 
StartCurrentWorkerThread()789     void StartCurrentWorkerThread() {
790       thread_pool_->Schedule([this]() { CurrentWorkerThread(); });
791     }
792 
StartFutureWorkerThread()793     void StartFutureWorkerThread() {
794       thread_pool_->Schedule([this]() { FutureWorkerThread(); });
795     }
796 
797     // Current workers are responsible for keeping elements in
798     // `current_elements_` processed. An element is processed if it is either
799     // done or its `results` buffer is full (contains `kPerIteratorPrefetch`
800     // elements).
801     //
802     // Current workers cycle between two phases: (1) finding an element and (2)
803     // processing it. When a worker is processing an element, it will
804     // claim the element by setting `element->active`, then continue to produce
805     // results for the element until enough results have been computed for the
806     // current cycle and the results buffer is full.
CurrentWorkerThread()807     void CurrentWorkerThread() TF_LOCKS_EXCLUDED(mu_) {
808       RecordStart(ctx_.get());
809       auto done = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
810         RecordStop(ctx_.get());
811         DecrementActiveWorkers();
812         DecrementCurrentActiveWorkers();
813         DecrementOutstandingThreads();
814         DecrementCurrentWorkers();
815       };
816       while (true) {
817         int element_index;
818         std::shared_ptr<Element> element;
819         // Find an element to process.
820         {
821           mutex_lock l(*mu_);
822           // In case autotune changes num_parallel_calls.
823           if (num_current_workers_ > num_parallel_calls_->value) {
824             done();
825             return;
826           }
827           // Look for an element that needs processing.
828           element.reset();
829           while (!cancelled_) {
830             while (!elements_to_process_.empty() && !wait_for_checkpoint_) {
831               int index = elements_to_process_.front();
832               elements_to_process_.pop_front();
833               auto& e = current_elements_[index];
834               if (NeedsProcessing(e) && !e->active) {
835                 element_index = index;
836                 element = e;
837                 break;
838               }
839             }
840             if (element) {
841               break;
842             }
843             DecrementCurrentActiveWorkers();
844             WaitWorkerThread(&current_workers_cond_var_, &l);
845             IncrementCurrentActiveWorkers();
846           }
847           if (cancelled_) {
848             done();
849             return;
850           }
851           VLOG(3) << "Current worker woke up to process " << element->id;
852           element->active = true;
853         }
854         // Loop on the element until we fill its results buffer or reach end of
855         // input for the element.
856         while (true) {
857           ProcessElement(element);
858           {
859             mutex_lock l(*mu_);
860             // Check whether we have produced enough results for the current
861             // cycle.
862             if (!NeedsProcessing(element)) {
863               element->active = false;
864               break;
865             }
866           }
867         }
868       }
869     }
870 
871     // Future workers process elements after the current interleave cycle. A
872     // future worker's job is to keep `future_elements_` filled with elements.
873     // Elements in `future_elements` have had their first `kPerIteratorPrefetch`
874     // results computed.
FutureWorkerThread()875     void FutureWorkerThread() TF_LOCKS_EXCLUDED(mu_) {
876       RecordStart(ctx_.get());
877       auto done = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
878         RecordStop(ctx_.get());
879         DecrementActiveWorkers();
880         DecrementOutstandingThreads();
881       };
882       std::shared_ptr<Element> element;
883       while (true) {
884         {
885           mutex_lock l(*mu_);
886           if (element) {
887             element->active = false;
888             if (element->cycle_index != -1) {
889               element->cond_var.notify_one();
890               // A current worker may need to process the element further.
891               elements_to_process_.push_back(element->cycle_index);
892               current_workers_cond_var_.notify_one();
893             }
894           }
895           while (!cancelled_ && (future_elements_.size() >=
896                                      dataset()->prefetch_input_elements_ ||
897                                  wait_for_checkpoint_)) {
898             WaitWorkerThread(&future_workers_cond_var_, &l);
899           }
900           if (cancelled_) {
901             done();
902             return;
903           }
904           element = MakeElement();
905           if (!element) {
906             done();
907             return;
908           }
909           VLOG(3) << "Future worker created element " << element->id;
910           element->active = true;
911           future_elements_.push_back(element);
912         }
913         ProcessElement(element);
914       }
915     }
916 
917     // Generates results for the given element until the element's results
918     // buffer is full or the element is done producing results.
ProcessElement(std::shared_ptr<Element> element)919     void ProcessElement(std::shared_ptr<Element> element)
920         TF_LOCKS_EXCLUDED(mu_) {
921       DCHECK(element != nullptr);
922       IteratorBase* iterator;
923       int64 input_element_id;
924       // Initialize the inputs and iterator if necessary.
925       {
926         mutex_lock l(*mu_);
927         DCHECK(element->active);
928         input_element_id = element->id;
929         if (!element->iterator) {
930           InitializeInputs(input_element_id);
931           if (!element->iterator) {
932             return;
933           }
934         }
935         // `iterator` will remain valid after releasing the lock because we have
936         // marked the element as active, so no other thread will modify its
937         // iterator.
938         iterator = element->iterator.get();
939       }
940       DCHECK(iterator != nullptr);
941       // Process until the results queue is full or we reach end of input.
942       while (true) {
943         auto result = std::make_shared<Result>();
944         profiler::TraceMe traceme([&] {
945           result->id = profiler::TraceMe::NewActivityId();
946           return profiler::TraceMeEncode(
947               "ParallelInterleaveProduce",
948               {{"input_element_id", input_element_id},
949                {"element_id", result->id}});
950         });
951         bool end_of_input = false;
952         result->status = iterator->GetNext(ctx_.get(), &result->return_values,
953                                            &end_of_input);
954         if (end_of_input) {
955           mutex_lock l(*mu_);
956           element->iterator.reset();
957           element->inputs.reset();
958           NotifyElementUpdate(element);
959           break;
960         }
961         RecordBufferEnqueue(ctx_.get(), result->return_values);
962         mutex_lock l(*mu_);
963         element->results.push_back(std::move(result));
964         NotifyElementUpdate(element);
965         if (element->results.size() == dataset()->buffer_output_elements_) {
966           break;
967         }
968       }
969     }
970 
971     // Initialize inputs and create an iterator for all elements up to
972     // element_id.
InitializeInputs(int element_id)973     void InitializeInputs(int element_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
974       while (!uninitialized_elements_.empty() &&
975              uninitialized_elements_.front()->id <= element_id) {
976         std::shared_ptr<Element> element = uninitialized_elements_.front();
977         uninitialized_elements_.pop_front();
978         element->initialized = true;
979         // Check if we've already reached end of input.
980         if (end_of_input_) {
981           element->no_input = true;
982           NotifyElementUpdate(element);
983           continue;
984         }
985         profiler::TraceMe traceme([input_element_id = element->id] {
986           return profiler::TraceMeEncode(
987               "ParallelInterleaveInitializeInput",
988               {{"input_element_id", input_element_id}});
989         });
990         std::vector<Tensor> inputs;
991         Status status;
992         {
993           // TODO(aaudibert): Refactor the implementation to move calls of
994           // `GetNext` out of the scope of `mu_`.
995           status = input_impl_->GetNext(ctx_.get(), &inputs, &end_of_input_);
996         }
997         if (!status.ok()) {
998           AddErrorResult(element, status);
999           continue;
1000         }
1001         if (end_of_input_) {
1002           element->no_input = true;
1003           NotifyElementUpdate(element);
1004           continue;
1005         }
1006         element->inputs =
1007             absl::make_unique<std::vector<Tensor>>(std::move(inputs));
1008         status = MakeIteratorFromInputElement(
1009             ctx_.get(), this, *element->inputs, element->id,
1010             *instantiated_captured_func_, prefix(), &element->iterator,
1011             model_node());
1012         if (!status.ok()) {
1013           element->inputs.reset();
1014           element->iterator.reset();
1015           AddErrorResult(element, status);
1016           continue;
1017         }
1018         if (element->cycle_index == -1) {
1019           DisableAutotune(ctx_.get(), element->iterator.get());
1020         }
1021       }
1022     }
1023 
1024     // Adds an error result for the given element.
AddErrorResult(std::shared_ptr<Element> element,Status status)1025     void AddErrorResult(std::shared_ptr<Element> element, Status status)
1026         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1027       auto result = std::make_shared<Result>();
1028       result->status = status;
1029       element->results.push_back(std::move(result));
1030       NotifyElementUpdate(element);
1031     }
1032 
1033     // Cancels all threads (including the manager) and waits for them to finish.
StopAllThreads(mutex_lock * l)1034     void StopAllThreads(mutex_lock* l) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {}
1035 
1036     // Waits on the given cond_var in a worker thread.
WaitWorkerThread(condition_variable * cond_var,mutex_lock * l)1037     void WaitWorkerThread(condition_variable* cond_var, mutex_lock* l)
1038         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1039       DecrementActiveWorkers();
1040       RecordStop(ctx_.get());
1041       cond_var->wait(*l);
1042       RecordStart(ctx_.get());
1043       IncrementActiveWorkers();
1044     }
1045 
NotifyElementUpdate(std::shared_ptr<Element> element)1046     void NotifyElementUpdate(std::shared_ptr<Element> element)
1047         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1048       if (deterministic_) {
1049         element->cond_var.notify_one();
1050       } else {
1051         any_element_available_cond_var_.notify_one();
1052       }
1053     }
1054 
NeedsProcessing(const std::shared_ptr<Element> & element)1055     bool NeedsProcessing(const std::shared_ptr<Element>& element)
1056         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1057       if (!element) {
1058         return false;
1059       }
1060       if (!element->initialized) {
1061         return true;
1062       }
1063       return element->iterator &&
1064              element->results.size() < dataset()->buffer_output_elements_;
1065     }
1066 
IncrementCurrentWorkers()1067     inline void IncrementCurrentWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1068       num_current_workers_++;
1069     }
1070 
DecrementCurrentWorkers()1071     inline void DecrementCurrentWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1072       num_current_workers_--;
1073     }
1074 
IncrementActiveWorkers()1075     inline void IncrementActiveWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1076       num_active_workers_++;
1077     }
1078 
DecrementActiveWorkers()1079     inline void DecrementActiveWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1080       num_active_workers_--;
1081       if (num_active_workers_ == 0) {
1082         zero_active_workers_cond_var_.notify_one();
1083       }
1084     }
1085 
IncrementCurrentActiveWorkers()1086     inline void IncrementCurrentActiveWorkers()
1087         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1088       num_current_active_workers_++;
1089     }
1090 
DecrementCurrentActiveWorkers()1091     inline void DecrementCurrentActiveWorkers()
1092         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1093       num_current_active_workers_--;
1094     }
1095 
IncrementOutstandingThreads()1096     inline void IncrementOutstandingThreads() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1097       outstanding_threads_++;
1098     }
1099 
DecrementOutstandingThreads()1100     inline void DecrementOutstandingThreads() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1101       outstanding_threads_--;
1102       if (outstanding_threads_ == 0) {
1103         outstanding_threads_finished_cond_var_.notify_one();
1104       }
1105     }
1106 
StatsThread()1107     void StatsThread() {
1108       for (int64 step = 0;; ++step) {
1109         int num_current_active_workers;
1110         int num_current_workers;
1111         {
1112           mutex_lock l(*mu_);
1113           if (step != 0 && !cancelled_) {
1114             stats_thread_cond_var_.wait_for(
1115                 l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
1116           }
1117           if (cancelled_) {
1118             DecrementOutstandingThreads();
1119             return;
1120           }
1121           num_current_active_workers = num_current_active_workers_;
1122           num_current_workers = num_current_workers_;
1123         }
1124         if (num_current_workers == 0) {
1125           // Avoid division by zero.
1126           num_current_workers = 1;
1127         }
1128         ctx_->stats_aggregator()->AddScalar(
1129             stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
1130             static_cast<float>(num_current_active_workers) /
1131                 static_cast<float>(num_current_workers),
1132             step);
1133       }
1134     }
1135 
WriteStatusLocked(IteratorStateWriter * writer,const string & iterator_name,size_t idx,const Status & status)1136     Status WriteStatusLocked(IteratorStateWriter* writer,
1137                              const string& iterator_name, size_t idx,
1138                              const Status& status)
1139         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1140       TF_RETURN_IF_ERROR(writer->WriteScalar(
1141           iterator_name, CodeKey(idx), static_cast<int64>(status.code())));
1142       if (!status.ok()) {
1143         TF_RETURN_IF_ERROR(writer->WriteScalar(
1144             iterator_name, ErrorMessageKey(idx), status.error_message()));
1145       }
1146       return Status::OK();
1147     }
1148 
ReadStatusLocked(IteratorStateReader * reader,const string & iterator_name,size_t idx,Status * status)1149     Status ReadStatusLocked(IteratorStateReader* reader,
1150                             const string& iterator_name, size_t idx,
1151                             Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1152       int64 code_int;
1153       TF_RETURN_IF_ERROR(
1154           reader->ReadScalar(iterator_name, CodeKey(idx), &code_int));
1155       error::Code code = static_cast<error::Code>(code_int);
1156 
1157       if (code != error::Code::OK) {
1158         tstring error_message;
1159         TF_RETURN_IF_ERROR(reader->ReadScalar(
1160             iterator_name, ErrorMessageKey(idx), &error_message));
1161         *status = Status(code, error_message);
1162       } else {
1163         *status = Status::OK();
1164       }
1165       return Status::OK();
1166     }
1167 
CodeKey(size_t idx)1168     string CodeKey(size_t idx) {
1169       return absl::StrCat(kResultsSuffix, "[", idx, "]", kCodeSuffix);
1170     }
1171 
ErrorMessageKey(size_t idx)1172     string ErrorMessageKey(size_t idx) {
1173       return absl::StrCat(kResultsSuffix, "[", idx, "]", kErrorMessageSuffix);
1174     }
1175 
WriteElement(SerializationContext * ctx,std::shared_ptr<Element> element,int idx,const string & key_prefix,IteratorStateWriter * writer)1176     Status WriteElement(SerializationContext* ctx,
1177                         std::shared_ptr<Element> element, int idx,
1178                         const string& key_prefix, IteratorStateWriter* writer)
1179         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
1180       const auto& iterator_name =
1181           absl::StrCat(prefix(), "::", key_prefix, "::", idx);
1182       if (element->iterator) {
1183         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, element->iterator));
1184         TF_RETURN_IF_ERROR(
1185             writer->WriteScalar(iterator_name, kIdSuffix, element->id));
1186         TF_RETURN_IF_ERROR(writer->WriteScalar(
1187             iterator_name, absl::StrCat(kInputsSuffix, kSizeSuffix),
1188             element->inputs->size()));
1189         for (int i = 0; i < element->inputs->size(); i++) {
1190           TF_RETURN_IF_ERROR(writer->WriteTensor(
1191               iterator_name, absl::StrCat(kInputsSuffix, "[", i, "]"),
1192               element->inputs->at(i)));
1193         }
1194       }
1195       TF_RETURN_IF_ERROR(writer->WriteScalar(
1196           iterator_name, absl::StrCat(kResultsSuffix, kSizeSuffix),
1197           element->results.size()));
1198       for (size_t i = 0; i < element->results.size(); i++) {
1199         std::shared_ptr<Result> result = element->results[i];
1200         TF_RETURN_IF_ERROR(
1201             WriteStatusLocked(writer, iterator_name, i, result->status));
1202         TF_RETURN_IF_ERROR(writer->WriteScalar(
1203             iterator_name,
1204             absl::StrCat(kResultsSuffix, "[", i, "]", kSizeSuffix),
1205             result->return_values.size()));
1206         for (size_t j = 0; j < result->return_values.size(); j++) {
1207           TF_RETURN_IF_ERROR(writer->WriteTensor(
1208               iterator_name, absl::StrCat(kResultsSuffix, "[", i, "][", j, "]"),
1209               result->return_values[j]));
1210         }
1211         TF_RETURN_IF_ERROR(writer->WriteScalar(
1212             iterator_name,
1213             absl::StrCat(kResultsSuffix, "[", i, "]", kIsReadySuffix), ""));
1214       }
1215       return Status::OK();
1216     }
1217 
WriteCurrentElements(SerializationContext * ctx,IteratorStateWriter * writer)1218     Status WriteCurrentElements(SerializationContext* ctx,
1219                                 IteratorStateWriter* writer)
1220         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1221       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentElementsSize,
1222                                              current_elements_.size()));
1223       for (int idx = 0; idx < current_elements_.size(); idx++) {
1224         if (current_elements_[idx]) {
1225           TF_RETURN_IF_ERROR(WriteElement(ctx, current_elements_[idx], idx,
1226                                           kCurrentElements, writer));
1227         }
1228       }
1229       return Status::OK();
1230     }
1231 
WriteFutureElements(SerializationContext * ctx,IteratorStateWriter * writer)1232     Status WriteFutureElements(SerializationContext* ctx,
1233                                IteratorStateWriter* writer)
1234         TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1235       TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kFutureElementsSize,
1236                                              future_elements_.size()));
1237       for (int idx = 0; idx < future_elements_.size(); idx++) {
1238         if (future_elements_[idx]) {
1239           TF_RETURN_IF_ERROR(WriteElement(ctx, future_elements_[idx], idx,
1240                                           kFutureElements, writer));
1241         }
1242       }
1243       return Status::OK();
1244     }
1245 
ReadElement(IteratorContext * ctx,IteratorStateReader * reader,int idx,const string & key_prefix,std::shared_ptr<Element> * out)1246     Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader,
1247                        int idx, const string& key_prefix,
1248                        std::shared_ptr<Element>* out) {
1249       std::unique_ptr<IteratorBase> iterator;
1250       auto element = std::make_shared<Element>();
1251       {
1252         mutex_lock l(*mu_);
1253         const auto& iterator_name =
1254             absl::StrCat(prefix(), "::", key_prefix, "::", idx);
1255         if (!reader->Contains(iterator_name,
1256                               absl::StrCat(kResultsSuffix, kSizeSuffix))) {
1257           return Status::OK();
1258         }
1259         int64 results_size;
1260         TF_RETURN_IF_ERROR(reader->ReadScalar(
1261             iterator_name, absl::StrCat(kResultsSuffix, kSizeSuffix),
1262             &results_size));
1263         element->results.resize(results_size);
1264         for (size_t i = 0; i < results_size; i++) {
1265           auto result = std::make_shared<Result>();
1266           TF_RETURN_IF_ERROR(
1267               ReadStatusLocked(reader, iterator_name, i, &result->status));
1268           int64 num_return_values;
1269           TF_RETURN_IF_ERROR(reader->ReadScalar(
1270               iterator_name,
1271               absl::StrCat(kResultsSuffix, "[", i, "]", kSizeSuffix),
1272               &num_return_values));
1273           result->return_values.reserve(num_return_values);
1274           for (size_t j = 0; j < num_return_values; j++) {
1275             result->return_values.emplace_back();
1276             TF_RETURN_IF_ERROR(reader->ReadTensor(
1277                 iterator_name,
1278                 absl::StrCat(kResultsSuffix, "[", i, "][", j, "]"),
1279                 &result->return_values.back()));
1280           }
1281           RecordBufferEnqueue(ctx, result->return_values);
1282           element->results[i] = std::move(result);
1283         }
1284         if (!reader->Contains(iterator_name,
1285                               absl::StrCat(kInputsSuffix, kSizeSuffix))) {
1286           element->iterator.reset();
1287           *out = std::move(element);
1288           return Status::OK();
1289         }
1290         int64 inputs_size;
1291         TF_RETURN_IF_ERROR(reader->ReadScalar(
1292             iterator_name, absl::StrCat(kInputsSuffix, kSizeSuffix),
1293             &inputs_size));
1294         element->inputs = std::make_unique<std::vector<Tensor>>(inputs_size);
1295         for (int i = 0; i < inputs_size; i++) {
1296           TF_RETURN_IF_ERROR(reader->ReadTensor(
1297               iterator_name, absl::StrCat(kInputsSuffix, "[", i, "]"),
1298               &element->inputs->at(i)));
1299         }
1300         TF_RETURN_IF_ERROR(
1301             reader->ReadScalar(iterator_name, kIdSuffix, &element->id));
1302         TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
1303             ctx, this, *element->inputs, element->id,
1304             *instantiated_captured_func_.get(), prefix(), &iterator,
1305             model_node()));
1306       }
1307       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
1308       mutex_lock l(*mu_);
1309       element->iterator = std::move(iterator);
1310       *out = std::move(element);
1311       return Status::OK();
1312     }
1313 
ReadCurrentElements(IteratorContext * ctx,IteratorStateReader * reader)1314     Status ReadCurrentElements(IteratorContext* ctx,
1315                                IteratorStateReader* reader) {
1316       int64 size;
1317       {
1318         mutex_lock l(*mu_);
1319         TF_RETURN_IF_ERROR(
1320             reader->ReadScalar(prefix(), kCurrentElementsSize, &size));
1321         if (current_elements_.size() != size) {
1322           // This could mean two things: (1) the user created their checkpoint
1323           // from a dataset with one cycle_length, then changed the cycle_length
1324           // and tried to restore from the old checkpoint, or (2) the user set
1325           // the cycle length to tf.data.AUTOTUNE, wrote the checkpoint from one
1326           // machine, then tried to restore the checkpoint on another machine
1327           // with a different CPU budget (causing autotune to pick a different
1328           // cycle length).
1329           return errors::FailedPrecondition(
1330               "The iterator cycle length ", current_elements_.size(),
1331               " is different from the cycle length to restore from the "
1332               "checkpoint: ",
1333               size);
1334         }
1335       }
1336       if (size == 0) {
1337         return Status::OK();
1338       }
1339       std::vector<std::shared_ptr<Element>> elements;
1340       TF_RETURN_IF_ERROR(
1341           ReadElementsParallel(ctx, reader, size, kCurrentElements, &elements));
1342       mutex_lock l(*mu_);
1343       for (auto& element : current_elements_) {
1344         DCHECK(element == nullptr);
1345       }
1346       for (int idx = 0; idx < size; ++idx) {
1347         current_elements_[idx] = std::move(elements[idx]);
1348       }
1349       return Status::OK();
1350     }
1351 
ReadFutureElements(IteratorContext * ctx,IteratorStateReader * reader)1352     Status ReadFutureElements(IteratorContext* ctx,
1353                               IteratorStateReader* reader) {
1354       int64 size;
1355       {
1356         mutex_lock l(*mu_);
1357         TF_RETURN_IF_ERROR(
1358             reader->ReadScalar(prefix(), kFutureElementsSize, &size));
1359         future_elements_.resize(size);
1360       }
1361       if (size == 0) {
1362         return Status::OK();
1363       }
1364       std::vector<std::shared_ptr<Element>> elements;
1365       TF_RETURN_IF_ERROR(
1366           ReadElementsParallel(ctx, reader, size, kFutureElements, &elements));
1367       mutex_lock l(*mu_);
1368       for (auto& element : future_elements_) {
1369         DCHECK(element == nullptr);
1370       }
1371       for (int idx = 0; idx < size; ++idx) {
1372         future_elements_[idx] = std::move(elements[idx]);
1373       }
1374       return Status::OK();
1375     }
1376 
ReadElementsParallel(IteratorContext * ctx,IteratorStateReader * reader,int64 size,const string & name,std::vector<std::shared_ptr<Element>> * elements)1377     Status ReadElementsParallel(
1378         IteratorContext* ctx, IteratorStateReader* reader, int64 size,
1379         const string& name, std::vector<std::shared_ptr<Element>>* elements) {
1380       elements->resize(size);
1381       Status s = Status::OK();
1382       BlockingCounter counter(size);
1383       for (int idx = 0; idx < size; ++idx) {
1384         thread_pool_->Schedule(
1385             [this, ctx, reader, idx, name, &s, &counter, elements] {
1386               RecordStart(ctx);
1387               auto cleanup = gtl::MakeCleanup([this, ctx, &counter]() {
1388                 RecordStop(ctx);
1389                 counter.DecrementCount();
1390               });
1391               std::shared_ptr<Element> elem;
1392               Status ret_status = ReadElement(ctx, reader, idx, name, &elem);
1393               mutex_lock l(*mu_);
1394               if (cancelled_) {
1395                 s.Update(
1396                     errors::Cancelled("Cancelled in ReadElementsParallel"));
1397                 return;
1398               }
1399               if (!ret_status.ok()) {
1400                 s.Update(ret_status);
1401                 return;
1402               }
1403               (*elements)[idx] = elem;
1404             });
1405       }
1406       counter.Wait();
1407       return s;
1408     }
1409 
DebugString()1410     std::string DebugString() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1411       std::string result;
1412       result.append(strings::StrCat("Cycle index: ", cycle_index_, "\n"));
1413       result.append(strings::StrCat("Block index: ", block_index_, "\n"));
1414       result.append(strings::StrCat("End of input: ", end_of_input_, "\n"));
1415       {
1416         result.append("Current elements:\n");
1417         for (int i = 0; i < current_elements_.size(); ++i) {
1418           string element_string = "null";
1419           if (current_elements_[i]) {
1420             element_string = current_elements_[i]->DebugString();
1421           }
1422           result.append(absl::StrFormat("%d: %s\n", i, element_string));
1423         }
1424       }
1425       {
1426         result.append("Future elements:\n");
1427         for (int i = 0; i < future_elements_.size(); ++i) {
1428           string element_string = "null";
1429           if (future_elements_[i]) {
1430             element_string = future_elements_[i]->DebugString();
1431           }
1432           result.append(absl::StrFormat("%d: %s\n", i, element_string));
1433         }
1434       }
1435       return result;
1436     }
1437 
1438     // Indices of `current_elements_` which need to be processed by a current
1439     // worker.
1440     std::deque<int> elements_to_process_;
1441 
1442     // The last index in `current_elements_` containing a non-null element.
1443     // This allows us to optimize the situation when the cycle_length is large
1444     // but the input dataset doesn't have many elements. By tracking the index
1445     // of the last valid element, GetNext can avoid checking many null entries
1446     // each time through the cycle.
1447     //
1448     // TODO(aaudibert): Generalize this optimization by removing null elements
1449     // from `current_elements_`, e.g. by compacting the vector when x% of
1450     // its elements are null.
1451     int64 last_valid_current_element_ TF_GUARDED_BY(mu_) = -1;
1452 
1453     // Identifies whether the current_elements_ vector has been initialized.
1454     bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
1455 
1456     // Identifies whether the element threads have been initialized.
1457     bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
1458 
1459     // Used for coordination between the main thread, the manager threads, and
1460     // the worker threads.
1461     //
1462     // NOTE: We should never call GetNext on the input while holding this mutex.
1463     const std::shared_ptr<mutex> mu_;
1464 
1465     // Condition variable for waking up current workers.
1466     condition_variable current_workers_cond_var_;
1467 
1468     // Condition variable for waking up future workers.
1469     condition_variable future_workers_cond_var_;
1470 
1471     // Condition variable for waking up the stats thread.
1472     condition_variable stats_thread_cond_var_;
1473 
1474     // Number of active worker threads which might be processing elements,
1475     // including both current workers and future workers. Used by
1476     // checkpointing to wait for outstanding work to finish.
1477     int num_active_workers_ TF_GUARDED_BY(mu_) = 0;
1478 
1479     // Number of active current worker threads.
1480     int num_current_active_workers_ TF_GUARDED_BY(mu_) = 0;
1481 
1482     // Condition variable notified whenever the total number of active workers
1483     // drops to zero. Used for checkpointing.
1484     condition_variable zero_active_workers_cond_var_;
1485 
1486     // Condition notified whenever num_parallel_calls_ changes. Shared so that
1487     // autotuning can notify us when num_parallel_calls_ changes.
1488     std::shared_ptr<condition_variable> num_parallel_calls_cond_var_;
1489 
1490     // Identifies the maximum number of parallel calls.
1491     const std::shared_ptr<model::SharedState> num_parallel_calls_;
1492 
1493     // The number of current workers currently alive or scheduled to be started.
1494     // This includes current workers which are blocked waiting for work.
1495     int num_current_workers_ TF_GUARDED_BY(mu_) = 0;
1496 
1497     // Condition variable to signal that a result has been produced by some
1498     // element thread. Only used when `deterministic` is false.
1499     condition_variable any_element_available_cond_var_;
1500 
1501     // Determines whether outputs can be produced in deterministic order.
1502     const bool deterministic_;
1503 
1504     // Iterator for input elements.
1505     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
1506 
1507     // Identifies position in the interleave cycle.
1508     int64 block_index_ TF_GUARDED_BY(mu_) = 0;
1509     // It is an invariant that either `last_valid_current_element_ == -1` or
1510     // `cycle_index_ <= last_valid_current_element_`.
1511     int64 cycle_index_ TF_GUARDED_BY(mu_) = 0;
1512 
1513     // Elements of the current interleave cycle.
1514     std::vector<std::shared_ptr<Element>> current_elements_ TF_GUARDED_BY(mu_);
1515 
1516     // Elements which still need their inputs and iterators to be initialized.
1517     // Elements at the front need to be initialized first.
1518     std::deque<std::shared_ptr<Element>> uninitialized_elements_
1519         TF_GUARDED_BY(mu_);
1520 
1521     // Elements to be used in the interleave cycle in the future. The element
1522     // at the front is the next element to add to the interleave cycle when a
1523     // current element is exhausted.
1524     std::deque<std::shared_ptr<Element>> future_elements_ TF_GUARDED_BY(mu_);
1525 
1526     // Identifies whether the global end of input has been reached.
1527     bool end_of_input_ TF_GUARDED_BY(mu_) = false;
1528 
1529     // The number of outstanding element threads.
1530     int outstanding_threads_ TF_GUARDED_BY(mu_) = 0;
1531 
1532     // Condition variable notified when outstanding_threads_ drops to 0.
1533     condition_variable outstanding_threads_finished_cond_var_;
1534 
1535     std::unique_ptr<thread::ThreadPool> thread_pool_;
1536 
1537     int64 element_id_counter_ TF_GUARDED_BY(mu_) = 0;
1538 
1539     // Iterator context used in worker threads.
1540     std::unique_ptr<IteratorContext> ctx_;
1541 
1542     // Set to true during checkpointing to alert element threads that they
1543     // should pause operation. This is needed to prevent constantly-active
1544     // worker threads from blocking checkpointing indefinitely.
1545     bool wait_for_checkpoint_ = false;
1546 
1547     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
1548 
1549     // Identifies whether background threads should be cancelled.
1550     bool cancelled_ TF_GUARDED_BY(mu_) = false;
1551 
1552     // Method for deregistering the cancellation callback.
1553     std::function<void()> deregister_fn_;
1554   };
1555 
1556   const DatasetBase* const input_;
1557   const std::unique_ptr<CapturedFunction> captured_func_;
1558   const int64 cycle_length_;
1559   const int64 block_length_;
1560   const int64 buffer_output_elements_;
1561   const int64 prefetch_input_elements_;
1562   const int64 num_parallel_calls_;
1563   const DeterminismPolicy deterministic_;
1564   const DataTypeVector output_types_;
1565   const std::vector<PartialTensorShape> output_shapes_;
1566   const int op_version_;
1567   const TraceMeMetadata traceme_metadata_;
1568 };
1569 
ParallelInterleaveDatasetOp(OpKernelConstruction * ctx)1570 ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
1571     OpKernelConstruction* ctx)
1572     : UnaryDatasetOpKernel(ctx),
1573       op_version_(OpVersionFromOpName(ctx->def().op())) {
1574   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
1575                                                &func_metadata_));
1576   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1577   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1578   if (op_version_ == 2) {
1579     bool sloppy;
1580     OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy));
1581     if (sloppy) {
1582       deterministic_ =
1583           DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
1584     } else {
1585       deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
1586     }
1587   }
1588   if (op_version_ >= 3) {
1589     std::string deterministic;
1590     OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
1591     OP_REQUIRES_OK(
1592         ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
1593   }
1594 }
1595 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1596 void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
1597                                               DatasetBase* input,
1598                                               DatasetBase** output) {
1599   int64 block_length = 0;
1600   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
1601   OP_REQUIRES(ctx, block_length > 0,
1602               errors::InvalidArgument("`block_length` must be > 0"));
1603 
1604   int64 buffer_output_elements = model::kAutotune;
1605   int64 prefetch_input_elements = model::kAutotune;
1606   if (op_version_ >= 4) {
1607     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
1608                                             &buffer_output_elements));
1609     OP_REQUIRES(ctx,
1610                 buffer_output_elements == model::kAutotune ||
1611                     buffer_output_elements > 0,
1612                 errors::InvalidArgument("`buffer_output_elements` must be ",
1613                                         model::kAutotune, " or > 0 but is ",
1614                                         buffer_output_elements));
1615 
1616     OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
1617                                             &prefetch_input_elements));
1618     OP_REQUIRES(ctx,
1619                 prefetch_input_elements == model::kAutotune ||
1620                     prefetch_input_elements >= 0,
1621                 errors::InvalidArgument("`prefetch_input_elements` must be ",
1622                                         model::kAutotune, " or >= 0 but is ",
1623                                         prefetch_input_elements));
1624   }
1625 
1626   int64 num_parallel_calls = 0;
1627   OP_REQUIRES_OK(
1628       ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
1629   OP_REQUIRES(
1630       ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
1631       errors::InvalidArgument("num_parallel_calls must be greater than zero."));
1632   int64 cycle_length = 0;
1633   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
1634   if (cycle_length == model::kAutotune) {
1635     if (num_parallel_calls != model::kAutotune) {
1636       cycle_length = std::min(num_parallel_calls,
1637                               static_cast<int64>(port::MaxParallelism()));
1638     } else {
1639       // If parallelism is to be autotuned, we set the cycle length so that
1640       // the number of thread created for the current and future cycle elements
1641       // roughly matches the number of schedulable cores.
1642       const int num_threads_per_cycle_length = kDefaultCyclePrefetchFactor + 1;
1643       cycle_length =
1644           CeilDiv(port::MaxParallelism(), num_threads_per_cycle_length);
1645     }
1646   }
1647   OP_REQUIRES(ctx, cycle_length > 0,
1648               errors::InvalidArgument("`cycle_length` must be > 0"));
1649 
1650   OP_REQUIRES(
1651       ctx, num_parallel_calls <= cycle_length,
1652       errors::InvalidArgument(
1653           "num_parallel_calls must less than or equal to cycle_length."));
1654 
1655   std::unique_ptr<CapturedFunction> captured_func;
1656   OP_REQUIRES_OK(ctx,
1657                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
1658                                           &captured_func));
1659 
1660   if (num_parallel_calls == model::kAutotune) {
1661     metrics::RecordTFDataAutotune(kDatasetType);
1662   }
1663 
1664   *output = new Dataset(
1665       ctx, input, std::move(captured_func), cycle_length, block_length,
1666       buffer_output_elements, prefetch_input_elements, num_parallel_calls,
1667       deterministic_, output_types_, output_shapes_, op_version_);
1668 }
1669 
1670 namespace {
1671 REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
1672                         ParallelInterleaveDatasetOp);
1673 REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
1674                         ParallelInterleaveDatasetOp);
1675 REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
1676                         ParallelInterleaveDatasetOp);
1677 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
1678 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
1679 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
1680 }  // namespace
1681 }  // namespace data
1682 }  // namespace tensorflow
1683