1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/padded_batch_dataset_op.h"
13
14 #include "tensorflow/core/kernels/data/dataset_test_base.h"
15
16 namespace tensorflow {
17 namespace data {
18 namespace {
19
20 constexpr char kNodeName[] = "padded_batch_dataset";
21 constexpr int kOpVersion = 2;
22
23 class PaddedBatchDatasetOpTest : public DatasetOpsTestBase {};
24
25 class PaddedBatchDatasetParams : public DatasetParams {
26 public:
27 template <typename T>
PaddedBatchDatasetParams(T input_dataset_params,int64 batch_size,std::vector<Tensor> padded_shapes,std::vector<Tensor> padded_values,bool drop_remainder,bool parallel_copy,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,int num_padded_shapes,string node_name)28 PaddedBatchDatasetParams(T input_dataset_params, int64 batch_size,
29 std::vector<Tensor> padded_shapes,
30 std::vector<Tensor> padded_values,
31 bool drop_remainder, bool parallel_copy,
32 DataTypeVector output_dtypes,
33 std::vector<PartialTensorShape> output_shapes,
34 int num_padded_shapes, string node_name)
35 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
36 std::move(node_name)),
37 batch_size_(batch_size),
38 padded_shapes_(std::move(padded_shapes)),
39 padded_values_(std::move(padded_values)),
40 drop_remainder_(drop_remainder),
41 parallel_copy_(parallel_copy),
42 num_padded_shapes_(num_padded_shapes) {
43 input_dataset_params_.push_back(absl::make_unique<T>(input_dataset_params));
44 op_version_ = kOpVersion;
45 iterator_prefix_ =
46 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
47 input_dataset_params.iterator_prefix());
48 }
49
GetInputTensors() const50 std::vector<Tensor> GetInputTensors() const override {
51 std::vector<Tensor> input_tensors;
52 input_tensors.emplace_back(
53 CreateTensor<int64>(TensorShape({}), {batch_size_}));
54 for (auto& padded_shape : padded_shapes_) {
55 input_tensors.emplace_back(padded_shape);
56 }
57 for (auto& padded_value : padded_values_) {
58 input_tensors.emplace_back(padded_value);
59 }
60 input_tensors.emplace_back(
61 CreateTensor<bool>(TensorShape({}), {drop_remainder_}));
62 return input_tensors;
63 }
64
GetInputNames(std::vector<string> * input_names) const65 Status GetInputNames(std::vector<string>* input_names) const override {
66 *input_names = {PaddedBatchDatasetOp::kInputDataset,
67 PaddedBatchDatasetOp::kBatchSize};
68 // Create the input names for the input padded_shapes.
69 for (int i = 0; i < num_padded_shapes_; ++i) {
70 input_names->emplace_back(
71 strings::StrCat(PaddedBatchDatasetOp::kPaddedShapes, "_", i));
72 }
73 // Create the input names for the input padding_values.
74 for (int j = 0; j < padded_values_.size(); ++j) {
75 input_names->emplace_back(
76 strings::StrCat(PaddedBatchDatasetOp::kPaddingValues, "_", j));
77 }
78 input_names->push_back(PaddedBatchDatasetOp::kDropRemainder);
79 return Status::OK();
80 }
81
GetAttributes(AttributeVector * attr_vector) const82 Status GetAttributes(AttributeVector* attr_vector) const override {
83 *attr_vector = {
84 {PaddedBatchDatasetOp::kParallelCopy, parallel_copy_},
85 {PaddedBatchDatasetOp::kToutputTypes, output_dtypes_},
86 {PaddedBatchDatasetOp::kOutputShapes, output_shapes_},
87 {PaddedBatchDatasetOp::kNumPaddedShapes, num_padded_shapes_}};
88 return Status::OK();
89 }
90
dataset_type() const91 string dataset_type() const override {
92 return PaddedBatchDatasetOp::kDatasetType;
93 }
94
95 private:
96 int64 batch_size_;
97 std::vector<Tensor> padded_shapes_;
98 std::vector<Tensor> padded_values_;
99 bool drop_remainder_;
100 bool parallel_copy_;
101 int num_padded_shapes_;
102 };
103
104 // Test case 1: input elements with same shapes.
PaddedBatchDatasetParams1()105 PaddedBatchDatasetParams PaddedBatchDatasetParams1() {
106 auto tensor_slice_dataset_params = TensorSliceDatasetParams(
107 /*components=*/{CreateTensor<int64>(
108 TensorShape{7, 2}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})},
109 /*node_name=*/"tensor_slice");
110 return PaddedBatchDatasetParams(
111 /*input_dataset_params=*/tensor_slice_dataset_params,
112 /*batch_size=*/2,
113 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {3})},
114 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
115 /*drop_remainder=*/true,
116 /*parallel_copy=*/true,
117 /*output_dtypes=*/{DT_INT64},
118 /*output_shapes=*/{PartialTensorShape({2, 3})},
119 /*num_padded_shapes=*/1,
120 /*node_name=*/kNodeName);
121 }
122
123 // Test case 2: input elements with different shapes.
PaddedBatchDatasetParams2()124 PaddedBatchDatasetParams PaddedBatchDatasetParams2() {
125 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
126 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
127 {{0, 1, 2, 3, 4, 5}}),
128 /*node_name=*/"tensor_slice_0");
129 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
130 /*components=*/CreateTensors<int64>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
131 /*node_name=*/"tensor_slice_1");
132 auto concatenate_dataset_params =
133 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
134 std::move(tensor_slice_dataset_params_1),
135 /*output_dtypes=*/{DT_INT64},
136 /*output_shapes=*/{PartialTensorShape({-1})},
137 /*node_name=*/"concatenate");
138 return PaddedBatchDatasetParams(
139 /*input_dataset_params=*/concatenate_dataset_params,
140 /*batch_size=*/2,
141 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {3})},
142 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
143 /*drop_remainder=*/true,
144 /*parallel_copy=*/true,
145 /*output_dtypes=*/{DT_INT64},
146 /*output_shapes=*/{PartialTensorShape({2, 3})},
147 /*num_padded_shapes=*/1,
148 /*node_name=*/kNodeName);
149 }
150
151 // Test case 3: similar with the test case 2 but drop_remainder = false.
PaddedBatchDatasetParams3()152 PaddedBatchDatasetParams PaddedBatchDatasetParams3() {
153 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
154 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
155 {{0, 1, 2, 3, 4, 5}}),
156 /*node_name=*/"tensor_slice_0");
157 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
158 /*components=*/CreateTensors<int64>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
159 /*node_name=*/"tensor_slice_1");
160 auto concatenate_dataset_params =
161 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
162 std::move(tensor_slice_dataset_params_1),
163 /*output_dtypes=*/{DT_INT64},
164 /*output_shapes=*/{PartialTensorShape({-1})},
165 /*node_name=*/"concatenate");
166 return PaddedBatchDatasetParams(
167 /*input_dataset_params=*/concatenate_dataset_params,
168 /*batch_size=*/2,
169 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {3})},
170 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
171 /*drop_remainder=*/false,
172 /*parallel_copy=*/true,
173 /*output_dtypes=*/{DT_INT64},
174 /*output_shapes=*/{PartialTensorShape({2, 3})},
175 /*num_padded_shapes=*/1,
176 /*node_name=*/kNodeName);
177 }
178
179 // Test case 4: similar with the test case 3 but the input elements can be
180 // divided by the batch size evenly. As drop_remainder = false, the output
181 // shape is still {-1, 3} instead of {2, 3}.
PaddedBatchDatasetParams4()182 PaddedBatchDatasetParams PaddedBatchDatasetParams4() {
183 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
184 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
185 {{0, 1, 2, 3, 4, 5}}),
186 /*node_name=*/"tensor_slice_0");
187 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
188 /*components=*/CreateTensors<int64>(TensorShape{3, 1}, {{6, 7, 8}}),
189 /*node_name=*/"tensor_slice_1");
190 auto concatenate_dataset_params =
191 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
192 std::move(tensor_slice_dataset_params_1),
193 /*output_dtypes=*/{DT_INT64},
194 /*output_shapes=*/{PartialTensorShape({-1})},
195 /*node_name=*/"concatenate");
196 return PaddedBatchDatasetParams(
197 /*input_dataset_params=*/concatenate_dataset_params,
198 /*batch_size=*/2,
199 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {3})},
200 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
201 /*drop_remainder=*/false,
202 /*parallel_copy=*/true,
203 /*output_dtypes=*/{DT_INT64},
204 /*output_shapes=*/{PartialTensorShape({-1, 3})},
205 /*num_padded_shapes=*/1,
206 /*node_name=*/kNodeName);
207 }
208
209 // Test case 5: similar with the test case 3 but padded_shapes = {-1}.
PaddedBatchDatasetParams5()210 PaddedBatchDatasetParams PaddedBatchDatasetParams5() {
211 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
212 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
213 {{0, 1, 2, 3, 4, 5}}),
214 /*node_name=*/"tensor_slice_0");
215 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
216 /*components=*/CreateTensors<int64>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
217 /*node_name=*/"tensor_slice_1");
218 auto concatenate_dataset_params =
219 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
220 std::move(tensor_slice_dataset_params_1),
221 /*output_dtypes=*/{DT_INT64},
222 /*output_shapes=*/{PartialTensorShape({-1})},
223 /*node_name=*/"concatenate");
224 return PaddedBatchDatasetParams(
225 /*input_dataset_params=*/concatenate_dataset_params,
226 /*batch_size=*/2,
227 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {-1})},
228 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
229 /*drop_remainder=*/false,
230 /*parallel_copy=*/false,
231 /*output_dtypes=*/{DT_INT64},
232 /*output_shapes=*/{PartialTensorShape({-1, -1})},
233 /*num_padded_shapes=*/1,
234 /*node_name=*/kNodeName);
235 }
236
237 // Test case 6: similar with the test case 5 but parallel_copy = true.
PaddedBatchDatasetParams6()238 PaddedBatchDatasetParams PaddedBatchDatasetParams6() {
239 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
240 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
241 {{0, 1, 2, 3, 4, 5}}),
242 /*node_name=*/"tensor_slice_0");
243 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
244 /*components=*/CreateTensors<int64>(TensorShape{4, 1}, {{6, 7, 8, 9}}),
245 /*node_name=*/"tensor_slice_1");
246 auto concatenate_dataset_params =
247 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
248 std::move(tensor_slice_dataset_params_1),
249 /*output_dtypes=*/{DT_INT64},
250 /*output_shapes=*/{PartialTensorShape({-1})},
251 /*node_name=*/"concatenate");
252 return PaddedBatchDatasetParams(
253 /*input_dataset_params=*/concatenate_dataset_params,
254 /*batch_size=*/2,
255 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {-1})},
256 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
257 /*drop_remainder=*/false,
258 /*parallel_copy=*/true,
259 /*output_dtypes=*/{DT_INT64},
260 /*output_shapes=*/{PartialTensorShape({-1, -1})},
261 /*num_padded_shapes=*/1,
262 /*node_name=*/kNodeName);
263 }
264
265 // Test case 7: empty input elements.
PaddedBatchDatasetParams7()266 PaddedBatchDatasetParams PaddedBatchDatasetParams7() {
267 return PaddedBatchDatasetParams(
268 /*input_dataset_params=*/RangeDatasetParams(0, 0, 1),
269 /*batch_size=*/2,
270 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {-1})},
271 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
272 /*drop_remainder=*/false,
273 /*parallel_copy=*/true,
274 /*output_dtypes=*/{DT_INT64},
275 /*output_shapes=*/{PartialTensorShape({-1, -1})},
276 /*num_padded_shapes=*/1,
277 /*node_name=*/kNodeName);
278 }
279
280 // Test case 8: short padding shape.
PaddedBatchDatasetParamsWithShortPaddingShape()281 PaddedBatchDatasetParams PaddedBatchDatasetParamsWithShortPaddingShape() {
282 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
283 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
284 {{0, 1, 2, 3, 4, 5}}),
285 /*node_name=*/"tensor_slice_0");
286 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
287 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
288 {{6, 7, 8, 9, 10, 11}}),
289 /*node_name=*/"tensor_slice_1");
290 auto concatenate_dataset_params =
291 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
292 std::move(tensor_slice_dataset_params_1),
293 /*output_dtypes=*/{DT_INT64},
294 /*output_shapes=*/{PartialTensorShape({2})},
295 /*node_name=*/"concatenate");
296 return PaddedBatchDatasetParams(
297 /*input_dataset_params=*/concatenate_dataset_params,
298 /*batch_size=*/2,
299 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {1})},
300 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
301 /*drop_remainder=*/false,
302 /*parallel_copy=*/true,
303 /*output_dtypes=*/{DT_INT64},
304 /*output_shapes=*/{PartialTensorShape({-1, -1})},
305 /*num_padded_shapes=*/1,
306 /*node_name=*/kNodeName);
307 }
308
PaddedBatchDatasetParamsWithInvalidPaddingShape()309 PaddedBatchDatasetParams PaddedBatchDatasetParamsWithInvalidPaddingShape() {
310 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
311 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
312 {{0, 1, 2, 3, 4, 5}}),
313 /*node_name=*/"tensor_slice_0");
314 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
315 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
316 {{6, 7, 8, 9, 10, 11}}),
317 /*node_name=*/"tensor_slice_1");
318 auto concatenate_dataset_params =
319 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
320 std::move(tensor_slice_dataset_params_1),
321 /*output_dtypes=*/{DT_INT64},
322 /*output_shapes=*/{PartialTensorShape({2})},
323 /*node_name=*/"concatenate");
324 return PaddedBatchDatasetParams(
325 /*input_dataset_params=*/concatenate_dataset_params,
326 /*batch_size=*/2,
327 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{2}, {1, 2})},
328 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
329 /*drop_remainder=*/false,
330 /*parallel_copy=*/true,
331 /*output_dtypes=*/{DT_INT64},
332 /*output_shapes=*/{PartialTensorShape({-1, -1})},
333 /*num_padded_shapes=*/1,
334 /*node_name=*/kNodeName);
335 }
336
PaddedBatchDatasetParamsWithInvalidBatchSize()337 PaddedBatchDatasetParams PaddedBatchDatasetParamsWithInvalidBatchSize() {
338 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
339 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
340 {{0, 1, 2, 3, 4, 5}}),
341 /*node_name=*/"tensor_slice_0");
342 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
343 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
344 {{6, 7, 8, 9, 10, 11}}),
345 /*node_name=*/"tensor_slice_1");
346 auto concatenate_dataset_params =
347 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
348 std::move(tensor_slice_dataset_params_1),
349 /*output_dtypes=*/{DT_INT64},
350 /*output_shapes=*/{PartialTensorShape({2})},
351 /*node_name=*/"concatenate");
352 return PaddedBatchDatasetParams(
353 /*input_dataset_params=*/concatenate_dataset_params,
354 /*batch_size=*/-1,
355 /*padded_shapes=*/{CreateTensor<int64>(TensorShape{1}, {3})},
356 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
357 /*drop_remainder=*/false,
358 /*parallel_copy=*/true,
359 /*output_dtypes=*/{DT_INT64},
360 /*output_shapes=*/{PartialTensorShape({-1, -1})},
361 /*num_padded_shapes=*/1,
362 /*node_name=*/kNodeName);
363 }
364
365 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingShapesSize()366 PaddedBatchDatasetParamsWithInvalidPaddingShapesSize() {
367 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
368 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
369 {{0, 1, 2, 3, 4, 5}}),
370 /*node_name=*/"tensor_slice_0");
371 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
372 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
373 {{6, 7, 8, 9, 10, 11}}),
374 /*node_name=*/"tensor_slice_1");
375 auto concatenate_dataset_params =
376 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
377 std::move(tensor_slice_dataset_params_1),
378 /*output_dtypes=*/{DT_INT64},
379 /*output_shapes=*/{PartialTensorShape({2})},
380 /*node_name=*/"concatenate");
381 return PaddedBatchDatasetParams(
382 /*input_dataset_params=*/concatenate_dataset_params,
383 /*batch_size=*/2,
384 /*padded_shapes=*/
385 {CreateTensor<int64>(TensorShape{1}, {3}),
386 CreateTensor<int64>(TensorShape{1}, {3})},
387 /*padded_values=*/{CreateTensor<int64>(TensorShape{}, {1})},
388 /*drop_remainder=*/false,
389 /*parallel_copy=*/true,
390 /*output_dtypes=*/{DT_INT64},
391 /*output_shapes=*/{PartialTensorShape({-1, -1})},
392 /*num_padded_shapes=*/2,
393 /*node_name=*/kNodeName);
394 }
395
396 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingValuesSize()397 PaddedBatchDatasetParamsWithInvalidPaddingValuesSize() {
398 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
399 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
400 {{0, 1, 2, 3, 4, 5}}),
401 /*node_name=*/"tensor_slice_0");
402 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
403 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
404 {{6, 7, 8, 9, 10, 11}}),
405 /*node_name=*/"tensor_slice_1");
406 auto concatenate_dataset_params =
407 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
408 std::move(tensor_slice_dataset_params_1),
409 /*output_dtypes=*/{DT_INT64},
410 /*output_shapes=*/{PartialTensorShape({2})},
411 /*node_name=*/"concatenate");
412 return PaddedBatchDatasetParams(
413 /*input_dataset_params=*/concatenate_dataset_params,
414 /*batch_size=*/2,
415 /*padded_shapes=*/
416 {CreateTensor<int64>(TensorShape{1}, {3})},
417 /*padded_values=*/
418 {CreateTensor<int64>(TensorShape{}, {1}),
419 CreateTensor<int64>(TensorShape{}, {1})},
420 /*drop_remainder=*/false,
421 /*parallel_copy=*/true,
422 /*output_dtypes=*/{DT_INT64},
423 /*output_shapes=*/{PartialTensorShape({-1, -1})},
424 /*num_padded_shapes=*/2,
425 /*node_name=*/kNodeName);
426 }
427
428 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingValuesDType()429 PaddedBatchDatasetParamsWithInvalidPaddingValuesDType() {
430 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
431 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
432 {{0, 1, 2, 3, 4, 5}}),
433 /*node_name=*/"tensor_slice_0");
434 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
435 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
436 {{6, 7, 8, 9, 10, 11}}),
437 /*node_name=*/"tensor_slice_1");
438 auto concatenate_dataset_params =
439 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
440 std::move(tensor_slice_dataset_params_1),
441 /*output_dtypes=*/{DT_INT64},
442 /*output_shapes=*/{PartialTensorShape({2})},
443 /*node_name=*/"concatenate");
444 return PaddedBatchDatasetParams(
445 /*input_dataset_params=*/concatenate_dataset_params,
446 /*batch_size=*/2,
447 /*padded_shapes=*/
448 {CreateTensor<int64>(TensorShape{1}, {3})},
449 /*padded_values=*/
450 {CreateTensor<tstring>(TensorShape{}, {"a"})},
451 /*drop_remainder=*/false,
452 /*parallel_copy=*/true,
453 /*output_dtypes=*/{DT_INT64},
454 /*output_shapes=*/{PartialTensorShape({-1, -1})},
455 /*num_padded_shapes=*/1,
456 /*node_name=*/kNodeName);
457 }
458
459 PaddedBatchDatasetParams
PaddedBatchDatasetParamsWithInvalidPaddingValuesShape()460 PaddedBatchDatasetParamsWithInvalidPaddingValuesShape() {
461 auto tensor_slice_dataset_params_0 = TensorSliceDatasetParams(
462 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
463 {{0, 1, 2, 3, 4, 5}}),
464 /*node_name=*/"tensor_slice_0");
465 auto tensor_slice_dataset_params_1 = TensorSliceDatasetParams(
466 /*components=*/CreateTensors<int64>(TensorShape{3, 2},
467 {{6, 7, 8, 9, 10, 11}}),
468 /*node_name=*/"tensor_slice_1");
469 auto concatenate_dataset_params =
470 ConcatenateDatasetParams(std::move(tensor_slice_dataset_params_0),
471 std::move(tensor_slice_dataset_params_1),
472 /*output_dtypes=*/{DT_INT64},
473 /*output_shapes=*/{PartialTensorShape({2})},
474 /*node_name=*/"concatenate");
475 return PaddedBatchDatasetParams(
476 /*input_dataset_params=*/concatenate_dataset_params,
477 /*batch_size=*/2,
478 /*padded_shapes=*/
479 {CreateTensor<int64>(TensorShape{1}, {3})},
480 /*padded_values=*/
481 {CreateTensor<int64>(TensorShape{1}, {1})},
482 /*drop_remainder=*/false,
483 /*parallel_copy=*/true,
484 /*output_dtypes=*/{DT_INT64},
485 /*output_shapes=*/{PartialTensorShape({-1, -1})},
486 /*num_padded_shapes=*/1,
487 /*node_name=*/kNodeName);
488 }
489
GetNextTestCases()490 std::vector<GetNextTestCase<PaddedBatchDatasetParams>> GetNextTestCases() {
491 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
492 /*expected_outputs=*/
493 CreateTensors<int64>(
494 TensorShape{2, 3},
495 {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 7, 1}, {8, 9, 1, 10, 11, 1}})},
496 {/*dataset_params=*/PaddedBatchDatasetParams2(),
497 /*expected_outputs=*/
498 CreateTensors<int64>(
499 TensorShape{2, 3},
500 {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
501 {/*dataset_params=*/PaddedBatchDatasetParams3(),
502 /*expected_outputs=*/
503 {CreateTensor<int64>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
504 CreateTensor<int64>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
505 CreateTensor<int64>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1}),
506 CreateTensor<int64>(TensorShape{1, 3}, {9, 1, 1})}},
507 {/*dataset_params=*/PaddedBatchDatasetParams4(),
508 /*expected_outputs=*/
509 CreateTensors<int64>(
510 TensorShape{2, 3},
511 {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
512 {/*dataset_params=*/PaddedBatchDatasetParams5(),
513 /*expected_outputs=*/
514 {CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
515 CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
516 CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
517 CreateTensor<int64>(TensorShape{1, 1}, {9})}},
518 {/*dataset_params=*/PaddedBatchDatasetParams6(),
519 /*expected_outputs=*/
520 {CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
521 CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
522 CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
523 CreateTensor<int64>(TensorShape{1, 1}, {9})}},
524 {/*dataset_params=*/PaddedBatchDatasetParams7(),
525 /*expected_outputs=*/{}}};
526 }
527
ITERATOR_GET_NEXT_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,GetNextTestCases ())528 ITERATOR_GET_NEXT_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
529 GetNextTestCases())
530
531 TEST_F(PaddedBatchDatasetOpTest, DatasetNodeName) {
532 auto dataset_params = PaddedBatchDatasetParams1();
533 TF_ASSERT_OK(Initialize(dataset_params));
534 TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
535 }
536
TEST_F(PaddedBatchDatasetOpTest,DatasetTypeString)537 TEST_F(PaddedBatchDatasetOpTest, DatasetTypeString) {
538 auto dataset_params = PaddedBatchDatasetParams1();
539 TF_ASSERT_OK(Initialize(dataset_params));
540 name_utils::OpNameParams params;
541 params.op_version = dataset_params.op_version();
542 TF_ASSERT_OK(CheckDatasetTypeString(
543 name_utils::OpName(PaddedBatchDatasetOp::kDatasetType, params)));
544 }
545
546 std::vector<DatasetOutputDtypesTestCase<PaddedBatchDatasetParams>>
DatasetOutputDtypesTestCases()547 DatasetOutputDtypesTestCases() {
548 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
549 /*expected_output_dtypes=*/{DT_INT64}},
550 {/*dataset_params=*/PaddedBatchDatasetParams2(),
551 /*expected_output_dtypes=*/{DT_INT64}},
552 {/*dataset_params=*/PaddedBatchDatasetParams3(),
553 /*expected_output_dtypes=*/{DT_INT64}},
554 {/*dataset_params=*/PaddedBatchDatasetParams4(),
555 /*expected_output_dtypes=*/{DT_INT64}},
556 {/*dataset_params=*/PaddedBatchDatasetParams5(),
557 /*expected_output_dtypes=*/{DT_INT64}},
558 {/*dataset_params=*/PaddedBatchDatasetParams6(),
559 /*expected_output_dtypes=*/{DT_INT64}},
560 {/*dataset_params=*/PaddedBatchDatasetParams7(),
561 /*expected_output_dtypes=*/{DT_INT64}}};
562 }
563
DATASET_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,DatasetOutputDtypesTestCases ())564 DATASET_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
565 DatasetOutputDtypesTestCases())
566
567 std::vector<DatasetOutputShapesTestCase<PaddedBatchDatasetParams>>
568 DatasetOutputShapesTestCases() {
569 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
570 /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
571 {/*dataset_params=*/PaddedBatchDatasetParams2(),
572 /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
573 {/*dataset_params=*/PaddedBatchDatasetParams3(),
574 /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
575 {/*dataset_params=*/PaddedBatchDatasetParams4(),
576 /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
577 {/*dataset_params=*/PaddedBatchDatasetParams5(),
578 /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
579 {/*dataset_params=*/PaddedBatchDatasetParams6(),
580 /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
581 {/*dataset_params=*/PaddedBatchDatasetParams7(),
582 /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}}};
583 }
584
DATASET_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,DatasetOutputShapesTestCases ())585 DATASET_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
586 DatasetOutputShapesTestCases())
587
588 std::vector<CardinalityTestCase<PaddedBatchDatasetParams>>
589 CardinalityTestCases() {
590 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
591 /*expected_cardinality=*/3},
592 {/*dataset_params=*/PaddedBatchDatasetParams2(),
593 /*expected_cardinality=*/3},
594 {/*dataset_params=*/PaddedBatchDatasetParams3(),
595 /*expected_cardinality=*/4},
596 {/*dataset_params=*/PaddedBatchDatasetParams4(),
597 /*expected_cardinality=*/3},
598 {/*dataset_params=*/PaddedBatchDatasetParams5(),
599 /*expected_cardinality=*/4},
600 {/*dataset_params=*/PaddedBatchDatasetParams6(),
601 /*expected_cardinality=*/4},
602 {/*dataset_params=*/PaddedBatchDatasetParams7(),
603 /*expected_cardinality=*/0}};
604 }
605
DATASET_CARDINALITY_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,CardinalityTestCases ())606 DATASET_CARDINALITY_TEST_P(PaddedBatchDatasetOpTest, PaddedBatchDatasetParams,
607 CardinalityTestCases())
608
609 std::vector<IteratorOutputDtypesTestCase<PaddedBatchDatasetParams>>
610 IteratorOutputDtypesTestCases() {
611 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
612 /*expected_output_dtypes=*/{DT_INT64}},
613 {/*dataset_params=*/PaddedBatchDatasetParams2(),
614 /*expected_output_dtypes=*/{DT_INT64}},
615 {/*dataset_params=*/PaddedBatchDatasetParams3(),
616 /*expected_output_dtypes=*/{DT_INT64}},
617 {/*dataset_params=*/PaddedBatchDatasetParams4(),
618 /*expected_output_dtypes=*/{DT_INT64}},
619 {/*dataset_params=*/PaddedBatchDatasetParams5(),
620 /*expected_output_dtypes=*/{DT_INT64}},
621 {/*dataset_params=*/PaddedBatchDatasetParams6(),
622 /*expected_output_dtypes=*/{DT_INT64}},
623 {/*dataset_params=*/PaddedBatchDatasetParams7(),
624 /*expected_output_dtypes=*/{DT_INT64}}};
625 }
626
ITERATOR_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,IteratorOutputDtypesTestCases ())627 ITERATOR_OUTPUT_DTYPES_TEST_P(PaddedBatchDatasetOpTest,
628 PaddedBatchDatasetParams,
629 IteratorOutputDtypesTestCases())
630
631 std::vector<IteratorOutputShapesTestCase<PaddedBatchDatasetParams>>
632 IteratorOutputShapesTestCases() {
633 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
634 /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
635 {/*dataset_params=*/PaddedBatchDatasetParams2(),
636 /*expected_output_shapes=*/{PartialTensorShape({2, 3})}},
637 {/*dataset_params=*/PaddedBatchDatasetParams3(),
638 /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
639 {/*dataset_params=*/PaddedBatchDatasetParams4(),
640 /*expected_output_shapes=*/{PartialTensorShape({-1, 3})}},
641 {/*dataset_params=*/PaddedBatchDatasetParams5(),
642 /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
643 {/*dataset_params=*/PaddedBatchDatasetParams6(),
644 /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}},
645 {/*dataset_params=*/PaddedBatchDatasetParams7(),
646 /*expected_output_shapes=*/{PartialTensorShape({-1, -1})}}};
647 }
648
ITERATOR_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,IteratorOutputShapesTestCases ())649 ITERATOR_OUTPUT_SHAPES_TEST_P(PaddedBatchDatasetOpTest,
650 PaddedBatchDatasetParams,
651 IteratorOutputShapesTestCases())
652
653 TEST_F(PaddedBatchDatasetOpTest, IteratorPrefix) {
654 auto dataset_params = PaddedBatchDatasetParams1();
655 TF_ASSERT_OK(Initialize(dataset_params));
656 name_utils::IteratorPrefixParams params;
657 params.op_version = dataset_params.op_version();
658 TF_ASSERT_OK(CheckIteratorPrefix(
659 name_utils::IteratorPrefix(PaddedBatchDatasetOp::kDatasetType,
660 dataset_params.iterator_prefix(), params)));
661 }
662
663 std::vector<IteratorSaveAndRestoreTestCase<PaddedBatchDatasetParams>>
IteratorSaveAndRestoreTestCases()664 IteratorSaveAndRestoreTestCases() {
665 return {{/*dataset_params=*/PaddedBatchDatasetParams1(),
666 /*breakpoints=*/{0, 2, 5},
667 /*expected_outputs=*/
668 CreateTensors<int64>(
669 TensorShape{2, 3},
670 {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 7, 1}, {8, 9, 1, 10, 11, 1}})},
671 {/*dataset_params=*/PaddedBatchDatasetParams2(),
672 /*breakpoints=*/{0, 2, 5},
673 /*expected_outputs=*/
674 CreateTensors<int64>(
675 TensorShape{2, 3},
676 {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
677 {/*dataset_params=*/PaddedBatchDatasetParams3(),
678 /*breakpoints=*/{0, 2, 5},
679 /*expected_outputs=*/
680 {CreateTensor<int64>(TensorShape{2, 3}, {0, 1, 1, 2, 3, 1}),
681 CreateTensor<int64>(TensorShape{2, 3}, {4, 5, 1, 6, 1, 1}),
682 CreateTensor<int64>(TensorShape{2, 3}, {7, 1, 1, 8, 1, 1}),
683 CreateTensor<int64>(TensorShape{1, 3}, {9, 1, 1})}},
684 {/*dataset_params=*/PaddedBatchDatasetParams4(),
685 /*breakpoints=*/{0, 2, 5},
686 /*expected_outputs=*/
687 CreateTensors<int64>(
688 TensorShape{2, 3},
689 {{0, 1, 1, 2, 3, 1}, {4, 5, 1, 6, 1, 1}, {7, 1, 1, 8, 1, 1}})},
690 {/*dataset_params=*/PaddedBatchDatasetParams5(),
691 /*breakpoints=*/{0, 2, 5},
692 /*expected_outputs=*/
693 {CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
694 CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
695 CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
696 CreateTensor<int64>(TensorShape{1, 1}, {9})}},
697 {/*dataset_params=*/PaddedBatchDatasetParams6(),
698 /*breakpoints=*/{0, 2, 5},
699 /*expected_outputs=*/
700 {CreateTensor<int64>(TensorShape{2, 2}, {0, 1, 2, 3}),
701 CreateTensor<int64>(TensorShape{2, 2}, {4, 5, 6, 1}),
702 CreateTensor<int64>(TensorShape{2, 1}, {7, 8}),
703 CreateTensor<int64>(TensorShape{1, 1}, {9})}},
704 {/*dataset_params=*/PaddedBatchDatasetParams7(),
705 /*breakpoints=*/{0, 2, 5},
706 /*expected_outputs=*/{}}};
707 }
708
ITERATOR_SAVE_AND_RESTORE_TEST_P(PaddedBatchDatasetOpTest,PaddedBatchDatasetParams,IteratorSaveAndRestoreTestCases ())709 ITERATOR_SAVE_AND_RESTORE_TEST_P(PaddedBatchDatasetOpTest,
710 PaddedBatchDatasetParams,
711 IteratorSaveAndRestoreTestCases())
712
713 TEST_F(PaddedBatchDatasetOpTest, ShortPadding) {
714 auto dataset_params = PaddedBatchDatasetParamsWithShortPaddingShape();
715 TF_ASSERT_OK(Initialize(dataset_params));
716 bool end_of_sequence = false;
717 std::vector<Tensor> out_tensors;
718 EXPECT_EQ(
719 iterator_->GetNext(iterator_ctx_.get(), &out_tensors, &end_of_sequence)
720 .code(),
721 tensorflow::error::DATA_LOSS);
722 }
723
TEST_F(PaddedBatchDatasetOpTest,InvalidPaddedShapes)724 TEST_F(PaddedBatchDatasetOpTest, InvalidPaddedShapes) {
725 auto dataset_params = PaddedBatchDatasetParamsWithInvalidPaddingShape();
726 TF_ASSERT_OK(Initialize(dataset_params));
727 bool end_of_sequence = false;
728 std::vector<Tensor> out_tensors;
729 EXPECT_EQ(
730 iterator_->GetNext(iterator_ctx_.get(), &out_tensors, &end_of_sequence)
731 .code(),
732 tensorflow::error::INVALID_ARGUMENT);
733 }
734
735 class ParameterizedInvalidArgumentTest
736 : public PaddedBatchDatasetOpTest,
737 public ::testing::WithParamInterface<PaddedBatchDatasetParams> {};
738
TEST_P(ParameterizedInvalidArgumentTest,InvalidPredicateFunc)739 TEST_P(ParameterizedInvalidArgumentTest, InvalidPredicateFunc) {
740 auto dataset_params = GetParam();
741 EXPECT_EQ(Initialize(dataset_params).code(),
742 tensorflow::error::INVALID_ARGUMENT);
743 }
744
745 INSTANTIATE_TEST_SUITE_P(
746 PaddedBatchDatasetOpTest, ParameterizedInvalidArgumentTest,
747 ::testing::ValuesIn(
748 {PaddedBatchDatasetParamsWithInvalidBatchSize(),
749 PaddedBatchDatasetParamsWithInvalidPaddingShapesSize(),
750 PaddedBatchDatasetParamsWithInvalidPaddingValuesSize(),
751 PaddedBatchDatasetParamsWithInvalidPaddingValuesDType(),
752 PaddedBatchDatasetParamsWithInvalidPaddingValuesShape()}));
753
754 } // namespace
755 } // namespace data
756 } // namespace tensorflow
757