1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/sparse_utils.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_types.h"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace {
26
27 using tensorflow::DataType;
28 using tensorflow::int32;
29 using tensorflow::int64;
30 using tensorflow::Tensor;
31 using tensorflow::TTypes;
32 using tensorflow::uint16;
33 using tensorflow::uint32;
34 using tensorflow::uint64;
35 using tensorflow::sparse_utils::ContainsEmptyRows;
36 using tensorflow::sparse_utils::FindNextDenseRowStartIndex;
37 using tensorflow::sparse_utils::GetStartIndicesOfEachDenseRow;
38 using tensorflow::sparse_utils::ParseRowStartIndices;
39
TEST(SparseUtilsTest,GetStartIndicesOfEachDenseRow)40 TEST(SparseUtilsTest, GetStartIndicesOfEachDenseRow) {
41 {
42 int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
43 TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
44 // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
45 bool contains_empty_rows;
46 EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int32>(indices_mat,
47 &contains_empty_rows) ==
48 std::vector<int32>({0, 1, 2, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8}));
49 EXPECT_TRUE(contains_empty_rows);
50 }
51 {
52 int32 data[] = {0, 0, 1, 0, 1, 0, 4, 0, 4, 0, 4, 0, 6, 0, 7,
53 0, 7, 0, 7, 0, 7, 0, 8, 0, 8, 0, 10, 0, 12, 0};
54 TTypes<int32>::ConstMatrix indices_mat(data, 15, 2);
55 // indices_list = {0, 1, 1, 4, 4, 4, 6, 7, 7, 7, 7, 8, 8, 10, 12};
56 bool contains_empty_rows;
57 EXPECT_TRUE(
58 GetStartIndicesOfEachDenseRow<int32>(indices_mat,
59 &contains_empty_rows) ==
60 std::vector<int32>({0, 1, 3, 3, 3, 6, 6, 7, 11, 13, 13, 14, 14, 15}));
61 EXPECT_TRUE(contains_empty_rows);
62 }
63 {
64 int64 data[] = {3, 0};
65 TTypes<int64>::ConstMatrix indices_mat(data, 1, 2);
66 bool contains_empty_rows;
67 EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int64>(indices_mat,
68 &contains_empty_rows) ==
69 std::vector<int64>({0, 0, 0, 0, 1}));
70 EXPECT_TRUE(contains_empty_rows);
71 }
72 {
73 uint32 data[] = {3, 0, 3, 0};
74 TTypes<uint32>::ConstMatrix indices_mat(data, 2, 2);
75 bool contains_empty_rows;
76 EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint32>(indices_mat,
77 &contains_empty_rows) ==
78 std::vector<uint32>({0, 0, 0, 0, 2}));
79 EXPECT_TRUE(contains_empty_rows);
80 }
81 {
82 uint16 data[] = {0, 0, 0, 0, 0, 0, 1, 0};
83 TTypes<uint16>::ConstMatrix indices_mat(data, 4, 2);
84 // indices_list = {0, 0, 0, 1};
85 bool contains_empty_rows;
86 EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint16>(indices_mat,
87 &contains_empty_rows) ==
88 std::vector<uint16>({0, 3, 4}));
89 EXPECT_FALSE(contains_empty_rows);
90 }
91 {
92 uint64 data[] = {0, 0, 0, 0, 0, 0, 3, 0};
93 TTypes<uint64>::ConstMatrix indices_mat(data, 4, 2);
94 bool contains_empty_rows;
95 // indices_list = {0, 0, 0, 3};
96 EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint64>(indices_mat,
97 &contains_empty_rows) ==
98 std::vector<uint64>({0, 3, 3, 3, 4}));
99 EXPECT_TRUE(contains_empty_rows);
100 }
101 }
102
TEST(SparseUtilsTest,ParseRowStartIndices)103 TEST(SparseUtilsTest, ParseRowStartIndices) {
104 {
105 Tensor t(DataType::DT_INT32, {1});
106 int indx = 0;
107 for (const int32_t v : {0}) {
108 t.flat<int32>()(indx++) = v;
109 }
110 EXPECT_TRUE(ParseRowStartIndices<int32>(t, 1) ==
111 std::vector<int32>({0, 1}));
112 }
113 {
114 Tensor t(DataType::DT_INT64, {1});
115 int indx = 0;
116 for (const int64_t v : {0}) {
117 t.flat<int64>()(indx++) = v;
118 }
119 EXPECT_TRUE(ParseRowStartIndices<int64>(t, 2) ==
120 std::vector<int64>({0, 2}));
121 }
122 {
123 Tensor t(DataType::DT_UINT64, {2});
124 int indx = 0;
125 for (const uint64 v : {0, 3}) {
126 t.flat<uint64>()(indx++) = v;
127 }
128 EXPECT_TRUE(ParseRowStartIndices<uint64>(t, 4) ==
129 std::vector<uint64>({0, 3, 4}));
130 }
131 {
132 Tensor t(DataType::DT_UINT16, {2});
133 int indx = 0;
134 for (const uint16 v : {0, 3}) {
135 t.flat<uint16>()(indx++) = v;
136 }
137 EXPECT_TRUE(ParseRowStartIndices<uint16>(t, 4) ==
138 std::vector<uint16>({0, 3, 4}));
139 }
140 }
141
TEST(SparseUtilsTest,ContainsEmptyRows)142 TEST(SparseUtilsTest, ContainsEmptyRows) {
143 {
144 int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
145 TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
146 bool contains_empty_rows;
147 const auto segment_indices =
148 GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
149 // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
150 EXPECT_TRUE(ContainsEmptyRows(segment_indices));
151 }
152 {
153 int64 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
154 TTypes<int64>::ConstMatrix indices_mat(data, 8, 2);
155 bool contains_empty_rows;
156 const auto segment_indices =
157 GetStartIndicesOfEachDenseRow<int64>(indices_mat, &contains_empty_rows);
158 // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
159 EXPECT_TRUE(ContainsEmptyRows(segment_indices));
160 }
161 {
162 int32 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
163 TTypes<int32>::ConstMatrix indices_mat(data, 6, 2);
164 bool contains_empty_rows;
165 const auto segment_indices =
166 GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
167 // indices_list = {1, 1, 2, 2, 2, 3};
168 EXPECT_TRUE(ContainsEmptyRows(segment_indices));
169 }
170 {
171 uint16 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
172 TTypes<uint16>::ConstMatrix indices_mat(data, 6, 2);
173 bool contains_empty_rows;
174 const auto segment_indices = GetStartIndicesOfEachDenseRow<uint16>(
175 indices_mat, &contains_empty_rows);
176 // indices_list = {1, 1, 2, 2, 2, 3};
177 EXPECT_TRUE(ContainsEmptyRows(segment_indices));
178 }
179 {
180 int32 data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
181 TTypes<int32>::ConstMatrix indices_mat(data, 7, 2);
182 bool contains_empty_rows;
183 const auto segment_indices =
184 GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
185 // indices_list = {0, 1, 1, 2, 2, 2, 3};
186 EXPECT_FALSE(ContainsEmptyRows(segment_indices));
187 }
188 {
189 int64 data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
190 TTypes<int64>::ConstMatrix indices_mat(data, 7, 2);
191 bool contains_empty_rows;
192 const auto segment_indices =
193 GetStartIndicesOfEachDenseRow<int64>(indices_mat, &contains_empty_rows);
194 // indices_list = {0, 1, 1, 2, 2, 2, 3};
195 EXPECT_FALSE(ContainsEmptyRows(segment_indices));
196 }
197 {
198 uint32 data[] = {0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 2, 2, 3, 4};
199 TTypes<uint32>::ConstMatrix indices_mat(data, 7, 2);
200 bool contains_empty_rows;
201 const auto segment_indices = GetStartIndicesOfEachDenseRow<uint32>(
202 indices_mat, &contains_empty_rows);
203 // indices_list = {0, 0, 0, 2, 2, 2, 3};
204 EXPECT_TRUE(ContainsEmptyRows(segment_indices));
205 }
206 {
207 int64 data[] = {0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 2, 2, 3, 4};
208 TTypes<int64>::ConstMatrix indices_mat(data, 7, 2);
209 bool contains_empty_rows;
210 const auto segment_indices =
211 GetStartIndicesOfEachDenseRow<int64>(indices_mat, &contains_empty_rows);
212 // indices_list = {0, 0, 0, 2, 2, 2, 3};
213 EXPECT_TRUE(ContainsEmptyRows(segment_indices));
214 }
215 {
216 uint64 data[] = {0, 0, 0, 1, 0, 2, 1, 0, 2, 1, 2, 2, 3, 4};
217 TTypes<uint64>::ConstMatrix indices_mat(data, 7, 2);
218 bool contains_empty_rows;
219 const auto segment_indices = GetStartIndicesOfEachDenseRow<uint64>(
220 indices_mat, &contains_empty_rows);
221 // indices_list = {0, 0, 0, 1, 2, 2, 3};
222 EXPECT_FALSE(ContainsEmptyRows(segment_indices));
223 }
224 }
225
TEST(SparseUtilsTest,FindNextDenseRowStartIndex)226 TEST(SparseUtilsTest, FindNextDenseRowStartIndex) {
227 {
228 int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
229 TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
230 // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
231 for (int32_t i = 0; i < 8; ++i) {
232 EXPECT_EQ(i + 1, FindNextDenseRowStartIndex<int32>(i, indices_mat));
233 }
234 }
235 {
236 uint16 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
237 TTypes<uint16>::ConstMatrix indices_mat(data, 8, 2);
238 // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
239 for (uint16 i = 0; i < 8; ++i) {
240 EXPECT_EQ(i + 1, FindNextDenseRowStartIndex<uint16>(i, indices_mat));
241 }
242 }
243 {
244 int64 data[] = {0, 0, 1, 0, 1, 0, 4, 0, 4, 0, 4, 0, 6, 0, 7,
245 0, 7, 0, 7, 0, 7, 0, 8, 0, 8, 0, 10, 0, 12, 0};
246 TTypes<int64>::ConstMatrix indices_mat(data, 15, 2);
247 // indices_list = {0, 1, 1, 4, 4, 4, 6, 7, 7, 7, 7, 8, 8, 10, 12};
248 EXPECT_EQ(3, FindNextDenseRowStartIndex<int64>(static_cast<int64>(1),
249 indices_mat));
250 EXPECT_EQ(3, FindNextDenseRowStartIndex<int64>(static_cast<int64>(2),
251 indices_mat));
252 EXPECT_EQ(6, FindNextDenseRowStartIndex<int64>(static_cast<int64>(3),
253 indices_mat));
254 EXPECT_EQ(6, FindNextDenseRowStartIndex<int64>(static_cast<int64>(4),
255 indices_mat));
256 EXPECT_EQ(14, FindNextDenseRowStartIndex<int64>(static_cast<int64>(13),
257 indices_mat));
258 EXPECT_EQ(15, FindNextDenseRowStartIndex<int64>(static_cast<int64>(14),
259 indices_mat));
260 }
261 }
262
263 } // namespace
264