1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/name_utils.h"
17 #include "tensorflow/core/framework/dataset.h"
18 #include "tensorflow/core/framework/tensor_util.h"
19 #include "tensorflow/core/platform/stringprintf.h"
20
21 namespace tensorflow {
22 namespace data {
23 namespace experimental {
24 namespace {
25
CeilDiv(int64_t dividend,int64_t divisor)26 inline int64_t CeilDiv(int64_t dividend, int64_t divisor) {
27 return (dividend - 1 + divisor) / divisor;
28 }
29
30 constexpr const char* const kDatasetTypeV1 = "Rebatch";
31 constexpr const char* const kDatasetTypeV2 = "RebatchV2";
32
33 class RebatchDatasetOp : public UnaryDatasetOpKernel {
34 public:
RebatchDatasetOp(OpKernelConstruction * ctx)35 explicit RebatchDatasetOp(OpKernelConstruction* ctx)
36 : UnaryDatasetOpKernel(ctx) {
37 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
38 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
39 }
40
41 protected:
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)42 void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
43 DatasetBase** output) override {
44 int64_t num_replicas;
45 OP_REQUIRES_OK(ctx,
46 ParseScalarArgument(ctx, "num_replicas", &num_replicas));
47 OP_REQUIRES(
48 ctx, num_replicas > 0,
49 errors::InvalidArgument("num_replicas must be greater than zero."));
50 *output =
51 new Dataset(ctx, input, num_replicas, output_types_, output_shapes_);
52 }
53
54 private:
55 class Dataset : public DatasetBase {
56 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,const int64_t num_replicas,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)57 Dataset(OpKernelContext* ctx, const DatasetBase* input,
58 const int64_t num_replicas, const DataTypeVector& output_types,
59 const std::vector<PartialTensorShape>& output_shapes)
60 : DatasetBase(DatasetContext(ctx)),
61 input_(input),
62 num_replicas_(num_replicas),
63 output_types_(output_types),
64 output_shapes_(output_shapes),
65 traceme_metadata_(
66 {{"num_replicas", strings::Printf("%lld", static_cast<long long>(
67 num_replicas))}}) {
68 input_->Ref();
69 }
70
~Dataset()71 ~Dataset() override { input_->Unref(); }
72
MakeIteratorInternal(const string & prefix) const73 std::unique_ptr<IteratorBase> MakeIteratorInternal(
74 const string& prefix) const override {
75 name_utils::IteratorPrefixParams params;
76 return std::make_unique<Iterator>(Iterator::Params{
77 this, name_utils::IteratorPrefix(kDatasetTypeV1, prefix, params)});
78 }
79
output_dtypes() const80 const DataTypeVector& output_dtypes() const override {
81 return output_types_;
82 }
83
output_shapes() const84 const std::vector<PartialTensorShape>& output_shapes() const override {
85 return output_shapes_;
86 }
87
DebugString() const88 string DebugString() const override {
89 name_utils::DatasetDebugStringParams params;
90 params.set_args(num_replicas_);
91 return name_utils::DatasetDebugString(kDatasetTypeV1, params);
92 }
93
InputDatasets(std::vector<const DatasetBase * > * inputs) const94 Status InputDatasets(
95 std::vector<const DatasetBase*>* inputs) const override {
96 inputs->push_back(input_);
97 return OkStatus();
98 }
99
CheckExternalState() const100 Status CheckExternalState() const override {
101 return input_->CheckExternalState();
102 }
103
104 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const105 Status AsGraphDefInternal(SerializationContext* ctx,
106 DatasetGraphDefBuilder* b,
107 Node** output) const override {
108 Node* input_graph_node = nullptr;
109 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
110 Node* num_replicas = nullptr;
111 TF_RETURN_IF_ERROR(b->AddScalar(num_replicas_, &num_replicas));
112 TF_RETURN_IF_ERROR(
113 b->AddDataset(this, {input_graph_node, num_replicas}, output));
114 return OkStatus();
115 }
116
117 private:
118 class Iterator : public DatasetIterator<Dataset> {
119 public:
Iterator(const Params & params)120 explicit Iterator(const Params& params)
121 : DatasetIterator<Dataset>(params) {}
122
~Iterator()123 ~Iterator() override {}
124
Initialize(IteratorContext * ctx)125 Status Initialize(IteratorContext* ctx) override {
126 return dataset()->input_->MakeIterator(ctx, this, prefix(),
127 &input_impl_);
128 }
129
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)130 Status GetNextInternal(IteratorContext* ctx,
131 std::vector<Tensor>* out_tensors,
132 bool* end_of_sequence) override {
133 mutex_lock l(mu_);
134 *end_of_sequence = false;
135 if (slice_number_ % dataset()->num_replicas_ == 0) {
136 input_descriptors_.clear();
137 std::vector<Tensor> input_tensors;
138 TF_RETURN_IF_ERROR(
139 input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
140 if (*end_of_sequence) {
141 return OkStatus();
142 }
143
144 input_descriptors_.reserve(input_tensors.size());
145 for (int i = 0; i < input_tensors.size(); ++i) {
146 if (input_tensors[i].dims() == 0) {
147 return errors::InvalidArgument(
148 "Cannot rebatch dataset: All components must have at least "
149 "one dimension. Perhaps your input dataset is not batched? "
150 "Component ",
151 i, " is scalar.");
152 }
153
154 int64_t original_batch_dim = input_tensors[i].dim_size(0);
155 int64_t interval =
156 CeilDiv(original_batch_dim, dataset()->num_replicas_);
157 input_descriptors_.push_back(
158 {std::move(input_tensors[i]), original_batch_dim, interval});
159 }
160 }
161
162 out_tensors->reserve(input_descriptors_.size());
163
164 // We slice each component independently because they may have
165 // different batch dimensions.
166 for (const auto& input_desc : input_descriptors_) {
167 int64_t start = input_desc.interval * slice_number_;
168 int64_t end = std::min(start + input_desc.interval,
169 input_desc.original_batch_dim);
170 if (start >= end) {
171 // We can get here if ceil(original_batch_dim_ / new batch dim) <
172 // num_replicas_, i.e. the batch isn't big enough to distribute
173 // over num replicas. In this case, we return empty tensors for
174 // the remaining iterations that correspond to this batch.
175 start = end;
176 }
177 Tensor slice = input_desc.whole_tensor.Slice(start, end);
178 if (slice.IsAligned()) {
179 out_tensors->push_back(std::move(slice));
180 } else {
181 out_tensors->push_back(tensor::DeepCopy(std::move(slice)));
182 }
183 }
184 slice_number_ = (slice_number_ + 1) % dataset()->num_replicas_;
185 return OkStatus();
186 }
187
188 protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)189 Status SaveInternal(SerializationContext* ctx,
190 IteratorStateWriter* writer) override {
191 mutex_lock l(mu_);
192 if (!input_impl_) {
193 TF_RETURN_IF_ERROR(
194 writer->WriteScalar(full_name("input_impl_empty"), ""));
195 } else {
196 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
197 }
198 TF_RETURN_IF_ERROR(
199 writer->WriteScalar(full_name("slice_number"), slice_number_));
200
201 if (slice_number_ % dataset()->num_replicas_ != 0) {
202 // Save state of input tensors.
203 for (int i = 0; i < input_descriptors_.size(); ++i) {
204 TF_RETURN_IF_ERROR(writer->WriteTensor(
205 full_name(strings::StrCat("tensors[", i, "]")),
206 input_descriptors_[i].whole_tensor));
207 }
208 }
209 return OkStatus();
210 }
211
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)212 Status RestoreInternal(IteratorContext* ctx,
213 IteratorStateReader* reader) override {
214 mutex_lock l(mu_);
215 if (!reader->Contains(full_name("input_impl_empty"))) {
216 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
217 } else {
218 input_impl_.reset();
219 }
220 TF_RETURN_IF_ERROR(
221 reader->ReadScalar(full_name("slice_number"), &slice_number_));
222
223 input_descriptors_.clear();
224 input_descriptors_.resize(dataset()->output_dtypes().size());
225 if (slice_number_ % dataset()->num_replicas_ != 0) {
226 for (int i = 0; i < input_descriptors_.size(); ++i) {
227 TF_RETURN_IF_ERROR(reader->ReadTensor(
228 ctx->flr(), full_name(strings::StrCat("tensors[", i, "]")),
229 &input_descriptors_[i].whole_tensor));
230 input_descriptors_[i].original_batch_dim =
231 input_descriptors_[i].whole_tensor.dim_size(0);
232 input_descriptors_[i].interval =
233 CeilDiv(input_descriptors_[i].original_batch_dim,
234 dataset()->num_replicas_);
235 }
236 }
237 return OkStatus();
238 }
239
GetTraceMeMetadata() const240 TraceMeMetadata GetTraceMeMetadata() const override {
241 return dataset()->traceme_metadata_;
242 }
243
244 private:
245 // Describes one component of the input.
246 struct InputDescriptor {
InputDescriptortensorflow::data::experimental::__anon08085dbc0111::RebatchDatasetOp::Dataset::Iterator::InputDescriptor247 InputDescriptor() {}
InputDescriptortensorflow::data::experimental::__anon08085dbc0111::RebatchDatasetOp::Dataset::Iterator::InputDescriptor248 InputDescriptor(Tensor&& whole_tensor, int64_t original_batch_dim,
249 int64_t interval)
250 : whole_tensor(std::move(whole_tensor)),
251 original_batch_dim(original_batch_dim),
252 interval(interval) {}
253
254 Tensor whole_tensor;
255 int64_t original_batch_dim;
256 int64_t interval;
257 };
258
259 mutex mu_;
260 std::unique_ptr<IteratorBase> input_impl_;
261 std::vector<InputDescriptor> input_descriptors_ TF_GUARDED_BY(mu_);
262 int64_t slice_number_ TF_GUARDED_BY(mu_) = 0;
263 };
264
265 const DatasetBase* const input_;
266 const int64_t num_replicas_;
267 const DataTypeVector output_types_;
268 const std::vector<PartialTensorShape> output_shapes_;
269 const TraceMeMetadata traceme_metadata_;
270 };
271
272 DataTypeVector output_types_;
273 std::vector<PartialTensorShape> output_shapes_;
274 };
275
276 // This dataset rebatches its input batches into batches of different size(s).
277 //
278 // This differs from RebatchDatasetOp. Namely, RebatchDatasetV2 rebatches
279 // incoming batches into batches whose new sizes are specified by the
280 // `batch_sizes` argument, while RebatchDataset splits its batches based
281 // on the (dynamic) input batch size and the given number of splits to make (its
282 // `num_replicas` argument). When used in tf.distribute, this allows
283 // RebatchDataset to split batches more correctly when the splits are
284 // distributed across multiple workers and replicas.
285 class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
286 public:
RebatchDatasetV2Op(OpKernelConstruction * ctx)287 explicit RebatchDatasetV2Op(OpKernelConstruction* ctx)
288 : UnaryDatasetOpKernel(ctx) {
289 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
290 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
291 }
292
293 protected:
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)294 void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
295 DatasetBase** output) override {
296 const Tensor* batch_sizes_tensor;
297 OP_REQUIRES_OK(ctx, ctx->input("batch_sizes", &batch_sizes_tensor));
298 OP_REQUIRES(
299 ctx, batch_sizes_tensor->dims() <= 1,
300 errors::InvalidArgument("`batch_sizes` must be a scalar or a vector."));
301
302 std::vector<int64_t> batch_sizes;
303 batch_sizes.reserve(batch_sizes_tensor->NumElements());
304 for (int i = 0; i < batch_sizes_tensor->NumElements(); ++i) {
305 batch_sizes.push_back(batch_sizes_tensor->flat<int64_t>()(i));
306 }
307
308 bool drop_remainder;
309 OP_REQUIRES_OK(
310 ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
311
312 *output = new Dataset(ctx, input, std::move(batch_sizes), drop_remainder,
313 output_types_, output_shapes_);
314 }
315
316 private:
317 class Dataset : public DatasetBase {
318 public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::vector<int64_t> && batch_sizes,bool drop_remainder,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)319 Dataset(OpKernelContext* ctx, const DatasetBase* input,
320 std::vector<int64_t>&& batch_sizes, bool drop_remainder,
321 const DataTypeVector& output_types,
322 const std::vector<PartialTensorShape>& output_shapes)
323 : DatasetBase(DatasetContext(ctx)),
324 input_(input),
325 batch_sizes_(std::move(batch_sizes)),
326 drop_remainder_(drop_remainder),
327 output_types_(output_types),
328 output_shapes_(output_shapes),
329 traceme_metadata_(
330 {{"batch_sizes", absl::StrJoin(batch_sizes, ",")}}) {
331 input_->Ref();
332 }
333
~Dataset()334 ~Dataset() override { input_->Unref(); }
335
MakeIteratorInternal(const string & prefix) const336 std::unique_ptr<IteratorBase> MakeIteratorInternal(
337 const string& prefix) const override {
338 name_utils::IteratorPrefixParams params;
339 return std::make_unique<Iterator>(Iterator::Params{
340 this, name_utils::IteratorPrefix(kDatasetTypeV2, prefix, params)});
341 }
342
output_dtypes() const343 const DataTypeVector& output_dtypes() const override {
344 return output_types_;
345 }
346
output_shapes() const347 const std::vector<PartialTensorShape>& output_shapes() const override {
348 return output_shapes_;
349 }
350
DebugString() const351 string DebugString() const override {
352 return name_utils::DatasetDebugString(kDatasetTypeV2);
353 }
354
InputDatasets(std::vector<const DatasetBase * > * inputs) const355 Status InputDatasets(
356 std::vector<const DatasetBase*>* inputs) const override {
357 inputs->push_back(input_);
358 return OkStatus();
359 }
360
CheckExternalState() const361 Status CheckExternalState() const override {
362 return input_->CheckExternalState();
363 }
364
365 protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const366 Status AsGraphDefInternal(SerializationContext* ctx,
367 DatasetGraphDefBuilder* b,
368 Node** output) const override {
369 Node* input_graph_node = nullptr;
370 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
371 Node* batch_sizes = nullptr;
372 TF_RETURN_IF_ERROR(b->AddVector(batch_sizes_, &batch_sizes));
373 Node* drop_remainder = nullptr;
374 TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
375 TF_RETURN_IF_ERROR(b->AddDataset(
376 this, {input_graph_node, batch_sizes, drop_remainder}, output));
377 return OkStatus();
378 }
379
380 private:
381 class Iterator : public DatasetIterator<Dataset> {
382 public:
Iterator(const Params & params)383 explicit Iterator(const Params& params)
384 : DatasetIterator<Dataset>(params) {}
385
~Iterator()386 ~Iterator() override {}
387
Initialize(IteratorContext * ctx)388 Status Initialize(IteratorContext* ctx) override {
389 return dataset()->input_->MakeIterator(ctx, this, prefix(),
390 &input_impl_);
391 }
392
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)393 Status GetNextInternal(IteratorContext* ctx,
394 std::vector<Tensor>* out_tensors,
395 bool* end_of_sequence) override {
396 mutex_lock l(mu_);
397 if (end_of_sequence_) {
398 *end_of_sequence = true;
399 return OkStatus();
400 }
401
402 *end_of_sequence = false;
403
404 auto desired_batch_size = dataset()->batch_sizes_[batch_sizes_index_];
405 // Tracks the size of the current batch as it's built up, possibly from
406 // different input tensors.
407 int64_t batch_size = 0;
408
409 std::vector<std::vector<Tensor>> slices_to_concatenate;
410 // Get slices from input tensors until they make up the whole batch
411 // size or we run out of input.
412 while (batch_size < desired_batch_size) {
413 if (offset_ == -1) {
414 // Get new input tensors.
415 tensors_.clear();
416 TF_RETURN_IF_ERROR(
417 input_impl_->GetNext(ctx, &tensors_, &end_of_sequence_));
418 if (end_of_sequence_) {
419 // Break and return partial batch, if any.
420 break;
421 }
422 TF_RETURN_IF_ERROR(ValidateInputTensors());
423 offset_ = 0;
424 }
425
426 int64_t slice_end =
427 std::min(offset_ + desired_batch_size - batch_size,
428 tensors_[0].dim_size(0));
429
430 std::vector<Tensor> slices;
431 slices.reserve(tensors_.size());
432 for (const auto& tensor : tensors_) {
433 slices.push_back(tensor.Slice(offset_, slice_end));
434 }
435 slices_to_concatenate.push_back(std::move(slices));
436
437 batch_size += (slice_end - offset_);
438 offset_ = slice_end;
439 if (offset_ == tensors_[0].dim_size(0)) {
440 // Exhausted current input tensors, reset.
441 offset_ = -1;
442 }
443 }
444
445 batch_sizes_index_++;
446 batch_sizes_index_ %= dataset()->batch_sizes_.size();
447
448 // Return end_of_sequence if GetNext is expected to produce a non-empty
449 // batch and there are no more inputs, or if drop_remainder is true and
450 // we can't make a full batch.
451 if ((batch_size == 0 && desired_batch_size > 0) ||
452 (dataset()->drop_remainder_ && batch_size < desired_batch_size)) {
453 DCHECK(end_of_sequence_);
454 *end_of_sequence = true;
455 return OkStatus();
456 }
457
458 const size_t num_components = dataset()->output_dtypes().size();
459 out_tensors->reserve(num_components);
460
461 // Special case: desired batch size == 0. This may be the case when,
462 // with distribution strategies, one of replicas expects an empty batch
463 // so that the global batch size adds up correctly.
464 if (desired_batch_size == 0) {
465 DCHECK_EQ(batch_size, 0);
466 DCHECK_EQ(slices_to_concatenate.size(), 0);
467 for (int i = 0; i < dataset()->output_dtypes().size(); ++i) {
468 if (dataset()->output_shapes()[i].unknown_rank()) {
469 // For unknown rank tensors, we just create a empty Tensor since
470 // it doesn't matter what shape it is.
471 out_tensors->push_back(Tensor(dataset()->output_dtypes()[i]));
472 } else {
473 auto dim_sizes = dataset()->output_shapes()[i].dim_sizes();
474
475 // The output batch size is always zero since the desired batch
476 // size is zero.
477 dim_sizes[0] = 0;
478
479 // Handle unknown dimensions by setting any unknown dimensions to
480 // zero since there isn't any data anyway.
481 for (int j = 1; j < dim_sizes.size(); ++j) {
482 if (dim_sizes[j] == -1) dim_sizes[j] = 0;
483 }
484
485 TensorShape tensor_shape(dim_sizes);
486 out_tensors->push_back(
487 Tensor(dataset()->output_dtypes()[i], tensor_shape));
488 }
489 }
490 return OkStatus();
491 }
492
493 // Special case: when there's only one slice, we return the slice
494 // directly where possible instead of copying the tensor data.
495 if (slices_to_concatenate.size() == 1) {
496 auto tensors = std::move(slices_to_concatenate[0]);
497 for (size_t i = 0; i < num_components; ++i) {
498 // If the slice is aligned, we return it directly.
499 if (!tensors[i].IsAligned()) {
500 tensors[i] = tensor::DeepCopy(std::move(tensors[i]));
501 }
502 }
503 *out_tensors = std::move(tensors);
504 return OkStatus();
505 }
506
507 // For each component, concatenate slices into one tensor.
508 for (size_t i = 0; i < num_components; ++i) {
509 TensorShape component_shape({batch_size});
510 TensorShape remaining_shape = slices_to_concatenate[0][i].shape();
511 remaining_shape.RemoveDim(0);
512 component_shape.AppendShape(remaining_shape);
513 out_tensors->emplace_back(ctx->allocator({}),
514 dataset()->output_dtypes()[i],
515 component_shape);
516 if (!out_tensors->back().IsInitialized()) {
517 return errors::ResourceExhausted(
518 "Failed to allocate memory for the batch of component ", i);
519 }
520 int64_t dst_offset = 0;
521 for (size_t j = 0; j < slices_to_concatenate.size(); ++j) {
522 auto num_slices = slices_to_concatenate[j][i].shape().dim_size(0);
523 TF_RETURN_IF_ERROR(batch_util::CopyContiguousSlices(
524 slices_to_concatenate[j][i], 0, dst_offset, num_slices,
525 &(*out_tensors)[i]));
526 dst_offset += num_slices;
527 }
528 }
529
530 return OkStatus();
531 }
532
533 protected:
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)534 Status SaveInternal(SerializationContext* ctx,
535 IteratorStateWriter* writer) override {
536 mutex_lock l(mu_);
537 if (!input_impl_) {
538 TF_RETURN_IF_ERROR(
539 writer->WriteScalar(full_name("input_impl_empty"), ""));
540 } else {
541 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
542 }
543 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("batch_sizes_index"),
544 batch_sizes_index_));
545 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("offset"), offset_));
546 if (offset_ != -1) {
547 for (int i = 0; i < tensors_.size(); ++i) {
548 TF_RETURN_IF_ERROR(writer->WriteTensor(
549 full_name(strings::StrCat("tensors[", i, "]")), tensors_[i]));
550 }
551 }
552 return OkStatus();
553 }
554
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)555 Status RestoreInternal(IteratorContext* ctx,
556 IteratorStateReader* reader) override {
557 mutex_lock l(mu_);
558 if (!reader->Contains(full_name("input_impl_empty"))) {
559 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
560 } else {
561 input_impl_.reset();
562 }
563 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("batch_sizes_index"),
564 &batch_sizes_index_));
565 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("offset"), &offset_));
566
567 tensors_.clear();
568 if (offset_ != -1) {
569 tensors_.resize(dataset()->output_dtypes().size());
570 for (int i = 0; i < tensors_.size(); ++i) {
571 TF_RETURN_IF_ERROR(reader->ReadTensor(
572 ctx->flr(), full_name(strings::StrCat("tensors[", i, "]")),
573 &tensors_[i]));
574 }
575 }
576 return OkStatus();
577 }
578
GetTraceMeMetadata() const579 TraceMeMetadata GetTraceMeMetadata() const override {
580 return dataset()->traceme_metadata_;
581 }
582
583 private:
ValidateInputTensors()584 Status ValidateInputTensors() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
585 for (size_t i = 0; i < tensors_.size(); ++i) {
586 if (tensors_[i].dims() == 0) {
587 return errors::InvalidArgument(
588 "Input element must have a non-scalar value in each "
589 "component.");
590 }
591 if (tensors_[i].dim_size(0) != tensors_[0].dim_size(0)) {
592 return errors::InvalidArgument(
593 "Input element must have the same batch size in each "
594 "component. Component 0 had size ",
595 tensors_[0].dim_size(0), " but component ", i, " had size, ",
596 tensors_[i].dim_size(0), ".");
597 }
598 }
599 return OkStatus();
600 }
601
602 mutex mu_;
603 std::unique_ptr<IteratorBase> input_impl_;
604 // Whether we have reached the end of the input.
605 bool end_of_sequence_ TF_GUARDED_BY(mu_) = false;
606 // Represents the current input tensor(s).
607 std::vector<Tensor> tensors_ TF_GUARDED_BY(mu_);
608 // Represents the offset into the current input tensor(s).
609 // An offset of -1 indicates that there is no data left in the current
610 // slice.
611 int64_t offset_ TF_GUARDED_BY(mu_) = -1;
612 // Represents the current index into the batch_sizes list.
613 int64_t batch_sizes_index_ TF_GUARDED_BY(mu_) = 0;
614 };
615
616 const DatasetBase* const input_;
617 const std::vector<int64_t> batch_sizes_;
618 const bool drop_remainder_;
619 const DataTypeVector output_types_;
620 const std::vector<PartialTensorShape> output_shapes_;
621 const TraceMeMetadata traceme_metadata_;
622 };
623
624 DataTypeVector output_types_;
625 std::vector<PartialTensorShape> output_shapes_;
626 };
627
628 REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU),
629 RebatchDatasetOp);
630 REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
631 RebatchDatasetOp);
632
633 REGISTER_KERNEL_BUILDER(Name("RebatchDatasetV2").Device(DEVICE_CPU),
634 RebatchDatasetV2Op);
635
636 } // anonymous namespace
637 } // namespace experimental
638 } // namespace data
639 } // namespace tensorflow
640