• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
18 
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36 #include "tensorflow/core/util/sparse/dim_comparator.h"
37 #include "tensorflow/core/util/sparse/group_iterator.h"
38 
39 namespace tensorflow {
40 namespace sparse {
41 
42 class SparseTensor {
43  public:
44   typedef typename gtl::ArraySlice<int64> VarDimArray;
45   typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
46 
Create(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order,SparseTensor * result)47   static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
48                        const VarDimArray order, SparseTensor* result) {
49     if (ix.dtype() != DT_INT64) {
50       return Status(
51           error::INVALID_ARGUMENT,
52           strings::StrCat("indices must be type int64 but got: ", ix.dtype()));
53     }
54     if (!TensorShapeUtils::IsVector(vals.shape())) {
55       return Status(error::INVALID_ARGUMENT,
56                     strings::StrCat("vals must be a vec, but got: ",
57                                     vals.shape().DebugString()));
58     }
59     if (ix.shape().dim_size(0) != vals.shape().dim_size(0)) {
60       return Status(error::INVALID_ARGUMENT,
61                     strings::StrCat("indices and values rows (indexing "
62                                     "dimension) must match. (indices = ",
63                                     ix.shape().dim_size(0), ", values = ",
64                                     vals.shape().dim_size(0), ")"));
65     }
66     int dims = 0;
67     TF_RETURN_IF_ERROR(GetDimsFromIx(ix, &dims));
68     if (order.size() != dims) {
69       return Status(error::INVALID_ARGUMENT,
70                     "Order length must be SparseTensor rank.");
71     }
72     if (shape.size() != dims) {
73       return Status(error::INVALID_ARGUMENT,
74                     "Shape rank must be SparseTensor rank.");
75     }
76 
77     *result = SparseTensor(ix, vals, shape, order);
78     return Status();
79   }
80 
Create(Tensor ix,Tensor vals,const TensorShape & shape,SparseTensor * result)81   static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
82                        SparseTensor* result) {
83     return Create(ix, vals, TensorShapeToVector(shape),
84                   UndefinedOrder(TensorShapeToVector(shape)), result);
85   }
86 
Create(Tensor ix,Tensor vals,const VarDimArray shape,SparseTensor * result)87   static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
88                        SparseTensor* result) {
89     return Create(ix, vals, shape, UndefinedOrder(shape), result);
90   }
91 
Create(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order,SparseTensor * result)92   static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
93                        const VarDimArray order, SparseTensor* result) {
94     return Create(ix, vals, TensorShapeToVector(shape), order, result);
95   }
96 
SparseTensor()97   SparseTensor() : dims_(0) {}
98 
99   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape)100   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
101       : SparseTensor(ix, vals, TensorShapeToVector(shape),
102                      UndefinedOrder(TensorShapeToVector(shape))) {}
103 
104   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape)105   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
106       : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
107 
108   ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order)109   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
110                const VarDimArray order)
111       : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
112 
113   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape,const VarDimArray order)114   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
115                const VarDimArray order)
116       : ix_(ix),
117         vals_(vals),
118         shape_(shape.begin(), shape.end()),
119         order_(order.begin(), order.end()),
120         dims_(UnsafeGetDimsFromIx(ix)) {
121     DCHECK_EQ(ix.dtype(), DT_INT64)
122         << "indices must be type int64 but got: " << ix.dtype();
123     DCHECK(TensorShapeUtils::IsVector(vals.shape()))
124         << "vals must be a vec, but got: " << vals.shape().DebugString();
125     DCHECK_EQ(ix.shape().dim_size(0), vals.shape().dim_size(0))
126         << "indices and values rows (indexing dimension) must match.";
127     DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank.";
128     DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
129   }
130 
SparseTensor(const SparseTensor & other)131   SparseTensor(const SparseTensor& other)
132       : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
133 
SparseTensor(SparseTensor && other)134   SparseTensor(SparseTensor&& other)
135       : SparseTensor(std::move(other.ix_), std::move(other.vals_),
136                      std::move(other.shape_), std::move(other.order_)) {}
137 
138   SparseTensor& operator=(const SparseTensor& other) {
139     ix_ = other.ix_;
140     vals_ = other.vals_;
141     shape_ = other.shape_;
142     order_ = other.order_;
143     dims_ = other.dims_;
144     return *this;
145   }
146 
147   SparseTensor& operator=(SparseTensor&& other) {
148     ix_ = std::move(other.ix_);
149     vals_ = std::move(other.vals_);
150     shape_ = std::move(other.shape_);
151     order_ = std::move(other.order_);
152     dims_ = std::move(other.dims_);
153     return *this;
154   }
155 
num_entries()156   std::size_t num_entries() const { return ix_.dim_size(0); }
157 
dims()158   int dims() const { return shape_.size(); }
159 
indices()160   const Tensor& indices() const { return ix_; }
161 
values()162   const Tensor& values() const { return vals_; }
163 
dtype()164   DataType dtype() const { return vals_.dtype(); }
165 
IndicesValid()166   Status IndicesValid() const {
167     const auto ix_t = ix_.matrix<int64>();
168     for (int64 ord : order_) {
169       if (ord < 0) {
170         return errors::FailedPrecondition(
171             "Order was not provided.  Provide an order at "
172             "construction time or run ReorderInPlace");
173       }
174     }
175 
176     for (std::size_t n = 0; n < num_entries(); ++n) {
177       TF_RETURN_IF_ERROR(IndexValid(ix_t, n));
178     }
179 
180     return Status::OK();
181   }
182 
shape()183   VarDimArray shape() const { return shape_; }
184 
order()185   VarDimArray order() const { return order_; }
186 
187   // Resorts the indices and values according to the dimensions in order.
188   template <typename T>
189   void Reorder(const VarDimArray& order);
190 
191   // Returns a group iterable that can be used for clumping indices
192   // and values according to the group indices of interest.
193   //
194   // Precondition: order()[0..group_ix.size()] == group_ix.
195   //
196   // See the README.md in this directory for more usage information.
group(const VarDimArray & group_ix)197   GroupIterable group(const VarDimArray& group_ix) const {
198     DCHECK_LE(group_ix.size(), dims_);
199     for (std::size_t di = 0; di < group_ix.size(); ++di) {
200       DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
201       DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
202       DCHECK_EQ(group_ix[di], order_[di])
203           << "Group dimension does not match sorted order";
204     }
205     return GroupIterable(ix_, vals_, dims_, group_ix);
206   }
207 
208   // Stores the sparse indices into the dense tensor out.
209   // Preconditions:
210   //   out->shape().dims() == shape().dims()
211   //   out->shape().dim_size(d) >= shape(d) for all d
212   //
213   // Returns true on success.  False on failure (mismatched dimensions
214   // or out-of-bounds indices).
215   //
216   // If initialize==True, ToDense first overwrites all coefficients in out to 0.
217   //
218   template <typename T>
219   bool ToDense(Tensor* out, bool initialize = true);
220 
221   // Concat() will concatenate all the tensors according to their first order
222   // dimension.  All tensors must have identical shape except for
223   // the first order dimension.  All tensors orders' first dimension
224   // must match.
225   //
226   // If all of the tensors have identical ordering, then the output
227   // will have this ordering.  Otherwise the output is set as not
228   // having any order and a Reorder<T>() should be called on it before
229   // performing any subsequent operations.
230   template <typename T>
231   static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
232 
233   // Split() will split the input SparseTensor into a list of num_split
234   // SparseTensor given a splitting dimension. If the input dimension range
235   // isn't an integer multiple of split_dim, we add one extra dimension for
236   // each slice.
237   template <typename T>
238   static Status Split(const SparseTensor& tensor, const int split_dim,
239                       const int num_split, std::vector<SparseTensor>* result);
240 
241   // Slice() will slice the input SparseTensor into a SparseTensor based on
242   // specified start and size. Both start and size are 1-D array with each
243   // element of the array representing one dimension. The start is the start
244   // index at each dimension and the size is the size at each dimension.
245   template <typename T>
246   static SparseTensor Slice(const SparseTensor& tensor,
247                             const gtl::ArraySlice<int64>& start,
248                             const gtl::ArraySlice<int64>& size);
249 
250   // Picks out the dimensions according to `dim_indices`.
PickDims(gtl::ArraySlice<int64> dim_indices)251   std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const {
252     std::vector<int64> res(dim_indices.size());
253     for (size_t i = 0; i < dim_indices.size(); ++i) {
254       res[i] = shape_[dim_indices[i]];
255     }
256     return res;
257   }
258 
259  private:
GetDimsFromIx(const Tensor & ix,int * result)260   static Status GetDimsFromIx(const Tensor& ix, int* result) {
261     if (!TensorShapeUtils::IsMatrix(ix.shape())) {
262       return Status(error::INVALID_ARGUMENT,
263                     strings::StrCat("indices must be a matrix, but got: ",
264                                     ix.shape().DebugString()));
265     }
266     *result = UnsafeGetDimsFromIx(ix);
267     return Status();
268   }
269 
UnsafeGetDimsFromIx(const Tensor & ix)270   static int UnsafeGetDimsFromIx(const Tensor& ix) {
271     DCHECK(TensorShapeUtils::IsMatrix(ix.shape()));
272     return ix.dim_size(1);
273   }
274 
UndefinedOrder(const VarDimArray shape)275   static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
276     return ShapeArray(shape.size(), -1);
277   }
278 
TensorShapeToVector(const TensorShape & shape)279   static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
280     ShapeArray vec(shape.dims());
281     for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
282     return vec;
283   }
284 
285   // Helper for IndicesValid()
IndexValid(const TTypes<int64>::ConstMatrix & ix_t,int n)286   inline Status IndexValid(const TTypes<int64>::ConstMatrix& ix_t,
287                            int n) const {
288     bool valid = true;
289     bool different = false;
290     bool increasing = true;
291     if (n == 0) {
292       for (int di = 0; di < dims_; ++di) {
293         if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
294       }
295       different = true;
296     } else {
297       for (int di = 0; di < dims_; ++di) {
298         if (ix_t(n, di) < 0 || ix_t(n, di) >= shape_[di]) valid = false;
299         int64 diff = ix_t(n, order_[di]) - ix_t(n - 1, order_[di]);
300         if (diff > 0) different = true;
301         if (!different && diff < 0) increasing = false;
302       }
303     }
304     if (TF_PREDICT_FALSE(!valid || !increasing || !different)) {
305       string index = strings::StrCat("indices[", n, "] = [");
306       for (int di = 0; di < dims_; ++di) {
307         strings::StrAppend(&index, ix_t(n, di), di < dims_ - 1 ? "," : "]");
308       }
309       if (!valid) {
310         return errors::InvalidArgument(index,
311                                        " is out of bounds: need 0 <= index < [",
312                                        str_util::Join(shape_, ","), "]");
313       }
314       if (!increasing) {
315         return errors::InvalidArgument(index, " is out of order");
316       }
317       if (!different) {
318         return errors::InvalidArgument(index, " is repeated");
319       }
320     }
321     return Status::OK();
322   }
323 
324   // Helper for ToDense<T>()
325   template <typename T>
326   bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
327 
328   // Helper for Split() that returns the slice index.
GetSliceIndex(const int dim,const int split_size,const int residual)329   static inline int GetSliceIndex(const int dim, const int split_size,
330                                   const int residual) {
331     DCHECK_GT(split_size, 0);
332     DCHECK_GE(dim, 0);
333     if (residual == 0) return dim / split_size;
334     const int offset = residual * (split_size + 1);
335     if (dim < offset) {
336       return dim / (split_size + 1);
337     } else {
338       return residual + ((dim - offset) / split_size);
339     }
340   }
341 
342   // Helper for Split() that returns the dimension in the slice.
GetDimensionInSlice(const int dim,const int split_size,const int residual)343   static inline int GetDimensionInSlice(const int dim, const int split_size,
344                                         const int residual) {
345     DCHECK_GT(split_size, 0);
346     DCHECK_GE(dim, 0);
347     if (residual == 0) return dim % split_size;
348     const int offset = residual * (split_size + 1);
349     if (dim < offset) {
350       return dim % (split_size + 1);
351     } else {
352       return (dim - offset) % split_size;
353     }
354   }
355 
356   // Helper for Split() that returns the shape given a slice index.
GetSliceShape(const int slice_index,const int split_size,const int residual)357   static inline int GetSliceShape(const int slice_index, const int split_size,
358                                   const int residual) {
359     DCHECK_GT(split_size, 0);
360     DCHECK_GE(slice_index, 0);
361     if (residual == 0) return split_size;
362     if (slice_index < residual) {
363       return split_size + 1;
364     } else {
365       return split_size;
366     }
367   }
368 
369   Tensor ix_;
370   Tensor vals_;
371   ShapeArray shape_;
372   ShapeArray order_;
373   int dims_;
374 };
375 
376 // This operation updates the indices and values Tensor rows, so it is
377 // an in-place algorithm.  It requires O(N log N) time and O(N)
378 // temporary space.
379 template <typename T>
Reorder(const VarDimArray & order)380 void SparseTensor::Reorder(const VarDimArray& order) {
381   DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
382       << "Reorder requested with the wrong datatype";
383   DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
384   auto ix_t = ix_.matrix<int64>();
385   auto vals_t = vals_.vec<T>();
386 
387   std::vector<int64> reorder(num_entries());
388   std::iota(reorder.begin(), reorder.end(), 0);
389 
390   // Sort to get order of indices
391   switch (order.size()) {
392 #define CASE_SORT(ORDER_SIZE)                                    \
393   case ORDER_SIZE: {                                             \
394     FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
395     std::sort(reorder.begin(), reorder.end(), sorter);           \
396     break;                                                       \
397   }
398     CASE_SORT(0);
399     CASE_SORT(1);
400     CASE_SORT(2);
401     CASE_SORT(3);
402     CASE_SORT(4);
403     CASE_SORT(5);
404 #undef CASE_SORT
405     default: {
406       DimComparator sorter(ix_t, order, shape());
407       std::sort(reorder.begin(), reorder.end(), sorter);
408     }
409   }
410 
411   // We have a forward reordering, but what we'll need is a
412   // permutation (the inverse).  This can be calculated with O(1)
413   // additional
414   // and O(n) time (INVPERM) but we just do the simple thing here.
415   std::vector<size_t> permutation(reorder.size());
416   for (std::size_t n = 0; n < reorder.size(); ++n) {
417     permutation[reorder[n]] = n;
418   }
419 
420   // Update indices & values by converting the permutations to
421   // a product of transpositions.  Iterate over the cycles in the
422   // permutation, and convert each of those into a product of
423   // transpositions (swaps):
424   //   https://en.wikipedia.org/wiki/Cyclic_permutation
425   // This is N swaps, 2*N comparisons.
426   for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
427     while (n != permutation[n]) {
428       std::size_t r = permutation[n];
429       std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
430       std::swap(vals_t(n), vals_t(r));
431       std::swap(permutation[n], permutation[r]);
432     }
433   }
434 
435   order_ = ShapeArray(order.begin(), order.end());
436 }
437 
438 template <typename T>
ValidateAndInitializeToDense(Tensor * out,bool initialize)439 bool SparseTensor::ValidateAndInitializeToDense(Tensor* out, bool initialize) {
440   DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
441       << "ToDense requested with the wrong datatype";
442 
443   DCHECK_EQ(out->shape().dims(), dims_)
444       << "Incompatible dimensions between SparseTensor and output";
445 
446   DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
447       << "Output must be type: " << DataTypeToEnum<T>::v()
448       << " but got: " << out->dtype();
449 
450   // Make sure the dense output is the same rank and has room
451   // to hold the SparseTensor.
452   const auto& out_shape = out->shape();
453   if (shape_.size() != out_shape.dims()) return false;
454   for (int d = 0; d < shape_.size(); ++d) {
455     if (shape_[d] > out_shape.dim_size(d)) return false;
456   }
457 
458   if (initialize) {
459     auto out_t = out->flat<T>();
460     out_t.setConstant(T());
461   }
462 
463   return true;
464 }
465 
466 template <typename T>
ToDense(Tensor * out,bool initialize)467 bool SparseTensor::ToDense(Tensor* out, bool initialize) {
468   if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
469 
470   auto out_t = out->flat<T>();
471   auto ix_t = ix_.matrix<int64>();
472   auto vals_t = vals_.vec<T>();
473 
474   std::vector<int64> strides(dims_);
475   const auto& out_shape = out->shape();
476   if (dims_ > 0) {
477     strides[dims_ - 1] = 1;
478   }
479   for (int d = dims_ - 2; d >= 0; --d) {
480     strides[d] = strides[d + 1] * out_shape.dim_size(d + 1);
481   }
482 
483   for (int n = 0; n < vals_t.dimension(0); ++n) {
484     bool invalid_dims = false;
485     int64 ix = 0;
486     for (int d = 0; d < dims_; ++d) {
487       const int64 ix_n_d = internal::SubtleMustCopy(ix_t(n, d));
488       if (!FastBoundsCheck(ix_n_d, out_shape.dim_size(d))) {
489         invalid_dims = true;
490       }
491       ix += strides[d] * ix_n_d;
492     }
493     if (invalid_dims) return false;
494     out_t(ix) = vals_t(n);
495   }
496   return true;
497 }
498 
499 template <typename T>
Concat(const gtl::ArraySlice<SparseTensor> & tensors)500 SparseTensor SparseTensor::Concat(
501     const gtl::ArraySlice<SparseTensor>& tensors) {
502   DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
503   const int dims = tensors[0].dims_;
504   DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
505   auto order_0 = tensors[0].order();
506   const int primary_dim = order_0[0];
507   ShapeArray final_order(order_0.begin(), order_0.end());
508   ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
509   final_shape[primary_dim] = 0;  // We'll build this up as we go along.
510   int num_entries = 0;
511 
512   bool fully_ordered = true;
513   for (const SparseTensor& st : tensors) {
514     DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
515     DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
516         << "Concat requested with the wrong data type";
517     DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
518     DCHECK_EQ(st.order()[0], primary_dim)
519         << "All SparseTensors' order[0] must match.  This is the concat dim.";
520     if (st.order() != final_order) fully_ordered = false;
521     const VarDimArray& st_shape = st.shape();
522     for (int d = 0; d < dims - 1; ++d) {
523       const int cdim = (d < primary_dim) ? d : d + 1;
524       DCHECK_EQ(final_shape[cdim], st_shape[cdim])
525           << "All SparseTensors' shapes must match except on the concat dim.  "
526           << "Concat dim: " << primary_dim
527           << ", mismatched shape at dim: " << cdim
528           << ".  Expecting shape like: [" << str_util::Join(final_shape, ",")
529           << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
530     }
531 
532     // Update dimension of final shape
533     final_shape[primary_dim] =
534         (final_shape[primary_dim] + st_shape[primary_dim]);
535 
536     num_entries += st.num_entries();  // Update number of entries
537   }
538 
539   // If nonconsistent ordering among inputs, set final order to -1s.
540   if (!fully_ordered) {
541     final_order = UndefinedOrder(final_shape);
542   }
543 
544   Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
545   Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
546 
547   TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>();
548   typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
549 
550   Eigen::DenseIndex offset = 0;
551   int64 shape_offset = 0;
552   for (const SparseTensor& st : tensors) {
553     const int st_num_entries = st.num_entries();
554 
555     // Fill in indices & values.
556     std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
557 
558     const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
559     auto* ix_out = &ix_t(offset, 0);
560     for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
561       *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
562     }
563 
564     offset += st_num_entries;
565     shape_offset += st.shape()[primary_dim];
566   }
567 
568   return SparseTensor(output_ix, output_vals, final_shape, final_order);
569 }
570 
571 template <typename T>
Split(const SparseTensor & input_tensor,const int split_dim,const int num_split,std::vector<SparseTensor> * result)572 Status SparseTensor::Split(const SparseTensor& input_tensor,
573                            const int split_dim, const int num_split,
574                            std::vector<SparseTensor>* result) {
575   std::vector<Tensor> output_indices;
576   std::vector<Tensor> output_values;
577   std::vector<TensorShape> output_shapes;
578   output_indices.reserve(num_split);
579   output_values.reserve(num_split);
580   output_shapes.reserve(num_split);
581 
582   std::vector<typename TTypes<int64>::Matrix> output_indices_t;
583   std::vector<typename TTypes<T>::Vec> output_values_t;
584   output_indices_t.reserve(num_split);
585   output_values_t.reserve(num_split);
586   auto input_values_t = input_tensor.values().vec<T>();
587   auto input_indices_t = input_tensor.indices().matrix<int64>();
588 
589   std::vector<int> num_values(num_split, 0);
590   const int num_dim = input_tensor.shape().size();
591   const int split_dim_size = input_tensor.shape()[split_dim];
592   const int split_size = split_dim_size / num_split;
593 
594   if (!(num_split > 0 && num_split <= split_dim_size)) {
595     return Status(error::INVALID_ARGUMENT,
596                   strings::StrCat("num_split must be in the interval (0, ",
597                                   split_dim_size, "]"));
598   }
599   if (!(split_dim >= 0 && split_dim < num_dim)) {
600     return Status(
601         error::INVALID_ARGUMENT,
602         strings::StrCat("num_dim must be in the interval [0, ", num_dim, ")"));
603   }
604 
605   const int residual = split_dim_size % num_split;
606   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
607     const int dim = input_tensor.indices().matrix<int64>()(i, split_dim);
608     int slice_index = GetSliceIndex(dim, split_size, residual);
609     num_values[slice_index]++;
610   }
611 
612   for (int i = 0; i < num_split; ++i) {
613     // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
614     output_indices.emplace_back(DT_INT64,
615                                 TensorShape({num_values[i], num_dim}));
616     output_values.emplace_back(DataTypeToEnum<T>::v(),
617                                TensorShape({num_values[i]}));
618     output_shapes.emplace_back(input_tensor.shape());
619     output_indices_t.emplace_back(output_indices[i].matrix<int64>());
620     output_values_t.emplace_back(output_values[i].vec<T>());
621     const int size = GetSliceShape(i, split_size, residual);
622     output_shapes[i].set_dim(split_dim, size);
623   }
624 
625   std::vector<int> values_inserted_in_slice(num_split, 0);
626   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
627     const int dim = input_indices_t(i, split_dim);
628     const int slice_index = GetSliceIndex(dim, split_size, residual);
629     const int slice_dim = values_inserted_in_slice[slice_index]++;
630     output_values_t[slice_index](slice_dim) = input_values_t(i);
631     for (int j = 0; j < num_dim; ++j) {
632       const int64 original_dim = input_indices_t(i, j);
633       output_indices_t[slice_index](slice_dim, j) =
634           (j == split_dim)
635               ? GetDimensionInSlice(original_dim, split_size, residual)
636               : original_dim;
637     }
638   }
639 
640   result->clear();
641   result->reserve(num_split);
642   for (int i = 0; i < num_split; ++i) {
643     SparseTensor tensor;
644     Status create_status =
645         Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
646     if (!create_status.ok()) {
647       return create_status;
648     }
649     result->push_back(std::move(tensor));
650   }
651   return Status::OK();
652 }
653 
654 template <typename T>
Slice(const SparseTensor & input_tensor,const gtl::ArraySlice<int64> & start,const gtl::ArraySlice<int64> & size)655 SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
656                                  const gtl::ArraySlice<int64>& start,
657                                  const gtl::ArraySlice<int64>& size) {
658   TensorShape output_shape(input_tensor.shape());
659 
660   const int dims = input_tensor.dims();
661   for (int dim = 0; dim < dims; dim++) {
662     int64 dim_size = start[dim] + size[dim] < output_shape.dim_size(dim)
663                          ? size[dim]
664                          : output_shape.dim_size(dim) - start[dim];
665     output_shape.set_dim(dim, dim_size);
666   }
667 
668   auto input_indices_t = input_tensor.indices().matrix<int64>();
669   auto input_values_t = input_tensor.values().vec<T>();
670 
671   // Find the number of indices that fall inside start and size.
672   int count = 0;
673   for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
674     // The following will check to see if an input is within the
675     // range specified by start and size.
676     // The for loop below iterates through all dimensions. In case
677     // the index falls outside of the start and size at any dimension,
678     // it will be considered as a "no hit" (hit = false). In this
679     // case, it will not be counted as the index that fall inside
680     // the range specified by start and size.
681     bool hit = true;
682     for (int dim = 0; dim < dims; dim++) {
683       if (!(start[dim] <= input_indices_t(i, dim) &&
684             input_indices_t(i, dim) < start[dim] + size[dim])) {
685         hit = false;
686         break;
687       }
688     }
689     if (!hit) {
690       continue;
691     }
692     count++;
693   }
694 
695   Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
696   Tensor output_indices(DT_INT64, TensorShape({count, dims}));
697 
698   auto output_values_t = output_values.vec<T>();
699   auto output_indices_t = output_indices.matrix<int64>();
700 
701   // Obtain the output indices that fall inside start and size.
702   int index = 0;
703   for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
704        i++) {
705     // The logic here is similar as the above except that the above
706     // only count the number of indices while here we actually generate
707     // the output.
708     bool hit = true;
709     for (int dim = 0; dim < dims; dim++) {
710       if (!(start[dim] <= input_indices_t(i, dim) &&
711             input_indices_t(i, dim) < start[dim] + size[dim])) {
712         hit = false;
713         break;
714       }
715     }
716     if (!hit) {
717       continue;
718     }
719     output_values_t(index) = input_values_t(i);
720     for (int dim = 0; dim < dims; dim++) {
721       output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
722     }
723     index++;
724   }
725 
726   return SparseTensor(output_indices, output_values, output_shape);
727 }
728 
729 }  // namespace sparse
730 }  // namespace tensorflow
731 
732 #endif  // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
733