• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "tensorflow/core/framework/dataset.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/data/dataset_utils.h"
25 #include "tensorflow/core/kernels/data/name_utils.h"
26 #include "tensorflow/core/lib/gtl/cleanup.h"
27 #include "tensorflow/core/platform/macros.h"
28 #include "tensorflow/core/platform/stringprintf.h"
29 #include "tensorflow/core/util/batch_util.h"
30 
31 namespace tensorflow {
32 namespace data {
33 
34 // See documentation in ../../ops/dataset_ops.cc for a high-level
35 // description of the following op.
36 
37 /* static */ constexpr const char* const BatchDatasetOp::kDatasetType;
38 /* static */ constexpr const char* const BatchDatasetOp::kInputDataset;
39 /* static */ constexpr const char* const BatchDatasetOp::kBatchSize;
40 /* static */ constexpr const char* const BatchDatasetOp::kDropRemainder;
41 /* static */ constexpr const char* const BatchDatasetOp::kParallelCopy;
42 /* static */ constexpr const char* const BatchDatasetOp::kOutputTypes;
43 /* static */ constexpr const char* const BatchDatasetOp::kOutputShapes;
44 
45 constexpr char kInputImplEmpty[] = "input_impl_empty";
46 constexpr char kBatchDataset[] = "BatchDataset";
47 
48 class BatchDatasetOp::Dataset : public DatasetBase {
49  public:
Dataset(OpKernelContext * ctx,int64 batch_size,bool drop_remainder,bool parallel_copy,const DatasetBase * input,int op_version)50   Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
51           bool parallel_copy, const DatasetBase* input, int op_version)
52       : DatasetBase(DatasetContext(ctx)),
53         batch_size_(batch_size),
54         // Dataset batch is sometimes used to stack all elements in the
55         // dataset. In such cases, a very large batch size (e.g., INT32_MAX)
56         // is passed with drop_remainder set to false. Avoid OOM in such case
57         // by limiting `reserve()` size by 2**16.
58         reserve_size_(drop_remainder ? batch_size
59                                      : std::min<int64>(batch_size, 1 << 16)),
60         drop_remainder_(drop_remainder),
61         parallel_copy_(parallel_copy),
62         input_(input),
63         op_version_(op_version),
64         traceme_metadata_(
65             {{"batch_size",
66               strings::Printf("%lld", static_cast<long long>(batch_size))},
67              {"drop_remainder", drop_remainder ? "true" : "false"},
68              {"parallel_copy", parallel_copy ? "true" : "false"}}) {
69     input_->Ref();
70 
71     // NOTE(mrry): Currently we implement "batch up to" semantics. If
72     // we could tell statically that the input dataset is infinite,
73     // then we could always report `batch_size` as the 0th dimension.
74     const auto& input_shapes = input_->output_shapes();
75     output_shapes_.reserve(input_shapes.size());
76     for (const auto& input_shape : input_shapes) {
77       if (drop_remainder_ || input_->Cardinality() == kInfiniteCardinality) {
78         output_shapes_.emplace_back(
79             PartialTensorShape({batch_size_}).Concatenate(input_shape));
80       } else {
81         output_shapes_.emplace_back(
82             PartialTensorShape({-1}).Concatenate(input_shape));
83       }
84     }
85   }
86 
~Dataset()87   ~Dataset() override { input_->Unref(); }
88 
MakeIteratorInternal(const string & prefix) const89   std::unique_ptr<IteratorBase> MakeIteratorInternal(
90       const string& prefix) const override {
91     name_utils::IteratorPrefixParams params;
92     params.op_version = op_version_;
93     return absl::make_unique<Iterator>(Iterator::Params{
94         this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
95   }
96 
output_dtypes() const97   const DataTypeVector& output_dtypes() const override {
98     return input_->output_dtypes();
99   }
100 
output_shapes() const101   const std::vector<PartialTensorShape>& output_shapes() const override {
102     return output_shapes_;
103   }
104 
DebugString() const105   string DebugString() const override {
106     name_utils::DatasetDebugStringParams params;
107     params.op_version = op_version_;
108     params.set_args(batch_size_);
109     return name_utils::DatasetDebugString(kDatasetType, params);
110   }
111 
Cardinality() const112   int64 Cardinality() const override {
113     int64 n = input_->Cardinality();
114     if (n == kInfiniteCardinality || n == kUnknownCardinality) {
115       return n;
116     }
117     return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
118   }
119 
InputDatasets(std::vector<const DatasetBase * > * inputs) const120   Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
121     inputs->push_back(input_);
122     return Status::OK();
123   }
124 
CheckExternalState() const125   Status CheckExternalState() const override {
126     return input_->CheckExternalState();
127   }
128 
129  protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const130   Status AsGraphDefInternal(SerializationContext* ctx,
131                             DatasetGraphDefBuilder* b,
132                             Node** output) const override {
133     Node* input_graph_node = nullptr;
134     TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
135     Node* batch_size = nullptr;
136     TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
137     Node* drop_remainder = nullptr;
138     TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
139     AttrValue parallel_copy;
140     b->BuildAttrValue(parallel_copy_, &parallel_copy);
141     TF_RETURN_IF_ERROR(
142         b->AddDataset(this, {input_graph_node, batch_size, drop_remainder},
143                       {{kParallelCopy, parallel_copy}}, output));
144     return Status::OK();
145   }
146 
147  private:
148   class Iterator : public DatasetIterator<Dataset> {
149    public:
Iterator(const Params & params)150     explicit Iterator(const Params& params)
151         : DatasetIterator<Dataset>(params) {}
152 
Initialize(IteratorContext * ctx)153     Status Initialize(IteratorContext* ctx) override {
154       return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_);
155     }
156 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)157     Status GetNextInternal(IteratorContext* ctx,
158                            std::vector<Tensor>* out_tensors,
159                            bool* end_of_sequence) override {
160       // Each row of `batch_elements` is a tuple of tensors from the
161       // input iterator.
162       std::vector<std::vector<Tensor>> batch_elements;
163       {
164         mutex_lock l(mu_);
165         if (!input_impl_) {
166           *end_of_sequence = true;
167           return Status::OK();
168         }
169         batch_elements.reserve(dataset()->reserve_size_);
170         *end_of_sequence = false;
171         for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence; ++i) {
172           std::vector<Tensor> batch_element_tuple;
173           TF_RETURN_IF_ERROR(
174               input_impl_->GetNext(ctx, &batch_element_tuple, end_of_sequence));
175           if (!*end_of_sequence) {
176             batch_elements.emplace_back(std::move(batch_element_tuple));
177           } else {
178             input_impl_.reset();
179           }
180         }
181       }
182 
183       if (batch_elements.empty()) {
184         DCHECK(*end_of_sequence);
185         return Status::OK();
186       }
187 
188       if (dataset()->drop_remainder_ &&
189           batch_elements.size() < dataset()->batch_size_) {
190         *end_of_sequence = true;
191         return Status::OK();
192       }
193 
194       // Copy the retrieved batch elements into one output tensor per tuple
195       // component.
196       //
197       // NOTE(mrry): If the input or output sizes are statically known, we
198       // could potentially read the input values in-place into their
199       // respective slice locations. This would require a different GetNext()
200       // overload that supports zero-copy, and might make sense in an
201       // optimization pass.
202       TF_RETURN_IF_ERROR(CopyBatch(/*parallel_copy=*/dataset()->parallel_copy_,
203                                    ctx, out_tensors, &batch_elements));
204 
205       *end_of_sequence = false;
206       return Status::OK();
207     }
208 
209    protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const210     std::shared_ptr<model::Node> CreateNode(
211         IteratorContext* ctx, model::Node::Args args) const override {
212       return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
213     }
214 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)215     Status SaveInternal(SerializationContext* ctx,
216                         IteratorStateWriter* writer) override {
217       mutex_lock l(mu_);
218       if (!input_impl_) {
219         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
220       } else {
221         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
222       }
223       return Status::OK();
224     }
225 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)226     Status RestoreInternal(IteratorContext* ctx,
227                            IteratorStateReader* reader) override {
228       mutex_lock l(mu_);
229       if (!reader->Contains(full_name(kInputImplEmpty))) {
230         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
231       } else {
232         input_impl_.reset();
233       }
234       return Status::OK();
235     }
236 
GetTraceMeMetadata() const237     TraceMeMetadata GetTraceMeMetadata() const override {
238       return dataset()->traceme_metadata_;
239     }
240 
241    private:
242     mutex mu_;
243     std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
244   };
245 
246   const int64 batch_size_;
247   const int64 reserve_size_;
248   const bool drop_remainder_;
249   const bool parallel_copy_;
250   const DatasetBase* const input_;
251   const int op_version_;
252   std::vector<PartialTensorShape> output_shapes_;
253   const TraceMeMetadata traceme_metadata_;
254 };
255 
BatchDatasetOp(OpKernelConstruction * ctx)256 BatchDatasetOp::BatchDatasetOp(OpKernelConstruction* ctx)
257     : UnaryDatasetOpKernel(ctx),
258       op_version_(ctx->def().op() == kBatchDataset ? 1 : 2) {
259   if (ctx->HasAttr(kParallelCopy)) {
260     OP_REQUIRES_OK(ctx, ctx->GetAttr(kParallelCopy, &parallel_copy_));
261   }
262 }
263 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)264 void BatchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
265                                  DatasetBase** output) {
266   int64 batch_size = 0;
267   OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, kBatchSize, &batch_size));
268   OP_REQUIRES(ctx, batch_size > 0,
269               errors::InvalidArgument("Batch size must be greater than zero."));
270 
271   bool drop_remainder = false;
272   if (op_version_ > 1) {
273     OP_REQUIRES_OK(
274         ctx, ParseScalarArgument<bool>(ctx, kDropRemainder, &drop_remainder));
275   }
276 
277   *output = new Dataset(ctx, batch_size, drop_remainder, parallel_copy_, input,
278                         op_version_);
279 }
280 
281 namespace {
282 REGISTER_KERNEL_BUILDER(Name("BatchDataset").Device(DEVICE_CPU),
283                         BatchDatasetOp);
284 
285 REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
286                         BatchDatasetOp);
287 }  // namespace
288 }  // namespace data
289 }  // namespace tensorflow
290