1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "tensorflow/core/data/dataset_test_base.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/serialization_utils.h"
22
23 namespace tensorflow {
24 namespace data {
25 namespace {
26
27 constexpr char kNodeName[] = "sparse_tensor_slice_dataset";
28 constexpr char kDatasetType[] = "SparseTensorSlice";
29
30 class SparseTensorSliceDatasetParams : public DatasetParams {
31 public:
SparseTensorSliceDatasetParams(Tensor indices,Tensor values,Tensor dense_shape,DataType tvalues,string node_name)32 SparseTensorSliceDatasetParams(Tensor indices, Tensor values,
33 Tensor dense_shape, DataType tvalues,
34 string node_name)
35 : DatasetParams({tvalues}, {PartialTensorShape({})},
36 std::move(node_name)),
37 indices_(std::move(indices)),
38 values_(std::move(values)),
39 dense_shape_(std::move(dense_shape)),
40 tvalues_(tvalues) {
41 iterator_prefix_ = "Iterator";
42 }
43
GetInputTensors() const44 std::vector<Tensor> GetInputTensors() const override {
45 return {indices_, values_, dense_shape_};
46 }
47
GetInputNames(std::vector<string> * input_names) const48 Status GetInputNames(std::vector<string>* input_names) const override {
49 input_names->clear();
50 input_names->emplace_back("indices");
51 input_names->emplace_back("values");
52 input_names->emplace_back("dense_shape");
53 return Status::OK();
54 }
55
GetAttributes(AttributeVector * attr_vector) const56 Status GetAttributes(AttributeVector* attr_vector) const override {
57 attr_vector->clear();
58 attr_vector->emplace_back("Tvalues", tvalues_);
59 return Status::OK();
60 }
61
dataset_type() const62 string dataset_type() const override { return kDatasetType; }
63
64 private:
65 Tensor indices_;
66 Tensor values_;
67 Tensor dense_shape_;
68 DataType tvalues_;
69 };
70
71 class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase {};
72
TwoDimsSparseTensorSliceDatasetParams()73 SparseTensorSliceDatasetParams TwoDimsSparseTensorSliceDatasetParams() {
74 return SparseTensorSliceDatasetParams(
75 /*indices=*/CreateTensor<int64>({2, 2}, {0, 0, 1, 1}),
76 /*values=*/CreateTensor<int32>({2}, {888, 999}),
77 /*dense_shape=*/CreateTensor<int64>({2}, {2, 2}),
78 /*tvalues=*/DT_INT32,
79 /*node_name=*/kNodeName);
80 }
81
ThreeDimsSparseTensorSliceDatasetParams()82 SparseTensorSliceDatasetParams ThreeDimsSparseTensorSliceDatasetParams() {
83 return SparseTensorSliceDatasetParams(
84 /*indices=*/CreateTensor<int64>({2, 3}, {0, 0, 0, 1, 1, 1}),
85 /*values=*/CreateTensor<double>({2}, {888.0, 999.0}),
86 /*dense_shape=*/CreateTensor<int64>({3}, {2, 2, 2}),
87 /*tvalues=*/DT_DOUBLE,
88 /*node_name=*/kNodeName);
89 }
90
FourDimsSparseTensorSliceDatasetParams()91 SparseTensorSliceDatasetParams FourDimsSparseTensorSliceDatasetParams() {
92 return SparseTensorSliceDatasetParams(
93 /*indices=*/CreateTensor<int64>({2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}),
94 /*values=*/CreateTensor<tstring>({2}, {"a", "b"}),
95 /*dense_shape=*/CreateTensor<int64>({4}, {3, 2, 2, 2}),
96 /*tvalues=*/DT_STRING,
97 /*node_name=*/kNodeName);
98 }
99
FiveDimsSparseTensorSliceDatasetParams()100 SparseTensorSliceDatasetParams FiveDimsSparseTensorSliceDatasetParams() {
101 return SparseTensorSliceDatasetParams(
102 /*indices=*/CreateTensor<int64>({2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}),
103 /*values=*/CreateTensor<int32>({2}, {888, 999}),
104 /*dense_shape=*/CreateTensor<int64>({5}, {3, 2, 2, 2, 2}),
105 /*tvalues=*/DT_INT32,
106 /*node_name=*/kNodeName);
107 }
108
109 template <typename T>
110 struct GetNextTestCase {
111 T dataset_params;
112 std::vector<std::vector<Tensor>> expected_outputs;
113 };
114
115 std::vector<GetNextTestCase<SparseTensorSliceDatasetParams>>
GetNextTestCases()116 GetNextTestCases() {
117 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
118 /*expected_outputs=*/
119 {{/*indices*/ CreateTensor<int64>({1, 1}, {0}),
120 /*values*/ CreateTensor<int32>({1}, {888}),
121 /*dense_shape*/ CreateTensor<int64>({1}, {2})},
122 {/*indices*/ CreateTensor<int64>({1, 1}, {1}),
123 /*values*/ CreateTensor<int32>({1}, {999}),
124 /*dense_shape*/ CreateTensor<int64>({1}, {2})}}},
125 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
126 /*expected_outputs=*/
127 {{/*indices*/ CreateTensor<int64>({1, 2}, {0, 0}),
128 /*values*/ CreateTensor<double>({1}, {888.0}),
129 /*dense_shape*/ CreateTensor<int64>({2}, {2, 2})},
130 {{/*indices*/ CreateTensor<int64>({1, 2}, {1, 1})},
131 {/*values*/ CreateTensor<double>({1}, {999.0})},
132 {/*dense_shape*/ CreateTensor<int64>({2}, {2, 2})}}}},
133 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
134 /*expected_outputs=*/
135 {{/*indices*/ CreateTensor<int64>({1, 3}, {0, 0, 0}),
136 /*values*/ CreateTensor<tstring>({1}, {"a"}),
137 /*dense_shape*/
138 CreateTensor<int64>({3}, {2, 2, 2})},
139 {/*indices*/ CreateTensor<int64>({1, 3}, {1, 1, 1}),
140 /*values*/ CreateTensor<tstring>({1}, {"b"}),
141 /*dense_shape*/
142 CreateTensor<int64>({3}, {2, 2, 2})},
143 {/*indices*/ CreateTensor<int64>({0, 3}, {}),
144 /*values*/ CreateTensor<tstring>({0}, {}),
145 /*dense_shape*/
146 CreateTensor<int64>({3}, {2, 2, 2})}}},
147 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
148 /*expected_outputs=*/{
149 {/*indices*/ CreateTensor<int64>({1, 4}, {0, 0, 0, 0}),
150 /*values*/ CreateTensor<int32>({1}, {888}),
151 /*dense_shape*/
152 CreateTensor<int64>({4}, {2, 2, 2, 2})},
153 {/*indices*/ CreateTensor<int64>({1, 4}, {1, 1, 1, 1}),
154 /*values*/ CreateTensor<int32>({1}, {999}),
155 /*dense_shape*/
156 CreateTensor<int64>({4}, {2, 2, 2, 2})},
157 {/*indices*/ CreateTensor<int64>({0, 4}, {}),
158 /*values*/ CreateTensor<int32>({0}, {}),
159 /*dense_shape*/
160 CreateTensor<int64>({4}, {2, 2, 2, 2})}}}};
161 }
162
163 class ParameterizedGetNextTest
164 : public SparseTensorSliceDatasetOpTest,
165 public ::testing::WithParamInterface<
166 GetNextTestCase<SparseTensorSliceDatasetParams>> {};
167
TEST_P(ParameterizedGetNextTest,GetNext)168 TEST_P(ParameterizedGetNextTest, GetNext) {
169 auto test_case = GetParam();
170 TF_ASSERT_OK(Initialize(test_case.dataset_params));
171
172 bool end_of_sequence = false;
173 std::vector<Tensor> out_tensors;
174 auto expected_outputs_it = test_case.expected_outputs.begin();
175 while (!end_of_sequence) {
176 TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
177 &end_of_sequence));
178 if (!end_of_sequence) {
179 TF_EXPECT_OK(ExpectEqual(out_tensors[0], expected_outputs_it->at(0)));
180 TF_EXPECT_OK(ExpectEqual(out_tensors[1], expected_outputs_it->at(1)));
181 TF_EXPECT_OK(ExpectEqual(out_tensors[2], expected_outputs_it->at(2)));
182 expected_outputs_it++;
183 }
184 }
185 EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
186 }
187
188 INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest,
189 ParameterizedGetNextTest,
190 ::testing::ValuesIn(GetNextTestCases()));
191
TEST_F(SparseTensorSliceDatasetOpTest,DatasetTypeString)192 TEST_F(SparseTensorSliceDatasetOpTest, DatasetTypeString) {
193 auto dataset_params = TwoDimsSparseTensorSliceDatasetParams();
194 TF_ASSERT_OK(Initialize(dataset_params));
195 TF_ASSERT_OK(CheckDatasetTypeString(name_utils::OpName(kDatasetType)));
196 }
197
TEST_F(SparseTensorSliceDatasetOpTest,DatasetNodeName)198 TEST_F(SparseTensorSliceDatasetOpTest, DatasetNodeName) {
199 auto dataset_params = TwoDimsSparseTensorSliceDatasetParams();
200 TF_ASSERT_OK(Initialize(dataset_params));
201 TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
202 }
203
204 std::vector<DatasetOutputDtypesTestCase<SparseTensorSliceDatasetParams>>
DatasetOutputDtypesTestCases()205 DatasetOutputDtypesTestCases() {
206 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
207 /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}},
208 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
209 /*expected_output_dtypes=*/{DT_INT64, DT_DOUBLE, DT_INT64}},
210 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
211 /*expected_output_dtypes=*/{DT_INT64, DT_STRING, DT_INT64}},
212 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
213 /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}}};
214 }
215
DATASET_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,DatasetOutputDtypesTestCases ())216 DATASET_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,
217 SparseTensorSliceDatasetParams,
218 DatasetOutputDtypesTestCases())
219
220 std::vector<DatasetOutputShapesTestCase<SparseTensorSliceDatasetParams>>
221 DatasetOutputShapesTestCases() {
222 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
223 /*expected_output_shapes=*/{PartialTensorShape({1, 1}),
224 PartialTensorShape({1}),
225 PartialTensorShape({1})}},
226 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
227 /*expected_output_shapes=*/{PartialTensorShape({1, 2}),
228 PartialTensorShape({1}),
229 PartialTensorShape({2})}},
230 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
231 /*expected_output_shapes=*/{PartialTensorShape({1, 3}),
232 PartialTensorShape({1}),
233 PartialTensorShape({3})}},
234 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
235 /*expected_output_shapes=*/{PartialTensorShape({1, 4}),
236 PartialTensorShape({1}),
237 PartialTensorShape({4})}}};
238 }
239
DATASET_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,DatasetOutputShapesTestCases ())240 DATASET_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,
241 SparseTensorSliceDatasetParams,
242 DatasetOutputShapesTestCases())
243
244 std::vector<CardinalityTestCase<SparseTensorSliceDatasetParams>>
245 CardinalityTestCases() {
246 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
247 /*expected_cardinality=*/2},
248 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
249 /*expected_cardinality=*/2},
250 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
251 /*expected_cardinality=*/3},
252 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
253 /*expected_cardinality=*/3}};
254 }
255
DATASET_CARDINALITY_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,CardinalityTestCases ())256 DATASET_CARDINALITY_TEST_P(SparseTensorSliceDatasetOpTest,
257 SparseTensorSliceDatasetParams,
258 CardinalityTestCases())
259
260 std::vector<IteratorOutputDtypesTestCase<SparseTensorSliceDatasetParams>>
261 IteratorOutputDtypesTestCases() {
262 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
263 /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}},
264 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
265 /*expected_output_dtypes=*/{DT_INT64, DT_DOUBLE, DT_INT64}},
266 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
267 /*expected_output_dtypes=*/{DT_INT64, DT_STRING, DT_INT64}},
268 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
269 /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}}};
270 }
271
ITERATOR_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,IteratorOutputDtypesTestCases ())272 ITERATOR_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,
273 SparseTensorSliceDatasetParams,
274 IteratorOutputDtypesTestCases())
275
276 std::vector<IteratorOutputShapesTestCase<SparseTensorSliceDatasetParams>>
277 IteratorOutputShapesTestCases() {
278 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
279 /*expected_output_shapes=*/{PartialTensorShape({1, 1}),
280 PartialTensorShape({1}),
281 PartialTensorShape({1})}},
282 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
283 /*expected_output_shapes=*/{PartialTensorShape({1, 2}),
284 PartialTensorShape({1}),
285 PartialTensorShape({2})}},
286 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
287 /*expected_output_shapes=*/{PartialTensorShape({1, 3}),
288 PartialTensorShape({1}),
289 PartialTensorShape({3})}},
290 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
291 /*expected_output_shapes=*/{PartialTensorShape({1, 4}),
292 PartialTensorShape({1}),
293 PartialTensorShape({4})}}};
294 }
295
ITERATOR_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,IteratorOutputShapesTestCases ())296 ITERATOR_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,
297 SparseTensorSliceDatasetParams,
298 IteratorOutputShapesTestCases())
299
300 TEST_F(SparseTensorSliceDatasetOpTest, IteratorPrefix) {
301 auto dataset_params = TwoDimsSparseTensorSliceDatasetParams();
302 TF_ASSERT_OK(Initialize(dataset_params));
303 TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
304 kDatasetType, dataset_params.iterator_prefix())));
305 }
306
307 template <typename T>
308 struct IteratorSaveAndRestoreTestCase {
309 T dataset_params;
310 std::vector<int> breakpoints;
311 std::vector<std::vector<Tensor>> expected_outputs;
312 };
313
314 std::vector<IteratorSaveAndRestoreTestCase<SparseTensorSliceDatasetParams>>
IteratorSaveAndRestoreTestCases()315 IteratorSaveAndRestoreTestCases() {
316 return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
317 /*breakpoints=*/{0, 1, 2},
318 /*expected_outputs=*/
319 {{/*indices*/ CreateTensor<int64>({1, 1}, {0}),
320 /*values*/ CreateTensor<int32>({1}, {888}),
321 /*dense_shape*/ CreateTensor<int64>({1}, {2})},
322 {/*indices*/ CreateTensor<int64>({1, 1}, {1}),
323 /*values*/ CreateTensor<int32>({1}, {999}),
324 /*dense_shape*/ CreateTensor<int64>({1}, {2})}}},
325 {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
326 /*breakpoints=*/{0, 1, 2},
327 /*expected_outputs=*/
328 {{/*indices*/ CreateTensor<int64>({1, 2}, {0, 0}),
329 /*values*/ CreateTensor<double>({1}, {888.0}),
330 /*dense_shape*/ CreateTensor<int64>({2}, {2, 2})},
331 {{/*indices*/ CreateTensor<int64>({1, 2}, {1, 1})},
332 {/*values*/ CreateTensor<double>({1}, {999.0})},
333 {/*dense_shape*/ CreateTensor<int64>({2}, {2, 2})}}}},
334 {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
335 /*breakpoints=*/{0, 1, 3},
336 /*expected_outputs=*/
337 {{/*indices*/ CreateTensor<int64>({1, 3}, {0, 0, 0}),
338 /*values*/ CreateTensor<tstring>({1}, {"a"}),
339 /*dense_shape*/
340 CreateTensor<int64>({3}, {2, 2, 2})},
341 {/*indices*/ CreateTensor<int64>({1, 3}, {1, 1, 1}),
342 /*values*/ CreateTensor<tstring>({1}, {"b"}),
343 /*dense_shape*/
344 CreateTensor<int64>({3}, {2, 2, 2})},
345 {/*indices*/ CreateTensor<int64>({0, 3}, {}),
346 /*values*/ CreateTensor<tstring>({0}, {}),
347 /*dense_shape*/
348 CreateTensor<int64>({3}, {2, 2, 2})}}},
349 {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
350 /*breakpoints=*/{0, 1, 2},
351 /*expected_outputs=*/
352 {{/*indices*/ CreateTensor<int64>({1, 4}, {0, 0, 0, 0}),
353 /*values*/ CreateTensor<int32>({1}, {888}),
354 /*dense_shape*/
355 CreateTensor<int64>({4}, {2, 2, 2, 2})},
356 {/*indices*/ CreateTensor<int64>({1, 4}, {1, 1, 1, 1}),
357 /*values*/ CreateTensor<int32>({1}, {999}),
358 /*dense_shape*/
359 CreateTensor<int64>({4}, {2, 2, 2, 2})},
360 {/*indices*/ CreateTensor<int64>({0, 4}, {}),
361 /*values*/ CreateTensor<int32>({0}, {}),
362 /*dense_shape*/
363 CreateTensor<int64>({4}, {2, 2, 2, 2})}}}};
364 }
365
366 class ParameterizedIteratorSaveAndRestoreTest
367 : public SparseTensorSliceDatasetOpTest,
368 public ::testing::WithParamInterface<
369 IteratorSaveAndRestoreTestCase<SparseTensorSliceDatasetParams>> {};
370
TEST_P(ParameterizedIteratorSaveAndRestoreTest,IteratorSaveAndRestore)371 TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
372 auto test_case = GetParam();
373 TF_ASSERT_OK(Initialize(test_case.dataset_params));
374
375 std::unique_ptr<SerializationContext> serialization_ctx;
376 TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
377
378 int cur_iteration = 0;
379 bool end_of_sequence = false;
380 int64_t num_slices = dataset_->Cardinality();
381 std::vector<Tensor> out_tensors;
382
383 for (int breakpoint : test_case.breakpoints) {
384 while (cur_iteration < breakpoint) {
385 TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
386 &end_of_sequence));
387 cur_iteration++;
388 }
389
390 if (breakpoint == 0) {
391 EXPECT_FALSE(end_of_sequence);
392 } else if (breakpoint <= num_slices) {
393 for (int i = 0; i < out_tensors.size(); ++i) {
394 TF_EXPECT_OK(ExpectEqual(
395 out_tensors[0], test_case.expected_outputs[cur_iteration - 1][0]));
396 TF_EXPECT_OK(ExpectEqual(
397 out_tensors[1], test_case.expected_outputs[cur_iteration - 1][1]));
398 TF_EXPECT_OK(ExpectEqual(
399 out_tensors[2], test_case.expected_outputs[cur_iteration - 1][2]));
400 }
401 } else {
402 EXPECT_TRUE(end_of_sequence);
403 }
404
405 VariantTensorDataWriter writer;
406 TF_ASSERT_OK(iterator_->Save(serialization_ctx.get(), &writer));
407 std::vector<const VariantTensorData*> data;
408 writer.GetData(&data);
409 VariantTensorDataReader reader(data);
410 TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader,
411 test_case.dataset_params.iterator_prefix(),
412 *dataset_, &iterator_));
413 }
414 }
415
416 INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest,
417 ParameterizedIteratorSaveAndRestoreTest,
418 ::testing::ValuesIn(IteratorSaveAndRestoreTestCases()));
419
420 } // namespace
421 } // namespace data
422 } // namespace tensorflow
423