1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
18
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22
23 #include "absl/base/macros.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/sparse/dim_comparator.h"
37 #include "tensorflow/core/util/sparse/group_iterator.h"
38
39 namespace tensorflow {
40 namespace sparse {
41
42 class SparseTensor {
43 public:
44 typedef typename gtl::ArraySlice<int64> VarDimArray;
45 typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
46
Create(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order,SparseTensor * result)47 static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
48 const VarDimArray order, SparseTensor* result) {
49 if (ix.dtype() != DT_INT64) {
50 return Status(
51 error::INVALID_ARGUMENT,
52 strings::StrCat("indices must be type int64 but got: ", ix.dtype()));
53 }
54 if (!TensorShapeUtils::IsVector(vals.shape())) {
55 return Status(error::INVALID_ARGUMENT,
56 strings::StrCat("vals must be a vec, but got: ",
57 vals.shape().DebugString()));
58 }
59 if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
60 return Status(error::INVALID_ARGUMENT,
61 strings::StrCat("indices and values rows (indexing "
62 "dimension) must match. (indices = ",
63 ix.shape().dim_size(0), ", values = ",
64 vals.shape().dim_size(0), ")"));
65 }
66 int dims = 0;
67 TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
68 if (order.size() != dims) {
69 return Status(error::INVALID_ARGUMENT,
70 "Order length must be SparseTensor rank.");
71 }
72 if (shape.size() != dims) {
73 return Status(error::INVALID_ARGUMENT,
74 "Shape rank must be SparseTensor rank.");
75 }
76
77 *result = SparseTensor(ix, vals, shape, order);
78 return Status();
79 }
80
Create(Tensor ix,Tensor vals,const TensorShape & shape,SparseTensor * result)81 static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
82 SparseTensor* result) {
83 return Create(ix, vals, TensorShapeToVector(shape),
84 UndefinedOrder(TensorShapeToVector(shape)), result);
85 }
86
Create(Tensor ix,Tensor vals,const VarDimArray shape,SparseTensor * result)87 static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
88 SparseTensor* result) {
89 return Create(ix, vals, shape, UndefinedOrder(shape), result);
90 }
91
Create(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order,SparseTensor * result)92 static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
93 const VarDimArray order, SparseTensor* result) {
94 return Create(ix, vals, TensorShapeToVector(shape), order, result);
95 }
96
SparseTensor()97 SparseTensor() : dims_(0) {}
98
99 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape)100 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
101 : SparseTensor(ix, vals, TensorShapeToVector(shape),
102 UndefinedOrder(TensorShapeToVector(shape))) {}
103
104 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape)105 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
106 : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
107
108 ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order)109 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
110 const VarDimArray order)
111 : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
112
113 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order)114 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
115 const VarDimArray order)
116 : ix_(ix),
117 vals_(vals),
118 shape_(shape.begin(), shape.end()),
119 order_(order.begin(), order.end()),
120 dims_(UnsafeGetDimsFromIx(ix)) {
121 DCHECK_EQ(ix.dtype(), DT_INT64)
122 << "indices must be type int64 but got: " << ix.dtype();
123 DCHECK(TensorShapeUtils::IsVector(vals.shape()))
124 << "vals must be a vec, but got: " << vals.shape().DebugString();
125 DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
126 << "indices and values rows (indexing dimension) must match.";
127 DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
128 DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
129 }
130
SparseTensor(const SparseTensor & other)131 SparseTensor(const SparseTensor& other)
132 : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
133
SparseTensor(SparseTensor && other)134 SparseTensor(SparseTensor&& other)
135 : SparseTensor(std::move(other.ix_), std::move(other.vals_),
136 std::move(other.shape_), std::move(other.order_)) {}
137
138 SparseTensor& operator=(const SparseTensor& other) {
139 ix_ = other.ix_;
140 vals_ = other.vals_;
141 shape_ = other.shape_;
142 order_ = other.order_;
143 dims_ = other.dims_;
144 return *this;
145 }
146
147 SparseTensor& operator=(SparseTensor&& other) {
148 ix_ = std::move(other.ix_);
149 vals_ = std::move(other.vals_);
150 shape_ = std::move(other.shape_);
151 order_ = std::move(other.order_);
152 dims_ = std::move(other.dims_);
153 return *this;
154 }
155
num_entries()156 std::size_t num_entries() const { return ix_.dim_size(0); }
157
dims()158 int dims() const { return shape_.size(); }
159
indices()160 const Tensor& indices() const { return ix_; }
161
values()162 const Tensor& values() const { return vals_; }
163
dtype()164 DataType dtype() const { return vals_.dtype(); }
165
IndicesValid()166 Status IndicesValid() const {
167 const auto ix_t = ix_.matrix<int64>();
168 for (int64 ord : order_) {
169 if (ord < 0) {
170 return errors::FailedPrecondition(
171 "Order was not provided. Provide an order at "
172 "construction time or run ReorderInPlace");
173 }
174 }
175
176 for (std::size_t n = 0; n < num_entries(); ++n) {
177 TF_RETURN_IF_ERROR(IndexValid(ix_t, n));
178 }
179
180 return Status::OK();
181 }
182
shape()183 VarDimArray shape() const { return shape_; }
184
order()185 VarDimArray order() const { return order_; }
186
187 // Resorts the indices and values according to the dimensions in order.
188 template <typename T>
189 void Reorder(const VarDimArray& order);
190
191 // Returns a group iterable that can be used for clumping indices
192 // and values according to the group indices of interest.
193 //
194 // Precondition: order()[0..group_ix.size()] == group_ix.
195 //
196 // See the README.md in this directory for more usage information.
group(const VarDimArray & group_ix)197 GroupIterable group(const VarDimArray& group_ix) const {
198 DCHECK_LE(group_ix.size(), dims_);
199 for (std::size_t di = 0; di < group_ix.size(); ++di) {
200 DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
201 DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
202 DCHECK_EQ(group_ix[di], order_[di])
203 << "Group dimension does not match sorted order";
204 }
205 return GroupIterable(ix_, vals_, dims_, group_ix);
206 }
207
208 // Stores the sparse indices into the dense tensor out.
209 // Preconditions:
210 // out->shape().dims() == shape().dims()
211 // out->shape().dim_size(d) >= shape(d) for all d
212 //
213 // Returns true on success. False on failure (mismatched dimensions
214 // or out-of-bounds indices).
215 //
216 // If initialize==True, ToDense first overwrites all coefficients in out to 0.
217 //
218 template <typename T>
219 bool ToDense(Tensor* out, bool initialize = true);
220
221 // Concat() will concatenate all the tensors according to their first order
222 // dimension. All tensors must have identical shape except for
223 // the first order dimension. All tensors orders' first dimension
224 // must match.
225 //
226 // If all of the tensors have identical ordering, then the output
227 // will have this ordering. Otherwise the output is set as not
228 // having any order and a Reorder<T>() should be called on it before
229 // performing any subsequent operations.
230 template <typename T>
231 static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
232
233 // Split() will split the input SparseTensor into a list of num_split
234 // SparseTensor given a splitting dimension. If the input dimension range
235 // isn't an integer multiple of split_dim, we add one extra dimension for
236 // each slice.
237 template <typename T>
238 static Status Split(const SparseTensor& tensor, const int split_dim,
239 const int num_split, std::vector<SparseTensor>* result);
240
241 // Slice() will slice the input SparseTensor into a SparseTensor based on
242 // specified start and size. Both start and size are 1-D array with each
243 // element of the array representing one dimension. The start is the start
244 // index at each dimension and the size is the size at each dimension.
245 template <typename T>
246 static SparseTensor Slice(const SparseTensor& tensor,
247 const gtl::ArraySlice<int64>& start,
248 const gtl::ArraySlice<int64>& size);
249
250 // Picks out the dimensions according to `dim_indices`.
PickDims(gtl::ArraySlice<int64> dim_indices)251 std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const {
252 std::vector<int64> res(dim_indices.size());
253 for (size_t i = 0; i < dim_indices.size(); ++i) {
254 res[i] = shape_[dim_indices[i]];
255 }
256 return res;
257 }
258
259 private:
GetDimsFromIx(const Tensor & ix,int * result)260 static Status GetDimsFromIx(const Tensor& ix, int* result) {
261 if (!TensorShapeUtils::IsMatrix(ix.shape())) {
262 return Status(error::INVALID_ARGUMENT,
263 strings::StrCat("indices must be a matrix, but got: ",
264 ix.shape().DebugString()));
265 }
266 *result = UnsafeGetDimsFromIx(ix);
267 return Status();
268 }
269
UnsafeGetDimsFromIx(const Tensor & ix)270 static int UnsafeGetDimsFromIx(const Tensor& ix) {
271 DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
272 return ix.dim_size(1);
273 }
274
UndefinedOrder(const VarDimArray shape)275 static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
276 return ShapeArray(shape.size(), -1);
277 }
278
TensorShapeToVector(const TensorShape & shape)279 static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
280 ShapeArray vec(shape.dims());
281 for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
282 return vec;
283 }
284
285 // Helper for IndicesValid()
IndexValid(const TTypes<int64>::ConstMatrix & ix_t,int n)286 inline Status IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
287 int n) const {
288 bool valid = true;
289 bool different = false;
290 bool increasing = true;
291 if (n == 0) {
292 for (int di = 0; di < dims_; ++di) {
293 if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
294 }
295 different = true;
296 } else {
297 for (int di = 0; di < dims_; ++di) {
298 if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
299 int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
300 if (diff > 0) different = true;
301 if (!different && diff < 0) increasing = false;
302 }
303 }
304 if (TF_PREDICT_FALSE(!valid || !increasing || !different)) {
305 string index = strings::StrCat("indices[", n, "] = [");
306 for (int di = 0; di < dims_; ++di) {
307 strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]");
308 }
309 if (!valid) {
310 return errors::InvalidArgument(index,
311 " is out of bounds: need 0 <= index < [",
312 str_util::Join(shape_, ","), "]");
313 }
314 if (!increasing) {
315 return errors::InvalidArgument(index, " is out of order");
316 }
317 if (!different) {
318 return errors::InvalidArgument(index, " is repeated");
319 }
320 }
321 return Status::OK();
322 }
323
324 // Helper for ToDense<T>()
325 template <typename T>
326 bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
327
328 // Helper for Split() that returns the slice index.
GetSliceIndex(const int dim,const int split_size,const int residual)329 static inline int GetSliceIndex(const int dim, const int split_size,
330 const int residual) {
331 DCHECK_GT(split_size, 0);
332 DCHECK_GE(dim, 0);
333 if (residual == 0) return dim / split_size;
334 const int offset = residual * (split_size + 1);
335 if (dim < offset) {
336 return dim / (split_size + 1);
337 } else {
338 return residual + ((dim - offset) / split_size);
339 }
340 }
341
342 // Helper for Split() that returns the dimension in the slice.
GetDimensionInSlice(const int dim,const int split_size,const int residual)343 static inline int GetDimensionInSlice(const int dim, const int split_size,
344 const int residual) {
345 DCHECK_GT(split_size, 0);
346 DCHECK_GE(dim, 0);
347 if (residual == 0) return dim % split_size;
348 const int offset = residual * (split_size + 1);
349 if (dim < offset) {
350 return dim % (split_size + 1);
351 } else {
352 return (dim - offset) % split_size;
353 }
354 }
355
356 // Helper for Split() that returns the shape given a slice index.
GetSliceShape(const int slice_index,const int split_size,const int residual)357 static inline int GetSliceShape(const int slice_index, const int split_size,
358 const int residual) {
359 DCHECK_GT(split_size, 0);
360 DCHECK_GE(slice_index, 0);
361 if (residual == 0) return split_size;
362 if (slice_index < residual) {
363 return split_size + 1;
364 } else {
365 return split_size;
366 }
367 }
368
369 Tensor ix_;
370 Tensor vals_;
371 ShapeArray shape_;
372 ShapeArray order_;
373 int dims_;
374 };
375
376 // This operation updates the indices and values Tensor rows, so it is
377 // an in-place algorithm. It requires O(N log N) time and O(N)
378 // temporary space.
379 template <typename T>
Reorder(const VarDimArray & order)380 void SparseTensor::Reorder(const VarDimArray& order) {
381 DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
382 << "Reorder requested with the wrong datatype";
383 DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
384 auto ix_t = ix_.matrix<int64>();
385 auto vals_t = vals_.vec<T>();
386
387 std::vector<int64> reorder(num_entries());
388 std::iota(reorder.begin(), reorder.end(), 0);
389
390 // Sort to get order of indices
391 switch (order.size()) {
392 #define CASE_SORT(ORDER_SIZE) \
393 case ORDER_SIZE: { \
394 FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
395 std::sort(reorder.begin(), reorder.end(), sorter); \
396 break; \
397 }
398 CASE_SORT(0);
399 CASE_SORT(1);
400 CASE_SORT(2);
401 CASE_SORT(3);
402 CASE_SORT(4);
403 CASE_SORT(5);
404 #undef CASE_SORT
405 default: {
406 DimComparator sorter(ix_t, order, shape());
407 std::sort(reorder.begin(), reorder.end(), sorter);
408 }
409 }
410
411 // We have a forward reordering, but what we'll need is a
412 // permutation (the inverse). This can be calculated with O(1)
413 // additional
414 // and O(n) time (INVPERM) but we just do the simple thing here.
415 std::vector<size_t> permutation(reorder.size());
416 for (std::size_t n = 0; n < reorder.size(); ++n) {
417 permutation[reorder[n]] = n;
418 }
419
420 // Update indices & values by converting the permutations to
421 // a product of transpositions. Iterate over the cycles in the
422 // permutation, and convert each of those into a product of
423 // transpositions (swaps):
424 // https://en.wikipedia.org/wiki/Cyclic_permutation
425 // This is N swaps, 2*N comparisons.
426 for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
427 while (n != permutation[n]) {
428 std::size_t r = permutation[n];
429 std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
430 std::swap(vals_t(n), vals_t(r));
431 std::swap(permutation[n], permutation[r]);
432 }
433 }
434
435 order_ = ShapeArray(order.begin(), order.end());
436 }
437
438 template <typename T>
ValidateAndInitializeToDense(Tensor * out,bool initialize)439 bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
440 DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
441 << "ToDense requested with the wrong datatype";
442
443 DCHECK_EQ(out->shape().dims(), dims_)
444 << "Incompatible dimensions between SparseTensor and output";
445
446 DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
447 << "Output must be type: " << DataTypeToEnum<T>::v()
448 << " but got: " << out->dtype();
449
450 // Make sure the dense output is the same rank and has room
451 // to hold the SparseTensor.
452 const auto& out_shape = out->shape();
453 if (shape_.size() != out_shape.dims()) return false;
454 for (int d = 0; d < shape_.size(); ++d) {
455 if (shape_[d] > out_shape.dim_size(d)) return false;
456 }
457
458 if (initialize) {
459 auto out_t = out->flat<T>();
460 out_t.setConstant(T());
461 }
462
463 return true;
464 }
465
466 template <typename T>
ToDense(Tensor * out,bool initialize)467 bool SparseTensor::ToDense(Tensor* out, bool initialize) {
468 if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
469
470 auto out_t = out->flat<T>();
471 auto ix_t = ix_.matrix<int64>();
472 auto vals_t = vals_.vec<T>();
473
474 std::vector<int64> strides(dims_);
475 const auto& out_shape = out->shape();
476 if (dims_ > 0) {
477 strides[dims_ - 1] = 1;
478 }
479 for (int d = dims_ - 2; d >= 0; --d) {
480 strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
481 }
482
483 for (int n = 0; n < vals_t.dimension(0); ++n) {
484 bool invalid_dims = false;
485 int64 ix = 0;
486 for (int d = 0; d < dims_; ++d) {
487 const int64 ix_n_d = internal::SubtleMustCopy(ix_t(n, d));
488 if (!FastBoundsCheck(ix_n_d, out_shape.dim_size(d))) {
489 invalid_dims = true;
490 }
491 ix += strides[d] * ix_n_d;
492 }
493 if (invalid_dims) return false;
494 out_t(ix) = vals_t(n);
495 }
496 return true;
497 }
498
499 template <typename T>
Concat(const gtl::ArraySlice<SparseTensor> & tensors)500 SparseTensor SparseTensor::Concat(
501 const gtl::ArraySlice<SparseTensor>& tensors) {
502 DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
503 const int dims = tensors[0].dims_;
504 DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
505 auto order_0 = tensors[0].order();
506 const int primary_dim = order_0[0];
507 ShapeArray final_order(order_0.begin(), order_0.end());
508 ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
509 final_shape[primary_dim] = 0; // We'll build this up as we go along.
510 int num_entries = 0;
511
512 bool fully_ordered = true;
513 for (const SparseTensor& st : tensors) {
514 DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
515 DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
516 << "Concat requested with the wrong data type";
517 DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
518 DCHECK_EQ(st.order()[0], primary_dim)
519 << "All SparseTensors' order[0] must match. This is the concat dim.";
520 if (st.order() != final_order) fully_ordered = false;
521 const VarDimArray& st_shape = st.shape();
522 for (int d = 0; d < dims - 1; ++d) {
523 const int cdim = (d < primary_dim) ? d : d + 1;
524 DCHECK_EQ(final_shape[cdim], st_shape[cdim])
525 << "All SparseTensors' shapes must match except on the concat dim. "
526 << "Concat dim: " << primary_dim
527 << ", mismatched shape at dim: " << cdim
528 << ". Expecting shape like: [" << str_util::Join(final_shape, ",")
529 << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
530 }
531
532 // Update dimension of final shape
533 final_shape[primary_dim] =
534 (final_shape[primary_dim] + st_shape[primary_dim]);
535
536 num_entries += st.num_entries(); // Update number of entries
537 }
538
539 // If nonconsistent ordering among inputs, set final order to -1s.
540 if (!fully_ordered) {
541 final_order = UndefinedOrder(final_shape);
542 }
543
544 Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
545 Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
546
547 TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>();
548 typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
549
550 Eigen::DenseIndex offset = 0;
551 int64 shape_offset = 0;
552 for (const SparseTensor& st : tensors) {
553 const int st_num_entries = st.num_entries();
554
555 // Fill in indices & values.
556 std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
557
558 const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
559 auto* ix_out = &ix_t(offset, 0);
560 for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
561 *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
562 }
563
564 offset += st_num_entries;
565 shape_offset += st.shape()[primary_dim];
566 }
567
568 return SparseTensor(output_ix, output_vals, final_shape, final_order);
569 }
570
571 template <typename T>
Split(const SparseTensor & input_tensor,const int split_dim,const int num_split,std::vector<SparseTensor> * result)572 Status SparseTensor::Split(const SparseTensor& input_tensor,
573 const int split_dim, const int num_split,
574 std::vector<SparseTensor>* result) {
575 std::vector<Tensor> output_indices;
576 std::vector<Tensor> output_values;
577 std::vector<TensorShape> output_shapes;
578 output_indices.reserve(num_split);
579 output_values.reserve(num_split);
580 output_shapes.reserve(num_split);
581
582 std::vector<typename TTypes<int64>::Matrix> output_indices_t;
583 std::vector<typename TTypes<T>::Vec> output_values_t;
584 output_indices_t.reserve(num_split);
585 output_values_t.reserve(num_split);
586 auto input_values_t = input_tensor.values().vec<T>();
587 auto input_indices_t = input_tensor.indices().matrix<int64>();
588
589 std::vector<int> num_values(num_split, 0);
590 const int num_dim = input_tensor.shape().size();
591 const int split_dim_size = input_tensor.shape()[split_dim];
592 const int split_size = split_dim_size / num_split;
593
594 if (!(num_split > 0 && num_split <= split_dim_size)) {
595 return Status(error::INVALID_ARGUMENT,
596 strings::StrCat("num_split must be in the interval (0, ",
597 split_dim_size, "]"));
598 }
599 if (!(split_dim >= 0 && split_dim < num_dim)) {
600 return Status(
601 error::INVALID_ARGUMENT,
602 strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")"));
603 }
604
605 const int residual = split_dim_size % num_split;
606 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
607 const int dim = input_tensor.indices().matrix<int64>()(i, split_dim);
608 int slice_index = GetSliceIndex(dim, split_size, residual);
609 num_values[slice_index]++;
610 }
611
612 for (int i = 0; i < num_split; ++i) {
613 // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
614 output_indices.emplace_back(DT_INT64,
615 TensorShape({num_values[i], num_dim}));
616 output_values.emplace_back(DataTypeToEnum<T>::v(),
617 TensorShape({num_values[i]}));
618 output_shapes.emplace_back(input_tensor.shape());
619 output_indices_t.emplace_back(output_indices[i].matrix<int64>());
620 output_values_t.emplace_back(output_values[i].vec<T>());
621 const int size = GetSliceShape(i, split_size, residual);
622 output_shapes[i].set_dim(split_dim, size);
623 }
624
625 std::vector<int> values_inserted_in_slice(num_split, 0);
626 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
627 const int dim = input_indices_t(i, split_dim);
628 const int slice_index = GetSliceIndex(dim, split_size, residual);
629 const int slice_dim = values_inserted_in_slice[slice_index]++;
630 output_values_t[slice_index](slice_dim) = input_values_t(i);
631 for (int j = 0; j < num_dim; ++j) {
632 const int64 original_dim = input_indices_t(i, j);
633 output_indices_t[slice_index](slice_dim, j) =
634 (j == split_dim)
635 ? GetDimensionInSlice(original_dim, split_size, residual)
636 : original_dim;
637 }
638 }
639
640 result->clear();
641 result->reserve(num_split);
642 for (int i = 0; i < num_split; ++i) {
643 SparseTensor tensor;
644 Status create_status =
645 Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
646 if (!create_status.ok()) {
647 return create_status;
648 }
649 result->push_back(std::move(tensor));
650 }
651 return Status::OK();
652 }
653
654 template <typename T>
Slice(const SparseTensor & input_tensor,const gtl::ArraySlice<int64> & start,const gtl::ArraySlice<int64> & size)655 SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
656 const gtl::ArraySlice<int64>& start,
657 const gtl::ArraySlice<int64>& size) {
658 TensorShape output_shape(input_tensor.shape());
659
660 const int dims = input_tensor.dims();
661 for (int dim = 0; dim < dims; dim++) {
662 int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim)
663 ? size[dim]
664 : output_shape.dim_size(dim) - start[dim];
665 output_shape.set_dim(dim, dim_size);
666 }
667
668 auto input_indices_t = input_tensor.indices().matrix<int64>();
669 auto input_values_t = input_tensor.values().vec<T>();
670
671 // Find the number of indices that fall inside start and size.
672 int count = 0;
673 for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
674 // The following will check to see if an input is within the
675 // range specified by start and size.
676 // The for loop below iterates through all dimensions. In case
677 // the index falls outside of the start and size at any dimension,
678 // it will be considered as a "no hit" (hit = false). In this
679 // case, it will not be counted as the index that fall inside
680 // the range specified by start and size.
681 bool hit = true;
682 for (int dim = 0; dim < dims; dim++) {
683 if (!(start[dim] <= input_indices_t(i, dim) &&
684 input_indices_t(i, dim) < start[dim] + size[dim])) {
685 hit = false;
686 break;
687 }
688 }
689 if (!hit) {
690 continue;
691 }
692 count++;
693 }
694
695 Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
696 Tensor output_indices(DT_INT64, TensorShape({count, dims}));
697
698 auto output_values_t = output_values.vec<T>();
699 auto output_indices_t = output_indices.matrix<int64>();
700
701 // Obtain the output indices that fall inside start and size.
702 int index = 0;
703 for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
704 i++) {
705 // The logic here is similar as the above except that the above
706 // only count the number of indices while here we actually generate
707 // the output.
708 bool hit = true;
709 for (int dim = 0; dim < dims; dim++) {
710 if (!(start[dim] <= input_indices_t(i, dim) &&
711 input_indices_t(i, dim) < start[dim] + size[dim])) {
712 hit = false;
713 break;
714 }
715 }
716 if (!hit) {
717 continue;
718 }
719 output_values_t(index) = input_values_t(i);
720 for (int dim = 0; dim < dims; dim++) {
721 output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
722 }
723 index++;
724 }
725
726 return SparseTensor(output_indices, output_values, output_shape);
727 }
728
729 } // namespace sparse
730 } // namespace tensorflow
731
732 #endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
733