1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
16
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/framework/partial_tensor_shape.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_util.h"
22 #include "tensorflow/core/kernels/data/name_utils.h"
23 #include "tensorflow/core/lib/core/blocking_counter.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/gtl/cleanup.h"
26 #include "tensorflow/core/platform/macros.h"
27 #include "tensorflow/core/platform/stringprintf.h"
28 #include "tensorflow/core/util/batch_util.h"
29
30 namespace tensorflow {
31 namespace data {
32
33 // See documentation in ../../ops/dataset_ops.cc for a high-level
34 // description of the following op.
35
36 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDatasetType;
37 /* static */ constexpr const char* const PaddedBatchDatasetOp::kInputDataset;
38 /* static */ constexpr const char* const PaddedBatchDatasetOp::kBatchSize;
39 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddedShapes;
40 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddingValues;
41 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDropRemainder;
42 /* static */ constexpr const char* const PaddedBatchDatasetOp::kParallelCopy;
43 /* static */ constexpr const char* const PaddedBatchDatasetOp::kToutputTypes;
44 /* static */ constexpr const char* const PaddedBatchDatasetOp::kOutputShapes;
45 /* static */ constexpr const char* const PaddedBatchDatasetOp::kNumPaddedShapes;
46
47 constexpr char kExhausted[] = "exhausted";
48
49 class PaddedBatchDatasetOp::Dataset : public DatasetBase {
50 public:
Dataset(OpKernelContext * ctx,int64 batch_size,bool drop_remainder,bool parallel_copy,std::vector<PartialTensorShape> padded_shapes,std::vector<Tensor> padding_values,const DatasetBase * input,int op_version)51 Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
52 bool parallel_copy, std::vector<PartialTensorShape> padded_shapes,
53 std::vector<Tensor> padding_values, const DatasetBase* input,
54 int op_version)
55 : DatasetBase(DatasetContext(ctx)),
56 batch_size_(batch_size),
57 drop_remainder_(drop_remainder),
58 parallel_copy_(parallel_copy),
59 padded_shapes_(std::move(padded_shapes)),
60 padding_values_(std::move(padding_values)),
61 input_(input),
62 op_version_(op_version),
63 traceme_metadata_(
64 {{"batch_size",
65 strings::Printf("%lld", static_cast<long long>(batch_size))},
66 {"drop_remainder", drop_remainder ? "true" : "false"}}) {
67 input_->Ref();
68
69 // NOTE(mrry): Currently we implement "batch up to" semantics. If we could
70 // tell statically that the input dataset is infinite, then we could
71 // always report `batch_size` as the 0th dimension.
72 //
73 // TODO(mrry): Need to validate that the input shape and the padded shape
74 // are "compatible" (i.e. that padded shape is >= input shape, with both
75 // static and dynamic checks as appropriate).
76 const auto& input_shapes = input_->output_shapes();
77 output_shapes_.reserve(input_shapes.size());
78 for (size_t i = 0; i < input_shapes.size(); ++i) {
79 if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
80 output_shapes_.push_back(
81 PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
82 } else {
83 output_shapes_.push_back(
84 PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
85 }
86 }
87 }
88
~Dataset()89 ~Dataset() override { input_->Unref(); }
90
MakeIteratorInternal(const string & prefix) const91 std::unique_ptr<IteratorBase> MakeIteratorInternal(
92 const string& prefix) const override {
93 name_utils::IteratorPrefixParams params;
94 params.op_version = op_version_;
95 return absl::make_unique<Iterator>(Iterator::Params{
96 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
97 }
98
output_dtypes() const99 const DataTypeVector& output_dtypes() const override {
100 return input_->output_dtypes();
101 }
102
output_shapes() const103 const std::vector<PartialTensorShape>& output_shapes() const override {
104 return output_shapes_;
105 }
106
DebugString() const107 string DebugString() const override {
108 name_utils::DatasetDebugStringParams params;
109 params.op_version = op_version_;
110 params.set_args(batch_size_);
111 return name_utils::DatasetDebugString(kDatasetType, params);
112 }
113
Cardinality() const114 int64 Cardinality() const override {
115 int64 n = input_->Cardinality();
116 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
117 return n;
118 }
119 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
120 }
121
InputDatasets(std::vector<const DatasetBase * > * inputs) const122 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
123 inputs->push_back(input_);
124 return Status::OK();
125 }
126
CheckExternalState() const127 Status CheckExternalState() const override {
128 return input_->CheckExternalState();
129 }
130
131 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const132 Status AsGraphDefInternal(SerializationContext* ctx,
133 DatasetGraphDefBuilder* b,
134 Node** output) const override {
135 Node* input_graph_node = nullptr;
136 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
137 Node* batch_size = nullptr;
138 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
139
140 std::vector<Node*> padded_shapes;
141 padded_shapes.reserve(padded_shapes_.size());
142 for (int i = 0; i < padded_shapes_.size(); i++) {
143 Node* node;
144 Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
145 for (int j = 0; j < padded_shapes_[i].dims(); j++) {
146 t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
147 }
148 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
149 padded_shapes.emplace_back(node);
150 }
151
152 std::vector<Node*> padding_values;
153 padding_values.reserve(padding_values_.size());
154 for (const Tensor& t : padding_values_) {
155 Node* node;
156 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
157 padding_values.emplace_back(node);
158 }
159
160 Node* drop_remainder = nullptr;
161 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
162
163 AttrValue parallel_copy;
164 b->BuildAttrValue(parallel_copy_, ¶llel_copy);
165
166 AttrValue output_types;
167 b->BuildAttrValue(output_dtypes(), &output_types);
168
169 AttrValue N;
170 b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
171
172 TF_RETURN_IF_ERROR(b->AddDataset(
173 this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
174 {{2, padded_shapes}, {3, padding_values}},
175 {{kParallelCopy, parallel_copy},
176 {kToutputTypes, output_types},
177 {kNumPaddedShapes, N}},
178 output));
179 return Status::OK();
180 }
181
182 private:
183 class Iterator : public DatasetIterator<Dataset> {
184 public:
Iterator(const Params & params)185 explicit Iterator(const Params& params)
186 : DatasetIterator<Dataset>(params) {}
187
Initialize(IteratorContext * ctx)188 Status Initialize(IteratorContext* ctx) override {
189 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
190 }
191
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)192 Status GetNextInternal(IteratorContext* ctx,
193 std::vector<Tensor>* out_tensors,
194 bool* end_of_sequence) override {
195 // Each row of `batch_elements` is a tuple of tensors from the
196 // input iterator.
197 std::vector<std::vector<Tensor>> batch_elements;
198 {
199 mutex_lock l(mu_);
200 if (!input_impl_) {
201 *end_of_sequence = true;
202 return Status::OK();
203 } else {
204 *end_of_sequence = false;
205 batch_elements.reserve(dataset()->batch_size_);
206 for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
207 ++i) {
208 std::vector<Tensor> batch_element_tuple;
209 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
210 end_of_sequence));
211 if (!*end_of_sequence) {
212 batch_elements.push_back(std::move(batch_element_tuple));
213 }
214 }
215 if (*end_of_sequence) {
216 input_impl_.reset();
217 }
218 }
219 }
220
221 if (batch_elements.empty()) {
222 DCHECK(*end_of_sequence);
223 return Status::OK();
224 }
225
226 if (dataset()->drop_remainder_ &&
227 batch_elements.size() < dataset()->batch_size_) {
228 *end_of_sequence = true;
229 return Status::OK();
230 }
231
232 // Copy the retrieved batch elements into one output tensor per tuple
233 // component.
234 //
235 // NOTE(mrry): If the input or output sizes are statically known, we
236 // could potentially read the input values in-place into their
237 // respective slice locations. This would require a different GetNext()
238 // overload that supports zero-copy, and might make sense in an
239 // optimization pass.
240 const size_t num_tuple_components = batch_elements[0].size();
241 const int64 num_batch_elements = batch_elements.size();
242 for (size_t component_index = 0; component_index < num_tuple_components;
243 ++component_index) {
244 // 1. Determine the shape of the padded tensor.
245 TensorShape batch_component_shape({num_batch_elements});
246 const PartialTensorShape& padded_shape =
247 dataset()->padded_shapes_[component_index];
248
249 for (int dim = 0; dim < padded_shape.dims(); ++dim) {
250 if (padded_shape.dim_size(dim) == -1) {
251 batch_component_shape.AddDim(0);
252 } else {
253 batch_component_shape.AddDim(padded_shape.dim_size(dim));
254 }
255 }
256
257 for (int64 i = 0; i < num_batch_elements; ++i) {
258 const TensorShape& element_shape =
259 batch_elements[i][component_index].shape();
260 // TODO(mrry): Perform this check in the shape function if
261 // enough static information is available to do so.
262 if (element_shape.dims() != padded_shape.dims()) {
263 return errors::InvalidArgument(
264 "All elements in a batch must have the same rank as the "
265 "padded shape for component",
266 component_index, ": expected rank ", padded_shape.dims(),
267 " but got element with rank ", element_shape.dims());
268 }
269 for (int dim = 0; dim < padded_shape.dims(); ++dim) {
270 if (padded_shape.dim_size(dim) == -1) {
271 // Take the max of all batch elements in this dimension.
272 if (batch_elements[i][component_index].shape().dim_size(dim) >
273 batch_component_shape.dim_size(dim + 1)) {
274 batch_component_shape.set_dim(
275 dim + 1,
276 batch_elements[i][component_index].shape().dim_size(dim));
277 }
278 } else {
279 if (batch_elements[i][component_index].shape().dim_size(dim) >
280 batch_component_shape.dim_size(dim + 1)) {
281 return errors::DataLoss(
282 "Attempted to pad to a smaller size than the input "
283 "element.");
284 }
285 }
286 }
287 }
288
289 // 2. Copy each batch element to the appropriate location in
290 // the output component tensor.
291 out_tensors->emplace_back(ctx->allocator({}),
292 output_dtypes()[component_index],
293 batch_component_shape);
294 Tensor& batch_component = out_tensors->back();
295 TF_RETURN_IF_ERROR(batch_util::SetElementZero(
296 &batch_component, dataset()->padding_values_[component_index]));
297
298 // Build the output tuple component by copying one slice
299 // from each input element in the batch.
300 TensorShape component_shape({});
301 for (int i = 1; i < batch_component_shape.dims(); ++i) {
302 component_shape.AddDim(batch_component_shape.dim_size(i));
303 }
304 auto copy_element_fn = [component_index, &batch_elements,
305 &batch_component, &component_shape](int index) {
306 // Take the fast path if possible.
307 if (batch_elements[index][component_index].shape() ==
308 component_shape) {
309 TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
310 batch_elements[index][component_index], &batch_component,
311 index));
312 } else {
313 TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
314 batch_elements[index][component_index], &batch_component,
315 index));
316 }
317 return Status::OK();
318 };
319 BlockingCounter counter(num_batch_elements);
320 Status status;
321 mutex status_mu;
322 for (size_t i = 0; i < num_batch_elements; ++i) {
323 if (TF_PREDICT_FALSE(dataset()->parallel_copy_)) {
324 (*ctx->runner())(
325 [i, &status, &status_mu, &counter, ©_element_fn]() {
326 Status s = copy_element_fn(i);
327 {
328 mutex_lock l(status_mu);
329 status.Update(s);
330 }
331 counter.DecrementCount();
332 });
333 } else {
334 status.Update(copy_element_fn(i));
335 counter.DecrementCount();
336 }
337 }
338 counter.Wait();
339 TF_RETURN_IF_ERROR(status);
340 }
341 *end_of_sequence = false;
342 return Status::OK();
343 }
344
345 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const346 std::shared_ptr<model::Node> CreateNode(
347 IteratorContext* ctx, model::Node::Args args) const override {
348 return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
349 }
350
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)351 Status SaveInternal(SerializationContext* ctx,
352 IteratorStateWriter* writer) override {
353 mutex_lock l(mu_);
354 if (input_impl_)
355 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
356 else
357 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
358 return Status::OK();
359 }
360
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)361 Status RestoreInternal(IteratorContext* ctx,
362 IteratorStateReader* reader) override {
363 mutex_lock l(mu_);
364 if (reader->Contains(full_name(kExhausted))) {
365 input_impl_.reset();
366 } else {
367 TF_RETURN_IF_ERROR(
368 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
369 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
370 }
371 return Status::OK();
372 }
373
GetTraceMeMetadata() const374 TraceMeMetadata GetTraceMeMetadata() const override {
375 return dataset()->traceme_metadata_;
376 }
377
378 private:
379 mutex mu_;
380 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
381 };
382
383 const int64 batch_size_;
384 const bool drop_remainder_;
385 const bool parallel_copy_;
386 const std::vector<PartialTensorShape> padded_shapes_;
387 const std::vector<Tensor> padding_values_;
388 const DatasetBase* const input_;
389 const int op_version_;
390 std::vector<PartialTensorShape> output_shapes_;
391 const TraceMeMetadata traceme_metadata_;
392 };
393
PaddedBatchDatasetOp(OpKernelConstruction * ctx)394 PaddedBatchDatasetOp::PaddedBatchDatasetOp(OpKernelConstruction* ctx)
395 : UnaryDatasetOpKernel(ctx),
396 op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {
397 if (ctx->HasAttr(kParallelCopy)) {
398 OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, ¶llel_copy_));
399 }
400 }
401
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)402 void PaddedBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
403 DatasetBase** output) {
404 int64 batch_size;
405 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
406 OP_REQUIRES(ctx, batch_size > 0,
407 errors::InvalidArgument("Batch size must be greater than zero."));
408
409 bool drop_remainder = false;
410 if (op_version_ > 1) {
411 OP_REQUIRES_OK(
412 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
413 }
414
415 OpInputList padded_shape_tensors;
416 OP_REQUIRES_OK(ctx, ctx->input_list(kPaddedShapes, &padded_shape_tensors));
417 std::vector<PartialTensorShape> padded_shapes;
418 padded_shapes.reserve(padded_shape_tensors.size());
419 OP_REQUIRES(ctx, padded_shape_tensors.size() == input->output_shapes().size(),
420 errors::InvalidArgument("Number of padded shapes (",
421 padded_shape_tensors.size(),
422 ") must match the number of components "
423 "in the input dataset's elements (",
424 input->output_shapes().size(), ")"));
425 for (const Tensor& padded_shape_t : padded_shape_tensors) {
426 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
427 errors::InvalidArgument("All padded shapes must be vectors"));
428 PartialTensorShape padded_shape;
429 OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
430 padded_shape_t.vec<int64>().data(),
431 padded_shape_t.NumElements(), &padded_shape));
432 padded_shapes.push_back(std::move(padded_shape));
433 }
434 OpInputList padding_values_list;
435 OP_REQUIRES_OK(ctx, ctx->input_list(kPaddingValues, &padding_values_list));
436 std::vector<Tensor> padding_values;
437 OP_REQUIRES(ctx, padding_values_list.size() == input->output_shapes().size(),
438 errors::InvalidArgument(
439 "Number of padding values (", padding_values_list.size(),
440 ") must match the number of components in the input "
441 "dataset's elements (",
442 input->output_shapes().size(), ")"));
443 for (int i = 0; i < padding_values_list.size(); ++i) {
444 const Tensor& padding_value_t = padding_values_list[i];
445 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
446 errors::InvalidArgument("All padding values must be scalars"));
447 OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
448 errors::InvalidArgument(
449 "Mismatched type between padding value ", i,
450 " and input dataset's component ", i, ": ",
451 DataTypeString(padding_value_t.dtype()), " vs. ",
452 DataTypeString(input->output_dtypes()[i])));
453 padding_values.push_back(tensor::DeepCopy(padding_value_t));
454 }
455
456 *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_,
457 std::move(padded_shapes), std::move(padding_values),
458 input, op_version_);
459 }
460
461 namespace {
462 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
463 PaddedBatchDatasetOp);
464
465 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
466 PaddedBatchDatasetOp);
467 } // namespace
468 } // namespace data
469 } // namespace tensorflow
470