• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/sparse_utils.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <set>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/core/status_test_util.h"
30 #include "tensorflow/core/lib/random/philox_random.h"
31 #include "tensorflow/core/lib/random/simple_philox.h"
32 #include "tensorflow/core/platform/status_matchers.h"
33 #include "tensorflow/core/platform/test.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 
36 namespace tensorflow {
37 namespace sparse_utils {
38 namespace {
39 
40 using ::tensorflow::testing::StatusIs;
41 using ::testing::MatchesRegex;
42 
TEST(SparseUtilsTest,GetStartIndicesOfEachDenseRow)43 TEST(SparseUtilsTest, GetStartIndicesOfEachDenseRow) {
44   {
45     int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
46     TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
47     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
48     bool contains_empty_rows;
49     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int32>(indices_mat,
50                                                      &contains_empty_rows) ==
51                 std::vector<int32>({0, 1, 2, 2, 2, 3, 3, 4, 5, 6, 6, 7, 7, 8}));
52     EXPECT_TRUE(contains_empty_rows);
53   }
54   {
55     int32 data[] = {0, 0, 1, 0, 1, 0, 4, 0, 4, 0, 4, 0,  6, 0,  7,
56                     0, 7, 0, 7, 0, 7, 0, 8, 0, 8, 0, 10, 0, 12, 0};
57     TTypes<int32>::ConstMatrix indices_mat(data, 15, 2);
58     // indices_list = {0, 1, 1, 4, 4, 4,  6, 7, 7, 7, 7, 8, 8, 10, 12};
59     bool contains_empty_rows;
60     EXPECT_TRUE(
61         GetStartIndicesOfEachDenseRow<int32>(indices_mat,
62                                              &contains_empty_rows) ==
63         std::vector<int32>({0, 1, 3, 3, 3, 6, 6, 7, 11, 13, 13, 14, 14, 15}));
64     EXPECT_TRUE(contains_empty_rows);
65   }
66   {
67     int64_t data[] = {3, 0};
68     TTypes<int64_t>::ConstMatrix indices_mat(data, 1, 2);
69     bool contains_empty_rows;
70     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<int64_t>(indices_mat,
71                                                        &contains_empty_rows) ==
72                 std::vector<int64_t>({0, 0, 0, 0, 1}));
73     EXPECT_TRUE(contains_empty_rows);
74   }
75   {
76     uint32 data[] = {3, 0, 3, 0};
77     TTypes<uint32>::ConstMatrix indices_mat(data, 2, 2);
78     bool contains_empty_rows;
79     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint32>(indices_mat,
80                                                       &contains_empty_rows) ==
81                 std::vector<uint32>({0, 0, 0, 0, 2}));
82     EXPECT_TRUE(contains_empty_rows);
83   }
84   {
85     uint16 data[] = {0, 0, 0, 0, 0, 0, 1, 0};
86     TTypes<uint16>::ConstMatrix indices_mat(data, 4, 2);
87     // indices_list = {0, 0, 0, 1};
88     bool contains_empty_rows;
89     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint16>(indices_mat,
90                                                       &contains_empty_rows) ==
91                 std::vector<uint16>({0, 3, 4}));
92     EXPECT_FALSE(contains_empty_rows);
93   }
94   {
95     uint64 data[] = {0, 0, 0, 0, 0, 0, 3, 0};
96     TTypes<uint64>::ConstMatrix indices_mat(data, 4, 2);
97     bool contains_empty_rows;
98     // indices_list = {0, 0, 0, 3};
99     EXPECT_TRUE(GetStartIndicesOfEachDenseRow<uint64>(indices_mat,
100                                                       &contains_empty_rows) ==
101                 std::vector<uint64>({0, 3, 3, 3, 4}));
102     EXPECT_TRUE(contains_empty_rows);
103   }
104 }
105 
TEST(SparseUtilsTest,ParseRowStartIndices)106 TEST(SparseUtilsTest, ParseRowStartIndices) {
107   {
108     Tensor t(DataType::DT_INT32, {1});
109     int indx = 0;
110     for (const int32_t v : {0}) {
111       t.flat<int32>()(indx++) = v;
112     }
113     EXPECT_TRUE(ParseRowStartIndices<int32>(t, 1) ==
114                 std::vector<int32>({0, 1}));
115   }
116   {
117     Tensor t(DataType::DT_INT64, {1});
118     int indx = 0;
119     for (const int64_t v : {0}) {
120       t.flat<int64_t>()(indx++) = v;
121     }
122     EXPECT_TRUE(ParseRowStartIndices<int64_t>(t, 2) ==
123                 std::vector<int64_t>({0, 2}));
124   }
125   {
126     Tensor t(DataType::DT_UINT64, {2});
127     int indx = 0;
128     for (const uint64 v : {0, 3}) {
129       t.flat<uint64>()(indx++) = v;
130     }
131     EXPECT_TRUE(ParseRowStartIndices<uint64>(t, 4) ==
132                 std::vector<uint64>({0, 3, 4}));
133   }
134   {
135     Tensor t(DataType::DT_UINT16, {2});
136     int indx = 0;
137     for (const uint16 v : {0, 3}) {
138       t.flat<uint16>()(indx++) = v;
139     }
140     EXPECT_TRUE(ParseRowStartIndices<uint16>(t, 4) ==
141                 std::vector<uint16>({0, 3, 4}));
142   }
143 }
144 
TEST(SparseUtilsTest,ContainsEmptyRows)145 TEST(SparseUtilsTest, ContainsEmptyRows) {
146   {
147     int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
148     TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
149     bool contains_empty_rows;
150     const auto segment_indices =
151         GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
152     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
153     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
154   }
155   {
156     int64_t data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
157     TTypes<int64_t>::ConstMatrix indices_mat(data, 8, 2);
158     bool contains_empty_rows;
159     const auto segment_indices = GetStartIndicesOfEachDenseRow<int64_t>(
160         indices_mat, &contains_empty_rows);
161     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
162     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
163   }
164   {
165     int32 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
166     TTypes<int32>::ConstMatrix indices_mat(data, 6, 2);
167     bool contains_empty_rows;
168     const auto segment_indices =
169         GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
170     // indices_list = {1, 1, 2, 2, 2, 3};
171     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
172   }
173   {
174     uint16 data[] = {1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
175     TTypes<uint16>::ConstMatrix indices_mat(data, 6, 2);
176     bool contains_empty_rows;
177     const auto segment_indices = GetStartIndicesOfEachDenseRow<uint16>(
178         indices_mat, &contains_empty_rows);
179     // indices_list = {1, 1, 2, 2, 2, 3};
180     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
181   }
182   {
183     int32 data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
184     TTypes<int32>::ConstMatrix indices_mat(data, 7, 2);
185     bool contains_empty_rows;
186     const auto segment_indices =
187         GetStartIndicesOfEachDenseRow<int32>(indices_mat, &contains_empty_rows);
188     // indices_list = {0, 1, 1, 2, 2, 2, 3};
189     EXPECT_FALSE(ContainsEmptyRows(segment_indices));
190   }
191   {
192     int64_t data[] = {0, 0, 1, 0, 1, 1, 2, 0, 2, 1, 2, 2, 3, 4};
193     TTypes<int64_t>::ConstMatrix indices_mat(data, 7, 2);
194     bool contains_empty_rows;
195     const auto segment_indices = GetStartIndicesOfEachDenseRow<int64_t>(
196         indices_mat, &contains_empty_rows);
197     // indices_list = {0, 1, 1, 2, 2, 2, 3};
198     EXPECT_FALSE(ContainsEmptyRows(segment_indices));
199   }
200   {
201     uint32 data[] = {0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 2, 2, 3, 4};
202     TTypes<uint32>::ConstMatrix indices_mat(data, 7, 2);
203     bool contains_empty_rows;
204     const auto segment_indices = GetStartIndicesOfEachDenseRow<uint32>(
205         indices_mat, &contains_empty_rows);
206     // indices_list = {0, 0, 0, 2, 2, 2, 3};
207     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
208   }
209   {
210     int64_t data[] = {0, 0, 0, 1, 0, 2, 2, 0, 2, 1, 2, 2, 3, 4};
211     TTypes<int64_t>::ConstMatrix indices_mat(data, 7, 2);
212     bool contains_empty_rows;
213     const auto segment_indices = GetStartIndicesOfEachDenseRow<int64_t>(
214         indices_mat, &contains_empty_rows);
215     // indices_list = {0, 0, 0, 2, 2, 2, 3};
216     EXPECT_TRUE(ContainsEmptyRows(segment_indices));
217   }
218   {
219     uint64 data[] = {0, 0, 0, 1, 0, 2, 1, 0, 2, 1, 2, 2, 3, 4};
220     TTypes<uint64>::ConstMatrix indices_mat(data, 7, 2);
221     bool contains_empty_rows;
222     const auto segment_indices = GetStartIndicesOfEachDenseRow<uint64>(
223         indices_mat, &contains_empty_rows);
224     // indices_list = {0, 0, 0, 1, 2, 2, 3};
225     EXPECT_FALSE(ContainsEmptyRows(segment_indices));
226   }
227 }
228 
TEST(SparseUtilsTest,FindNextDenseRowStartIndex)229 TEST(SparseUtilsTest, FindNextDenseRowStartIndex) {
230   {
231     int32 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
232     TTypes<int32>::ConstMatrix indices_mat(data, 8, 2);
233     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
234     for (int32_t i = 0; i < 8; ++i) {
235       EXPECT_EQ(i + 1, FindNextDenseRowStartIndex<int32>(i, indices_mat));
236     }
237   }
238   {
239     uint16 data[] = {0, 0, 1, 0, 4, 0, 6, 0, 7, 0, 8, 0, 10, 0, 12, 0};
240     TTypes<uint16>::ConstMatrix indices_mat(data, 8, 2);
241     // indices_list = {0, 1, 4, 6, 7, 8, 10, 12};
242     for (uint16 i = 0; i < 8; ++i) {
243       EXPECT_EQ(i + 1, FindNextDenseRowStartIndex<uint16>(i, indices_mat));
244     }
245   }
246   {
247     int64_t data[] = {0, 0, 1, 0, 1, 0, 4, 0, 4, 0, 4, 0,  6, 0,  7,
248                       0, 7, 0, 7, 0, 7, 0, 8, 0, 8, 0, 10, 0, 12, 0};
249     TTypes<int64_t>::ConstMatrix indices_mat(data, 15, 2);
250     // indices_list = {0, 1, 1, 4, 4, 4,  6, 7, 7, 7, 7, 8, 8, 10, 12};
251     EXPECT_EQ(3, FindNextDenseRowStartIndex<int64_t>(static_cast<int64_t>(1),
252                                                      indices_mat));
253     EXPECT_EQ(3, FindNextDenseRowStartIndex<int64_t>(static_cast<int64_t>(2),
254                                                      indices_mat));
255     EXPECT_EQ(6, FindNextDenseRowStartIndex<int64_t>(static_cast<int64_t>(3),
256                                                      indices_mat));
257     EXPECT_EQ(6, FindNextDenseRowStartIndex<int64_t>(static_cast<int64_t>(4),
258                                                      indices_mat));
259     EXPECT_EQ(14, FindNextDenseRowStartIndex<int64_t>(static_cast<int64_t>(13),
260                                                       indices_mat));
261     EXPECT_EQ(15, FindNextDenseRowStartIndex<int64_t>(static_cast<int64_t>(14),
262                                                       indices_mat));
263   }
264 }
265 
266 // Returns a shared random number generator.
RandomPhilox()267 ::tensorflow::random::SimplePhilox& RandomPhilox() {
268   // Safe initialization of static random generator.
269   static auto* philox =
270       new ::tensorflow::random::PhiloxRandom(tensorflow::testing::RandomSeed());
271   static auto* rnd = new ::tensorflow::random::SimplePhilox(philox);
272   return *rnd;
273 }
274 
275 // Fills a tensor of indices with a unique set of random index tuples.
276 // The `SetType` must be a std::set-like type (e.g. flat_hash_set, btree_set)
277 // that is used to ensure uniqueness and governs the final index tuple order.
278 // For example, use a hash set for unordered indices, and sorted set for
279 // lexicographically ordered indices. The `shape` is used to ensure proper index
280 // bounds.
281 template <typename SetType>
FillIndicesWithRandomTuples(const TensorShape & shape,Tensor & indices)282 void FillIndicesWithRandomTuples(const TensorShape& shape, Tensor& indices) {
283   const int64_t nnz = indices.dim_size(0);
284   const int64_t ndims = indices.dim_size(1);
285 
286   SetType indices_set;
287   int64_t count = 0;
288   // Generate nnz unique random tuples.
289   while (count < nnz) {
290     std::vector<int64_t> candidate(ndims);
291     for (int64_t d = 0; d < ndims; ++d) {
292       candidate[d] = RandomPhilox().Uniform64(shape.dim_size(d));
293     }
294     auto it = indices_set.insert(std::move(candidate));
295     if (it.second) {
296       ++count;
297     }
298   }
299 
300   // Copy index tuples from set into index tensor.
301   auto indices_mat = indices.matrix<int64_t>();
302   int64_t row = 0;
303   for (const std::vector<int64_t>& idxs : indices_set) {
304     for (int64_t col = 0; col < ndims; ++col) {
305       indices_mat(row, col) = idxs[col];
306     }
307     ++row;
308   }
309 }
310 
311 // Populates components of a sparse random tensor with provided number of
312 // non-zeros `max_nnz` and tensor shape `shape`.  If `ordered`, output indices
313 // are ordered lexicographically.
GenerateRandomSparseTensor(int64_t max_nnz,const TensorShape & shape,bool ordered,Tensor & output_indices,Tensor & output_values,Tensor & output_shape)314 void GenerateRandomSparseTensor(int64_t max_nnz, const TensorShape& shape,
315                                 bool ordered, Tensor& output_indices,
316                                 Tensor& output_values, Tensor& output_shape) {
317   const int64_t ndims = shape.dims();
318   // We cannot generate more elements than the total in the tensor, so
319   // potentially reduce nnz.
320   const int64_t nnz = std::min(shape.num_elements(), max_nnz);
321   output_indices = Tensor(DT_INT64, TensorShape({nnz, ndims}));
322   output_values = Tensor(DT_FLOAT, TensorShape({nnz}));
323   output_shape = Tensor(DT_INT64, TensorShape({ndims}));
324 
325   // Generate random unique sparse indices.
326   if (ordered) {
327     // NOTE: absl::btree_set does not seem to be available in TF OSS.
328     FillIndicesWithRandomTuples<std::set<std::vector<int64_t>>>(shape,
329                                                                 output_indices);
330   } else {
331     FillIndicesWithRandomTuples<absl::flat_hash_set<std::vector<int64_t>>>(
332         shape, output_indices);
333   }
334 
335   auto values_vec = output_values.vec<float>();
336   values_vec.setRandom();
337 
338   auto shape_vec = output_shape.vec<int64_t>();
339   for (int i = 0; i < shape.dims(); ++i) {
340     shape_vec(i) = shape.dim_size(i);
341   }
342 }
343 
344 using ValidateSparseTensorTest = ::testing::TestWithParam<IndexValidation>;
345 
TEST_P(ValidateSparseTensorTest,ValidSparseTensorPasses)346 TEST_P(ValidateSparseTensorTest, ValidSparseTensorPasses) {
347   constexpr int kNumNonZeros = 1000;
348   const TensorShape kTensorShapes[] = {
349       {}, {3}, {4, 5}, {6, 7, 8}, {9, 10, 11, 12}};
350   const IndexValidation index_validation = GetParam();
351   const bool ordered = (index_validation == IndexValidation::kOrdered);
352   for (const TensorShape& test_shape : kTensorShapes) {
353     Tensor indices, values, shape;
354     GenerateRandomSparseTensor(kNumNonZeros, test_shape, ordered, indices,
355                                values, shape);
356     TF_EXPECT_OK((ValidateSparseTensor<int64_t>(indices, values, shape,
357                                                 index_validation)));
358   }
359 }
360 
TEST_P(ValidateSparseTensorTest,InvalidIndicesRankFails)361 TEST_P(ValidateSparseTensorTest, InvalidIndicesRankFails) {
362   constexpr int kNumNonZeros = 1000;
363   constexpr int kNumDims = 3;
364   // Indices tensor must be rank 2, so try rank 0, 1, 3.
365   const TensorShape kInvalidIndicesShapes[] = {
366       {}, {kNumNonZeros}, {kNumNonZeros, kNumDims, 4}};
367   const IndexValidation index_validation = GetParam();
368   for (const TensorShape& invalid_shape : kInvalidIndicesShapes) {
369     const Tensor indices = Tensor(DT_INT64, invalid_shape);
370     const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
371     const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
372     EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
373                                                index_validation)),
374                 StatusIs(error::INVALID_ARGUMENT,
375                          MatchesRegex("Sparse indices must be rank 2 .*")));
376   }
377 }
378 
TEST_P(ValidateSparseTensorTest,InvalidValuesRankFails)379 TEST_P(ValidateSparseTensorTest, InvalidValuesRankFails) {
380   constexpr int kNumNonZeros = 1000;
381   constexpr int kNumDims = 3;
382   // Values tensor must be rank 1, so try rank 0, 2.
383   const TensorShape kInvalidValuesShapes[] = {{}, {kNumNonZeros, 2}};
384   const IndexValidation index_validation = GetParam();
385   for (const TensorShape& invalid_shape : kInvalidValuesShapes) {
386     const Tensor indices =
387         Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims}));
388     const Tensor values = Tensor(DT_FLOAT, invalid_shape);
389     const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
390     EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
391                                                index_validation)),
392                 StatusIs(error::INVALID_ARGUMENT,
393                          MatchesRegex("Sparse values must be rank 1 .*")));
394   }
395 }
396 
TEST_P(ValidateSparseTensorTest,InvalidShapeRankFails)397 TEST_P(ValidateSparseTensorTest, InvalidShapeRankFails) {
398   constexpr int kNumNonZeros = 1000;
399   constexpr int kNumDims = 3;
400   const IndexValidation index_validation = GetParam();
401   // Shape tensor must be rank 1, so try rank 0, 2.
402   const TensorShape kInvalidShapeShapes[] = {{}, {kNumDims, 2}};
403   for (const TensorShape& invalid_shape : kInvalidShapeShapes) {
404     const Tensor indices =
405         Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims}));
406     const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
407     const Tensor shape = Tensor(DT_INT64, invalid_shape);
408     EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
409                                                index_validation)),
410                 StatusIs(error::INVALID_ARGUMENT,
411                          MatchesRegex("Sparse shape must be rank 1 .*")));
412   }
413 }
414 
TEST_P(ValidateSparseTensorTest,IncompatibleShapesFails)415 TEST_P(ValidateSparseTensorTest, IncompatibleShapesFails) {
416   constexpr int kNumNonZeros = 1000;
417   constexpr int kNumDims = 3;
418   const IndexValidation index_validation = GetParam();
419 
420   const Tensor values = Tensor(DT_FLOAT, TensorShape({kNumNonZeros}));
421   const Tensor shape = Tensor(DT_INT64, TensorShape({kNumDims}));
422 
423   // Indices and values must have the same size in dimension 0 (nnz).
424   {
425     const Tensor indices =
426         Tensor(DT_INT64, TensorShape({kNumNonZeros + 1, kNumDims}));
427     EXPECT_THAT((ValidateSparseTensor<int64_t>(indices, values, shape,
428                                                index_validation)),
429                 StatusIs(error::INVALID_ARGUMENT,
430                          MatchesRegex("Number of elements in indices .* and "
431                                       "values .* do not match")));
432   }
433 
434   // Each index tuple must have the same size in dimension 1 as the dense
435   // tensor shape (ndims).
436   {
437     const Tensor indices =
438         Tensor(DT_INT64, TensorShape({kNumNonZeros, kNumDims + 1}));
439     EXPECT_THAT(
440         (ValidateSparseTensor<int64_t>(indices, values, shape,
441                                        index_validation)),
442         StatusIs(error::INVALID_ARGUMENT,
443                  MatchesRegex("Index rank .* and shape rank .* do not match")));
444   }
445 }
446 
TEST_P(ValidateSparseTensorTest,IndexOutOfBoundsFails)447 TEST_P(ValidateSparseTensorTest, IndexOutOfBoundsFails) {
448   constexpr int kNumNonZeros = 1000;
449   constexpr int kNumTests = 100;
450   const IndexValidation index_validation = GetParam();
451   const bool ordered = (index_validation == IndexValidation::kOrdered);
452 
453   const TensorShape kTensorShapes[] = {{3}, {4, 5}, {6, 7, 8}, {9, 10, 11, 12}};
454 
455   for (const TensorShape& test_shape : kTensorShapes) {
456     Tensor indices, values, shape;
457     GenerateRandomSparseTensor(kNumNonZeros, test_shape, ordered, indices,
458                                values, shape);
459     // Access tensor values.
460     auto indices_mat = indices.matrix<int64_t>();
461     for (int test = 0; test < kNumTests; ++test) {
462       // Pick a random entry and dimension, and make the index out of bounds.
463       int64_t row = RandomPhilox().Uniform64(indices.dim_size(0));
464       int64_t dim = RandomPhilox().Uniform64(indices.dim_size(1));
465       int64_t old_val = indices_mat(row, dim);
466 
467       for (int64_t val : {static_cast<int64_t>(-1), test_shape.dim_size(dim)}) {
468         indices_mat(row, dim) = val;
469         Status indices_valid = ValidateSparseTensor<int64_t>(
470             indices, values, shape, index_validation);
471         if (index_validation == IndexValidation::kNone) {
472           TF_EXPECT_OK(indices_valid);
473         } else {
474           EXPECT_THAT(
475               indices_valid,
476               StatusIs(error::INVALID_ARGUMENT,
477                        MatchesRegex("Sparse index tuple .* is out of bounds")))
478               << indices_mat;
479         }
480       }
481 
482       // Restore index for next test.
483       indices_mat(row, dim) = old_val;
484     }
485   }
486 }
487 
TEST_P(ValidateSparseTensorTest,IndexOutOfOrderFailsForOrderedValidation)488 TEST_P(ValidateSparseTensorTest, IndexOutOfOrderFailsForOrderedValidation) {
489   constexpr int kNumNonZeros = 1000;
490   constexpr int kNumTests = 100;
491   const TensorShape kTensorShapes[] = {{3}, {4, 5}, {6, 7, 8}, {9, 10, 11, 12}};
492   const IndexValidation index_validation = GetParam();
493   const bool ordered = (index_validation == IndexValidation::kOrdered);
494 
495   for (const TensorShape& test_shape : kTensorShapes) {
496     Tensor indices, values, shape;
497     GenerateRandomSparseTensor(kNumNonZeros, test_shape, ordered, indices,
498                                values, shape);
499     // Access tensor values.
500     auto indices_mat = indices.matrix<int64_t>();
501     const int64_t nnz = indices.dim_size(0);
502     const int64_t ndims = indices.dim_size(1);
503     for (int test = 0; test < kNumTests; ++test) {
504       // Pick two random index entries to swap.
505       int64_t row1 = RandomPhilox().Uniform64(nnz);
506       int64_t row2;
507       do {
508         row2 = RandomPhilox().Uniform64(nnz);
509       } while (row1 == row2);
510       for (int dim = 0; dim < ndims; ++dim) {
511         std::swap(indices_mat(row1, dim), indices_mat(row2, dim));
512       }
513 
514       Status indices_valid = ValidateSparseTensor<int64_t>(
515           indices, values, shape, index_validation);
516       if (ordered) {
517         EXPECT_THAT(
518             indices_valid,
519             StatusIs(error::INVALID_ARGUMENT,
520                      MatchesRegex("Sparse index tuple .* is out of order")));
521       } else {
522         TF_EXPECT_OK(indices_valid);
523       }
524 
525       // Restore index for next test.
526       for (int dim = 0; dim < ndims; ++dim) {
527         std::swap(indices_mat(row1, dim), indices_mat(row2, dim));
528       }
529     }
530   }
531 }
532 
533 INSTANTIATE_TEST_SUITE_P(
534     ValidateSparseTensorTestSuite, ValidateSparseTensorTest,
535     ::testing::Values(IndexValidation::kNone, IndexValidation::kOrdered,
536                       IndexValidation::kUnordered),
537     [](const ::testing::TestParamInfo<ValidateSparseTensorTest::ParamType>&
__anon3f4f3a170202(const ::testing::TestParamInfo<ValidateSparseTensorTest::ParamType>& info) 538            info) {
539       switch (info.param) {
540         case IndexValidation::kNone:
541           return "None";
542         case IndexValidation::kUnordered:
543           return "Unordered";
544         case IndexValidation::kOrdered:
545           return "Ordered";
546       }
547     });
548 
549 //==============================================================================
550 // BENCHMARKS
551 //==============================================================================
552 
553 // Benchmark time to validate a valid sparse tensor (the common case, worst-case
554 // latency).
BM_ValidateSparseTensor(::testing::benchmark::State & state,TensorShape dense_shape,IndexValidation index_validation)555 void BM_ValidateSparseTensor(::testing::benchmark::State& state,
556                              TensorShape dense_shape,
557                              IndexValidation index_validation) {
558   Tensor indices, values, shape;
559   const int64_t nnz = state.range(0);
560   GenerateRandomSparseTensor(nnz, dense_shape, /*ordered=*/true, indices,
561                              values, shape);
562   for (auto s : state) {
563     ::benchmark::DoNotOptimize(ValidateSparseTensor<int64_t>(
564         indices, values, shape, index_validation));
565   }
566 }
567 
568 BENCHMARK_CAPTURE(BM_ValidateSparseTensor, Ordered1024, TensorShape({1024}),
569                   IndexValidation::kOrdered)
570     ->Range(8, 512);
571 BENCHMARK_CAPTURE(BM_ValidateSparseTensor, Unordered1024, TensorShape({1024}),
572                   IndexValidation::kUnordered)
573     ->Range(8, 512);
574 BENCHMARK_CAPTURE(BM_ValidateSparseTensor, Ordered1024x1024,
575                   TensorShape({1024, 1024}), IndexValidation::kOrdered)
576     ->Range(8, 1024);
577 BENCHMARK_CAPTURE(BM_ValidateSparseTensor, Unordered1024x1024,
578                   TensorShape({1024, 1024}), IndexValidation::kUnordered)
579     ->Range(8, 1024);
580 BENCHMARK_CAPTURE(BM_ValidateSparseTensor, Ordered1024x1024x1024,
581                   TensorShape({1024, 1024, 1024}), IndexValidation::kOrdered)
582     ->Range(8, 1024 * 32);
583 BENCHMARK_CAPTURE(BM_ValidateSparseTensor, Unordered1024x1024x1024,
584                   TensorShape({1024, 1024, 1024}), IndexValidation::kUnordered)
585     ->Range(8, 1024 * 32);
586 
587 }  // namespace
588 }  // namespace sparse_utils
589 }  // namespace tensorflow
590