1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/sparse/sparse_tensor.h"
17
18 #include "tensorflow/core/lib/strings/strcat.h"
19
20 namespace tensorflow {
21 namespace sparse {
22
23 namespace {
24
UnsafeGetDimsFromIx(const Tensor & ix)25 int UnsafeGetDimsFromIx(const Tensor& ix) {
26 DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
27 return ix.dim_size(1);
28 }
29
GetDimsFromIx(const Tensor & ix,int * result)30 Status GetDimsFromIx(const Tensor& ix, int* result) {
31 if (!TensorShapeUtils::IsMatrix(ix.shape())) {
32 return errors::InvalidArgument("indices must be a matrix, but got: ",
33 ix.shape().DebugString());
34 }
35 *result = UnsafeGetDimsFromIx(ix);
36 return Status();
37 }
38
39 } // namespace
40
Create(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order,SparseTensor * result)41 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
42 const VarDimArray shape,
43 const VarDimArray order,
44 SparseTensor* result) {
45 if (ix.dtype() != DT_INT64) {
46 return errors::InvalidArgument("indices must be type int64 but got: ",
47 ix.dtype());
48 }
49 if (!TensorShapeUtils::IsVector(vals.shape())) {
50 return errors::InvalidArgument("vals must be a vec, but got: ",
51 vals.shape().DebugString());
52 }
53 if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
54 return errors::InvalidArgument(
55 "indices and values rows (indexing "
56 "dimension) must match. (indices = ",
57 ix.shape().dim_size(0), ", values = ", vals.shape().dim_size(0), ")");
58 }
59 int dims = 0;
60 TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
61 if (order.size() != dims) {
62 return errors::InvalidArgument("Order length must be SparseTensor rank.");
63 }
64 if (shape.size() != dims) {
65 return errors::InvalidArgument("Shape rank must be SparseTensor rank.");
66 }
67
68 result->ix_ = std::move(ix);
69 result->vals_ = std::move(vals);
70 result->shape_.assign(shape.begin(), shape.end());
71 result->order_.assign(order.begin(), order.end());
72 result->dims_ = dims;
73 return Status::OK();
74 }
75
Create(Tensor ix,Tensor vals,const TensorShape & shape,SparseTensor * result)76 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
77 const TensorShape& shape,
78 SparseTensor* result) {
79 return Create(std::move(ix), std::move(vals), TensorShapeToVector(shape),
80 UndefinedOrder(TensorShapeToVector(shape)), result);
81 }
82
Create(Tensor ix,Tensor vals,const VarDimArray shape,SparseTensor * result)83 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
84 const VarDimArray shape,
85 SparseTensor* result) {
86 return Create(std::move(ix), std::move(vals), shape, UndefinedOrder(shape),
87 result);
88 }
89
Create(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order,SparseTensor * result)90 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
91 const TensorShape& shape,
92 const VarDimArray order,
93 SparseTensor* result) {
94 return Create(std::move(ix), std::move(vals), TensorShapeToVector(shape),
95 order, result);
96 }
97
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order)98 SparseTensor::SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
99 const VarDimArray order)
100 : ix_(std::move(ix)),
101 vals_(std::move(vals)),
102 shape_(shape.begin(), shape.end()),
103 order_(order.begin(), order.end()),
104 dims_(UnsafeGetDimsFromIx(ix_)) {
105 DCHECK_EQ(ix_.dtype(), DT_INT64)
106 << "indices must be type int64 but got: " << ix_.dtype();
107 DCHECK(TensorShapeUtils::IsVector(vals_.shape()))
108 << "vals must be a vec, but got: " << vals_.shape().DebugString();
109 DCHECK_EQ(ix_.shape().dim_size(0), vals_.shape().dim_size(0))
110 << "indices and values rows (indexing dimension) must match.";
111 DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
112 DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
113 }
114
115 // Optimized version of `IndicesValid()` with the following requirements:
116 // * The sparse tensor is one-dimensional.
117 //
118 // Returns true if the indices are valid, otherwise false.
119 // NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
120 // to obtain a meaningful error message.
IndicesValidVectorFastPath() const121 bool SparseTensor::IndicesValidVectorFastPath() const {
122 DCHECK_EQ(shape_.size(), 1);
123 DCHECK_EQ(order_[0], 0);
124
125 const int64 max_index = shape_[0];
126
127 // We maintain separate bools for each validation predicate to enable
128 // vectorization across loop iterations.
129 bool index_in_range_valid = true;
130 bool order_valid = true;
131
132 int64 prev_index = -1;
133 const auto ix_t = ix_.matrix<int64>();
134 const int64* const index_base_ptr = ix_t.data();
135
136 for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
137 const int64 index = index_base_ptr[n];
138 index_in_range_valid = index_in_range_valid & (index < max_index);
139 order_valid = order_valid & (index > prev_index);
140 prev_index = index;
141 }
142
143 return index_in_range_valid & order_valid;
144 }
145
146 // Optimized version of `IndicesValid()` with the following requirements:
147 // * The sparse tensor is two-dimensional.
148 // * The tensor's indices are in the "standard" (lexicographic) order.
149 // * All of the tensor's indices fit within the range of a signed int32.
150 //
151 // Returns true if the indices are valid, otherwise false.
152 // NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
153 // to obtain a meaningful error message.
IndicesValidMatrix32BitFastPath() const154 bool SparseTensor::IndicesValidMatrix32BitFastPath() const {
155 const auto ix_t = ix_.matrix<int64>();
156 const int64* const shape_ptr = shape_.data();
157
158 DCHECK_EQ(shape_.size(), 2);
159 DCHECK_EQ(order_[0], 0);
160 DCHECK_EQ(order_[1], 1);
161 DCHECK_LE(shape_ptr[0], std::numeric_limits<int32>::max());
162 DCHECK_LE(shape_ptr[1], std::numeric_limits<int32>::max());
163
164 const int32 max_rows = static_cast<int32>(shape_ptr[0]);
165 const int32 max_cols = static_cast<int32>(shape_ptr[1]);
166
167 // We maintain separate bools for each validation predicate to enable
168 // vectorization across loop iterations.
169 bool row_zeros_valid = true;
170 bool row_in_range_valid = true;
171 bool col_zeros_valid = true;
172 bool col_in_range_valid = true;
173 bool order_valid = true;
174
175 int64 prev_index = -1;
176
177 // Points to the beginning of the current row of the indices matrix.
178 // Each row has two int64 elements, but we use an int32 pointer to access
179 // the low and high 32 bits of each element separately. This means that our
180 // stride per row is 4 elements.
181 const int32* const index_base_ptr =
182 reinterpret_cast<const int32*>(ix_t.data());
183 const size_t kInt32ElementsPerRow = 4;
184
185 for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
186 const int32* const index_ptr = index_base_ptr + n * kInt32ElementsPerRow;
187
188 // Unpack the values on the current row of the indices matrix.
189 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
190 const int32 row_zeros = index_ptr[0];
191 const int32 row_32 = index_ptr[1];
192 const int32 col_zeros = index_ptr[2];
193 const int32 col_32 = index_ptr[3];
194 #else
195 const int32 row_32 = index_ptr[0];
196 const int32 row_zeros = index_ptr[1];
197 const int32 col_32 = index_ptr[2];
198 const int32 col_zeros = index_ptr[3];
199 #endif
200
201 // Validate that the high 32 bits of the row and column indices are zero.
202 row_zeros_valid = row_zeros_valid & (row_zeros == 0);
203 col_zeros_valid = col_zeros_valid & (col_zeros == 0);
204
205 // Validate that the low 32 bits of the row and column indices are within
206 // range of the shape.
207 row_in_range_valid =
208 row_in_range_valid & (row_32 >= 0) & (row_32 < max_rows);
209 col_in_range_valid =
210 col_in_range_valid & (col_32 >= 0) & (col_32 < max_cols);
211
212 // Interpret the row and column as a concatenated 64-bit integer, and
213 // validate that the concatenated indices are in strictly increasing order.
214 const int64 concatenated_index =
215 (static_cast<int64>(row_32) << 32) + col_32;
216 order_valid = order_valid & (concatenated_index > prev_index);
217 prev_index = concatenated_index;
218 }
219
220 return row_zeros_valid & row_in_range_valid & col_zeros_valid &
221 col_in_range_valid & order_valid;
222 }
223
224 template <bool standard_order>
IndicesValidHelper() const225 Status SparseTensor::IndicesValidHelper() const {
226 const auto ix_t = ix_.matrix<int64>();
227 const int64* const shape_ptr = shape_.data();
228
229 for (std::size_t n = 0; n < num_entries(); ++n) {
230 bool valid = true;
231 bool different = false;
232 bool increasing = true;
233 if (n == 0) {
234 for (int di = 0; di < dims_; ++di) {
235 if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_ptr[di]) valid = false;
236 }
237 different = true;
238 } else {
239 for (int di = 0; di < dims_; ++di) {
240 if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_ptr[di]) valid = false;
241 int ordered_dim;
242 if (standard_order) {
243 ordered_dim = di;
244 } else {
245 ordered_dim = order_[di];
246 }
247 int64 diff = ix_t(n, ordered_dim) - ix_t(n - 1, ordered_dim);
248 if (diff > 0) different = true;
249 if (!different && diff < 0) increasing = false;
250 }
251 }
252 if (TF_PREDICT_FALSE(!valid || !increasing || !different)) {
253 string index = strings::StrCat("indices[", n, "] = [");
254 for (int di = 0; di < dims_; ++di) {
255 strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]");
256 }
257 if (!valid) {
258 return errors::InvalidArgument(index,
259 " is out of bounds: need 0 <= index < [",
260 str_util::Join(shape_, ","), "]");
261 }
262 if (!increasing) {
263 return errors::InvalidArgument(
264 index,
265 " is out of order. Many sparse ops require sorted indices.\n"
266 " Use `tf.sparse.reorder` to create a correctly ordered copy."
267 "\n\n");
268 }
269 if (!different) {
270 return errors::InvalidArgument(index, " is repeated");
271 }
272 }
273 }
274
275 return Status::OK();
276 }
277
IndicesValid() const278 Status SparseTensor::IndicesValid() const {
279 if (shape_.size() == 1 && IndicesValidVectorFastPath()) {
280 return Status::OK();
281 }
282
283 bool standard_order = true;
284 for (size_t i = 0; i < order_.size(); ++i) {
285 if (order_[i] < 0) {
286 return errors::FailedPrecondition(
287 "Order was not provided. Provide an order at "
288 "construction time or run ReorderInPlace");
289 }
290 standard_order = standard_order && order_[i] == i;
291 }
292
293 if (standard_order) {
294 if (shape_.size() == 1) {
295 if (IndicesValidVectorFastPath()) {
296 return Status::OK();
297 }
298 } else if (shape_.size() == 2 &&
299 shape_[0] <= std::numeric_limits<int32>::max() &&
300 shape_[1] <= std::numeric_limits<int32>::max()) {
301 if (IndicesValidMatrix32BitFastPath()) {
302 return Status::OK();
303 }
304 }
305 return IndicesValidHelper<true>();
306 } else {
307 return IndicesValidHelper<false>();
308 }
309 }
310
311 } // namespace sparse
312 } // namespace tensorflow
313