• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "tensorflow/core/data/dataset_test_base.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/data/serialization_utils.h"
22 
23 namespace tensorflow {
24 namespace data {
25 namespace {
26 
27 constexpr char kNodeName[] = "sparse_tensor_slice_dataset";
28 constexpr char kDatasetType[] = "SparseTensorSlice";
29 
30 class SparseTensorSliceDatasetParams : public DatasetParams {
31  public:
SparseTensorSliceDatasetParams(Tensor indices,Tensor values,Tensor dense_shape,DataType tvalues,string node_name)32   SparseTensorSliceDatasetParams(Tensor indices, Tensor values,
33                                  Tensor dense_shape, DataType tvalues,
34                                  string node_name)
35       : DatasetParams({tvalues}, {PartialTensorShape({})},
36                       std::move(node_name)),
37         indices_(std::move(indices)),
38         values_(std::move(values)),
39         dense_shape_(std::move(dense_shape)),
40         tvalues_(tvalues) {
41     iterator_prefix_ = "Iterator";
42   }
43 
GetInputTensors() const44   std::vector<Tensor> GetInputTensors() const override {
45     return {indices_, values_, dense_shape_};
46   }
47 
GetInputNames(std::vector<string> * input_names) const48   Status GetInputNames(std::vector<string>* input_names) const override {
49     input_names->clear();
50     input_names->emplace_back("indices");
51     input_names->emplace_back("values");
52     input_names->emplace_back("dense_shape");
53     return Status::OK();
54   }
55 
GetAttributes(AttributeVector * attr_vector) const56   Status GetAttributes(AttributeVector* attr_vector) const override {
57     attr_vector->clear();
58     attr_vector->emplace_back("Tvalues", tvalues_);
59     return Status::OK();
60   }
61 
dataset_type() const62   string dataset_type() const override { return kDatasetType; }
63 
64  private:
65   Tensor indices_;
66   Tensor values_;
67   Tensor dense_shape_;
68   DataType tvalues_;
69 };
70 
71 class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase {};
72 
TwoDimsSparseTensorSliceDatasetParams()73 SparseTensorSliceDatasetParams TwoDimsSparseTensorSliceDatasetParams() {
74   return SparseTensorSliceDatasetParams(
75       /*indices=*/CreateTensor<int64>({2, 2}, {0, 0, 1, 1}),
76       /*values=*/CreateTensor<int32>({2}, {888, 999}),
77       /*dense_shape=*/CreateTensor<int64>({2}, {2, 2}),
78       /*tvalues=*/DT_INT32,
79       /*node_name=*/kNodeName);
80 }
81 
ThreeDimsSparseTensorSliceDatasetParams()82 SparseTensorSliceDatasetParams ThreeDimsSparseTensorSliceDatasetParams() {
83   return SparseTensorSliceDatasetParams(
84       /*indices=*/CreateTensor<int64>({2, 3}, {0, 0, 0, 1, 1, 1}),
85       /*values=*/CreateTensor<double>({2}, {888.0, 999.0}),
86       /*dense_shape=*/CreateTensor<int64>({3}, {2, 2, 2}),
87       /*tvalues=*/DT_DOUBLE,
88       /*node_name=*/kNodeName);
89 }
90 
FourDimsSparseTensorSliceDatasetParams()91 SparseTensorSliceDatasetParams FourDimsSparseTensorSliceDatasetParams() {
92   return SparseTensorSliceDatasetParams(
93       /*indices=*/CreateTensor<int64>({2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}),
94       /*values=*/CreateTensor<tstring>({2}, {"a", "b"}),
95       /*dense_shape=*/CreateTensor<int64>({4}, {3, 2, 2, 2}),
96       /*tvalues=*/DT_STRING,
97       /*node_name=*/kNodeName);
98 }
99 
FiveDimsSparseTensorSliceDatasetParams()100 SparseTensorSliceDatasetParams FiveDimsSparseTensorSliceDatasetParams() {
101   return SparseTensorSliceDatasetParams(
102       /*indices=*/CreateTensor<int64>({2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}),
103       /*values=*/CreateTensor<int32>({2}, {888, 999}),
104       /*dense_shape=*/CreateTensor<int64>({5}, {3, 2, 2, 2, 2}),
105       /*tvalues=*/DT_INT32,
106       /*node_name=*/kNodeName);
107 }
108 
109 template <typename T>
110 struct GetNextTestCase {
111   T dataset_params;
112   std::vector<std::vector<Tensor>> expected_outputs;
113 };
114 
115 std::vector<GetNextTestCase<SparseTensorSliceDatasetParams>>
GetNextTestCases()116 GetNextTestCases() {
117   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
118            /*expected_outputs=*/
119            {{/*indices*/ CreateTensor<int64>({1, 1}, {0}),
120              /*values*/ CreateTensor<int32>({1}, {888}),
121              /*dense_shape*/ CreateTensor<int64>({1}, {2})},
122             {/*indices*/ CreateTensor<int64>({1, 1}, {1}),
123              /*values*/ CreateTensor<int32>({1}, {999}),
124              /*dense_shape*/ CreateTensor<int64>({1}, {2})}}},
125           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
126            /*expected_outputs=*/
127            {{/*indices*/ CreateTensor<int64>({1, 2}, {0, 0}),
128              /*values*/ CreateTensor<double>({1}, {888.0}),
129              /*dense_shape*/ CreateTensor<int64>({2}, {2, 2})},
130             {{/*indices*/ CreateTensor<int64>({1, 2}, {1, 1})},
131              {/*values*/ CreateTensor<double>({1}, {999.0})},
132              {/*dense_shape*/ CreateTensor<int64>({2}, {2, 2})}}}},
133           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
134            /*expected_outputs=*/
135            {{/*indices*/ CreateTensor<int64>({1, 3}, {0, 0, 0}),
136              /*values*/ CreateTensor<tstring>({1}, {"a"}),
137              /*dense_shape*/
138              CreateTensor<int64>({3}, {2, 2, 2})},
139             {/*indices*/ CreateTensor<int64>({1, 3}, {1, 1, 1}),
140              /*values*/ CreateTensor<tstring>({1}, {"b"}),
141              /*dense_shape*/
142              CreateTensor<int64>({3}, {2, 2, 2})},
143             {/*indices*/ CreateTensor<int64>({0, 3}, {}),
144              /*values*/ CreateTensor<tstring>({0}, {}),
145              /*dense_shape*/
146              CreateTensor<int64>({3}, {2, 2, 2})}}},
147           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
148            /*expected_outputs=*/{
149                {/*indices*/ CreateTensor<int64>({1, 4}, {0, 0, 0, 0}),
150                 /*values*/ CreateTensor<int32>({1}, {888}),
151                 /*dense_shape*/
152                 CreateTensor<int64>({4}, {2, 2, 2, 2})},
153                {/*indices*/ CreateTensor<int64>({1, 4}, {1, 1, 1, 1}),
154                 /*values*/ CreateTensor<int32>({1}, {999}),
155                 /*dense_shape*/
156                 CreateTensor<int64>({4}, {2, 2, 2, 2})},
157                {/*indices*/ CreateTensor<int64>({0, 4}, {}),
158                 /*values*/ CreateTensor<int32>({0}, {}),
159                 /*dense_shape*/
160                 CreateTensor<int64>({4}, {2, 2, 2, 2})}}}};
161 }
162 
163 class ParameterizedGetNextTest
164     : public SparseTensorSliceDatasetOpTest,
165       public ::testing::WithParamInterface<
166           GetNextTestCase<SparseTensorSliceDatasetParams>> {};
167 
TEST_P(ParameterizedGetNextTest,GetNext)168 TEST_P(ParameterizedGetNextTest, GetNext) {
169   auto test_case = GetParam();
170   TF_ASSERT_OK(Initialize(test_case.dataset_params));
171 
172   bool end_of_sequence = false;
173   std::vector<Tensor> out_tensors;
174   auto expected_outputs_it = test_case.expected_outputs.begin();
175   while (!end_of_sequence) {
176     TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
177                                     &end_of_sequence));
178     if (!end_of_sequence) {
179       TF_EXPECT_OK(ExpectEqual(out_tensors[0], expected_outputs_it->at(0)));
180       TF_EXPECT_OK(ExpectEqual(out_tensors[1], expected_outputs_it->at(1)));
181       TF_EXPECT_OK(ExpectEqual(out_tensors[2], expected_outputs_it->at(2)));
182       expected_outputs_it++;
183     }
184   }
185   EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end());
186 }
187 
188 INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest,
189                         ParameterizedGetNextTest,
190                         ::testing::ValuesIn(GetNextTestCases()));
191 
TEST_F(SparseTensorSliceDatasetOpTest,DatasetTypeString)192 TEST_F(SparseTensorSliceDatasetOpTest, DatasetTypeString) {
193   auto dataset_params = TwoDimsSparseTensorSliceDatasetParams();
194   TF_ASSERT_OK(Initialize(dataset_params));
195   TF_ASSERT_OK(CheckDatasetTypeString(name_utils::OpName(kDatasetType)));
196 }
197 
TEST_F(SparseTensorSliceDatasetOpTest,DatasetNodeName)198 TEST_F(SparseTensorSliceDatasetOpTest, DatasetNodeName) {
199   auto dataset_params = TwoDimsSparseTensorSliceDatasetParams();
200   TF_ASSERT_OK(Initialize(dataset_params));
201   TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
202 }
203 
204 std::vector<DatasetOutputDtypesTestCase<SparseTensorSliceDatasetParams>>
DatasetOutputDtypesTestCases()205 DatasetOutputDtypesTestCases() {
206   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
207            /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}},
208           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
209            /*expected_output_dtypes=*/{DT_INT64, DT_DOUBLE, DT_INT64}},
210           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
211            /*expected_output_dtypes=*/{DT_INT64, DT_STRING, DT_INT64}},
212           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
213            /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}}};
214 }
215 
DATASET_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,DatasetOutputDtypesTestCases ())216 DATASET_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,
217                              SparseTensorSliceDatasetParams,
218                              DatasetOutputDtypesTestCases())
219 
220 std::vector<DatasetOutputShapesTestCase<SparseTensorSliceDatasetParams>>
221 DatasetOutputShapesTestCases() {
222   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
223            /*expected_output_shapes=*/{PartialTensorShape({1, 1}),
224                                        PartialTensorShape({1}),
225                                        PartialTensorShape({1})}},
226           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
227            /*expected_output_shapes=*/{PartialTensorShape({1, 2}),
228                                        PartialTensorShape({1}),
229                                        PartialTensorShape({2})}},
230           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
231            /*expected_output_shapes=*/{PartialTensorShape({1, 3}),
232                                        PartialTensorShape({1}),
233                                        PartialTensorShape({3})}},
234           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
235            /*expected_output_shapes=*/{PartialTensorShape({1, 4}),
236                                        PartialTensorShape({1}),
237                                        PartialTensorShape({4})}}};
238 }
239 
DATASET_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,DatasetOutputShapesTestCases ())240 DATASET_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,
241                              SparseTensorSliceDatasetParams,
242                              DatasetOutputShapesTestCases())
243 
244 std::vector<CardinalityTestCase<SparseTensorSliceDatasetParams>>
245 CardinalityTestCases() {
246   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
247            /*expected_cardinality=*/2},
248           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
249            /*expected_cardinality=*/2},
250           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
251            /*expected_cardinality=*/3},
252           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
253            /*expected_cardinality=*/3}};
254 }
255 
DATASET_CARDINALITY_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,CardinalityTestCases ())256 DATASET_CARDINALITY_TEST_P(SparseTensorSliceDatasetOpTest,
257                            SparseTensorSliceDatasetParams,
258                            CardinalityTestCases())
259 
260 std::vector<IteratorOutputDtypesTestCase<SparseTensorSliceDatasetParams>>
261 IteratorOutputDtypesTestCases() {
262   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
263            /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}},
264           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
265            /*expected_output_dtypes=*/{DT_INT64, DT_DOUBLE, DT_INT64}},
266           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
267            /*expected_output_dtypes=*/{DT_INT64, DT_STRING, DT_INT64}},
268           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
269            /*expected_output_dtypes=*/{DT_INT64, DT_INT32, DT_INT64}}};
270 }
271 
ITERATOR_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,IteratorOutputDtypesTestCases ())272 ITERATOR_OUTPUT_DTYPES_TEST_P(SparseTensorSliceDatasetOpTest,
273                               SparseTensorSliceDatasetParams,
274                               IteratorOutputDtypesTestCases())
275 
276 std::vector<IteratorOutputShapesTestCase<SparseTensorSliceDatasetParams>>
277 IteratorOutputShapesTestCases() {
278   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
279            /*expected_output_shapes=*/{PartialTensorShape({1, 1}),
280                                        PartialTensorShape({1}),
281                                        PartialTensorShape({1})}},
282           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
283            /*expected_output_shapes=*/{PartialTensorShape({1, 2}),
284                                        PartialTensorShape({1}),
285                                        PartialTensorShape({2})}},
286           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
287            /*expected_output_shapes=*/{PartialTensorShape({1, 3}),
288                                        PartialTensorShape({1}),
289                                        PartialTensorShape({3})}},
290           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
291            /*expected_output_shapes=*/{PartialTensorShape({1, 4}),
292                                        PartialTensorShape({1}),
293                                        PartialTensorShape({4})}}};
294 }
295 
ITERATOR_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,SparseTensorSliceDatasetParams,IteratorOutputShapesTestCases ())296 ITERATOR_OUTPUT_SHAPES_TEST_P(SparseTensorSliceDatasetOpTest,
297                               SparseTensorSliceDatasetParams,
298                               IteratorOutputShapesTestCases())
299 
300 TEST_F(SparseTensorSliceDatasetOpTest, IteratorPrefix) {
301   auto dataset_params = TwoDimsSparseTensorSliceDatasetParams();
302   TF_ASSERT_OK(Initialize(dataset_params));
303   TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
304       kDatasetType, dataset_params.iterator_prefix())));
305 }
306 
307 template <typename T>
308 struct IteratorSaveAndRestoreTestCase {
309   T dataset_params;
310   std::vector<int> breakpoints;
311   std::vector<std::vector<Tensor>> expected_outputs;
312 };
313 
314 std::vector<IteratorSaveAndRestoreTestCase<SparseTensorSliceDatasetParams>>
IteratorSaveAndRestoreTestCases()315 IteratorSaveAndRestoreTestCases() {
316   return {{/*dataset_params=*/TwoDimsSparseTensorSliceDatasetParams(),
317            /*breakpoints=*/{0, 1, 2},
318            /*expected_outputs=*/
319            {{/*indices*/ CreateTensor<int64>({1, 1}, {0}),
320              /*values*/ CreateTensor<int32>({1}, {888}),
321              /*dense_shape*/ CreateTensor<int64>({1}, {2})},
322             {/*indices*/ CreateTensor<int64>({1, 1}, {1}),
323              /*values*/ CreateTensor<int32>({1}, {999}),
324              /*dense_shape*/ CreateTensor<int64>({1}, {2})}}},
325           {/*dataset_params=*/ThreeDimsSparseTensorSliceDatasetParams(),
326            /*breakpoints=*/{0, 1, 2},
327            /*expected_outputs=*/
328            {{/*indices*/ CreateTensor<int64>({1, 2}, {0, 0}),
329              /*values*/ CreateTensor<double>({1}, {888.0}),
330              /*dense_shape*/ CreateTensor<int64>({2}, {2, 2})},
331             {{/*indices*/ CreateTensor<int64>({1, 2}, {1, 1})},
332              {/*values*/ CreateTensor<double>({1}, {999.0})},
333              {/*dense_shape*/ CreateTensor<int64>({2}, {2, 2})}}}},
334           {/*dataset_params=*/FourDimsSparseTensorSliceDatasetParams(),
335            /*breakpoints=*/{0, 1, 3},
336            /*expected_outputs=*/
337            {{/*indices*/ CreateTensor<int64>({1, 3}, {0, 0, 0}),
338              /*values*/ CreateTensor<tstring>({1}, {"a"}),
339              /*dense_shape*/
340              CreateTensor<int64>({3}, {2, 2, 2})},
341             {/*indices*/ CreateTensor<int64>({1, 3}, {1, 1, 1}),
342              /*values*/ CreateTensor<tstring>({1}, {"b"}),
343              /*dense_shape*/
344              CreateTensor<int64>({3}, {2, 2, 2})},
345             {/*indices*/ CreateTensor<int64>({0, 3}, {}),
346              /*values*/ CreateTensor<tstring>({0}, {}),
347              /*dense_shape*/
348              CreateTensor<int64>({3}, {2, 2, 2})}}},
349           {/*dataset_params=*/FiveDimsSparseTensorSliceDatasetParams(),
350            /*breakpoints=*/{0, 1, 2},
351            /*expected_outputs=*/
352            {{/*indices*/ CreateTensor<int64>({1, 4}, {0, 0, 0, 0}),
353              /*values*/ CreateTensor<int32>({1}, {888}),
354              /*dense_shape*/
355              CreateTensor<int64>({4}, {2, 2, 2, 2})},
356             {/*indices*/ CreateTensor<int64>({1, 4}, {1, 1, 1, 1}),
357              /*values*/ CreateTensor<int32>({1}, {999}),
358              /*dense_shape*/
359              CreateTensor<int64>({4}, {2, 2, 2, 2})},
360             {/*indices*/ CreateTensor<int64>({0, 4}, {}),
361              /*values*/ CreateTensor<int32>({0}, {}),
362              /*dense_shape*/
363              CreateTensor<int64>({4}, {2, 2, 2, 2})}}}};
364 }
365 
366 class ParameterizedIteratorSaveAndRestoreTest
367     : public SparseTensorSliceDatasetOpTest,
368       public ::testing::WithParamInterface<
369           IteratorSaveAndRestoreTestCase<SparseTensorSliceDatasetParams>> {};
370 
TEST_P(ParameterizedIteratorSaveAndRestoreTest,IteratorSaveAndRestore)371 TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
372   auto test_case = GetParam();
373   TF_ASSERT_OK(Initialize(test_case.dataset_params));
374 
375   std::unique_ptr<SerializationContext> serialization_ctx;
376   TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
377 
378   int cur_iteration = 0;
379   bool end_of_sequence = false;
380   int64_t num_slices = dataset_->Cardinality();
381   std::vector<Tensor> out_tensors;
382 
383   for (int breakpoint : test_case.breakpoints) {
384     while (cur_iteration < breakpoint) {
385       TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
386                                       &end_of_sequence));
387       cur_iteration++;
388     }
389 
390     if (breakpoint == 0) {
391       EXPECT_FALSE(end_of_sequence);
392     } else if (breakpoint <= num_slices) {
393       for (int i = 0; i < out_tensors.size(); ++i) {
394         TF_EXPECT_OK(ExpectEqual(
395             out_tensors[0], test_case.expected_outputs[cur_iteration - 1][0]));
396         TF_EXPECT_OK(ExpectEqual(
397             out_tensors[1], test_case.expected_outputs[cur_iteration - 1][1]));
398         TF_EXPECT_OK(ExpectEqual(
399             out_tensors[2], test_case.expected_outputs[cur_iteration - 1][2]));
400       }
401     } else {
402       EXPECT_TRUE(end_of_sequence);
403     }
404 
405     VariantTensorDataWriter writer;
406     TF_ASSERT_OK(iterator_->Save(serialization_ctx.get(), &writer));
407     std::vector<const VariantTensorData*> data;
408     writer.GetData(&data);
409     VariantTensorDataReader reader(data);
410     TF_EXPECT_OK(RestoreIterator(iterator_ctx_.get(), &reader,
411                                  test_case.dataset_params.iterator_prefix(),
412                                  *dataset_, &iterator_));
413   }
414 }
415 
416 INSTANTIATE_TEST_CASE_P(SparseTensorSliceDatasetOpTest,
417                         ParameterizedIteratorSaveAndRestoreTest,
418                         ::testing::ValuesIn(IteratorSaveAndRestoreTestCases()));
419 
420 }  // namespace
421 }  // namespace data
422 }  // namespace tensorflow
423