• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Helpers for writing OpKernels for sparse tensors.
17 #ifndef TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_
18 #define TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_
19 
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 namespace tensorflow {
27 namespace sparse_utils {
28 
29 // Find the index i of the first element for which
30 // indices_mat(sparse_index_begin, 0) < indices_mat(i, 0).
31 // The search is conducted in the open interval
32 // [sparse_index_begin, indices_mat.dimension(0)) and when no such i is found,
33 // indices_mat.dimension(0) is returned.
34 // indices_mat(k, 0) should be non-decreasing over the interval
35 // [begin, indices_mat.dimension(0)).
36 // Requires 0 <= sparse_index_begin < indices_mat.dimension(0).
37 template <typename Tindices>
38 Tindices FindNextDenseRowStartIndex(
39     const Tindices sparse_index_begin,
40     const typename TTypes<Tindices>::ConstMatrix& indices_mat);
41 
42 // Returns the vector v of indices in indices_mat at which new dense matrix
43 // rows begin.
44 // v.front() = 0, v.back() = indices_mat.dimension(0), and for i > 0,
45 // v[i] - v[i-1] is the length of the ith dense row in indices_mat.
46 // *contains_empty_rows = true if and only if indices_mat contains empty rows
47 // (rows without values) between row 0 and the last row.
48 template <typename Tindices>
49 std::vector<Tindices> GetStartIndicesOfEachDenseRow(
50     const typename TTypes<Tindices>::ConstMatrix& indices_mat,
51     bool* contains_empty_rows);
52 
53 // Converts tensor.vec<Tindices> to an std::vector<Tindices> object, appends
54 // the value num_nonzero_entries_in_sparse_mat, and returns the result.
55 template <typename Tindices>
56 std::vector<Tindices> ParseRowStartIndices(
57     const tensorflow::Tensor& tensor,
58     const Tindices num_nonzero_entries_in_sparse_mat);
59 
60 // Returns true if and only if the sparse matrix indices_mat whose row start
61 // indices are represented by row_start_indices has empty dense rows
62 // (between its first and last dense rows).
63 // This function satisfies the identity row_start_indices ==
64 // GetStartIndicesOfEachDenseRow(indices_mat, &return_value).
65 template <typename Tindices>
66 bool ContainsEmptyRows(const std::vector<Tindices>& row_start_indices);
67 
68 }  // namespace sparse_utils
69 }  // namespace tensorflow
70 
71 #endif  // TENSORFLOW_CORE_KERNELS_SPARSE_UTILS_H_
72