1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/kernels/data/tensor_slice_dataset_op.h"
16
17 #include "tensorflow/core/kernels/data/dataset_test_base.h"
18 #include "tensorflow/core/kernels/data/dataset_utils.h"
19
20 namespace tensorflow {
21 namespace data {
22 namespace {
23
24 constexpr char kNodeName[] = "tensor_slice_dataset";
25
26 class TensorSliceDatasetOpTest : public DatasetOpsTestBase {};
27
PlainTensorSliceDatasetParams()28 TensorSliceDatasetParams PlainTensorSliceDatasetParams() {
29 std::vector<Tensor> components = {
30 CreateTensor<int64>(TensorShape({2}), {1, 2}),
31 CreateTensor<int64>(TensorShape({2, 2}), {1, 2, 3, 4}),
32 CreateTensor<uint32>(TensorShape({2}), {2, 3}),
33 CreateTensor<uint32>(TensorShape({2, 2}), {2, 3, 4, 5}),
34 CreateTensor<uint64>(TensorShape({2}), {3, 4}),
35 CreateTensor<uint64>(TensorShape({2, 2}), {3, 4, 5, 6}),
36 CreateTensor<double>(TensorShape({2, 1}), {37.0, 38.0}),
37 CreateTensor<tstring>(TensorShape({2, 1}), {"a", "b"})};
38
39 return {std::move(components), kNodeName};
40 }
41
NestedTensorSliceDatasetParams()42 TensorSliceDatasetParams NestedTensorSliceDatasetParams() {
43 std::vector<Tensor> components = {
44 CreateTensor<Variant>(
45 TensorShape({2, 1}),
46 {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0}),
47 CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
48 CreateTensor<Variant>(
49 TensorShape({2, 1}),
50 {CreateTensor<tstring>(TensorShape({1, 2}), {"a", "b"}),
51 CreateTensor<tstring>(TensorShape({1, 2}), {"c", "d"})}),
52 CreateTensor<int64>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})};
53
54 return {std::move(components), kNodeName};
55 }
56
GetNextTestCases()57 std::vector<GetNextTestCase<TensorSliceDatasetParams>> GetNextTestCases() {
58 return {
59 {/*dataset_params=*/PlainTensorSliceDatasetParams(),
60 /*expected_outputs=*/{CreateTensor<int64>(TensorShape({}), {1}),
61 CreateTensor<int64>(TensorShape({2}), {1, 2}),
62 CreateTensor<uint32>(TensorShape({}), {2}),
63 CreateTensor<uint32>(TensorShape({2}), {2, 3}),
64 CreateTensor<uint64>(TensorShape({}), {3}),
65 CreateTensor<uint64>(TensorShape({2}), {3, 4}),
66 CreateTensor<double>(TensorShape({1}), {37.0}),
67 CreateTensor<tstring>(TensorShape({1}), {"a"}),
68 CreateTensor<int64>(TensorShape({}), {2}),
69 CreateTensor<int64>(TensorShape({2}), {3, 4}),
70 CreateTensor<uint32>(TensorShape({}), {3}),
71 CreateTensor<uint32>(TensorShape({2}), {4, 5}),
72 CreateTensor<uint64>(TensorShape({}), {4}),
73 CreateTensor<uint64>(TensorShape({2}), {5, 6}),
74 CreateTensor<double>(TensorShape({1}), {38.0}),
75 CreateTensor<tstring>(TensorShape({1}), {"b"})}},
76 {/*dataset_params=*/NestedTensorSliceDatasetParams(),
77 /*expected_outputs=*/
78 {CreateTensor<Variant>(
79 TensorShape({1}),
80 {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
81 CreateTensor<Variant>(
82 TensorShape({1}),
83 {CreateTensor<tstring>(TensorShape({1, 2}), {"a", "b"})}),
84 CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
85 CreateTensor<Variant>(
86 TensorShape({1}),
87 {CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
88 CreateTensor<Variant>(
89 TensorShape({1}),
90 {CreateTensor<tstring>(TensorShape({1, 2}), {"c", "d"})}),
91 CreateTensor<int64>(TensorShape({3}), {4, 5, 6})}}};
92 }
93
94 class ParameterizedGetNextTest
95 : public TensorSliceDatasetOpTest,
96 public ::testing::WithParamInterface<
97 GetNextTestCase<TensorSliceDatasetParams>> {};
98
TEST_P(ParameterizedGetNextTest,GetNext)99 TEST_P(ParameterizedGetNextTest, GetNext) {
100 auto test_case = GetParam();
101 TF_ASSERT_OK(Initialize(test_case.dataset_params));
102
103 std::vector<string> input_names;
104 TF_ASSERT_OK(test_case.dataset_params.GetInputNames(&input_names));
105 size_t num_tensors_per_slice = input_names.size();
106 bool end_of_sequence = false;
107 std::vector<Tensor> out_tensors;
108 int cur_slice = 0;
109
110 while (!end_of_sequence) {
111 TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
112 &end_of_sequence));
113 for (int i = 0; i < out_tensors.size(); ++i) {
114 EXPECT_LT(i + num_tensors_per_slice * cur_slice,
115 test_case.expected_outputs.size());
116 if (out_tensors[i].dtype() == DT_VARIANT) {
117 // Currently `ExpectEqual()` does not support the variant tensor
118 // yet, so we manually cast the variant to numeric/string tensor.
119 const Tensor* output = out_tensors[i].scalar<Variant>()().get<Tensor>();
120 const Tensor* expected_output =
121 test_case.expected_outputs[i + num_tensors_per_slice * cur_slice]
122 .scalar<Variant>()()
123 .get<Tensor>();
124 TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
125 } else {
126 TF_EXPECT_OK(ExpectEqual(
127 out_tensors[i],
128 test_case.expected_outputs[i + num_tensors_per_slice * cur_slice]));
129 }
130 }
131 out_tensors.clear();
132 cur_slice++;
133 }
134 }
135
136 INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest, ParameterizedGetNextTest,
137 ::testing::ValuesIn(GetNextTestCases()));
138
TEST_F(TensorSliceDatasetOpTest,DatasetNodeName)139 TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) {
140 auto dataset_params = PlainTensorSliceDatasetParams();
141 TF_ASSERT_OK(Initialize(dataset_params));
142 TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
143 }
144
TEST_F(TensorSliceDatasetOpTest,DatasetTypeString)145 TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) {
146 auto dataset_params = PlainTensorSliceDatasetParams();
147 TF_ASSERT_OK(Initialize(dataset_params));
148 TF_ASSERT_OK(CheckDatasetTypeString(
149 name_utils::OpName(TensorSliceDatasetOp::kDatasetType)));
150 }
151
152 std::vector<DatasetOutputDtypesTestCase<TensorSliceDatasetParams>>
DatasetOutputTypesTestCases()153 DatasetOutputTypesTestCases() {
154 return {{PlainTensorSliceDatasetParams(),
155 PlainTensorSliceDatasetParams().output_dtypes()},
156 {NestedTensorSliceDatasetParams(),
157 NestedTensorSliceDatasetParams().output_dtypes()}};
158 }
159
DATASET_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,DatasetOutputTypesTestCases ())160 DATASET_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest, TensorSliceDatasetParams,
161 DatasetOutputTypesTestCases())
162
163 std::vector<DatasetOutputShapesTestCase<TensorSliceDatasetParams>>
164 DatasetOutputShapesTestCases() {
165 return {{PlainTensorSliceDatasetParams(),
166 PlainTensorSliceDatasetParams().output_shapes()},
167 {NestedTensorSliceDatasetParams(),
168 NestedTensorSliceDatasetParams().output_shapes()}};
169 }
170
DATASET_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,DatasetOutputShapesTestCases ())171 DATASET_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest, TensorSliceDatasetParams,
172 DatasetOutputShapesTestCases())
173
174 std::vector<CardinalityTestCase<TensorSliceDatasetParams>>
175 DatasetCardinalityTestCases() {
176 return {{PlainTensorSliceDatasetParams(), /*expected_cardinality=*/2},
177 {NestedTensorSliceDatasetParams(), /*expected_cardinality=*/2}};
178 }
179
DATASET_CARDINALITY_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,DatasetCardinalityTestCases ())180 DATASET_CARDINALITY_TEST_P(TensorSliceDatasetOpTest, TensorSliceDatasetParams,
181 DatasetCardinalityTestCases())
182
183 std::vector<IteratorOutputDtypesTestCase<TensorSliceDatasetParams>>
184 IteratorOutputTypesTestCases() {
185 return {{PlainTensorSliceDatasetParams(),
186 PlainTensorSliceDatasetParams().output_dtypes()},
187 {NestedTensorSliceDatasetParams(),
188 NestedTensorSliceDatasetParams().output_dtypes()}};
189 }
190
ITERATOR_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,IteratorOutputTypesTestCases ())191 ITERATOR_OUTPUT_DTYPES_TEST_P(TensorSliceDatasetOpTest,
192 TensorSliceDatasetParams,
193 IteratorOutputTypesTestCases())
194
195 std::vector<IteratorOutputShapesTestCase<TensorSliceDatasetParams>>
196 IteratorOutputShapesTestCases() {
197 return {{PlainTensorSliceDatasetParams(),
198 PlainTensorSliceDatasetParams().output_shapes()},
199 {NestedTensorSliceDatasetParams(),
200 NestedTensorSliceDatasetParams().output_shapes()}};
201 }
202
ITERATOR_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest,TensorSliceDatasetParams,IteratorOutputShapesTestCases ())203 ITERATOR_OUTPUT_SHAPES_TEST_P(TensorSliceDatasetOpTest,
204 TensorSliceDatasetParams,
205 IteratorOutputShapesTestCases())
206
207 TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) {
208 auto dataset_params = PlainTensorSliceDatasetParams();
209 TF_ASSERT_OK(Initialize(dataset_params));
210 TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
211 TensorSliceDatasetOp::kDatasetType, dataset_params.iterator_prefix())));
212 }
213
214 std::vector<IteratorSaveAndRestoreTestCase<TensorSliceDatasetParams>>
IteratorSaveAndRestoreTestCases()215 IteratorSaveAndRestoreTestCases() {
216 return {
217 {/*dataset_params=*/PlainTensorSliceDatasetParams(),
218 /*breakpoints=*/{0, 1, 2},
219 /*expected_outputs=*/
220 {CreateTensor<int64>(TensorShape({}), {1}),
221 CreateTensor<int64>(TensorShape({2}), {1, 2}),
222 CreateTensor<uint32>(TensorShape({}), {2}),
223 CreateTensor<uint32>(TensorShape({2}), {2, 3}),
224 CreateTensor<uint64>(TensorShape({}), {3}),
225 CreateTensor<uint64>(TensorShape({2}), {3, 4}),
226 CreateTensor<double>(TensorShape({1}), {37.0}),
227 CreateTensor<tstring>(TensorShape({1}), {"a"}),
228 CreateTensor<int64>(TensorShape({}), {2}),
229 CreateTensor<int64>(TensorShape({2}), {3, 4}),
230 CreateTensor<uint32>(TensorShape({}), {3}),
231 CreateTensor<uint32>(TensorShape({2}), {4, 5}),
232 CreateTensor<uint64>(TensorShape({}), {4}),
233 CreateTensor<uint64>(TensorShape({2}), {5, 6}),
234 CreateTensor<double>(TensorShape({1}), {38.0}),
235 CreateTensor<tstring>(TensorShape({1}), {"b"})}},
236 {/*dataset_params=*/NestedTensorSliceDatasetParams(),
237 /*breakpoints=*/{0, 1, 2},
238 /*expected_outputs=*/
239 {CreateTensor<Variant>(
240 TensorShape({1}),
241 {CreateTensor<double>(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}),
242 CreateTensor<Variant>(
243 TensorShape({1}),
244 {CreateTensor<tstring>(TensorShape({1, 2}), {"a", "b"})}),
245 CreateTensor<int64>(TensorShape({3}), {1, 2, 3}),
246 CreateTensor<Variant>(
247 TensorShape({1}),
248 {CreateTensor<double>(TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}),
249 CreateTensor<Variant>(
250 TensorShape({1}),
251 {CreateTensor<tstring>(TensorShape({1, 2}), {"c", "d"})}),
252 CreateTensor<int64>(TensorShape({3}), {4, 5, 6})}}};
253 }
254
255 class ParameterizedIteratorSaveAndRestoreTest
256 : public TensorSliceDatasetOpTest,
257 public ::testing::WithParamInterface<
258 IteratorSaveAndRestoreTestCase<TensorSliceDatasetParams>> {};
259
TEST_P(ParameterizedIteratorSaveAndRestoreTest,SaveAndRestore)260 TEST_P(ParameterizedIteratorSaveAndRestoreTest, SaveAndRestore) {
261 auto test_case = GetParam();
262 TF_ASSERT_OK(Initialize(test_case.dataset_params));
263
264 std::unique_ptr<SerializationContext> serialization_context;
265 TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
266
267 int cur_iteration = 0;
268 bool end_of_sequence = false;
269
270 auto params =
271 static_cast<TensorSliceDatasetParams&>(test_case.dataset_params);
272 int64 num_slices = params.num_slices();
273 size_t num_tensors_per_slice = params.num_tensors_per_slice();
274 std::vector<Tensor> out_tensors;
275 const std::vector<int>& breakpoints = test_case.breakpoints;
276 for (int breakpoint : breakpoints) {
277 while (cur_iteration < breakpoint) {
278 TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
279 &end_of_sequence));
280 cur_iteration++;
281 }
282
283 if (breakpoint == 0) {
284 EXPECT_FALSE(end_of_sequence);
285 } else if (breakpoint <= num_slices) {
286 for (int i = 0; i < out_tensors.size(); ++i) {
287 if (out_tensors[i].dtype() == DT_VARIANT) {
288 const Tensor* output =
289 out_tensors[i].scalar<Variant>()().get<Tensor>();
290 const Tensor* expected_output =
291 test_case
292 .expected_outputs[i +
293 num_tensors_per_slice * (cur_iteration - 1)]
294 .scalar<Variant>()()
295 .get<Tensor>();
296 TF_EXPECT_OK(ExpectEqual(*output, *expected_output));
297 } else {
298 TF_EXPECT_OK(ExpectEqual(
299 out_tensors[i],
300 test_case.expected_outputs[i + num_tensors_per_slice *
301 (cur_iteration - 1)]));
302 }
303 }
304 } else {
305 EXPECT_TRUE(end_of_sequence);
306 }
307
308 VariantTensorDataWriter writer;
309 TF_ASSERT_OK(iterator_->Save(serialization_context.get(), &writer));
310 std::vector<const VariantTensorData*> data;
311 writer.GetData(&data);
312 VariantTensorDataReader reader(data);
313 TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader, "Iterator",
314 *dataset_, &iterator_));
315 }
316 }
317
318 INSTANTIATE_TEST_SUITE_P(
319 TensorSliceDatasetOpTest, ParameterizedIteratorSaveAndRestoreTest,
320 ::testing::ValuesIn(IteratorSaveAndRestoreTestCases()));
321
TEST_F(TensorSliceDatasetOpTest,SplitProvider)322 TEST_F(TensorSliceDatasetOpTest, SplitProvider) {
323 auto params = TensorSliceDatasetParams(
324 CreateTensors<int64>(TensorShape({7}), {{6, 2, 3, 8, 7, 0, 10}}),
325 kNodeName);
326 TF_ASSERT_OK(InitializeRuntime(params));
327 TF_EXPECT_OK(CheckSplitProviderFullIteration(
328 params, CreateTensors<int64>(TensorShape({}),
329 {{6}, {2}, {3}, {8}, {7}, {0}, {10}})));
330 TF_EXPECT_OK(CheckSplitProviderShardedIteration(
331 params, /*num_shards=*/3, /*shard_index=*/1,
332 CreateTensors<int64>(TensorShape({}), {{2}, {7}})));
333 }
334
TEST_F(TensorSliceDatasetOpTest,SplitProviderEmpty)335 TEST_F(TensorSliceDatasetOpTest, SplitProviderEmpty) {
336 auto params = TensorSliceDatasetParams(
337 CreateTensors<int64>(TensorShape({0}), {{}}), kNodeName);
338 TF_ASSERT_OK(InitializeRuntime(params));
339 TF_EXPECT_OK(CheckSplitProviderFullIteration(
340 params, CreateTensors<int64>(TensorShape({}), {})));
341 TF_EXPECT_OK(CheckSplitProviderShardedIteration(
342 params, /*num_shards=*/3, /*shard_index=*/1,
343 CreateTensors<int64>(TensorShape({}), {})));
344 }
345
346 } // namespace
347 } // namespace data
348 } // namespace tensorflow
349