1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/strided_slice_op.h"
17
18 #include <array>
19 #include <iterator>
20
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/lib/core/status.h"
23
24 namespace tensorflow {
25 namespace {
26
27 /// Constants
28 constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
29
30 // Sparse slicing specification
31 // if one does foo[3:5, ..., -3], this will have 3 length tensors
32 struct StridedSliceSparseSpec {
33 int64 dims;
34 int32 num_add_axis_after_ellipsis;
35 const Tensor* begin_tensor;
36 const Tensor* end_tensor;
37 const Tensor& strides_tensor;
38 const int32 begin_mask, end_mask;
39 int32 ellipsis_mask;
40 const int32 new_axis_mask, shrink_axis_mask;
41 };
42
43 // Dense slicing specification
44 // all ellipses and newaxis' are expanded out. So if
45 // foo[3:5, ..., -3] where foo is 10 dimensional,
46 // each inlinedVector will have 10 entries whereas the
47 // sparse had 3 length tensors.
48 struct StridedSliceDenseSpec {
49 const int64 dims;
50 int32 begin_mask;
51 int32 end_mask;
52 bool begin_valid;
53 bool end_valid;
54 gtl::InlinedVector<int64, 4>& begin;
55 gtl::InlinedVector<int64, 4>& end;
56 gtl::InlinedVector<int64, 4>& strides;
57 // This vector helps construct the final shape of the slice.
58 // The final tensor is reduced in rank whenever a single index e.g. foo[3]
59 // is called for. The final tensor increases in rank with tf.newaxis
60 // entries. If an index in this array is positive, the size of the dimension
61 // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
62 // it will be 1. A shrunk dimension is skipped.
63 gtl::InlinedVector<int32, 4> final_shape_gather_indices;
64 // This vector has the same size as final_shape_gather_indices, but it
65 // remembers the sparse index that a dimension comes from, instead of dense
66 // index. A -1 in this vector means there the index is not from the sparse
67 // input.
68 gtl::InlinedVector<int32, 4> final_shape_gather_indices_sparse;
69 gtl::InlinedVector<int32, 4> input_shape_gather_indices_sparse;
70 // The dense indexed shrink mask is which processing dimensions
71 // should be shrunk. For example, if foo.shape = (10,10,10,10)
72 // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
73 // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
74 int32 shrink_axis_mask;
75 };
76
77 } // namespace
78
79 template <class T>
BuildDenseSpec(const StridedSliceSparseSpec & sparse,StridedSliceDenseSpec * dense)80 static Status TF_MUST_USE_RESULT BuildDenseSpec(
81 const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
82 // Build expanded begin, end, strides, begin_mask, end_mask
83 // to remove any ellipsis
84 dense->begin.resize(dense->dims);
85 dense->end.resize(dense->dims);
86 dense->strides.resize(dense->dims);
87 dense->input_shape_gather_indices_sparse.resize(dense->dims);
88 // What indices to get the final shape from.
89 dense->begin_mask = 0;
90 dense->end_mask = 0;
91 dense->shrink_axis_mask = 0;
92 {
93 int full_index = 0;
94
95 const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
96 dense->begin_valid = sparse.begin_tensor != nullptr;
97 dense->end_valid = sparse.end_tensor != nullptr;
98
99 const T* const begin_flat = sparse.begin_tensor != nullptr
100 ? sparse.begin_tensor->vec<T>().data()
101 : nullptr;
102 const T* const end_flat = sparse.end_tensor != nullptr
103 ? sparse.end_tensor->vec<T>().data()
104 : nullptr;
105
106 for (int i = 0; i < sparse.dims; i++) {
107 if ((1 << i) & sparse.ellipsis_mask) {
108 // Expand the ellipsis into the appropriate indices
109 // NOTE: this only works because we guaranteed one ellipsis
110 int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
111 sparse.num_add_axis_after_ellipsis,
112 dense->dims);
113 for (; full_index < next_index; full_index++) {
114 // new_axis' aren't real axis so you have to skip
115 dense->begin[full_index] = dense->end[full_index] = 0;
116 dense->strides[full_index] = 1;
117 dense->begin_mask |= (1 << full_index);
118 dense->end_mask |= (1 << full_index);
119 dense->final_shape_gather_indices.push_back(full_index);
120 dense->final_shape_gather_indices_sparse.push_back(-1);
121 dense->input_shape_gather_indices_sparse[full_index] = i;
122 }
123 } else if ((1 << i) & sparse.new_axis_mask) {
124 dense->final_shape_gather_indices.push_back(kNewAxis);
125 dense->final_shape_gather_indices_sparse.push_back(-1);
126 } else {
127 if (full_index == dense->begin.size()) {
128 return errors::InvalidArgument("Index out of range using input dim ",
129 full_index, "; input has only ",
130 dense->dims, " dims");
131 }
132
133 // Gather slicing spec into appropriate index
134 if (begin_flat != nullptr) {
135 dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
136 }
137 if (end_flat != nullptr) {
138 dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
139 }
140 dense->strides[full_index] =
141 internal::SubtleMustCopy<T>(strides_flat[i]);
142 if (sparse.begin_mask & (1 << i)) {
143 dense->begin_mask |= (1 << full_index);
144 }
145 if (sparse.end_mask & (1 << i)) {
146 dense->end_mask |= (1 << full_index);
147 }
148 // If shrink, record where to get the dimensionality from (i.e.
149 // new_axis creates a fake 1 size dimension. Also remember shrink
150 // axis (now in dense form) so we can ignore dense->end below.
151 if (sparse.shrink_axis_mask & (1 << i)) {
152 dense->final_shape_gather_indices.push_back(kShrinkAxis);
153 dense->final_shape_gather_indices_sparse.push_back(-1);
154 dense->shrink_axis_mask |= (1 << full_index);
155 } else {
156 dense->final_shape_gather_indices.push_back(full_index);
157 // Remember that where in the sparse shape the dense dim comes
158 // from.
159 dense->final_shape_gather_indices_sparse.push_back(i);
160 }
161 dense->input_shape_gather_indices_sparse[full_index] = i;
162 full_index++;
163 }
164 }
165 }
166 return Status::OK();
167 }
168
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,PartialTensorShape * processing_shape,PartialTensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides,StridedSliceShapeSpec * shape_spec)169 Status ValidateStridedSliceOp(
170 const Tensor* begin_tensor, const Tensor* end_tensor,
171 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
172 int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
173 int32 new_axis_mask, int32 shrink_axis_mask,
174 PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
175 bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
176 gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
177 gtl::InlinedVector<int64, 4>* strides, StridedSliceShapeSpec* shape_spec) {
178 const bool begin_is_wrong =
179 begin_tensor != nullptr &&
180 !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
181 begin_tensor->NumElements() == strides_tensor.NumElements() &&
182 begin_tensor->NumElements() < 32 /* using 32 bit masks */);
183 const bool end_is_wrong =
184 end_tensor != nullptr &&
185 !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
186 end_tensor->NumElements() == strides_tensor.NumElements());
187 if (begin_is_wrong || end_is_wrong ||
188 !TensorShapeUtils::IsVector(strides_tensor.shape())) {
189 if (begin_tensor != nullptr && end_tensor != nullptr) {
190 return errors::InvalidArgument(
191 "Expected begin, end, and strides to be 1D equal size tensors, ",
192 "but got shapes ", begin_tensor->shape().DebugString(), ", ",
193 end_tensor->shape().DebugString(), ", and ",
194 strides_tensor.shape().DebugString(), " instead.");
195 } else {
196 return errors::InvalidArgument(
197 "Expected begin, end, and strides to be 1D equal size tensors, ",
198 "but got shape ", strides_tensor.shape().DebugString(),
199 " for strides.");
200 }
201 }
202 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
203 // i.e. there exists only no more than one ellipsis
204 if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
205 return errors::InvalidArgument(
206 "Multiple ellipses in slice spec not allowed");
207 }
208
209 // Step 1: Account for ellipsis and new axis
210 //
211 // Check for ellipses and count how many non-newaxis' there are after
212 // TODO(aselle): Convert this to do a fast log2 followed by iteration
213 // counting ones in next guys
214 bool ellipsis_seen = false;
215
216 StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
217 0,
218 begin_tensor,
219 end_tensor,
220 strides_tensor,
221 begin_mask_spec,
222 end_mask_spec,
223 ellipsis_mask,
224 new_axis_mask,
225 shrink_axis_mask};
226
227 for (int32 i = 0; i < sparse_spec.dims; i++) {
228 if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
229 sparse_spec.num_add_axis_after_ellipsis++;
230 }
231 if ((1 << i) & ellipsis_mask) {
232 ellipsis_seen = true;
233 }
234 }
235 // If no ellipsis insert one at the end
236 if (!ellipsis_seen) {
237 sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
238 sparse_spec.dims++; // this effects loop iteration below
239 }
240
241 // Step 2: Make a sparse spec into a full index spec
242 //
243 // The sparse spec does not correspond to the number of dimensions
244 // Make a dense spec that corresponds to the number of dimensions
245 //
246 // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
247 // we need to produce the missing begin_mask for the first two
248 // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
249 // we achieve begin_mask=6, end_mask=7
250 StridedSliceDenseSpec dense_spec = {input_shape.dims(),
251 0 /* begin_mask */,
252 0 /* end_mask */,
253 false /* begin_valid */,
254 false /* end_valid */,
255 *begin,
256 *end,
257 *strides};
258
259 if (strides_tensor.dtype() == DT_INT32) {
260 TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
261 } else if (strides_tensor.dtype() == DT_INT64) {
262 TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
263 } else {
264 LOG(FATAL) << "begin must be either int32 or int64";
265 }
266
267 // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
268 // and bounds check!
269 *is_identity = true;
270 *slice_dim0 = true;
271 *is_simple_slice = true;
272 processing_shape->Clear();
273 for (int i = 0; i < input_shape.dims(); ++i) {
274 int64& begin_i = (*begin)[i];
275 int64& end_i = (*end)[i];
276 int64& stride_i = (*strides)[i];
277 int64 dim_i = input_shape.dim_size(i);
278 if (stride_i == 0) {
279 return errors::InvalidArgument("strides[", i, "] must be non-zero");
280 }
281 bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
282 if (dim_i == -1) {
283 processing_shape->AddDim(shrink_i ? 1 : -1);
284 continue;
285 }
286
287 const std::array<int64, 2> masks = {
288 {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
289 const std::array<int64, 2> valid_range = {
290 {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
291
292 auto canonical = [stride_i, dim_i, masks, valid_range](int64 x, int c) {
293 if (masks[c]) {
294 return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
295 } else {
296 int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive
297 return x_fwd < valid_range[0]
298 ? valid_range[0]
299 : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
300 }
301 };
302 if (shrink_i && stride_i <= 0) {
303 return errors::InvalidArgument(
304 "only stride 1 allowed on non-range indexing.");
305 }
306 (*is_simple_slice) &= stride_i == 1;
307
308 const bool begin_and_end_masked =
309 (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
310 if (dense_spec.begin_valid && dense_spec.end_valid) {
311 if (shrink_i) {
312 // If we are shrinking, the end index is now possibly incorrect. In
313 // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
314 // and canonical puts these to n-1 and 0, which implies a degenerate
315 // interval. Fortunately, it is now safe to re-create end as begin+1.
316 int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
317 begin_i = x_fwd;
318 end_i = begin_i + 1;
319 if (x_fwd < 0 || x_fwd >= dim_i) {
320 return errors::InvalidArgument(
321 "slice index ", begin_i, " of dimension ", i, " out of bounds.");
322 }
323 } else {
324 begin_i = canonical(begin_i, 0);
325 end_i = canonical(end_i, 1);
326 }
327 // Update optimization values
328 bool take_all_in_dimension =
329 stride_i == 1 && begin_i == 0 && end_i == dim_i;
330 (*is_identity) &= take_all_in_dimension;
331 (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
332 } else {
333 (*is_identity) &= stride_i == 1 && begin_and_end_masked;
334 (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
335 }
336 // Compute the processing shape (the intermediate Eigen will produce)
337 int64 interval_length;
338 bool known_interval = false;
339 if (dense_spec.begin_valid && dense_spec.end_valid) {
340 interval_length = end_i - begin_i;
341 known_interval = true;
342 } else if (shrink_i) {
343 // The dimension is still known as 1 for the processing_shape, but will be
344 // discarded for the final shape.
345 interval_length = 1;
346 known_interval = true;
347 } else if (begin_and_end_masked) {
348 // Even if we don't have values for begin or end, we do know that this
349 // dimension covers the whole interval. If we have shape information for
350 // this dimension, that tells us the interval length.
351 if (dim_i >= 0) {
352 if (stride_i < 0) {
353 interval_length = -dim_i;
354 } else {
355 interval_length = dim_i;
356 }
357 known_interval = true;
358 }
359 }
360 if (known_interval) {
361 int64 size_i;
362 // Hold zero if the interval is degenerate, otherwise account for
363 // remainder
364 if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
365 size_i = 0;
366 } else {
367 size_i = interval_length / stride_i +
368 (interval_length % stride_i != 0 ? 1 : 0);
369 }
370 processing_shape->AddDim(size_i);
371 } else {
372 processing_shape->AddDim(-1);
373 }
374 }
375
376 // Step 4: Compute the final shape
377 //
378 // new_axis will increase dimension by 1 (with a one-size dimension)
379 // slices like foo[3,...] will reduce dimension by 1.
380 // This cannot be done earlier, because it depends on Step 3.
381 final_shape->Clear();
382 if (shape_spec != nullptr) {
383 shape_spec->output_to_sparse_mapping.clear();
384 shape_spec->output_to_processing_mapping.clear();
385 shape_spec->processing_to_sparse_mapping.assign(
386 dense_spec.input_shape_gather_indices_sparse.begin(),
387 dense_spec.input_shape_gather_indices_sparse.end());
388
389 shape_spec->begin_dense_mask = dense_spec.begin_mask;
390 shape_spec->end_dense_mask = dense_spec.end_mask;
391 shape_spec->shrink_axis_dense_mask = dense_spec.shrink_axis_mask;
392 }
393
394 for (int64 dense_dim = 0;
395 dense_dim < dense_spec.final_shape_gather_indices.size(); ++dense_dim) {
396 int64 gather_index = dense_spec.final_shape_gather_indices[dense_dim];
397 int64 sparse_index =
398 dense_spec.final_shape_gather_indices_sparse[dense_dim];
399 if (gather_index >= 0) {
400 final_shape->AddDim(processing_shape->dim_size(gather_index));
401 if (shape_spec != nullptr) {
402 shape_spec->output_to_sparse_mapping.push_back(sparse_index);
403 shape_spec->output_to_processing_mapping.push_back(gather_index);
404 }
405 } else if (gather_index == kNewAxis) {
406 final_shape->AddDim(1);
407 if (shape_spec != nullptr) {
408 shape_spec->output_to_sparse_mapping.push_back(-1);
409 shape_spec->output_to_processing_mapping.push_back(-1);
410 }
411 }
412 }
413
414 return Status::OK();
415 }
416
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,TensorShape * processing_shape,TensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides,StridedSliceShapeSpec * shape_spec)417 Status ValidateStridedSliceOp(
418 const Tensor* begin_tensor, const Tensor* end_tensor,
419 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
420 int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
421 int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
422 TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
423 bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
424 gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides,
425 StridedSliceShapeSpec* shape_spec) {
426 // Validate with PartialTensorShape output
427 PartialTensorShape partial_processing_shape, partial_final_shape;
428 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
429 begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
430 end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
431 &partial_processing_shape, &partial_final_shape, is_identity,
432 is_simple_slice, slice_dim0, begin, end, strides, shape_spec));
433
434 // Verify that the output shapes are fully known
435 if (!partial_processing_shape.AsTensorShape(processing_shape) ||
436 !partial_final_shape.AsTensorShape(final_shape)) {
437 return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
438 partial_processing_shape.DebugString(), " and ",
439 partial_final_shape.DebugString());
440 }
441 return Status::OK();
442 }
443
444 } // namespace tensorflow
445