• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
16 
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/lib/core/blocking_counter.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36 
37 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const PaddedBatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const PaddedBatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddedShapes;
41 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddingValues;
42 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDropRemainder;
43 /* static */ constexpr const char* const PaddedBatchDatasetOp::kParallelCopy;
44 /* static */ constexpr const char* const PaddedBatchDatasetOp::kToutputTypes;
45 /* static */ constexpr const char* const PaddedBatchDatasetOp::kOutputShapes;
46 /* static */ constexpr const char* const PaddedBatchDatasetOp::kNumPaddedShapes;
47 
48 constexpr char kExhausted[] = "exhausted";
49 
50 class PaddedBatchDatasetOp::Dataset : public DatasetBase {
51  public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,std::vector<PartialTensorShape> padded_shapes,std::vector<Tensor> padding_values,const DatasetBase * input,int op_version)52   Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
53           bool parallel_copy, std::vector<PartialTensorShape> padded_shapes,
54           std::vector<Tensor> padding_values, const DatasetBase* input,
55           int op_version)
56       : DatasetBase(DatasetContext(ctx)),
57         batch_size_(batch_size),
58         drop_remainder_(drop_remainder),
59         parallel_copy_(parallel_copy),
60         padded_shapes_(std::move(padded_shapes)),
61         padding_values_(std::move(padding_values)),
62         input_(input),
63         op_version_(op_version),
64         traceme_metadata_(
65             {{"batch_size",
66               strings::Printf("%lld", static_cast<long long>(batch_size))},
67              {"drop_remainder", drop_remainder ? "true" : "false"},
68              {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69     input_->Ref();
70 
71     // NOTE(mrry): Currently we implement "batch up to" semantics. If we could
72     // tell statically that the input dataset is infinite, then we could
73     // always report `batch_size` as the 0th dimension.
74     //
75     // TODO(mrry): Need to validate that the input shape and the padded shape
76     // are "compatible" (i.e. that padded shape is >= input shape, with both
77     // static and dynamic checks as appropriate).
78     const auto& input_shapes = input_->output_shapes();
79     output_shapes_.reserve(input_shapes.size());
80     for (size_t i = 0; i < input_shapes.size(); ++i) {
81       if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
82         output_shapes_.push_back(
83             PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
84       } else {
85         output_shapes_.push_back(
86             PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
87       }
88     }
89   }
90 
~Dataset()91   ~Dataset() override { input_->Unref(); }
92 
MakeIteratorInternal(const string & prefix) const93   std::unique_ptr<IteratorBase> MakeIteratorInternal(
94       const string& prefix) const override {
95     name_utils::IteratorPrefixParams params;
96     params.op_version = op_version_;
97     return absl::make_unique<Iterator>(Iterator::Params{
98         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
99   }
100 
output_dtypes() const101   const DataTypeVector& output_dtypes() const override {
102     return input_->output_dtypes();
103   }
104 
output_shapes() const105   const std::vector<PartialTensorShape>& output_shapes() const override {
106     return output_shapes_;
107   }
108 
DebugString() const109   string DebugString() const override {
110     name_utils::DatasetDebugStringParams params;
111     params.op_version = op_version_;
112     params.set_args(batch_size_);
113     return name_utils::DatasetDebugString(kDatasetType, params);
114   }
115 
Cardinality() const116   int64 Cardinality() const override {
117     int64_t n = input_->Cardinality();
118     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
119       return n;
120     }
121     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
122   }
123 
InputDatasets(std::vector<const DatasetBase * > * inputs) const124   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
125     inputs->push_back(input_);
126     return Status::OK();
127   }
128 
CheckExternalState() const129   Status CheckExternalState() const override {
130     return input_->CheckExternalState();
131   }
132 
133  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const134   Status AsGraphDefInternal(SerializationContext* ctx,
135                             DatasetGraphDefBuilder* b,
136                             Node** output) const override {
137     Node* input_graph_node = nullptr;
138     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
139     Node* batch_size = nullptr;
140     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
141 
142     std::vector<Node*> padded_shapes;
143     padded_shapes.reserve(padded_shapes_.size());
144     for (int i = 0; i < padded_shapes_.size(); i++) {
145       Node* node;
146       Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
147       for (int j = 0; j < padded_shapes_[i].dims(); j++) {
148         t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
149       }
150       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
151       padded_shapes.emplace_back(node);
152     }
153 
154     std::vector<Node*> padding_values;
155     padding_values.reserve(padding_values_.size());
156     for (const Tensor& t : padding_values_) {
157       Node* node;
158       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
159       padding_values.emplace_back(node);
160     }
161 
162     Node* drop_remainder = nullptr;
163     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
164 
165     AttrValue parallel_copy;
166     b->BuildAttrValue(parallel_copy_, &parallel_copy);
167 
168     AttrValue output_types;
169     b->BuildAttrValue(output_dtypes(), &output_types);
170 
171     AttrValue N;
172     b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
173 
174     TF_RETURN_IF_ERROR(b->AddDataset(
175         this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
176         {{2, padded_shapes}, {3, padding_values}},
177         {{kParallelCopy, parallel_copy},
178          {kToutputTypes, output_types},
179          {kNumPaddedShapes, N}},
180         output));
181     return Status::OK();
182   }
183 
184  private:
185   class Iterator : public DatasetIterator<Dataset> {
186    public:
Iterator(const Params & params)187     explicit Iterator(const Params& params)
188         : DatasetIterator<Dataset>(params) {}
189 
Initialize(IteratorContext * ctx)190     Status Initialize(IteratorContext* ctx) override {
191       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
192     }
193 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)194     Status GetNextInternal(IteratorContext* ctx,
195                            std::vector<Tensor>* out_tensors,
196                            bool* end_of_sequence) override {
197       // Each row of `batch_elements` is a tuple of tensors from the
198       // input iterator.
199       std::vector<std::vector<Tensor>> batch_elements;
200       {
201         mutex_lock l(mu_);
202         if (!input_impl_) {
203           *end_of_sequence = true;
204           return Status::OK();
205         } else {
206           *end_of_sequence = false;
207           batch_elements.reserve(dataset()->batch_size_);
208           for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
209                ++i) {
210             std::vector<Tensor> batch_element_tuple;
211             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
212                                                     end_of_sequence));
213             if (!*end_of_sequence) {
214               batch_elements.push_back(std::move(batch_element_tuple));
215             }
216           }
217           if (*end_of_sequence) {
218             input_impl_.reset();
219           }
220         }
221       }
222 
223       if (batch_elements.empty()) {
224         DCHECK(*end_of_sequence);
225         return Status::OK();
226       }
227 
228       if (dataset()->drop_remainder_ &&
229           batch_elements.size() < dataset()->batch_size_) {
230         *end_of_sequence = true;
231         return Status::OK();
232       }
233 
234       TF_RETURN_IF_ERROR(CopyBatch(ctx, batch_elements, out_tensors));
235       *end_of_sequence = false;
236       return Status::OK();
237     }
238 
239    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const240     std::shared_ptr<model::Node> CreateNode(
241         IteratorContext* ctx, model::Node::Args args) const override {
242       return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
243     }
244 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)245     Status SaveInternal(SerializationContext* ctx,
246                         IteratorStateWriter* writer) override {
247       mutex_lock l(mu_);
248       if (input_impl_)
249         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
250       else
251         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
252       return Status::OK();
253     }
254 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)255     Status RestoreInternal(IteratorContext* ctx,
256                            IteratorStateReader* reader) override {
257       mutex_lock l(mu_);
258       if (reader->Contains(full_name(kExhausted))) {
259         input_impl_.reset();
260       } else {
261         TF_RETURN_IF_ERROR(
262             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
263         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
264       }
265       return Status::OK();
266     }
267 
GetTraceMeMetadata() const268     TraceMeMetadata GetTraceMeMetadata() const override {
269       return dataset()->traceme_metadata_;
270     }
271 
272    private:
273     // Copies the retrieved batch elements into one output tensor per tuple
274     // component.
275     //
276     // NOTE(mrry): If the input or output sizes are statically known, we could
277     // potentially read the input values in-place into their respective slice
278     // locations. This would require a different GetNext() overload that
279     // supports zero-copy, and might make sense in an optimization pass.
CopyBatch(IteratorContext * ctx,const std::vector<std::vector<Tensor>> & batch_elements,std::vector<Tensor> * out_tensors)280     Status CopyBatch(IteratorContext* ctx,
281                      const std::vector<std::vector<Tensor>>& batch_elements,
282                      std::vector<Tensor>* out_tensors) {
283       static bool in_experiment =
284           GetExperiments().contains("parallelize_batch_copy");
285       const size_t num_tuple_components = batch_elements[0].size();
286       const int64_t num_batch_elements = batch_elements.size();
287       for (size_t component_index = 0; component_index < num_tuple_components;
288            ++component_index) {
289         // 1. Determine the shape of the padded tensor.
290         TensorShape batch_component_shape({num_batch_elements});
291         const PartialTensorShape& padded_shape =
292             dataset()->padded_shapes_[component_index];
293 
294         for (int dim = 0; dim < padded_shape.dims(); ++dim) {
295           if (padded_shape.dim_size(dim) == -1) {
296             batch_component_shape.AddDim(0);
297           } else {
298             batch_component_shape.AddDim(padded_shape.dim_size(dim));
299           }
300         }
301 
302         for (int64_t i = 0; i < num_batch_elements; ++i) {
303           const TensorShape& element_shape =
304               batch_elements[i][component_index].shape();
305           // TODO(mrry): Perform this check in the shape function if
306           // enough static information is available to do so.
307           if (element_shape.dims() != padded_shape.dims()) {
308             return errors::InvalidArgument(
309                 "All elements in a batch must have the same rank as the "
310                 "padded shape for component",
311                 component_index, ": expected rank ", padded_shape.dims(),
312                 " but got element with rank ", element_shape.dims());
313           }
314           for (int dim = 0; dim < padded_shape.dims(); ++dim) {
315             if (padded_shape.dim_size(dim) == -1) {
316               // Take the max of all batch elements in this dimension.
317               if (batch_elements[i][component_index].shape().dim_size(dim) >
318                   batch_component_shape.dim_size(dim + 1)) {
319                 batch_component_shape.set_dim(
320                     dim + 1,
321                     batch_elements[i][component_index].shape().dim_size(dim));
322               }
323             } else {
324               if (batch_elements[i][component_index].shape().dim_size(dim) >
325                   batch_component_shape.dim_size(dim + 1)) {
326                 return errors::DataLoss(
327                     "Attempted to pad to a smaller size than the input "
328                     "element.");
329               }
330             }
331           }
332         }
333 
334         // 2. Copy each batch element to the appropriate location in
335         // the output component tensor.
336         out_tensors->emplace_back(ctx->allocator({}),
337                                   output_dtypes()[component_index],
338                                   batch_component_shape);
339         Tensor& batch_component = out_tensors->back();
340         TF_RETURN_IF_ERROR(batch_util::SetElementZero(
341             &batch_component, dataset()->padding_values_[component_index]));
342 
343         // Build the output tuple component by copying one slice from each input
344         // element in the batch.
345         TensorShape component_shape({});
346         for (int i = 1; i < batch_component_shape.dims(); ++i) {
347           component_shape.AddDim(batch_component_shape.dim_size(i));
348         }
349         auto copy_element_fn = [component_index, &batch_elements,
350                                 &batch_component, &component_shape](int index) {
351           // Take the fast path if possible.
352           if (batch_elements[index][component_index].shape() ==
353               component_shape) {
354             TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
355                 batch_elements[index][component_index], &batch_component,
356                 index));
357           } else {
358             TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
359                 batch_elements[index][component_index], &batch_component,
360                 index));
361           }
362           return Status::OK();
363         };
364 
365         if (dataset()->parallel_copy_ ||
366             (in_experiment && (batch_component.AllocatedBytes() /
367                                num_batch_elements) >= (1 << 15))) {
368           BlockingCounter counter(num_batch_elements);
369           Status status;
370           mutex status_mu;
371           const auto num_threads = ctx->runner_threadpool_size();
372           const auto slice_size = num_batch_elements / num_threads;
373           int64_t offset = 0;
374           for (size_t i = 0; i < num_threads; ++i) {
375             int64_t length = slice_size;
376             // When the number of threads does not divide the number of elements
377             // evenly, the size of some slices is incremented to guarantee their
378             // sizes add up to the total number of elements.
379             if (i < num_batch_elements % num_threads) ++length;
380             (*ctx->runner())([offset, length, &status, &status_mu, &counter,
381                               &copy_element_fn]() {
382               for (size_t j = offset; j < offset + length; ++j) {
383                 {
384                   Status s = copy_element_fn(j);
385                   mutex_lock l(status_mu);
386                   status.Update(s);
387                 }
388                 counter.DecrementCount();
389               }
390             });
391             offset += length;
392           }
393           counter.Wait();
394           TF_RETURN_IF_ERROR(status);
395         } else {
396           for (size_t i = 0; i < num_batch_elements; ++i) {
397             TF_RETURN_IF_ERROR(copy_element_fn(i));
398           }
399         }
400       }
401       return Status::OK();
402     }
403 
404     mutex mu_;
405     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
406   };
407 
408   const int64 batch_size_;
409   const bool drop_remainder_;
410   const bool parallel_copy_;
411   const std::vector<PartialTensorShape> padded_shapes_;
412   const std::vector<Tensor> padding_values_;
413   const DatasetBase* const input_;
414   const int op_version_;
415   std::vector<PartialTensorShape> output_shapes_;
416   const TraceMeMetadata traceme_metadata_;
417 };
418 
PaddedBatchDatasetOp(OpKernelConstruction * ctx)419 PaddedBatchDatasetOp::PaddedBatchDatasetOp(OpKernelConstruction* ctx)
420     : UnaryDatasetOpKernel(ctx),
421       op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {
422   if (ctx->HasAttr(kParallelCopy)) {
423     OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, &parallel_copy_));
424   }
425 }
426 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)427 void PaddedBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
428                                        DatasetBase** output) {
429   int64_t batch_size;
430   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
431   OP_REQUIRES(ctx, batch_size > 0,
432               errors::InvalidArgument("Batch size must be greater than zero."));
433 
434   bool drop_remainder = false;
435   if (op_version_ > 1) {
436     OP_REQUIRES_OK(
437         ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
438   }
439 
440   OpInputList padded_shape_tensors;
441   OP_REQUIRES_OK(ctx, ctx->input_list(kPaddedShapes, &padded_shape_tensors));
442   std::vector<PartialTensorShape> padded_shapes;
443   padded_shapes.reserve(padded_shape_tensors.size());
444   OP_REQUIRES(ctx, padded_shape_tensors.size() == input->output_shapes().size(),
445               errors::InvalidArgument("Number of padded shapes (",
446                                       padded_shape_tensors.size(),
447                                       ") must match the number of components "
448                                       "in the input dataset's elements (",
449                                       input->output_shapes().size(), ")"));
450   for (const Tensor& padded_shape_t : padded_shape_tensors) {
451     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
452                 errors::InvalidArgument("All padded shapes must be vectors"));
453     PartialTensorShape padded_shape;
454     OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
455                             padded_shape_t.vec<int64>().data(),
456                             padded_shape_t.NumElements(), &padded_shape));
457     padded_shapes.push_back(std::move(padded_shape));
458   }
459   OpInputList padding_values_list;
460   OP_REQUIRES_OK(ctx, ctx->input_list(kPaddingValues, &padding_values_list));
461   std::vector<Tensor> padding_values;
462   OP_REQUIRES(ctx, padding_values_list.size() == input->output_shapes().size(),
463               errors::InvalidArgument(
464                   "Number of padding values (", padding_values_list.size(),
465                   ") must match the number of components in the input "
466                   "dataset's elements (",
467                   input->output_shapes().size(), ")"));
468   for (int i = 0; i < padding_values_list.size(); ++i) {
469     const Tensor& padding_value_t = padding_values_list[i];
470     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
471                 errors::InvalidArgument("All padding values must be scalars"));
472     OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
473                 errors::InvalidArgument(
474                     "Mismatched type between padding value ", i,
475                     " and input dataset's component ", i, ": ",
476                     DataTypeString(padding_value_t.dtype()), " vs. ",
477                     DataTypeString(input->output_dtypes()[i])));
478     padding_values.push_back(tensor::DeepCopy(padding_value_t));
479   }
480 
481   *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_,
482                         std::move(padded_shapes), std::move(padding_values),
483                         input, op_version_);
484 }
485 
486 namespace {
487 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
488                         PaddedBatchDatasetOp);
489 
490 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
491                         PaddedBatchDatasetOp);
492 }  // namespace
493 }  // namespace data
494 }  // namespace tensorflow
495