• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
17 
18 #include <math.h>
19 #include <stdint.h>
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include "tensorflow/lite/c/builtin_op_data.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/kernels/internal/compatibility.h"
27 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
29 #include "tensorflow/lite/kernels/internal/tensor.h"
30 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31 #include "tensorflow/lite/kernels/internal/types.h"
32 #include "tensorflow/lite/kernels/kernel_util.h"
33 
34 namespace tflite {
35 namespace ops {
36 namespace builtin {
37 namespace strided_slice {
38 
39 enum KernelType {
40   kReference,
41   kGenericOptimized,
42 };
43 
44 constexpr int kInputTensor = 0;
45 constexpr int kBeginTensor = 1;
46 constexpr int kEndTensor = 2;
47 constexpr int kStridesTensor = 3;
48 constexpr int kOutputTensor = 0;
49 
50 struct StridedSliceContext {
StridedSliceContexttflite::ops::builtin::strided_slice::StridedSliceContext51   StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
52     params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
53     input = GetInput(context, node, kInputTensor);
54     begin = GetInput(context, node, kBeginTensor);
55     end = GetInput(context, node, kEndTensor);
56     strides = GetInput(context, node, kStridesTensor);
57     output = GetOutput(context, node, kOutputTensor);
58     input_dims = NumDimensions(input);
59   }
60   const TfLiteStridedSliceParams* params;
61   const TfLiteTensor* input;
62   const TfLiteTensor* begin;
63   const TfLiteTensor* end;
64   const TfLiteTensor* strides;
65   TfLiteTensor* output;
66 
67   // Equivalent input shape after adding axis according to new_axis_mask.
68   RuntimeShape effective_input_shape;
69   int input_dims;
70 };
71 
BuildStridedSliceParams(StridedSliceContext * op_context)72 StridedSliceParams BuildStridedSliceParams(StridedSliceContext* op_context) {
73   StridedSliceParams op_params;
74 
75   // The ellipsis_mask and new_axis_mask in op_params are not used. Those masks
76   // are processed here to update begin_mask, end_mask and the index range.
77   op_params.begin_mask = 0;
78   op_params.ellipsis_mask = 0;
79   op_params.end_mask = 0;
80   op_params.new_axis_mask = 0;
81   op_params.shrink_axis_mask = 0;
82 
83   // Count indexes where the new_axis_mask is set but the ellipsis_mask is not.
84   const int begin_count = GetTensorShape(op_context->begin).Dims(0);
85   int num_add_axis = 0;
86   for (int i = 0; i < begin_count; ++i) {
87     if (!((1 << i) & op_context->params->ellipsis_mask) &&
88         ((1 << i) & op_context->params->new_axis_mask)) {
89       num_add_axis++;
90     }
91   }
92 
93   // Calculate the dims of input after adding new axises.
94   const int effective_dims = op_context->input_dims + num_add_axis;
95 
96   // If begin, end and strides are not fully provided, it means Ellipsis should
97   // be expanded to multiple dimensions (Ex: for spec [Ellipsis, 2] on a 3D
98   // input, the Ellipsis should be applied for the first 2 dimensions). Besides,
99   // If the new_axis_mask and the ellipsis_mask are set at the same index, the
100   // new_axis_mask will have no effect.
101   int effective_ellipsis_mask = 0, effective_new_axis_mask = 0;
102   int ellipsis_start_idx = effective_dims, expanded_ellipsis = 0;
103   for (int i = 0; i < effective_dims;) {
104     if ((1 << i) & op_context->params->ellipsis_mask) {
105       ellipsis_start_idx = i;
106       int ellipsis_end_idx = std::max(
107           i + 1,
108           std::min(i + 1 + num_add_axis + op_context->input_dims - begin_count,
109                    effective_dims));
110       expanded_ellipsis = ellipsis_end_idx - ellipsis_start_idx - 1;
111 
112       // Set bit for effective_ellipsis_mask.
113       for (; i < ellipsis_end_idx; ++i) {
114         effective_ellipsis_mask |= (1 << i);
115       }
116       continue;
117     }
118 
119     if ((1 << (i - expanded_ellipsis)) & op_context->params->new_axis_mask) {
120       effective_new_axis_mask |= (1 << i);
121     }
122     ++i;
123   }
124 
125   // Calculate effective_input_shape and its corresponding begin, end, strides.
126   const int32_t* begin_data = GetTensorData<int32_t>(op_context->begin);
127   const int32_t* end_data = GetTensorData<int32_t>(op_context->end);
128   const int32_t* strides_data = GetTensorData<int32_t>(op_context->strides);
129   const RuntimeShape input_shape = GetTensorShape(op_context->input);
130   int added_ellipsis = 0, added_axises = 0;
131   op_context->effective_input_shape.Resize(effective_dims);
132 
133   for (int i = 0; i < effective_dims; ++i) {
134     if ((1 << i) & effective_ellipsis_mask) {
135       // If ellipsis_mask, set the begin_mask and end_mask at that index.
136       added_ellipsis = std::max(0, i - ellipsis_start_idx);
137       op_params.begin_mask |= (1 << i);
138       op_params.end_mask |= (1 << i);
139       op_params.strides[i] = 1;
140       op_context->effective_input_shape.SetDim(
141           i, input_shape.Dims(i - added_axises));
142     } else if ((1 << i) & effective_new_axis_mask) {
143       // If new_axis_mask is set, it is equivalent to adding a new dim of 1 to
144       // input tensor. Store added shape to effective_input_shape.
145       op_params.start_indices[i] = 0;
146       op_params.stop_indices[i] = 1;
147       op_params.strides[i] = 1;
148       op_context->effective_input_shape.SetDim(i, 1);
149       added_axises++;
150     } else if (i >= begin_count + expanded_ellipsis) {
151       op_params.start_indices[i] = 0;
152       op_params.stop_indices[i] = 0;
153       op_params.strides[i] = 1;
154       op_params.begin_mask |= (1 << i);
155       op_params.end_mask |= (1 << i);
156       op_context->effective_input_shape.SetDim(
157           i, input_shape.Dims(i - added_axises));
158     } else {
159       const int orig_idx = i - added_ellipsis;
160       op_params.start_indices[i] = begin_data[orig_idx];
161       op_params.stop_indices[i] = end_data[orig_idx];
162       op_params.strides[i] = strides_data[orig_idx];
163       if (op_context->params->begin_mask & (1 << orig_idx)) {
164         op_params.begin_mask |= (1 << i);
165       }
166       if (op_context->params->end_mask & (1 << orig_idx)) {
167         op_params.end_mask |= (1 << i);
168       }
169       if (op_context->params->shrink_axis_mask & (1 << orig_idx)) {
170         op_params.shrink_axis_mask |= (1 << i);
171       }
172       op_context->effective_input_shape.SetDim(
173           i, input_shape.Dims(i - added_axises));
174     }
175   }
176   op_params.start_indices_count = effective_dims;
177   op_params.stop_indices_count = effective_dims;
178   op_params.strides_count = effective_dims;
179 
180   return op_params;
181 }
182 
183 // Processes the indexing tensors (begin, end and strides) to resize the
184 // output tensor. This function is callable from both Prepare() and Eval() as
185 // long as the caller ensures the indexing tensors are present.
ResizeOutputTensor(TfLiteContext * context,StridedSliceContext * op_context)186 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
187                                 StridedSliceContext* op_context) {
188   std::vector<int> output_shape_vector;
189   StridedSliceParams op_params = BuildStridedSliceParams(op_context);
190   const RuntimeShape effective_input_shape = op_context->effective_input_shape;
191   TF_LITE_ENSURE_MSG(
192       context, effective_input_shape.DimensionsCount() <= 5,
193       "StridedSlice op only supports up to 5D output including added axis.");
194 
195   for (int idx = effective_input_shape.DimensionsCount() - 1; idx >= 0; --idx) {
196     int32_t stride = op_params.strides[idx];
197     TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
198 
199     int32_t begin = ::tflite::strided_slice::StartForAxis(
200         op_params, effective_input_shape, idx);
201     int32_t end = ::tflite::strided_slice::StopForAxis(
202         op_params, effective_input_shape, idx, begin);
203 
204     // When shrinking an axis, the end position does not matter (and can be
205     // incorrect when negative indexing is used, see Issue #19260). Always use
206     // begin + 1 to generate a length 1 slice, since begin has
207     // already been adjusted for negative indices by GetBeginValueAtIndex.
208     const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
209     if (shrink_axis) {
210       end = begin + 1;
211     }
212 
213     // This is valid for both positive and negative strides
214     int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
215     dim_shape = dim_shape < 0 ? 0 : dim_shape;
216     if (!shrink_axis) {
217       output_shape_vector.push_back(dim_shape);
218     }
219   }
220 
221   TfLiteIntArray* output_shape =
222       TfLiteIntArrayCreate(output_shape_vector.size());
223 
224   std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
225                     output_shape->data);
226 
227   TF_LITE_ENSURE_STATUS(
228       context->ResizeTensor(context, op_context->output, output_shape));
229 
230   return kTfLiteOk;
231 }
232 
Prepare(TfLiteContext * context,TfLiteNode * node)233 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
234   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
235   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
236 
237   StridedSliceContext op_context(context, node);
238 
239   // Ensure validity of input tensor and its dimension
240   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
241   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
242   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
243   TF_LITE_ENSURE_EQ(context, NumElements(op_context.begin),
244                     NumElements(op_context.end));
245   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
246 
247   // Only INT32 begin/end/strides are supported
248   // TODO(b/175642009): add support for INT64
249   TF_LITE_ENSURE_TYPES_EQ(context, op_context.begin->type, kTfLiteInt32);
250   TF_LITE_ENSURE_TYPES_EQ(context, op_context.end->type, kTfLiteInt32);
251   TF_LITE_ENSURE_TYPES_EQ(context, op_context.strides->type, kTfLiteInt32);
252   TF_LITE_ENSURE_MSG(context, op_context.input_dims <= 5,
253                      "StridedSlice op only supports 1D-5D input arrays.");
254 
255   // Postpone allocation of output if any of the indexing tensors is not
256   // constant
257   if (!(IsConstantTensor(op_context.begin) &&
258         IsConstantTensor(op_context.end) &&
259         IsConstantTensor(op_context.strides))) {
260     SetTensorToDynamic(op_context.output);
261     return kTfLiteOk;
262   }
263   return ResizeOutputTensor(context, &op_context);
264 }
265 
266 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)267 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
268   StridedSliceContext op_context(context, node);
269 
270   if (IsDynamicTensor(op_context.output)) {
271     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
272   }
273   StridedSliceParams op_params = BuildStridedSliceParams(&op_context);
274 
275 #define TF_LITE_STRIDED_SLICE(data_type)                                 \
276   {                                                                      \
277     if (kernel_type == kGenericOptimized) {                              \
278       optimized_ops::StridedSlice<data_type>(                            \
279           op_params, op_context.effective_input_shape, op_context.input, \
280           GetTensorShape(op_context.output), op_context.output);         \
281     } else {                                                             \
282       reference_ops::StridedSlice<data_type>(                            \
283           op_params, op_context.effective_input_shape, op_context.input, \
284           GetTensorShape(op_context.output), op_context.output);         \
285     }                                                                    \
286   }
287 
288   switch (op_context.input->type) {
289     case kTfLiteFloat32:
290       TF_LITE_STRIDED_SLICE(float);
291       break;
292     case kTfLiteInt32:
293       TF_LITE_STRIDED_SLICE(int32_t);
294       break;
295     case kTfLiteInt64:
296       TF_LITE_STRIDED_SLICE(int64_t);
297       break;
298     case kTfLiteUInt8:
299       TF_LITE_STRIDED_SLICE(uint8_t);
300       break;
301     case kTfLiteInt8:
302       TF_LITE_STRIDED_SLICE(int8_t);
303       break;
304     case kTfLiteInt16:
305       TF_LITE_STRIDED_SLICE(int16_t);
306       break;
307     case kTfLiteBool:
308       TF_LITE_STRIDED_SLICE(bool);
309       break;
310     case kTfLiteString:
311       TF_LITE_STRIDED_SLICE(string);
312       break;
313     default:
314       TF_LITE_KERNEL_LOG(context,
315                          "Type %s is currently not supported "
316                          "by StridedSlice.",
317                          TfLiteTypeGetName(op_context.input->type));
318       return kTfLiteError;
319   }
320 #undef TF_LITE_STRIDED_SLICE
321   return kTfLiteOk;
322 }
323 
324 }  // namespace strided_slice
325 
Register_STRIDED_SLICE_REF()326 TfLiteRegistration* Register_STRIDED_SLICE_REF() {
327   static TfLiteRegistration r = {
328       nullptr, nullptr, strided_slice::Prepare,
329       strided_slice::Eval<strided_slice::kReference>};
330   return &r;
331 }
332 
Register_STRIDED_SLICE()333 TfLiteRegistration* Register_STRIDED_SLICE() {
334   static TfLiteRegistration r = {
335       nullptr, nullptr, strided_slice::Prepare,
336       strided_slice::Eval<strided_slice::kGenericOptimized>};
337   return &r;
338 }
339 
340 }  // namespace builtin
341 }  // namespace ops
342 }  // namespace tflite
343