1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/internal/compatibility.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/padding.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace custom {
28 namespace max_pool_with_argmax {
29 namespace {
30 // TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
31 // this op to a builtin op.
32 template <typename T>
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,const T * input_data,T * output_data,int32_t * indices_data)33 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
34 const RuntimeShape& output_shape, const T* input_data,
35 T* output_data, int32_t* indices_data) {
36 TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
37 TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
38
39 const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
40 const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
41 const int32_t input_height = input_shape.Dims(1);
42 const int32_t input_width = input_shape.Dims(2);
43 const int32_t output_height = output_shape.Dims(1);
44 const int32_t output_width = output_shape.Dims(2);
45 const int32_t stride_height = params.stride_height;
46 const int32_t stride_width = params.stride_width;
47 for (int32_t batch = 0; batch < batches; ++batch) {
48 for (int32_t out_y = 0; out_y < output_height; ++out_y) {
49 for (int32_t out_x = 0; out_x < output_width; ++out_x) {
50 for (int32_t channel = 0; channel < depth; ++channel) {
51 const int32_t in_x_origin =
52 (out_x * stride_width) - params.padding_values.width;
53 const int32_t in_y_origin =
54 (out_y * stride_height) - params.padding_values.height;
55 // Compute the boundaries of the filter region clamped so as to
56 // ensure that the filter window fits in the input array.
57 const int32_t filter_x_start = std::max(0, -in_x_origin);
58 const int32_t filter_x_end =
59 std::min(params.filter_width, input_width - in_x_origin);
60 const int32_t filter_y_start = std::max(0, -in_y_origin);
61 const int32_t filter_y_end =
62 std::min(params.filter_height, input_height - in_y_origin);
63 float max = std::numeric_limits<float>::lowest();
64 int32_t max_x = 0;
65 int32_t max_y = 0;
66
67 for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
68 ++filter_y) {
69 for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
70 ++filter_x) {
71 const int32_t in_x = in_x_origin + filter_x;
72 const int32_t in_y = in_y_origin + filter_y;
73 float cur =
74 input_data[Offset(input_shape, batch, in_y, in_x, channel)];
75 if (cur > max) {
76 max = cur;
77 max_x = in_x;
78 max_y = in_y;
79 }
80 }
81 }
82 int32_t output_idx =
83 Offset(output_shape, batch, out_y, out_x, channel);
84 output_data[output_idx] = ActivationFunctionWithMinMax(
85 max, params.float_activation_min, params.float_activation_max);
86 indices_data[output_idx] =
87 (max_y * input_width + max_x) * depth + channel;
88 }
89 }
90 }
91 }
92 }
93
94 } // namespace
95
96 constexpr int kDataInputTensor = 0;
97 constexpr int kDataOutputTensor = 0;
98 constexpr int kIndicesOutputTensor = 1;
99
100 constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
101 constexpr const char kPoolSizeStr[] = "ksize";
102 constexpr const char kStridesStr[] = "strides";
103 constexpr const char kPaddingStr[] = "padding";
104 constexpr const char kPaddingSameStr[] = "SAME";
105 constexpr const char kPaddingValidStr[] = "VALID";
106
107 struct OpData {
108 TfLitePoolParams params;
109 bool include_batch_in_index;
110 };
111
Init(TfLiteContext * context,const char * buffer,size_t length)112 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
113 const flexbuffers::Map& m =
114 flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
115 .AsMap();
116
117 OpData* op_data = new OpData;
118 op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
119 op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
120 op_data->params.activation = kTfLiteActNone;
121
122 const std::string padding = m[kPaddingStr].AsString().str();
123 if (padding == kPaddingValidStr) {
124 op_data->params.padding = kTfLitePaddingValid;
125 } else if (padding == kPaddingSameStr) {
126 op_data->params.padding = kTfLitePaddingSame;
127 } else {
128 op_data->params.padding = kTfLitePaddingUnknown;
129 }
130
131 // The first and last element of pool_size are always 1.
132 const auto pool_size = m[kPoolSizeStr].AsTypedVector();
133 TFLITE_CHECK_EQ(pool_size.size(), 4);
134 TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
135 TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
136 op_data->params.filter_height = pool_size[1].AsInt32();
137 op_data->params.filter_width = pool_size[2].AsInt32();
138
139 // The first and last element of strides are always 1.
140 const auto strides = m[kStridesStr].AsTypedVector();
141 TFLITE_CHECK_EQ(strides.size(), 4);
142 TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
143 TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
144 op_data->params.stride_height = strides[1].AsInt32();
145 op_data->params.stride_width = strides[2].AsInt32();
146
147 return op_data;
148 }
149
Free(TfLiteContext * context,void * buffer)150 void Free(TfLiteContext* context, void* buffer) {
151 delete reinterpret_cast<OpData*>(buffer);
152 }
153
Prepare(TfLiteContext * context,TfLiteNode * node)154 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
155 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
156
157 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
158 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
159 TfLiteTensor *output, *indices;
160 TF_LITE_ENSURE_OK(context,
161 GetOutputSafe(context, node, kDataOutputTensor, &output));
162 TF_LITE_ENSURE_OK(
163 context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
164 const TfLiteTensor* input;
165 TF_LITE_ENSURE_OK(context,
166 GetInputSafe(context, node, kDataInputTensor, &input));
167 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
168 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
169 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
170 TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
171 TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
172 TF_LITE_ENSURE_MSG(
173 context, !op_data->include_batch_in_index,
174 "Include batch dimension in flattened index is not yet supported.");
175
176 int batches = input->dims->data[0];
177 int height = input->dims->data[1];
178 int width = input->dims->data[2];
179 int channels_out = input->dims->data[3];
180
181 // Matching GetWindowedOutputSize in TensorFlow.
182 int out_width, out_height;
183 op_data->params.computed.padding = ComputePaddingHeightWidth(
184 op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
185 width, op_data->params.filter_height, op_data->params.filter_width,
186 op_data->params.padding, &out_height, &out_width);
187
188 TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
189 output_size->data[0] = batches;
190 output_size->data[1] = out_height;
191 output_size->data[2] = out_width;
192 output_size->data[3] = channels_out;
193 TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
194
195 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
196 return context->ResizeTensor(context, output, output_size);
197 }
198
Eval(TfLiteContext * context,TfLiteNode * node)199 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
200 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
201
202 float activation_min, activation_max;
203 CalculateActivationRange(op_data->params.activation, &activation_min,
204 &activation_max);
205
206 tflite::PoolParams op_params;
207 op_params.stride_height = op_data->params.stride_height;
208 op_params.stride_width = op_data->params.stride_width;
209 op_params.filter_height = op_data->params.filter_height;
210 op_params.filter_width = op_data->params.filter_width;
211 op_params.padding_values.height = op_data->params.computed.padding.height;
212 op_params.padding_values.width = op_data->params.computed.padding.width;
213 op_params.float_activation_min = activation_min;
214 op_params.float_activation_max = activation_max;
215
216 TfLiteTensor *output, *indices;
217 TF_LITE_ENSURE_OK(context,
218 GetOutputSafe(context, node, kDataOutputTensor, &output));
219 TF_LITE_ENSURE_OK(
220 context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
221 const TfLiteTensor* input;
222 TF_LITE_ENSURE_OK(context,
223 GetInputSafe(context, node, kDataInputTensor, &input));
224
225 switch (input->type) {
226 case kTfLiteFloat32:
227 MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
228 GetTensorData<float>(input), GetTensorData<float>(output),
229 GetTensorData<int32_t>(indices));
230 break;
231 default:
232 TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
233 TfLiteTypeGetName(input->type));
234 return kTfLiteError;
235 }
236 return kTfLiteOk;
237 }
238 } // namespace max_pool_with_argmax
239
RegisterMaxPoolWithArgmax()240 TfLiteRegistration* RegisterMaxPoolWithArgmax() {
241 static TfLiteRegistration r = {
242 max_pool_with_argmax::Init, max_pool_with_argmax::Free,
243 max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
244 return &r;
245 }
246
247 // Alias for selective build.
Register_MAX_POOL_WITH_ARGMAX()248 TfLiteRegistration* Register_MAX_POOL_WITH_ARGMAX() {
249 return RegisterMaxPoolWithArgmax();
250 }
251
252 } // namespace custom
253 } // namespace ops
254 } // namespace tflite
255