1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/data/dataset_utils.h"
25 #include "tensorflow/core/kernels/data/name_utils.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30
31 namespace tensorflow {
32 namespace data {
33
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36
37 /* static */ constexpr const char* const BatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const BatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const BatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const BatchDatasetOp::kDropRemainder;
41 /* static */ constexpr const char* const BatchDatasetOp::kParallelCopy;
42 /* static */ constexpr const char* const BatchDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const BatchDatasetOp::kOutputShapes;
44
45 constexpr char kInputImplEmpty[] = "input_impl_empty";
46 constexpr char kBatchDataset[] = "BatchDataset";
47
48 class BatchDatasetOp::Dataset : public DatasetBase {
49 public:
Dataset(OpKernelContext * ctx,int64 batch_size,bool drop_remainder,bool parallel_copy,const DatasetBase * input,int op_version)50 Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
51 bool parallel_copy, const DatasetBase* input, int op_version)
52 : DatasetBase(DatasetContext(ctx)),
53 batch_size_(batch_size),
54 // Dataset batch is sometimes used to stack all elements in the
55 // dataset. In such cases, a very large batch size (e.g., INT32_MAX)
56 // is passed with drop_remainder set to false. Avoid OOM in such case
57 // by limiting `reserve()` size by 2**16.
58 reserve_size_(drop_remainder ? batch_size
59 : std::min<int64>(batch_size, 1 << 16)),
60 drop_remainder_(drop_remainder),
61 parallel_copy_(parallel_copy),
62 input_(input),
63 op_version_(op_version),
64 traceme_metadata_(
65 {{"batch_size",
66 strings::Printf("%lld", static_cast<long long>(batch_size))},
67 {"drop_remainder", drop_remainder ? "true" : "false"},
68 {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69 input_->Ref();
70
71 // NOTE(mrry): Currently we implement "batch up to" semantics. If
72 // we could tell statically that the input dataset is infinite,
73 // then we could always report `batch_size` as the 0th dimension.
74 const auto& input_shapes = input_->output_shapes();
75 output_shapes_.reserve(input_shapes.size());
76 for (const auto& input_shape : input_shapes) {
77 if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
78 output_shapes_.emplace_back(
79 PartialTensorShape({batch_size_}).Concatenate(input_shape));
80 } else {
81 output_shapes_.emplace_back(
82 PartialTensorShape({-1}).Concatenate(input_shape));
83 }
84 }
85 }
86
~Dataset()87 ~Dataset() override { input_->Unref(); }
88
MakeIteratorInternal(const string & prefix) const89 std::unique_ptr<IteratorBase> MakeIteratorInternal(
90 const string& prefix) const override {
91 name_utils::IteratorPrefixParams params;
92 params.op_version = op_version_;
93 return absl::make_unique<Iterator>(Iterator::Params{
94 this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
95 }
96
output_dtypes() const97 const DataTypeVector& output_dtypes() const override {
98 return input_->output_dtypes();
99 }
100
output_shapes() const101 const std::vector<PartialTensorShape>& output_shapes() const override {
102 return output_shapes_;
103 }
104
DebugString() const105 string DebugString() const override {
106 name_utils::DatasetDebugStringParams params;
107 params.op_version = op_version_;
108 params.set_args(batch_size_);
109 return name_utils::DatasetDebugString(kDatasetType, params);
110 }
111
Cardinality() const112 int64 Cardinality() const override {
113 int64 n = input_->Cardinality();
114 if (n == kInfiniteCardinality || n == kUnknownCardinality) {
115 return n;
116 }
117 return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
118 }
119
InputDatasets(std::vector<const DatasetBase * > * inputs) const120 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
121 inputs->push_back(input_);
122 return Status::OK();
123 }
124
CheckExternalState() const125 Status CheckExternalState() const override {
126 return input_->CheckExternalState();
127 }
128
129 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const130 Status AsGraphDefInternal(SerializationContext* ctx,
131 DatasetGraphDefBuilder* b,
132 Node** output) const override {
133 Node* input_graph_node = nullptr;
134 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
135 Node* batch_size = nullptr;
136 TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
137 Node* drop_remainder = nullptr;
138 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
139 AttrValue parallel_copy;
140 b->BuildAttrValue(parallel_copy_, ¶llel_copy);
141 TF_RETURN_IF_ERROR(
142 b->AddDataset(this, {input_graph_node, batch_size, drop_remainder},
143 {{kParallelCopy, parallel_copy}}, output));
144 return Status::OK();
145 }
146
147 private:
148 class Iterator : public DatasetIterator<Dataset> {
149 public:
Iterator(const Params & params)150 explicit Iterator(const Params& params)
151 : DatasetIterator<Dataset>(params) {}
152
Initialize(IteratorContext * ctx)153 Status Initialize(IteratorContext* ctx) override {
154 return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
155 }
156
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)157 Status GetNextInternal(IteratorContext* ctx,
158 std::vector<Tensor>* out_tensors,
159 bool* end_of_sequence) override {
160 // Each row of `batch_elements` is a tuple of tensors from the
161 // input iterator.
162 std::vector<std::vector<Tensor>> batch_elements;
163 {
164 mutex_lock l(mu_);
165 if (!input_impl_) {
166 *end_of_sequence = true;
167 return Status::OK();
168 }
169 batch_elements.reserve(dataset()->reserve_size_);
170 *end_of_sequence = false;
171 for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
172 std::vector<Tensor> batch_element_tuple;
173 TF_RETURN_IF_ERROR(
174 input_impl_->GetNext(ctx, &batch_element_tuple, end_of_sequence));
175 if (!*end_of_sequence) {
176 batch_elements.emplace_back(std::move(batch_element_tuple));
177 } else {
178 input_impl_.reset();
179 }
180 }
181 }
182
183 if (batch_elements.empty()) {
184 DCHECK(*end_of_sequence);
185 return Status::OK();
186 }
187
188 if (dataset()->drop_remainder_ &&
189 batch_elements.size() < dataset()->batch_size_) {
190 *end_of_sequence = true;
191 return Status::OK();
192 }
193
194 // Copy the retrieved batch elements into one output tensor per tuple
195 // component.
196 //
197 // NOTE(mrry): If the input or output sizes are statically known, we
198 // could potentially read the input values in-place into their
199 // respective slice locations. This would require a different GetNext()
200 // overload that supports zero-copy, and might make sense in an
201 // optimization pass.
202 TF_RETURN_IF_ERROR(CopyBatch(/*parallel_copy=*/dataset()->parallel_copy_,
203 ctx, out_tensors, &batch_elements));
204
205 *end_of_sequence = false;
206 return Status::OK();
207 }
208
209 protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const210 std::shared_ptr<model::Node> CreateNode(
211 IteratorContext* ctx, model::Node::Args args) const override {
212 return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
213 }
214
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)215 Status SaveInternal(SerializationContext* ctx,
216 IteratorStateWriter* writer) override {
217 mutex_lock l(mu_);
218 if (!input_impl_) {
219 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
220 } else {
221 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
222 }
223 return Status::OK();
224 }
225
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)226 Status RestoreInternal(IteratorContext* ctx,
227 IteratorStateReader* reader) override {
228 mutex_lock l(mu_);
229 if (!reader->Contains(full_name(kInputImplEmpty))) {
230 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
231 } else {
232 input_impl_.reset();
233 }
234 return Status::OK();
235 }
236
GetTraceMeMetadata() const237 TraceMeMetadata GetTraceMeMetadata() const override {
238 return dataset()->traceme_metadata_;
239 }
240
241 private:
242 mutex mu_;
243 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
244 };
245
246 const int64 batch_size_;
247 const int64 reserve_size_;
248 const bool drop_remainder_;
249 const bool parallel_copy_;
250 const DatasetBase* const input_;
251 const int op_version_;
252 std::vector<PartialTensorShape> output_shapes_;
253 const TraceMeMetadata traceme_metadata_;
254 };
255
BatchDatasetOp(OpKernelConstruction * ctx)256 BatchDatasetOp::BatchDatasetOp(OpKernelConstruction* ctx)
257 : UnaryDatasetOpKernel(ctx),
258 op_version_(ctx->def().op() == kBatchDataset ? 1 : 2) {
259 if (ctx->HasAttr(kParallelCopy)) {
260 OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, ¶llel_copy_));
261 }
262 }
263
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)264 void BatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
265 DatasetBase** output) {
266 int64 batch_size = 0;
267 OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
268 OP_REQUIRES(ctx, batch_size > 0,
269 errors::InvalidArgument("Batch size must be greater than zero."));
270
271 bool drop_remainder = false;
272 if (op_version_ > 1) {
273 OP_REQUIRES_OK(
274 ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
275 }
276
277 *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, input,
278 op_version_);
279 }
280
281 namespace {
282 REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
283 BatchDatasetOp);
284
285 REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
286 BatchDatasetOp);
287 } // namespace
288 } // namespace data
289 } // namespace tensorflow
290