• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/sparse_utils.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/platform/test.h"
24 
25 namespace {
26 
27 using tensorflow::DataType;
28 using tensorflow::int32;
29 using tensorflow::int64;
30 using tensorflow::Tensor;
31 using tensorflow::TTypes;
32 using tensorflow::uint16;
33 using tensorflow::uint32;
34 using tensorflow::uint64;
35 using tensorflow::sparse_utils::ContainsEmptyRows;
36 using tensorflow::sparse_utils::FindNextDenseRowStartIndex;
37 using tensorflow::sparse_utils::GetStartIndicesOfEachDenseRow;
38 using tensorflow::sparse_utils::ParseRowStartIndices;
39 
TEST(SparseUtilsTest,GetStartIndicesOfEachDenseRow)40 TEST(SparseUtilsTest, GetStartIndicesOfEachDenseRow) {
41   {
42     int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
43     TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
44     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
45     bool contains_empty_rows;
46     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int32>(indices_mat,
47                                                      &contains_empty_rows) ==
48                 std::vector<int32>({0, 1, 2, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8}));
49     EXPECT_TRUE(contains_empty_rows);
50   }
51   {
52     int32 data[] = {0, 0, 1, 0, 1, 0, 4, 0, 4, 0, 4, 0,  6, 0,  7,
53                     0, 7, 0, 7, 0, 7, 0, 8, 0, 8, 0, 10, 0, 12, 0};
54     TTypes<int32>::ConstMatrix indices_mat(data, 15, 2);
55     // indices_list = {0, 1, 1, 4, 4, 4,  6, 7, 7, 7, 7, 8, 8, 10, 12};
56     bool contains_empty_rows;
57     EXPECT_TRUE(
58         GetStartIndicesOfEachDenseRow<int32>(indices_mat,
59                                              &contains_empty_rows) ==
60         std::vector<int32>({0, 1, 3, 3, 3, 6, 6, 7, 11, 13, 13, 14, 14, 15}));
61     EXPECT_TRUE(contains_empty_rows);
62   }
63   {
64     int64 data[] = {3, 0};
65     TTypes<int64>::ConstMatrix indices_mat(data, 1, 2);
66     bool contains_empty_rows;
67     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int64>(indices_mat,
68                                                      &contains_empty_rows) ==
69                 std::vector<int64>({0, 0, 0, 0, 1}));
70     EXPECT_TRUE(contains_empty_rows);
71   }
72   {
73     uint32 data[] = {3, 0, 3, 0};
74     TTypes<uint32>::ConstMatrix indices_mat(data, 2, 2);
75     bool contains_empty_rows;
76     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint32>(indices_mat,
77                                                       &contains_empty_rows) ==
78                 std::vector<uint32>({0, 0, 0, 0, 2}));
79     EXPECT_TRUE(contains_empty_rows);
80   }
81   {
82     uint16 data[] = {0, 0, 0, 0, 0, 0, 1, 0};
83     TTypes<uint16>::ConstMatrix indices_mat(data, 4, 2);
84     // indices_list = {0, 0, 0, 1};
85     bool contains_empty_rows;
86     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint16>(indices_mat,
87                                                       &contains_empty_rows) ==
88                 std::vector<uint16>({0, 3, 4}));
89     EXPECT_FALSE(contains_empty_rows);
90   }
91   {
92     uint64 data[] = {0, 0, 0, 0, 0, 0, 3, 0};
93     TTypes<uint64>::ConstMatrix indices_mat(data, 4, 2);
94     bool contains_empty_rows;
95     // indices_list = {0, 0, 0, 3};
96     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint64>(indices_mat,
97                                                       &contains_empty_rows) ==
98                 std::vector<uint64>({0, 3, 3, 3, 4}));
99     EXPECT_TRUE(contains_empty_rows);
100   }
101 }
102 
TEST(SparseUtilsTest,ParseRowStartIndices)103 TEST(SparseUtilsTest, ParseRowStartIndices) {
104   {
105     Tensor t(DataType::DT_INT32, {1});
106     int indx = 0;
107     for (const int32_t v : {0}) {
108       t.flat<int32>()(indx++) = v;
109     }
110     EXPECT_TRUE(ParseRowStartIndices<int32>(t, 1) ==
111                 std::vector<int32>({0, 1}));
112   }
113   {
114     Tensor t(DataType::DT_INT64, {1});
115     int indx = 0;
116     for (const int64_t v : {0}) {
117       t.flat<int64>()(indx++) = v;
118     }
119     EXPECT_TRUE(ParseRowStartIndices<int64>(t, 2) ==
120                 std::vector<int64>({0, 2}));
121   }
122   {
123     Tensor t(DataType::DT_UINT64, {2});
124     int indx = 0;
125     for (const uint64 v : {0, 3}) {
126       t.flat<uint64>()(indx++) = v;
127     }
128     EXPECT_TRUE(ParseRowStartIndices<uint64>(t, 4) ==
129                 std::vector<uint64>({0, 3, 4}));
130   }
131   {
132     Tensor t(DataType::DT_UINT16, {2});
133     int indx = 0;
134     for (const uint16 v : {0, 3}) {
135       t.flat<uint16>()(indx++) = v;
136     }
137     EXPECT_TRUE(ParseRowStartIndices<uint16>(t, 4) ==
138                 std::vector<uint16>({0, 3, 4}));
139   }
140 }
141 
TEST(SparseUtilsTest,ContainsEmptyRows)142 TEST(SparseUtilsTest, ContainsEmptyRows) {
143   {
144     int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
145     TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
146     bool contains_empty_rows;
147     const auto segment_indices =
148         GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
149     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
150     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
151   }
152   {
153     int64 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
154     TTypes<int64>::ConstMatrix indices_mat(data, 8, 2);
155     bool contains_empty_rows;
156     const auto segment_indices =
157         GetStartIndicesOfEachDenseRow<int64>(indices_mat, &contains_empty_rows);
158     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
159     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
160   }
161   {
162     int32 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
163     TTypes<int32>::ConstMatrix indices_mat(data, 6, 2);
164     bool contains_empty_rows;
165     const auto segment_indices =
166         GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
167     // indices_list = {1, 1, 2, 2, 2, 3};
168     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
169   }
170   {
171     uint16 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
172     TTypes<uint16>::ConstMatrix indices_mat(data, 6, 2);
173     bool contains_empty_rows;
174     const auto segment_indices = GetStartIndicesOfEachDenseRow<uint16>(
175         indices_mat, &contains_empty_rows);
176     // indices_list = {1, 1, 2, 2, 2, 3};
177     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
178   }
179   {
180     int32 data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
181     TTypes<int32>::ConstMatrix indices_mat(data, 7, 2);
182     bool contains_empty_rows;
183     const auto segment_indices =
184         GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
185     // indices_list = {0, 1, 1, 2, 2, 2, 3};
186     EXPECT_FALSE(ContainsEmptyRows(segment_indices));
187   }
188   {
189     int64 data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
190     TTypes<int64>::ConstMatrix indices_mat(data, 7, 2);
191     bool contains_empty_rows;
192     const auto segment_indices =
193         GetStartIndicesOfEachDenseRow<int64>(indices_mat, &contains_empty_rows);
194     // indices_list = {0, 1, 1, 2, 2, 2, 3};
195     EXPECT_FALSE(ContainsEmptyRows(segment_indices));
196   }
197   {
198     uint32 data[] = {0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 2, 2, 3, 4};
199     TTypes<uint32>::ConstMatrix indices_mat(data, 7, 2);
200     bool contains_empty_rows;
201     const auto segment_indices = GetStartIndicesOfEachDenseRow<uint32>(
202         indices_mat, &contains_empty_rows);
203     // indices_list = {0, 0, 0, 2, 2, 2, 3};
204     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
205   }
206   {
207     int64 data[] = {0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 2, 2, 3, 4};
208     TTypes<int64>::ConstMatrix indices_mat(data, 7, 2);
209     bool contains_empty_rows;
210     const auto segment_indices =
211         GetStartIndicesOfEachDenseRow<int64>(indices_mat, &contains_empty_rows);
212     // indices_list = {0, 0, 0, 2, 2, 2, 3};
213     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
214   }
215   {
216     uint64 data[] = {0, 0, 0, 1, 0, 2, 1, 0, 2, 1, 2, 2, 3, 4};
217     TTypes<uint64>::ConstMatrix indices_mat(data, 7, 2);
218     bool contains_empty_rows;
219     const auto segment_indices = GetStartIndicesOfEachDenseRow<uint64>(
220         indices_mat, &contains_empty_rows);
221     // indices_list = {0, 0, 0, 1, 2, 2, 3};
222     EXPECT_FALSE(ContainsEmptyRows(segment_indices));
223   }
224 }
225 
TEST(SparseUtilsTest,FindNextDenseRowStartIndex)226 TEST(SparseUtilsTest, FindNextDenseRowStartIndex) {
227   {
228     int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
229     TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
230     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
231     for (int32_t i = 0; i < 8; ++i) {
232       EXPECT_EQ(i + 1, FindNextDenseRowStartIndex<int32>(i, indices_mat));
233     }
234   }
235   {
236     uint16 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
237     TTypes<uint16>::ConstMatrix indices_mat(data, 8, 2);
238     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
239     for (uint16 i = 0; i < 8; ++i) {
240       EXPECT_EQ(i + 1, FindNextDenseRowStartIndex<uint16>(i, indices_mat));
241     }
242   }
243   {
244     int64 data[] = {0, 0, 1, 0, 1, 0, 4, 0, 4, 0, 4, 0,  6, 0,  7,
245                     0, 7, 0, 7, 0, 7, 0, 8, 0, 8, 0, 10, 0, 12, 0};
246     TTypes<int64>::ConstMatrix indices_mat(data, 15, 2);
247     // indices_list = {0, 1, 1, 4, 4, 4,  6, 7, 7, 7, 7, 8, 8, 10, 12};
248     EXPECT_EQ(3, FindNextDenseRowStartIndex<int64>(static_cast<int64>(1),
249                                                    indices_mat));
250     EXPECT_EQ(3, FindNextDenseRowStartIndex<int64>(static_cast<int64>(2),
251                                                    indices_mat));
252     EXPECT_EQ(6, FindNextDenseRowStartIndex<int64>(static_cast<int64>(3),
253                                                    indices_mat));
254     EXPECT_EQ(6, FindNextDenseRowStartIndex<int64>(static_cast<int64>(4),
255                                                    indices_mat));
256     EXPECT_EQ(14, FindNextDenseRowStartIndex<int64>(static_cast<int64>(13),
257                                                     indices_mat));
258     EXPECT_EQ(15, FindNextDenseRowStartIndex<int64>(static_cast<int64>(14),
259                                                     indices_mat));
260   }
261 }
262 
263 }  // namespace
264