• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
16 
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/kernels/data/name_utils.h"
23 #include "tensorflow/core/lib/core/blocking_counter.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/cleanup.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/stringprintf.h"
28 #include "tensorflow/core/util/batch_util.h"
29 
30 namespace tensorflow {
31 namespace data {
32 
33 // See documentation in ../../ops/dataset_ops.cc for a high-level
34 // description of the following op.
35 
36 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDatasetType;
37 /* static */ constexpr const char* const PaddedBatchDatasetOp::kInputDataset;
38 /* static */ constexpr const char* const PaddedBatchDatasetOp::kBatchSize;
39 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddedShapes;
40 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddingValues;
41 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDropRemainder;
42 /* static */ constexpr const char* const PaddedBatchDatasetOp::kParallelCopy;
43 /* static */ constexpr const char* const PaddedBatchDatasetOp::kToutputTypes;
44 /* static */ constexpr const char* const PaddedBatchDatasetOp::kOutputShapes;
45 /* static */ constexpr const char* const PaddedBatchDatasetOp::kNumPaddedShapes;
46 
47 constexpr char kExhausted[] = "exhausted";
48 
49 class PaddedBatchDatasetOp::Dataset : public DatasetBase {
50  public:
Dataset(OpKernelContext * ctx,int64 batch_size,bool drop_remainder,bool parallel_copy,std::vector<PartialTensorShape> padded_shapes,std::vector<Tensor> padding_values,const DatasetBase * input,int op_version)51   Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
52           bool parallel_copy, std::vector<PartialTensorShape> padded_shapes,
53           std::vector<Tensor> padding_values, const DatasetBase* input,
54           int op_version)
55       : DatasetBase(DatasetContext(ctx)),
56         batch_size_(batch_size),
57         drop_remainder_(drop_remainder),
58         parallel_copy_(parallel_copy),
59         padded_shapes_(std::move(padded_shapes)),
60         padding_values_(std::move(padding_values)),
61         input_(input),
62         op_version_(op_version),
63         traceme_metadata_(
64             {{"batch_size",
65               strings::Printf("%lld", static_cast<long long>(batch_size))},
66              {"drop_remainder", drop_remainder ? "true" : "false"}}) {
67     input_->Ref();
68 
69     // NOTE(mrry): Currently we implement "batch up to" semantics. If we could
70     // tell statically that the input dataset is infinite, then we could
71     // always report `batch_size` as the 0th dimension.
72     //
73     // TODO(mrry): Need to validate that the input shape and the padded shape
74     // are "compatible" (i.e. that padded shape is >= input shape, with both
75     // static and dynamic checks as appropriate).
76     const auto& input_shapes = input_->output_shapes();
77     output_shapes_.reserve(input_shapes.size());
78     for (size_t i = 0; i < input_shapes.size(); ++i) {
79       if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
80         output_shapes_.push_back(
81             PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
82       } else {
83         output_shapes_.push_back(
84             PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
85       }
86     }
87   }
88 
~Dataset()89   ~Dataset() override { input_->Unref(); }
90 
MakeIteratorInternal(const string & prefix) const91   std::unique_ptr<IteratorBase> MakeIteratorInternal(
92       const string& prefix) const override {
93     name_utils::IteratorPrefixParams params;
94     params.op_version = op_version_;
95     return absl::make_unique<Iterator>(Iterator::Params{
96         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
97   }
98 
output_dtypes() const99   const DataTypeVector& output_dtypes() const override {
100     return input_->output_dtypes();
101   }
102 
output_shapes() const103   const std::vector<PartialTensorShape>& output_shapes() const override {
104     return output_shapes_;
105   }
106 
DebugString() const107   string DebugString() const override {
108     name_utils::DatasetDebugStringParams params;
109     params.op_version = op_version_;
110     params.set_args(batch_size_);
111     return name_utils::DatasetDebugString(kDatasetType, params);
112   }
113 
Cardinality() const114   int64 Cardinality() const override {
115     int64 n = input_->Cardinality();
116     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
117       return n;
118     }
119     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
120   }
121 
InputDatasets(std::vector<const DatasetBase * > * inputs) const122   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
123     inputs->push_back(input_);
124     return Status::OK();
125   }
126 
CheckExternalState() const127   Status CheckExternalState() const override {
128     return input_->CheckExternalState();
129   }
130 
131  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const132   Status AsGraphDefInternal(SerializationContext* ctx,
133                             DatasetGraphDefBuilder* b,
134                             Node** output) const override {
135     Node* input_graph_node = nullptr;
136     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
137     Node* batch_size = nullptr;
138     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
139 
140     std::vector<Node*> padded_shapes;
141     padded_shapes.reserve(padded_shapes_.size());
142     for (int i = 0; i < padded_shapes_.size(); i++) {
143       Node* node;
144       Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
145       for (int j = 0; j < padded_shapes_[i].dims(); j++) {
146         t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
147       }
148       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
149       padded_shapes.emplace_back(node);
150     }
151 
152     std::vector<Node*> padding_values;
153     padding_values.reserve(padding_values_.size());
154     for (const Tensor& t : padding_values_) {
155       Node* node;
156       TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
157       padding_values.emplace_back(node);
158     }
159 
160     Node* drop_remainder = nullptr;
161     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
162 
163     AttrValue parallel_copy;
164     b->BuildAttrValue(parallel_copy_, &parallel_copy);
165 
166     AttrValue output_types;
167     b->BuildAttrValue(output_dtypes(), &output_types);
168 
169     AttrValue N;
170     b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
171 
172     TF_RETURN_IF_ERROR(b->AddDataset(
173         this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
174         {{2, padded_shapes}, {3, padding_values}},
175         {{kParallelCopy, parallel_copy},
176          {kToutputTypes, output_types},
177          {kNumPaddedShapes, N}},
178         output));
179     return Status::OK();
180   }
181 
182  private:
183   class Iterator : public DatasetIterator<Dataset> {
184    public:
Iterator(const Params & params)185     explicit Iterator(const Params& params)
186         : DatasetIterator<Dataset>(params) {}
187 
Initialize(IteratorContext * ctx)188     Status Initialize(IteratorContext* ctx) override {
189       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
190     }
191 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)192     Status GetNextInternal(IteratorContext* ctx,
193                            std::vector<Tensor>* out_tensors,
194                            bool* end_of_sequence) override {
195       // Each row of `batch_elements` is a tuple of tensors from the
196       // input iterator.
197       std::vector<std::vector<Tensor>> batch_elements;
198       {
199         mutex_lock l(mu_);
200         if (!input_impl_) {
201           *end_of_sequence = true;
202           return Status::OK();
203         } else {
204           *end_of_sequence = false;
205           batch_elements.reserve(dataset()->batch_size_);
206           for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
207                ++i) {
208             std::vector<Tensor> batch_element_tuple;
209             TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
210                                                     end_of_sequence));
211             if (!*end_of_sequence) {
212               batch_elements.push_back(std::move(batch_element_tuple));
213             }
214           }
215           if (*end_of_sequence) {
216             input_impl_.reset();
217           }
218         }
219       }
220 
221       if (batch_elements.empty()) {
222         DCHECK(*end_of_sequence);
223         return Status::OK();
224       }
225 
226       if (dataset()->drop_remainder_ &&
227           batch_elements.size() < dataset()->batch_size_) {
228         *end_of_sequence = true;
229         return Status::OK();
230       }
231 
232       // Copy the retrieved batch elements into one output tensor per tuple
233       // component.
234       //
235       // NOTE(mrry): If the input or output sizes are statically known, we
236       // could potentially read the input values in-place into their
237       // respective slice locations. This would require a different GetNext()
238       // overload that supports zero-copy, and might make sense in an
239       // optimization pass.
240       const size_t num_tuple_components = batch_elements[0].size();
241       const int64 num_batch_elements = batch_elements.size();
242       for (size_t component_index = 0; component_index < num_tuple_components;
243            ++component_index) {
244         // 1. Determine the shape of the padded tensor.
245         TensorShape batch_component_shape({num_batch_elements});
246         const PartialTensorShape& padded_shape =
247             dataset()->padded_shapes_[component_index];
248 
249         for (int dim = 0; dim < padded_shape.dims(); ++dim) {
250           if (padded_shape.dim_size(dim) == -1) {
251             batch_component_shape.AddDim(0);
252           } else {
253             batch_component_shape.AddDim(padded_shape.dim_size(dim));
254           }
255         }
256 
257         for (int64 i = 0; i < num_batch_elements; ++i) {
258           const TensorShape& element_shape =
259               batch_elements[i][component_index].shape();
260           // TODO(mrry): Perform this check in the shape function if
261           // enough static information is available to do so.
262           if (element_shape.dims() != padded_shape.dims()) {
263             return errors::InvalidArgument(
264                 "All elements in a batch must have the same rank as the "
265                 "padded shape for component",
266                 component_index, ": expected rank ", padded_shape.dims(),
267                 " but got element with rank ", element_shape.dims());
268           }
269           for (int dim = 0; dim < padded_shape.dims(); ++dim) {
270             if (padded_shape.dim_size(dim) == -1) {
271               // Take the max of all batch elements in this dimension.
272               if (batch_elements[i][component_index].shape().dim_size(dim) >
273                   batch_component_shape.dim_size(dim + 1)) {
274                 batch_component_shape.set_dim(
275                     dim + 1,
276                     batch_elements[i][component_index].shape().dim_size(dim));
277               }
278             } else {
279               if (batch_elements[i][component_index].shape().dim_size(dim) >
280                   batch_component_shape.dim_size(dim + 1)) {
281                 return errors::DataLoss(
282                     "Attempted to pad to a smaller size than the input "
283                     "element.");
284               }
285             }
286           }
287         }
288 
289         // 2. Copy each batch element to the appropriate location in
290         // the output component tensor.
291         out_tensors->emplace_back(ctx->allocator({}),
292                                   output_dtypes()[component_index],
293                                   batch_component_shape);
294         Tensor& batch_component = out_tensors->back();
295         TF_RETURN_IF_ERROR(batch_util::SetElementZero(
296             &batch_component, dataset()->padding_values_[component_index]));
297 
298         // Build the output tuple component by copying one slice
299         // from each input element in the batch.
300         TensorShape component_shape({});
301         for (int i = 1; i < batch_component_shape.dims(); ++i) {
302           component_shape.AddDim(batch_component_shape.dim_size(i));
303         }
304         auto copy_element_fn = [component_index, &batch_elements,
305                                 &batch_component, &component_shape](int index) {
306           // Take the fast path if possible.
307           if (batch_elements[index][component_index].shape() ==
308               component_shape) {
309             TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
310                 batch_elements[index][component_index], &batch_component,
311                 index));
312           } else {
313             TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
314                 batch_elements[index][component_index], &batch_component,
315                 index));
316           }
317           return Status::OK();
318         };
319         BlockingCounter counter(num_batch_elements);
320         Status status;
321         mutex status_mu;
322         for (size_t i = 0; i < num_batch_elements; ++i) {
323           if (TF_PREDICT_FALSE(dataset()->parallel_copy_)) {
324             (*ctx->runner())(
325                 [i, &status, &status_mu, &counter, &copy_element_fn]() {
326                   Status s = copy_element_fn(i);
327                   {
328                     mutex_lock l(status_mu);
329                     status.Update(s);
330                   }
331                   counter.DecrementCount();
332                 });
333           } else {
334             status.Update(copy_element_fn(i));
335             counter.DecrementCount();
336           }
337         }
338         counter.Wait();
339         TF_RETURN_IF_ERROR(status);
340       }
341       *end_of_sequence = false;
342       return Status::OK();
343     }
344 
345    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const346     std::shared_ptr<model::Node> CreateNode(
347         IteratorContext* ctx, model::Node::Args args) const override {
348       return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
349     }
350 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)351     Status SaveInternal(SerializationContext* ctx,
352                         IteratorStateWriter* writer) override {
353       mutex_lock l(mu_);
354       if (input_impl_)
355         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
356       else
357         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
358       return Status::OK();
359     }
360 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)361     Status RestoreInternal(IteratorContext* ctx,
362                            IteratorStateReader* reader) override {
363       mutex_lock l(mu_);
364       if (reader->Contains(full_name(kExhausted))) {
365         input_impl_.reset();
366       } else {
367         TF_RETURN_IF_ERROR(
368             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
369         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
370       }
371       return Status::OK();
372     }
373 
GetTraceMeMetadata() const374     TraceMeMetadata GetTraceMeMetadata() const override {
375       return dataset()->traceme_metadata_;
376     }
377 
378    private:
379     mutex mu_;
380     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
381   };
382 
383   const int64 batch_size_;
384   const bool drop_remainder_;
385   const bool parallel_copy_;
386   const std::vector<PartialTensorShape> padded_shapes_;
387   const std::vector<Tensor> padding_values_;
388   const DatasetBase* const input_;
389   const int op_version_;
390   std::vector<PartialTensorShape> output_shapes_;
391   const TraceMeMetadata traceme_metadata_;
392 };
393 
PaddedBatchDatasetOp(OpKernelConstruction * ctx)394 PaddedBatchDatasetOp::PaddedBatchDatasetOp(OpKernelConstruction* ctx)
395     : UnaryDatasetOpKernel(ctx),
396       op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {
397   if (ctx->HasAttr(kParallelCopy)) {
398     OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, &parallel_copy_));
399   }
400 }
401 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)402 void PaddedBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
403                                        DatasetBase** output) {
404   int64 batch_size;
405   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
406   OP_REQUIRES(ctx, batch_size > 0,
407               errors::InvalidArgument("Batch size must be greater than zero."));
408 
409   bool drop_remainder = false;
410   if (op_version_ > 1) {
411     OP_REQUIRES_OK(
412         ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
413   }
414 
415   OpInputList padded_shape_tensors;
416   OP_REQUIRES_OK(ctx, ctx->input_list(kPaddedShapes, &padded_shape_tensors));
417   std::vector<PartialTensorShape> padded_shapes;
418   padded_shapes.reserve(padded_shape_tensors.size());
419   OP_REQUIRES(ctx, padded_shape_tensors.size() == input->output_shapes().size(),
420               errors::InvalidArgument("Number of padded shapes (",
421                                       padded_shape_tensors.size(),
422                                       ") must match the number of components "
423                                       "in the input dataset's elements (",
424                                       input->output_shapes().size(), ")"));
425   for (const Tensor& padded_shape_t : padded_shape_tensors) {
426     OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
427                 errors::InvalidArgument("All padded shapes must be vectors"));
428     PartialTensorShape padded_shape;
429     OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
430                             padded_shape_t.vec<int64>().data(),
431                             padded_shape_t.NumElements(), &padded_shape));
432     padded_shapes.push_back(std::move(padded_shape));
433   }
434   OpInputList padding_values_list;
435   OP_REQUIRES_OK(ctx, ctx->input_list(kPaddingValues, &padding_values_list));
436   std::vector<Tensor> padding_values;
437   OP_REQUIRES(ctx, padding_values_list.size() == input->output_shapes().size(),
438               errors::InvalidArgument(
439                   "Number of padding values (", padding_values_list.size(),
440                   ") must match the number of components in the input "
441                   "dataset's elements (",
442                   input->output_shapes().size(), ")"));
443   for (int i = 0; i < padding_values_list.size(); ++i) {
444     const Tensor& padding_value_t = padding_values_list[i];
445     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
446                 errors::InvalidArgument("All padding values must be scalars"));
447     OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
448                 errors::InvalidArgument(
449                     "Mismatched type between padding value ", i,
450                     " and input dataset's component ", i, ": ",
451                     DataTypeString(padding_value_t.dtype()), " vs. ",
452                     DataTypeString(input->output_dtypes()[i])));
453     padding_values.push_back(tensor::DeepCopy(padding_value_t));
454   }
455 
456   *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_,
457                         std::move(padded_shapes), std::move(padding_values),
458                         input, op_version_);
459 }
460 
461 namespace {
462 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
463                         PaddedBatchDatasetOp);
464 
465 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
466                         PaddedBatchDatasetOp);
467 }  // namespace
468 }  // namespace data
469 }  // namespace tensorflow
470