• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
18 
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22 
23 #include "absl/base/macros.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/sparse/dim_comparator.h"
36 #include "tensorflow/core/util/sparse/group_iterator.h"
37 
38 namespace tensorflow {
39 namespace sparse {
40 
41 class SparseTensor {
42  public:
43   typedef typename gtl::ArraySlice<int64> VarDimArray;
44   typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
45 
46   static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
47                        const VarDimArray order, SparseTensor* result);
48 
49   static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
50                        SparseTensor* result);
51 
52   static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
53                        SparseTensor* result);
54 
55   static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
56                        const VarDimArray order, SparseTensor* result);
57 
SparseTensor()58   SparseTensor() : dims_(0) {}
59 
60   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape)61   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
62       : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
63                      UndefinedOrder(TensorShapeToVector(shape))) {}
64 
65   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape)66   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
67       : SparseTensor(std::move(ix), std::move(vals), shape,
68                      UndefinedOrder(shape)) {}
69 
70   ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order)71   SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
72                const VarDimArray order)
73       : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
74                      order) {}
75 
76   ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
77   SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
78                const VarDimArray order);
79 
SparseTensor(const SparseTensor & other)80   SparseTensor(const SparseTensor& other)
81       : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
82 
SparseTensor(SparseTensor && other)83   SparseTensor(SparseTensor&& other)
84       : SparseTensor(std::move(other.ix_), std::move(other.vals_),
85                      std::move(other.shape_), std::move(other.order_)) {}
86 
87   SparseTensor& operator=(const SparseTensor& other) {
88     ix_ = other.ix_;
89     vals_ = other.vals_;
90     shape_ = other.shape_;
91     order_ = other.order_;
92     dims_ = other.dims_;
93     return *this;
94   }
95 
96   SparseTensor& operator=(SparseTensor&& other) {
97     ix_ = std::move(other.ix_);
98     vals_ = std::move(other.vals_);
99     shape_ = std::move(other.shape_);
100     order_ = std::move(other.order_);
101     dims_ = std::move(other.dims_);
102     return *this;
103   }
104 
num_entries()105   std::size_t num_entries() const { return ix_.dim_size(0); }
106 
dims()107   int dims() const { return shape_.size(); }
108 
indices()109   const Tensor& indices() const { return ix_; }
110 
values()111   const Tensor& values() const { return vals_; }
112 
dtype()113   DataType dtype() const { return vals_.dtype(); }
114 
115   Status IndicesValid() const;
116 
shape()117   VarDimArray shape() const { return shape_; }
118 
order()119   VarDimArray order() const { return order_; }
120 
121   // Resorts the indices and values according to the dimensions in order.
122   template <typename T>
123   void Reorder(const VarDimArray& order);
124 
125   // Returns a group iterable that can be used for clumping indices
126   // and values according to the group indices of interest.
127   //
128   // Precondition: order()[0..group_ix.size()] == group_ix.
129   //
130   // See the README.md in this directory for more usage information.
group(const VarDimArray & group_ix)131   GroupIterable group(const VarDimArray& group_ix) const {
132     DCHECK_LE(group_ix.size(), dims_);
133     for (std::size_t di = 0; di < group_ix.size(); ++di) {
134       DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
135       DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
136       DCHECK_EQ(group_ix[di], order_[di])
137           << "Group dimension does not match sorted order";
138     }
139     return GroupIterable(ix_, vals_, dims_, group_ix);
140   }
141 
142   // Stores the sparse indices into the dense tensor out.
143   // Preconditions:
144   //   out->shape().dims() == shape().dims()
145   //   out->shape().dim_size(d) >= shape(d) for all d
146   //
147   // Returns true on success.  False on failure (mismatched dimensions
148   // or out-of-bounds indices).
149   //
150   // If initialize==True, ToDense first overwrites all coefficients in out to 0.
151   //
152   template <typename T>
153   bool ToDense(Tensor* out, bool initialize = true);
154 
155   // Concat() will concatenate all the tensors according to their first order
156   // dimension.  All tensors must have identical shape except for
157   // the first order dimension.  All tensors orders' first dimension
158   // must match.
159   //
160   // If all of the tensors have identical ordering, then the output
161   // will have this ordering.  Otherwise the output is set as not
162   // having any order and a Reorder<T>() should be called on it before
163   // performing any subsequent operations.
164   template <typename T>
165   static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
166 
167   // Split() will split the input SparseTensor into a list of num_split
168   // SparseTensor given a splitting dimension. If the input dimension range
169   // isn't an integer multiple of split_dim, we add one extra dimension for
170   // each slice.
171   template <typename T>
172   static Status Split(const SparseTensor& tensor, const int split_dim,
173                       const int num_split, std::vector<SparseTensor>* result);
174 
175   // Slice() will slice the input SparseTensor into a SparseTensor based on
176   // specified start and size. Both start and size are 1-D array with each
177   // element of the array representing one dimension. The start is the start
178   // index at each dimension and the size is the size at each dimension.
179   template <typename T>
180   static SparseTensor Slice(const SparseTensor& tensor,
181                             const gtl::ArraySlice<int64>& start,
182                             const gtl::ArraySlice<int64>& size);
183 
184   // Picks out the dimensions according to `dim_indices`.
PickDims(gtl::ArraySlice<int64> dim_indices)185   std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const {
186     std::vector<int64> res(dim_indices.size());
187     for (size_t i = 0; i < dim_indices.size(); ++i) {
188       res[i] = shape_[dim_indices[i]];
189     }
190     return res;
191   }
192 
193  private:
UndefinedOrder(const VarDimArray shape)194   static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
195     return ShapeArray(shape.size(), -1);
196   }
197 
TensorShapeToVector(const TensorShape & shape)198   static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
199     ShapeArray vec(shape.dims());
200     for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
201     return vec;
202   }
203 
204   // Optimized implementation of `IndicesValid` for 1-D sparse tensors.
205   // REQUIRES: `shape_.size() == 1`.
206   bool IndicesValidVectorFastPath() const;
207 
208   // Optimized implementation of `IndicesValid` for 2-D sparse tensors whose
209   // indices fit within the range of an `int32`.
210   // REQUIRES: `shape_.size() == 2`.
211   bool IndicesValidMatrix32BitFastPath() const;
212 
213   template <bool standard_order>
214   Status IndicesValidHelper() const;
215 
216   // Helper for ToDense<T>()
217   template <typename T>
218   bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
219 
220   // Helper for Split() that returns the slice index.
GetSliceIndex(const int dim,const int split_size,const int residual)221   static inline int GetSliceIndex(const int dim, const int split_size,
222                                   const int residual) {
223     DCHECK_GT(split_size, 0);
224     DCHECK_GE(dim, 0);
225     if (residual == 0) return dim / split_size;
226     const int offset = residual * (split_size + 1);
227     if (dim < offset) {
228       return dim / (split_size + 1);
229     } else {
230       return residual + ((dim - offset) / split_size);
231     }
232   }
233 
234   // Helper for Split() that returns the dimension in the slice.
GetDimensionInSlice(const int dim,const int split_size,const int residual)235   static inline int GetDimensionInSlice(const int dim, const int split_size,
236                                         const int residual) {
237     DCHECK_GT(split_size, 0);
238     DCHECK_GE(dim, 0);
239     if (residual == 0) return dim % split_size;
240     const int offset = residual * (split_size + 1);
241     if (dim < offset) {
242       return dim % (split_size + 1);
243     } else {
244       return (dim - offset) % split_size;
245     }
246   }
247 
248   // Helper for Split() that returns the shape given a slice index.
GetSliceShape(const int slice_index,const int split_size,const int residual)249   static inline int GetSliceShape(const int slice_index, const int split_size,
250                                   const int residual) {
251     DCHECK_GT(split_size, 0);
252     DCHECK_GE(slice_index, 0);
253     if (residual == 0) return split_size;
254     if (slice_index < residual) {
255       return split_size + 1;
256     } else {
257       return split_size;
258     }
259   }
260 
261   Tensor ix_;
262   Tensor vals_;
263   ShapeArray shape_;
264   ShapeArray order_;
265   int dims_;
266 };
267 
268 // This operation updates the indices and values Tensor rows, so it is
269 // an in-place algorithm.  It requires O(N log N) time and O(N)
270 // temporary space.
271 template <typename T>
Reorder(const VarDimArray & order)272 inline void SparseTensor::Reorder(const VarDimArray& order) {
273   DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
274       << "Reorder requested with the wrong datatype";
275   DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
276   auto ix_t = ix_.matrix<int64>();
277   auto vals_t = vals_.vec<T>();
278 
279   std::vector<int64> reorder(num_entries());
280   std::iota(reorder.begin(), reorder.end(), 0);
281 
282   // Sort to get order of indices
283   switch (order.size()) {
284 #define CASE_SORT(ORDER_SIZE)                                    \
285   case ORDER_SIZE: {                                             \
286     FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
287     std::sort(reorder.begin(), reorder.end(), sorter);           \
288     break;                                                       \
289   }
290     CASE_SORT(0);
291     CASE_SORT(1);
292     CASE_SORT(2);
293     CASE_SORT(3);
294     CASE_SORT(4);
295     CASE_SORT(5);
296 #undef CASE_SORT
297     default: {
298       DimComparator sorter(ix_t, order, shape());
299       std::sort(reorder.begin(), reorder.end(), sorter);
300     }
301   }
302 
303   // We have a forward reordering, but what we'll need is a
304   // permutation (the inverse).  This can be calculated with O(1)
305   // additional
306   // and O(n) time (INVPERM) but we just do the simple thing here.
307   std::vector<size_t> permutation(reorder.size());
308   for (std::size_t n = 0; n < reorder.size(); ++n) {
309     permutation[reorder[n]] = n;
310   }
311 
312   // Update indices & values by converting the permutations to
313   // a product of transpositions.  Iterate over the cycles in the
314   // permutation, and convert each of those into a product of
315   // transpositions (swaps):
316   //   https://en.wikipedia.org/wiki/Cyclic_permutation
317   // This is N swaps, 2*N comparisons.
318   for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
319     while (n != permutation[n]) {
320       std::size_t r = permutation[n];
321       std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
322       std::swap(vals_t(n), vals_t(r));
323       std::swap(permutation[n], permutation[r]);
324     }
325   }
326 
327   order_ = ShapeArray(order.begin(), order.end());
328 }
329 
330 template <typename T>
ValidateAndInitializeToDense(Tensor * out,bool initialize)331 inline bool SparseTensor::ValidateAndInitializeToDense(Tensor* out,
332                                                        bool initialize) {
333   DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
334       << "ToDense requested with the wrong datatype";
335 
336   DCHECK_EQ(out->shape().dims(), dims_)
337       << "Incompatible dimensions between SparseTensor and output";
338 
339   DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
340       << "Output must be type: " << DataTypeToEnum<T>::v()
341       << " but got: " << out->dtype();
342 
343   // Make sure the dense output is the same rank and has room
344   // to hold the SparseTensor.
345   const auto& out_shape = out->shape();
346   if (shape_.size() != out_shape.dims()) return false;
347   for (int d = 0; d < shape_.size(); ++d) {
348     if (shape_[d] > out_shape.dim_size(d)) return false;
349   }
350 
351   if (initialize) {
352     auto out_t = out->flat<T>();
353     out_t.setConstant(T());
354   }
355 
356   return true;
357 }
358 
359 template <typename T>
ToDense(Tensor * out,bool initialize)360 inline bool SparseTensor::ToDense(Tensor* out, bool initialize) {
361   if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
362 
363   auto out_t = out->flat<T>();
364   auto vals_t = vals_.vec<T>();
365   auto ix_t = ix_.matrix<int64>();
366   const int64* const ix_ptr = ix_t.data();
367 
368   if (dims_ == 1) {
369     // Fast path for sparse vectors.
370     const int64 out_length = out->shape().dim_size(0);
371     for (int n = 0; n < vals_t.dimension(0); ++n) {
372       const int64 index = internal::SubtleMustCopy(ix_ptr[n]);
373       if (!FastBoundsCheck(index, out_length)) return false;
374       out_t(index) = vals_t(n);
375     }
376     return true;
377   } else if (dims_ == 2) {
378     // Fast path for sparse matrices.
379     const auto& out_shape = out->shape();
380     const int64 out_rows = out_shape.dim_size(0);
381     const int64 out_cols = out_shape.dim_size(1);
382     for (int n = 0; n < vals_t.dimension(0); ++n) {
383       const int64 row_index = internal::SubtleMustCopy(ix_ptr[n * 2]);
384       const int64 col_index = internal::SubtleMustCopy(ix_ptr[n * 2 + 1]);
385       if (!(FastBoundsCheck(row_index, out_rows) &&
386             FastBoundsCheck(col_index, out_cols))) {
387         return false;
388       }
389       out_t(row_index * out_cols + col_index) = vals_t(n);
390     }
391     return true;
392   } else {
393     // General path for N-dimensional sparse tensors.
394     gtl::InlinedVector<int64, 4> strides(dims_);
395     const auto& out_shape = out->shape().dim_sizes();
396     if (dims_ > 0) {
397       strides[dims_ - 1] = 1;
398     }
399     for (int d = dims_ - 2; d >= 0; --d) {
400       strides[d] = strides[d + 1] * out_shape[d + 1];
401     }
402 
403     for (int n = 0; n < vals_t.dimension(0); ++n) {
404       bool invalid_dims = false;
405       int64 ix = 0;
406       for (int d = 0; d < dims_; ++d) {
407         const int64 ix_n_d = internal::SubtleMustCopy(ix_ptr[n * dims_ + d]);
408         if (!FastBoundsCheck(ix_n_d, out_shape[d])) {
409           invalid_dims = true;
410         }
411         ix += strides[d] * ix_n_d;
412       }
413       if (invalid_dims) return false;
414       out_t(ix) = vals_t(n);
415     }
416     return true;
417   }
418 }
419 
420 template <typename T>
Concat(const gtl::ArraySlice<SparseTensor> & tensors)421 inline SparseTensor SparseTensor::Concat(
422     const gtl::ArraySlice<SparseTensor>& tensors) {
423   DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
424   const int dims = tensors[0].dims_;
425   DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
426   auto order_0 = tensors[0].order();
427   const int primary_dim = order_0[0];
428   ShapeArray final_order(order_0.begin(), order_0.end());
429   ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
430   final_shape[primary_dim] = 0;  // We'll build this up as we go along.
431   int num_entries = 0;
432 
433   bool fully_ordered = true;
434   for (const SparseTensor& st : tensors) {
435     DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
436     DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
437         << "Concat requested with the wrong data type";
438     DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
439     DCHECK_EQ(st.order()[0], primary_dim)
440         << "All SparseTensors' order[0] must match.  This is the concat dim.";
441     if (st.order() != final_order) fully_ordered = false;
442     const VarDimArray& st_shape = st.shape();
443     for (int d = 0; d < dims - 1; ++d) {
444       const int cdim = (d < primary_dim) ? d : d + 1;
445       DCHECK_EQ(final_shape[cdim], st_shape[cdim])
446           << "All SparseTensors' shapes must match except on the concat dim.  "
447           << "Concat dim: " << primary_dim
448           << ", mismatched shape at dim: " << cdim
449           << ".  Expecting shape like: [" << str_util::Join(final_shape, ",")
450           << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
451     }
452 
453     // Update dimension of final shape
454     final_shape[primary_dim] =
455         (final_shape[primary_dim] + st_shape[primary_dim]);
456 
457     num_entries += st.num_entries();  // Update number of entries
458   }
459 
460   // If nonconsistent ordering among inputs, set final order to -1s.
461   if (!fully_ordered) {
462     final_order = UndefinedOrder(final_shape);
463   }
464 
465   Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
466   Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
467 
468   TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>();
469   typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
470 
471   Eigen::DenseIndex offset = 0;
472   int64 shape_offset = 0;
473   for (const SparseTensor& st : tensors) {
474     const int st_num_entries = st.num_entries();
475 
476     // Fill in indices & values.
477     if (st_num_entries > 0) {
478       std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
479 
480       const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
481       auto* ix_out = &ix_t(offset, 0);
482       for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
483         *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
484       }
485     }
486 
487     offset += st_num_entries;
488     shape_offset += st.shape()[primary_dim];
489   }
490 
491   return SparseTensor(output_ix, output_vals, final_shape, final_order);
492 }
493 
494 template <typename T>
Split(const SparseTensor & input_tensor,const int split_dim,const int num_split,std::vector<SparseTensor> * result)495 inline Status SparseTensor::Split(const SparseTensor& input_tensor,
496                                   const int split_dim, const int num_split,
497                                   std::vector<SparseTensor>* result) {
498   std::vector<Tensor> output_indices;
499   std::vector<Tensor> output_values;
500   std::vector<TensorShape> output_shapes;
501   output_indices.reserve(num_split);
502   output_values.reserve(num_split);
503   output_shapes.reserve(num_split);
504 
505   std::vector<typename TTypes<int64>::Matrix> output_indices_t;
506   std::vector<typename TTypes<T>::Vec> output_values_t;
507   output_indices_t.reserve(num_split);
508   output_values_t.reserve(num_split);
509   auto input_values_t = input_tensor.values().vec<T>();
510   auto input_indices_t = input_tensor.indices().matrix<int64>();
511 
512   std::vector<int> num_values(num_split, 0);
513   const int num_dim = input_tensor.shape().size();
514   const int split_dim_size = input_tensor.shape()[split_dim];
515   const int split_size = split_dim_size / num_split;
516 
517   if (!(num_split > 0 && num_split <= split_dim_size)) {
518     return errors::InvalidArgument("num_split must be in the interval (0, ",
519                                    split_dim_size, "]");
520   }
521   if (!(split_dim >= 0 && split_dim < num_dim)) {
522     return errors::InvalidArgument("num_dim must be in the interval [0, ",
523                                    num_dim, ")");
524   }
525 
526   const int residual = split_dim_size % num_split;
527   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
528     const int dim = input_tensor.indices().matrix<int64>()(i, split_dim);
529     int slice_index = GetSliceIndex(dim, split_size, residual);
530     num_values[slice_index]++;
531   }
532 
533   for (int i = 0; i < num_split; ++i) {
534     // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
535     output_indices.emplace_back(DT_INT64,
536                                 TensorShape({num_values[i], num_dim}));
537     output_values.emplace_back(DataTypeToEnum<T>::v(),
538                                TensorShape({num_values[i]}));
539     output_shapes.emplace_back(input_tensor.shape());
540     output_indices_t.emplace_back(output_indices[i].matrix<int64>());
541     output_values_t.emplace_back(output_values[i].vec<T>());
542     const int size = GetSliceShape(i, split_size, residual);
543     output_shapes[i].set_dim(split_dim, size);
544   }
545 
546   std::vector<int> values_inserted_in_slice(num_split, 0);
547   for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
548     const int dim = input_indices_t(i, split_dim);
549     const int slice_index = GetSliceIndex(dim, split_size, residual);
550     const int slice_dim = values_inserted_in_slice[slice_index]++;
551     output_values_t[slice_index](slice_dim) = input_values_t(i);
552     for (int j = 0; j < num_dim; ++j) {
553       const int64 original_dim = input_indices_t(i, j);
554       output_indices_t[slice_index](slice_dim, j) =
555           (j == split_dim)
556               ? GetDimensionInSlice(original_dim, split_size, residual)
557               : original_dim;
558     }
559   }
560 
561   result->clear();
562   result->reserve(num_split);
563   for (int i = 0; i < num_split; ++i) {
564     SparseTensor tensor;
565     Status create_status =
566         Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
567     if (!create_status.ok()) {
568       return create_status;
569     }
570     result->push_back(std::move(tensor));
571   }
572   return Status::OK();
573 }
574 
575 template <typename T>
Slice(const SparseTensor & input_tensor,const gtl::ArraySlice<int64> & start,const gtl::ArraySlice<int64> & size)576 inline SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
577                                         const gtl::ArraySlice<int64>& start,
578                                         const gtl::ArraySlice<int64>& size) {
579   TensorShape output_shape(input_tensor.shape());
580 
581   const int dims = input_tensor.dims();
582   for (int dim = 0; dim < dims; dim++) {
583     // Determine the size of the result; if the selected slice goes beyond the
584     // input boundary, the result will correspond to the size of the overlap
585     // between the input and the selected slice.
586     const int64 input_size = output_shape.dim_size(dim);
587     const int64 start_index = start[dim];
588     const int64 slice_size = size[dim];
589     if (start_index + slice_size < input_size) {
590       // The entire selection is within input boundaries.
591       output_shape.set_dim(dim, slice_size);
592     } else if (start_index < input_size) {
593       // The selection starts within input boundaries, but goes beyond them.
594       output_shape.set_dim(dim, input_size - start_index);
595     } else {
596       // The selection is entirely out of input boundaries.
597       output_shape.set_dim(dim, 0);
598     }
599   }
600 
601   auto input_indices_t = input_tensor.indices().matrix<int64>();
602   auto input_values_t = input_tensor.values().vec<T>();
603 
604   // Find the number of indices that fall inside start and size.
605   int count = 0;
606   for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
607     // The following will check to see if an input is within the
608     // range specified by start and size.
609     // The for loop below iterates through all dimensions. In case
610     // the index falls outside of the start and size at any dimension,
611     // it will be considered as a "no hit" (hit = false). In this
612     // case, it will not be counted as the index that fall inside
613     // the range specified by start and size.
614     bool hit = true;
615     for (int dim = 0; dim < dims; dim++) {
616       if (!(start[dim] <= input_indices_t(i, dim) &&
617             input_indices_t(i, dim) < start[dim] + size[dim])) {
618         hit = false;
619         break;
620       }
621     }
622     if (!hit) {
623       continue;
624     }
625     count++;
626   }
627 
628   Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
629   Tensor output_indices(DT_INT64, TensorShape({count, dims}));
630 
631   auto output_values_t = output_values.vec<T>();
632   auto output_indices_t = output_indices.matrix<int64>();
633 
634   // Obtain the output indices that fall inside start and size.
635   int index = 0;
636   for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
637        i++) {
638     // The logic here is similar as the above except that the above
639     // only count the number of indices while here we actually generate
640     // the output.
641     bool hit = true;
642     for (int dim = 0; dim < dims; dim++) {
643       if (!(start[dim] <= input_indices_t(i, dim) &&
644             input_indices_t(i, dim) < start[dim] + size[dim])) {
645         hit = false;
646         break;
647       }
648     }
649     if (!hit) {
650       continue;
651     }
652     output_values_t(index) = input_values_t(i);
653     for (int dim = 0; dim < dims; dim++) {
654       output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
655     }
656     index++;
657   }
658 
659   return SparseTensor(output_indices, output_values, output_shape);
660 }
661 
662 }  // namespace sparse
663 }  // namespace tensorflow
664 
665 #endif  // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
666