1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30
31 namespace tensorflow {
32 namespace data {
33
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36
37 /* static */ constexpr const char* const BatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const BatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const BatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const BatchDatasetOp::kDropRemainder;
41 /* static */ constexpr const char* const BatchDatasetOp::kParallelCopy;
42 /* static */ constexpr const char* const BatchDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const BatchDatasetOp::kOutputShapes;
44
45 constexpr char kInputImplEmpty[] = "input_impl_empty";
46 constexpr char kBatchDataset[] = "BatchDataset";
47
48 class BatchDatasetOp::Dataset : public DatasetBase {
49 public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,const DatasetBase * input,int op_version)50 Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
51 bool parallel_copy, const DatasetBase* input, int op_version)
52 : DatasetBase(DatasetContext(ctx)),
53 batch_size_(batch_size),
54 // Dataset batch is sometimes used to stack all elements in the
55 // dataset. In such cases, a very large batch size (e.g., INT32_MAX)
56 // is passed with drop_remainder set to false. Avoid OOM in such case
57 // by limiting `reserve()` size by 2**16.
58 reserve_size_(drop_remainder ? batch_size
59 : std::min<int64>(batch_size, 1 << 16)),
60 drop_remainder_(drop_remainder),
61 parallel_copy_(parallel_copy),
62 input_(input),
63 op_version_(op_version),
64 traceme_metadata_(
65 {{"batch_size",
66 strings::Printf("%lld", static_cast<long long>(batch_size))},
67 {"drop_remainder", drop_remainder ? "true" : "false"},
68 {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69 input_->Ref();
70
71 // NOTE(mrry): Currently we implement "batch up to" semantics. If
72 // we could tell statically that the input dataset is infinite,
73 // then we could always report `batch_size` as the 0th dimension.
74 const auto& input_shapes = input_->output_shapes();
75 output_shapes_.reserve(input_shapes.size());
76 for (const auto& input_shape : input_shapes) {
77 if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
78 output_shapes_.emplace_back(
79 PartialTensorShape({batch_size_}).Concatenate(input_shape));
80 } else {
81 output_shapes_.emplace_back(
82 PartialTensorShape({-1}).Concatenate(input_shape));
83 }
84 }
85 }
86
~Dataset()87 ~Dataset() override { input_->Unref(); }
88
MakeIteratorInternal(const string & prefix) const89 std::unique_ptr<IteratorBase> MakeIteratorInternal(
90 const string& prefix) const override {
91 name_utils::IteratorPrefixParams params;
92 params.op_version = op_version_;
93 return absl::make_unique<Iterator>(Iterator::Params{
94 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
95 }
96
output_dtypes() const97 const DataTypeVector& output_dtypes() const override {
98 return input_->output_dtypes();
99 }
100
output_shapes() const101 const std::vector<PartialTensorShape>& output_shapes() const override {
102 return output_shapes_;
103 }
104
DebugString() const105 string DebugString() const override {
106 name_utils::DatasetDebugStringParams params;
107 params.op_version = op_version_;
108 params.set_args(batch_size_);
109 return name_utils::DatasetDebugString(kDatasetType, params);
110 }
111
Cardinality() const112 int64 Cardinality() const override {
113 int64_t n = input_->Cardinality();
114 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
115 return n;
116 }
117 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
118 }
119
InputDatasets(std::vector<const DatasetBase * > * inputs) const120 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
121 inputs->push_back(input_);
122 return Status::OK();
123 }
124
CheckExternalState() const125 Status CheckExternalState() const override {
126 return input_->CheckExternalState();
127 }
128
129 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const130 Status AsGraphDefInternal(SerializationContext* ctx,
131 DatasetGraphDefBuilder* b,
132 Node** output) const override {
133 Node* input_graph_node = nullptr;
134 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
135 Node* batch_size = nullptr;
136 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
137 Node* drop_remainder = nullptr;
138 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
139 AttrValue parallel_copy;
140 b->BuildAttrValue(parallel_copy_, ¶llel_copy);
141 TF_RETURN_IF_ERROR(
142 b->AddDataset(this, {input_graph_node, batch_size, drop_remainder},
143 {{kParallelCopy, parallel_copy}}, output));
144 return Status::OK();
145 }
146
147 private:
148 class Iterator : public DatasetIterator<Dataset> {
149 public:
Iterator(const Params & params)150 explicit Iterator(const Params& params)
151 : DatasetIterator<Dataset>(params) {}
152
Initialize(IteratorContext * ctx)153 Status Initialize(IteratorContext* ctx) override {
154 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
155 }
156
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)157 Status GetNextInternal(IteratorContext* ctx,
158 std::vector<Tensor>* out_tensors,
159 bool* end_of_sequence) override {
160 // Each row of `batch_elements` is a tuple of tensors from the
161 // input iterator.
162 std::vector<std::vector<Tensor>> batch_elements;
163 {
164 mutex_lock l(mu_);
165 if (!input_impl_) {
166 *end_of_sequence = true;
167 return Status::OK();
168 }
169 batch_elements.reserve(dataset()->reserve_size_);
170 *end_of_sequence = false;
171 for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
172 std::vector<Tensor> batch_element_tuple;
173 TF_RETURN_IF_ERROR(
174 input_impl_->GetNext(ctx, &batch_element_tuple, end_of_sequence));
175 if (!*end_of_sequence) {
176 batch_elements.emplace_back(std::move(batch_element_tuple));
177 } else {
178 input_impl_.reset();
179 }
180 }
181 }
182
183 if (batch_elements.empty()) {
184 DCHECK(*end_of_sequence);
185 return Status::OK();
186 }
187
188 if (dataset()->drop_remainder_ &&
189 batch_elements.size() < dataset()->batch_size_) {
190 *end_of_sequence = true;
191 return Status::OK();
192 }
193
194 // Copy the retrieved batch elements into one output tensor per tuple
195 // component.
196 //
197 // NOTE(mrry): If the input or output sizes are statically known, we
198 // could potentially read the input values in-place into their
199 // respective slice locations. This would require a different GetNext()
200 // overload that supports zero-copy, and might make sense in an
201 // optimization pass.
202 TF_RETURN_IF_ERROR(
203 CopyBatch(ctx, batch_elements, dataset()->parallel_copy_,
204 /*allocation_callback=*/nullptr, out_tensors));
205
206 *end_of_sequence = false;
207 return Status::OK();
208 }
209
210 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const211 std::shared_ptr<model::Node> CreateNode(
212 IteratorContext* ctx, model::Node::Args args) const override {
213 return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
214 }
215
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)216 Status SaveInternal(SerializationContext* ctx,
217 IteratorStateWriter* writer) override {
218 mutex_lock l(mu_);
219 if (!input_impl_) {
220 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
221 } else {
222 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
223 }
224 return Status::OK();
225 }
226
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)227 Status RestoreInternal(IteratorContext* ctx,
228 IteratorStateReader* reader) override {
229 mutex_lock l(mu_);
230 if (!reader->Contains(full_name(kInputImplEmpty))) {
231 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
232 } else {
233 input_impl_.reset();
234 }
235 return Status::OK();
236 }
237
GetTraceMeMetadata() const238 TraceMeMetadata GetTraceMeMetadata() const override {
239 return dataset()->traceme_metadata_;
240 }
241
242 private:
243 mutex mu_;
244 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
245 };
246
247 const int64 batch_size_;
248 const int64 reserve_size_;
249 const bool drop_remainder_;
250 const bool parallel_copy_;
251 const DatasetBase* const input_;
252 const int op_version_;
253 std::vector<PartialTensorShape> output_shapes_;
254 const TraceMeMetadata traceme_metadata_;
255 };
256
BatchDatasetOp(OpKernelConstruction * ctx)257 BatchDatasetOp::BatchDatasetOp(OpKernelConstruction* ctx)
258 : UnaryDatasetOpKernel(ctx),
259 op_version_(ctx->def().op() == kBatchDataset ? 1 : 2) {
260 if (ctx->HasAttr(kParallelCopy)) {
261 OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, ¶llel_copy_));
262 }
263 }
264
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)265 void BatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
266 DatasetBase** output) {
267 int64_t batch_size = 0;
268 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
269 OP_REQUIRES(ctx, batch_size > 0,
270 errors::InvalidArgument("Batch size must be greater than zero."));
271
272 bool drop_remainder = false;
273 if (op_version_ > 1) {
274 OP_REQUIRES_OK(
275 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
276 }
277
278 *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, input,
279 op_version_);
280 }
281
282 namespace {
283 REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
284 BatchDatasetOp);
285
286 REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
287 BatchDatasetOp);
288 } // namespace
289 } // namespace data
290 } // namespace tensorflow
291