1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Utility class for managing sparse array indices.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
19 #define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
20
21 #include <vector>
22
23 #include "absl/container/inlined_vector.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/array2d.h"
26 #include "tensorflow/compiler/xla/index_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28
29 namespace xla {
30
31 // Encapsulates the array of indices for a sparse array. A SparseIndexArray
32 // contain indices for up to `max_indices` elements of a sparse array. Each
33 // sparse index is an array of `rank` int64 value that gives the location of a
34 // value within a sparse array. Note that the dimensions of the array are not
35 // checked (except for the rank). To avoid confusion, we refer to the position
36 // of an index within a SparseIndexArray as a sparse index number.
37 class SparseIndexArray {
38 public:
39 SparseIndexArray();
40 SparseIndexArray(const SparseIndexArray&) = default;
41 SparseIndexArray(SparseIndexArray&&) = default;
42 SparseIndexArray& operator=(const SparseIndexArray&) = default;
43 SparseIndexArray& operator=(SparseIndexArray&&) = default;
44
45 // Constructs a SparseIndexArray that can hold up to `max_indices` sparse
46 // indices, with an initial contents obtained from the given array. The rank
47 // is taken from the minor dimension of the array. The major dimension of the
48 // array must not exceed `max_indices`.
49 SparseIndexArray(int64 max_indices, const Array2D<int64>& indices);
50
51 // Like above, but the array is flattened. For example, the following are
52 // equivalent:
53 //
54 // SparseIndexArray(10, 3,
55 // Array2D{
56 // {0, 1, 2},
57 // {3, 4, 5},
58 // {6, 7, 8},
59 // {9, 10, 11},
60 // })
61 //
62 // SparseIndexArray(10, 3,
63 // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
64 //
65 SparseIndexArray(int64 max_indices, int64 rank,
66 std::vector<int64> indices = {});
67 SparseIndexArray(int64 max_indices, int64 rank,
68 absl::Span<const int64> indices);
69
70 // Returns the number of elements represented by the indices stored in the
71 // array.
72 int64 index_count() const;
73
74 // Returns a slice that refers to the given sparse index number. The argument
75 // must be in the range [0, element_count()).
76 absl::Span<const int64> At(int64 sparse_element_number) const;
77 absl::Span<int64> At(int64 sparse_element_number);
78
79 // Adds the given index at the end of the array. The new size of the
80 // SparseIndexArray must not exceed `max_indices`.
81 void Append(absl::Span<const int64> index);
82
83 // Removes all indices from the array.
84 void Clear();
85
86 // Resizes the array to contain the given number of sparse indices. The new
87 // size must be smaller than `max_indices`. If the new size is larger than
88 // the old size, the value of the new indices is not specified.
89 void Resize(int64 num_indices);
90
91 // Returns true iff all indices are unique and occur in sorted order, and are
92 // valid for the given shape.
93 bool Validate(const Shape& shape) const;
94
rank()95 int64 rank() const { return rank_; }
max_indices()96 int64 max_indices() const { return max_indices_; }
97
98 // Returns a pointer to the int64 array that holds the sparse indices.
mutable_data()99 absl::Span<int64> mutable_data() { return absl::MakeSpan(indices_); }
data()100 absl::Span<const int64> data() const { return indices_; }
101
102 // Sorts this sparse index array along with the set of corresponding values.
103 // The indices and values are sorted in the lexicographic order of the
104 // indices, from smallest to largest.
105 //
106 // For example:
107 //
108 // std::vector<float> v{10.0, 11.0, 12.0};
109 // SparseIndexArray a(10, 3,
110 // {{3, 4, 5},
111 // {1, 2, 3},
112 // {2, 3, 4}});
113 // a.SortWithValues(&v);
114 // // Prints "11.0, 12.0, 10.0":
115 // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl;
116 //
117 template <typename NativeT>
118 void SortWithValues(absl::Span<NativeT> values);
119
120 private:
121 std::vector<int64> indices_;
122 int64 rank_;
123 int64 max_indices_;
124 };
125
126 template <typename NativeT>
SortWithValues(absl::Span<NativeT> values)127 void SparseIndexArray::SortWithValues(absl::Span<NativeT> values) {
128 int64 num_elements = index_count();
129 CHECK_EQ(values.size(), num_elements);
130 std::vector<int64> sort_order;
131 sort_order.reserve(num_elements);
132 for (int64 i = 0; i < num_elements; ++i) {
133 sort_order.push_back(i);
134 }
135 auto sort_order_less = [this](int64 lhs, int64 rhs) {
136 return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0;
137 };
138 absl::c_sort(sort_order, sort_order_less);
139
140 // Reorder the array elements according to sort_order. Work through the array
141 // and follow cycles so we can do the reorder in-place.
142 absl::InlinedVector<int64, 8> saved_index(rank());
143 for (int64 i = 0; i < num_elements; ++i) {
144 // sort_order[i] == -1 indicates the element has already been copied.
145 if (sort_order[i] < 0) {
146 continue;
147 } else if (i == sort_order[i]) {
148 // The element is already in sorted order.
149 sort_order[i] = -1;
150 continue;
151 }
152
153 std::copy_n(At(i).begin(), rank(), saved_index.begin());
154 NativeT saved_value = values[i];
155 int64 j = i;
156 for (;;) {
157 if (sort_order[j] == i) {
158 std::copy_n(saved_index.begin(), rank(), At(j).begin());
159 values[j] = saved_value;
160 sort_order[j] = -1;
161 break;
162 }
163
164 std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin());
165 values[j] = values[sort_order[j]];
166
167 int64 k = sort_order[j];
168 sort_order[j] = -1;
169 j = k;
170 }
171 }
172 }
173
174 } // namespace xla
175
176 #endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_
177