• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/name_utils.h"
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/tensor_util.h"
19 #include "tensorflow/core/platform/stringprintf.h"
20 
21 namespace tensorflow {
22 namespace data {
23 namespace experimental {
24 namespace {
25 
CeilDiv(int64_t dividend,int64_t divisor)26 inline int64_t CeilDiv(int64_t dividend, int64_t divisor) {
27   return (dividend - 1 + divisor) / divisor;
28 }
29 
30 constexpr const char* const kDatasetTypeV1 = "Rebatch";
31 constexpr const char* const kDatasetTypeV2 = "RebatchV2";
32 
33 class RebatchDatasetOp : public UnaryDatasetOpKernel {
34  public:
RebatchDatasetOp(OpKernelConstruction * ctx)35   explicit RebatchDatasetOp(OpKernelConstruction* ctx)
36       : UnaryDatasetOpKernel(ctx) {
37     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
38     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
39   }
40 
41  protected:
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)42   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
43                    DatasetBase** output) override {
44     int64_t num_replicas;
45     OP_REQUIRES_OK(ctx,
46                    ParseScalarArgument(ctx, "num_replicas", &num_replicas));
47     OP_REQUIRES(
48         ctx, num_replicas > 0,
49         errors::InvalidArgument("num_replicas must be greater than zero."));
50     *output =
51         new Dataset(ctx, input, num_replicas, output_types_, output_shapes_);
52   }
53 
54  private:
55   class Dataset : public DatasetBase {
56    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const int64_t num_replicas,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)57     Dataset(OpKernelContext* ctx, const DatasetBase* input,
58             const int64_t num_replicas, const DataTypeVector& output_types,
59             const std::vector<PartialTensorShape>& output_shapes)
60         : DatasetBase(DatasetContext(ctx)),
61           input_(input),
62           num_replicas_(num_replicas),
63           output_types_(output_types),
64           output_shapes_(output_shapes),
65           traceme_metadata_(
66               {{"num_replicas", strings::Printf("%lld", static_cast<long long>(
67                                                             num_replicas))}}) {
68       input_->Ref();
69     }
70 
~Dataset()71     ~Dataset() override { input_->Unref(); }
72 
MakeIteratorInternal(const string & prefix) const73     std::unique_ptr<IteratorBase> MakeIteratorInternal(
74         const string& prefix) const override {
75       name_utils::IteratorPrefixParams params;
76       return std::make_unique<Iterator>(Iterator::Params{
77           this, name_utils::IteratorPrefix(kDatasetTypeV1, prefix, params)});
78     }
79 
output_dtypes() const80     const DataTypeVector& output_dtypes() const override {
81       return output_types_;
82     }
83 
output_shapes() const84     const std::vector<PartialTensorShape>& output_shapes() const override {
85       return output_shapes_;
86     }
87 
DebugString() const88     string DebugString() const override {
89       name_utils::DatasetDebugStringParams params;
90       params.set_args(num_replicas_);
91       return name_utils::DatasetDebugString(kDatasetTypeV1, params);
92     }
93 
InputDatasets(std::vector<const DatasetBase * > * inputs) const94     Status InputDatasets(
95         std::vector<const DatasetBase*>* inputs) const override {
96       inputs->push_back(input_);
97       return OkStatus();
98     }
99 
CheckExternalState() const100     Status CheckExternalState() const override {
101       return input_->CheckExternalState();
102     }
103 
104    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const105     Status AsGraphDefInternal(SerializationContext* ctx,
106                               DatasetGraphDefBuilder* b,
107                               Node** output) const override {
108       Node* input_graph_node = nullptr;
109       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
110       Node* num_replicas = nullptr;
111       TF_RETURN_IF_ERROR(b->AddScalar(num_replicas_, &num_replicas));
112       TF_RETURN_IF_ERROR(
113           b->AddDataset(this, {input_graph_node, num_replicas}, output));
114       return OkStatus();
115     }
116 
117    private:
118     class Iterator : public DatasetIterator<Dataset> {
119      public:
Iterator(const Params & params)120       explicit Iterator(const Params& params)
121           : DatasetIterator<Dataset>(params) {}
122 
~Iterator()123       ~Iterator() override {}
124 
Initialize(IteratorContext * ctx)125       Status Initialize(IteratorContext* ctx) override {
126         return dataset()->input_->MakeIterator(ctx, this, prefix(),
127                                                &input_impl_);
128       }
129 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)130       Status GetNextInternal(IteratorContext* ctx,
131                              std::vector<Tensor>* out_tensors,
132                              bool* end_of_sequence) override {
133         mutex_lock l(mu_);
134         *end_of_sequence = false;
135         if (slice_number_ % dataset()->num_replicas_ == 0) {
136           input_descriptors_.clear();
137           std::vector<Tensor> input_tensors;
138           TF_RETURN_IF_ERROR(
139               input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
140           if (*end_of_sequence) {
141             return OkStatus();
142           }
143 
144           input_descriptors_.reserve(input_tensors.size());
145           for (int i = 0; i < input_tensors.size(); ++i) {
146             if (input_tensors[i].dims() == 0) {
147               return errors::InvalidArgument(
148                   "Cannot rebatch dataset: All components must have at least "
149                   "one dimension. Perhaps your input dataset is not batched? "
150                   "Component ",
151                   i, " is scalar.");
152             }
153 
154             int64_t original_batch_dim = input_tensors[i].dim_size(0);
155             int64_t interval =
156                 CeilDiv(original_batch_dim, dataset()->num_replicas_);
157             input_descriptors_.push_back(
158                 {std::move(input_tensors[i]), original_batch_dim, interval});
159           }
160         }
161 
162         out_tensors->reserve(input_descriptors_.size());
163 
164         // We slice each component independently because they may have
165         // different batch dimensions.
166         for (const auto& input_desc : input_descriptors_) {
167           int64_t start = input_desc.interval * slice_number_;
168           int64_t end = std::min(start + input_desc.interval,
169                                  input_desc.original_batch_dim);
170           if (start >= end) {
171             // We can get here if ceil(original_batch_dim_ / new batch dim) <
172             // num_replicas_, i.e. the batch isn't big enough to distribute
173             // over num replicas. In this case, we return empty tensors for
174             // the remaining iterations that correspond to this batch.
175             start = end;
176           }
177           Tensor slice = input_desc.whole_tensor.Slice(start, end);
178           if (slice.IsAligned()) {
179             out_tensors->push_back(std::move(slice));
180           } else {
181             out_tensors->push_back(tensor::DeepCopy(std::move(slice)));
182           }
183         }
184         slice_number_ = (slice_number_ + 1) % dataset()->num_replicas_;
185         return OkStatus();
186       }
187 
188      protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)189       Status SaveInternal(SerializationContext* ctx,
190                           IteratorStateWriter* writer) override {
191         mutex_lock l(mu_);
192         if (!input_impl_) {
193           TF_RETURN_IF_ERROR(
194               writer->WriteScalar(full_name("input_impl_empty"), ""));
195         } else {
196           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
197         }
198         TF_RETURN_IF_ERROR(
199             writer->WriteScalar(full_name("slice_number"), slice_number_));
200 
201         if (slice_number_ % dataset()->num_replicas_ != 0) {
202           // Save state of input tensors.
203           for (int i = 0; i < input_descriptors_.size(); ++i) {
204             TF_RETURN_IF_ERROR(writer->WriteTensor(
205                 full_name(strings::StrCat("tensors[", i, "]")),
206                 input_descriptors_[i].whole_tensor));
207           }
208         }
209         return OkStatus();
210       }
211 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)212       Status RestoreInternal(IteratorContext* ctx,
213                              IteratorStateReader* reader) override {
214         mutex_lock l(mu_);
215         if (!reader->Contains(full_name("input_impl_empty"))) {
216           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
217         } else {
218           input_impl_.reset();
219         }
220         TF_RETURN_IF_ERROR(
221             reader->ReadScalar(full_name("slice_number"), &slice_number_));
222 
223         input_descriptors_.clear();
224         input_descriptors_.resize(dataset()->output_dtypes().size());
225         if (slice_number_ % dataset()->num_replicas_ != 0) {
226           for (int i = 0; i < input_descriptors_.size(); ++i) {
227             TF_RETURN_IF_ERROR(reader->ReadTensor(
228                 ctx->flr(), full_name(strings::StrCat("tensors[", i, "]")),
229                 &input_descriptors_[i].whole_tensor));
230             input_descriptors_[i].original_batch_dim =
231                 input_descriptors_[i].whole_tensor.dim_size(0);
232             input_descriptors_[i].interval =
233                 CeilDiv(input_descriptors_[i].original_batch_dim,
234                         dataset()->num_replicas_);
235           }
236         }
237         return OkStatus();
238       }
239 
GetTraceMeMetadata() const240       TraceMeMetadata GetTraceMeMetadata() const override {
241         return dataset()->traceme_metadata_;
242       }
243 
244      private:
245       // Describes one component of the input.
246       struct InputDescriptor {
InputDescriptortensorflow::data::experimental::__anon08085dbc0111::RebatchDatasetOp::Dataset::Iterator::InputDescriptor247         InputDescriptor() {}
InputDescriptortensorflow::data::experimental::__anon08085dbc0111::RebatchDatasetOp::Dataset::Iterator::InputDescriptor248         InputDescriptor(Tensor&& whole_tensor, int64_t original_batch_dim,
249                         int64_t interval)
250             : whole_tensor(std::move(whole_tensor)),
251               original_batch_dim(original_batch_dim),
252               interval(interval) {}
253 
254         Tensor whole_tensor;
255         int64_t original_batch_dim;
256         int64_t interval;
257       };
258 
259       mutex mu_;
260       std::unique_ptr<IteratorBase> input_impl_;
261       std::vector<InputDescriptor> input_descriptors_ TF_GUARDED_BY(mu_);
262       int64_t slice_number_ TF_GUARDED_BY(mu_) = 0;
263     };
264 
265     const DatasetBase* const input_;
266     const int64_t num_replicas_;
267     const DataTypeVector output_types_;
268     const std::vector<PartialTensorShape> output_shapes_;
269     const TraceMeMetadata traceme_metadata_;
270   };
271 
272   DataTypeVector output_types_;
273   std::vector<PartialTensorShape> output_shapes_;
274 };
275 
276 // This dataset rebatches its input batches into batches of different size(s).
277 //
278 // This differs from RebatchDatasetOp. Namely, RebatchDatasetV2 rebatches
279 // incoming batches into batches whose new sizes are specified by the
280 // `batch_sizes` argument, while RebatchDataset splits its batches based
281 // on the (dynamic) input batch size and the given number of splits to make (its
282 // `num_replicas` argument). When used in tf.distribute, this allows
283 // RebatchDataset to split batches more correctly when the splits are
284 // distributed across multiple workers and replicas.
285 class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
286  public:
RebatchDatasetV2Op(OpKernelConstruction * ctx)287   explicit RebatchDatasetV2Op(OpKernelConstruction* ctx)
288       : UnaryDatasetOpKernel(ctx) {
289     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
290     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
291   }
292 
293  protected:
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)294   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
295                    DatasetBase** output) override {
296     const Tensor* batch_sizes_tensor;
297     OP_REQUIRES_OK(ctx, ctx->input("batch_sizes", &batch_sizes_tensor));
298     OP_REQUIRES(
299         ctx, batch_sizes_tensor->dims() <= 1,
300         errors::InvalidArgument("`batch_sizes` must be a scalar or a vector."));
301 
302     std::vector<int64_t> batch_sizes;
303     batch_sizes.reserve(batch_sizes_tensor->NumElements());
304     for (int i = 0; i < batch_sizes_tensor->NumElements(); ++i) {
305       batch_sizes.push_back(batch_sizes_tensor->flat<int64_t>()(i));
306     }
307 
308     bool drop_remainder;
309     OP_REQUIRES_OK(
310         ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
311 
312     *output = new Dataset(ctx, input, std::move(batch_sizes), drop_remainder,
313                           output_types_, output_shapes_);
314   }
315 
316  private:
317   class Dataset : public DatasetBase {
318    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::vector<int64_t> && batch_sizes,bool drop_remainder,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)319     Dataset(OpKernelContext* ctx, const DatasetBase* input,
320             std::vector<int64_t>&& batch_sizes, bool drop_remainder,
321             const DataTypeVector& output_types,
322             const std::vector<PartialTensorShape>& output_shapes)
323         : DatasetBase(DatasetContext(ctx)),
324           input_(input),
325           batch_sizes_(std::move(batch_sizes)),
326           drop_remainder_(drop_remainder),
327           output_types_(output_types),
328           output_shapes_(output_shapes),
329           traceme_metadata_(
330               {{"batch_sizes", absl::StrJoin(batch_sizes, ",")}}) {
331       input_->Ref();
332     }
333 
~Dataset()334     ~Dataset() override { input_->Unref(); }
335 
MakeIteratorInternal(const string & prefix) const336     std::unique_ptr<IteratorBase> MakeIteratorInternal(
337         const string& prefix) const override {
338       name_utils::IteratorPrefixParams params;
339       return std::make_unique<Iterator>(Iterator::Params{
340           this, name_utils::IteratorPrefix(kDatasetTypeV2, prefix, params)});
341     }
342 
output_dtypes() const343     const DataTypeVector& output_dtypes() const override {
344       return output_types_;
345     }
346 
output_shapes() const347     const std::vector<PartialTensorShape>& output_shapes() const override {
348       return output_shapes_;
349     }
350 
DebugString() const351     string DebugString() const override {
352       return name_utils::DatasetDebugString(kDatasetTypeV2);
353     }
354 
InputDatasets(std::vector<const DatasetBase * > * inputs) const355     Status InputDatasets(
356         std::vector<const DatasetBase*>* inputs) const override {
357       inputs->push_back(input_);
358       return OkStatus();
359     }
360 
CheckExternalState() const361     Status CheckExternalState() const override {
362       return input_->CheckExternalState();
363     }
364 
365    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const366     Status AsGraphDefInternal(SerializationContext* ctx,
367                               DatasetGraphDefBuilder* b,
368                               Node** output) const override {
369       Node* input_graph_node = nullptr;
370       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
371       Node* batch_sizes = nullptr;
372       TF_RETURN_IF_ERROR(b->AddVector(batch_sizes_, &batch_sizes));
373       Node* drop_remainder = nullptr;
374       TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
375       TF_RETURN_IF_ERROR(b->AddDataset(
376           this, {input_graph_node, batch_sizes, drop_remainder}, output));
377       return OkStatus();
378     }
379 
380    private:
381     class Iterator : public DatasetIterator<Dataset> {
382      public:
Iterator(const Params & params)383       explicit Iterator(const Params& params)
384           : DatasetIterator<Dataset>(params) {}
385 
~Iterator()386       ~Iterator() override {}
387 
Initialize(IteratorContext * ctx)388       Status Initialize(IteratorContext* ctx) override {
389         return dataset()->input_->MakeIterator(ctx, this, prefix(),
390                                                &input_impl_);
391       }
392 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)393       Status GetNextInternal(IteratorContext* ctx,
394                              std::vector<Tensor>* out_tensors,
395                              bool* end_of_sequence) override {
396         mutex_lock l(mu_);
397         if (end_of_sequence_) {
398           *end_of_sequence = true;
399           return OkStatus();
400         }
401 
402         *end_of_sequence = false;
403 
404         auto desired_batch_size = dataset()->batch_sizes_[batch_sizes_index_];
405         // Tracks the size of the current batch as it's built up, possibly from
406         // different input tensors.
407         int64_t batch_size = 0;
408 
409         std::vector<std::vector<Tensor>> slices_to_concatenate;
410         // Get slices from input tensors until they make up the whole batch
411         // size or we run out of input.
412         while (batch_size < desired_batch_size) {
413           if (offset_ == -1) {
414             // Get new input tensors.
415             tensors_.clear();
416             TF_RETURN_IF_ERROR(
417                 input_impl_->GetNext(ctx, &tensors_, &end_of_sequence_));
418             if (end_of_sequence_) {
419               // Break and return partial batch, if any.
420               break;
421             }
422             TF_RETURN_IF_ERROR(ValidateInputTensors());
423             offset_ = 0;
424           }
425 
426           int64_t slice_end =
427               std::min(offset_ + desired_batch_size - batch_size,
428                        tensors_[0].dim_size(0));
429 
430           std::vector<Tensor> slices;
431           slices.reserve(tensors_.size());
432           for (const auto& tensor : tensors_) {
433             slices.push_back(tensor.Slice(offset_, slice_end));
434           }
435           slices_to_concatenate.push_back(std::move(slices));
436 
437           batch_size += (slice_end - offset_);
438           offset_ = slice_end;
439           if (offset_ == tensors_[0].dim_size(0)) {
440             // Exhausted current input tensors, reset.
441             offset_ = -1;
442           }
443         }
444 
445         batch_sizes_index_++;
446         batch_sizes_index_ %= dataset()->batch_sizes_.size();
447 
448         // Return end_of_sequence if GetNext is expected to produce a non-empty
449         // batch and there are no more inputs, or if drop_remainder is true and
450         // we can't make a full batch.
451         if ((batch_size == 0 && desired_batch_size > 0) ||
452             (dataset()->drop_remainder_ && batch_size < desired_batch_size)) {
453           DCHECK(end_of_sequence_);
454           *end_of_sequence = true;
455           return OkStatus();
456         }
457 
458         const size_t num_components = dataset()->output_dtypes().size();
459         out_tensors->reserve(num_components);
460 
461         // Special case: desired batch size == 0. This may be the case when,
462         // with distribution strategies, one of replicas expects an empty batch
463         // so that the global batch size adds up correctly.
464         if (desired_batch_size == 0) {
465           DCHECK_EQ(batch_size, 0);
466           DCHECK_EQ(slices_to_concatenate.size(), 0);
467           for (int i = 0; i < dataset()->output_dtypes().size(); ++i) {
468             if (dataset()->output_shapes()[i].unknown_rank()) {
469               // For unknown rank tensors, we just create a empty Tensor since
470               // it doesn't matter what shape it is.
471               out_tensors->push_back(Tensor(dataset()->output_dtypes()[i]));
472             } else {
473               auto dim_sizes = dataset()->output_shapes()[i].dim_sizes();
474 
475               // The output batch size is always zero since the desired batch
476               // size is zero.
477               dim_sizes[0] = 0;
478 
479               // Handle unknown dimensions by setting any unknown dimensions to
480               // zero since there isn't any data anyway.
481               for (int j = 1; j < dim_sizes.size(); ++j) {
482                 if (dim_sizes[j] == -1) dim_sizes[j] = 0;
483               }
484 
485               TensorShape tensor_shape(dim_sizes);
486               out_tensors->push_back(
487                   Tensor(dataset()->output_dtypes()[i], tensor_shape));
488             }
489           }
490           return OkStatus();
491         }
492 
493         // Special case: when there's only one slice, we return the slice
494         // directly where possible instead of copying the tensor data.
495         if (slices_to_concatenate.size() == 1) {
496           auto tensors = std::move(slices_to_concatenate[0]);
497           for (size_t i = 0; i < num_components; ++i) {
498             // If the slice is aligned, we return it directly.
499             if (!tensors[i].IsAligned()) {
500               tensors[i] = tensor::DeepCopy(std::move(tensors[i]));
501             }
502           }
503           *out_tensors = std::move(tensors);
504           return OkStatus();
505         }
506 
507         // For each component, concatenate slices into one tensor.
508         for (size_t i = 0; i < num_components; ++i) {
509           TensorShape component_shape({batch_size});
510           TensorShape remaining_shape = slices_to_concatenate[0][i].shape();
511           remaining_shape.RemoveDim(0);
512           component_shape.AppendShape(remaining_shape);
513           out_tensors->emplace_back(ctx->allocator({}),
514                                     dataset()->output_dtypes()[i],
515                                     component_shape);
516           if (!out_tensors->back().IsInitialized()) {
517             return errors::ResourceExhausted(
518                 "Failed to allocate memory for the batch of component ", i);
519           }
520           int64_t dst_offset = 0;
521           for (size_t j = 0; j < slices_to_concatenate.size(); ++j) {
522             auto num_slices = slices_to_concatenate[j][i].shape().dim_size(0);
523             TF_RETURN_IF_ERROR(batch_util::CopyContiguousSlices(
524                 slices_to_concatenate[j][i], 0, dst_offset, num_slices,
525                 &(*out_tensors)[i]));
526             dst_offset += num_slices;
527           }
528         }
529 
530         return OkStatus();
531       }
532 
533      protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)534       Status SaveInternal(SerializationContext* ctx,
535                           IteratorStateWriter* writer) override {
536         mutex_lock l(mu_);
537         if (!input_impl_) {
538           TF_RETURN_IF_ERROR(
539               writer->WriteScalar(full_name("input_impl_empty"), ""));
540         } else {
541           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
542         }
543         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_sizes_index"),
544                                                batch_sizes_index_));
545         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("offset"), offset_));
546         if (offset_ != -1) {
547           for (int i = 0; i < tensors_.size(); ++i) {
548             TF_RETURN_IF_ERROR(writer->WriteTensor(
549                 full_name(strings::StrCat("tensors[", i, "]")), tensors_[i]));
550           }
551         }
552         return OkStatus();
553       }
554 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)555       Status RestoreInternal(IteratorContext* ctx,
556                              IteratorStateReader* reader) override {
557         mutex_lock l(mu_);
558         if (!reader->Contains(full_name("input_impl_empty"))) {
559           TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
560         } else {
561           input_impl_.reset();
562         }
563         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_sizes_index"),
564                                               &batch_sizes_index_));
565         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset_));
566 
567         tensors_.clear();
568         if (offset_ != -1) {
569           tensors_.resize(dataset()->output_dtypes().size());
570           for (int i = 0; i < tensors_.size(); ++i) {
571             TF_RETURN_IF_ERROR(reader->ReadTensor(
572                 ctx->flr(), full_name(strings::StrCat("tensors[", i, "]")),
573                 &tensors_[i]));
574           }
575         }
576         return OkStatus();
577       }
578 
GetTraceMeMetadata() const579       TraceMeMetadata GetTraceMeMetadata() const override {
580         return dataset()->traceme_metadata_;
581       }
582 
583      private:
ValidateInputTensors()584       Status ValidateInputTensors() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
585         for (size_t i = 0; i < tensors_.size(); ++i) {
586           if (tensors_[i].dims() == 0) {
587             return errors::InvalidArgument(
588                 "Input element must have a non-scalar value in each "
589                 "component.");
590           }
591           if (tensors_[i].dim_size(0) != tensors_[0].dim_size(0)) {
592             return errors::InvalidArgument(
593                 "Input element must have the same batch size in each "
594                 "component. Component 0 had size ",
595                 tensors_[0].dim_size(0), " but component ", i, " had size, ",
596                 tensors_[i].dim_size(0), ".");
597           }
598         }
599         return OkStatus();
600       }
601 
602       mutex mu_;
603       std::unique_ptr<IteratorBase> input_impl_;
604       // Whether we have reached the end of the input.
605       bool end_of_sequence_ TF_GUARDED_BY(mu_) = false;
606       // Represents the current input tensor(s).
607       std::vector<Tensor> tensors_ TF_GUARDED_BY(mu_);
608       // Represents the offset into the current input tensor(s).
609       // An offset of -1 indicates that there is no data left in the current
610       // slice.
611       int64_t offset_ TF_GUARDED_BY(mu_) = -1;
612       // Represents the current index into the batch_sizes list.
613       int64_t batch_sizes_index_ TF_GUARDED_BY(mu_) = 0;
614     };
615 
616     const DatasetBase* const input_;
617     const std::vector<int64_t> batch_sizes_;
618     const bool drop_remainder_;
619     const DataTypeVector output_types_;
620     const std::vector<PartialTensorShape> output_shapes_;
621     const TraceMeMetadata traceme_metadata_;
622   };
623 
624   DataTypeVector output_types_;
625   std::vector<PartialTensorShape> output_shapes_;
626 };
627 
628 REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU),
629                         RebatchDatasetOp);
630 REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
631                         RebatchDatasetOp);
632 
633 REGISTER_KERNEL_BUILDER(Name("RebatchDatasetV2").Device(DEVICE_CPU),
634                         RebatchDatasetV2Op);
635 
636 }  // anonymous namespace
637 }  // namespace experimental
638 }  // namespace data
639 }  // namespace tensorflow
640