• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
13 
14 #include "tensorflow/core/data/dataset_test_base.h"
15 
16 namespace tensorflow {
17 namespace data {
18 namespace {
19 
20 constexpr char kNodeName[] = "padded_batch_dataset";
21 constexpr int kOpVersion = 2;
22 
23 class PaddedBatchDatasetOpTest : public DatasetOpsTestBase {};
24 
25 class PaddedBatchDatasetParams : public DatasetParams {
26  public:
27   template <typename T>
PaddedBatchDatasetParams(T input_dataset_params,int64_t batch_size,std::vector<Tensor> padded_shapes,std::vector<Tensor> padded_values,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,int num_padded_shapes,string node_name)28   PaddedBatchDatasetParams(T input_dataset_params, int64_t batch_size,
29                            std::vector<Tensor> padded_shapes,
30                            std::vector<Tensor> padded_values,
31                            bool drop_remainder, bool parallel_copy,
32                            DataTypeVector output_dtypes,
33                            std::vector<PartialTensorShape> output_shapes,
34                            int num_padded_shapes, string node_name)
35       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
36                       std::move(node_name)),
37         batch_size_(batch_size),
38         padded_shapes_(std::move(padded_shapes)),
39         padded_values_(std::move(padded_values)),
40         drop_remainder_(drop_remainder),
41         parallel_copy_(parallel_copy),
42         num_padded_shapes_(num_padded_shapes) {
43     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
44     op_version_ = kOpVersion;
45     iterator_prefix_ =
46         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
47                                    input_dataset_params.iterator_prefix());
48   }
49 
GetInputTensors() const50   std::vector<Tensor> GetInputTensors() const override {
51     std::vector<Tensor> input_tensors;
52     input_tensors.emplace_back(
53         CreateTensor<int64_t>(TensorShape({}), {batch_size_}));
54     for (auto& padded_shape : padded_shapes_) {
55       input_tensors.emplace_back(padded_shape);
56     }
57     for (auto& padded_value : padded_values_) {
58       input_tensors.emplace_back(padded_value);
59     }
60     input_tensors.emplace_back(
61         CreateTensor<bool>(TensorShape({}), {drop_remainder_}));
62     return input_tensors;
63   }
64 
GetInputNames(std::vector<string> * input_names) const65   Status GetInputNames(std::vector<string>* input_names) const override {
66     *input_names = {PaddedBatchDatasetOp::kInputDataset,
67                     PaddedBatchDatasetOp::kBatchSize};
68     // Create the input names for the input padded_shapes.
69     for (int i = 0; i < num_padded_shapes_; ++i) {
70       input_names->emplace_back(
71           strings::StrCat(PaddedBatchDatasetOp::kPaddedShapes, "_", i));
72     }
73     // Create the input names for the input padding_values.
74     for (int j = 0; j < padded_values_.size(); ++j) {
75       input_names->emplace_back(
76           strings::StrCat(PaddedBatchDatasetOp::kPaddingValues, "_", j));
77     }
78     input_names->push_back(PaddedBatchDatasetOp::kDropRemainder);
79     return OkStatus();
80   }
81 
GetAttributes(AttributeVector * attr_vector) const82   Status GetAttributes(AttributeVector* attr_vector) const override {
83     *attr_vector = {{"parallel_copy", parallel_copy_},
84                     {"Toutput_types", output_dtypes_},
85                     {"output_shapes", output_shapes_},
86                     {"N", num_padded_shapes_},
87                     {"metadata", ""}};
88     return OkStatus();
89   }
90 
dataset_type() const91   string dataset_type() const override {
92     return PaddedBatchDatasetOp::kDatasetType;
93   }
94 
95  private:
96   int64_t batch_size_;
97   std::vector<Tensor> padded_shapes_;
98   std::vector<Tensor> padded_values_;
99   bool drop_remainder_;
100   bool parallel_copy_;
101   int num_padded_shapes_;
102 };
103 
104 // Test case 1: input elements with same shapes.
PaddedBatchDatasetParams1()105 PaddedBatchDatasetParams PaddedBatchDatasetParams1() {
106   auto tensor_slice_dataset_params = TensorSliceDatasetParams(
107       /*components=*/{CreateTensor<int64_t>(
108           TensorShape{7, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})},
109       /*node_name=*/"tensor_slice");
110   return PaddedBatchDatasetParams(
111       /*input_dataset_params=*/tensor_slice_dataset_params,
112       /*batch_size=*/2,
113       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {3})},
114       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
115       /*drop_remainder=*/true,
116       /*parallel_copy=*/true,
117       /*output_dtypes=*/{DT_INT64},
118       /*output_shapes=*/{PartialTensorShape({2, 3})},
119       /*num_padded_shapes=*/1,
120       /*node_name=*/kNodeName);
121 }
122 
123 // Test case 2: input elements with different shapes.
PaddedBatchDatasetParams2()124 PaddedBatchDatasetParams PaddedBatchDatasetParams2() {
125   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
126       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
127                                             {{0, 1, 2, 3, 4, 5}}),
128       /*node_name=*/"tensor_slice_0");
129   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
130       /*components=*/CreateTensors<int64_t>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
131       /*node_name=*/"tensor_slice_1");
132   auto concatenate_dataset_params =
133       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
134                                std::move(tensor_slice_dataset_params_1),
135                                /*output_dtypes=*/{DT_INT64},
136                                /*output_shapes=*/{PartialTensorShape({-1})},
137                                /*node_name=*/"concatenate");
138   return PaddedBatchDatasetParams(
139       /*input_dataset_params=*/concatenate_dataset_params,
140       /*batch_size=*/2,
141       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {3})},
142       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
143       /*drop_remainder=*/true,
144       /*parallel_copy=*/true,
145       /*output_dtypes=*/{DT_INT64},
146       /*output_shapes=*/{PartialTensorShape({2, 3})},
147       /*num_padded_shapes=*/1,
148       /*node_name=*/kNodeName);
149 }
150 
151 // Test case 3: similar with the test case 2 but drop_remainder = false.
PaddedBatchDatasetParams3()152 PaddedBatchDatasetParams PaddedBatchDatasetParams3() {
153   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
154       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
155                                             {{0, 1, 2, 3, 4, 5}}),
156       /*node_name=*/"tensor_slice_0");
157   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
158       /*components=*/CreateTensors<int64_t>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
159       /*node_name=*/"tensor_slice_1");
160   auto concatenate_dataset_params =
161       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
162                                std::move(tensor_slice_dataset_params_1),
163                                /*output_dtypes=*/{DT_INT64},
164                                /*output_shapes=*/{PartialTensorShape({-1})},
165                                /*node_name=*/"concatenate");
166   return PaddedBatchDatasetParams(
167       /*input_dataset_params=*/concatenate_dataset_params,
168       /*batch_size=*/2,
169       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {3})},
170       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
171       /*drop_remainder=*/false,
172       /*parallel_copy=*/true,
173       /*output_dtypes=*/{DT_INT64},
174       /*output_shapes=*/{PartialTensorShape({2, 3})},
175       /*num_padded_shapes=*/1,
176       /*node_name=*/kNodeName);
177 }
178 
179 // Test case 4: similar with the test case 3 but the input elements can be
180 // divided by the batch size evenly. As drop_remainder = false, the output
181 // shape is still {-1, 3} instead of {2, 3}.
PaddedBatchDatasetParams4()182 PaddedBatchDatasetParams PaddedBatchDatasetParams4() {
183   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
184       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
185                                             {{0, 1, 2, 3, 4, 5}}),
186       /*node_name=*/"tensor_slice_0");
187   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
188       /*components=*/CreateTensors<int64_t>(TensorShape{3, 1}, {{6, 7, 8}}),
189       /*node_name=*/"tensor_slice_1");
190   auto concatenate_dataset_params =
191       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
192                                std::move(tensor_slice_dataset_params_1),
193                                /*output_dtypes=*/{DT_INT64},
194                                /*output_shapes=*/{PartialTensorShape({-1})},
195                                /*node_name=*/"concatenate");
196   return PaddedBatchDatasetParams(
197       /*input_dataset_params=*/concatenate_dataset_params,
198       /*batch_size=*/2,
199       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {3})},
200       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
201       /*drop_remainder=*/false,
202       /*parallel_copy=*/true,
203       /*output_dtypes=*/{DT_INT64},
204       /*output_shapes=*/{PartialTensorShape({-1, 3})},
205       /*num_padded_shapes=*/1,
206       /*node_name=*/kNodeName);
207 }
208 
209 // Test case 5: similar with the test case 3 but padded_shapes = {-1}.
PaddedBatchDatasetParams5()210 PaddedBatchDatasetParams PaddedBatchDatasetParams5() {
211   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
212       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
213                                             {{0, 1, 2, 3, 4, 5}}),
214       /*node_name=*/"tensor_slice_0");
215   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
216       /*components=*/CreateTensors<int64_t>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
217       /*node_name=*/"tensor_slice_1");
218   auto concatenate_dataset_params =
219       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
220                                std::move(tensor_slice_dataset_params_1),
221                                /*output_dtypes=*/{DT_INT64},
222                                /*output_shapes=*/{PartialTensorShape({-1})},
223                                /*node_name=*/"concatenate");
224   return PaddedBatchDatasetParams(
225       /*input_dataset_params=*/concatenate_dataset_params,
226       /*batch_size=*/2,
227       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {-1})},
228       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
229       /*drop_remainder=*/false,
230       /*parallel_copy=*/false,
231       /*output_dtypes=*/{DT_INT64},
232       /*output_shapes=*/{PartialTensorShape({-1, -1})},
233       /*num_padded_shapes=*/1,
234       /*node_name=*/kNodeName);
235 }
236 
237 // Test case 6: similar with the test case 5 but parallel_copy = true.
PaddedBatchDatasetParams6()238 PaddedBatchDatasetParams PaddedBatchDatasetParams6() {
239   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
240       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
241                                             {{0, 1, 2, 3, 4, 5}}),
242       /*node_name=*/"tensor_slice_0");
243   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
244       /*components=*/CreateTensors<int64_t>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
245       /*node_name=*/"tensor_slice_1");
246   auto concatenate_dataset_params =
247       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
248                                std::move(tensor_slice_dataset_params_1),
249                                /*output_dtypes=*/{DT_INT64},
250                                /*output_shapes=*/{PartialTensorShape({-1})},
251                                /*node_name=*/"concatenate");
252   return PaddedBatchDatasetParams(
253       /*input_dataset_params=*/concatenate_dataset_params,
254       /*batch_size=*/2,
255       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {-1})},
256       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
257       /*drop_remainder=*/false,
258       /*parallel_copy=*/true,
259       /*output_dtypes=*/{DT_INT64},
260       /*output_shapes=*/{PartialTensorShape({-1, -1})},
261       /*num_padded_shapes=*/1,
262       /*node_name=*/kNodeName);
263 }
264 
265 // Test case 7: empty input elements.
PaddedBatchDatasetParams7()266 PaddedBatchDatasetParams PaddedBatchDatasetParams7() {
267   return PaddedBatchDatasetParams(
268       /*input_dataset_params=*/RangeDatasetParams(0, 0, 1),
269       /*batch_size=*/2,
270       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {-1})},
271       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
272       /*drop_remainder=*/false,
273       /*parallel_copy=*/true,
274       /*output_dtypes=*/{DT_INT64},
275       /*output_shapes=*/{PartialTensorShape({-1, -1})},
276       /*num_padded_shapes=*/1,
277       /*node_name=*/kNodeName);
278 }
279 
280 // Test case 8: short padding shape.
PaddedBatchDatasetParamsWithShortPaddingShape()281 PaddedBatchDatasetParams PaddedBatchDatasetParamsWithShortPaddingShape() {
282   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
283       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
284                                             {{0, 1, 2, 3, 4, 5}}),
285       /*node_name=*/"tensor_slice_0");
286   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
287       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
288                                             {{6, 7, 8, 9, 10, 11}}),
289       /*node_name=*/"tensor_slice_1");
290   auto concatenate_dataset_params =
291       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
292                                std::move(tensor_slice_dataset_params_1),
293                                /*output_dtypes=*/{DT_INT64},
294                                /*output_shapes=*/{PartialTensorShape({2})},
295                                /*node_name=*/"concatenate");
296   return PaddedBatchDatasetParams(
297       /*input_dataset_params=*/concatenate_dataset_params,
298       /*batch_size=*/2,
299       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {1})},
300       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
301       /*drop_remainder=*/false,
302       /*parallel_copy=*/true,
303       /*output_dtypes=*/{DT_INT64},
304       /*output_shapes=*/{PartialTensorShape({-1, -1})},
305       /*num_padded_shapes=*/1,
306       /*node_name=*/kNodeName);
307 }
308 
PaddedBatchDatasetParamsWithInvalidPaddingShape()309 PaddedBatchDatasetParams PaddedBatchDatasetParamsWithInvalidPaddingShape() {
310   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
311       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
312                                             {{0, 1, 2, 3, 4, 5}}),
313       /*node_name=*/"tensor_slice_0");
314   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
315       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
316                                             {{6, 7, 8, 9, 10, 11}}),
317       /*node_name=*/"tensor_slice_1");
318   auto concatenate_dataset_params =
319       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
320                                std::move(tensor_slice_dataset_params_1),
321                                /*output_dtypes=*/{DT_INT64},
322                                /*output_shapes=*/{PartialTensorShape({2})},
323                                /*node_name=*/"concatenate");
324   return PaddedBatchDatasetParams(
325       /*input_dataset_params=*/concatenate_dataset_params,
326       /*batch_size=*/2,
327       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{2}, {1, 2})},
328       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
329       /*drop_remainder=*/false,
330       /*parallel_copy=*/true,
331       /*output_dtypes=*/{DT_INT64},
332       /*output_shapes=*/{PartialTensorShape({-1, -1})},
333       /*num_padded_shapes=*/1,
334       /*node_name=*/kNodeName);
335 }
336 
PaddedBatchDatasetParamsWithInvalidBatchSize()337 PaddedBatchDatasetParams PaddedBatchDatasetParamsWithInvalidBatchSize() {
338   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
339       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
340                                             {{0, 1, 2, 3, 4, 5}}),
341       /*node_name=*/"tensor_slice_0");
342   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
343       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
344                                             {{6, 7, 8, 9, 10, 11}}),
345       /*node_name=*/"tensor_slice_1");
346   auto concatenate_dataset_params =
347       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
348                                std::move(tensor_slice_dataset_params_1),
349                                /*output_dtypes=*/{DT_INT64},
350                                /*output_shapes=*/{PartialTensorShape({2})},
351                                /*node_name=*/"concatenate");
352   return PaddedBatchDatasetParams(
353       /*input_dataset_params=*/concatenate_dataset_params,
354       /*batch_size=*/-1,
355       /*padded_shapes=*/{CreateTensor<int64_t>(TensorShape{1}, {3})},
356       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
357       /*drop_remainder=*/false,
358       /*parallel_copy=*/true,
359       /*output_dtypes=*/{DT_INT64},
360       /*output_shapes=*/{PartialTensorShape({-1, -1})},
361       /*num_padded_shapes=*/1,
362       /*node_name=*/kNodeName);
363 }
364 
365 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingShapesSize()366 PaddedBatchDatasetParamsWithInvalidPaddingShapesSize() {
367   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
368       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
369                                             {{0, 1, 2, 3, 4, 5}}),
370       /*node_name=*/"tensor_slice_0");
371   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
372       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
373                                             {{6, 7, 8, 9, 10, 11}}),
374       /*node_name=*/"tensor_slice_1");
375   auto concatenate_dataset_params =
376       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
377                                std::move(tensor_slice_dataset_params_1),
378                                /*output_dtypes=*/{DT_INT64},
379                                /*output_shapes=*/{PartialTensorShape({2})},
380                                /*node_name=*/"concatenate");
381   return PaddedBatchDatasetParams(
382       /*input_dataset_params=*/concatenate_dataset_params,
383       /*batch_size=*/2,
384       /*padded_shapes=*/
385       {CreateTensor<int64_t>(TensorShape{1}, {3}),
386        CreateTensor<int64_t>(TensorShape{1}, {3})},
387       /*padded_values=*/{CreateTensor<int64_t>(TensorShape{}, {1})},
388       /*drop_remainder=*/false,
389       /*parallel_copy=*/true,
390       /*output_dtypes=*/{DT_INT64},
391       /*output_shapes=*/{PartialTensorShape({-1, -1})},
392       /*num_padded_shapes=*/2,
393       /*node_name=*/kNodeName);
394 }
395 
396 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingValuesSize()397 PaddedBatchDatasetParamsWithInvalidPaddingValuesSize() {
398   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
399       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
400                                             {{0, 1, 2, 3, 4, 5}}),
401       /*node_name=*/"tensor_slice_0");
402   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
403       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
404                                             {{6, 7, 8, 9, 10, 11}}),
405       /*node_name=*/"tensor_slice_1");
406   auto concatenate_dataset_params =
407       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
408                                std::move(tensor_slice_dataset_params_1),
409                                /*output_dtypes=*/{DT_INT64},
410                                /*output_shapes=*/{PartialTensorShape({2})},
411                                /*node_name=*/"concatenate");
412   return PaddedBatchDatasetParams(
413       /*input_dataset_params=*/concatenate_dataset_params,
414       /*batch_size=*/2,
415       /*padded_shapes=*/
416       {CreateTensor<int64_t>(TensorShape{1}, {3})},
417       /*padded_values=*/
418       {CreateTensor<int64_t>(TensorShape{}, {1}),
419        CreateTensor<int64_t>(TensorShape{}, {1})},
420       /*drop_remainder=*/false,
421       /*parallel_copy=*/true,
422       /*output_dtypes=*/{DT_INT64},
423       /*output_shapes=*/{PartialTensorShape({-1, -1})},
424       /*num_padded_shapes=*/2,
425       /*node_name=*/kNodeName);
426 }
427 
428 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingValuesDType()429 PaddedBatchDatasetParamsWithInvalidPaddingValuesDType() {
430   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
431       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
432                                             {{0, 1, 2, 3, 4, 5}}),
433       /*node_name=*/"tensor_slice_0");
434   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
435       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
436                                             {{6, 7, 8, 9, 10, 11}}),
437       /*node_name=*/"tensor_slice_1");
438   auto concatenate_dataset_params =
439       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
440                                std::move(tensor_slice_dataset_params_1),
441                                /*output_dtypes=*/{DT_INT64},
442                                /*output_shapes=*/{PartialTensorShape({2})},
443                                /*node_name=*/"concatenate");
444   return PaddedBatchDatasetParams(
445       /*input_dataset_params=*/concatenate_dataset_params,
446       /*batch_size=*/2,
447       /*padded_shapes=*/
448       {CreateTensor<int64_t>(TensorShape{1}, {3})},
449       /*padded_values=*/
450       {CreateTensor<tstring>(TensorShape{}, {"a"})},
451       /*drop_remainder=*/false,
452       /*parallel_copy=*/true,
453       /*output_dtypes=*/{DT_INT64},
454       /*output_shapes=*/{PartialTensorShape({-1, -1})},
455       /*num_padded_shapes=*/1,
456       /*node_name=*/kNodeName);
457 }
458 
459 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingValuesShape()460 PaddedBatchDatasetParamsWithInvalidPaddingValuesShape() {
461   auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
462       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
463                                             {{0, 1, 2, 3, 4, 5}}),
464       /*node_name=*/"tensor_slice_0");
465   auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
466       /*components=*/CreateTensors<int64_t>(TensorShape{3, 2},
467                                             {{6, 7, 8, 9, 10, 11}}),
468       /*node_name=*/"tensor_slice_1");
469   auto concatenate_dataset_params =
470       ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
471                                std::move(tensor_slice_dataset_params_1),
472                                /*output_dtypes=*/{DT_INT64},
473                                /*output_shapes=*/{PartialTensorShape({2})},
474                                /*node_name=*/"concatenate");
475   return PaddedBatchDatasetParams(
476       /*input_dataset_params=*/concatenate_dataset_params,
477       /*batch_size=*/2,
478       /*padded_shapes=*/
479       {CreateTensor<int64_t>(TensorShape{1}, {3})},
480       /*padded_values=*/
481       {CreateTensor<int64_t>(TensorShape{1}, {1})},
482       /*drop_remainder=*/false,
483       /*parallel_copy=*/true,
484       /*output_dtypes=*/{DT_INT64},
485       /*output_shapes=*/{PartialTensorShape({-1, -1})},
486       /*num_padded_shapes=*/1,
487       /*node_name=*/kNodeName);
488 }
489 
GetNextTestCases()490 std::vector<GetNextTestCase<PaddedBatchDatasetParams>> GetNextTestCases() {
491   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
492            /*expected_outputs=*/
493            CreateTensors<int64_t>(
494                TensorShape{2, 3},
495                {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 7, 1}, {8, 9, 1, 10, 11, 1}})},
496           {/*dataset_params=*/PaddedBatchDatasetParams2(),
497            /*expected_outputs=*/
498            CreateTensors<int64_t>(
499                TensorShape{2, 3},
500                {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
501           {/*dataset_params=*/PaddedBatchDatasetParams3(),
502            /*expected_outputs=*/
503            {CreateTensor<int64_t>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
504             CreateTensor<int64_t>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
505             CreateTensor<int64_t>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1}),
506             CreateTensor<int64_t>(TensorShape{1, 3}, {9, 1, 1})}},
507           {/*dataset_params=*/PaddedBatchDatasetParams4(),
508            /*expected_outputs=*/
509            CreateTensors<int64_t>(
510                TensorShape{2, 3},
511                {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
512           {/*dataset_params=*/PaddedBatchDatasetParams5(),
513            /*expected_outputs=*/
514            {CreateTensor<int64_t>(TensorShape{2, 2}, {0, 1, 2, 3}),
515             CreateTensor<int64_t>(TensorShape{2, 2}, {4, 5, 6, 1}),
516             CreateTensor<int64_t>(TensorShape{2, 1}, {7, 8}),
517             CreateTensor<int64_t>(TensorShape{1, 1}, {9})}},
518           {/*dataset_params=*/PaddedBatchDatasetParams6(),
519            /*expected_outputs=*/
520            {CreateTensor<int64_t>(TensorShape{2, 2}, {0, 1, 2, 3}),
521             CreateTensor<int64_t>(TensorShape{2, 2}, {4, 5, 6, 1}),
522             CreateTensor<int64_t>(TensorShape{2, 1}, {7, 8}),
523             CreateTensor<int64_t>(TensorShape{1, 1}, {9})}},
524           {/*dataset_params=*/PaddedBatchDatasetParams7(),
525            /*expected_outputs=*/{}}};
526 }
527 
ITERATOR_GET_NEXT_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,GetNextTestCases ())528 ITERATOR_GET_NEXT_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
529                          GetNextTestCases())
530 
531 TEST_F(PaddedBatchDatasetOpTest, DatasetNodeName) {
532   auto dataset_params = PaddedBatchDatasetParams1();
533   TF_ASSERT_OK(Initialize(dataset_params));
534   TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
535 }
536 
TEST_F(PaddedBatchDatasetOpTest,DatasetTypeString)537 TEST_F(PaddedBatchDatasetOpTest, DatasetTypeString) {
538   auto dataset_params = PaddedBatchDatasetParams1();
539   TF_ASSERT_OK(Initialize(dataset_params));
540   name_utils::OpNameParams params;
541   params.op_version = dataset_params.op_version();
542   TF_ASSERT_OK(CheckDatasetTypeString(
543       name_utils::OpName(PaddedBatchDatasetOp::kDatasetType, params)));
544 }
545 
546 std::vector<DatasetOutputDtypesTestCase<PaddedBatchDatasetParams>>
DatasetOutputDtypesTestCases()547 DatasetOutputDtypesTestCases() {
548   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
549            /*expected_output_dtypes=*/{DT_INT64}},
550           {/*dataset_params=*/PaddedBatchDatasetParams2(),
551            /*expected_output_dtypes=*/{DT_INT64}},
552           {/*dataset_params=*/PaddedBatchDatasetParams3(),
553            /*expected_output_dtypes=*/{DT_INT64}},
554           {/*dataset_params=*/PaddedBatchDatasetParams4(),
555            /*expected_output_dtypes=*/{DT_INT64}},
556           {/*dataset_params=*/PaddedBatchDatasetParams5(),
557            /*expected_output_dtypes=*/{DT_INT64}},
558           {/*dataset_params=*/PaddedBatchDatasetParams6(),
559            /*expected_output_dtypes=*/{DT_INT64}},
560           {/*dataset_params=*/PaddedBatchDatasetParams7(),
561            /*expected_output_dtypes=*/{DT_INT64}}};
562 }
563 
DATASET_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,DatasetOutputDtypesTestCases ())564 DATASET_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
565                              DatasetOutputDtypesTestCases())
566 
567 std::vector<DatasetOutputShapesTestCase<PaddedBatchDatasetParams>>
568 DatasetOutputShapesTestCases() {
569   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
570            /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
571           {/*dataset_params=*/PaddedBatchDatasetParams2(),
572            /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
573           {/*dataset_params=*/PaddedBatchDatasetParams3(),
574            /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
575           {/*dataset_params=*/PaddedBatchDatasetParams4(),
576            /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
577           {/*dataset_params=*/PaddedBatchDatasetParams5(),
578            /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
579           {/*dataset_params=*/PaddedBatchDatasetParams6(),
580            /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
581           {/*dataset_params=*/PaddedBatchDatasetParams7(),
582            /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}}};
583 }
584 
DATASET_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,DatasetOutputShapesTestCases ())585 DATASET_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
586                              DatasetOutputShapesTestCases())
587 
588 std::vector<CardinalityTestCase<PaddedBatchDatasetParams>>
589 CardinalityTestCases() {
590   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
591            /*expected_cardinality=*/3},
592           {/*dataset_params=*/PaddedBatchDatasetParams2(),
593            /*expected_cardinality=*/3},
594           {/*dataset_params=*/PaddedBatchDatasetParams3(),
595            /*expected_cardinality=*/4},
596           {/*dataset_params=*/PaddedBatchDatasetParams4(),
597            /*expected_cardinality=*/3},
598           {/*dataset_params=*/PaddedBatchDatasetParams5(),
599            /*expected_cardinality=*/4},
600           {/*dataset_params=*/PaddedBatchDatasetParams6(),
601            /*expected_cardinality=*/4},
602           {/*dataset_params=*/PaddedBatchDatasetParams7(),
603            /*expected_cardinality=*/0}};
604 }
605 
DATASET_CARDINALITY_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,CardinalityTestCases ())606 DATASET_CARDINALITY_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
607                            CardinalityTestCases())
608 
609 std::vector<IteratorOutputDtypesTestCase<PaddedBatchDatasetParams>>
610 IteratorOutputDtypesTestCases() {
611   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
612            /*expected_output_dtypes=*/{DT_INT64}},
613           {/*dataset_params=*/PaddedBatchDatasetParams2(),
614            /*expected_output_dtypes=*/{DT_INT64}},
615           {/*dataset_params=*/PaddedBatchDatasetParams3(),
616            /*expected_output_dtypes=*/{DT_INT64}},
617           {/*dataset_params=*/PaddedBatchDatasetParams4(),
618            /*expected_output_dtypes=*/{DT_INT64}},
619           {/*dataset_params=*/PaddedBatchDatasetParams5(),
620            /*expected_output_dtypes=*/{DT_INT64}},
621           {/*dataset_params=*/PaddedBatchDatasetParams6(),
622            /*expected_output_dtypes=*/{DT_INT64}},
623           {/*dataset_params=*/PaddedBatchDatasetParams7(),
624            /*expected_output_dtypes=*/{DT_INT64}}};
625 }
626 
ITERATOR_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,IteratorOutputDtypesTestCases ())627 ITERATOR_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest,
628                               PaddedBatchDatasetParams,
629                               IteratorOutputDtypesTestCases())
630 
631 std::vector<IteratorOutputShapesTestCase<PaddedBatchDatasetParams>>
632 IteratorOutputShapesTestCases() {
633   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
634            /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
635           {/*dataset_params=*/PaddedBatchDatasetParams2(),
636            /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
637           {/*dataset_params=*/PaddedBatchDatasetParams3(),
638            /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
639           {/*dataset_params=*/PaddedBatchDatasetParams4(),
640            /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
641           {/*dataset_params=*/PaddedBatchDatasetParams5(),
642            /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
643           {/*dataset_params=*/PaddedBatchDatasetParams6(),
644            /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
645           {/*dataset_params=*/PaddedBatchDatasetParams7(),
646            /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}}};
647 }
648 
ITERATOR_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,IteratorOutputShapesTestCases ())649 ITERATOR_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest,
650                               PaddedBatchDatasetParams,
651                               IteratorOutputShapesTestCases())
652 
653 TEST_F(PaddedBatchDatasetOpTest, IteratorPrefix) {
654   auto dataset_params = PaddedBatchDatasetParams1();
655   TF_ASSERT_OK(Initialize(dataset_params));
656   name_utils::IteratorPrefixParams params;
657   params.op_version = dataset_params.op_version();
658   TF_ASSERT_OK(CheckIteratorPrefix(
659       name_utils::IteratorPrefix(PaddedBatchDatasetOp::kDatasetType,
660                                  dataset_params.iterator_prefix(), params)));
661 }
662 
663 std::vector<IteratorSaveAndRestoreTestCase<PaddedBatchDatasetParams>>
IteratorSaveAndRestoreTestCases()664 IteratorSaveAndRestoreTestCases() {
665   return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
666            /*breakpoints=*/{0, 2, 5},
667            /*expected_outputs=*/
668            CreateTensors<int64_t>(
669                TensorShape{2, 3},
670                {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 7, 1}, {8, 9, 1, 10, 11, 1}})},
671           {/*dataset_params=*/PaddedBatchDatasetParams2(),
672            /*breakpoints=*/{0, 2, 5},
673            /*expected_outputs=*/
674            CreateTensors<int64_t>(
675                TensorShape{2, 3},
676                {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
677           {/*dataset_params=*/PaddedBatchDatasetParams3(),
678            /*breakpoints=*/{0, 2, 5},
679            /*expected_outputs=*/
680            {CreateTensor<int64_t>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
681             CreateTensor<int64_t>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
682             CreateTensor<int64_t>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1}),
683             CreateTensor<int64_t>(TensorShape{1, 3}, {9, 1, 1})}},
684           {/*dataset_params=*/PaddedBatchDatasetParams4(),
685            /*breakpoints=*/{0, 2, 5},
686            /*expected_outputs=*/
687            CreateTensors<int64_t>(
688                TensorShape{2, 3},
689                {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
690           {/*dataset_params=*/PaddedBatchDatasetParams5(),
691            /*breakpoints=*/{0, 2, 5},
692            /*expected_outputs=*/
693            {CreateTensor<int64_t>(TensorShape{2, 2}, {0, 1, 2, 3}),
694             CreateTensor<int64_t>(TensorShape{2, 2}, {4, 5, 6, 1}),
695             CreateTensor<int64_t>(TensorShape{2, 1}, {7, 8}),
696             CreateTensor<int64_t>(TensorShape{1, 1}, {9})}},
697           {/*dataset_params=*/PaddedBatchDatasetParams6(),
698            /*breakpoints=*/{0, 2, 5},
699            /*expected_outputs=*/
700            {CreateTensor<int64_t>(TensorShape{2, 2}, {0, 1, 2, 3}),
701             CreateTensor<int64_t>(TensorShape{2, 2}, {4, 5, 6, 1}),
702             CreateTensor<int64_t>(TensorShape{2, 1}, {7, 8}),
703             CreateTensor<int64_t>(TensorShape{1, 1}, {9})}},
704           {/*dataset_params=*/PaddedBatchDatasetParams7(),
705            /*breakpoints=*/{0, 2, 5},
706            /*expected_outputs=*/{}}};
707 }
708 
ITERATOR_SAVE_AND_RESTORE_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,IteratorSaveAndRestoreTestCases ())709 ITERATOR_SAVE_AND_RESTORE_TEST_P(PaddedBatchDatasetOpTest,
710                                  PaddedBatchDatasetParams,
711                                  IteratorSaveAndRestoreTestCases())
712 
713 TEST_F(PaddedBatchDatasetOpTest, ShortPadding) {
714   auto dataset_params = PaddedBatchDatasetParamsWithShortPaddingShape();
715   TF_ASSERT_OK(Initialize(dataset_params));
716   bool end_of_sequence = false;
717   std::vector<Tensor> out_tensors;
718   EXPECT_EQ(
719       iterator_->GetNext(iterator_ctx_.get(), &out_tensors, &end_of_sequence)
720           .code(),
721       tensorflow::error::DATA_LOSS);
722 }
723 
TEST_F(PaddedBatchDatasetOpTest,InvalidPaddedShapes)724 TEST_F(PaddedBatchDatasetOpTest, InvalidPaddedShapes) {
725   auto dataset_params = PaddedBatchDatasetParamsWithInvalidPaddingShape();
726   TF_ASSERT_OK(Initialize(dataset_params));
727   bool end_of_sequence = false;
728   std::vector<Tensor> out_tensors;
729   EXPECT_EQ(
730       iterator_->GetNext(iterator_ctx_.get(), &out_tensors, &end_of_sequence)
731           .code(),
732       tensorflow::error::INVALID_ARGUMENT);
733 }
734 
735 class ParameterizedInvalidArgumentTest
736     : public PaddedBatchDatasetOpTest,
737       public ::testing::WithParamInterface<PaddedBatchDatasetParams> {};
738 
TEST_P(ParameterizedInvalidArgumentTest,InvalidPredicateFunc)739 TEST_P(ParameterizedInvalidArgumentTest, InvalidPredicateFunc) {
740   auto dataset_params = GetParam();
741   EXPECT_EQ(Initialize(dataset_params).code(),
742             tensorflow::error::INVALID_ARGUMENT);
743 }
744 
745 INSTANTIATE_TEST_SUITE_P(
746     PaddedBatchDatasetOpTest, ParameterizedInvalidArgumentTest,
747     ::testing::ValuesIn(
748         {PaddedBatchDatasetParamsWithInvalidBatchSize(),
749          PaddedBatchDatasetParamsWithInvalidPaddingShapesSize(),
750          PaddedBatchDatasetParamsWithInvalidPaddingValuesSize(),
751          PaddedBatchDatasetParamsWithInvalidPaddingValuesDType(),
752          PaddedBatchDatasetParamsWithInvalidPaddingValuesShape()}));
753 
754 }  // namespace
755 }  // namespace data
756 }  // namespace tensorflow
757