• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.h"
16 
17 #include <atomic>
18 #include <utility>
19 
20 #include "tensorflow/core/common_runtime/function.h"
21 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
22 #include "tensorflow/core/data/dataset_utils.h"
23 #include "tensorflow/core/data/name_utils.h"
24 #include "tensorflow/core/data/stats_utils.h"
25 #include "tensorflow/core/framework/metrics.h"
26 #include "tensorflow/core/framework/model.h"
27 #include "tensorflow/core/framework/partial_tensor_shape.h"
28 #include "tensorflow/core/framework/stats_aggregator.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/kernels/inplace_ops_functor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/cleanup.h"
33 #include "tensorflow/core/lib/random/random.h"
34 #include "tensorflow/core/lib/strings/strcat.h"
35 #include "tensorflow/core/platform/cpu_info.h"
36 #include "tensorflow/core/platform/env_time.h"
37 #include "tensorflow/core/platform/status.h"
38 #include "tensorflow/core/platform/stringprintf.h"
39 #include "tensorflow/core/platform/tracing.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 #include "tensorflow/core/profiler/lib/traceme_encode.h"
42 
43 namespace tensorflow {
44 namespace data {
45 namespace experimental {
46 
47 /* static */ constexpr const char* const MapAndBatchDatasetOp::kDatasetType;
48 /* static */ constexpr const char* const MapAndBatchDatasetOp::kInputDataset;
49 /* static */ constexpr const char* const MapAndBatchDatasetOp::kOtherArguments;
50 /* static */ constexpr const char* const MapAndBatchDatasetOp::kBatchSize;
51 /* static */ constexpr const char* const
52     MapAndBatchDatasetOp::kNumParallelCalls;
53 /* static */ constexpr const char* const MapAndBatchDatasetOp::kDropRemainder;
54 /* static */ constexpr const char* const MapAndBatchDatasetOp::kFunc;
55 /* static */ constexpr const char* const MapAndBatchDatasetOp::kTarguments;
56 /* static */ constexpr const char* const MapAndBatchDatasetOp::kOutputTypes;
57 /* static */ constexpr const char* const MapAndBatchDatasetOp::kOutputShapes;
58 /* static */ constexpr const char* const
59     MapAndBatchDatasetOp::kPreserveCardinality;
60 
61 // Maximum number of batch results to buffer.
62 
63 namespace {
64 
65 constexpr int64_t kMaxBatchResults = 16;
66 constexpr char kParallelism[] = "parallelism";
67 constexpr char kCallCounter[] = "call_counter";
68 constexpr char kBatchResultsSize[] = "batch_results_size";
69 constexpr char kTFDataMapAndBatch[] = "tf_data_map_and_batch";
70 constexpr char kBatchResults[] = "batch_results";
71 constexpr char kEndOfInput[] = "end_of_input";
72 constexpr char kNumCalls[] = "num_calls";
73 constexpr char kNumElements[] = "num_elements";
74 constexpr char kOutputAllocated[] = "output_allocated";
75 constexpr char kStatus[] = "status";
76 
77 // Computes ceil(x / y).
CeilDiv(int64_t x,int64_t y)78 inline int64_t CeilDiv(int64_t x, int64_t y) { return (x + y - 1) / y; }
79 
80 }  // namespace
81 
82 class MapAndBatchDatasetOp::Dataset : public DatasetBase {
83  public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,int64_t batch_size,int64_t num_parallel_calls,bool drop_remainder,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes,std::unique_ptr<CapturedFunction> captured_func,bool preserve_cardinality)84   Dataset(OpKernelContext* ctx, const DatasetBase* input, int64_t batch_size,
85           int64_t num_parallel_calls, bool drop_remainder,
86           const DataTypeVector& output_types,
87           const std::vector<PartialTensorShape>& output_shapes,
88           std::unique_ptr<CapturedFunction> captured_func,
89           bool preserve_cardinality)
90       : DatasetBase(DatasetContext(ctx)),
91         input_(input),
92         batch_size_(batch_size),
93         num_parallel_calls_(num_parallel_calls),
94         drop_remainder_(drop_remainder),
95         output_types_(output_types),
96         output_shapes_(output_shapes),
97         captured_func_(std::move(captured_func)),
98         preserve_cardinality_(preserve_cardinality),
99         traceme_metadata_(
100             {{"autotune",
101               num_parallel_calls == model::kAutotune ? "true" : "false"},
102              {"batch_size",
103               strings::Printf("%lld", static_cast<long long>(batch_size))},
104              {"drop_remainder", drop_remainder ? "true" : "false"}}) {
105     input_->Ref();
106   }
107 
~Dataset()108   ~Dataset() override { input_->Unref(); }
109 
MakeIteratorInternal(const string & prefix) const110   std::unique_ptr<IteratorBase> MakeIteratorInternal(
111       const string& prefix) const override {
112     return std::make_unique<Iterator>(Iterator::Params{
113         this, name_utils::IteratorPrefix(kDatasetType, prefix)});
114   }
115 
output_dtypes() const116   const DataTypeVector& output_dtypes() const override { return output_types_; }
117 
output_shapes() const118   const std::vector<PartialTensorShape>& output_shapes() const override {
119     return output_shapes_;
120   }
121 
DebugString() const122   string DebugString() const override {
123     return name_utils::DatasetDebugString(kDatasetType);
124   }
125 
CardinalityInternal() const126   int64_t CardinalityInternal() const override {
127     if (!preserve_cardinality_) {
128       return kUnknownCardinality;
129     }
130     int64_t n = input_->Cardinality();
131     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
132       return n;
133     }
134     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
135   }
136 
InputDatasets(std::vector<const DatasetBase * > * inputs) const137   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
138     inputs->push_back(input_);
139     return OkStatus();
140   }
141 
CheckExternalState() const142   Status CheckExternalState() const override {
143     TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
144     return input_->CheckExternalState();
145   }
146 
147  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const148   Status AsGraphDefInternal(SerializationContext* ctx,
149                             DatasetGraphDefBuilder* b,
150                             Node** output) const override {
151     Node* input_graph_node = nullptr;
152     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
153     Node* batch_size_node;
154     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size_node));
155     Node* num_parallel_calls_node;
156     TF_RETURN_IF_ERROR(
157         b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
158     Node* drop_remainder_node;
159     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
160     std::vector<Node*> other_arguments;
161     DataTypeVector other_arguments_types;
162     TF_RETURN_IF_ERROR(captured_func_->AddToGraph(ctx, b, &other_arguments,
163                                                   &other_arguments_types));
164     AttrValue f;
165     b->BuildAttrValue(captured_func_->func(), &f);
166     AttrValue other_arguments_types_attr;
167     b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
168     AttrValue preserve_cardinality_attr;
169     b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr);
170 
171     TF_RETURN_IF_ERROR(b->AddDataset(
172         this,
173         {std::make_pair(0, input_graph_node),
174          std::make_pair(2, batch_size_node),
175          std::make_pair(3, num_parallel_calls_node),
176          std::make_pair(4, drop_remainder_node)},  // Single tensor inputs.
177         {std::make_pair(1, other_arguments)},      // Tensor list inputs.
178         {std::make_pair(kFunc, f),
179          std::make_pair(kTarguments, other_arguments_types_attr),
180          std::make_pair(kPreserveCardinality,
181                         preserve_cardinality_attr)},  // Attrs
182         output));
183     return OkStatus();
184   }
185 
186  private:
187   class Iterator : public DatasetIterator<Dataset> {
188    public:
Iterator(const Params & params)189     explicit Iterator(const Params& params)
190         : DatasetIterator<Dataset>(params),
191           mu_(std::make_shared<mutex>()),
192           cond_var_(std::make_shared<condition_variable>()),
193           num_parallel_calls_(std::make_shared<model::SharedState>(
194               params.dataset->num_parallel_calls_, mu_, cond_var_)) {
195       // To mitigate the effect of stragglers (i.e. map invocations that take
196       // much longer than others), we allow the kernel to pre-compute batches
197       // ahead of time and store them in an internal buffer. The maximum number
198       // of batches to buffer is a trade-off between performance and memory and
199       // we derive it from the degree of parallelism and the batch size.
200       //
201       // TODO(b/178059273): If we handle RAM budget correctly, the upper bound
202       // should be removed.
203       max_batch_results_ = std::min(
204           kMaxBatchResults,
205           CeilDiv(params.dataset->num_parallel_calls_ == model::kAutotune
206                       ? GetCpuBudget()  // maximum parallelism
207                       : params.dataset->num_parallel_calls_,
208                   params.dataset->batch_size_));
209     }
210 
~Iterator()211     ~Iterator() override {
212       CancelThreads(/*wait=*/true);
213       if (deregister_fn_) deregister_fn_();
214     }
215 
Initialize(IteratorContext * ctx)216     Status Initialize(IteratorContext* ctx) override {
217       mutex_lock l(*mu_);
218       interleave_depth_ = ctx->interleave_depth();
219 
220       if (num_parallel_calls_->value == model::kAutotune) {
221         num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx);
222       }
223       cancellation_manager_ = std::make_unique<CancellationManager>();
224       TF_RETURN_IF_ERROR(RegisterCancellationCallback(
225           ctx->cancellation_manager(),
226           [this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
227       IteratorContext::Params params(ctx);
228       params.cancellation_manager = cancellation_manager_.get();
229       TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
230           IteratorContext(params), this, prefix(), &input_impl_));
231       return dataset()->captured_func_->Instantiate(
232           ctx, &instantiated_captured_func_);
233     }
234 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)235     Status GetNextInternal(IteratorContext* ctx,
236                            std::vector<Tensor>* out_tensors,
237                            bool* end_of_sequence) override {
238       std::shared_ptr<BatchResult> result;
239       {
240         mutex_lock l(*mu_);
241         EnsureRunnerThreadStarted(ctx);
242         while (!cancelled_ && (batch_results_.empty() ||
243                                batch_results_.front()->num_calls > 0)) {
244           ++waiting_;
245           RecordStop(ctx);
246           cond_var_->wait(l);
247           RecordStart(ctx);
248           --waiting_;
249         }
250         if (cancelled_) {
251           return errors::Cancelled("Iterator was cancelled");
252         }
253         std::swap(result, batch_results_.front());
254         batch_results_.pop_front();
255         cond_var_->notify_all();
256       }
257       profiler::TraceMe traceme([&] {
258         return profiler::TraceMeEncode("MapAndBatchConsume",
259                                        {{"element_id", result->uid}});
260       });
261       // Deallocate tensors allocated for the output.
262       auto cleanup = gtl::MakeCleanup([result] { result->output.clear(); });
263       mutex_lock l(result->mu);
264       if (result->output_allocated) {
265         RecordBufferDequeue(ctx, result->output);
266       }
267       TF_RETURN_IF_ERROR(
268           ProcessBatch(dataset()->batch_size_, result->num_elements,
269                        dataset()->drop_remainder_, result->status, ctx,
270                        out_tensors, end_of_sequence, &result->output));
271       return OkStatus();
272     }
273 
274    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const275     std::shared_ptr<model::Node> CreateNode(
276         IteratorContext* ctx, model::Node::Args args) const override {
277       return model::MakeAsyncKnownRatioNode(
278           std::move(args), dataset()->batch_size_,
279           {model::MakeParameter(kParallelism, num_parallel_calls_, /*min=*/1,
280                                 /*max=*/ctx->runner_threadpool_size())});
281     }
282 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)283     Status SaveInternal(SerializationContext* ctx,
284                         IteratorStateWriter* writer) override {
285       TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
286           dataset()->captured_func_->CheckExternalState()));
287       mutex_lock l(*mu_);
288       // Wait for all in-flight calls to complete.
289       while (num_calls_ > 0) {
290         cond_var_->wait(l);
291       }
292       DCHECK_EQ(num_calls_, 0);
293       TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
294       TF_RETURN_IF_ERROR(
295           writer->WriteScalar(full_name(kCallCounter), call_counter_));
296       TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBatchResultsSize),
297                                              batch_results_.size()));
298       for (size_t i = 0; i < batch_results_.size(); ++i) {
299         TF_RETURN_IF_ERROR(WriteBatchResult(writer, i));
300       }
301       return OkStatus();
302     }
303 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)304     Status RestoreInternal(IteratorContext* ctx,
305                            IteratorStateReader* reader) override {
306       mutex_lock l(*mu_);
307       TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
308       TF_RETURN_IF_ERROR(
309           reader->ReadScalar(full_name(kCallCounter), &call_counter_));
310       int64_t batch_results_size;
311       TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kBatchResultsSize),
312                                             &batch_results_size));
313       DCHECK(batch_results_.empty());
314       for (int i = 0; i < batch_results_size; ++i) {
315         TF_RETURN_IF_ERROR(ReadBatchResult(ctx, reader, i));
316       }
317       return OkStatus();
318     }
319 
GetTraceMeMetadata() const320     TraceMeMetadata GetTraceMeMetadata() const override {
321       int64_t parallelism = -1;
322       int64_t max_batch_results = -1;
323       // NOTE: We only set the parallelism value if the lock can be acquired
324       // right away to avoid introducing tracing overhead.
325       if (mu_->try_lock()) {
326         parallelism = num_parallel_calls_->value;
327         max_batch_results = max_batch_results_;
328         mu_->unlock();
329       }
330       auto result = dataset()->traceme_metadata_;
331       result.push_back(std::make_pair(
332           "max_batch_results",
333           strings::Printf("%lld", static_cast<long long>(max_batch_results))));
334       result.push_back(std::make_pair(
335           "parallelism",
336           parallelism == -1
337               ? kTraceInfoUnavailable
338               : strings::Printf("%lld", static_cast<long long>(parallelism))));
339       result.push_back(std::make_pair(
340           "interleave_depth",
341           strings::Printf("%lld", static_cast<long long>(interleave_depth_))));
342       return result;
343     }
344 
345    private:
346     // BatchResult encapsulates the output batch, as well as ancillary
347     // metadata required to execute the fused map-and-batch operation.
348     struct BatchResult {
BatchResulttensorflow::data::experimental::MapAndBatchDatasetOp::Dataset::Iterator::BatchResult349       explicit BatchResult(int64_t batch_size)
350           : end_of_input(false),
351             num_elements(0),
352             output_allocated(false),
353             status(OkStatus()),
354             status_offset(-1),
355             num_calls(batch_size),
356             uid(tensorflow::EnvTime::NowNanos()) {}
357 
358       // UpdateStatus updates the batch's aggregate Status.
359       //
360       // In order to ensure that exactly the first non-OK status is returned
361       // (required to make the behavior is observably identical to a
362       // sequential execution of map followed by batch), we must also keep
363       // track of the offset into the batch that produced `s`.
UpdateStatustensorflow::data::experimental::MapAndBatchDatasetOp::Dataset::Iterator::BatchResult364       void UpdateStatus(const Status& s, int64_t offset) {
365         if (TF_PREDICT_FALSE(!s.ok())) {
366           mutex_lock l(mu);
367           if (status.ok() || offset < status_offset) {
368             status = s;
369             status_offset = offset;
370           }
371         }
372       }
373 
374       mutex mu;
375       bool end_of_input TF_GUARDED_BY(mu);
376       int64_t num_elements TF_GUARDED_BY(mu);
377       std::vector<Tensor> output;
378       bool output_allocated TF_GUARDED_BY(mu);
379       Status status TF_GUARDED_BY(mu);
380       int64_t status_offset TF_GUARDED_BY(mu);
381       // Counts the number of outstanding calls for this batch.
382       int64_t num_calls TF_GUARDED_BY(&Iterator::mu_);
383       const uint64 uid = -1;
384     };
385 
CallCompleted(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<BatchResult> & result)386     void CallCompleted(const std::shared_ptr<IteratorContext>& ctx,
387                        const std::shared_ptr<BatchResult>& result)
388         TF_LOCKS_EXCLUDED(*mu_) {
389       mutex_lock l(*mu_);
390       num_calls_--;
391       result->num_calls--;
392       const auto& stats_aggregator = ctx->stats_aggregator();
393       if (stats_aggregator) {
394         stats_aggregator->AddScalar(
395             stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
396             static_cast<float>(num_calls_) /
397                 static_cast<float>(num_parallel_calls_->value),
398             num_elements());
399       }
400       cond_var_->notify_all();
401     }
402 
CallFunction(std::shared_ptr<IteratorContext> ctx,const std::shared_ptr<BatchResult> & result,int64_t offset)403     void CallFunction(std::shared_ptr<IteratorContext> ctx,
404                       const std::shared_ptr<BatchResult>& result,
405                       int64_t offset) TF_LOCKS_EXCLUDED(*mu_) {
406       profiler::TraceMe traceme([&] {
407         return profiler::TraceMeEncode("MapAndBatchProduce",
408                                        {{"element_id", result->uid}});
409       });
410       // Get the next input element.
411       std::vector<Tensor> input_element;
412       bool end_of_input = false;
413       Status status =
414           input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
415       bool return_early;
416       {
417         mutex_lock l(result->mu);
418         result->end_of_input = result->end_of_input || end_of_input;
419         result->status.Update(status);
420         return_early = result->end_of_input || !result->status.ok();
421       }
422       if (return_early) {
423         CallCompleted(ctx, result);
424         return;
425       }
426 
427       std::shared_ptr<std::vector<Tensor>> return_values =
428           std::make_shared<std::vector<Tensor>>();
429       auto done = [this, ctx, result, return_values, offset](Status status) {
430         if (dataset()->preserve_cardinality_ && errors::IsOutOfRange(status)) {
431           // To guarantee that the transformation preserves the cardinality of
432           // the dataset, we convert `OutOfRange` to `InvalidArgument` as the
433           // former may be interpreted by a caller as the end of sequence.
434           status = errors::InvalidArgument(
435               "Function invocation produced OutOfRangeError: ",
436               status.error_message());
437         }
438         result->UpdateStatus(status, offset);
439         if (status.ok()) {
440           Status allocate_status =
441               EnsureOutputAllocated(ctx, result, return_values);
442           if (!allocate_status.ok()) {
443             result->UpdateStatus(allocate_status, offset);
444           } else {
445             for (size_t i = 0; i < return_values->size(); ++i) {
446               Tensor& tensor = return_values->at(i);
447               Tensor* batch = &(result->output)[i];
448               if (tensor.NumElements() !=
449                   (batch->NumElements() / batch->dim_size(0))) {
450                 TensorShape batch_shape = batch->shape();
451                 batch_shape.RemoveDim(0);
452                 result->UpdateStatus(
453                     errors::InvalidArgument(
454                         "Cannot add tensor to the batch: number of elements "
455                         "does not match. Shapes are: [tensor]: ",
456                         tensor.shape().DebugString(),
457                         ", [batch]: ", batch_shape.DebugString()),
458                     offset);
459                 break;
460               }
461               // TODO(mrry): Add a version of DoParallelConcat that allows us
462               // to move `tensor` where possible, to speed up string tensor
463               // batching.
464               Status copy_status = batch_util::CopyElementToSlice(
465                   std::move(tensor), batch, offset);
466               if (!copy_status.ok()) {
467                 result->UpdateStatus(copy_status, offset);
468                 break;
469               }
470             }
471           }
472           {
473             mutex_lock l(result->mu);
474             result->num_elements++;
475           }
476         }
477         CallCompleted(ctx, result);
478       };
479 
480       // Apply the map function on `input_element`, storing the result in
481       // `return_values`, and invoking `done` when finished.
482       instantiated_captured_func_->RunAsync(ctx.get(), std::move(input_element),
483                                             return_values.get(),
484                                             std::move(done), model_node());
485     }
486 
CancelThreads(bool wait)487     void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
488       cancellation_manager_->StartCancel();
489       mutex_lock l(*mu_);
490       cancelled_ = true;
491       cond_var_->notify_all();
492       // Wait for all in-flight calls to complete.
493       while (wait && num_calls_ > 0) {
494         cond_var_->wait(l);
495       }
496     }
497 
EnsureRunnerThreadStarted(IteratorContext * ctx)498     void EnsureRunnerThreadStarted(IteratorContext* ctx)
499         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
500       if (!runner_thread_) {
501         auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
502         runner_thread_ = ctx->StartThread(
503             kTFDataMapAndBatch,
504             std::bind(&Iterator::RunnerThread, this, ctx_copy));
505       }
506     }
507 
EnsureOutputAllocated(const std::shared_ptr<IteratorContext> & ctx,const std::shared_ptr<BatchResult> & result,const std::shared_ptr<std::vector<Tensor>> & return_values)508     Status EnsureOutputAllocated(
509         const std::shared_ptr<IteratorContext>& ctx,
510         const std::shared_ptr<BatchResult>& result,
511         const std::shared_ptr<std::vector<Tensor>>& return_values) {
512       mutex_lock l(result->mu);
513       if (result->output_allocated) {
514         return OkStatus();
515       }
516       const size_t num_components = return_values->size();
517       result->output.reserve(num_components);
518       for (size_t i = 0; i < num_components; ++i) {
519         TensorShape component_shape({dataset()->batch_size_});
520         component_shape.AppendShape(return_values->at(i).shape());
521         AllocatorAttributes attr;
522         attr.set_gpu_compatible(true);
523         result->output.emplace_back(ctx->allocator(attr),
524                                     return_values->at(i).dtype(),
525                                     component_shape);
526         if (!result->output.back().IsInitialized()) {
527           return errors::ResourceExhausted(
528               "Failed to allocate memory for the batch of component ", i);
529         }
530       }
531       RecordBufferEnqueue(ctx.get(), result->output);
532       result->output_allocated = true;
533       return OkStatus();
534     }
535 
RunnerThread(const std::shared_ptr<IteratorContext> & ctx)536     void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
537         TF_LOCKS_EXCLUDED(*mu_) {
538       std::vector<std::pair<std::shared_ptr<BatchResult>, int64_t>> new_calls;
539       RecordStart(ctx.get());
540       auto stop_cleanup =
541           gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
542       {
543         tf_shared_lock l(*mu_);  // mu_ == num_parallel_calls_->mu
544         new_calls.reserve(num_parallel_calls_->value);
545       }
546       auto busy = [this]() TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
547         int64_t num_parallel_calls = num_parallel_calls_->value;
548         return num_calls_ >= num_parallel_calls ||
549                (batch_results_.size() > max_batch_results_ ||
550                 (batch_results_.size() == max_batch_results_ &&
551                  call_counter_ % dataset()->batch_size_ == 0));
552       };
553       while (true) {
554         {
555           mutex_lock l(*mu_);
556           while (!cancelled_ && busy()) {
557             if (waiting_ > 0 && num_calls_ < num_parallel_calls_->value &&
558                 max_batch_results_ < kMaxBatchResults) {
559               // If there is a caller waiting for a batch and the number of
560               // outstanding calls is not maxed out, it means we are out of
561               // `batch_results_` slots. Instead of waiting for a slot to open
562               // up, we create a new one to utilize CPU efficiently.
563               max_batch_results_++;
564               continue;
565             }
566             RecordStop(ctx.get());
567             cond_var_->wait(l);
568             RecordStart(ctx.get());
569           }
570 
571           if (cancelled_) {
572             return;
573           }
574 
575           while (!busy()) {
576             if (call_counter_ % dataset()->batch_size_ == 0) {
577               batch_results_.push_back(
578                   std::make_shared<BatchResult>(dataset()->batch_size_));
579             }
580             int64_t offset = call_counter_++ % dataset()->batch_size_;
581             new_calls.emplace_back(batch_results_.back(), offset);
582             num_calls_++;
583           }
584         }
585         const auto& stats_aggregator = ctx->stats_aggregator();
586         if (stats_aggregator) {
587           mutex_lock l(*mu_);
588           stats_aggregator->AddScalar(
589               stats_utils::ThreadUtilizationScalarName(dataset()->node_name()),
590               static_cast<float>(num_calls_) /
591                   static_cast<float>(num_parallel_calls_->value),
592               num_elements());
593         }
594         for (const auto& call : new_calls) {
595           CallFunction(ctx, call.first, call.second);
596         }
597         new_calls.clear();
598       }
599     }
600 
ReadBatchResult(IteratorContext * ctx,IteratorStateReader * reader,size_t index)601     Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
602                            size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
603       batch_results_.push_back(
604           std::make_shared<BatchResult>(dataset()->batch_size_));
605       std::shared_ptr<BatchResult> result = batch_results_.back();
606       string batch_prefix = strings::StrCat(kBatchResults, "_", index);
607       mutex_lock l(result->mu);
608       result->end_of_input = reader->Contains(
609           full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)));
610       TF_RETURN_IF_ERROR(reader->ReadScalar(
611           full_name(strings::StrCat(batch_prefix, "_", kNumCalls)),
612           &result->num_calls));
613       TF_RETURN_IF_ERROR(reader->ReadScalar(
614           full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
615           &result->num_elements));
616       result->output_allocated = reader->Contains(
617           full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)));
618 
619       TF_RETURN_IF_ERROR(ReadBatch(ctx, reader, dataset()->batch_size_,
620                                    prefix(), batch_prefix, &result->output));
621       TF_RETURN_IF_ERROR(ReadStatus(prefix(),
622                                     strings::StrCat(batch_prefix, "_", kStatus),
623                                     reader, &result->status));
624       if (result->output_allocated) {
625         RecordBufferEnqueue(ctx, result->output);
626       }
627       return OkStatus();
628     }
629 
WriteBatchResult(IteratorStateWriter * writer,size_t index)630     Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
631         TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
632       std::shared_ptr<BatchResult> result = batch_results_[index];
633       string batch_prefix = strings::StrCat(kBatchResults, "_", index);
634       mutex_lock l(result->mu);
635       if (result->end_of_input) {
636         TF_RETURN_IF_ERROR(writer->WriteScalar(
637             full_name(strings::StrCat(batch_prefix, "_", kEndOfInput)), ""));
638       }
639       TF_RETURN_IF_ERROR(writer->WriteScalar(
640           full_name(strings::StrCat(batch_prefix, "_", kNumCalls)),
641           result->num_calls));
642       TF_RETURN_IF_ERROR(writer->WriteScalar(
643           full_name(strings::StrCat(batch_prefix, "_", kNumElements)),
644           result->num_elements));
645       if (result->output_allocated) {
646         TF_RETURN_IF_ERROR(writer->WriteScalar(
647             full_name(strings::StrCat(batch_prefix, "_", kOutputAllocated)),
648             ""));
649       }
650 
651       TF_RETURN_IF_ERROR(WriteBatch(dataset()->batch_size_,
652                                     result->num_elements, prefix(),
653                                     batch_prefix, writer, &result->output));
654       TF_RETURN_IF_ERROR(
655           WriteStatus(prefix(), strings::StrCat(batch_prefix, "_", kStatus),
656                       result->status, writer));
657       return OkStatus();
658     }
659 
660     // Used for coordination between the main thread, the runner thread, and
661     // the callback threads.
662     const std::shared_ptr<mutex> mu_;
663     // Used for coordination between the main thread, the runner thread, and
664     // the callback threads. In particular, the runner thread should only
665     // schedule new calls when the number of in-flight calls is less than
666     // `num_parallel_calls_->value` and there are slots available in the
667     // `batch_results_` buffer.
668     const std::shared_ptr<condition_variable> cond_var_;
669     // Identifies the maximum number of parallel calls.
670     const std::shared_ptr<model::SharedState> num_parallel_calls_;
671 
672     // Controls cancellation of `input_impl_`. Must be ordered before
673     // `input_impl_` so that `input_impl_` is destroyed first.
674     std::unique_ptr<CancellationManager> cancellation_manager_;
675     // Counts the number of outstanding calls for this batch.
676     int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0;
677     // Counts the total number of calls.
678     int64_t call_counter_ TF_GUARDED_BY(*mu_) = 0;
679     std::unique_ptr<IteratorBase> input_impl_;
680     // Buffer for storing the (intermediate) batch results. Whenever an
681     // output-allocated batch result is added to or removed from
682     // `batch_results_`, call `RecordBufferEnqueue` or `RecordBufferDequeue`
683     // respectively.
684     std::deque<std::shared_ptr<BatchResult>> batch_results_ TF_GUARDED_BY(*mu_);
685     // Determines whether the transformation has been cancelled.
686     bool cancelled_ TF_GUARDED_BY(*mu_) = false;
687     // Identifies the number of callers currently waiting for a batch result.
688     int64_t waiting_ TF_GUARDED_BY(*mu_) = 0;
689     // Identifies the maximum number of batch results to store.
690     int64_t max_batch_results_ TF_GUARDED_BY(*mu_);
691     std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
692 
693     // Method for deregistering the cancellation callback.
694     std::function<void()> deregister_fn_;
695 
696     // Records the number of ParallelInterleave operations in the path from the
697     // root node to this node (not including this node) in the input pipeline
698     // tree. We record the interleave depth so that it can be included in the
699     // trace metadata.
700     int64 interleave_depth_ = -1;
701     // Background thread used for coordinating input processing. The thread
702     // should be destroyed before the variables it accesses are destroyed.
703     std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
704   };
705 
706   const DatasetBase* const input_;
707   const int64_t batch_size_;
708   const int64_t num_parallel_calls_;
709   const bool drop_remainder_;
710   const DataTypeVector output_types_;
711   const std::vector<PartialTensorShape> output_shapes_;
712   const std::unique_ptr<CapturedFunction> captured_func_;
713   const bool preserve_cardinality_;
714   const TraceMeMetadata traceme_metadata_;
715 };
716 
MapAndBatchDatasetOp(OpKernelConstruction * ctx)717 MapAndBatchDatasetOp::MapAndBatchDatasetOp(OpKernelConstruction* ctx)
718     : UnaryDatasetOpKernel(ctx) {
719   OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
720                                                &func_metadata_));
721   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
722   OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
723   OP_REQUIRES_OK(ctx,
724                  ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_));
725 }
726 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)727 void MapAndBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
728                                        DatasetBase** output) {
729   int64_t batch_size = 0;
730   OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kBatchSize, &batch_size));
731   OP_REQUIRES(ctx, batch_size > 0,
732               errors::InvalidArgument("batch_size must be greater than zero."));
733 
734   int64_t num_parallel_calls = 0;
735   OP_REQUIRES_OK(
736       ctx, ParseScalarArgument(ctx, kNumParallelCalls, &num_parallel_calls));
737   OP_REQUIRES(
738       ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutotune,
739       errors::InvalidArgument("num_parallel_calls must be greater than zero."));
740 
741   bool drop_remainder;
742   OP_REQUIRES_OK(ctx,
743                  ParseScalarArgument(ctx, kDropRemainder, &drop_remainder));
744 
745   std::unique_ptr<CapturedFunction> captured_func;
746   OP_REQUIRES_OK(ctx,
747                  CapturedFunction::Create(ctx, func_metadata_, kOtherArguments,
748                                           &captured_func));
749 
750   if (num_parallel_calls == model::kAutotune) {
751     metrics::RecordTFDataAutotune(kDatasetType);
752   }
753 
754   *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
755                         drop_remainder, output_types_, output_shapes_,
756                         std::move(captured_func), preserve_cardinality_);
757 }
758 
759 namespace {
760 REGISTER_KERNEL_BUILDER(Name("MapAndBatchDataset").Device(DEVICE_CPU),
761                         MapAndBatchDatasetOp);
762 REGISTER_KERNEL_BUILDER(
763     Name("ExperimentalMapAndBatchDataset").Device(DEVICE_CPU),
764     MapAndBatchDatasetOp);
765 
766 REGISTER_INPUT_COLOCATION_EXEMPTION("MapAndBatchDataset");
767 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalMapAndBatchDataset");
768 }  // namespace
769 }  // namespace experimental
770 }  // namespace data
771 }  // namespace tensorflow
772