1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <cmath>
17 #include <vector>
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/c_api_internal.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace strided_slice {
29
30 enum KernelType {
31 kReference,
32 // TODO(soroosh): add kGenericOptimized
33 };
34
35 constexpr int kInputTensor = 0;
36 constexpr int kBeginTensor = 1;
37 constexpr int kEndTensor = 2;
38 constexpr int kStridesTensor = 3;
39 constexpr int kOutputTensor = 0;
40
41 struct StridedSliceContext {
StridedSliceContexttflite::ops::builtin::strided_slice::StridedSliceContext42 StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
43 params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
44 input = GetInput(context, node, kInputTensor);
45 begin = GetInput(context, node, kBeginTensor);
46 end = GetInput(context, node, kEndTensor);
47 strides = GetInput(context, node, kStridesTensor);
48 output = GetOutput(context, node, kOutputTensor);
49 dims = NumDimensions(input);
50 }
51 const TfLiteStridedSliceParams* params;
52 const TfLiteTensor* input;
53 const TfLiteTensor* begin;
54 const TfLiteTensor* end;
55 const TfLiteTensor* strides;
56 TfLiteTensor* output;
57 int dims;
58 };
59
60 // This Op only supports 1-4D cases and since we use the reference 4D
61 // implementation, the 1-3D tensors are mapped to 4D.
62 const int kMaxDim = 4;
63
PositiveRemainder(int32_t dividend,int32_t divisor)64 inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) {
65 return (divisor + (dividend % divisor)) % divisor;
66 }
67
ClampedIndex(int32_t index,int dim,bool pos_stride)68 inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) {
69 return pos_stride
70 ? (index >= dim ? dim
71 : PositiveRemainder(
72 std::min(std::max(index, -dim), dim), dim))
73 : (index < -dim
74 ? -1
75 : PositiveRemainder(
76 std::min(std::max(index, -dim), dim - 1), dim));
77 }
78
79 // TODO(b/77971377) this logic should be removed, as it's a duplication of
80 // StartForAxis() & StopForAxis() in kernels/internal/reference/reference_ops.h
GetBeginValueAtIndex(StridedSliceContext * op_context,int idx)81 inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) {
82 const int dim = op_context->input->dims->data[idx];
83 const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
84 return op_context->params->begin_mask & (1 << idx)
85 ? pos_stride ? 0 : dim - 1
86 : ClampedIndex(GetTensorData<int32_t>(op_context->begin)[idx], dim,
87 pos_stride);
88 }
89
GetEndValueAtIndex(StridedSliceContext * op_context,int idx)90 inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) {
91 const int dim = op_context->input->dims->data[idx];
92 const bool pos_stride = GetTensorData<int32_t>(op_context->strides)[idx] > 0;
93 return op_context->params->end_mask & (1 << idx)
94 ? pos_stride ? dim : -1
95 : ClampedIndex(GetTensorData<int32_t>(op_context->end)[idx], dim,
96 pos_stride);
97 }
98
99 // Processes the indexing tensors (begin, end and strides) to resize the
100 // output tensor. This function is callable from both Prepare() and Eval() as
101 // long as the caller ensures the indexing tensors are present.
ResizeOutputTensor(TfLiteContext * context,StridedSliceContext * op_context)102 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
103 StridedSliceContext* op_context) {
104 std::vector<int> output_shape_vector;
105
106 for (int idx = op_context->dims - 1; idx >= 0; --idx) {
107 int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
108 TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
109
110 int32_t begin = GetBeginValueAtIndex(op_context, idx);
111 int32_t end = GetEndValueAtIndex(op_context, idx);
112
113 // When shrinking an axis, the end position does not matter (and can be
114 // incorrect when negative indexing is used, see Issue #19260). Always use
115 // begin + 1 to generate a length 1 slice, since begin has
116 // already been adjusted for negative indices by GetBeginValueAtIndex.
117 const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
118 if (shrink_axis) {
119 end = begin + 1;
120 }
121
122 // This is valid for both positive and negative strides
123 int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
124 dim_shape = dim_shape < 0 ? 0 : dim_shape;
125 if (!shrink_axis) {
126 output_shape_vector.push_back(dim_shape);
127 }
128 }
129
130 TfLiteIntArray* output_shape =
131 TfLiteIntArrayCreate(output_shape_vector.size());
132
133 std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(),
134 output_shape->data);
135
136 TF_LITE_ENSURE_STATUS(
137 context->ResizeTensor(context, op_context->output, output_shape));
138
139 return kTfLiteOk;
140 }
141
Prepare(TfLiteContext * context,TfLiteNode * node)142 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
143 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
144 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
145
146 StridedSliceContext op_context(context, node);
147
148 // Ensure validity of input tensor and its dimension
149 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1);
150 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1);
151 TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1);
152 TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
153 // Only INT32 begin/end/strides are supported
154 // TODO(soroosh) add support for INT64
155 TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32);
156 TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32);
157 TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32);
158 TF_LITE_ENSURE_MSG(context, op_context.dims <= 4,
159 "StridedSlice op only supports 1D-4D input arrays.");
160
161 // TODO(soroosh): add the following missing functionalities
162 TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0,
163 "ellipsis_mask is not implemented yet.");
164 TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0,
165 "new_axis_mask is not implemented yet.");
166
167 // Postpone allocation of output if any of the indexing tensors is not
168 // constant
169 if (!(IsConstantTensor(op_context.begin) &&
170 IsConstantTensor(op_context.end) &&
171 IsConstantTensor(op_context.strides))) {
172 SetTensorToDynamic(op_context.output);
173 return kTfLiteOk;
174 }
175 return ResizeOutputTensor(context, &op_context);
176 }
177
178 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)179 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
180 StridedSliceContext op_context(context, node);
181
182 if (IsDynamicTensor(op_context.output)) {
183 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
184 }
185
186 std::vector<int32_t> starts;
187 std::vector<int32_t> stops;
188 std::vector<int32_t> strides;
189
190 for (int i = op_context.dims; i < kMaxDim; i++) {
191 starts.emplace_back(0);
192 stops.emplace_back(1);
193 strides.emplace_back(1);
194 }
195
196 for (int idx = 0; idx < op_context.dims; ++idx) {
197 starts.emplace_back(GetTensorData<int32_t>(op_context.begin)[idx]);
198 stops.emplace_back(GetTensorData<int32_t>(op_context.end)[idx]);
199 strides.emplace_back(GetTensorData<int32_t>(op_context.strides)[idx]);
200 }
201
202 int begin_mask = op_context.params->begin_mask << (4 - op_context.dims);
203 int end_mask = op_context.params->end_mask << (4 - op_context.dims);
204 int shrink_axis_mask = op_context.params->shrink_axis_mask
205 << (4 - op_context.dims);
206 TF_LITE_ENSURE_EQ(context, starts.size(), 4);
207 auto op_params = ::tflite::strided_slice::BuildStridedSliceParams(
208 begin_mask, end_mask, shrink_axis_mask, starts, stops, strides);
209
210 #define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
211 kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
212 GetTensorData<data_type>(op_context.input), \
213 GetTensorShape(op_context.output), \
214 GetTensorData<data_type>(op_context.output))
215
216 switch (op_context.input->type) {
217 case kTfLiteFloat32:
218 if (kernel_type == kReference) {
219 TF_LITE_STRIDED_SLICE(reference_ops, float);
220 }
221 break;
222 case kTfLiteInt32:
223 if (kernel_type == kReference) {
224 TF_LITE_STRIDED_SLICE(reference_ops, int32_t);
225 }
226 break;
227 case kTfLiteInt64:
228 if (kernel_type == kReference) {
229 TF_LITE_STRIDED_SLICE(reference_ops, int64_t);
230 }
231 break;
232 case kTfLiteUInt8:
233 if (kernel_type == kReference) {
234 TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
235 }
236 break;
237 case kTfLiteInt8:
238 if (kernel_type == kReference) {
239 TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
240 }
241 break;
242 default:
243 context->ReportError(context,
244 "Type %d is currently not supported "
245 "by StridedSlice.",
246 op_context.input->type);
247 return kTfLiteError;
248 }
249 #undef TF_LITE_STRIDED_SLICE
250 return kTfLiteOk;
251 }
252
253 } // namespace strided_slice
254
Register_STRIDED_SLICE_REF()255 TfLiteRegistration* Register_STRIDED_SLICE_REF() {
256 static TfLiteRegistration r = {
257 nullptr, nullptr, strided_slice::Prepare,
258 strided_slice::Eval<strided_slice::kReference>};
259 return &r;
260 }
261
262 // TODO(soroosh): add optimized
Register_STRIDED_SLICE()263 TfLiteRegistration* Register_STRIDED_SLICE() {
264 return Register_STRIDED_SLICE_REF();
265 }
266
267 } // namespace builtin
268 } // namespace ops
269 } // namespace tflite
270