• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/sparse/sparse_tensor.h"
17 
18 #include "tensorflow/core/lib/strings/strcat.h"
19 
20 namespace tensorflow {
21 namespace sparse {
22 
23 namespace {
24 
UnsafeGetDimsFromIx(const Tensor & ix)25 int UnsafeGetDimsFromIx(const Tensor& ix) {
26   DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
27   return ix.dim_size(1);
28 }
29 
GetDimsFromIx(const Tensor & ix,int * result)30 Status GetDimsFromIx(const Tensor& ix, int* result) {
31   if (!TensorShapeUtils::IsMatrix(ix.shape())) {
32     return errors::InvalidArgument("indices must be a matrix, but got: ",
33                                    ix.shape().DebugString());
34   }
35   *result = UnsafeGetDimsFromIx(ix);
36   return Status();
37 }
38 
39 }  // namespace
40 
Create(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order,SparseTensor * result)41 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
42                                          const VarDimArray shape,
43                                          const VarDimArray order,
44                                          SparseTensor* result) {
45   if (ix.dtype() != DT_INT64) {
46     return errors::InvalidArgument("indices must be type int64 but got: ",
47                                    ix.dtype());
48   }
49   if (!TensorShapeUtils::IsVector(vals.shape())) {
50     return errors::InvalidArgument("vals must be a vec, but got: ",
51                                    vals.shape().DebugString());
52   }
53   if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
54     return errors::InvalidArgument(
55         "indices and values rows (indexing "
56         "dimension) must match. (indices = ",
57         ix.shape().dim_size(0), ", values = ", vals.shape().dim_size(0), ")");
58   }
59   int dims = 0;
60   TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
61   if (order.size() != dims) {
62     return errors::InvalidArgument("Order length must be SparseTensor rank.");
63   }
64   if (shape.size() != dims) {
65     return errors::InvalidArgument("Shape rank must be SparseTensor rank.");
66   }
67 
68   result->ix_ = std::move(ix);
69   result->vals_ = std::move(vals);
70   result->shape_.assign(shape.begin(), shape.end());
71   result->order_.assign(order.begin(), order.end());
72   result->dims_ = dims;
73   return Status::OK();
74 }
75 
Create(Tensor ix,Tensor vals,const TensorShape & shape,SparseTensor * result)76 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
77                                          const TensorShape& shape,
78                                          SparseTensor* result) {
79   return Create(std::move(ix), std::move(vals), TensorShapeToVector(shape),
80                 UndefinedOrder(TensorShapeToVector(shape)), result);
81 }
82 
Create(Tensor ix,Tensor vals,const VarDimArray shape,SparseTensor * result)83 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
84                                          const VarDimArray shape,
85                                          SparseTensor* result) {
86   return Create(std::move(ix), std::move(vals), shape, UndefinedOrder(shape),
87                 result);
88 }
89 
Create(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order,SparseTensor * result)90 /* static */ Status SparseTensor::Create(Tensor ix, Tensor vals,
91                                          const TensorShape& shape,
92                                          const VarDimArray order,
93                                          SparseTensor* result) {
94   return Create(std::move(ix), std::move(vals), TensorShapeToVector(shape),
95                 order, result);
96 }
97 
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order)98 SparseTensor::SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
99                            const VarDimArray order)
100     : ix_(std::move(ix)),
101       vals_(std::move(vals)),
102       shape_(shape.begin(), shape.end()),
103       order_(order.begin(), order.end()),
104       dims_(UnsafeGetDimsFromIx(ix_)) {
105   DCHECK_EQ(ix_.dtype(), DT_INT64)
106       << "indices must be type int64 but got: " << ix_.dtype();
107   DCHECK(TensorShapeUtils::IsVector(vals_.shape()))
108       << "vals must be a vec, but got: " << vals_.shape().DebugString();
109   DCHECK_EQ(ix_.shape().dim_size(0), vals_.shape().dim_size(0))
110       << "indices and values rows (indexing dimension) must match.";
111   DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
112   DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
113 }
114 
115 // Optimized version of `IndicesValid()` with the following requirements:
116 // * The sparse tensor is one-dimensional.
117 //
118 // Returns true if the indices are valid, otherwise false.
119 // NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
120 // to obtain a meaningful error message.
IndicesValidVectorFastPath() const121 bool SparseTensor::IndicesValidVectorFastPath() const {
122   DCHECK_EQ(shape_.size(), 1);
123   DCHECK_EQ(order_[0], 0);
124 
125   const int64 max_index = shape_[0];
126 
127   // We maintain separate bools for each validation predicate to enable
128   // vectorization across loop iterations.
129   bool index_in_range_valid = true;
130   bool order_valid = true;
131 
132   int64 prev_index = -1;
133   const auto ix_t = ix_.matrix<int64>();
134   const int64* const index_base_ptr = ix_t.data();
135 
136   for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
137     const int64 index = index_base_ptr[n];
138     index_in_range_valid = index_in_range_valid & (index < max_index);
139     order_valid = order_valid & (index > prev_index);
140     prev_index = index;
141   }
142 
143   return index_in_range_valid & order_valid;
144 }
145 
146 // Optimized version of `IndicesValid()` with the following requirements:
147 // * The sparse tensor is two-dimensional.
148 // * The tensor's indices are in the "standard" (lexicographic) order.
149 // * All of the tensor's indices fit within the range of a signed int32.
150 //
151 // Returns true if the indices are valid, otherwise false.
152 // NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
153 // to obtain a meaningful error message.
IndicesValidMatrix32BitFastPath() const154 bool SparseTensor::IndicesValidMatrix32BitFastPath() const {
155   const auto ix_t = ix_.matrix<int64>();
156   const int64* const shape_ptr = shape_.data();
157 
158   DCHECK_EQ(shape_.size(), 2);
159   DCHECK_EQ(order_[0], 0);
160   DCHECK_EQ(order_[1], 1);
161   DCHECK_LE(shape_ptr[0], std::numeric_limits<int32>::max());
162   DCHECK_LE(shape_ptr[1], std::numeric_limits<int32>::max());
163 
164   const int32 max_rows = static_cast<int32>(shape_ptr[0]);
165   const int32 max_cols = static_cast<int32>(shape_ptr[1]);
166 
167   // We maintain separate bools for each validation predicate to enable
168   // vectorization across loop iterations.
169   bool row_zeros_valid = true;
170   bool row_in_range_valid = true;
171   bool col_zeros_valid = true;
172   bool col_in_range_valid = true;
173   bool order_valid = true;
174 
175   int64 prev_index = -1;
176 
177   // Points to the beginning of the current row of the indices matrix.
178   // Each row has two int64 elements, but we use an int32 pointer to access
179   // the low and high 32 bits of each element separately. This means that our
180   // stride per row is 4 elements.
181   const int32* const index_base_ptr =
182       reinterpret_cast<const int32*>(ix_t.data());
183   const size_t kInt32ElementsPerRow = 4;
184 
185   for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
186     const int32* const index_ptr = index_base_ptr + n * kInt32ElementsPerRow;
187 
188     // Unpack the values on the current row of the indices matrix.
189 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
190     const int32 row_zeros = index_ptr[0];
191     const int32 row_32 = index_ptr[1];
192     const int32 col_zeros = index_ptr[2];
193     const int32 col_32 = index_ptr[3];
194 #else
195     const int32 row_32 = index_ptr[0];
196     const int32 row_zeros = index_ptr[1];
197     const int32 col_32 = index_ptr[2];
198     const int32 col_zeros = index_ptr[3];
199 #endif
200 
201     // Validate that the high 32 bits of the row and column indices are zero.
202     row_zeros_valid = row_zeros_valid & (row_zeros == 0);
203     col_zeros_valid = col_zeros_valid & (col_zeros == 0);
204 
205     // Validate that the low 32 bits of the row and column indices are within
206     // range of the shape.
207     row_in_range_valid =
208         row_in_range_valid & (row_32 >= 0) & (row_32 < max_rows);
209     col_in_range_valid =
210         col_in_range_valid & (col_32 >= 0) & (col_32 < max_cols);
211 
212     // Interpret the row and column as a concatenated 64-bit integer, and
213     // validate that the concatenated indices are in strictly increasing order.
214     const int64 concatenated_index =
215         (static_cast<int64>(row_32) << 32) + col_32;
216     order_valid = order_valid & (concatenated_index > prev_index);
217     prev_index = concatenated_index;
218   }
219 
220   return row_zeros_valid & row_in_range_valid & col_zeros_valid &
221          col_in_range_valid & order_valid;
222 }
223 
224 template <bool standard_order>
IndicesValidHelper() const225 Status SparseTensor::IndicesValidHelper() const {
226   const auto ix_t = ix_.matrix<int64>();
227   const int64* const shape_ptr = shape_.data();
228 
229   for (std::size_t n = 0; n < num_entries(); ++n) {
230     bool valid = true;
231     bool different = false;
232     bool increasing = true;
233     if (n == 0) {
234       for (int di = 0; di < dims_; ++di) {
235         if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_ptr[di]) valid = false;
236       }
237       different = true;
238     } else {
239       for (int di = 0; di < dims_; ++di) {
240         if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_ptr[di]) valid = false;
241         int ordered_dim;
242         if (standard_order) {
243           ordered_dim = di;
244         } else {
245           ordered_dim = order_[di];
246         }
247         int64 diff = ix_t(n, ordered_dim) - ix_t(n - 1, ordered_dim);
248         if (diff > 0) different = true;
249         if (!different && diff < 0) increasing = false;
250       }
251     }
252     if (TF_PREDICT_FALSE(!valid || !increasing || !different)) {
253       string index = strings::StrCat("indices[", n, "] = [");
254       for (int di = 0; di < dims_; ++di) {
255         strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]");
256       }
257       if (!valid) {
258         return errors::InvalidArgument(index,
259                                        " is out of bounds: need 0 <= index < [",
260                                        str_util::Join(shape_, ","), "]");
261       }
262       if (!increasing) {
263         return errors::InvalidArgument(
264             index,
265             " is out of order. Many sparse ops require sorted indices.\n"
266             "    Use `tf.sparse.reorder` to create a correctly ordered copy."
267             "\n\n");
268       }
269       if (!different) {
270         return errors::InvalidArgument(index, " is repeated");
271       }
272     }
273   }
274 
275   return Status::OK();
276 }
277 
IndicesValid() const278 Status SparseTensor::IndicesValid() const {
279   if (shape_.size() == 1 && IndicesValidVectorFastPath()) {
280     return Status::OK();
281   }
282 
283   bool standard_order = true;
284   for (size_t i = 0; i < order_.size(); ++i) {
285     if (order_[i] < 0) {
286       return errors::FailedPrecondition(
287           "Order was not provided.  Provide an order at "
288           "construction time or run ReorderInPlace");
289     }
290     standard_order = standard_order && order_[i] == i;
291   }
292 
293   if (standard_order) {
294     if (shape_.size() == 1) {
295       if (IndicesValidVectorFastPath()) {
296         return Status::OK();
297       }
298     } else if (shape_.size() == 2 &&
299                shape_[0] <= std::numeric_limits<int32>::max() &&
300                shape_[1] <= std::numeric_limits<int32>::max()) {
301       if (IndicesValidMatrix32BitFastPath()) {
302         return Status::OK();
303       }
304     }
305     return IndicesValidHelper<true>();
306   } else {
307     return IndicesValidHelper<false>();
308   }
309 }
310 
311 }  // namespace sparse
312 }  // namespace tensorflow
313