• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/strided_slice_op.h"
17 
18 #include <array>
19 #include <iterator>
20 
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/lib/core/status.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 /// Constants
28 constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
29 
30 // Sparse slicing specification
31 // if one does foo[3:5, ..., -3], this will have 3 length tensors
32 struct StridedSliceSparseSpec {
33   int64 dims;
34   int32 num_add_axis_after_ellipsis;
35   const Tensor* begin_tensor;
36   const Tensor* end_tensor;
37   const Tensor& strides_tensor;
38   const int32 begin_mask, end_mask;
39   int32 ellipsis_mask;
40   const int32 new_axis_mask, shrink_axis_mask;
41 };
42 
43 // Dense slicing specification
44 // all ellipses and newaxis' are expanded out. So if
45 // foo[3:5, ..., -3] where foo is 10 dimensional,
46 // each inlinedVector will have 10 entries whereas the
47 // sparse had 3 length tensors.
48 struct StridedSliceDenseSpec {
49   const int64 dims;
50   int32 begin_mask;
51   int32 end_mask;
52   bool begin_valid;
53   bool end_valid;
54   gtl::InlinedVector<int64, 4>& begin;
55   gtl::InlinedVector<int64, 4>& end;
56   gtl::InlinedVector<int64, 4>& strides;
57   // This vector helps construct the final shape of the slice.
58   // The final tensor is reduced in rank whenever a single index e.g. foo[3]
59   // is called for. The final tensor increases in rank with tf.newaxis
60   // entries. If an index in this array is positive, the size of the dimension
61   // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
62   // it will be 1. A shrunk dimension is skipped.
63   gtl::InlinedVector<int32, 4> final_shape_gather_indices;
64   // This vector has the same size as final_shape_gather_indices, but it
65   // remembers the sparse index that a dimension comes from, instead of dense
66   // index. A -1 in this vector means there the index is not from the sparse
67   // input.
68   gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
69   gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
70   // The dense indexed shrink mask is which processing dimensions
71   // should be shrunk. For example, if foo.shape = (10,10,10,10)
72   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
73   // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
74   int32 shrink_axis_mask;
75 };
76 
77 }  // namespace
78 
79 template <class T>
BuildDenseSpec(const StridedSliceSparseSpec & sparse,StridedSliceDenseSpec * dense)80 static Status TF_MUST_USE_RESULT BuildDenseSpec(
81     const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
82   // Build expanded begin, end, strides, begin_mask, end_mask
83   // to remove any ellipsis
84   dense->begin.resize(dense->dims);
85   dense->end.resize(dense->dims);
86   dense->strides.resize(dense->dims);
87   dense->input_shape_gather_indices_sparse.resize(dense->dims);
88   // What indices to get the final shape from.
89   dense->begin_mask = 0;
90   dense->end_mask = 0;
91   dense->shrink_axis_mask = 0;
92   {
93     int full_index = 0;
94 
95     const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
96     dense->begin_valid = sparse.begin_tensor != nullptr;
97     dense->end_valid = sparse.end_tensor != nullptr;
98 
99     const T* const begin_flat = sparse.begin_tensor != nullptr
100                                     ? sparse.begin_tensor->vec<T>().data()
101                                     : nullptr;
102     const T* const end_flat = sparse.end_tensor != nullptr
103                                   ? sparse.end_tensor->vec<T>().data()
104                                   : nullptr;
105 
106     for (int i = 0; i < sparse.dims; i++) {
107       if ((1 << i) & sparse.ellipsis_mask) {
108         // Expand the ellipsis into the appropriate indices
109         // NOTE: this only works because we guaranteed one ellipsis
110         int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
111                                         sparse.num_add_axis_after_ellipsis,
112                                     dense->dims);
113         for (; full_index < next_index; full_index++) {
114           // new_axis' aren't real axis so you have to skip
115           dense->begin[full_index] = dense->end[full_index] = 0;
116           dense->strides[full_index] = 1;
117           dense->begin_mask |= (1 << full_index);
118           dense->end_mask |= (1 << full_index);
119           dense->final_shape_gather_indices.push_back(full_index);
120           dense->final_shape_gather_indices_sparse.push_back(-1);
121           dense->input_shape_gather_indices_sparse[full_index] = i;
122         }
123       } else if ((1 << i) & sparse.new_axis_mask) {
124         dense->final_shape_gather_indices.push_back(kNewAxis);
125         dense->final_shape_gather_indices_sparse.push_back(-1);
126       } else {
127         if (full_index == dense->begin.size()) {
128           return errors::InvalidArgument("Index out of range using input dim ",
129                                          full_index, "; input has only ",
130                                          dense->dims, " dims");
131         }
132 
133         // Gather slicing spec into appropriate index
134         if (begin_flat != nullptr) {
135           dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
136         }
137         if (end_flat != nullptr) {
138           dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
139         }
140         dense->strides[full_index] =
141             internal::SubtleMustCopy<T>(strides_flat[i]);
142         if (sparse.begin_mask & (1 << i)) {
143           dense->begin_mask |= (1 << full_index);
144         }
145         if (sparse.end_mask & (1 << i)) {
146           dense->end_mask |= (1 << full_index);
147         }
148         // If shrink, record where to get the dimensionality from (i.e.
149         // new_axis creates a fake 1 size dimension. Also remember shrink
150         // axis (now in dense form) so we can ignore dense->end below.
151         if (sparse.shrink_axis_mask & (1 << i)) {
152           dense->final_shape_gather_indices.push_back(kShrinkAxis);
153           dense->final_shape_gather_indices_sparse.push_back(-1);
154           dense->shrink_axis_mask |= (1 << full_index);
155         } else {
156           dense->final_shape_gather_indices.push_back(full_index);
157           // Remember that where in the sparse shape the dense dim comes
158           // from.
159           dense->final_shape_gather_indices_sparse.push_back(i);
160         }
161         dense->input_shape_gather_indices_sparse[full_index] = i;
162         full_index++;
163       }
164     }
165   }
166   return Status::OK();
167 }
168 
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,PartialTensorShape * processing_shape,PartialTensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides,StridedSliceShapeSpec * shape_spec)169 Status ValidateStridedSliceOp(
170     const Tensor* begin_tensor, const Tensor* end_tensor,
171     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
172     int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
173     int32 new_axis_mask, int32 shrink_axis_mask,
174     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
175     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
176     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
177     gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) {
178   const bool begin_is_wrong =
179       begin_tensor != nullptr &&
180       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
181         begin_tensor->NumElements() == strides_tensor.NumElements() &&
182         begin_tensor->NumElements() < 32 /* using 32 bit masks */);
183   const bool end_is_wrong =
184       end_tensor != nullptr &&
185       !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
186         end_tensor->NumElements() == strides_tensor.NumElements());
187   if (begin_is_wrong || end_is_wrong ||
188       !TensorShapeUtils::IsVector(strides_tensor.shape())) {
189     if (begin_tensor != nullptr && end_tensor != nullptr) {
190       return errors::InvalidArgument(
191           "Expected begin, end, and strides to be 1D equal size tensors, ",
192           "but got shapes ", begin_tensor->shape().DebugString(), ", ",
193           end_tensor->shape().DebugString(), ", and ",
194           strides_tensor.shape().DebugString(), " instead.");
195     } else {
196       return errors::InvalidArgument(
197           "Expected begin, end, and strides to be 1D equal size tensors, ",
198           "but got shape ", strides_tensor.shape().DebugString(),
199           " for strides.");
200     }
201   }
202   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
203   // i.e. there exists only no more than one ellipsis
204   if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
205     return errors::InvalidArgument(
206         "Multiple ellipses in slice spec not allowed");
207   }
208 
209   // Step 1: Account for ellipsis and new axis
210   //
211   // Check for ellipses and count how many non-newaxis' there are after
212   // TODO(aselle): Convert this to do a fast log2 followed by iteration
213   //               counting ones in next guys
214   bool ellipsis_seen = false;
215 
216   StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
217                                         0,
218                                         begin_tensor,
219                                         end_tensor,
220                                         strides_tensor,
221                                         begin_mask_spec,
222                                         end_mask_spec,
223                                         ellipsis_mask,
224                                         new_axis_mask,
225                                         shrink_axis_mask};
226 
227   for (int32 i = 0; i < sparse_spec.dims; i++) {
228     if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
229       sparse_spec.num_add_axis_after_ellipsis++;
230     }
231     if ((1 << i) & ellipsis_mask) {
232       ellipsis_seen = true;
233     }
234   }
235   // If no ellipsis insert one at the end
236   if (!ellipsis_seen) {
237     sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
238     sparse_spec.dims++;  // this effects loop iteration below
239   }
240 
241   // Step 2: Make a sparse spec into a full index spec
242   //
243   // The sparse spec does not correspond to the number of dimensions
244   // Make a dense spec that corresponds to the number of dimensions
245   //
246   // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
247   // we need to produce the missing begin_mask for the first two
248   // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
249   // we achieve begin_mask=6, end_mask=7
250   StridedSliceDenseSpec dense_spec = {input_shape.dims(),
251                                       0 /* begin_mask */,
252                                       0 /* end_mask */,
253                                       false /* begin_valid */,
254                                       false /* end_valid */,
255                                       *begin,
256                                       *end,
257                                       *strides};
258 
259   if (strides_tensor.dtype() == DT_INT32) {
260     TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
261   } else if (strides_tensor.dtype() == DT_INT64) {
262     TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
263   } else {
264     LOG(FATAL) << "begin must be either int32 or int64";
265   }
266 
267   // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
268   //         and bounds check!
269   *is_identity = true;
270   *slice_dim0 = true;
271   *is_simple_slice = true;
272   processing_shape->Clear();
273   for (int i = 0; i < input_shape.dims(); ++i) {
274     int64& begin_i = (*begin)[i];
275     int64& end_i = (*end)[i];
276     int64& stride_i = (*strides)[i];
277     int64 dim_i = input_shape.dim_size(i);
278     if (stride_i == 0) {
279       return errors::InvalidArgument("strides[", i, "] must be non-zero");
280     }
281     bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
282     if (dim_i == -1) {
283       processing_shape->AddDim(shrink_i ? 1 : -1);
284       continue;
285     }
286 
287     const std::array<int64, 2> masks = {
288         {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
289     const std::array<int64, 2> valid_range = {
290         {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
291 
292     auto canonical = [stride_i, dim_i, masks, valid_range](int64 x, int c) {
293       if (masks[c]) {
294         return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
295       } else {
296         int64 x_fwd = x < 0 ? dim_i + x : x;  // make negative indices positive
297         return x_fwd < valid_range[0]
298                    ? valid_range[0]
299                    : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
300       }
301     };
302     if (shrink_i && stride_i <= 0) {
303       return errors::InvalidArgument(
304           "only stride 1 allowed on non-range indexing.");
305     }
306     (*is_simple_slice) &= stride_i == 1;
307 
308     const bool begin_and_end_masked =
309         (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
310     if (dense_spec.begin_valid && dense_spec.end_valid) {
311       if (shrink_i) {
312         // If we are shrinking, the end index is now possibly incorrect. In
313         // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
314         // and canonical puts these to n-1 and 0, which implies a degenerate
315         // interval. Fortunately, it is now safe to re-create end as begin+1.
316         int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
317         begin_i = x_fwd;
318         end_i = begin_i + 1;
319         if (x_fwd < 0 || x_fwd >= dim_i) {
320           return errors::InvalidArgument(
321               "slice index ", begin_i, " of dimension ", i, " out of bounds.");
322         }
323       } else {
324         begin_i = canonical(begin_i, 0);
325         end_i = canonical(end_i, 1);
326       }
327       // Update optimization values
328       bool take_all_in_dimension =
329           stride_i == 1 && begin_i == 0 && end_i == dim_i;
330       (*is_identity) &= take_all_in_dimension;
331       (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
332     } else {
333       (*is_identity) &= stride_i == 1 && begin_and_end_masked;
334       (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
335     }
336     // Compute the processing shape (the intermediate Eigen will produce)
337     int64 interval_length;
338     bool known_interval = false;
339     if (dense_spec.begin_valid && dense_spec.end_valid) {
340       interval_length = end_i - begin_i;
341       known_interval = true;
342     } else if (shrink_i) {
343       // The dimension is still known as 1 for the processing_shape, but will be
344       // discarded for the final shape.
345       interval_length = 1;
346       known_interval = true;
347     } else if (begin_and_end_masked) {
348       // Even if we don't have values for begin or end, we do know that this
349       // dimension covers the whole interval. If we have shape information for
350       // this dimension, that tells us the interval length.
351       if (dim_i >= 0) {
352         if (stride_i < 0) {
353           interval_length = -dim_i;
354         } else {
355           interval_length = dim_i;
356         }
357         known_interval = true;
358       }
359     }
360     if (known_interval) {
361       int64 size_i;
362       // Hold zero if the interval is degenerate, otherwise account for
363       // remainder
364       if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
365         size_i = 0;
366       } else {
367         size_i = interval_length / stride_i +
368                  (interval_length % stride_i != 0 ? 1 : 0);
369       }
370       processing_shape->AddDim(size_i);
371     } else {
372       processing_shape->AddDim(-1);
373     }
374   }
375 
376   // Step 4: Compute the final shape
377   //
378   // new_axis will increase dimension by 1 (with a one-size dimension)
379   // slices like foo[3,...] will reduce dimension by 1.
380   // This cannot be done earlier, because it depends on Step 3.
381   final_shape->Clear();
382   if (shape_spec != nullptr) {
383     shape_spec->output_to_sparse_mapping.clear();
384     shape_spec->output_to_processing_mapping.clear();
385     shape_spec->processing_to_sparse_mapping.assign(
386         dense_spec.input_shape_gather_indices_sparse.begin(),
387         dense_spec.input_shape_gather_indices_sparse.end());
388 
389     shape_spec->begin_dense_mask = dense_spec.begin_mask;
390     shape_spec->end_dense_mask = dense_spec.end_mask;
391     shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
392   }
393 
394   for (int64 dense_dim = 0;
395        dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
396     int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
397     int64 sparse_index =
398         dense_spec.final_shape_gather_indices_sparse[dense_dim];
399     if (gather_index >= 0) {
400       final_shape->AddDim(processing_shape->dim_size(gather_index));
401       if (shape_spec != nullptr) {
402         shape_spec->output_to_sparse_mapping.push_back(sparse_index);
403         shape_spec->output_to_processing_mapping.push_back(gather_index);
404       }
405     } else if (gather_index == kNewAxis) {
406       final_shape->AddDim(1);
407       if (shape_spec != nullptr) {
408         shape_spec->output_to_sparse_mapping.push_back(-1);
409         shape_spec->output_to_processing_mapping.push_back(-1);
410       }
411     }
412   }
413 
414   return Status::OK();
415 }
416 
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,TensorShape * processing_shape,TensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides,StridedSliceShapeSpec * shape_spec)417 Status ValidateStridedSliceOp(
418     const Tensor* begin_tensor, const Tensor* end_tensor,
419     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
420     int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
421     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
422     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
423     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
424     gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
425     StridedSliceShapeSpec* shape_spec) {
426   // Validate with PartialTensorShape output
427   PartialTensorShape partial_processing_shape, partial_final_shape;
428   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
429       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
430       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
431       &partial_processing_shape, &partial_final_shape, is_identity,
432       is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
433 
434   // Verify that the output shapes are fully known
435   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
436       !partial_final_shape.AsTensorShape(final_shape)) {
437     return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
438                             partial_processing_shape.DebugString(), " and ",
439                             partial_final_shape.DebugString());
440   }
441   return Status::OK();
442 }
443 
444 }  // namespace tensorflow
445