1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
17 #define TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
18
19 #include <limits>
20 #include <numeric>
21 #include <vector>
22
23 #include "absl/base/macros.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/bounds_check.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/status.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/sparse/dim_comparator.h"
36 #include "tensorflow/core/util/sparse/group_iterator.h"
37
38 namespace tensorflow {
39 namespace sparse {
40
41 class SparseTensor {
42 public:
43 typedef typename gtl::ArraySlice<int64> VarDimArray;
44 typedef typename gtl::InlinedVector<int64, 8> ShapeArray;
45
46 static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
47 const VarDimArray order, SparseTensor* result);
48
49 static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
50 SparseTensor* result);
51
52 static Status Create(Tensor ix, Tensor vals, const VarDimArray shape,
53 SparseTensor* result);
54
55 static Status Create(Tensor ix, Tensor vals, const TensorShape& shape,
56 const VarDimArray order, SparseTensor* result);
57
SparseTensor()58 SparseTensor() : dims_(0) {}
59
60 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape)61 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
62 : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
63 UndefinedOrder(TensorShapeToVector(shape))) {}
64
65 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const VarDimArray shape)66 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
67 : SparseTensor(std::move(ix), std::move(vals), shape,
68 UndefinedOrder(shape)) {}
69
70 ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix,Tensor vals,const TensorShape & shape,const VarDimArray order)71 SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
72 const VarDimArray order)
73 : SparseTensor(std::move(ix), std::move(vals), TensorShapeToVector(shape),
74 order) {}
75
76 ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
77 SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
78 const VarDimArray order);
79
SparseTensor(const SparseTensor & other)80 SparseTensor(const SparseTensor& other)
81 : SparseTensor(other.ix_, other.vals_, other.shape_, other.order_) {}
82
SparseTensor(SparseTensor && other)83 SparseTensor(SparseTensor&& other)
84 : SparseTensor(std::move(other.ix_), std::move(other.vals_),
85 std::move(other.shape_), std::move(other.order_)) {}
86
87 SparseTensor& operator=(const SparseTensor& other) {
88 ix_ = other.ix_;
89 vals_ = other.vals_;
90 shape_ = other.shape_;
91 order_ = other.order_;
92 dims_ = other.dims_;
93 return *this;
94 }
95
96 SparseTensor& operator=(SparseTensor&& other) {
97 ix_ = std::move(other.ix_);
98 vals_ = std::move(other.vals_);
99 shape_ = std::move(other.shape_);
100 order_ = std::move(other.order_);
101 dims_ = std::move(other.dims_);
102 return *this;
103 }
104
num_entries()105 std::size_t num_entries() const { return ix_.dim_size(0); }
106
dims()107 int dims() const { return shape_.size(); }
108
indices()109 const Tensor& indices() const { return ix_; }
110
values()111 const Tensor& values() const { return vals_; }
112
dtype()113 DataType dtype() const { return vals_.dtype(); }
114
115 Status IndicesValid() const;
116
shape()117 VarDimArray shape() const { return shape_; }
118
order()119 VarDimArray order() const { return order_; }
120
121 // Resorts the indices and values according to the dimensions in order.
122 template <typename T>
123 void Reorder(const VarDimArray& order);
124
125 // Returns a group iterable that can be used for clumping indices
126 // and values according to the group indices of interest.
127 //
128 // Precondition: order()[0..group_ix.size()] == group_ix.
129 //
130 // See the README.md in this directory for more usage information.
group(const VarDimArray & group_ix)131 GroupIterable group(const VarDimArray& group_ix) const {
132 DCHECK_LE(group_ix.size(), dims_);
133 for (std::size_t di = 0; di < group_ix.size(); ++di) {
134 DCHECK_GE(group_ix[di], 0) << "Group dimension out of range";
135 DCHECK_LT(group_ix[di], dims_) << "Group dimension out of range";
136 DCHECK_EQ(group_ix[di], order_[di])
137 << "Group dimension does not match sorted order";
138 }
139 return GroupIterable(ix_, vals_, dims_, group_ix);
140 }
141
142 // Stores the sparse indices into the dense tensor out.
143 // Preconditions:
144 // out->shape().dims() == shape().dims()
145 // out->shape().dim_size(d) >= shape(d) for all d
146 //
147 // Returns true on success. False on failure (mismatched dimensions
148 // or out-of-bounds indices).
149 //
150 // If initialize==True, ToDense first overwrites all coefficients in out to 0.
151 //
152 template <typename T>
153 bool ToDense(Tensor* out, bool initialize = true);
154
155 // Concat() will concatenate all the tensors according to their first order
156 // dimension. All tensors must have identical shape except for
157 // the first order dimension. All tensors orders' first dimension
158 // must match.
159 //
160 // If all of the tensors have identical ordering, then the output
161 // will have this ordering. Otherwise the output is set as not
162 // having any order and a Reorder<T>() should be called on it before
163 // performing any subsequent operations.
164 template <typename T>
165 static SparseTensor Concat(const gtl::ArraySlice<SparseTensor>& tensors);
166
167 // Split() will split the input SparseTensor into a list of num_split
168 // SparseTensor given a splitting dimension. If the input dimension range
169 // isn't an integer multiple of split_dim, we add one extra dimension for
170 // each slice.
171 template <typename T>
172 static Status Split(const SparseTensor& tensor, const int split_dim,
173 const int num_split, std::vector<SparseTensor>* result);
174
175 // Slice() will slice the input SparseTensor into a SparseTensor based on
176 // specified start and size. Both start and size are 1-D array with each
177 // element of the array representing one dimension. The start is the start
178 // index at each dimension and the size is the size at each dimension.
179 template <typename T>
180 static SparseTensor Slice(const SparseTensor& tensor,
181 const gtl::ArraySlice<int64>& start,
182 const gtl::ArraySlice<int64>& size);
183
184 // Picks out the dimensions according to `dim_indices`.
PickDims(gtl::ArraySlice<int64> dim_indices)185 std::vector<int64> PickDims(gtl::ArraySlice<int64> dim_indices) const {
186 std::vector<int64> res(dim_indices.size());
187 for (size_t i = 0; i < dim_indices.size(); ++i) {
188 res[i] = shape_[dim_indices[i]];
189 }
190 return res;
191 }
192
193 private:
UndefinedOrder(const VarDimArray shape)194 static inline ShapeArray UndefinedOrder(const VarDimArray shape) {
195 return ShapeArray(shape.size(), -1);
196 }
197
TensorShapeToVector(const TensorShape & shape)198 static inline ShapeArray TensorShapeToVector(const TensorShape& shape) {
199 ShapeArray vec(shape.dims());
200 for (int i = 0; i < shape.dims(); ++i) vec[i] = shape.dim_size(i);
201 return vec;
202 }
203
204 // Optimized implementation of `IndicesValid` for 1-D sparse tensors.
205 // REQUIRES: `shape_.size() == 1`.
206 bool IndicesValidVectorFastPath() const;
207
208 // Optimized implementation of `IndicesValid` for 2-D sparse tensors whose
209 // indices fit within the range of an `int32`.
210 // REQUIRES: `shape_.size() == 2`.
211 bool IndicesValidMatrix32BitFastPath() const;
212
213 template <bool standard_order>
214 Status IndicesValidHelper() const;
215
216 // Helper for ToDense<T>()
217 template <typename T>
218 bool ValidateAndInitializeToDense(Tensor* out, bool initialize);
219
220 // Helper for Split() that returns the slice index.
GetSliceIndex(const int dim,const int split_size,const int residual)221 static inline int GetSliceIndex(const int dim, const int split_size,
222 const int residual) {
223 DCHECK_GT(split_size, 0);
224 DCHECK_GE(dim, 0);
225 if (residual == 0) return dim / split_size;
226 const int offset = residual * (split_size + 1);
227 if (dim < offset) {
228 return dim / (split_size + 1);
229 } else {
230 return residual + ((dim - offset) / split_size);
231 }
232 }
233
234 // Helper for Split() that returns the dimension in the slice.
GetDimensionInSlice(const int dim,const int split_size,const int residual)235 static inline int GetDimensionInSlice(const int dim, const int split_size,
236 const int residual) {
237 DCHECK_GT(split_size, 0);
238 DCHECK_GE(dim, 0);
239 if (residual == 0) return dim % split_size;
240 const int offset = residual * (split_size + 1);
241 if (dim < offset) {
242 return dim % (split_size + 1);
243 } else {
244 return (dim - offset) % split_size;
245 }
246 }
247
248 // Helper for Split() that returns the shape given a slice index.
GetSliceShape(const int slice_index,const int split_size,const int residual)249 static inline int GetSliceShape(const int slice_index, const int split_size,
250 const int residual) {
251 DCHECK_GT(split_size, 0);
252 DCHECK_GE(slice_index, 0);
253 if (residual == 0) return split_size;
254 if (slice_index < residual) {
255 return split_size + 1;
256 } else {
257 return split_size;
258 }
259 }
260
261 Tensor ix_;
262 Tensor vals_;
263 ShapeArray shape_;
264 ShapeArray order_;
265 int dims_;
266 };
267
268 // This operation updates the indices and values Tensor rows, so it is
269 // an in-place algorithm. It requires O(N log N) time and O(N)
270 // temporary space.
271 template <typename T>
Reorder(const VarDimArray & order)272 inline void SparseTensor::Reorder(const VarDimArray& order) {
273 DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
274 << "Reorder requested with the wrong datatype";
275 DCHECK_EQ(order.size(), dims_) << "Order length must be SparseTensor rank";
276 auto ix_t = ix_.matrix<int64>();
277 auto vals_t = vals_.vec<T>();
278
279 std::vector<int64> reorder(num_entries());
280 std::iota(reorder.begin(), reorder.end(), 0);
281
282 // Sort to get order of indices
283 switch (order.size()) {
284 #define CASE_SORT(ORDER_SIZE) \
285 case ORDER_SIZE: { \
286 FixedDimComparator<ORDER_SIZE> sorter(ix_t, order, shape()); \
287 std::sort(reorder.begin(), reorder.end(), sorter); \
288 break; \
289 }
290 CASE_SORT(0);
291 CASE_SORT(1);
292 CASE_SORT(2);
293 CASE_SORT(3);
294 CASE_SORT(4);
295 CASE_SORT(5);
296 #undef CASE_SORT
297 default: {
298 DimComparator sorter(ix_t, order, shape());
299 std::sort(reorder.begin(), reorder.end(), sorter);
300 }
301 }
302
303 // We have a forward reordering, but what we'll need is a
304 // permutation (the inverse). This can be calculated with O(1)
305 // additional
306 // and O(n) time (INVPERM) but we just do the simple thing here.
307 std::vector<size_t> permutation(reorder.size());
308 for (std::size_t n = 0; n < reorder.size(); ++n) {
309 permutation[reorder[n]] = n;
310 }
311
312 // Update indices & values by converting the permutations to
313 // a product of transpositions. Iterate over the cycles in the
314 // permutation, and convert each of those into a product of
315 // transpositions (swaps):
316 // https://en.wikipedia.org/wiki/Cyclic_permutation
317 // This is N swaps, 2*N comparisons.
318 for (std::size_t n = 0; n + 1 < permutation.size(); ++n) {
319 while (n != permutation[n]) {
320 std::size_t r = permutation[n];
321 std::swap_ranges(&(ix_t(n, 0)), &(ix_t(n + 1, 0)), &(ix_t(r, 0)));
322 std::swap(vals_t(n), vals_t(r));
323 std::swap(permutation[n], permutation[r]);
324 }
325 }
326
327 order_ = ShapeArray(order.begin(), order.end());
328 }
329
330 template <typename T>
ValidateAndInitializeToDense(Tensor * out,bool initialize)331 inline bool SparseTensor::ValidateAndInitializeToDense(Tensor* out,
332 bool initialize) {
333 DCHECK_EQ(DataTypeToEnum<T>::v(), dtype())
334 << "ToDense requested with the wrong datatype";
335
336 DCHECK_EQ(out->shape().dims(), dims_)
337 << "Incompatible dimensions between SparseTensor and output";
338
339 DCHECK_EQ(out->dtype(), DataTypeToEnum<T>::v())
340 << "Output must be type: " << DataTypeToEnum<T>::v()
341 << " but got: " << out->dtype();
342
343 // Make sure the dense output is the same rank and has room
344 // to hold the SparseTensor.
345 const auto& out_shape = out->shape();
346 if (shape_.size() != out_shape.dims()) return false;
347 for (int d = 0; d < shape_.size(); ++d) {
348 if (shape_[d] > out_shape.dim_size(d)) return false;
349 }
350
351 if (initialize) {
352 auto out_t = out->flat<T>();
353 out_t.setConstant(T());
354 }
355
356 return true;
357 }
358
359 template <typename T>
ToDense(Tensor * out,bool initialize)360 inline bool SparseTensor::ToDense(Tensor* out, bool initialize) {
361 if (!ValidateAndInitializeToDense<T>(out, initialize)) return false;
362
363 auto out_t = out->flat<T>();
364 auto vals_t = vals_.vec<T>();
365 auto ix_t = ix_.matrix<int64>();
366 const int64* const ix_ptr = ix_t.data();
367
368 if (dims_ == 1) {
369 // Fast path for sparse vectors.
370 const int64 out_length = out->shape().dim_size(0);
371 for (int n = 0; n < vals_t.dimension(0); ++n) {
372 const int64 index = internal::SubtleMustCopy(ix_ptr[n]);
373 if (!FastBoundsCheck(index, out_length)) return false;
374 out_t(index) = vals_t(n);
375 }
376 return true;
377 } else if (dims_ == 2) {
378 // Fast path for sparse matrices.
379 const auto& out_shape = out->shape();
380 const int64 out_rows = out_shape.dim_size(0);
381 const int64 out_cols = out_shape.dim_size(1);
382 for (int n = 0; n < vals_t.dimension(0); ++n) {
383 const int64 row_index = internal::SubtleMustCopy(ix_ptr[n * 2]);
384 const int64 col_index = internal::SubtleMustCopy(ix_ptr[n * 2 + 1]);
385 if (!(FastBoundsCheck(row_index, out_rows) &&
386 FastBoundsCheck(col_index, out_cols))) {
387 return false;
388 }
389 out_t(row_index * out_cols + col_index) = vals_t(n);
390 }
391 return true;
392 } else {
393 // General path for N-dimensional sparse tensors.
394 gtl::InlinedVector<int64, 4> strides(dims_);
395 const auto& out_shape = out->shape().dim_sizes();
396 if (dims_ > 0) {
397 strides[dims_ - 1] = 1;
398 }
399 for (int d = dims_ - 2; d >= 0; --d) {
400 strides[d] = strides[d + 1] * out_shape[d + 1];
401 }
402
403 for (int n = 0; n < vals_t.dimension(0); ++n) {
404 bool invalid_dims = false;
405 int64 ix = 0;
406 for (int d = 0; d < dims_; ++d) {
407 const int64 ix_n_d = internal::SubtleMustCopy(ix_ptr[n * dims_ + d]);
408 if (!FastBoundsCheck(ix_n_d, out_shape[d])) {
409 invalid_dims = true;
410 }
411 ix += strides[d] * ix_n_d;
412 }
413 if (invalid_dims) return false;
414 out_t(ix) = vals_t(n);
415 }
416 return true;
417 }
418 }
419
420 template <typename T>
Concat(const gtl::ArraySlice<SparseTensor> & tensors)421 inline SparseTensor SparseTensor::Concat(
422 const gtl::ArraySlice<SparseTensor>& tensors) {
423 DCHECK_GE(tensors.size(), size_t{1}) << "Cannot concat 0 SparseTensors";
424 const int dims = tensors[0].dims_;
425 DCHECK_GE(dims, 1) << "Cannot concat 0-dimensional SparseTensors";
426 auto order_0 = tensors[0].order();
427 const int primary_dim = order_0[0];
428 ShapeArray final_order(order_0.begin(), order_0.end());
429 ShapeArray final_shape(tensors[0].shape().begin(), tensors[0].shape().end());
430 final_shape[primary_dim] = 0; // We'll build this up as we go along.
431 int num_entries = 0;
432
433 bool fully_ordered = true;
434 for (const SparseTensor& st : tensors) {
435 DCHECK_EQ(st.dims_, dims) << "All SparseTensors must have the same rank.";
436 DCHECK_EQ(DataTypeToEnum<T>::v(), st.dtype())
437 << "Concat requested with the wrong data type";
438 DCHECK_GE(st.order()[0], 0) << "SparseTensor must be ordered";
439 DCHECK_EQ(st.order()[0], primary_dim)
440 << "All SparseTensors' order[0] must match. This is the concat dim.";
441 if (st.order() != final_order) fully_ordered = false;
442 const VarDimArray& st_shape = st.shape();
443 for (int d = 0; d < dims - 1; ++d) {
444 const int cdim = (d < primary_dim) ? d : d + 1;
445 DCHECK_EQ(final_shape[cdim], st_shape[cdim])
446 << "All SparseTensors' shapes must match except on the concat dim. "
447 << "Concat dim: " << primary_dim
448 << ", mismatched shape at dim: " << cdim
449 << ". Expecting shape like: [" << str_util::Join(final_shape, ",")
450 << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]";
451 }
452
453 // Update dimension of final shape
454 final_shape[primary_dim] =
455 (final_shape[primary_dim] + st_shape[primary_dim]);
456
457 num_entries += st.num_entries(); // Update number of entries
458 }
459
460 // If nonconsistent ordering among inputs, set final order to -1s.
461 if (!fully_ordered) {
462 final_order = UndefinedOrder(final_shape);
463 }
464
465 Tensor output_ix(DT_INT64, TensorShape({num_entries, dims}));
466 Tensor output_vals(DataTypeToEnum<T>::v(), TensorShape({num_entries}));
467
468 TTypes<int64>::Matrix ix_t = output_ix.matrix<int64>();
469 typename TTypes<T>::Vec vals_t = output_vals.vec<T>();
470
471 Eigen::DenseIndex offset = 0;
472 int64 shape_offset = 0;
473 for (const SparseTensor& st : tensors) {
474 const int st_num_entries = st.num_entries();
475
476 // Fill in indices & values.
477 if (st_num_entries > 0) {
478 std::copy_n(&st.vals_.vec<T>()(0), st_num_entries, &vals_t(offset));
479
480 const auto* st_ix = &st.ix_.matrix<int64>()(0, 0);
481 auto* ix_out = &ix_t(offset, 0);
482 for (std::size_t i = 0; i < st_num_entries * dims; ++i) {
483 *ix_out++ = *st_ix++ + ((i % dims == primary_dim) ? shape_offset : 0);
484 }
485 }
486
487 offset += st_num_entries;
488 shape_offset += st.shape()[primary_dim];
489 }
490
491 return SparseTensor(output_ix, output_vals, final_shape, final_order);
492 }
493
494 template <typename T>
Split(const SparseTensor & input_tensor,const int split_dim,const int num_split,std::vector<SparseTensor> * result)495 inline Status SparseTensor::Split(const SparseTensor& input_tensor,
496 const int split_dim, const int num_split,
497 std::vector<SparseTensor>* result) {
498 std::vector<Tensor> output_indices;
499 std::vector<Tensor> output_values;
500 std::vector<TensorShape> output_shapes;
501 output_indices.reserve(num_split);
502 output_values.reserve(num_split);
503 output_shapes.reserve(num_split);
504
505 std::vector<typename TTypes<int64>::Matrix> output_indices_t;
506 std::vector<typename TTypes<T>::Vec> output_values_t;
507 output_indices_t.reserve(num_split);
508 output_values_t.reserve(num_split);
509 auto input_values_t = input_tensor.values().vec<T>();
510 auto input_indices_t = input_tensor.indices().matrix<int64>();
511
512 std::vector<int> num_values(num_split, 0);
513 const int num_dim = input_tensor.shape().size();
514 const int split_dim_size = input_tensor.shape()[split_dim];
515 const int split_size = split_dim_size / num_split;
516
517 if (!(num_split > 0 && num_split <= split_dim_size)) {
518 return errors::InvalidArgument("num_split must be in the interval (0, ",
519 split_dim_size, "]");
520 }
521 if (!(split_dim >= 0 && split_dim < num_dim)) {
522 return errors::InvalidArgument("num_dim must be in the interval [0, ",
523 num_dim, ")");
524 }
525
526 const int residual = split_dim_size % num_split;
527 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
528 const int dim = input_tensor.indices().matrix<int64>()(i, split_dim);
529 int slice_index = GetSliceIndex(dim, split_size, residual);
530 num_values[slice_index]++;
531 }
532
533 for (int i = 0; i < num_split; ++i) {
534 // TODO(ataei): Pass an allocator to avoid allocating large memory buffer.
535 output_indices.emplace_back(DT_INT64,
536 TensorShape({num_values[i], num_dim}));
537 output_values.emplace_back(DataTypeToEnum<T>::v(),
538 TensorShape({num_values[i]}));
539 output_shapes.emplace_back(input_tensor.shape());
540 output_indices_t.emplace_back(output_indices[i].matrix<int64>());
541 output_values_t.emplace_back(output_values[i].vec<T>());
542 const int size = GetSliceShape(i, split_size, residual);
543 output_shapes[i].set_dim(split_dim, size);
544 }
545
546 std::vector<int> values_inserted_in_slice(num_split, 0);
547 for (int i = 0; i < input_tensor.indices().dim_size(0); ++i) {
548 const int dim = input_indices_t(i, split_dim);
549 const int slice_index = GetSliceIndex(dim, split_size, residual);
550 const int slice_dim = values_inserted_in_slice[slice_index]++;
551 output_values_t[slice_index](slice_dim) = input_values_t(i);
552 for (int j = 0; j < num_dim; ++j) {
553 const int64 original_dim = input_indices_t(i, j);
554 output_indices_t[slice_index](slice_dim, j) =
555 (j == split_dim)
556 ? GetDimensionInSlice(original_dim, split_size, residual)
557 : original_dim;
558 }
559 }
560
561 result->clear();
562 result->reserve(num_split);
563 for (int i = 0; i < num_split; ++i) {
564 SparseTensor tensor;
565 Status create_status =
566 Create(output_indices[i], output_values[i], output_shapes[i], &tensor);
567 if (!create_status.ok()) {
568 return create_status;
569 }
570 result->push_back(std::move(tensor));
571 }
572 return Status::OK();
573 }
574
575 template <typename T>
Slice(const SparseTensor & input_tensor,const gtl::ArraySlice<int64> & start,const gtl::ArraySlice<int64> & size)576 inline SparseTensor SparseTensor::Slice(const SparseTensor& input_tensor,
577 const gtl::ArraySlice<int64>& start,
578 const gtl::ArraySlice<int64>& size) {
579 TensorShape output_shape(input_tensor.shape());
580
581 const int dims = input_tensor.dims();
582 for (int dim = 0; dim < dims; dim++) {
583 // Determine the size of the result; if the selected slice goes beyond the
584 // input boundary, the result will correspond to the size of the overlap
585 // between the input and the selected slice.
586 const int64 input_size = output_shape.dim_size(dim);
587 const int64 start_index = start[dim];
588 const int64 slice_size = size[dim];
589 if (start_index + slice_size < input_size) {
590 // The entire selection is within input boundaries.
591 output_shape.set_dim(dim, slice_size);
592 } else if (start_index < input_size) {
593 // The selection starts within input boundaries, but goes beyond them.
594 output_shape.set_dim(dim, input_size - start_index);
595 } else {
596 // The selection is entirely out of input boundaries.
597 output_shape.set_dim(dim, 0);
598 }
599 }
600
601 auto input_indices_t = input_tensor.indices().matrix<int64>();
602 auto input_values_t = input_tensor.values().vec<T>();
603
604 // Find the number of indices that fall inside start and size.
605 int count = 0;
606 for (int i = 0; i < input_tensor.indices().dim_size(0); i++) {
607 // The following will check to see if an input is within the
608 // range specified by start and size.
609 // The for loop below iterates through all dimensions. In case
610 // the index falls outside of the start and size at any dimension,
611 // it will be considered as a "no hit" (hit = false). In this
612 // case, it will not be counted as the index that fall inside
613 // the range specified by start and size.
614 bool hit = true;
615 for (int dim = 0; dim < dims; dim++) {
616 if (!(start[dim] <= input_indices_t(i, dim) &&
617 input_indices_t(i, dim) < start[dim] + size[dim])) {
618 hit = false;
619 break;
620 }
621 }
622 if (!hit) {
623 continue;
624 }
625 count++;
626 }
627
628 Tensor output_values(DataTypeToEnum<T>::v(), TensorShape({count}));
629 Tensor output_indices(DT_INT64, TensorShape({count, dims}));
630
631 auto output_values_t = output_values.vec<T>();
632 auto output_indices_t = output_indices.matrix<int64>();
633
634 // Obtain the output indices that fall inside start and size.
635 int index = 0;
636 for (int i = 0; i < input_tensor.indices().dim_size(0) && index < count;
637 i++) {
638 // The logic here is similar as the above except that the above
639 // only count the number of indices while here we actually generate
640 // the output.
641 bool hit = true;
642 for (int dim = 0; dim < dims; dim++) {
643 if (!(start[dim] <= input_indices_t(i, dim) &&
644 input_indices_t(i, dim) < start[dim] + size[dim])) {
645 hit = false;
646 break;
647 }
648 }
649 if (!hit) {
650 continue;
651 }
652 output_values_t(index) = input_values_t(i);
653 for (int dim = 0; dim < dims; dim++) {
654 output_indices_t(index, dim) = input_indices_t(i, dim) - start[dim];
655 }
656 index++;
657 }
658
659 return SparseTensor(output_indices, output_values, output_shape);
660 }
661
662 } // namespace sparse
663 } // namespace tensorflow
664
665 #endif // TENSORFLOW_CORE_UTIL_SPARSE_SPARSE_TENSOR_H_
666