1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/parallel_interleave_dataset_op.h"
16
17 #include <atomic>
18 #include <deque>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/strings/str_format.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
25 #include "tensorflow/core/common_runtime/metrics.h"
26 #include "tensorflow/core/framework/dataset.h"
27 #include "tensorflow/core/framework/model.h"
28 #include "tensorflow/core/framework/partial_tensor_shape.h"
29 #include "tensorflow/core/framework/stats_aggregator.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/kernels/data/captured_function.h"
33 #include "tensorflow/core/kernels/data/dataset_utils.h"
34 #include "tensorflow/core/kernels/data/name_utils.h"
35 #include "tensorflow/core/kernels/data/stats_utils.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/lib/gtl/cleanup.h"
39 #include "tensorflow/core/lib/random/random.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/lib/strings/stringprintf.h"
42 #include "tensorflow/core/platform/blocking_counter.h"
43 #include "tensorflow/core/platform/cpu_info.h"
44 #include "tensorflow/core/platform/errors.h"
45 #include "tensorflow/core/platform/stringprintf.h"
46 #include "tensorflow/core/profiler/lib/traceme.h"
47 #include "tensorflow/core/profiler/lib/traceme_encode.h"
48
49 namespace tensorflow {
50 namespace data {
51
52 // See documentation in ../../ops/dataset_ops.cc for a high-level
53 // description of the following op.
54
55 /* static */ constexpr const char* const
56 ParallelInterleaveDatasetOp::kDatasetType;
57 /* static */ constexpr const char* const
58 ParallelInterleaveDatasetOp::kInputDataset;
59 /* static */ constexpr const char* const
60 ParallelInterleaveDatasetOp::kOtherArguments;
61 /* static */ constexpr const char* const
62 ParallelInterleaveDatasetOp::kCycleLength;
63 /* static */ constexpr const char* const
64 ParallelInterleaveDatasetOp::kBlockLength;
65 /* static */ constexpr const char* const
66 ParallelInterleaveDatasetOp::kBufferOutputElements;
67 /* static */ constexpr const char* const
68 ParallelInterleaveDatasetOp::kPrefetchInputElements;
69 /* static */ constexpr const char* const
70 ParallelInterleaveDatasetOp::kNumParallelCalls;
71 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kFunc;
72 /* static */ constexpr const char* const
73 ParallelInterleaveDatasetOp::kTarguments;
74 /* static */ constexpr const char* const
75 ParallelInterleaveDatasetOp::kOutputTypes;
76 /* static */ constexpr const char* const
77 ParallelInterleaveDatasetOp::kOutputShapes;
78 /* static */ constexpr const char* const
79 ParallelInterleaveDatasetOp::kDeterministic;
80 /* static */ constexpr const char* const ParallelInterleaveDatasetOp::kSloppy;
81
82 namespace {
83
84 constexpr char kTfDataParallelInterleaveWorkerPool[] =
85 "tf_data_parallel_interleave_worker_pool";
86 constexpr char kParallelism[] = "parallelism";
87 constexpr char kBlockIndex[] = "block_index";
88 constexpr char kCycleIndex[] = "cycle_index";
89 constexpr char kEndOfInput[] = "end_of_input";
90 constexpr char kElementIdCounter[] = "element_id_counter";
91 constexpr char kCurrentElements[] = "current_elements";
92 constexpr char kCurrentElementsSize[] = "current_elements.size";
93 constexpr char kFutureElements[] = "future_elements";
94 constexpr char kFutureElementsSize[] = "future_elements.size";
95 constexpr char kResultsSuffix[] = ".results";
96 constexpr char kCodeSuffix[] = ".code";
97 constexpr char kErrorMessageSuffix[] = ".error_message";
98 constexpr char kIdSuffix[] = ".id";
99 constexpr char kSizeSuffix[] = ".size";
100 constexpr char kInputsSuffix[] = ".inputs";
101 constexpr char kIsReadySuffix[] = ".is_ready";
102
103 constexpr char kParallelInterleaveDatasetV2[] = "ParallelInterleaveDatasetV2";
104 constexpr char kParallelInterleaveDatasetV3[] = "ParallelInterleaveDatasetV3";
105 constexpr char kParallelInterleaveDatasetV4[] = "ParallelInterleaveDatasetV4";
106
107 // `kCyclePrefetchFactor * cycle_length` is the default number of future cycle
108 // elements that will be prefetched ahead of time. The purpose of prefetching
109 // future cycle elements is to overlap expensive initialization (e.g. opening of
110 // a remote file) with other computation.
111 constexpr double kDefaultCyclePrefetchFactor = 2.0L;
112
113 // `kPerIteratorPrefetchFactor * block_length + 1` is the defualt number of
114 // per-iterator results that will be prefetched ahead of time. The `+ 1` is to
115 // match the behavior of the original implementation.
116 constexpr double kDefaultPerIteratorPrefetchFactor = 2.0L;
117
118 // Period between reporting dataset statistics.
119 constexpr int kStatsReportingPeriodMillis = 1000;
120
CeilDiv(int64 numerator,int64 denominator)121 inline int64 CeilDiv(int64 numerator, int64 denominator) {
122 return (numerator + denominator - 1) / denominator;
123 }
124
ComputeBufferOutputElements(int64 configured_buffer_output_elements,int64 block_length)125 int64 ComputeBufferOutputElements(int64 configured_buffer_output_elements,
126 int64 block_length) {
127 if (configured_buffer_output_elements != model::kAutotune) {
128 return configured_buffer_output_elements;
129 }
130 return kDefaultPerIteratorPrefetchFactor * block_length + 1;
131 }
132
ComputePrefetchInputElements(int64 configured_prefetch_input_elements,int64 cycle_length)133 int64 ComputePrefetchInputElements(int64 configured_prefetch_input_elements,
134 int64 cycle_length) {
135 if (configured_prefetch_input_elements != model::kAutotune) {
136 return configured_prefetch_input_elements;
137 }
138 return kDefaultCyclePrefetchFactor * cycle_length;
139 }
140
OpVersionFromOpName(absl::string_view op_name)141 int64 OpVersionFromOpName(absl::string_view op_name) {
142 if (op_name == kParallelInterleaveDatasetV2) {
143 return 2;
144 } else if (op_name == kParallelInterleaveDatasetV3) {
145 return 3;
146 } else {
147 DCHECK_EQ(op_name, kParallelInterleaveDatasetV4);
148 return 4;
149 }
150 }
151
152 } // namespace
153
154 // The motivation for creating an alternative implementation of parallel
155 // interleave is to decouple the degree of parallelism from the cycle length.
156 // This makes it possible to change the degree of parallelism (e.g. through
157 // auto-tuning) without changing the cycle length (which would change the order
158 // in which elements are produced).
159 class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
160 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_func,int64 cycle_length,int64 block_length,int64 buffer_output_elements,int64 prefetch_input_elements,int64 num_parallel_calls,DeterminismPolicy deterministic,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,int op_version)161 Dataset(OpKernelContext* ctx, const DatasetBase* input,
162 std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
163 int64 block_length, int64 buffer_output_elements,
164 int64 prefetch_input_elements, int64 num_parallel_calls,
165 DeterminismPolicy deterministic, const DataTypeVector& output_types,
166 const std::vector<PartialTensorShape>& output_shapes, int op_version)
167 : DatasetBase(DatasetContext(ctx)),
168 input_(input),
169 captured_func_(std::move(captured_func)),
170 cycle_length_(cycle_length),
171 block_length_(block_length),
172 buffer_output_elements_(
173 ComputeBufferOutputElements(buffer_output_elements, block_length)),
174 prefetch_input_elements_(ComputePrefetchInputElements(
175 prefetch_input_elements, cycle_length)),
176 num_parallel_calls_(num_parallel_calls),
177 deterministic_(deterministic),
178 output_types_(output_types),
179 output_shapes_(output_shapes),
180 op_version_(op_version),
181 traceme_metadata_(
182 {{"autotune",
183 num_parallel_calls == model::kAutotune ? "true" : "false"},
184 {"block_length",
185 strings::Printf("%lld", static_cast<long long>(block_length))},
186 {"cycle_length",
187 strings::Printf("%lld", static_cast<long long>(cycle_length))},
188 {"deterministic",
189 deterministic.IsNondeterministic() ? "false" : "true"}}) {
190 input_->Ref();
191 }
192
~Dataset()193 ~Dataset() override { input_->Unref(); }
194
MakeIteratorInternal(const string & prefix) const195 std::unique_ptr<IteratorBase> MakeIteratorInternal(
196 const string& prefix) const override {
197 name_utils::IteratorPrefixParams params;
198 params.op_version = op_version_;
199 bool deterministic =
200 deterministic_.IsDeterministic() || deterministic_.IsDefault();
201 return absl::make_unique<ParallelInterleaveIterator>(
202 ParallelInterleaveIterator::Params{
203 this,
204 name_utils::IteratorPrefix(
205 ParallelInterleaveDatasetOp::kDatasetType, prefix, params)},
206 deterministic);
207 }
208
output_dtypes() const209 const DataTypeVector& output_dtypes() const override { return output_types_; }
210
output_shapes() const211 const std::vector<PartialTensorShape>& output_shapes() const override {
212 return output_shapes_;
213 }
214
DebugString() const215 string DebugString() const override {
216 name_utils::DatasetDebugStringParams params;
217 params.op_version = op_version_;
218 return name_utils::DatasetDebugString(
219 ParallelInterleaveDatasetOp::kDatasetType, params);
220 }
221
Cardinality() const222 int64 Cardinality() const override {
223 int64 n = input_->Cardinality();
224 if (n == kInfiniteCardinality) {
225 return n;
226 }
227 return kUnknownCardinality;
228 }
229
InputDatasets(std::vector<const DatasetBase * > * inputs) const230 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
231 inputs->push_back(input_);
232 return Status::OK();
233 }
234
CheckExternalState() const235 Status CheckExternalState() const override {
236 TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
237 return input_->CheckExternalState();
238 }
239
240 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const241 Status AsGraphDefInternal(SerializationContext* ctx,
242 DatasetGraphDefBuilder* b,
243 Node** output) const override {
244 std::vector<std::pair<size_t, Node*>> inputs;
245 std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>> list_inputs;
246 int input_index = 0;
247
248 Node* input_node;
249 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
250 inputs.emplace_back(input_index++, input_node);
251
252 std::vector<Node*> other_arguments;
253 DataTypeVector other_arguments_types;
254 TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
255 &other_arguments_types));
256 list_inputs.emplace_back(input_index++, other_arguments);
257
258 Node* cycle_length_node;
259 TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
260 inputs.emplace_back(input_index++, cycle_length_node);
261
262 Node* block_length_node;
263 TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
264 inputs.emplace_back(input_index++, block_length_node);
265
266 if (op_version_ >= 4) {
267 Node* buffer_output_elements_node;
268 TF_RETURN_IF_ERROR(
269 b->AddScalar(buffer_output_elements_, &buffer_output_elements_node));
270 inputs.emplace_back(input_index++, buffer_output_elements_node);
271
272 Node* prefetch_input_elements_node;
273 TF_RETURN_IF_ERROR(b->AddScalar(prefetch_input_elements_,
274 &prefetch_input_elements_node));
275 inputs.emplace_back(input_index++, prefetch_input_elements_node);
276 }
277
278 Node* num_parallel_calls_node;
279 TF_RETURN_IF_ERROR(
280 b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
281 inputs.emplace_back(input_index++, num_parallel_calls_node);
282
283 std::vector<std::pair<StringPiece, AttrValue>> attrs;
284 AttrValue f;
285 b->BuildAttrValue(captured_func_->func(), &f);
286 attrs.emplace_back(kFunc, f);
287
288 AttrValue other_arguments_types_attr;
289 b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
290 attrs.emplace_back(kTarguments, other_arguments_types_attr);
291
292 if (op_version_ == 2) {
293 AttrValue sloppy_attr;
294 b->BuildAttrValue(deterministic_.IsNondeterministic(), &sloppy_attr);
295 attrs.emplace_back(kSloppy, sloppy_attr);
296 }
297 if (op_version_ >= 3) {
298 AttrValue deterministic_attr;
299 b->BuildAttrValue(deterministic_.String(), &deterministic_attr);
300 attrs.emplace_back(kDeterministic, deterministic_attr);
301 }
302
303 TF_RETURN_IF_ERROR(b->AddDataset(this, inputs, list_inputs, attrs, output));
304 return Status::OK();
305 }
306
307 private:
308 class ParallelInterleaveIterator : public DatasetIterator<Dataset> {
309 public:
ParallelInterleaveIterator(const Params & params,bool deterministic)310 ParallelInterleaveIterator(const Params& params, bool deterministic)
311 : DatasetIterator<Dataset>(params),
312 mu_(std::make_shared<mutex>()),
313 num_parallel_calls_cond_var_(std::make_shared<condition_variable>()),
314 num_parallel_calls_(std::make_shared<model::SharedState>(
315 params.dataset->num_parallel_calls_, mu_,
316 num_parallel_calls_cond_var_)),
317 deterministic_(deterministic),
318 current_elements_(params.dataset->cycle_length_) {}
319
~ParallelInterleaveIterator()320 ~ParallelInterleaveIterator() override {
321 CancelThreads(/*wait=*/true);
322 if (deregister_fn_) deregister_fn_();
323 }
324
Initialize(IteratorContext * ctx)325 Status Initialize(IteratorContext* ctx) override {
326 mutex_lock l(*mu_);
327 // Note that if `ctx->thread_pool()` is non-null, then instead of creating
328 // a dedicated thread pool of size `num_threads`, computation will be
329 // scheduled into the shared threadpool. The threadpool is guaranteed to
330 // support `num_threads` concurrent tasks without blocking indefinitely.
331 //
332 // Allocate one thread for the worker manager, one thread for stats
333 // collection, `cycle_length_` threads for the current workers, and
334 // `future_elements_prefetch_` for the future workers.
335 int max_current_workers = dataset()->cycle_length_;
336 int future_workers =
337 dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
338 int num_threads = 1 + max_current_workers + future_workers;
339 if (ctx->stats_aggregator()) {
340 num_threads++;
341 }
342 thread_pool_ = ctx->CreateThreadPool(kTfDataParallelInterleaveWorkerPool,
343 num_threads);
344 if (num_parallel_calls_->value == model::kAutotune) {
345 num_parallel_calls_->value = dataset()->cycle_length_;
346 }
347 // TODO(jsimsa): Register cancellation callback once the implementation is
348 // refactored not to hold mu_ while calling `GetNext` on the input.
349 ctx_ = std::make_unique<IteratorContext>(*ctx);
350 TF_RETURN_IF_ERROR(
351 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
352 return dataset()->captured_func_->Instantiate(
353 ctx, &instantiated_captured_func_);
354 }
355
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)356 Status GetNextInternal(IteratorContext* ctx,
357 std::vector<Tensor>* out_tensors,
358 bool* end_of_sequence) override {
359 std::shared_ptr<Result> result;
360 {
361 mutex_lock l(*mu_);
362 EnsureInitialElementsCreated();
363 EnsureThreadsStarted();
364 while (!cancelled_ && !Consume(&result)) {
365 RecordStop(ctx);
366 if (deterministic_) {
367 VLOG(3) << "Blocked waiting for element "
368 << current_elements_[cycle_index_]->id;
369 current_elements_[cycle_index_]->cond_var.wait(l);
370 } else {
371 any_element_available_cond_var_.wait(l);
372 }
373 RecordStart(ctx);
374 }
375 if (cancelled_) {
376 return errors::Cancelled("Iterator was cancelled");
377 }
378 }
379 if (!result) {
380 *end_of_sequence = true;
381 return Status::OK();
382 }
383 profiler::TraceMe traceme([&] {
384 return profiler::TraceMeEncode("ParallelInterleaveConsume",
385 {{"element_id", result->id}});
386 });
387 if (result->status.ok()) {
388 *out_tensors = std::move(result->return_values);
389 RecordBufferDequeue(ctx, *out_tensors);
390 }
391 *end_of_sequence = false;
392 return result->status;
393 }
394
395 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const396 std::shared_ptr<model::Node> CreateNode(
397 IteratorContext* ctx, model::Node::Args args) const override {
398 return model::MakeAsyncInterleaveManyNode(
399 std::move(args),
400 {model::MakeParameter(kParallelism, num_parallel_calls_, /*min=*/1,
401 /*max=*/dataset()->cycle_length_)});
402 }
403
404 // TODO(aaudibert): Refactor the implementations to avoid the need for
405 // `IteratorContext` when saving the state of the iterator.
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)406 Status SaveInternal(SerializationContext* ctx,
407 IteratorStateWriter* writer) override {
408 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
409 dataset()->captured_func_->CheckExternalState()));
410 mutex_lock l(*mu_);
411 wait_for_checkpoint_ = true;
412 // Wait for all in-flight calls to complete.
413 while (num_active_workers_ > 0) {
414 RecordStop(ctx_.get());
415 zero_active_workers_cond_var_.wait(l);
416 RecordStart(ctx_.get());
417 }
418 // Initialize all elements and filter out elements with no input.
419 InitializeInputs(element_id_counter_);
420 for (auto& element : current_elements_) {
421 if (element && element->no_input) {
422 element.reset();
423 }
424 }
425 while (!future_elements_.empty() && future_elements_.back()->no_input) {
426 future_elements_.pop_back();
427 }
428 wait_for_checkpoint_ = false;
429 DCHECK_EQ(num_active_workers_, 0);
430 VLOG(4) << "State before save:\n" << DebugString();
431 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
432 TF_RETURN_IF_ERROR(
433 writer->WriteScalar(prefix(), kBlockIndex, block_index_));
434 TF_RETURN_IF_ERROR(
435 writer->WriteScalar(prefix(), kCycleIndex, cycle_index_));
436 if (end_of_input_) {
437 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kEndOfInput, ""));
438 }
439 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kElementIdCounter,
440 element_id_counter_));
441 TF_RETURN_IF_ERROR(WriteCurrentElements(ctx, writer));
442 TF_RETURN_IF_ERROR(WriteFutureElements(ctx, writer));
443 // Wake workers back up.
444 current_workers_cond_var_.notify_all();
445 future_workers_cond_var_.notify_all();
446 return Status::OK();
447 }
448
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)449 Status RestoreInternal(IteratorContext* ctx,
450 IteratorStateReader* reader) override {
451 {
452 mutex_lock l(*mu_);
453 DCHECK(!threads_initialized_);
454 DCHECK(!initial_elements_created_);
455 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
456 TF_RETURN_IF_ERROR(
457 reader->ReadScalar(prefix(), kBlockIndex, &block_index_));
458 TF_RETURN_IF_ERROR(
459 reader->ReadScalar(prefix(), kCycleIndex, &cycle_index_));
460 TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kElementIdCounter,
461 &element_id_counter_));
462 end_of_input_ = reader->Contains(prefix(), kEndOfInput);
463 }
464 TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
465 TF_RETURN_IF_ERROR(ReadFutureElements(ctx, reader));
466 mutex_lock l(*mu_);
467 initial_elements_created_ = false;
468 for (int i = 0; i < current_elements_.size(); ++i) {
469 int index = (cycle_index_ + i) % current_elements_.size();
470 auto element = current_elements_[index];
471 if (element) {
472 elements_to_process_.push_back(index);
473 element->initialized = true;
474 element->cycle_index = index;
475 initial_elements_created_ = true;
476 }
477 }
478 for (const auto& element : future_elements_) {
479 element->initialized = true;
480 }
481 last_valid_current_element_ = current_elements_.size() - 1;
482 while (last_valid_current_element_ >= 0 &&
483 !current_elements_[last_valid_current_element_]) {
484 last_valid_current_element_--;
485 }
486 VLOG(2) << "Parallel interleave iterator restored";
487 VLOG(4) << "State after restore:\n" << DebugString();
488 return Status::OK();
489 }
490
GetTraceMeMetadata() const491 TraceMeMetadata GetTraceMeMetadata() const override {
492 int64 parallelism = -1;
493 // NOTE: We only set the parallelism value if the lock can be acquired
494 // right away to avoid introducing tracing overhead.
495 if (mu_->try_lock()) {
496 parallelism = num_parallel_calls_->value;
497 mu_->unlock();
498 }
499 auto result = dataset()->traceme_metadata_;
500 result.push_back(std::make_pair(
501 "parallelism",
502 strings::Printf("%lld", static_cast<long long>(parallelism))));
503 return result;
504 }
505
506 private:
507 // Represents the result of fetching an element from a dataset.
508 struct Result {
509 Status status;
510 int64 id = -1;
511 std::vector<Tensor> return_values;
512 };
513
514 // The interleave transformation repeatedly inputs elements, applies the
515 // user-provided function to transform the input elements to datasets, and
516 // interleaves the elements of these datasets as its output.
517 //
518 // This structure represents an input element and derived state.
519 struct Element {
520 // Unique identifier, needed to support checkpointing.
521 int64 id TF_GUARDED_BY(&ParallelInterleaveIterator::mu_);
522 // The actual input element. Iterator created from the input element. A
523 // null value indicates that the element either reached end of input or
524 // hasn't been initialized yet.
525 std::unique_ptr<std::vector<Tensor>> inputs
526 TF_GUARDED_BY(&ParallelInterleaveIterator::mu_);
527 // Iterator created from the input element. A null value indicates that
528 // the element either reached end of input or hasn't been initialized yet.
529 std::unique_ptr<IteratorBase> iterator
530 TF_GUARDED_BY(&ParallelInterleaveIterator::mu_);
531 // Buffer for storing the outputs of `iterator`.
532 std::deque<std::shared_ptr<Result>> TF_GUARDED_BY(
533 &ParallelInterleaveIterator::mu_) results;
534 // The element's index in the cycle, if it is in the current cycle.
535 // -1 if the element is not in the current cycle.
536 int64 cycle_index TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = -1;
537 // Whether the element is currently being processed by a worker thread.
538 // This is used to ensure that only one thread at a time tries to process
539 // an element.
540 bool active TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
541 // Whether the inputs and iterator have been initialized.
542 bool initialized TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
543 // Whether we tried to initialize the element, but the input iterator
544 // was exhausted so we could produce no inputs.
545 bool no_input TF_GUARDED_BY(&ParallelInterleaveIterator::mu_) = false;
546 // Condition variable for communicating between current worker threads
547 // and GetNext.
548 condition_variable cond_var;
549
DebugStringtensorflow::data::ParallelInterleaveDatasetOp::Dataset::ParallelInterleaveIterator::Element550 std::string DebugString()
551 TF_EXCLUSIVE_LOCKS_REQUIRED(&ParallelInterleaveIterator::mu_) {
552 return absl::StrFormat(
553 "Element(id: %d, iterator_null: %d, results_size: %d, "
554 "cycle_index: %d, active: %d, initialized: %d, no_input: %d)",
555 id, iterator == nullptr, results.size(), cycle_index, active,
556 initialized, no_input);
557 }
558 };
559
560 // Sets the cancellation bit and wakes up all threads that need to be
561 // cancelled. Optionally, the method waits until all threads finish
562 // executing.
CancelThreads(bool wait)563 void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
564 mutex_lock l(*mu_);
565 cancelled_ = true;
566 // Wake up all threads so that they can exit. This will also wake up any
567 // threads waiting in GetNextInternal.
568 for (const auto& element : current_elements_) {
569 if (element) {
570 element->cond_var.notify_all();
571 }
572 }
573 current_workers_cond_var_.notify_all();
574 future_workers_cond_var_.notify_all();
575 num_parallel_calls_cond_var_->notify_all();
576 stats_thread_cond_var_.notify_all();
577 while (wait && outstanding_threads_ > 0) {
578 outstanding_threads_finished_cond_var_.wait(l);
579 }
580 any_element_available_cond_var_.notify_all();
581 zero_active_workers_cond_var_.notify_all();
582 }
583
EnsureInitialElementsCreated()584 void EnsureInitialElementsCreated() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
585 if (!initial_elements_created_) {
586 for (int i = 0; i < dataset()->cycle_length_; ++i) {
587 current_elements_[i] = MakeElement();
588 if (!current_elements_[i]) {
589 break;
590 }
591 current_elements_[i]->cycle_index = i;
592 elements_to_process_.push_back(i);
593 last_valid_current_element_ = i;
594 }
595 initial_elements_created_ = true;
596 }
597 }
598
EnsureThreadsStarted()599 void EnsureThreadsStarted() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
600 if (!threads_initialized_) {
601 IncrementOutstandingThreads();
602 thread_pool_->Schedule([this]() { WorkerManagerThread(); });
603 if (ctx_->stats_aggregator()) {
604 IncrementOutstandingThreads();
605 thread_pool_->Schedule([this]() { StatsThread(); });
606 }
607 threads_initialized_ = true;
608 }
609 }
610
611 // Advances the position in the interleave cycle to the next cycle
612 // element.
AdvanceToNextInCycle()613 void AdvanceToNextInCycle() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
614 DCHECK_NE(last_valid_current_element_, -1);
615 block_index_ = 0;
616 cycle_index_ = (cycle_index_ + 1) % (last_valid_current_element_ + 1);
617 }
618
619 // Advances the position in the interleave cycle by one.
AdvancePosition()620 void AdvancePosition() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
621 ++block_index_;
622 if (block_index_ == dataset()->block_length_) {
623 AdvanceToNextInCycle();
624 }
625 }
626
627 // Consumes a result (if available), returning an indication of whether
628 // a result is available. If `true` is returned, `result` either
629 // points to a valid result or is null if end of input has been reached.
Consume(std::shared_ptr<Result> * result)630 bool Consume(std::shared_ptr<Result>* result)
631 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
632 if (deterministic_) {
633 return ConsumeHelper(result);
634 }
635 // If we are allowed to be nondeterministic (i.e. return results out of
636 // order), try to find an element in the cycle that has a result
637 // available.
638 for (int i = 0; i < dataset()->cycle_length_; ++i) {
639 if (ConsumeHelper(result)) {
640 return true;
641 }
642 AdvanceToNextInCycle();
643 }
644 return false;
645 }
646
647 // Consumes a result (if available), returning an indication of whether
648 // a result is available. If `true` is returned, `result` either
649 // points to a valid result or is null if end of input has been reached.
ConsumeHelper(std::shared_ptr<Result> * result)650 bool ConsumeHelper(std::shared_ptr<Result>* result)
651 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
652 while (true) {
653 if (last_valid_current_element_ == -1) {
654 // Reached end of input.
655 return true;
656 }
657 for (int64 i = 0; i < (last_valid_current_element_ + 1); ++i) {
658 int64 index = (cycle_index_ + i) % (last_valid_current_element_ + 1);
659 if (current_elements_[index]) {
660 cycle_index_ = index;
661 if (i > 0) {
662 block_index_ = 0;
663 }
664 break;
665 }
666 }
667 DCHECK(current_elements_[cycle_index_]);
668 std::shared_ptr<Element> element = current_elements_[cycle_index_];
669 if (!element->results.empty()) {
670 // We found a result.
671 std::swap(*result, element->results.front());
672 element->results.pop_front();
673 if (!element->active) {
674 elements_to_process_.push_back(cycle_index_);
675 current_workers_cond_var_.notify_one();
676 }
677 AdvancePosition();
678 return true;
679 }
680 if (!element->initialized || element->iterator) {
681 // The element is still producing results, so we wait.
682 return false;
683 }
684 // We've consumed all results from the element. Get a new element from
685 // future_elements, or create a new element if no future elements are
686 // available.
687 if (!future_elements_.empty()) {
688 std::shared_ptr<Element> future_element =
689 std::move(future_elements_.front());
690 future_elements_.pop_front();
691 if (future_element->iterator) {
692 EnableAutotune(ctx_.get(), future_element->iterator.get());
693 }
694 future_element->cycle_index = cycle_index_;
695 current_elements_[cycle_index_] = std::move(future_element);
696 future_workers_cond_var_.notify_one();
697 if (!current_elements_[cycle_index_]->active) {
698 current_workers_cond_var_.notify_one();
699 }
700 } else {
701 current_elements_[cycle_index_] = MakeElement();
702 if (current_elements_[cycle_index_]) {
703 current_elements_[cycle_index_]->cycle_index = cycle_index_;
704 elements_to_process_.push_back(cycle_index_);
705 element->cycle_index = cycle_index_;
706 current_workers_cond_var_.notify_one();
707 }
708 while (last_valid_current_element_ >= 0 &&
709 !current_elements_[last_valid_current_element_]) {
710 last_valid_current_element_--;
711 if (cycle_index_ > last_valid_current_element_) {
712 // We are about to move the cycle index below in
713 // AdvanceToNextInCycle().
714 cycle_index_ = last_valid_current_element_;
715 }
716 }
717 }
718 if (last_valid_current_element_ != -1) {
719 AdvanceToNextInCycle();
720 }
721 }
722 }
723
724 // Creates a new element.
MakeElement()725 std::shared_ptr<Element> MakeElement() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
726 if (end_of_input_) {
727 return nullptr;
728 }
729 auto element = std::make_shared<Element>();
730 element->id = element_id_counter_++;
731 uninitialized_elements_.push_back(element);
732 return element;
733 }
734
735 // Thread responsible for launching all worker threads. The thread stays
736 // around after startup in case autotuning increases num_parallel_calls.
WorkerManagerThread()737 void WorkerManagerThread() TF_LOCKS_EXCLUDED(mu_) {
738 RecordStart(ctx_.get());
739 auto cleanup = gtl::MakeCleanup([this]() {
740 RecordStop(ctx_.get());
741 mutex_lock l(*mu_);
742 DecrementOutstandingThreads();
743 });
744 int initial_current_workers;
745 // When elements are moved from `future_elements_` to `current_elements_`,
746 // the future worker which created the element may continue to process
747 // the element for some time. That is why we need an additional
748 // `cycle_length_` future workers to guarantee that whenever
749 // `future_element_.size() < future_elements_prefetch_`, there will be a
750 // future worker available to create a new future element.
751 int future_workers =
752 dataset()->prefetch_input_elements_ + dataset()->cycle_length_;
753 {
754 mutex_lock l(*mu_);
755 initial_current_workers = num_parallel_calls_->value;
756 outstanding_threads_ += initial_current_workers + future_workers;
757 num_current_workers_ += initial_current_workers;
758 num_active_workers_ += initial_current_workers + future_workers;
759 num_current_active_workers_ += initial_current_workers;
760 }
761 // Start current workers before future workers to improve startup time.
762 for (int i = 0; i < initial_current_workers; ++i) {
763 StartCurrentWorkerThread();
764 }
765 for (int i = 0; i < future_workers; ++i) {
766 StartFutureWorkerThread();
767 }
768 while (true) {
769 {
770 mutex_lock l(*mu_);
771 while (!cancelled_ &&
772 num_current_workers_ >= num_parallel_calls_->value) {
773 RecordStop(ctx_.get());
774 num_parallel_calls_cond_var_->wait(l);
775 RecordStart(ctx_.get());
776 }
777 if (cancelled_ || end_of_input_) {
778 return;
779 }
780 IncrementOutstandingThreads();
781 IncrementCurrentWorkers();
782 IncrementActiveWorkers();
783 IncrementCurrentActiveWorkers();
784 StartCurrentWorkerThread();
785 }
786 }
787 }
788
StartCurrentWorkerThread()789 void StartCurrentWorkerThread() {
790 thread_pool_->Schedule([this]() { CurrentWorkerThread(); });
791 }
792
StartFutureWorkerThread()793 void StartFutureWorkerThread() {
794 thread_pool_->Schedule([this]() { FutureWorkerThread(); });
795 }
796
797 // Current workers are responsible for keeping elements in
798 // `current_elements_` processed. An element is processed if it is either
799 // done or its `results` buffer is full (contains `kPerIteratorPrefetch`
800 // elements).
801 //
802 // Current workers cycle between two phases: (1) finding an element and (2)
803 // processing it. When a worker is processing an element, it will
804 // claim the element by setting `element->active`, then continue to produce
805 // results for the element until enough results have been computed for the
806 // current cycle and the results buffer is full.
CurrentWorkerThread()807 void CurrentWorkerThread() TF_LOCKS_EXCLUDED(mu_) {
808 RecordStart(ctx_.get());
809 auto done = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
810 RecordStop(ctx_.get());
811 DecrementActiveWorkers();
812 DecrementCurrentActiveWorkers();
813 DecrementOutstandingThreads();
814 DecrementCurrentWorkers();
815 };
816 while (true) {
817 int element_index;
818 std::shared_ptr<Element> element;
819 // Find an element to process.
820 {
821 mutex_lock l(*mu_);
822 // In case autotune changes num_parallel_calls.
823 if (num_current_workers_ > num_parallel_calls_->value) {
824 done();
825 return;
826 }
827 // Look for an element that needs processing.
828 element.reset();
829 while (!cancelled_) {
830 while (!elements_to_process_.empty() && !wait_for_checkpoint_) {
831 int index = elements_to_process_.front();
832 elements_to_process_.pop_front();
833 auto& e = current_elements_[index];
834 if (NeedsProcessing(e) && !e->active) {
835 element_index = index;
836 element = e;
837 break;
838 }
839 }
840 if (element) {
841 break;
842 }
843 DecrementCurrentActiveWorkers();
844 WaitWorkerThread(¤t_workers_cond_var_, &l);
845 IncrementCurrentActiveWorkers();
846 }
847 if (cancelled_) {
848 done();
849 return;
850 }
851 VLOG(3) << "Current worker woke up to process " << element->id;
852 element->active = true;
853 }
854 // Loop on the element until we fill its results buffer or reach end of
855 // input for the element.
856 while (true) {
857 ProcessElement(element);
858 {
859 mutex_lock l(*mu_);
860 // Check whether we have produced enough results for the current
861 // cycle.
862 if (!NeedsProcessing(element)) {
863 element->active = false;
864 break;
865 }
866 }
867 }
868 }
869 }
870
871 // Future workers process elements after the current interleave cycle. A
872 // future worker's job is to keep `future_elements_` filled with elements.
873 // Elements in `future_elements` have had their first `kPerIteratorPrefetch`
874 // results computed.
FutureWorkerThread()875 void FutureWorkerThread() TF_LOCKS_EXCLUDED(mu_) {
876 RecordStart(ctx_.get());
877 auto done = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
878 RecordStop(ctx_.get());
879 DecrementActiveWorkers();
880 DecrementOutstandingThreads();
881 };
882 std::shared_ptr<Element> element;
883 while (true) {
884 {
885 mutex_lock l(*mu_);
886 if (element) {
887 element->active = false;
888 if (element->cycle_index != -1) {
889 element->cond_var.notify_one();
890 // A current worker may need to process the element further.
891 elements_to_process_.push_back(element->cycle_index);
892 current_workers_cond_var_.notify_one();
893 }
894 }
895 while (!cancelled_ && (future_elements_.size() >=
896 dataset()->prefetch_input_elements_ ||
897 wait_for_checkpoint_)) {
898 WaitWorkerThread(&future_workers_cond_var_, &l);
899 }
900 if (cancelled_) {
901 done();
902 return;
903 }
904 element = MakeElement();
905 if (!element) {
906 done();
907 return;
908 }
909 VLOG(3) << "Future worker created element " << element->id;
910 element->active = true;
911 future_elements_.push_back(element);
912 }
913 ProcessElement(element);
914 }
915 }
916
917 // Generates results for the given element until the element's results
918 // buffer is full or the element is done producing results.
ProcessElement(std::shared_ptr<Element> element)919 void ProcessElement(std::shared_ptr<Element> element)
920 TF_LOCKS_EXCLUDED(mu_) {
921 DCHECK(element != nullptr);
922 IteratorBase* iterator;
923 int64 input_element_id;
924 // Initialize the inputs and iterator if necessary.
925 {
926 mutex_lock l(*mu_);
927 DCHECK(element->active);
928 input_element_id = element->id;
929 if (!element->iterator) {
930 InitializeInputs(input_element_id);
931 if (!element->iterator) {
932 return;
933 }
934 }
935 // `iterator` will remain valid after releasing the lock because we have
936 // marked the element as active, so no other thread will modify its
937 // iterator.
938 iterator = element->iterator.get();
939 }
940 DCHECK(iterator != nullptr);
941 // Process until the results queue is full or we reach end of input.
942 while (true) {
943 auto result = std::make_shared<Result>();
944 profiler::TraceMe traceme([&] {
945 result->id = profiler::TraceMe::NewActivityId();
946 return profiler::TraceMeEncode(
947 "ParallelInterleaveProduce",
948 {{"input_element_id", input_element_id},
949 {"element_id", result->id}});
950 });
951 bool end_of_input = false;
952 result->status = iterator->GetNext(ctx_.get(), &result->return_values,
953 &end_of_input);
954 if (end_of_input) {
955 mutex_lock l(*mu_);
956 element->iterator.reset();
957 element->inputs.reset();
958 NotifyElementUpdate(element);
959 break;
960 }
961 RecordBufferEnqueue(ctx_.get(), result->return_values);
962 mutex_lock l(*mu_);
963 element->results.push_back(std::move(result));
964 NotifyElementUpdate(element);
965 if (element->results.size() == dataset()->buffer_output_elements_) {
966 break;
967 }
968 }
969 }
970
971 // Initialize inputs and create an iterator for all elements up to
972 // element_id.
InitializeInputs(int element_id)973 void InitializeInputs(int element_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
974 while (!uninitialized_elements_.empty() &&
975 uninitialized_elements_.front()->id <= element_id) {
976 std::shared_ptr<Element> element = uninitialized_elements_.front();
977 uninitialized_elements_.pop_front();
978 element->initialized = true;
979 // Check if we've already reached end of input.
980 if (end_of_input_) {
981 element->no_input = true;
982 NotifyElementUpdate(element);
983 continue;
984 }
985 profiler::TraceMe traceme([input_element_id = element->id] {
986 return profiler::TraceMeEncode(
987 "ParallelInterleaveInitializeInput",
988 {{"input_element_id", input_element_id}});
989 });
990 std::vector<Tensor> inputs;
991 Status status;
992 {
993 // TODO(aaudibert): Refactor the implementation to move calls of
994 // `GetNext` out of the scope of `mu_`.
995 status = input_impl_->GetNext(ctx_.get(), &inputs, &end_of_input_);
996 }
997 if (!status.ok()) {
998 AddErrorResult(element, status);
999 continue;
1000 }
1001 if (end_of_input_) {
1002 element->no_input = true;
1003 NotifyElementUpdate(element);
1004 continue;
1005 }
1006 element->inputs =
1007 absl::make_unique<std::vector<Tensor>>(std::move(inputs));
1008 status = MakeIteratorFromInputElement(
1009 ctx_.get(), this, *element->inputs, element->id,
1010 *instantiated_captured_func_, prefix(), &element->iterator,
1011 model_node());
1012 if (!status.ok()) {
1013 element->inputs.reset();
1014 element->iterator.reset();
1015 AddErrorResult(element, status);
1016 continue;
1017 }
1018 if (element->cycle_index == -1) {
1019 DisableAutotune(ctx_.get(), element->iterator.get());
1020 }
1021 }
1022 }
1023
1024 // Adds an error result for the given element.
AddErrorResult(std::shared_ptr<Element> element,Status status)1025 void AddErrorResult(std::shared_ptr<Element> element, Status status)
1026 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1027 auto result = std::make_shared<Result>();
1028 result->status = status;
1029 element->results.push_back(std::move(result));
1030 NotifyElementUpdate(element);
1031 }
1032
1033 // Cancels all threads (including the manager) and waits for them to finish.
StopAllThreads(mutex_lock * l)1034 void StopAllThreads(mutex_lock* l) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {}
1035
1036 // Waits on the given cond_var in a worker thread.
WaitWorkerThread(condition_variable * cond_var,mutex_lock * l)1037 void WaitWorkerThread(condition_variable* cond_var, mutex_lock* l)
1038 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1039 DecrementActiveWorkers();
1040 RecordStop(ctx_.get());
1041 cond_var->wait(*l);
1042 RecordStart(ctx_.get());
1043 IncrementActiveWorkers();
1044 }
1045
NotifyElementUpdate(std::shared_ptr<Element> element)1046 void NotifyElementUpdate(std::shared_ptr<Element> element)
1047 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1048 if (deterministic_) {
1049 element->cond_var.notify_one();
1050 } else {
1051 any_element_available_cond_var_.notify_one();
1052 }
1053 }
1054
NeedsProcessing(const std::shared_ptr<Element> & element)1055 bool NeedsProcessing(const std::shared_ptr<Element>& element)
1056 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1057 if (!element) {
1058 return false;
1059 }
1060 if (!element->initialized) {
1061 return true;
1062 }
1063 return element->iterator &&
1064 element->results.size() < dataset()->buffer_output_elements_;
1065 }
1066
IncrementCurrentWorkers()1067 inline void IncrementCurrentWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1068 num_current_workers_++;
1069 }
1070
DecrementCurrentWorkers()1071 inline void DecrementCurrentWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1072 num_current_workers_--;
1073 }
1074
IncrementActiveWorkers()1075 inline void IncrementActiveWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1076 num_active_workers_++;
1077 }
1078
DecrementActiveWorkers()1079 inline void DecrementActiveWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1080 num_active_workers_--;
1081 if (num_active_workers_ == 0) {
1082 zero_active_workers_cond_var_.notify_one();
1083 }
1084 }
1085
IncrementCurrentActiveWorkers()1086 inline void IncrementCurrentActiveWorkers()
1087 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1088 num_current_active_workers_++;
1089 }
1090
DecrementCurrentActiveWorkers()1091 inline void DecrementCurrentActiveWorkers()
1092 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1093 num_current_active_workers_--;
1094 }
1095
IncrementOutstandingThreads()1096 inline void IncrementOutstandingThreads() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1097 outstanding_threads_++;
1098 }
1099
DecrementOutstandingThreads()1100 inline void DecrementOutstandingThreads() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1101 outstanding_threads_--;
1102 if (outstanding_threads_ == 0) {
1103 outstanding_threads_finished_cond_var_.notify_one();
1104 }
1105 }
1106
StatsThread()1107 void StatsThread() {
1108 for (int64 step = 0;; ++step) {
1109 int num_current_active_workers;
1110 int num_current_workers;
1111 {
1112 mutex_lock l(*mu_);
1113 if (step != 0 && !cancelled_) {
1114 stats_thread_cond_var_.wait_for(
1115 l, std::chrono::milliseconds(kStatsReportingPeriodMillis));
1116 }
1117 if (cancelled_) {
1118 DecrementOutstandingThreads();
1119 return;
1120 }
1121 num_current_active_workers = num_current_active_workers_;
1122 num_current_workers = num_current_workers_;
1123 }
1124 if (num_current_workers == 0) {
1125 // Avoid division by zero.
1126 num_current_workers = 1;
1127 }
1128 ctx_->stats_aggregator()->AddScalar(
1129 stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
1130 static_cast<float>(num_current_active_workers) /
1131 static_cast<float>(num_current_workers),
1132 step);
1133 }
1134 }
1135
WriteStatusLocked(IteratorStateWriter * writer,const string & iterator_name,size_t idx,const Status & status)1136 Status WriteStatusLocked(IteratorStateWriter* writer,
1137 const string& iterator_name, size_t idx,
1138 const Status& status)
1139 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1140 TF_RETURN_IF_ERROR(writer->WriteScalar(
1141 iterator_name, CodeKey(idx), static_cast<int64>(status.code())));
1142 if (!status.ok()) {
1143 TF_RETURN_IF_ERROR(writer->WriteScalar(
1144 iterator_name, ErrorMessageKey(idx), status.error_message()));
1145 }
1146 return Status::OK();
1147 }
1148
ReadStatusLocked(IteratorStateReader * reader,const string & iterator_name,size_t idx,Status * status)1149 Status ReadStatusLocked(IteratorStateReader* reader,
1150 const string& iterator_name, size_t idx,
1151 Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1152 int64 code_int;
1153 TF_RETURN_IF_ERROR(
1154 reader->ReadScalar(iterator_name, CodeKey(idx), &code_int));
1155 error::Code code = static_cast<error::Code>(code_int);
1156
1157 if (code != error::Code::OK) {
1158 tstring error_message;
1159 TF_RETURN_IF_ERROR(reader->ReadScalar(
1160 iterator_name, ErrorMessageKey(idx), &error_message));
1161 *status = Status(code, error_message);
1162 } else {
1163 *status = Status::OK();
1164 }
1165 return Status::OK();
1166 }
1167
CodeKey(size_t idx)1168 string CodeKey(size_t idx) {
1169 return absl::StrCat(kResultsSuffix, "[", idx, "]", kCodeSuffix);
1170 }
1171
ErrorMessageKey(size_t idx)1172 string ErrorMessageKey(size_t idx) {
1173 return absl::StrCat(kResultsSuffix, "[", idx, "]", kErrorMessageSuffix);
1174 }
1175
WriteElement(SerializationContext * ctx,std::shared_ptr<Element> element,int idx,const string & key_prefix,IteratorStateWriter * writer)1176 Status WriteElement(SerializationContext* ctx,
1177 std::shared_ptr<Element> element, int idx,
1178 const string& key_prefix, IteratorStateWriter* writer)
1179 TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
1180 const auto& iterator_name =
1181 absl::StrCat(prefix(), "::", key_prefix, "::", idx);
1182 if (element->iterator) {
1183 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, element->iterator));
1184 TF_RETURN_IF_ERROR(
1185 writer->WriteScalar(iterator_name, kIdSuffix, element->id));
1186 TF_RETURN_IF_ERROR(writer->WriteScalar(
1187 iterator_name, absl::StrCat(kInputsSuffix, kSizeSuffix),
1188 element->inputs->size()));
1189 for (int i = 0; i < element->inputs->size(); i++) {
1190 TF_RETURN_IF_ERROR(writer->WriteTensor(
1191 iterator_name, absl::StrCat(kInputsSuffix, "[", i, "]"),
1192 element->inputs->at(i)));
1193 }
1194 }
1195 TF_RETURN_IF_ERROR(writer->WriteScalar(
1196 iterator_name, absl::StrCat(kResultsSuffix, kSizeSuffix),
1197 element->results.size()));
1198 for (size_t i = 0; i < element->results.size(); i++) {
1199 std::shared_ptr<Result> result = element->results[i];
1200 TF_RETURN_IF_ERROR(
1201 WriteStatusLocked(writer, iterator_name, i, result->status));
1202 TF_RETURN_IF_ERROR(writer->WriteScalar(
1203 iterator_name,
1204 absl::StrCat(kResultsSuffix, "[", i, "]", kSizeSuffix),
1205 result->return_values.size()));
1206 for (size_t j = 0; j < result->return_values.size(); j++) {
1207 TF_RETURN_IF_ERROR(writer->WriteTensor(
1208 iterator_name, absl::StrCat(kResultsSuffix, "[", i, "][", j, "]"),
1209 result->return_values[j]));
1210 }
1211 TF_RETURN_IF_ERROR(writer->WriteScalar(
1212 iterator_name,
1213 absl::StrCat(kResultsSuffix, "[", i, "]", kIsReadySuffix), ""));
1214 }
1215 return Status::OK();
1216 }
1217
WriteCurrentElements(SerializationContext * ctx,IteratorStateWriter * writer)1218 Status WriteCurrentElements(SerializationContext* ctx,
1219 IteratorStateWriter* writer)
1220 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1221 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentElementsSize,
1222 current_elements_.size()));
1223 for (int idx = 0; idx < current_elements_.size(); idx++) {
1224 if (current_elements_[idx]) {
1225 TF_RETURN_IF_ERROR(WriteElement(ctx, current_elements_[idx], idx,
1226 kCurrentElements, writer));
1227 }
1228 }
1229 return Status::OK();
1230 }
1231
WriteFutureElements(SerializationContext * ctx,IteratorStateWriter * writer)1232 Status WriteFutureElements(SerializationContext* ctx,
1233 IteratorStateWriter* writer)
1234 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1235 TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kFutureElementsSize,
1236 future_elements_.size()));
1237 for (int idx = 0; idx < future_elements_.size(); idx++) {
1238 if (future_elements_[idx]) {
1239 TF_RETURN_IF_ERROR(WriteElement(ctx, future_elements_[idx], idx,
1240 kFutureElements, writer));
1241 }
1242 }
1243 return Status::OK();
1244 }
1245
ReadElement(IteratorContext * ctx,IteratorStateReader * reader,int idx,const string & key_prefix,std::shared_ptr<Element> * out)1246 Status ReadElement(IteratorContext* ctx, IteratorStateReader* reader,
1247 int idx, const string& key_prefix,
1248 std::shared_ptr<Element>* out) {
1249 std::unique_ptr<IteratorBase> iterator;
1250 auto element = std::make_shared<Element>();
1251 {
1252 mutex_lock l(*mu_);
1253 const auto& iterator_name =
1254 absl::StrCat(prefix(), "::", key_prefix, "::", idx);
1255 if (!reader->Contains(iterator_name,
1256 absl::StrCat(kResultsSuffix, kSizeSuffix))) {
1257 return Status::OK();
1258 }
1259 int64 results_size;
1260 TF_RETURN_IF_ERROR(reader->ReadScalar(
1261 iterator_name, absl::StrCat(kResultsSuffix, kSizeSuffix),
1262 &results_size));
1263 element->results.resize(results_size);
1264 for (size_t i = 0; i < results_size; i++) {
1265 auto result = std::make_shared<Result>();
1266 TF_RETURN_IF_ERROR(
1267 ReadStatusLocked(reader, iterator_name, i, &result->status));
1268 int64 num_return_values;
1269 TF_RETURN_IF_ERROR(reader->ReadScalar(
1270 iterator_name,
1271 absl::StrCat(kResultsSuffix, "[", i, "]", kSizeSuffix),
1272 &num_return_values));
1273 result->return_values.reserve(num_return_values);
1274 for (size_t j = 0; j < num_return_values; j++) {
1275 result->return_values.emplace_back();
1276 TF_RETURN_IF_ERROR(reader->ReadTensor(
1277 iterator_name,
1278 absl::StrCat(kResultsSuffix, "[", i, "][", j, "]"),
1279 &result->return_values.back()));
1280 }
1281 RecordBufferEnqueue(ctx, result->return_values);
1282 element->results[i] = std::move(result);
1283 }
1284 if (!reader->Contains(iterator_name,
1285 absl::StrCat(kInputsSuffix, kSizeSuffix))) {
1286 element->iterator.reset();
1287 *out = std::move(element);
1288 return Status::OK();
1289 }
1290 int64 inputs_size;
1291 TF_RETURN_IF_ERROR(reader->ReadScalar(
1292 iterator_name, absl::StrCat(kInputsSuffix, kSizeSuffix),
1293 &inputs_size));
1294 element->inputs = std::make_unique<std::vector<Tensor>>(inputs_size);
1295 for (int i = 0; i < inputs_size; i++) {
1296 TF_RETURN_IF_ERROR(reader->ReadTensor(
1297 iterator_name, absl::StrCat(kInputsSuffix, "[", i, "]"),
1298 &element->inputs->at(i)));
1299 }
1300 TF_RETURN_IF_ERROR(
1301 reader->ReadScalar(iterator_name, kIdSuffix, &element->id));
1302 TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
1303 ctx, this, *element->inputs, element->id,
1304 *instantiated_captured_func_.get(), prefix(), &iterator,
1305 model_node()));
1306 }
1307 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
1308 mutex_lock l(*mu_);
1309 element->iterator = std::move(iterator);
1310 *out = std::move(element);
1311 return Status::OK();
1312 }
1313
ReadCurrentElements(IteratorContext * ctx,IteratorStateReader * reader)1314 Status ReadCurrentElements(IteratorContext* ctx,
1315 IteratorStateReader* reader) {
1316 int64 size;
1317 {
1318 mutex_lock l(*mu_);
1319 TF_RETURN_IF_ERROR(
1320 reader->ReadScalar(prefix(), kCurrentElementsSize, &size));
1321 if (current_elements_.size() != size) {
1322 // This could mean two things: (1) the user created their checkpoint
1323 // from a dataset with one cycle_length, then changed the cycle_length
1324 // and tried to restore from the old checkpoint, or (2) the user set
1325 // the cycle length to tf.data.AUTOTUNE, wrote the checkpoint from one
1326 // machine, then tried to restore the checkpoint on another machine
1327 // with a different CPU budget (causing autotune to pick a different
1328 // cycle length).
1329 return errors::FailedPrecondition(
1330 "The iterator cycle length ", current_elements_.size(),
1331 " is different from the cycle length to restore from the "
1332 "checkpoint: ",
1333 size);
1334 }
1335 }
1336 if (size == 0) {
1337 return Status::OK();
1338 }
1339 std::vector<std::shared_ptr<Element>> elements;
1340 TF_RETURN_IF_ERROR(
1341 ReadElementsParallel(ctx, reader, size, kCurrentElements, &elements));
1342 mutex_lock l(*mu_);
1343 for (auto& element : current_elements_) {
1344 DCHECK(element == nullptr);
1345 }
1346 for (int idx = 0; idx < size; ++idx) {
1347 current_elements_[idx] = std::move(elements[idx]);
1348 }
1349 return Status::OK();
1350 }
1351
ReadFutureElements(IteratorContext * ctx,IteratorStateReader * reader)1352 Status ReadFutureElements(IteratorContext* ctx,
1353 IteratorStateReader* reader) {
1354 int64 size;
1355 {
1356 mutex_lock l(*mu_);
1357 TF_RETURN_IF_ERROR(
1358 reader->ReadScalar(prefix(), kFutureElementsSize, &size));
1359 future_elements_.resize(size);
1360 }
1361 if (size == 0) {
1362 return Status::OK();
1363 }
1364 std::vector<std::shared_ptr<Element>> elements;
1365 TF_RETURN_IF_ERROR(
1366 ReadElementsParallel(ctx, reader, size, kFutureElements, &elements));
1367 mutex_lock l(*mu_);
1368 for (auto& element : future_elements_) {
1369 DCHECK(element == nullptr);
1370 }
1371 for (int idx = 0; idx < size; ++idx) {
1372 future_elements_[idx] = std::move(elements[idx]);
1373 }
1374 return Status::OK();
1375 }
1376
ReadElementsParallel(IteratorContext * ctx,IteratorStateReader * reader,int64 size,const string & name,std::vector<std::shared_ptr<Element>> * elements)1377 Status ReadElementsParallel(
1378 IteratorContext* ctx, IteratorStateReader* reader, int64 size,
1379 const string& name, std::vector<std::shared_ptr<Element>>* elements) {
1380 elements->resize(size);
1381 Status s = Status::OK();
1382 BlockingCounter counter(size);
1383 for (int idx = 0; idx < size; ++idx) {
1384 thread_pool_->Schedule(
1385 [this, ctx, reader, idx, name, &s, &counter, elements] {
1386 RecordStart(ctx);
1387 auto cleanup = gtl::MakeCleanup([this, ctx, &counter]() {
1388 RecordStop(ctx);
1389 counter.DecrementCount();
1390 });
1391 std::shared_ptr<Element> elem;
1392 Status ret_status = ReadElement(ctx, reader, idx, name, &elem);
1393 mutex_lock l(*mu_);
1394 if (cancelled_) {
1395 s.Update(
1396 errors::Cancelled("Cancelled in ReadElementsParallel"));
1397 return;
1398 }
1399 if (!ret_status.ok()) {
1400 s.Update(ret_status);
1401 return;
1402 }
1403 (*elements)[idx] = elem;
1404 });
1405 }
1406 counter.Wait();
1407 return s;
1408 }
1409
DebugString()1410 std::string DebugString() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
1411 std::string result;
1412 result.append(strings::StrCat("Cycle index: ", cycle_index_, "\n"));
1413 result.append(strings::StrCat("Block index: ", block_index_, "\n"));
1414 result.append(strings::StrCat("End of input: ", end_of_input_, "\n"));
1415 {
1416 result.append("Current elements:\n");
1417 for (int i = 0; i < current_elements_.size(); ++i) {
1418 string element_string = "null";
1419 if (current_elements_[i]) {
1420 element_string = current_elements_[i]->DebugString();
1421 }
1422 result.append(absl::StrFormat("%d: %s\n", i, element_string));
1423 }
1424 }
1425 {
1426 result.append("Future elements:\n");
1427 for (int i = 0; i < future_elements_.size(); ++i) {
1428 string element_string = "null";
1429 if (future_elements_[i]) {
1430 element_string = future_elements_[i]->DebugString();
1431 }
1432 result.append(absl::StrFormat("%d: %s\n", i, element_string));
1433 }
1434 }
1435 return result;
1436 }
1437
1438 // Indices of `current_elements_` which need to be processed by a current
1439 // worker.
1440 std::deque<int> elements_to_process_;
1441
1442 // The last index in `current_elements_` containing a non-null element.
1443 // This allows us to optimize the situation when the cycle_length is large
1444 // but the input dataset doesn't have many elements. By tracking the index
1445 // of the last valid element, GetNext can avoid checking many null entries
1446 // each time through the cycle.
1447 //
1448 // TODO(aaudibert): Generalize this optimization by removing null elements
1449 // from `current_elements_`, e.g. by compacting the vector when x% of
1450 // its elements are null.
1451 int64 last_valid_current_element_ TF_GUARDED_BY(mu_) = -1;
1452
1453 // Identifies whether the current_elements_ vector has been initialized.
1454 bool initial_elements_created_ TF_GUARDED_BY(mu_) = false;
1455
1456 // Identifies whether the element threads have been initialized.
1457 bool threads_initialized_ TF_GUARDED_BY(mu_) = false;
1458
1459 // Used for coordination between the main thread, the manager threads, and
1460 // the worker threads.
1461 //
1462 // NOTE: We should never call GetNext on the input while holding this mutex.
1463 const std::shared_ptr<mutex> mu_;
1464
1465 // Condition variable for waking up current workers.
1466 condition_variable current_workers_cond_var_;
1467
1468 // Condition variable for waking up future workers.
1469 condition_variable future_workers_cond_var_;
1470
1471 // Condition variable for waking up the stats thread.
1472 condition_variable stats_thread_cond_var_;
1473
1474 // Number of active worker threads which might be processing elements,
1475 // including both current workers and future workers. Used by
1476 // checkpointing to wait for outstanding work to finish.
1477 int num_active_workers_ TF_GUARDED_BY(mu_) = 0;
1478
1479 // Number of active current worker threads.
1480 int num_current_active_workers_ TF_GUARDED_BY(mu_) = 0;
1481
1482 // Condition variable notified whenever the total number of active workers
1483 // drops to zero. Used for checkpointing.
1484 condition_variable zero_active_workers_cond_var_;
1485
1486 // Condition notified whenever num_parallel_calls_ changes. Shared so that
1487 // autotuning can notify us when num_parallel_calls_ changes.
1488 std::shared_ptr<condition_variable> num_parallel_calls_cond_var_;
1489
1490 // Identifies the maximum number of parallel calls.
1491 const std::shared_ptr<model::SharedState> num_parallel_calls_;
1492
1493 // The number of current workers currently alive or scheduled to be started.
1494 // This includes current workers which are blocked waiting for work.
1495 int num_current_workers_ TF_GUARDED_BY(mu_) = 0;
1496
1497 // Condition variable to signal that a result has been produced by some
1498 // element thread. Only used when `deterministic` is false.
1499 condition_variable any_element_available_cond_var_;
1500
1501 // Determines whether outputs can be produced in deterministic order.
1502 const bool deterministic_;
1503
1504 // Iterator for input elements.
1505 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
1506
1507 // Identifies position in the interleave cycle.
1508 int64 block_index_ TF_GUARDED_BY(mu_) = 0;
1509 // It is an invariant that either `last_valid_current_element_ == -1` or
1510 // `cycle_index_ <= last_valid_current_element_`.
1511 int64 cycle_index_ TF_GUARDED_BY(mu_) = 0;
1512
1513 // Elements of the current interleave cycle.
1514 std::vector<std::shared_ptr<Element>> current_elements_ TF_GUARDED_BY(mu_);
1515
1516 // Elements which still need their inputs and iterators to be initialized.
1517 // Elements at the front need to be initialized first.
1518 std::deque<std::shared_ptr<Element>> uninitialized_elements_
1519 TF_GUARDED_BY(mu_);
1520
1521 // Elements to be used in the interleave cycle in the future. The element
1522 // at the front is the next element to add to the interleave cycle when a
1523 // current element is exhausted.
1524 std::deque<std::shared_ptr<Element>> future_elements_ TF_GUARDED_BY(mu_);
1525
1526 // Identifies whether the global end of input has been reached.
1527 bool end_of_input_ TF_GUARDED_BY(mu_) = false;
1528
1529 // The number of outstanding element threads.
1530 int outstanding_threads_ TF_GUARDED_BY(mu_) = 0;
1531
1532 // Condition variable notified when outstanding_threads_ drops to 0.
1533 condition_variable outstanding_threads_finished_cond_var_;
1534
1535 std::unique_ptr<thread::ThreadPool> thread_pool_;
1536
1537 int64 element_id_counter_ TF_GUARDED_BY(mu_) = 0;
1538
1539 // Iterator context used in worker threads.
1540 std::unique_ptr<IteratorContext> ctx_;
1541
1542 // Set to true during checkpointing to alert element threads that they
1543 // should pause operation. This is needed to prevent constantly-active
1544 // worker threads from blocking checkpointing indefinitely.
1545 bool wait_for_checkpoint_ = false;
1546
1547 std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
1548
1549 // Identifies whether background threads should be cancelled.
1550 bool cancelled_ TF_GUARDED_BY(mu_) = false;
1551
1552 // Method for deregistering the cancellation callback.
1553 std::function<void()> deregister_fn_;
1554 };
1555
1556 const DatasetBase* const input_;
1557 const std::unique_ptr<CapturedFunction> captured_func_;
1558 const int64 cycle_length_;
1559 const int64 block_length_;
1560 const int64 buffer_output_elements_;
1561 const int64 prefetch_input_elements_;
1562 const int64 num_parallel_calls_;
1563 const DeterminismPolicy deterministic_;
1564 const DataTypeVector output_types_;
1565 const std::vector<PartialTensorShape> output_shapes_;
1566 const int op_version_;
1567 const TraceMeMetadata traceme_metadata_;
1568 };
1569
ParallelInterleaveDatasetOp(OpKernelConstruction * ctx)1570 ParallelInterleaveDatasetOp::ParallelInterleaveDatasetOp(
1571 OpKernelConstruction* ctx)
1572 : UnaryDatasetOpKernel(ctx),
1573 op_version_(OpVersionFromOpName(ctx->def().op())) {
1574 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
1575 &func_metadata_));
1576 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
1577 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
1578 if (op_version_ == 2) {
1579 bool sloppy;
1580 OP_REQUIRES_OK(ctx, ctx->GetAttr(kSloppy, &sloppy));
1581 if (sloppy) {
1582 deterministic_ =
1583 DeterminismPolicy(DeterminismPolicy::Type::kNondeterministic);
1584 } else {
1585 deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault);
1586 }
1587 }
1588 if (op_version_ >= 3) {
1589 std::string deterministic;
1590 OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic));
1591 OP_REQUIRES_OK(
1592 ctx, DeterminismPolicy::FromString(deterministic, &deterministic_));
1593 }
1594 }
1595
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)1596 void ParallelInterleaveDatasetOp::MakeDataset(OpKernelContext* ctx,
1597 DatasetBase* input,
1598 DatasetBase** output) {
1599 int64 block_length = 0;
1600 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBlockLength, &block_length));
1601 OP_REQUIRES(ctx, block_length > 0,
1602 errors::InvalidArgument("`block_length` must be > 0"));
1603
1604 int64 buffer_output_elements = model::kAutotune;
1605 int64 prefetch_input_elements = model::kAutotune;
1606 if (op_version_ >= 4) {
1607 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBufferOutputElements,
1608 &buffer_output_elements));
1609 OP_REQUIRES(ctx,
1610 buffer_output_elements == model::kAutotune ||
1611 buffer_output_elements > 0,
1612 errors::InvalidArgument("`buffer_output_elements` must be ",
1613 model::kAutotune, " or > 0 but is ",
1614 buffer_output_elements));
1615
1616 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kPrefetchInputElements,
1617 &prefetch_input_elements));
1618 OP_REQUIRES(ctx,
1619 prefetch_input_elements == model::kAutotune ||
1620 prefetch_input_elements >= 0,
1621 errors::InvalidArgument("`prefetch_input_elements` must be ",
1622 model::kAutotune, " or >= 0 but is ",
1623 prefetch_input_elements));
1624 }
1625
1626 int64 num_parallel_calls = 0;
1627 OP_REQUIRES_OK(
1628 ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
1629 OP_REQUIRES(
1630 ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
1631 errors::InvalidArgument("num_parallel_calls must be greater than zero."));
1632 int64 cycle_length = 0;
1633 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kCycleLength, &cycle_length));
1634 if (cycle_length == model::kAutotune) {
1635 if (num_parallel_calls != model::kAutotune) {
1636 cycle_length = std::min(num_parallel_calls,
1637 static_cast<int64>(port::MaxParallelism()));
1638 } else {
1639 // If parallelism is to be autotuned, we set the cycle length so that
1640 // the number of thread created for the current and future cycle elements
1641 // roughly matches the number of schedulable cores.
1642 const int num_threads_per_cycle_length = kDefaultCyclePrefetchFactor + 1;
1643 cycle_length =
1644 CeilDiv(port::MaxParallelism(), num_threads_per_cycle_length);
1645 }
1646 }
1647 OP_REQUIRES(ctx, cycle_length > 0,
1648 errors::InvalidArgument("`cycle_length` must be > 0"));
1649
1650 OP_REQUIRES(
1651 ctx, num_parallel_calls <= cycle_length,
1652 errors::InvalidArgument(
1653 "num_parallel_calls must less than or equal to cycle_length."));
1654
1655 std::unique_ptr<CapturedFunction> captured_func;
1656 OP_REQUIRES_OK(ctx,
1657 CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
1658 &captured_func));
1659
1660 if (num_parallel_calls == model::kAutotune) {
1661 metrics::RecordTFDataAutotune(kDatasetType);
1662 }
1663
1664 *output = new Dataset(
1665 ctx, input, std::move(captured_func), cycle_length, block_length,
1666 buffer_output_elements, prefetch_input_elements, num_parallel_calls,
1667 deterministic_, output_types_, output_shapes_, op_version_);
1668 }
1669
1670 namespace {
1671 REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV2).Device(DEVICE_CPU),
1672 ParallelInterleaveDatasetOp);
1673 REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV3).Device(DEVICE_CPU),
1674 ParallelInterleaveDatasetOp);
1675 REGISTER_KERNEL_BUILDER(Name(kParallelInterleaveDatasetV4).Device(DEVICE_CPU),
1676 ParallelInterleaveDatasetOp);
1677 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV2);
1678 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV3);
1679 REGISTER_INPUT_COLOCATION_EXEMPTION(kParallelInterleaveDatasetV4);
1680 } // namespace
1681 } // namespace data
1682 } // namespace tensorflow
1683