• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/name_utils.h"
22 #include "tensorflow/core/framework/dataset.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/partial_tensor_shape.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36 
37 /* static */ constexpr const char* const BatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const BatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const BatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const BatchDatasetOp::kDropRemainder;
41 /* static */ constexpr const char* const BatchDatasetOp::kParallelCopy;
42 /* static */ constexpr const char* const BatchDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const BatchDatasetOp::kOutputShapes;
44 
45 constexpr char kInputImplEmpty[] = "input_impl_empty";
46 constexpr char kBatchDataset[] = "BatchDataset";
47 
48 class BatchDatasetOp::Dataset : public DatasetBase {
49  public:
Dataset(OpKernelContext * ctx,int64_t batch_size,bool drop_remainder,bool parallel_copy,const DatasetBase * input,int op_version)50   Dataset(OpKernelContext* ctx, int64_t batch_size, bool drop_remainder,
51           bool parallel_copy, const DatasetBase* input, int op_version)
52       : DatasetBase(DatasetContext(ctx)),
53         batch_size_(batch_size),
54         // Dataset batch is sometimes used to stack all elements in the
55         // dataset. In such cases, a very large batch size (e.g., INT32_MAX)
56         // is passed with drop_remainder set to false. Avoid OOM in such case
57         // by limiting `reserve()` size by 2**16.
58         reserve_size_(drop_remainder ? batch_size
59                                      : std::min<int64>(batch_size, 1 << 16)),
60         drop_remainder_(drop_remainder),
61         parallel_copy_(parallel_copy),
62         input_(input),
63         op_version_(op_version),
64         traceme_metadata_(
65             {{"batch_size",
66               strings::Printf("%lld", static_cast<long long>(batch_size))},
67              {"drop_remainder", drop_remainder ? "true" : "false"},
68              {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69     input_->Ref();
70 
71     // NOTE(mrry): Currently we implement "batch up to" semantics. If
72     // we could tell statically that the input dataset is infinite,
73     // then we could always report `batch_size` as the 0th dimension.
74     const auto& input_shapes = input_->output_shapes();
75     output_shapes_.reserve(input_shapes.size());
76     for (const auto& input_shape : input_shapes) {
77       if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
78         output_shapes_.emplace_back(
79             PartialTensorShape({batch_size_}).Concatenate(input_shape));
80       } else {
81         output_shapes_.emplace_back(
82             PartialTensorShape({-1}).Concatenate(input_shape));
83       }
84     }
85   }
86 
~Dataset()87   ~Dataset() override { input_->Unref(); }
88 
MakeIteratorInternal(const string & prefix) const89   std::unique_ptr<IteratorBase> MakeIteratorInternal(
90       const string& prefix) const override {
91     name_utils::IteratorPrefixParams params;
92     params.op_version = op_version_;
93     return absl::make_unique<Iterator>(Iterator::Params{
94         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
95   }
96 
output_dtypes() const97   const DataTypeVector& output_dtypes() const override {
98     return input_->output_dtypes();
99   }
100 
output_shapes() const101   const std::vector<PartialTensorShape>& output_shapes() const override {
102     return output_shapes_;
103   }
104 
DebugString() const105   string DebugString() const override {
106     name_utils::DatasetDebugStringParams params;
107     params.op_version = op_version_;
108     params.set_args(batch_size_);
109     return name_utils::DatasetDebugString(kDatasetType, params);
110   }
111 
Cardinality() const112   int64 Cardinality() const override {
113     int64_t n = input_->Cardinality();
114     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
115       return n;
116     }
117     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
118   }
119 
InputDatasets(std::vector<const DatasetBase * > * inputs) const120   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
121     inputs->push_back(input_);
122     return Status::OK();
123   }
124 
CheckExternalState() const125   Status CheckExternalState() const override {
126     return input_->CheckExternalState();
127   }
128 
129  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const130   Status AsGraphDefInternal(SerializationContext* ctx,
131                             DatasetGraphDefBuilder* b,
132                             Node** output) const override {
133     Node* input_graph_node = nullptr;
134     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
135     Node* batch_size = nullptr;
136     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
137     Node* drop_remainder = nullptr;
138     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
139     AttrValue parallel_copy;
140     b->BuildAttrValue(parallel_copy_, &parallel_copy);
141     TF_RETURN_IF_ERROR(
142         b->AddDataset(this, {input_graph_node, batch_size, drop_remainder},
143                       {{kParallelCopy, parallel_copy}}, output));
144     return Status::OK();
145   }
146 
147  private:
148   class Iterator : public DatasetIterator<Dataset> {
149    public:
Iterator(const Params & params)150     explicit Iterator(const Params& params)
151         : DatasetIterator<Dataset>(params) {}
152 
Initialize(IteratorContext * ctx)153     Status Initialize(IteratorContext* ctx) override {
154       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
155     }
156 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)157     Status GetNextInternal(IteratorContext* ctx,
158                            std::vector<Tensor>* out_tensors,
159                            bool* end_of_sequence) override {
160       // Each row of `batch_elements` is a tuple of tensors from the
161       // input iterator.
162       std::vector<std::vector<Tensor>> batch_elements;
163       {
164         mutex_lock l(mu_);
165         if (!input_impl_) {
166           *end_of_sequence = true;
167           return Status::OK();
168         }
169         batch_elements.reserve(dataset()->reserve_size_);
170         *end_of_sequence = false;
171         for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
172           std::vector<Tensor> batch_element_tuple;
173           TF_RETURN_IF_ERROR(
174               input_impl_->GetNext(ctx, &batch_element_tuple, end_of_sequence));
175           if (!*end_of_sequence) {
176             batch_elements.emplace_back(std::move(batch_element_tuple));
177           } else {
178             input_impl_.reset();
179           }
180         }
181       }
182 
183       if (batch_elements.empty()) {
184         DCHECK(*end_of_sequence);
185         return Status::OK();
186       }
187 
188       if (dataset()->drop_remainder_ &&
189           batch_elements.size() < dataset()->batch_size_) {
190         *end_of_sequence = true;
191         return Status::OK();
192       }
193 
194       // Copy the retrieved batch elements into one output tensor per tuple
195       // component.
196       //
197       // NOTE(mrry): If the input or output sizes are statically known, we
198       // could potentially read the input values in-place into their
199       // respective slice locations. This would require a different GetNext()
200       // overload that supports zero-copy, and might make sense in an
201       // optimization pass.
202       TF_RETURN_IF_ERROR(
203           CopyBatch(ctx, batch_elements, dataset()->parallel_copy_,
204                     /*allocation_callback=*/nullptr, out_tensors));
205 
206       *end_of_sequence = false;
207       return Status::OK();
208     }
209 
210    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const211     std::shared_ptr<model::Node> CreateNode(
212         IteratorContext* ctx, model::Node::Args args) const override {
213       return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
214     }
215 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)216     Status SaveInternal(SerializationContext* ctx,
217                         IteratorStateWriter* writer) override {
218       mutex_lock l(mu_);
219       if (!input_impl_) {
220         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
221       } else {
222         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
223       }
224       return Status::OK();
225     }
226 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)227     Status RestoreInternal(IteratorContext* ctx,
228                            IteratorStateReader* reader) override {
229       mutex_lock l(mu_);
230       if (!reader->Contains(full_name(kInputImplEmpty))) {
231         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
232       } else {
233         input_impl_.reset();
234       }
235       return Status::OK();
236     }
237 
GetTraceMeMetadata() const238     TraceMeMetadata GetTraceMeMetadata() const override {
239       return dataset()->traceme_metadata_;
240     }
241 
242    private:
243     mutex mu_;
244     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
245   };
246 
247   const int64 batch_size_;
248   const int64 reserve_size_;
249   const bool drop_remainder_;
250   const bool parallel_copy_;
251   const DatasetBase* const input_;
252   const int op_version_;
253   std::vector<PartialTensorShape> output_shapes_;
254   const TraceMeMetadata traceme_metadata_;
255 };
256 
BatchDatasetOp(OpKernelConstruction * ctx)257 BatchDatasetOp::BatchDatasetOp(OpKernelConstruction* ctx)
258     : UnaryDatasetOpKernel(ctx),
259       op_version_(ctx->def().op() == kBatchDataset ? 1 : 2) {
260   if (ctx->HasAttr(kParallelCopy)) {
261     OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, &parallel_copy_));
262   }
263 }
264 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)265 void BatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
266                                  DatasetBase** output) {
267   int64_t batch_size = 0;
268   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
269   OP_REQUIRES(ctx, batch_size > 0,
270               errors::InvalidArgument("Batch size must be greater than zero."));
271 
272   bool drop_remainder = false;
273   if (op_version_ > 1) {
274     OP_REQUIRES_OK(
275         ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
276   }
277 
278   *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, input,
279                         op_version_);
280 }
281 
282 namespace {
283 REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
284                         BatchDatasetOp);
285 
286 REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
287                         BatchDatasetOp);
288 }  // namespace
289 }  // namespace data
290 }  // namespace tensorflow
291