1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <numeric>
17
18 #define FLATBUFFERS_LOCALE_INDEPENDENT 0
19 #include "flatbuffers/flexbuffers.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/kernels/op_macros.h"
27 #include "tensorflow/lite/micro/kernels/kernel_util.h"
28 #include "tensorflow/lite/micro/micro_utils.h"
29
30 namespace tflite {
31 namespace {
32
33 /**
34 * This version of detection_postprocess is specific to TFLite Micro. It
35 * contains the following differences between the TFLite version:
36 *
37 * 1.) Temporaries (temporary tensors) - Micro use instead scratch buffer API.
38 * 2.) Output dimensions - the TFLite version does not support undefined out
39 * dimensions. So model must have static out dimensions.
40 */
41
42 // Input tensors
43 constexpr int kInputTensorBoxEncodings = 0;
44 constexpr int kInputTensorClassPredictions = 1;
45 constexpr int kInputTensorAnchors = 2;
46
47 // Output tensors
48 constexpr int kOutputTensorDetectionBoxes = 0;
49 constexpr int kOutputTensorDetectionClasses = 1;
50 constexpr int kOutputTensorDetectionScores = 2;
51 constexpr int kOutputTensorNumDetections = 3;
52
53 constexpr int kNumCoordBox = 4;
54 constexpr int kBatchSize = 1;
55
56 constexpr int kNumDetectionsPerClass = 100;
57
58 // Object Detection model produces axis-aligned boxes in two formats:
59 // BoxCorner represents the lower left corner (xmin, ymin) and
60 // the upper right corner (xmax, ymax).
61 // CenterSize represents the center (xcenter, ycenter), height and width.
62 // BoxCornerEncoding and CenterSizeEncoding are related as follows:
63 // ycenter = y / y_scale * anchor.h + anchor.y;
64 // xcenter = x / x_scale * anchor.w + anchor.x;
65 // half_h = 0.5*exp(h/ h_scale)) * anchor.h;
66 // half_w = 0.5*exp(w / w_scale)) * anchor.w;
67 // ymin = ycenter - half_h
68 // ymax = ycenter + half_h
69 // xmin = xcenter - half_w
70 // xmax = xcenter + half_w
71 struct BoxCornerEncoding {
72 float ymin;
73 float xmin;
74 float ymax;
75 float xmax;
76 };
77
78 struct CenterSizeEncoding {
79 float y;
80 float x;
81 float h;
82 float w;
83 };
84 // We make sure that the memory allocations are contiguous with static_assert.
85 static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
86 "Size of BoxCornerEncoding is 4 float values");
87 static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
88 "Size of CenterSizeEncoding is 4 float values");
89
90 struct OpData {
91 int max_detections;
92 int max_classes_per_detection; // Fast Non-Max-Suppression
93 int detections_per_class; // Regular Non-Max-Suppression
94 float non_max_suppression_score_threshold;
95 float intersection_over_union_threshold;
96 int num_classes;
97 bool use_regular_non_max_suppression;
98 CenterSizeEncoding scale_values;
99
100 // Scratch buffers indexes
101 int active_candidate_idx;
102 int decoded_boxes_idx;
103 int scores_idx;
104 int score_buffer_idx;
105 int keep_scores_idx;
106 int scores_after_regular_non_max_suppression_idx;
107 int sorted_values_idx;
108 int keep_indices_idx;
109 int sorted_indices_idx;
110 int buffer_idx;
111 int selected_idx;
112
113 // Cached tensor scale and zero point values for quantized operations
114 TfLiteQuantizationParams input_box_encodings;
115 TfLiteQuantizationParams input_class_predictions;
116 TfLiteQuantizationParams input_anchors;
117 };
118
Init(TfLiteContext * context,const char * buffer,size_t length)119 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
120 OpData* op_data = nullptr;
121
122 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
123 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
124
125 TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
126 op_data = reinterpret_cast<OpData*>(
127 context->AllocatePersistentBuffer(context, sizeof(OpData)));
128
129 op_data->max_detections = m["max_detections"].AsInt32();
130 op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
131 if (m["detections_per_class"].IsNull())
132 op_data->detections_per_class = kNumDetectionsPerClass;
133 else
134 op_data->detections_per_class = m["detections_per_class"].AsInt32();
135 if (m["use_regular_nms"].IsNull())
136 op_data->use_regular_non_max_suppression = false;
137 else
138 op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
139
140 op_data->non_max_suppression_score_threshold =
141 m["nms_score_threshold"].AsFloat();
142 op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
143 op_data->num_classes = m["num_classes"].AsInt32();
144 op_data->scale_values.y = m["y_scale"].AsFloat();
145 op_data->scale_values.x = m["x_scale"].AsFloat();
146 op_data->scale_values.h = m["h_scale"].AsFloat();
147 op_data->scale_values.w = m["w_scale"].AsFloat();
148
149 return op_data;
150 }
151
Free(TfLiteContext * context,void * buffer)152 void Free(TfLiteContext* context, void* buffer) {}
153
Prepare(TfLiteContext * context,TfLiteNode * node)154 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
155 auto* op_data = static_cast<OpData*>(node->user_data);
156
157 // Inputs: box_encodings, scores, anchors
158 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
159 const TfLiteTensor* input_box_encodings =
160 GetInput(context, node, kInputTensorBoxEncodings);
161 const TfLiteTensor* input_class_predictions =
162 GetInput(context, node, kInputTensorClassPredictions);
163 const TfLiteTensor* input_anchors =
164 GetInput(context, node, kInputTensorAnchors);
165 TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
166 TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
167 TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
168
169 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
170 const int num_boxes = input_box_encodings->dims->data[1];
171 const int num_classes = op_data->num_classes;
172
173 op_data->input_box_encodings.scale = input_box_encodings->params.scale;
174 op_data->input_box_encodings.zero_point =
175 input_box_encodings->params.zero_point;
176 op_data->input_class_predictions.scale =
177 input_class_predictions->params.scale;
178 op_data->input_class_predictions.zero_point =
179 input_class_predictions->params.zero_point;
180 op_data->input_anchors.scale = input_anchors->params.scale;
181 op_data->input_anchors.zero_point = input_anchors->params.zero_point;
182
183 // Scratch tensors
184 context->RequestScratchBufferInArena(context, num_boxes,
185 &op_data->active_candidate_idx);
186 context->RequestScratchBufferInArena(context,
187 num_boxes * kNumCoordBox * sizeof(float),
188 &op_data->decoded_boxes_idx);
189 context->RequestScratchBufferInArena(
190 context,
191 input_class_predictions->dims->data[1] *
192 input_class_predictions->dims->data[2] * sizeof(float),
193 &op_data->scores_idx);
194
195 // Additional buffers
196 context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
197 &op_data->score_buffer_idx);
198 context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
199 &op_data->keep_scores_idx);
200 context->RequestScratchBufferInArena(
201 context, op_data->max_detections * num_boxes * sizeof(float),
202 &op_data->scores_after_regular_non_max_suppression_idx);
203 context->RequestScratchBufferInArena(
204 context, op_data->max_detections * num_boxes * sizeof(float),
205 &op_data->sorted_values_idx);
206 context->RequestScratchBufferInArena(context, num_boxes * sizeof(int),
207 &op_data->keep_indices_idx);
208 context->RequestScratchBufferInArena(
209 context, op_data->max_detections * num_boxes * sizeof(int),
210 &op_data->sorted_indices_idx);
211 int buffer_size = std::max(num_classes, op_data->max_detections);
212 context->RequestScratchBufferInArena(
213 context, buffer_size * num_boxes * sizeof(int), &op_data->buffer_idx);
214 buffer_size = std::min(num_boxes, op_data->max_detections);
215 context->RequestScratchBufferInArena(
216 context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx);
217
218 // Outputs: detection_boxes, detection_scores, detection_classes,
219 // num_detections
220 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
221
222 return kTfLiteOk;
223 }
224
225 class Dequantizer {
226 public:
Dequantizer(int zero_point,float scale)227 Dequantizer(int zero_point, float scale)
228 : zero_point_(zero_point), scale_(scale) {}
operator ()(uint8_t x)229 float operator()(uint8_t x) {
230 return (static_cast<float>(x) - zero_point_) * scale_;
231 }
232
233 private:
234 int zero_point_;
235 float scale_;
236 };
237
DequantizeBoxEncodings(const TfLiteEvalTensor * input_box_encodings,int idx,float quant_zero_point,float quant_scale,int length_box_encoding,CenterSizeEncoding * box_centersize)238 void DequantizeBoxEncodings(const TfLiteEvalTensor* input_box_encodings,
239 int idx, float quant_zero_point, float quant_scale,
240 int length_box_encoding,
241 CenterSizeEncoding* box_centersize) {
242 const uint8_t* boxes =
243 tflite::micro::GetTensorData<uint8_t>(input_box_encodings) +
244 length_box_encoding * idx;
245 Dequantizer dequantize(quant_zero_point, quant_scale);
246 // See definition of the KeyPointBoxCoder at
247 // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
248 // The first four elements are the box coordinates, which is the same as the
249 // FastRnnBoxCoder at
250 // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
251 box_centersize->y = dequantize(boxes[0]);
252 box_centersize->x = dequantize(boxes[1]);
253 box_centersize->h = dequantize(boxes[2]);
254 box_centersize->w = dequantize(boxes[3]);
255 }
256
257 template <class T>
ReInterpretTensor(const TfLiteEvalTensor * tensor)258 T ReInterpretTensor(const TfLiteEvalTensor* tensor) {
259 const float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
260 return reinterpret_cast<T>(tensor_base);
261 }
262
263 template <class T>
ReInterpretTensor(TfLiteEvalTensor * tensor)264 T ReInterpretTensor(TfLiteEvalTensor* tensor) {
265 float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
266 return reinterpret_cast<T>(tensor_base);
267 }
268
DecodeCenterSizeBoxes(TfLiteContext * context,TfLiteNode * node,OpData * op_data)269 TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
270 OpData* op_data) {
271 // Parse input tensor boxencodings
272 const TfLiteEvalTensor* input_box_encodings =
273 tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
274 TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
275 const int num_boxes = input_box_encodings->dims->data[1];
276 TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
277 const TfLiteEvalTensor* input_anchors =
278 tflite::micro::GetEvalInput(context, node, kInputTensorAnchors);
279
280 // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
281 CenterSizeEncoding box_centersize;
282 CenterSizeEncoding scale_values = op_data->scale_values;
283 CenterSizeEncoding anchor;
284 for (int idx = 0; idx < num_boxes; ++idx) {
285 switch (input_box_encodings->type) {
286 // Quantized
287 case kTfLiteUInt8:
288 DequantizeBoxEncodings(
289 input_box_encodings, idx,
290 static_cast<float>(op_data->input_box_encodings.zero_point),
291 static_cast<float>(op_data->input_box_encodings.scale),
292 input_box_encodings->dims->data[2], &box_centersize);
293 DequantizeBoxEncodings(
294 input_anchors, idx,
295 static_cast<float>(op_data->input_anchors.zero_point),
296 static_cast<float>(op_data->input_anchors.scale), kNumCoordBox,
297 &anchor);
298 break;
299 // Float
300 case kTfLiteFloat32: {
301 // Please see DequantizeBoxEncodings function for the support detail.
302 const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
303 const float* boxes = &(tflite::micro::GetTensorData<float>(
304 input_box_encodings)[box_encoding_idx]);
305 box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
306 anchor =
307 ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
308 break;
309 }
310 default:
311 // Unsupported type.
312 return kTfLiteError;
313 }
314
315 float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) /
316 static_cast<double>(scale_values.y) *
317 static_cast<double>(anchor.h) +
318 static_cast<double>(anchor.y));
319
320 float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) /
321 static_cast<double>(scale_values.x) *
322 static_cast<double>(anchor.w) +
323 static_cast<double>(anchor.x));
324
325 float half_h =
326 static_cast<float>(0.5 *
327 (std::exp(static_cast<double>(box_centersize.h) /
328 static_cast<double>(scale_values.h))) *
329 static_cast<double>(anchor.h));
330 float half_w =
331 static_cast<float>(0.5 *
332 (std::exp(static_cast<double>(box_centersize.w) /
333 static_cast<double>(scale_values.w))) *
334 static_cast<double>(anchor.w));
335
336 float* decoded_boxes = reinterpret_cast<float*>(
337 context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
338 auto& box = reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[idx];
339 box.ymin = ycenter - half_h;
340 box.xmin = xcenter - half_w;
341 box.ymax = ycenter + half_h;
342 box.xmax = xcenter + half_w;
343 }
344 return kTfLiteOk;
345 }
346
DecreasingPartialArgSort(const float * values,int num_values,int num_to_sort,int * indices)347 void DecreasingPartialArgSort(const float* values, int num_values,
348 int num_to_sort, int* indices) {
349 std::iota(indices, indices + num_values, 0);
350 std::partial_sort(
351 indices, indices + num_to_sort, indices + num_values,
352 [&values](const int i, const int j) { return values[i] > values[j]; });
353 }
354
SelectDetectionsAboveScoreThreshold(const float * values,int size,const float threshold,float * keep_values,int * keep_indices)355 int SelectDetectionsAboveScoreThreshold(const float* values, int size,
356 const float threshold,
357 float* keep_values, int* keep_indices) {
358 int counter = 0;
359 for (int i = 0; i < size; i++) {
360 if (values[i] >= threshold) {
361 keep_values[counter] = values[i];
362 keep_indices[counter] = i;
363 counter++;
364 }
365 }
366 return counter;
367 }
368
ValidateBoxes(const float * decoded_boxes,const int num_boxes)369 bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) {
370 for (int i = 0; i < num_boxes; ++i) {
371 // ymax>=ymin, xmax>=xmin
372 auto& box = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
373 if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
374 return false;
375 }
376 }
377 return true;
378 }
379
ComputeIntersectionOverUnion(const float * decoded_boxes,const int i,const int j)380 float ComputeIntersectionOverUnion(const float* decoded_boxes, const int i,
381 const int j) {
382 auto& box_i = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
383 auto& box_j = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[j];
384 const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
385 const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
386 if (area_i <= 0 || area_j <= 0) return 0.0;
387 const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
388 const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
389 const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
390 const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
391 const float intersection_area =
392 std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
393 std::max<float>(intersection_xmax - intersection_xmin, 0.0);
394 return intersection_area / (area_i + area_j - intersection_area);
395 }
396
397 // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
398 // before selecting the highest scoring boxes (max_detections in number)
399 // It assumes all boxes are good in beginning and sorts based on the scores.
400 // If lower-scoring box has too much overlap with a higher-scoring box,
401 // we get rid of the lower-scoring box.
402 // Complexity is O(N^2) pairwise comparison between boxes
NonMaxSuppressionSingleClassHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores,int * selected,int * selected_size,int max_detections)403 TfLiteStatus NonMaxSuppressionSingleClassHelper(
404 TfLiteContext* context, TfLiteNode* node, OpData* op_data,
405 const float* scores, int* selected, int* selected_size,
406 int max_detections) {
407 const TfLiteEvalTensor* input_box_encodings =
408 tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
409 const int num_boxes = input_box_encodings->dims->data[1];
410 const float non_max_suppression_score_threshold =
411 op_data->non_max_suppression_score_threshold;
412 const float intersection_over_union_threshold =
413 op_data->intersection_over_union_threshold;
414 // Maximum detections should be positive.
415 TF_LITE_ENSURE(context, (max_detections >= 0));
416 // intersection_over_union_threshold should be positive
417 // and should be less than 1.
418 TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
419 (intersection_over_union_threshold <= 1.0f));
420 // Validate boxes
421 float* decoded_boxes = reinterpret_cast<float*>(
422 context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
423
424 TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
425
426 // threshold scores
427 int* keep_indices = reinterpret_cast<int*>(
428 context->GetScratchBuffer(context, op_data->keep_indices_idx));
429 float* keep_scores = reinterpret_cast<float*>(
430 context->GetScratchBuffer(context, op_data->keep_scores_idx));
431 int num_scores_kept = SelectDetectionsAboveScoreThreshold(
432 scores, num_boxes, non_max_suppression_score_threshold, keep_scores,
433 keep_indices);
434 int* sorted_indices = reinterpret_cast<int*>(
435 context->GetScratchBuffer(context, op_data->sorted_indices_idx));
436
437 DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept,
438 sorted_indices);
439
440 const int num_boxes_kept = num_scores_kept;
441 const int output_size = std::min(num_boxes_kept, max_detections);
442 *selected_size = 0;
443
444 int num_active_candidate = num_boxes_kept;
445 uint8_t* active_box_candidate = reinterpret_cast<uint8_t*>(
446 context->GetScratchBuffer(context, op_data->active_candidate_idx));
447
448 for (int row = 0; row < num_boxes_kept; row++) {
449 active_box_candidate[row] = 1;
450 }
451 for (int i = 0; i < num_boxes_kept; ++i) {
452 if (num_active_candidate == 0 || *selected_size >= output_size) break;
453 if (active_box_candidate[i] == 1) {
454 selected[(*selected_size)++] = keep_indices[sorted_indices[i]];
455 active_box_candidate[i] = 0;
456 num_active_candidate--;
457 } else {
458 continue;
459 }
460 for (int j = i + 1; j < num_boxes_kept; ++j) {
461 if (active_box_candidate[j] == 1) {
462 float intersection_over_union = ComputeIntersectionOverUnion(
463 decoded_boxes, keep_indices[sorted_indices[i]],
464 keep_indices[sorted_indices[j]]);
465
466 if (intersection_over_union > intersection_over_union_threshold) {
467 active_box_candidate[j] = 0;
468 num_active_candidate--;
469 }
470 }
471 }
472 }
473
474 return kTfLiteOk;
475 }
476
477 // This function implements a regular version of Non Maximal Suppression (NMS)
478 // for multiple classes where
479 // 1) we do NMS separately for each class across all anchors and
480 // 2) keep only the highest anchor scores across all classes
481 // 3) The worst runtime of the regular NMS is O(K*N^2)
482 // where N is the number of anchors and K the number of
483 // classes.
NonMaxSuppressionMultiClassRegularHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)484 TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
485 TfLiteNode* node,
486 OpData* op_data,
487 const float* scores) {
488 const TfLiteEvalTensor* input_box_encodings =
489 tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
490 const TfLiteEvalTensor* input_class_predictions =
491 tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
492 TfLiteEvalTensor* detection_boxes =
493 tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
494 TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
495 context, node, kOutputTensorDetectionClasses);
496 TfLiteEvalTensor* detection_scores =
497 tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
498 TfLiteEvalTensor* num_detections =
499 tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
500
501 const int num_boxes = input_box_encodings->dims->data[1];
502 const int num_classes = op_data->num_classes;
503 const int num_detections_per_class = op_data->detections_per_class;
504 const int max_detections = op_data->max_detections;
505 const int num_classes_with_background =
506 input_class_predictions->dims->data[2];
507 // The row index offset is 1 if background class is included and 0 otherwise.
508 int label_offset = num_classes_with_background - num_classes;
509 TF_LITE_ENSURE(context, num_detections_per_class > 0);
510
511 // For each class, perform non-max suppression.
512 float* class_scores = reinterpret_cast<float*>(
513 context->GetScratchBuffer(context, op_data->score_buffer_idx));
514 int* box_indices_after_regular_non_max_suppression = reinterpret_cast<int*>(
515 context->GetScratchBuffer(context, op_data->buffer_idx));
516 float* scores_after_regular_non_max_suppression =
517 reinterpret_cast<float*>(context->GetScratchBuffer(
518 context, op_data->scores_after_regular_non_max_suppression_idx));
519
520 int size_of_sorted_indices = 0;
521 int* sorted_indices = reinterpret_cast<int*>(
522 context->GetScratchBuffer(context, op_data->sorted_indices_idx));
523 float* sorted_values = reinterpret_cast<float*>(
524 context->GetScratchBuffer(context, op_data->sorted_values_idx));
525
526 for (int col = 0; col < num_classes; col++) {
527 for (int row = 0; row < num_boxes; row++) {
528 // Get scores of boxes corresponding to all anchors for single class
529 class_scores[row] =
530 *(scores + row * num_classes_with_background + col + label_offset);
531 }
532 // Perform non-maximal suppression on single class
533 int selected_size = 0;
534 int* selected = reinterpret_cast<int*>(
535 context->GetScratchBuffer(context, op_data->selected_idx));
536 TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
537 context, node, op_data, class_scores, selected, &selected_size,
538 num_detections_per_class));
539 // Add selected indices from non-max suppression of boxes in this class
540 int output_index = size_of_sorted_indices;
541 for (int i = 0; i < selected_size; i++) {
542 int selected_index = selected[i];
543
544 box_indices_after_regular_non_max_suppression[output_index] =
545 (selected_index * num_classes_with_background + col + label_offset);
546 scores_after_regular_non_max_suppression[output_index] =
547 class_scores[selected_index];
548 output_index++;
549 }
550 // Sort the max scores among the selected indices
551 // Get the indices for top scores
552 int num_indices_to_sort = std::min(output_index, max_detections);
553 DecreasingPartialArgSort(scores_after_regular_non_max_suppression,
554 output_index, num_indices_to_sort, sorted_indices);
555
556 // Copy values to temporary vectors
557 for (int row = 0; row < num_indices_to_sort; row++) {
558 int temp = sorted_indices[row];
559 sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
560 sorted_values[row] = scores_after_regular_non_max_suppression[temp];
561 }
562 // Copy scores and indices from temporary vectors
563 for (int row = 0; row < num_indices_to_sort; row++) {
564 box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
565 scores_after_regular_non_max_suppression[row] = sorted_values[row];
566 }
567 size_of_sorted_indices = num_indices_to_sort;
568 }
569
570 // Allocate output tensors
571 for (int output_box_index = 0; output_box_index < max_detections;
572 output_box_index++) {
573 if (output_box_index < size_of_sorted_indices) {
574 const int anchor_index = floor(
575 box_indices_after_regular_non_max_suppression[output_box_index] /
576 num_classes_with_background);
577 const int class_index =
578 box_indices_after_regular_non_max_suppression[output_box_index] -
579 anchor_index * num_classes_with_background - label_offset;
580 const float selected_score =
581 scores_after_regular_non_max_suppression[output_box_index];
582 // detection_boxes
583 float* decoded_boxes = reinterpret_cast<float*>(
584 context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
585 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
586 reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[anchor_index];
587 // detection_classes
588 tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
589 class_index;
590 // detection_scores
591 tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
592 selected_score;
593 } else {
594 ReInterpretTensor<BoxCornerEncoding*>(
595 detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
596 // detection_classes
597 tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
598 0.0f;
599 // detection_scores
600 tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
601 0.0f;
602 }
603 }
604 tflite::micro::GetTensorData<float>(num_detections)[0] =
605 size_of_sorted_indices;
606
607 return kTfLiteOk;
608 }
609
610 // This function implements a fast version of Non Maximal Suppression for
611 // multiple classes where
612 // 1) we keep the top-k scores for each anchor and
613 // 2) during NMS, each anchor only uses the highest class score for sorting.
614 // 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
615 // instead of O(KN^2) where N is the number of anchors and K the number of
616 // classes.
NonMaxSuppressionMultiClassFastHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)617 TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
618 TfLiteNode* node,
619 OpData* op_data,
620 const float* scores) {
621 const TfLiteEvalTensor* input_box_encodings =
622 tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
623 const TfLiteEvalTensor* input_class_predictions =
624 tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
625 TfLiteEvalTensor* detection_boxes =
626 tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
627
628 TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
629 context, node, kOutputTensorDetectionClasses);
630 TfLiteEvalTensor* detection_scores =
631 tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
632 TfLiteEvalTensor* num_detections =
633 tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
634
635 const int num_boxes = input_box_encodings->dims->data[1];
636 const int num_classes = op_data->num_classes;
637 const int max_categories_per_anchor = op_data->max_classes_per_detection;
638 const int num_classes_with_background =
639 input_class_predictions->dims->data[2];
640
641 // The row index offset is 1 if background class is included and 0 otherwise.
642 int label_offset = num_classes_with_background - num_classes;
643 TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
644 const int num_categories_per_anchor =
645 std::min(max_categories_per_anchor, num_classes);
646 float* max_scores = reinterpret_cast<float*>(
647 context->GetScratchBuffer(context, op_data->score_buffer_idx));
648 int* sorted_class_indices = reinterpret_cast<int*>(
649 context->GetScratchBuffer(context, op_data->buffer_idx));
650
651 for (int row = 0; row < num_boxes; row++) {
652 const float* box_scores =
653 scores + row * num_classes_with_background + label_offset;
654 int* class_indices = sorted_class_indices + row * num_classes;
655 DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
656 class_indices);
657 max_scores[row] = box_scores[class_indices[0]];
658 }
659
660 // Perform non-maximal suppression on max scores
661 int selected_size = 0;
662 int* selected = reinterpret_cast<int*>(
663 context->GetScratchBuffer(context, op_data->selected_idx));
664 TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
665 context, node, op_data, max_scores, selected, &selected_size,
666 op_data->max_detections));
667
668 // Allocate output tensors
669 int output_box_index = 0;
670
671 for (int i = 0; i < selected_size; i++) {
672 int selected_index = selected[i];
673
674 const float* box_scores =
675 scores + selected_index * num_classes_with_background + label_offset;
676 const int* class_indices =
677 sorted_class_indices + selected_index * num_classes;
678
679 for (int col = 0; col < num_categories_per_anchor; ++col) {
680 int box_offset = num_categories_per_anchor * output_box_index + col;
681
682 // detection_boxes
683 float* decoded_boxes = reinterpret_cast<float*>(
684 context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
685 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
686 reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[selected_index];
687
688 // detection_classes
689 tflite::micro::GetTensorData<float>(detection_classes)[box_offset] =
690 class_indices[col];
691
692 // detection_scores
693 tflite::micro::GetTensorData<float>(detection_scores)[box_offset] =
694 box_scores[class_indices[col]];
695
696 output_box_index++;
697 }
698 }
699
700 tflite::micro::GetTensorData<float>(num_detections)[0] = output_box_index;
701 return kTfLiteOk;
702 }
703
DequantizeClassPredictions(const TfLiteEvalTensor * input_class_predictions,const int num_boxes,const int num_classes_with_background,float * scores,OpData * op_data)704 void DequantizeClassPredictions(const TfLiteEvalTensor* input_class_predictions,
705 const int num_boxes,
706 const int num_classes_with_background,
707 float* scores, OpData* op_data) {
708 float quant_zero_point =
709 static_cast<float>(op_data->input_class_predictions.zero_point);
710 float quant_scale =
711 static_cast<float>(op_data->input_class_predictions.scale);
712 Dequantizer dequantize(quant_zero_point, quant_scale);
713 const uint8_t* scores_quant =
714 tflite::micro::GetTensorData<uint8_t>(input_class_predictions);
715 for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
716 scores[idx] = dequantize(scores_quant[idx]);
717 }
718 }
719
NonMaxSuppressionMultiClass(TfLiteContext * context,TfLiteNode * node,OpData * op_data)720 TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
721 TfLiteNode* node, OpData* op_data) {
722 // Get the input tensors
723 const TfLiteEvalTensor* input_box_encodings =
724 tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
725 const TfLiteEvalTensor* input_class_predictions =
726 tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
727 const int num_boxes = input_box_encodings->dims->data[1];
728 const int num_classes = op_data->num_classes;
729
730 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
731 kBatchSize);
732 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
733 const int num_classes_with_background =
734 input_class_predictions->dims->data[2];
735
736 TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
737 TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
738
739 const float* scores;
740 switch (input_class_predictions->type) {
741 case kTfLiteUInt8: {
742 float* temporary_scores = reinterpret_cast<float*>(
743 context->GetScratchBuffer(context, op_data->scores_idx));
744 DequantizeClassPredictions(input_class_predictions, num_boxes,
745 num_classes_with_background, temporary_scores,
746 op_data);
747 scores = temporary_scores;
748 } break;
749 case kTfLiteFloat32:
750 scores = tflite::micro::GetTensorData<float>(input_class_predictions);
751 break;
752 default:
753 // Unsupported type.
754 return kTfLiteError;
755 }
756
757 if (op_data->use_regular_non_max_suppression) {
758 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
759 context, node, op_data, scores));
760 } else {
761 TF_LITE_ENSURE_STATUS(
762 NonMaxSuppressionMultiClassFastHelper(context, node, op_data, scores));
763 }
764
765 return kTfLiteOk;
766 }
767
Eval(TfLiteContext * context,TfLiteNode * node)768 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
769 TF_LITE_ENSURE(context, (kBatchSize == 1));
770 auto* op_data = static_cast<OpData*>(node->user_data);
771
772 // These two functions correspond to two blocks in the Object Detection model.
773 // In future, we would like to break the custom op in two blocks, which is
774 // currently not feasible because we would like to input quantized inputs
775 // and do all calculations in float. Mixed quantized/float calculations are
776 // currently not supported in TFLite.
777
778 // This fills in temporary decoded_boxes
779 // by transforming input_box_encodings and input_anchors from
780 // CenterSizeEncodings to BoxCornerEncoding
781 TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
782
783 // This fills in the output tensors
784 // by choosing effective set of decoded boxes
785 // based on Non Maximal Suppression, i.e. selecting
786 // highest scoring non-overlapping boxes.
787 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
788
789 return kTfLiteOk;
790 }
791 } // namespace
792
Register_DETECTION_POSTPROCESS()793 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
794 static TfLiteRegistration r = {/*init=*/Init,
795 /*free=*/Free,
796 /*prepare=*/Prepare,
797 /*invoke=*/Eval,
798 /*profiling_string=*/nullptr,
799 /*builtin_code=*/0,
800 /*custom_name=*/nullptr,
801 /*version=*/0};
802 return &r;
803 }
804
805 } // namespace tflite
806