• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/sparse_utils.h"
17 
18 #include <cstddef>
19 
20 #include "tensorflow/core/framework/tensor_shape.h"
21 
22 namespace tensorflow {
23 namespace sparse_utils {
24 
25 template <typename Tindices>
FindNextDenseRowStartIndex(const Tindices sparse_index_begin,const typename TTypes<Tindices>::ConstMatrix & indices_mat)26 Tindices FindNextDenseRowStartIndex(
27     const Tindices sparse_index_begin,
28     const typename TTypes<Tindices>::ConstMatrix& indices_mat) {
29   // Search in the index range [begin, end) of indices_mat.
30   Tindices begin = sparse_index_begin;
31   Tindices end = indices_mat.dimension(0);
32   const Tindices orig_sparse_index_end = end;
33 
34   // The first dense row we search.
35   const Tindices orig_dense_index_begin = indices_mat(begin, 0);
36   // Early exit if no next dense row index.
37   if (orig_dense_index_begin == static_cast<int64>(indices_mat(end - 1, 0))) {
38     return orig_sparse_index_end;
39   }
40 
41   Tindices increment = 1;
42   while (begin + increment < end &&
43          indices_mat(begin + increment, 0) == orig_dense_index_begin) {
44     increment *= 2;
45   }
46   // Narrow the search space as an optimization.
47   if (begin + increment < end) {
48     end = begin + increment;
49   }
50   begin += increment / 2;
51 
52   // Perform a binary search on the interval [begin, end) for
53   // dense_row_index_to_find.
54   const Tindices dense_row_index_to_find = orig_dense_index_begin;
55   while (begin < end) {
56     const Tindices m = begin + (end - begin) / 2;
57     const Tindices m_dense_row_index = static_cast<Tindices>(indices_mat(m, 0));
58     if (m_dense_row_index == dense_row_index_to_find &&
59         (m + 1 == orig_sparse_index_end ||
60          static_cast<Tindices>(indices_mat(m + 1, 0)) !=
61              dense_row_index_to_find)) {
62       return m + 1;
63     } else if (m_dense_row_index <= dense_row_index_to_find) {
64       begin = m + 1;
65     } else {
66       end = m;
67     }
68   }
69 
70   // No next dense row index.
71   return orig_sparse_index_end;
72 }
73 
74 template <typename Tindices>
GetStartIndicesOfEachDenseRow(const typename TTypes<Tindices>::ConstMatrix & indices_mat,bool * contains_empty_rows)75 std::vector<Tindices> GetStartIndicesOfEachDenseRow(
76     const typename TTypes<Tindices>::ConstMatrix& indices_mat,
77     bool* contains_empty_rows) {
78   int64_t start_sparse_index_of_cur_dense_row = 0;
79   std::vector<Tindices> segment_indices;
80   const Tindices num_entries_in_sparse_tensor = indices_mat.dimension(0);
81   const Tindices num_dense_rows_in_sparse_tensor =
82       1 + indices_mat(num_entries_in_sparse_tensor - 1, 0);
83   // Reserve an extra slot for the 0 we store in the first entry by convention.
84   segment_indices.reserve(1 + num_dense_rows_in_sparse_tensor);
85   segment_indices.push_back(0);
86   for (Tindices i = 0; i < indices_mat(0, 0); ++i) {
87     segment_indices.push_back(0);
88   }
89   *contains_empty_rows = indices_mat(0, 0) > 0;
90   while (true) {
91     const Tindices start_sparse_index_of_next_dense_row =
92         FindNextDenseRowStartIndex<Tindices>(
93             start_sparse_index_of_cur_dense_row, indices_mat);
94     if (start_sparse_index_of_next_dense_row == num_entries_in_sparse_tensor) {
95       segment_indices.push_back(start_sparse_index_of_next_dense_row);
96       break;
97     }
98     // Encode the length of the current dense row as well as the lengths of all
99     // the empty rows until the next dense row,
100     for (Tindices i = 0;
101          i < indices_mat(start_sparse_index_of_next_dense_row, 0) -
102                  indices_mat(start_sparse_index_of_cur_dense_row, 0);
103          ++i) {
104       segment_indices.push_back(start_sparse_index_of_next_dense_row);
105     }
106     // If there is more than one row between the current and next non-empty
107     // rows then those rows are empty.
108     *contains_empty_rows |=
109         indices_mat(start_sparse_index_of_next_dense_row, 0) -
110             indices_mat(start_sparse_index_of_cur_dense_row, 0) >
111         1;
112     start_sparse_index_of_cur_dense_row = start_sparse_index_of_next_dense_row;
113   }
114   return segment_indices;
115 }
116 
117 template <typename Tindices>
ParseRowStartIndices(const tensorflow::Tensor & tensor,const Tindices num_nonzero_entries_in_sparse_mat)118 std::vector<Tindices> ParseRowStartIndices(
119     const tensorflow::Tensor& tensor,
120     const Tindices num_nonzero_entries_in_sparse_mat) {
121   std::vector<Tindices> out;
122   auto vec = tensor.vec<Tindices>();
123   out.reserve(vec.size() + 1);
124   for (size_t i = 0; i < vec.dimension(0); ++i) {
125     out.push_back(vec(i));
126   }
127   out.push_back(num_nonzero_entries_in_sparse_mat);
128   return out;
129 }
130 
131 template <typename Tindices>
ContainsEmptyRows(const std::vector<Tindices> & row_start_indices)132 bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices) {
133   // Skip checking the length of the last dense row since it is
134   // always non-empty.
135   for (size_t i = 1; i < row_start_indices.size() - 1; ++i) {
136     if (row_start_indices.at(i) - row_start_indices.at(i - 1) == 0) {
137       return true;
138     }
139   }
140   return false;
141 }
142 
143 #define REGISTER_SPARSE_UTIL_FUNCTIONS(TypeIndex)                           \
144   template TypeIndex FindNextDenseRowStartIndex<TypeIndex>(                 \
145       const TypeIndex sparse_index_begin,                                   \
146       const TTypes<TypeIndex>::ConstMatrix& indices_mat);                   \
147   template std::vector<TypeIndex> GetStartIndicesOfEachDenseRow<TypeIndex>( \
148       const TTypes<TypeIndex>::ConstMatrix& indices_mat,                    \
149       bool* contains_empty_rows);                                           \
150   template bool ContainsEmptyRows<TypeIndex>(                               \
151       const std::vector<TypeIndex>& row_start_indices);                     \
152   template std::vector<TypeIndex> ParseRowStartIndices<TypeIndex>(          \
153       const tensorflow::Tensor& tensor,                                     \
154       const TypeIndex num_nonzero_entries_in_sparse_mat);
155 
156 REGISTER_SPARSE_UTIL_FUNCTIONS(int32);
157 REGISTER_SPARSE_UTIL_FUNCTIONS(int64);
158 REGISTER_SPARSE_UTIL_FUNCTIONS(uint8);
159 REGISTER_SPARSE_UTIL_FUNCTIONS(uint16);
160 REGISTER_SPARSE_UTIL_FUNCTIONS(uint32);
161 REGISTER_SPARSE_UTIL_FUNCTIONS(uint64);
162 
163 }  // namespace sparse_utils
164 }  // namespace tensorflow
165