• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <cmath>
17 #include <vector>
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/c_api_internal.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace strided_slice {
29 
30 enum KernelType {
31   kReference,
32   // TODO(soroosh): add kGenericOptimized
33 };
34 
35 constexpr int kInputTensor = 0;
36 constexpr int kBeginTensor = 1;
37 constexpr int kEndTensor = 2;
38 constexpr int kStridesTensor = 3;
39 constexpr int kOutputTensor = 0;
40 
41 struct StridedSliceContext {
StridedSliceContexttflite::ops::builtin::strided_slice::StridedSliceContext42   StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
43     params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
44     input = GetInput(context, node, kInputTensor);
45     begin = GetInput(context, node, kBeginTensor);
46     end = GetInput(context, node, kEndTensor);
47     strides = GetInput(context, node, kStridesTensor);
48     output = GetOutput(context, node, kOutputTensor);
49     dims = NumDimensions(input);
50   }
51   const TfLiteStridedSliceParams* params;
52   const TfLiteTensor* input;
53   const TfLiteTensor* begin;
54   const TfLiteTensor* end;
55   const TfLiteTensor* strides;
56   TfLiteTensor* output;
57   int dims;
58 };
59 
60 // This Op only supports 1-4D cases and since we use the reference 4D
61 // implementation, the 1-3D tensors are mapped to 4D.
62 const int kMaxDim = 4;
63 
PositiveRemainder(int32_t dividend,int32_t divisor)64 inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
65   return (divisor + (dividend % divisor)) % divisor;
66 }
67 
ClampedIndex(int32_t index,int dim,bool pos_stride)68 inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
69   return pos_stride
70              ? (index >= dim ? dim
71                              : PositiveRemainder(
72                                    std::min(std::max(index, -dim), dim), dim))
73              : (index < -dim
74                     ? -1
75                     : PositiveRemainder(
76                           std::min(std::max(index, -dim), dim - 1), dim));
77 }
78 
79 // TODO(b/77971377) this logic should be removed, as it's a duplication of
80 // StartForAxis() & StopForAxis() in kernels/internal/reference/reference_ops.h
GetBeginValueAtIndex(StridedSliceContext * op_context,int idx)81 inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) {
82   const int dim = op_context->input->dims->data[idx];
83   const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
84   return op_context->params->begin_mask & (1 << idx)
85              ? pos_stride ? 0 : dim - 1
86              : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim,
87                             pos_stride);
88 }
89 
GetEndValueAtIndex(StridedSliceContext * op_context,int idx)90 inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) {
91   const int dim = op_context->input->dims->data[idx];
92   const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
93   return op_context->params->end_mask & (1 << idx)
94              ? pos_stride ? dim : -1
95              : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim,
96                             pos_stride);
97 }
98 
99 // Processes the indexing tensors (begin, end and strides) to resize the
100 // output tensor. This function is callable from both Prepare() and Eval() as
101 // long as the caller ensures the indexing tensors are present.
ResizeOutputTensor(TfLiteContext * context,StridedSliceContext * op_context)102 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
103                                 StridedSliceContext* op_context) {
104   std::vector<int> output_shape_vector;
105 
106   for (int idx = op_context->dims - 1; idx >= 0; --idx) {
107     int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
108     TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
109 
110     int32_t begin = GetBeginValueAtIndex(op_context, idx);
111     int32_t end = GetEndValueAtIndex(op_context, idx);
112 
113     // When shrinking an axis, the end position does not matter (and can be
114     // incorrect when negative indexing is used, see Issue #19260). Always use
115     // begin + 1 to generate a length 1 slice, since begin has
116     // already been adjusted for negative indices by GetBeginValueAtIndex.
117     const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
118     if (shrink_axis) {
119       end = begin + 1;
120     }
121 
122     // This is valid for both positive and negative strides
123     int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
124     dim_shape = dim_shape < 0 ? 0 : dim_shape;
125     if (!shrink_axis) {
126       output_shape_vector.push_back(dim_shape);
127     }
128   }
129 
130   TfLiteIntArray* output_shape =
131       TfLiteIntArrayCreate(output_shape_vector.size());
132 
133   std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
134                     output_shape->data);
135 
136   TF_LITE_ENSURE_STATUS(
137       context->ResizeTensor(context, op_context->output, output_shape));
138 
139   return kTfLiteOk;
140 }
141 
Prepare(TfLiteContext * context,TfLiteNode * node)142 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
143   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
144   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
145 
146   StridedSliceContext op_context(context, node);
147 
148   // Ensure validity of input tensor and its dimension
149   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
150   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
151   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
152   TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
153   // Only INT32 begin/end/strides are supported
154   // TODO(soroosh) add support for INT64
155   TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
156   TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
157   TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
158   TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
159                      "StridedSlice op only supports 1D-4D input arrays.");
160 
161   // TODO(soroosh): add the following missing functionalities
162   TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
163                      "ellipsis_mask is not implemented yet.");
164   TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
165                      "new_axis_mask is not implemented yet.");
166 
167   // Postpone allocation of output if any of the indexing tensors is not
168   // constant
169   if (!(IsConstantTensor(op_context.begin) &&
170         IsConstantTensor(op_context.end) &&
171         IsConstantTensor(op_context.strides))) {
172     SetTensorToDynamic(op_context.output);
173     return kTfLiteOk;
174   }
175   return ResizeOutputTensor(context, &op_context);
176 }
177 
178 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)179 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
180   StridedSliceContext op_context(context, node);
181 
182   if (IsDynamicTensor(op_context.output)) {
183     TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
184   }
185 
186   std::vector<int32_t> starts;
187   std::vector<int32_t> stops;
188   std::vector<int32_t> strides;
189 
190   for (int i = op_context.dims; i < kMaxDim; i++) {
191     starts.emplace_back(0);
192     stops.emplace_back(1);
193     strides.emplace_back(1);
194   }
195 
196   for (int idx = 0; idx < op_context.dims; ++idx) {
197     starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
198     stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
199     strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
200   }
201 
202   int begin_mask = op_context.params->begin_mask << (4 - op_context.dims);
203   int end_mask = op_context.params->end_mask << (4 - op_context.dims);
204   int shrink_axis_mask = op_context.params->shrink_axis_mask
205                          << (4 - op_context.dims);
206   TF_LITE_ENSURE_EQ(context, starts.size(), 4);
207   auto op_params = ::tflite::strided_slice::BuildStridedSliceParams(
208       begin_mask, end_mask, shrink_axis_mask, starts, stops, strides);
209 
210 #define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                    \
211   kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
212                             GetTensorData<data_type>(op_context.input),  \
213                             GetTensorShape(op_context.output),           \
214                             GetTensorData<data_type>(op_context.output))
215 
216   switch (op_context.input->type) {
217     case kTfLiteFloat32:
218       if (kernel_type == kReference) {
219         TF_LITE_STRIDED_SLICE(reference_ops, float);
220       }
221       break;
222     case kTfLiteInt32:
223       if (kernel_type == kReference) {
224         TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
225       }
226       break;
227     case kTfLiteInt64:
228       if (kernel_type == kReference) {
229         TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
230       }
231       break;
232     case kTfLiteUInt8:
233       if (kernel_type == kReference) {
234         TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
235       }
236       break;
237     case kTfLiteInt8:
238       if (kernel_type == kReference) {
239         TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
240       }
241       break;
242     default:
243       context->ReportError(context,
244                            "Type %d is currently not supported "
245                            "by StridedSlice.",
246                            op_context.input->type);
247       return kTfLiteError;
248   }
249 #undef TF_LITE_STRIDED_SLICE
250   return kTfLiteOk;
251 }
252 
253 }  // namespace strided_slice
254 
Register_STRIDED_SLICE_REF()255 TfLiteRegistration* Register_STRIDED_SLICE_REF() {
256   static TfLiteRegistration r = {
257       nullptr, nullptr, strided_slice::Prepare,
258       strided_slice::Eval<strided_slice::kReference>};
259   return &r;
260 }
261 
262 // TODO(soroosh): add optimized
Register_STRIDED_SLICE()263 TfLiteRegistration* Register_STRIDED_SLICE() {
264   return Register_STRIDED_SLICE_REF();
265 }
266 
267 }  // namespace builtin
268 }  // namespace ops
269 }  // namespace tflite
270