1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/strided_slice_op.h"
17
18 #include <array>
19 #include "tensorflow/core/framework/bounds_check.h"
20 #include "tensorflow/core/lib/core/status.h"
21
22 namespace tensorflow {
23 namespace {
24
25 /// Constants
26 constexpr int32 kShrinkAxis = -1, kNewAxis = -2;
27
28 // Sparse slicing specification
29 // if one does foo[3:5, ..., -3], this will have 3 length tensors
30 struct StridedSliceSparseSpec {
31 int64 dims;
32 int32 num_add_axis_after_ellipsis;
33 const Tensor* begin_tensor;
34 const Tensor* end_tensor;
35 const Tensor& strides_tensor;
36 const int32 begin_mask, end_mask;
37 int32 ellipsis_mask;
38 const int32 new_axis_mask, shrink_axis_mask;
39 };
40
41 // Dense slicing specification
42 // all ellipses and newaxis' are expanded out. So if
43 // foo[3:5, ..., -3] where foo is 10 dimensional,
44 // each inlinedVector will have 10 entries whereas the
45 // sparse had 3 length tensors.
46 struct StridedSliceDenseSpec {
47 const int64 dims;
48 int32 begin_mask;
49 int32 end_mask;
50 bool begin_valid;
51 bool end_valid;
52 gtl::InlinedVector<int64, 4>& begin;
53 gtl::InlinedVector<int64, 4>& end;
54 gtl::InlinedVector<int64, 4>& strides;
55 // This vector helps construct the final shape of the slice.
56 // The final tensor is reduced in rank whenever a single index e.g. foo[3]
57 // is called for. The final tensor increases in rank with tf.newaxis
58 // entries. If an index in this array is positive, the size of the dimension
59 // is obtained from canonical end-begin. Otherwise, if it is a kNewAxis,
60 // it will be 1. A shrunk dimension is skipped.
61 gtl::InlinedVector<int32, 4> final_shape_gather_indices;
62 // The dense indexed shrink mask is which processing dimensions
63 // should be shrunk. For example, if foo.shape = (10,10,10,10)
64 // foo[3, ..., 5] has sparse_shrink_axis_mask of 0x5 and
65 // dense_shrink_axis_mask of 0x9, yielding a final shape (10,10).
66 int32 shrink_axis_mask;
67 };
68
69 } // namespace
70
71 template <class T>
BuildDenseSpec(const StridedSliceSparseSpec & sparse,StridedSliceDenseSpec * dense)72 static Status TF_MUST_USE_RESULT BuildDenseSpec(
73 const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
74 // Build expanded begin, end, strides, begin_mask, end_mask
75 // to remove any ellipsis
76 dense->begin.resize(dense->dims);
77 dense->end.resize(dense->dims);
78 dense->strides.resize(dense->dims);
79 // What indices to get the final shape from.
80 dense->begin_mask = 0;
81 dense->end_mask = 0;
82 dense->shrink_axis_mask = 0;
83 {
84 int full_index = 0;
85
86 const T* const strides_flat = sparse.strides_tensor.vec<T>().data();
87 dense->begin_valid = sparse.begin_tensor != nullptr;
88 dense->end_valid = sparse.end_tensor != nullptr;
89
90 const T* const begin_flat = sparse.begin_tensor != nullptr
91 ? sparse.begin_tensor->vec<T>().data()
92 : nullptr;
93 const T* const end_flat = sparse.end_tensor != nullptr
94 ? sparse.end_tensor->vec<T>().data()
95 : nullptr;
96
97 for (int i = 0; i < sparse.dims; i++) {
98 if ((1 << i) & sparse.ellipsis_mask) {
99 // Expand the ellipsis into the appropriate indices
100 // NOTE: this only works because we guaranteed one ellipsis
101 int32 next_index = std::min(dense->dims - (sparse.dims - i) + 1 +
102 sparse.num_add_axis_after_ellipsis,
103 dense->dims);
104 for (; full_index < next_index; full_index++) {
105 // new_axis' aren't real axis so you have to skip
106 dense->begin[full_index] = dense->end[full_index] = 0;
107 dense->strides[full_index] = 1;
108 dense->begin_mask |= (1 << full_index);
109 dense->end_mask |= (1 << full_index);
110 dense->final_shape_gather_indices.push_back(full_index);
111 }
112 } else if ((1 << i) & sparse.new_axis_mask) {
113 dense->final_shape_gather_indices.push_back(kNewAxis);
114 } else {
115 if (full_index == dense->begin.size()) {
116 return errors::InvalidArgument("Index out of range using input dim ",
117 full_index, "; input has only ",
118 dense->dims, " dims");
119 }
120
121 // Gather slicing spec into appropriate index
122 if (begin_flat != nullptr) {
123 dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat[i]);
124 }
125 if (end_flat != nullptr) {
126 dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat[i]);
127 }
128 dense->strides[full_index] =
129 internal::SubtleMustCopy<T>(strides_flat[i]);
130 if (sparse.begin_mask & (1 << i)) {
131 dense->begin_mask |= (1 << full_index);
132 }
133 if (sparse.end_mask & (1 << i)) {
134 dense->end_mask |= (1 << full_index);
135 }
136 // If shrink, record where to get the dimensionality from (i.e.
137 // new_axis creates a fake 1 size dimension. Also remember shrink
138 // axis (now in dense form) so we can ignore dense->end below.
139 if (sparse.shrink_axis_mask & (1 << i)) {
140 dense->final_shape_gather_indices.push_back(kShrinkAxis);
141 dense->shrink_axis_mask |= (1 << full_index);
142 } else {
143 dense->final_shape_gather_indices.push_back(full_index);
144 }
145 full_index++;
146 }
147 }
148 }
149 return Status::OK();
150 }
151
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,PartialTensorShape * processing_shape,PartialTensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides)152 Status ValidateStridedSliceOp(
153 const Tensor* begin_tensor, const Tensor* end_tensor,
154 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
155 int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
156 int32 new_axis_mask, int32 shrink_axis_mask,
157 PartialTensorShape* processing_shape, PartialTensorShape* final_shape,
158 bool* is_identity, bool* is_simple_slice, bool* slice_dim0,
159 gtl::InlinedVector<int64, 4>* begin, gtl::InlinedVector<int64, 4>* end,
160 gtl::InlinedVector<int64, 4>* strides) {
161 const bool begin_is_wrong =
162 begin_tensor != nullptr &&
163 !(TensorShapeUtils::IsVector(begin_tensor->shape()) &&
164 begin_tensor->NumElements() == strides_tensor.NumElements() &&
165 begin_tensor->NumElements() < 32 /* using 32 bit masks */);
166 const bool end_is_wrong =
167 end_tensor != nullptr &&
168 !(TensorShapeUtils::IsVector(end_tensor->shape()) &&
169 end_tensor->NumElements() == strides_tensor.NumElements());
170 if (begin_is_wrong || end_is_wrong ||
171 !TensorShapeUtils::IsVector(strides_tensor.shape())) {
172 if (begin_tensor != nullptr && end_tensor != nullptr) {
173 return errors::InvalidArgument(
174 "Expected begin, end, and strides to be 1D equal size tensors, ",
175 "but got shapes ", begin_tensor->shape().DebugString(), ", ",
176 end_tensor->shape().DebugString(), ", and ",
177 strides_tensor.shape().DebugString(), " instead.");
178 } else {
179 return errors::InvalidArgument(
180 "Expected begin, end, and strides to be 1D equal size tensors, ",
181 "but got shape ", strides_tensor.shape().DebugString(),
182 " for strides.");
183 }
184 }
185 // Use bit compares to ensure ellipsis_mask is 0 or a power of 2
186 // i.e. there exists only no more than one ellipsis
187 if (ellipsis_mask && ((ellipsis_mask & (ellipsis_mask - 1)) != 0)) {
188 return errors::InvalidArgument(
189 "Multiple ellipses in slice spec not allowed");
190 }
191
192 // Step 1: Account for ellipsis and new axis
193 //
194 // Check for ellipses and count how many non-newaxis' there are after
195 // TODO(aselle): Convert this to do a fast log2 followed by iteration
196 // counting ones in next guys
197 bool ellipsis_seen = false;
198
199 StridedSliceSparseSpec sparse_spec = {strides_tensor.NumElements(),
200 0,
201 begin_tensor,
202 end_tensor,
203 strides_tensor,
204 begin_mask_spec,
205 end_mask_spec,
206 ellipsis_mask,
207 new_axis_mask,
208 shrink_axis_mask};
209
210 for (int32 i = 0; i < sparse_spec.dims; i++) {
211 if (ellipsis_seen && ((1 << i) & new_axis_mask) != 0) {
212 sparse_spec.num_add_axis_after_ellipsis++;
213 }
214 if ((1 << i) & ellipsis_mask) {
215 ellipsis_seen = true;
216 }
217 }
218 // If no ellipsis insert one at the end
219 if (!ellipsis_seen) {
220 sparse_spec.ellipsis_mask |= (1 << sparse_spec.dims);
221 sparse_spec.dims++; // this effects loop iteration below
222 }
223
224 // Step 2: Make a sparse spec into a full index spec
225 //
226 // The sparse spec does not correspond to the number of dimensions
227 // Make a dense spec that corresponds to the number of dimensions
228 //
229 // For example suppose foo[...,3:] on foo.shape=(2,2,3) then
230 // we need to produce the missing begin_mask for the first two
231 // dimensions i.e. from begin_mask_spec=0, end_mask_spec=2
232 // we achieve begin_mask=6, end_mask=7
233 StridedSliceDenseSpec dense_spec = {input_shape.dims(),
234 0 /* begin_mask */,
235 0 /* end_mask */,
236 false /* begin_valid */,
237 false /* end_valid */,
238 *begin,
239 *end,
240 *strides};
241
242 if (strides_tensor.dtype() == DT_INT32) {
243 TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
244 } else if (strides_tensor.dtype() == DT_INT64) {
245 TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
246 } else {
247 LOG(FATAL) << "begin must be either int32 or int64";
248 }
249
250 // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
251 // and bounds check!
252 *is_identity = true;
253 *slice_dim0 = true;
254 *is_simple_slice = true;
255 processing_shape->Clear();
256 for (int i = 0; i < input_shape.dims(); ++i) {
257 int64& begin_i = (*begin)[i];
258 int64& end_i = (*end)[i];
259 int64& stride_i = (*strides)[i];
260 int64 dim_i = input_shape.dim_size(i);
261 if (stride_i == 0) {
262 return errors::InvalidArgument("strides[", i, "] must be non-zero");
263 }
264 bool shrink_i = (dense_spec.shrink_axis_mask & (1 << i));
265 if (dim_i == -1) {
266 processing_shape->AddDim(shrink_i ? 1 : -1);
267 continue;
268 }
269
270 const std::array<int64, 2> masks = {
271 {dense_spec.begin_mask & (1 << i), dense_spec.end_mask & (1 << i)}};
272 const std::array<int64, 2> valid_range = {
273 {stride_i > 0 ? 0 : -1, stride_i > 0 ? dim_i : dim_i - 1}};
274
275 auto canonical = [stride_i, dim_i, masks, valid_range](int64 x, int c) {
276 if (masks[c]) {
277 return stride_i > 0 ? valid_range[c] : valid_range[(c + 1) & 1];
278 } else {
279 int64 x_fwd = x < 0 ? dim_i + x : x; // make negative indices positive
280 return x_fwd < valid_range[0]
281 ? valid_range[0]
282 : x_fwd > valid_range[1] ? valid_range[1] : x_fwd;
283 }
284 };
285 if (shrink_i && stride_i <= 0) {
286 return errors::InvalidArgument(
287 "only stride 1 allowed on non-range indexing.");
288 }
289 (*is_simple_slice) &= stride_i == 1;
290
291 const bool begin_and_end_masked =
292 (dense_spec.begin_mask & (1 << i)) && (dense_spec.end_mask & (1 << i));
293 if (dense_spec.begin_valid && dense_spec.end_valid) {
294 if (shrink_i) {
295 // If we are shrinking, the end index is now possibly incorrect. In
296 // particular foo[-1] produces sparse_begin = -1, sparse_end = 0.
297 // and canonical puts these to n-1 and 0, which implies a degenerate
298 // interval. Fortunately, it is now safe to re-create end as begin+1.
299 int64 x_fwd = begin_i < 0 ? dim_i + begin_i : begin_i;
300 begin_i = x_fwd;
301 end_i = begin_i + 1;
302 if (x_fwd < 0 || x_fwd >= dim_i) {
303 return errors::InvalidArgument(
304 "slice index ", begin_i, " of dimension ", i, " out of bounds.");
305 }
306 } else {
307 begin_i = canonical(begin_i, 0);
308 end_i = canonical(end_i, 1);
309 }
310 // Update optimization values
311 bool take_all_in_dimension =
312 stride_i == 1 && begin_i == 0 && end_i == dim_i;
313 (*is_identity) &= take_all_in_dimension;
314 (*slice_dim0) &= (i == 0 && stride_i == 1) || take_all_in_dimension;
315 } else {
316 (*is_identity) &= stride_i == 1 && begin_and_end_masked;
317 (*slice_dim0) &= (i == 0 && stride_i == 1) || begin_and_end_masked;
318 }
319 // Compute the processing shape (the intermediate Eigen will produce)
320 int64 interval_length;
321 bool known_interval = false;
322 if (dense_spec.begin_valid && dense_spec.end_valid) {
323 interval_length = end_i - begin_i;
324 known_interval = true;
325 } else if (shrink_i) {
326 // The dimension is still known as 1 for the processing_shape, but will be
327 // discarded for the final shape.
328 interval_length = 1;
329 known_interval = true;
330 } else if (begin_and_end_masked) {
331 // Even if we don't have values for begin or end, we do know that this
332 // dimension covers the whole interval. If we have shape information for
333 // this dimension, that tells us the interval length.
334 if (dim_i >= 0) {
335 if (stride_i < 0) {
336 interval_length = -dim_i;
337 } else {
338 interval_length = dim_i;
339 }
340 known_interval = true;
341 }
342 }
343 if (known_interval) {
344 int64 size_i;
345 // Hold zero if the interval is degenerate, otherwise account for
346 // remainder
347 if (interval_length == 0 || ((interval_length < 0) != (stride_i < 0))) {
348 size_i = 0;
349 } else {
350 size_i = interval_length / stride_i +
351 (interval_length % stride_i != 0 ? 1 : 0);
352 }
353 processing_shape->AddDim(size_i);
354 } else {
355 processing_shape->AddDim(-1);
356 }
357 }
358
359 // Step 4: Compute the final shape
360 //
361 // new_axis will increase dimension by 1 (with a one-size dimension)
362 // slices like foo[3,...] will reduce dimension by 1.
363 // This cannot be done earlier, because it depends on Step 3.
364 final_shape->Clear();
365 for (auto gather_index : dense_spec.final_shape_gather_indices) {
366 if (gather_index >= 0) {
367 final_shape->AddDim(processing_shape->dim_size(gather_index));
368 } else if (gather_index == kNewAxis) {
369 final_shape->AddDim(1);
370 }
371 }
372 return Status::OK();
373 }
374
ValidateStridedSliceOp(const Tensor * begin_tensor,const Tensor * end_tensor,const Tensor & strides_tensor,const PartialTensorShape & input_shape,int32 begin_mask_spec,int32 end_mask_spec,const int32 ellipsis_mask,int32 new_axis_mask,int32 shrink_axis_mask,TensorShape * processing_shape,TensorShape * final_shape,bool * is_identity,bool * is_simple_slice,bool * slice_dim0,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * end,gtl::InlinedVector<int64,4> * strides)375 Status ValidateStridedSliceOp(
376 const Tensor* begin_tensor, const Tensor* end_tensor,
377 const Tensor& strides_tensor, const PartialTensorShape& input_shape,
378 int32 begin_mask_spec, int32 end_mask_spec, const int32 ellipsis_mask,
379 int32 new_axis_mask, int32 shrink_axis_mask, TensorShape* processing_shape,
380 TensorShape* final_shape, bool* is_identity, bool* is_simple_slice,
381 bool* slice_dim0, gtl::InlinedVector<int64, 4>* begin,
382 gtl::InlinedVector<int64, 4>* end, gtl::InlinedVector<int64, 4>* strides) {
383 // Validate with PartialTensorShape output
384 PartialTensorShape partial_processing_shape, partial_final_shape;
385 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
386 begin_tensor, end_tensor, strides_tensor, input_shape, begin_mask_spec,
387 end_mask_spec, ellipsis_mask, new_axis_mask, shrink_axis_mask,
388 &partial_processing_shape, &partial_final_shape, is_identity,
389 is_simple_slice, slice_dim0, begin, end, strides));
390
391 // Verify that the output shapes are fully known
392 if (!partial_processing_shape.AsTensorShape(processing_shape) ||
393 !partial_final_shape.AsTensorShape(final_shape)) {
394 return errors::Internal("ValidateStridedSliceOp returned partial shapes ",
395 partial_processing_shape.DebugString(), " and ",
396 partial_final_shape.DebugString());
397 }
398 return Status::OK();
399 }
400
401 } // namespace tensorflow
402