1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/sparse_utils.h"
17
18 #include <cstddef>
19
20 #include "tensorflow/core/framework/tensor_shape.h"
21
22 namespace tensorflow {
23 namespace sparse_utils {
24
25 template <typename Tindices>
FindNextDenseRowStartIndex(const Tindices sparse_index_begin,const typename TTypes<Tindices>::ConstMatrix & indices_mat)26 Tindices FindNextDenseRowStartIndex(
27 const Tindices sparse_index_begin,
28 const typename TTypes<Tindices>::ConstMatrix& indices_mat) {
29 // Search in the index range [begin, end) of indices_mat.
30 Tindices begin = sparse_index_begin;
31 Tindices end = indices_mat.dimension(0);
32 const Tindices orig_sparse_index_end = end;
33
34 // The first dense row we search.
35 const Tindices orig_dense_index_begin = indices_mat(begin, 0);
36 // Early exit if no next dense row index.
37 if (orig_dense_index_begin == static_cast<int64>(indices_mat(end - 1, 0))) {
38 return orig_sparse_index_end;
39 }
40
41 Tindices increment = 1;
42 while (begin + increment < end &&
43 indices_mat(begin + increment, 0) == orig_dense_index_begin) {
44 increment *= 2;
45 }
46 // Narrow the search space as an optimization.
47 if (begin + increment < end) {
48 end = begin + increment;
49 }
50 begin += increment / 2;
51
52 // Perform a binary search on the interval [begin, end) for
53 // dense_row_index_to_find.
54 const Tindices dense_row_index_to_find = orig_dense_index_begin;
55 while (begin < end) {
56 const Tindices m = begin + (end - begin) / 2;
57 const Tindices m_dense_row_index = static_cast<Tindices>(indices_mat(m, 0));
58 if (m_dense_row_index == dense_row_index_to_find &&
59 (m + 1 == orig_sparse_index_end ||
60 static_cast<Tindices>(indices_mat(m + 1, 0)) !=
61 dense_row_index_to_find)) {
62 return m + 1;
63 } else if (m_dense_row_index <= dense_row_index_to_find) {
64 begin = m + 1;
65 } else {
66 end = m;
67 }
68 }
69
70 // No next dense row index.
71 return orig_sparse_index_end;
72 }
73
74 template <typename Tindices>
GetStartIndicesOfEachDenseRow(const typename TTypes<Tindices>::ConstMatrix & indices_mat,bool * contains_empty_rows)75 std::vector<Tindices> GetStartIndicesOfEachDenseRow(
76 const typename TTypes<Tindices>::ConstMatrix& indices_mat,
77 bool* contains_empty_rows) {
78 int64_t start_sparse_index_of_cur_dense_row = 0;
79 std::vector<Tindices> segment_indices;
80 const Tindices num_entries_in_sparse_tensor = indices_mat.dimension(0);
81 const Tindices num_dense_rows_in_sparse_tensor =
82 1 + indices_mat(num_entries_in_sparse_tensor - 1, 0);
83 // Reserve an extra slot for the 0 we store in the first entry by convention.
84 segment_indices.reserve(1 + num_dense_rows_in_sparse_tensor);
85 segment_indices.push_back(0);
86 for (Tindices i = 0; i < indices_mat(0, 0); ++i) {
87 segment_indices.push_back(0);
88 }
89 *contains_empty_rows = indices_mat(0, 0) > 0;
90 while (true) {
91 const Tindices start_sparse_index_of_next_dense_row =
92 FindNextDenseRowStartIndex<Tindices>(
93 start_sparse_index_of_cur_dense_row, indices_mat);
94 if (start_sparse_index_of_next_dense_row == num_entries_in_sparse_tensor) {
95 segment_indices.push_back(start_sparse_index_of_next_dense_row);
96 break;
97 }
98 // Encode the length of the current dense row as well as the lengths of all
99 // the empty rows until the next dense row,
100 for (Tindices i = 0;
101 i < indices_mat(start_sparse_index_of_next_dense_row, 0) -
102 indices_mat(start_sparse_index_of_cur_dense_row, 0);
103 ++i) {
104 segment_indices.push_back(start_sparse_index_of_next_dense_row);
105 }
106 // If there is more than one row between the current and next non-empty
107 // rows then those rows are empty.
108 *contains_empty_rows |=
109 indices_mat(start_sparse_index_of_next_dense_row, 0) -
110 indices_mat(start_sparse_index_of_cur_dense_row, 0) >
111 1;
112 start_sparse_index_of_cur_dense_row = start_sparse_index_of_next_dense_row;
113 }
114 return segment_indices;
115 }
116
117 template <typename Tindices>
ParseRowStartIndices(const tensorflow::Tensor & tensor,const Tindices num_nonzero_entries_in_sparse_mat)118 std::vector<Tindices> ParseRowStartIndices(
119 const tensorflow::Tensor& tensor,
120 const Tindices num_nonzero_entries_in_sparse_mat) {
121 std::vector<Tindices> out;
122 auto vec = tensor.vec<Tindices>();
123 out.reserve(vec.size() + 1);
124 for (size_t i = 0; i < vec.dimension(0); ++i) {
125 out.push_back(vec(i));
126 }
127 out.push_back(num_nonzero_entries_in_sparse_mat);
128 return out;
129 }
130
131 template <typename Tindices>
ContainsEmptyRows(const std::vector<Tindices> & row_start_indices)132 bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
133 // Skip checking the length of the last dense row since it is
134 // always non-empty.
135 for (size_t i = 1; i < row_start_indices.size() - 1; ++i) {
136 if (row_start_indices.at(i) - row_start_indices.at(i - 1) == 0) {
137 return true;
138 }
139 }
140 return false;
141 }
142
143 #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex) \
144 template TypeIndex FindNextDenseRowStartIndex<TypeIndex>( \
145 const TypeIndex sparse_index_begin, \
146 const TTypes<TypeIndex>::ConstMatrix& indices_mat); \
147 template std::vector<TypeIndex> GetStartIndicesOfEachDenseRow<TypeIndex>( \
148 const TTypes<TypeIndex>::ConstMatrix& indices_mat, \
149 bool* contains_empty_rows); \
150 template bool ContainsEmptyRows<TypeIndex>( \
151 const std::vector<TypeIndex>& row_start_indices); \
152 template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>( \
153 const tensorflow::Tensor& tensor, \
154 const TypeIndex num_nonzero_entries_in_sparse_mat);
155
156 REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
157 REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
158 REGISTER_SPARSE_UTIL_FUNCTIONS(uint8);
159 REGISTER_SPARSE_UTIL_FUNCTIONS(uint16);
160 REGISTER_SPARSE_UTIL_FUNCTIONS(uint32);
161 REGISTER_SPARSE_UTIL_FUNCTIONS(uint64);
162
163 } // namespace sparse_utils
164 } // namespace tensorflow
165