1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
16
17 #include "tensorflow/core/data/dataset_utils.h"
18 #include "tensorflow/core/data/name_utils.h"
19 #include "tensorflow/core/framework/dataset.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/partial_tensor_shape.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_util.h"
24 #include "tensorflow/core/lib/core/blocking_counter.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30
31 namespace tensorflow {
32 namespace data {
33
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36
37 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const PaddedBatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const PaddedBatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddedShapes;
41 /* static */ constexpr const char* const PaddedBatchDatasetOp::kPaddingValues;
42 /* static */ constexpr const char* const PaddedBatchDatasetOp::kDropRemainder;
43 /* static */ constexpr const char* const PaddedBatchDatasetOp::kParallelCopy;
44 /* static */ constexpr const char* const PaddedBatchDatasetOp::kToutputTypes;
45 /* static */ constexpr const char* const PaddedBatchDatasetOp::kOutputShapes;
46 /* static */ constexpr const char* const PaddedBatchDatasetOp::kNumPaddedShapes;
47
48 constexpr char kExhausted[] = "exhausted";
49
50 class PaddedBatchDatasetOp::Dataset : public DatasetBase {
51 public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,std::vector<PartialTensorShape> padded_shapes,std::vector<Tensor> padding_values,const DatasetBase * input,int op_version)52 Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
53 bool parallel_copy, std::vector<PartialTensorShape> padded_shapes,
54 std::vector<Tensor> padding_values, const DatasetBase* input,
55 int op_version)
56 : DatasetBase(DatasetContext(ctx)),
57 batch_size_(batch_size),
58 drop_remainder_(drop_remainder),
59 parallel_copy_(parallel_copy),
60 padded_shapes_(std::move(padded_shapes)),
61 padding_values_(std::move(padding_values)),
62 input_(input),
63 op_version_(op_version),
64 traceme_metadata_(
65 {{"batch_size",
66 strings::Printf("%lld", static_cast<long long>(batch_size))},
67 {"drop_remainder", drop_remainder ? "true" : "false"},
68 {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69 input_->Ref();
70
71 // NOTE(mrry): Currently we implement "batch up to" semantics. If we could
72 // tell statically that the input dataset is infinite, then we could
73 // always report `batch_size` as the 0th dimension.
74 //
75 // TODO(mrry): Need to validate that the input shape and the padded shape
76 // are "compatible" (i.e. that padded shape is >= input shape, with both
77 // static and dynamic checks as appropriate).
78 const auto& input_shapes = input_->output_shapes();
79 output_shapes_.reserve(input_shapes.size());
80 for (size_t i = 0; i < input_shapes.size(); ++i) {
81 if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
82 output_shapes_.push_back(
83 PartialTensorShape({batch_size_}).Concatenate(padded_shapes_[i]));
84 } else {
85 output_shapes_.push_back(
86 PartialTensorShape({-1}).Concatenate(padded_shapes_[i]));
87 }
88 }
89 }
90
~Dataset()91 ~Dataset() override { input_->Unref(); }
92
MakeIteratorInternal(const string & prefix) const93 std::unique_ptr<IteratorBase> MakeIteratorInternal(
94 const string& prefix) const override {
95 name_utils::IteratorPrefixParams params;
96 params.op_version = op_version_;
97 return absl::make_unique<Iterator>(Iterator::Params{
98 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
99 }
100
output_dtypes() const101 const DataTypeVector& output_dtypes() const override {
102 return input_->output_dtypes();
103 }
104
output_shapes() const105 const std::vector<PartialTensorShape>& output_shapes() const override {
106 return output_shapes_;
107 }
108
DebugString() const109 string DebugString() const override {
110 name_utils::DatasetDebugStringParams params;
111 params.op_version = op_version_;
112 params.set_args(batch_size_);
113 return name_utils::DatasetDebugString(kDatasetType, params);
114 }
115
Cardinality() const116 int64 Cardinality() const override {
117 int64_t n = input_->Cardinality();
118 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
119 return n;
120 }
121 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
122 }
123
InputDatasets(std::vector<const DatasetBase * > * inputs) const124 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
125 inputs->push_back(input_);
126 return Status::OK();
127 }
128
CheckExternalState() const129 Status CheckExternalState() const override {
130 return input_->CheckExternalState();
131 }
132
133 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const134 Status AsGraphDefInternal(SerializationContext* ctx,
135 DatasetGraphDefBuilder* b,
136 Node** output) const override {
137 Node* input_graph_node = nullptr;
138 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
139 Node* batch_size = nullptr;
140 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
141
142 std::vector<Node*> padded_shapes;
143 padded_shapes.reserve(padded_shapes_.size());
144 for (int i = 0; i < padded_shapes_.size(); i++) {
145 Node* node;
146 Tensor t(DT_INT64, TensorShape({padded_shapes_[i].dims()}));
147 for (int j = 0; j < padded_shapes_[i].dims(); j++) {
148 t.vec<int64>()(j) = padded_shapes_[i].dim_size(j);
149 }
150 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
151 padded_shapes.emplace_back(node);
152 }
153
154 std::vector<Node*> padding_values;
155 padding_values.reserve(padding_values_.size());
156 for (const Tensor& t : padding_values_) {
157 Node* node;
158 TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
159 padding_values.emplace_back(node);
160 }
161
162 Node* drop_remainder = nullptr;
163 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
164
165 AttrValue parallel_copy;
166 b->BuildAttrValue(parallel_copy_, ¶llel_copy);
167
168 AttrValue output_types;
169 b->BuildAttrValue(output_dtypes(), &output_types);
170
171 AttrValue N;
172 b->BuildAttrValue<int64>(padded_shapes_.size(), &N);
173
174 TF_RETURN_IF_ERROR(b->AddDataset(
175 this, {{0, input_graph_node}, {1, batch_size}, {4, drop_remainder}},
176 {{2, padded_shapes}, {3, padding_values}},
177 {{kParallelCopy, parallel_copy},
178 {kToutputTypes, output_types},
179 {kNumPaddedShapes, N}},
180 output));
181 return Status::OK();
182 }
183
184 private:
185 class Iterator : public DatasetIterator<Dataset> {
186 public:
Iterator(const Params & params)187 explicit Iterator(const Params& params)
188 : DatasetIterator<Dataset>(params) {}
189
Initialize(IteratorContext * ctx)190 Status Initialize(IteratorContext* ctx) override {
191 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
192 }
193
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)194 Status GetNextInternal(IteratorContext* ctx,
195 std::vector<Tensor>* out_tensors,
196 bool* end_of_sequence) override {
197 // Each row of `batch_elements` is a tuple of tensors from the
198 // input iterator.
199 std::vector<std::vector<Tensor>> batch_elements;
200 {
201 mutex_lock l(mu_);
202 if (!input_impl_) {
203 *end_of_sequence = true;
204 return Status::OK();
205 } else {
206 *end_of_sequence = false;
207 batch_elements.reserve(dataset()->batch_size_);
208 for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
209 ++i) {
210 std::vector<Tensor> batch_element_tuple;
211 TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &batch_element_tuple,
212 end_of_sequence));
213 if (!*end_of_sequence) {
214 batch_elements.push_back(std::move(batch_element_tuple));
215 }
216 }
217 if (*end_of_sequence) {
218 input_impl_.reset();
219 }
220 }
221 }
222
223 if (batch_elements.empty()) {
224 DCHECK(*end_of_sequence);
225 return Status::OK();
226 }
227
228 if (dataset()->drop_remainder_ &&
229 batch_elements.size() < dataset()->batch_size_) {
230 *end_of_sequence = true;
231 return Status::OK();
232 }
233
234 TF_RETURN_IF_ERROR(CopyBatch(ctx, batch_elements, out_tensors));
235 *end_of_sequence = false;
236 return Status::OK();
237 }
238
239 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const240 std::shared_ptr<model::Node> CreateNode(
241 IteratorContext* ctx, model::Node::Args args) const override {
242 return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
243 }
244
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)245 Status SaveInternal(SerializationContext* ctx,
246 IteratorStateWriter* writer) override {
247 mutex_lock l(mu_);
248 if (input_impl_)
249 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
250 else
251 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
252 return Status::OK();
253 }
254
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)255 Status RestoreInternal(IteratorContext* ctx,
256 IteratorStateReader* reader) override {
257 mutex_lock l(mu_);
258 if (reader->Contains(full_name(kExhausted))) {
259 input_impl_.reset();
260 } else {
261 TF_RETURN_IF_ERROR(
262 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
263 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
264 }
265 return Status::OK();
266 }
267
GetTraceMeMetadata() const268 TraceMeMetadata GetTraceMeMetadata() const override {
269 return dataset()->traceme_metadata_;
270 }
271
272 private:
273 // Copies the retrieved batch elements into one output tensor per tuple
274 // component.
275 //
276 // NOTE(mrry): If the input or output sizes are statically known, we could
277 // potentially read the input values in-place into their respective slice
278 // locations. This would require a different GetNext() overload that
279 // supports zero-copy, and might make sense in an optimization pass.
CopyBatch(IteratorContext * ctx,const std::vector<std::vector<Tensor>> & batch_elements,std::vector<Tensor> * out_tensors)280 Status CopyBatch(IteratorContext* ctx,
281 const std::vector<std::vector<Tensor>>& batch_elements,
282 std::vector<Tensor>* out_tensors) {
283 static bool in_experiment =
284 GetExperiments().contains("parallelize_batch_copy");
285 const size_t num_tuple_components = batch_elements[0].size();
286 const int64_t num_batch_elements = batch_elements.size();
287 for (size_t component_index = 0; component_index < num_tuple_components;
288 ++component_index) {
289 // 1. Determine the shape of the padded tensor.
290 TensorShape batch_component_shape({num_batch_elements});
291 const PartialTensorShape& padded_shape =
292 dataset()->padded_shapes_[component_index];
293
294 for (int dim = 0; dim < padded_shape.dims(); ++dim) {
295 if (padded_shape.dim_size(dim) == -1) {
296 batch_component_shape.AddDim(0);
297 } else {
298 batch_component_shape.AddDim(padded_shape.dim_size(dim));
299 }
300 }
301
302 for (int64_t i = 0; i < num_batch_elements; ++i) {
303 const TensorShape& element_shape =
304 batch_elements[i][component_index].shape();
305 // TODO(mrry): Perform this check in the shape function if
306 // enough static information is available to do so.
307 if (element_shape.dims() != padded_shape.dims()) {
308 return errors::InvalidArgument(
309 "All elements in a batch must have the same rank as the "
310 "padded shape for component",
311 component_index, ": expected rank ", padded_shape.dims(),
312 " but got element with rank ", element_shape.dims());
313 }
314 for (int dim = 0; dim < padded_shape.dims(); ++dim) {
315 if (padded_shape.dim_size(dim) == -1) {
316 // Take the max of all batch elements in this dimension.
317 if (batch_elements[i][component_index].shape().dim_size(dim) >
318 batch_component_shape.dim_size(dim + 1)) {
319 batch_component_shape.set_dim(
320 dim + 1,
321 batch_elements[i][component_index].shape().dim_size(dim));
322 }
323 } else {
324 if (batch_elements[i][component_index].shape().dim_size(dim) >
325 batch_component_shape.dim_size(dim + 1)) {
326 return errors::DataLoss(
327 "Attempted to pad to a smaller size than the input "
328 "element.");
329 }
330 }
331 }
332 }
333
334 // 2. Copy each batch element to the appropriate location in
335 // the output component tensor.
336 out_tensors->emplace_back(ctx->allocator({}),
337 output_dtypes()[component_index],
338 batch_component_shape);
339 Tensor& batch_component = out_tensors->back();
340 TF_RETURN_IF_ERROR(batch_util::SetElementZero(
341 &batch_component, dataset()->padding_values_[component_index]));
342
343 // Build the output tuple component by copying one slice from each input
344 // element in the batch.
345 TensorShape component_shape({});
346 for (int i = 1; i < batch_component_shape.dims(); ++i) {
347 component_shape.AddDim(batch_component_shape.dim_size(i));
348 }
349 auto copy_element_fn = [component_index, &batch_elements,
350 &batch_component, &component_shape](int index) {
351 // Take the fast path if possible.
352 if (batch_elements[index][component_index].shape() ==
353 component_shape) {
354 TF_RETURN_IF_ERROR(batch_util::CopyElementToSlice(
355 batch_elements[index][component_index], &batch_component,
356 index));
357 } else {
358 TF_RETURN_IF_ERROR(batch_util::CopyElementToLargerSlice(
359 batch_elements[index][component_index], &batch_component,
360 index));
361 }
362 return Status::OK();
363 };
364
365 if (dataset()->parallel_copy_ ||
366 (in_experiment && (batch_component.AllocatedBytes() /
367 num_batch_elements) >= (1 << 15))) {
368 BlockingCounter counter(num_batch_elements);
369 Status status;
370 mutex status_mu;
371 const auto num_threads = ctx->runner_threadpool_size();
372 const auto slice_size = num_batch_elements / num_threads;
373 int64_t offset = 0;
374 for (size_t i = 0; i < num_threads; ++i) {
375 int64_t length = slice_size;
376 // When the number of threads does not divide the number of elements
377 // evenly, the size of some slices is incremented to guarantee their
378 // sizes add up to the total number of elements.
379 if (i < num_batch_elements % num_threads) ++length;
380 (*ctx->runner())([offset, length, &status, &status_mu, &counter,
381 ©_element_fn]() {
382 for (size_t j = offset; j < offset + length; ++j) {
383 {
384 Status s = copy_element_fn(j);
385 mutex_lock l(status_mu);
386 status.Update(s);
387 }
388 counter.DecrementCount();
389 }
390 });
391 offset += length;
392 }
393 counter.Wait();
394 TF_RETURN_IF_ERROR(status);
395 } else {
396 for (size_t i = 0; i < num_batch_elements; ++i) {
397 TF_RETURN_IF_ERROR(copy_element_fn(i));
398 }
399 }
400 }
401 return Status::OK();
402 }
403
404 mutex mu_;
405 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
406 };
407
408 const int64 batch_size_;
409 const bool drop_remainder_;
410 const bool parallel_copy_;
411 const std::vector<PartialTensorShape> padded_shapes_;
412 const std::vector<Tensor> padding_values_;
413 const DatasetBase* const input_;
414 const int op_version_;
415 std::vector<PartialTensorShape> output_shapes_;
416 const TraceMeMetadata traceme_metadata_;
417 };
418
PaddedBatchDatasetOp(OpKernelConstruction * ctx)419 PaddedBatchDatasetOp::PaddedBatchDatasetOp(OpKernelConstruction* ctx)
420 : UnaryDatasetOpKernel(ctx),
421 op_version_(ctx->def().op() == "PaddedBatchDataset" ? 1 : 2) {
422 if (ctx->HasAttr(kParallelCopy)) {
423 OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, ¶llel_copy_));
424 }
425 }
426
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)427 void PaddedBatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
428 DatasetBase** output) {
429 int64_t batch_size;
430 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
431 OP_REQUIRES(ctx, batch_size > 0,
432 errors::InvalidArgument("Batch size must be greater than zero."));
433
434 bool drop_remainder = false;
435 if (op_version_ > 1) {
436 OP_REQUIRES_OK(
437 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
438 }
439
440 OpInputList padded_shape_tensors;
441 OP_REQUIRES_OK(ctx, ctx->input_list(kPaddedShapes, &padded_shape_tensors));
442 std::vector<PartialTensorShape> padded_shapes;
443 padded_shapes.reserve(padded_shape_tensors.size());
444 OP_REQUIRES(ctx, padded_shape_tensors.size() == input->output_shapes().size(),
445 errors::InvalidArgument("Number of padded shapes (",
446 padded_shape_tensors.size(),
447 ") must match the number of components "
448 "in the input dataset's elements (",
449 input->output_shapes().size(), ")"));
450 for (const Tensor& padded_shape_t : padded_shape_tensors) {
451 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(padded_shape_t.shape()),
452 errors::InvalidArgument("All padded shapes must be vectors"));
453 PartialTensorShape padded_shape;
454 OP_REQUIRES_OK(ctx, PartialTensorShape::MakePartialShape(
455 padded_shape_t.vec<int64>().data(),
456 padded_shape_t.NumElements(), &padded_shape));
457 padded_shapes.push_back(std::move(padded_shape));
458 }
459 OpInputList padding_values_list;
460 OP_REQUIRES_OK(ctx, ctx->input_list(kPaddingValues, &padding_values_list));
461 std::vector<Tensor> padding_values;
462 OP_REQUIRES(ctx, padding_values_list.size() == input->output_shapes().size(),
463 errors::InvalidArgument(
464 "Number of padding values (", padding_values_list.size(),
465 ") must match the number of components in the input "
466 "dataset's elements (",
467 input->output_shapes().size(), ")"));
468 for (int i = 0; i < padding_values_list.size(); ++i) {
469 const Tensor& padding_value_t = padding_values_list[i];
470 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(padding_value_t.shape()),
471 errors::InvalidArgument("All padding values must be scalars"));
472 OP_REQUIRES(ctx, padding_value_t.dtype() == input->output_dtypes()[i],
473 errors::InvalidArgument(
474 "Mismatched type between padding value ", i,
475 " and input dataset's component ", i, ": ",
476 DataTypeString(padding_value_t.dtype()), " vs. ",
477 DataTypeString(input->output_dtypes()[i])));
478 padding_values.push_back(tensor::DeepCopy(padding_value_t));
479 }
480
481 *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_,
482 std::move(padded_shapes), std::move(padding_values),
483 input, op_version_);
484 }
485
486 namespace {
487 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDataset").Device(DEVICE_CPU),
488 PaddedBatchDatasetOp);
489
490 REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
491 PaddedBatchDatasetOp);
492 } // namespace
493 } // namespace data
494 } // namespace tensorflow
495