• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/common.h"
19 #include "tensorflow/lite/kernels/internal/compatibility.h"
20 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21 #include "tensorflow/lite/kernels/internal/types.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/padding.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace custom {
28 namespace max_pool_with_argmax {
29 namespace {
30 // TODO(b/175003241): Move this logic to lite/kernels/internal when promoting
31 // this op to a builtin op.
32 template <typename T>
MaxPool(const PoolParams & params,const RuntimeShape & input_shape,const RuntimeShape & output_shape,const T * input_data,T * output_data,int32_t * indices_data)33 inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
34                     const RuntimeShape& output_shape, const T* input_data,
35                     T* output_data, int32_t* indices_data) {
36   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
37   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
38 
39   const int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
40   const int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
41   const int32_t input_height = input_shape.Dims(1);
42   const int32_t input_width = input_shape.Dims(2);
43   const int32_t output_height = output_shape.Dims(1);
44   const int32_t output_width = output_shape.Dims(2);
45   const int32_t stride_height = params.stride_height;
46   const int32_t stride_width = params.stride_width;
47   for (int32_t batch = 0; batch < batches; ++batch) {
48     for (int32_t out_y = 0; out_y < output_height; ++out_y) {
49       for (int32_t out_x = 0; out_x < output_width; ++out_x) {
50         for (int32_t channel = 0; channel < depth; ++channel) {
51           const int32_t in_x_origin =
52               (out_x * stride_width) - params.padding_values.width;
53           const int32_t in_y_origin =
54               (out_y * stride_height) - params.padding_values.height;
55           // Compute the boundaries of the filter region clamped so as to
56           // ensure that the filter window fits in the input array.
57           const int32_t filter_x_start = std::max(0, -in_x_origin);
58           const int32_t filter_x_end =
59               std::min(params.filter_width, input_width - in_x_origin);
60           const int32_t filter_y_start = std::max(0, -in_y_origin);
61           const int32_t filter_y_end =
62               std::min(params.filter_height, input_height - in_y_origin);
63           float max = std::numeric_limits<float>::lowest();
64           int32_t max_x = 0;
65           int32_t max_y = 0;
66 
67           for (int32_t filter_y = filter_y_start; filter_y < filter_y_end;
68                ++filter_y) {
69             for (int32_t filter_x = filter_x_start; filter_x < filter_x_end;
70                  ++filter_x) {
71               const int32_t in_x = in_x_origin + filter_x;
72               const int32_t in_y = in_y_origin + filter_y;
73               float cur =
74                   input_data[Offset(input_shape, batch, in_y, in_x, channel)];
75               if (cur > max) {
76                 max = cur;
77                 max_x = in_x;
78                 max_y = in_y;
79               }
80             }
81           }
82           int32_t output_idx =
83               Offset(output_shape, batch, out_y, out_x, channel);
84           output_data[output_idx] = ActivationFunctionWithMinMax(
85               max, params.float_activation_min, params.float_activation_max);
86           indices_data[output_idx] =
87               (max_y * input_width + max_x) * depth + channel;
88         }
89       }
90     }
91   }
92 }
93 
94 }  // namespace
95 
96 constexpr int kDataInputTensor = 0;
97 constexpr int kDataOutputTensor = 0;
98 constexpr int kIndicesOutputTensor = 1;
99 
100 constexpr const char kIncludeBatchStr[] = "include_batch_in_index";
101 constexpr const char kPoolSizeStr[] = "ksize";
102 constexpr const char kStridesStr[] = "strides";
103 constexpr const char kPaddingStr[] = "padding";
104 constexpr const char kPaddingSameStr[] = "SAME";
105 constexpr const char kPaddingValidStr[] = "VALID";
106 
107 struct OpData {
108   TfLitePoolParams params;
109   bool include_batch_in_index;
110 };
111 
Init(TfLiteContext * context,const char * buffer,size_t length)112 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
113   const flexbuffers::Map& m =
114       flexbuffers::GetRoot(reinterpret_cast<const uint8_t*>(buffer), length)
115           .AsMap();
116 
117   OpData* op_data = new OpData;
118   op_data->params.computed.padding = TfLitePaddingValues{0, 0, 0, 0};
119   op_data->include_batch_in_index = m[kIncludeBatchStr].AsBool();
120   op_data->params.activation = kTfLiteActNone;
121 
122   const std::string padding = m[kPaddingStr].AsString().str();
123   if (padding == kPaddingValidStr) {
124     op_data->params.padding = kTfLitePaddingValid;
125   } else if (padding == kPaddingSameStr) {
126     op_data->params.padding = kTfLitePaddingSame;
127   } else {
128     op_data->params.padding = kTfLitePaddingUnknown;
129   }
130 
131   // The first and last element of pool_size are always 1.
132   const auto pool_size = m[kPoolSizeStr].AsTypedVector();
133   TFLITE_CHECK_EQ(pool_size.size(), 4);
134   TFLITE_CHECK_EQ(pool_size[0].AsInt32(), 1);
135   TFLITE_CHECK_EQ(pool_size[3].AsInt32(), 1);
136   op_data->params.filter_height = pool_size[1].AsInt32();
137   op_data->params.filter_width = pool_size[2].AsInt32();
138 
139   // The first and last element of strides are always 1.
140   const auto strides = m[kStridesStr].AsTypedVector();
141   TFLITE_CHECK_EQ(strides.size(), 4);
142   TFLITE_CHECK_EQ(strides[0].AsInt32(), 1);
143   TFLITE_CHECK_EQ(strides[3].AsInt32(), 1);
144   op_data->params.stride_height = strides[1].AsInt32();
145   op_data->params.stride_width = strides[2].AsInt32();
146 
147   return op_data;
148 }
149 
Free(TfLiteContext * context,void * buffer)150 void Free(TfLiteContext* context, void* buffer) {
151   delete reinterpret_cast<OpData*>(buffer);
152 }
153 
Prepare(TfLiteContext * context,TfLiteNode * node)154 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
155   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
156 
157   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
158   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
159   TfLiteTensor *output, *indices;
160   TF_LITE_ENSURE_OK(context,
161                     GetOutputSafe(context, node, kDataOutputTensor, &output));
162   TF_LITE_ENSURE_OK(
163       context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
164   const TfLiteTensor* input;
165   TF_LITE_ENSURE_OK(context,
166                     GetInputSafe(context, node, kDataInputTensor, &input));
167   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
168   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
169   TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
170   TF_LITE_ENSURE(context, indices->type == kTfLiteInt32);
171   TF_LITE_ENSURE(context, op_data->params.padding != kTfLitePaddingUnknown);
172   TF_LITE_ENSURE_MSG(
173       context, !op_data->include_batch_in_index,
174       "Include batch dimension in flattened index is not yet supported.");
175 
176   int batches = input->dims->data[0];
177   int height = input->dims->data[1];
178   int width = input->dims->data[2];
179   int channels_out = input->dims->data[3];
180 
181   // Matching GetWindowedOutputSize in TensorFlow.
182   int out_width, out_height;
183   op_data->params.computed.padding = ComputePaddingHeightWidth(
184       op_data->params.stride_height, op_data->params.stride_width, 1, 1, height,
185       width, op_data->params.filter_height, op_data->params.filter_width,
186       op_data->params.padding, &out_height, &out_width);
187 
188   TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
189   output_size->data[0] = batches;
190   output_size->data[1] = out_height;
191   output_size->data[2] = out_width;
192   output_size->data[3] = channels_out;
193   TfLiteIntArray* indices_size = TfLiteIntArrayCopy(output_size);
194 
195   TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, indices, indices_size));
196   return context->ResizeTensor(context, output, output_size);
197 }
198 
Eval(TfLiteContext * context,TfLiteNode * node)199 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
200   OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
201 
202   float activation_min, activation_max;
203   CalculateActivationRange(op_data->params.activation, &activation_min,
204                            &activation_max);
205 
206   tflite::PoolParams op_params;
207   op_params.stride_height = op_data->params.stride_height;
208   op_params.stride_width = op_data->params.stride_width;
209   op_params.filter_height = op_data->params.filter_height;
210   op_params.filter_width = op_data->params.filter_width;
211   op_params.padding_values.height = op_data->params.computed.padding.height;
212   op_params.padding_values.width = op_data->params.computed.padding.width;
213   op_params.float_activation_min = activation_min;
214   op_params.float_activation_max = activation_max;
215 
216   TfLiteTensor *output, *indices;
217   TF_LITE_ENSURE_OK(context,
218                     GetOutputSafe(context, node, kDataOutputTensor, &output));
219   TF_LITE_ENSURE_OK(
220       context, GetOutputSafe(context, node, kIndicesOutputTensor, &indices));
221   const TfLiteTensor* input;
222   TF_LITE_ENSURE_OK(context,
223                     GetInputSafe(context, node, kDataInputTensor, &input));
224 
225   switch (input->type) {
226     case kTfLiteFloat32:
227       MaxPool<float>(op_params, GetTensorShape(input), GetTensorShape(output),
228                      GetTensorData<float>(input), GetTensorData<float>(output),
229                      GetTensorData<int32_t>(indices));
230       break;
231     default:
232       TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
233                          TfLiteTypeGetName(input->type));
234       return kTfLiteError;
235   }
236   return kTfLiteOk;
237 }
238 }  // namespace max_pool_with_argmax
239 
RegisterMaxPoolWithArgmax()240 TfLiteRegistration* RegisterMaxPoolWithArgmax() {
241   static TfLiteRegistration r = {
242       max_pool_with_argmax::Init, max_pool_with_argmax::Free,
243       max_pool_with_argmax::Prepare, max_pool_with_argmax::Eval};
244   return &r;
245 }
246 
247 // Alias for selective build.
Register_MAX_POOL_WITH_ARGMAX()248 TfLiteRegistration* Register_MAX_POOL_WITH_ARGMAX() {
249   return RegisterMaxPoolWithArgmax();
250 }
251 
252 }  // namespace custom
253 }  // namespace ops
254 }  // namespace tflite
255