• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/strided_slice_op.h"
17 
18 #include <array>
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/lib/core/status.h"
21 
22 namespace tensorflow {
23 namespace {
24 
25 /// Constants
26 constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
27 
28 // Sparse slicing specification
29 // if one does foo[3:5, ..., -3], this will have 3 length tensors
30 struct StridedSliceSparseSpec {
31   int64 dims;
32   int32 num_add_axis_after_ellipsis;
33   const Tensor* begin_tensor;
34   const Tensor* end_tensor;
35   const Tensor& strides_tensor;
36   const int32 begin_mask, end_mask;
37   int32 ellipsis_mask;
38   const int32 new_axis_mask, shrink_axis_mask;
39 };
40 
41 // Dense slicing specification
42 // all ellipses and newaxis' are expanded out. So if
43 // foo[3:5, ..., -3] where foo is 10 dimensional,
44 // each inlinedVector will have 10 entries whereas the
45 // sparse had 3 length tensors.
46 struct StridedSliceDenseSpec {
47   const int64 dims;
48   int32 begin_mask;
49   int32 end_mask;
50   bool begin_valid;
51   bool end_valid;
52   gtl::InlinedVector<int64, 4>& begin;
53   gtl::InlinedVector<int64, 4>& end;
54   gtl::InlinedVector<int64, 4>& strides;
55   // This vector helps construct the final shape of the slice.
56   // The final tensor is reduced in rank whenever a single index e.g. foo[3]
57   // is called for. The final tensor increases in rank with tf.newaxis
58   // entries. If an index in this array is positive, the size of the dimension
59   // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
60   // it will be 1. A shrunk dimension is skipped.
61   gtl::InlinedVector<int32, 4> final_shape_gather_indices;
62   // The dense indexed shrink mask is which processing dimensions
63   // should be shrunk. For example, if foo.shape = (10,10,10,10)
64   // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
65   // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
66   int32 shrink_axis_mask;
67 };
68 
69 }  // namespace
70 
71 template <class T>
BuildDenseSpec(const StridedSliceSparseSpec & sparse,StridedSliceDenseSpec * dense)72 static Status TF_MUST_USE_RESULT BuildDenseSpec(
73     const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
74   // Build expanded begin, end, strides, begin_mask, end_mask
75   // to remove any ellipsis
76   dense->begin.resize(dense->dims);
77   dense->end.resize(dense->dims);
78   dense->strides.resize(dense->dims);
79   // What indices to get the final shape from.
80   dense->begin_mask = 0;
81   dense->end_mask = 0;
82   dense->shrink_axis_mask = 0;
83   {
84     int full_index = 0;
85 
86     const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
87     dense->begin_valid = sparse.begin_tensor != nullptr;
88     dense->end_valid = sparse.end_tensor != nullptr;
89 
90     const T* const begin_flat = sparse.begin_tensor != nullptr
91                                     ? sparse.begin_tensor->vec<T>().data()
92                                     : nullptr;
93     const T* const end_flat = sparse.end_tensor != nullptr
94                                   ? sparse.end_tensor->vec<T>().data()
95                                   : nullptr;
96 
97     for (int i = 0; i < sparse.dims; i++) {
98       if ((1 << i) & sparse.ellipsis_mask) {
99         // Expand the ellipsis into the appropriate indices
100         // NOTE: this only works because we guaranteed one ellipsis
101         int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
102                                         sparse.num_add_axis_after_ellipsis,
103                                     dense->dims);
104         for (; full_index < next_index; full_index++) {
105           // new_axis' aren't real axis so you have to skip
106           dense->begin[full_index] = dense->end[full_index] = 0;
107           dense->strides[full_index] = 1;
108           dense->begin_mask |= (1 << full_index);
109           dense->end_mask |= (1 << full_index);
110           dense->final_shape_gather_indices.push_back(full_index);
111         }
112       } else if ((1 << i) & sparse.new_axis_mask) {
113         dense->final_shape_gather_indices.push_back(kNewAxis);
114       } else {
115         if (full_index == dense->begin.size()) {
116           return errors::InvalidArgument("Index out of range using input dim ",
117                                          full_index, "; input has only ",
118                                          dense->dims, " dims");
119         }
120 
121         // Gather slicing spec into appropriate index
122         if (begin_flat != nullptr) {
123           dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
124         }
125         if (end_flat != nullptr) {
126           dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
127         }
128         dense->strides[full_index] =
129             internal::SubtleMustCopy<T>(strides_flat[i]);
130         if (sparse.begin_mask & (1 << i)) {
131           dense->begin_mask |= (1 << full_index);
132         }
133         if (sparse.end_mask & (1 << i)) {
134           dense->end_mask |= (1 << full_index);
135         }
136         // If shrink, record where to get the dimensionality from (i.e.
137         // new_axis creates a fake 1 size dimension. Also remember shrink
138         // axis (now in dense form) so we can ignore dense->end below.
139         if (sparse.shrink_axis_mask & (1 << i)) {
140           dense->final_shape_gather_indices.push_back(kShrinkAxis);
141           dense->shrink_axis_mask |= (1 << full_index);
142         } else {
143           dense->final_shape_gather_indices.push_back(full_index);
144         }
145         full_index++;
146       }
147     }
148   }
149   return Status::OK();
150 }
151 
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,PartialTensorShape * processing_shape,PartialTensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides)152 Status ValidateStridedSliceOp(
153     const Tensor* begin_tensor, const Tensor* end_tensor,
154     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
155     int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
156     int32 new_axis_mask, int32 shrink_axis_mask,
157     PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
158     bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
159     gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
160     gtl::InlinedVector<int64, 4>* strides) {
161   const bool begin_is_wrong =
162       begin_tensor != nullptr &&
163       !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
164         begin_tensor->NumElements() == strides_tensor.NumElements() &&
165         begin_tensor->NumElements() < 32 /* using 32 bit masks */);
166   const bool end_is_wrong =
167       end_tensor != nullptr &&
168       !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
169         end_tensor->NumElements() == strides_tensor.NumElements());
170   if (begin_is_wrong || end_is_wrong ||
171       !TensorShapeUtils::IsVector(strides_tensor.shape())) {
172     if (begin_tensor != nullptr && end_tensor != nullptr) {
173       return errors::InvalidArgument(
174           "Expected begin, end, and strides to be 1D equal size tensors, ",
175           "but got shapes ", begin_tensor->shape().DebugString(), ", ",
176           end_tensor->shape().DebugString(), ", and ",
177           strides_tensor.shape().DebugString(), " instead.");
178     } else {
179       return errors::InvalidArgument(
180           "Expected begin, end, and strides to be 1D equal size tensors, ",
181           "but got shape ", strides_tensor.shape().DebugString(),
182           " for strides.");
183     }
184   }
185   // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
186   // i.e. there exists only no more than one ellipsis
187   if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
188     return errors::InvalidArgument(
189         "Multiple ellipses in slice spec not allowed");
190   }
191 
192   // Step 1: Account for ellipsis and new axis
193   //
194   // Check for ellipses and count how many non-newaxis' there are after
195   // TODO(aselle): Convert this to do a fast log2 followed by iteration
196   //               counting ones in next guys
197   bool ellipsis_seen = false;
198 
199   StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
200                                         0,
201                                         begin_tensor,
202                                         end_tensor,
203                                         strides_tensor,
204                                         begin_mask_spec,
205                                         end_mask_spec,
206                                         ellipsis_mask,
207                                         new_axis_mask,
208                                         shrink_axis_mask};
209 
210   for (int32 i = 0; i < sparse_spec.dims; i++) {
211     if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
212       sparse_spec.num_add_axis_after_ellipsis++;
213     }
214     if ((1 << i) & ellipsis_mask) {
215       ellipsis_seen = true;
216     }
217   }
218   // If no ellipsis insert one at the end
219   if (!ellipsis_seen) {
220     sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
221     sparse_spec.dims++;  // this effects loop iteration below
222   }
223 
224   // Step 2: Make a sparse spec into a full index spec
225   //
226   // The sparse spec does not correspond to the number of dimensions
227   // Make a dense spec that corresponds to the number of dimensions
228   //
229   // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
230   // we need to produce the missing begin_mask for the first two
231   // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
232   // we achieve begin_mask=6, end_mask=7
233   StridedSliceDenseSpec dense_spec = {input_shape.dims(),
234                                       0 /* begin_mask */,
235                                       0 /* end_mask */,
236                                       false /* begin_valid */,
237                                       false /* end_valid */,
238                                       *begin,
239                                       *end,
240                                       *strides};
241 
242   if (strides_tensor.dtype() == DT_INT32) {
243     TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
244   } else if (strides_tensor.dtype() == DT_INT64) {
245     TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
246   } else {
247     LOG(FATAL) << "begin must be either int32 or int64";
248   }
249 
250   // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
251   //         and bounds check!
252   *is_identity = true;
253   *slice_dim0 = true;
254   *is_simple_slice = true;
255   processing_shape->Clear();
256   for (int i = 0; i < input_shape.dims(); ++i) {
257     int64& begin_i = (*begin)[i];
258     int64& end_i = (*end)[i];
259     int64& stride_i = (*strides)[i];
260     int64 dim_i = input_shape.dim_size(i);
261     if (stride_i == 0) {
262       return errors::InvalidArgument("strides[", i, "] must be non-zero");
263     }
264     bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
265     if (dim_i == -1) {
266       processing_shape->AddDim(shrink_i ? 1 : -1);
267       continue;
268     }
269 
270     const std::array<int64, 2> masks = {
271         {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
272     const std::array<int64, 2> valid_range = {
273         {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
274 
275     auto canonical = [stride_i, dim_i, masks, valid_range](int64 x, int c) {
276       if (masks[c]) {
277         return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
278       } else {
279         int64 x_fwd = x < 0 ? dim_i + x : x;  // make negative indices positive
280         return x_fwd < valid_range[0]
281                    ? valid_range[0]
282                    : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
283       }
284     };
285     if (shrink_i && stride_i <= 0) {
286       return errors::InvalidArgument(
287           "only stride 1 allowed on non-range indexing.");
288     }
289     (*is_simple_slice) &= stride_i == 1;
290 
291     const bool begin_and_end_masked =
292         (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
293     if (dense_spec.begin_valid && dense_spec.end_valid) {
294       if (shrink_i) {
295         // If we are shrinking, the end index is now possibly incorrect. In
296         // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
297         // and canonical puts these to n-1 and 0, which implies a degenerate
298         // interval. Fortunately, it is now safe to re-create end as begin+1.
299         int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
300         begin_i = x_fwd;
301         end_i = begin_i + 1;
302         if (x_fwd < 0 || x_fwd >= dim_i) {
303           return errors::InvalidArgument(
304               "slice index ", begin_i, " of dimension ", i, " out of bounds.");
305         }
306       } else {
307         begin_i = canonical(begin_i, 0);
308         end_i = canonical(end_i, 1);
309       }
310       // Update optimization values
311       bool take_all_in_dimension =
312           stride_i == 1 && begin_i == 0 && end_i == dim_i;
313       (*is_identity) &= take_all_in_dimension;
314       (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
315     } else {
316       (*is_identity) &= stride_i == 1 && begin_and_end_masked;
317       (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
318     }
319     // Compute the processing shape (the intermediate Eigen will produce)
320     int64 interval_length;
321     bool known_interval = false;
322     if (dense_spec.begin_valid && dense_spec.end_valid) {
323       interval_length = end_i - begin_i;
324       known_interval = true;
325     } else if (shrink_i) {
326       // The dimension is still known as 1 for the processing_shape, but will be
327       // discarded for the final shape.
328       interval_length = 1;
329       known_interval = true;
330     } else if (begin_and_end_masked) {
331       // Even if we don't have values for begin or end, we do know that this
332       // dimension covers the whole interval. If we have shape information for
333       // this dimension, that tells us the interval length.
334       if (dim_i >= 0) {
335         if (stride_i < 0) {
336           interval_length = -dim_i;
337         } else {
338           interval_length = dim_i;
339         }
340         known_interval = true;
341       }
342     }
343     if (known_interval) {
344       int64 size_i;
345       // Hold zero if the interval is degenerate, otherwise account for
346       // remainder
347       if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
348         size_i = 0;
349       } else {
350         size_i = interval_length / stride_i +
351                  (interval_length % stride_i != 0 ? 1 : 0);
352       }
353       processing_shape->AddDim(size_i);
354     } else {
355       processing_shape->AddDim(-1);
356     }
357   }
358 
359   // Step 4: Compute the final shape
360   //
361   // new_axis will increase dimension by 1 (with a one-size dimension)
362   // slices like foo[3,...] will reduce dimension by 1.
363   // This cannot be done earlier, because it depends on Step 3.
364   final_shape->Clear();
365   for (auto gather_index : dense_spec.final_shape_gather_indices) {
366     if (gather_index >= 0) {
367       final_shape->AddDim(processing_shape->dim_size(gather_index));
368     } else if (gather_index == kNewAxis) {
369       final_shape->AddDim(1);
370     }
371   }
372   return Status::OK();
373 }
374 
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,TensorShape * processing_shape,TensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides)375 Status ValidateStridedSliceOp(
376     const Tensor* begin_tensor, const Tensor* end_tensor,
377     const Tensor& strides_tensor, const PartialTensorShape& input_shape,
378     int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
379     int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
380     TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
381     bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
382     gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
383   // Validate with PartialTensorShape output
384   PartialTensorShape partial_processing_shape, partial_final_shape;
385   TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
386       begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
387       end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
388       &partial_processing_shape, &partial_final_shape, is_identity,
389       is_simple_slice, slice_dim0, begin, end, strides));
390 
391   // Verify that the output shapes are fully known
392   if (!partial_processing_shape.AsTensorShape(processing_shape) ||
393       !partial_final_shape.AsTensorShape(final_shape)) {
394     return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
395                             partial_processing_shape.DebugString(), " and ",
396                             partial_final_shape.DebugString());
397   }
398   return Status::OK();
399 }
400 
401 }  // namespace tensorflow
402