• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <numeric>
17 
18 #define FLATBUFFERS_LOCALE_INDEPENDENT 0
19 #include "flatbuffers/flexbuffers.h"
20 #include "tensorflow/lite/c/builtin_op_data.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/common.h"
23 #include "tensorflow/lite/kernels/internal/quantization_util.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/kernels/kernel_util.h"
26 #include "tensorflow/lite/kernels/op_macros.h"
27 #include "tensorflow/lite/micro/kernels/kernel_util.h"
28 #include "tensorflow/lite/micro/micro_utils.h"
29 
30 namespace tflite {
31 namespace {
32 
33 /**
34  * This version of detection_postprocess is specific to TFLite Micro. It
35  * contains the following differences between the TFLite version:
36  *
37  * 1.) Temporaries (temporary tensors) - Micro use instead scratch buffer API.
38  * 2.) Output dimensions - the TFLite version does not support undefined out
39  * dimensions. So model must have static out dimensions.
40  */
41 
42 // Input tensors
43 constexpr int kInputTensorBoxEncodings = 0;
44 constexpr int kInputTensorClassPredictions = 1;
45 constexpr int kInputTensorAnchors = 2;
46 
47 // Output tensors
48 constexpr int kOutputTensorDetectionBoxes = 0;
49 constexpr int kOutputTensorDetectionClasses = 1;
50 constexpr int kOutputTensorDetectionScores = 2;
51 constexpr int kOutputTensorNumDetections = 3;
52 
53 constexpr int kNumCoordBox = 4;
54 constexpr int kBatchSize = 1;
55 
56 constexpr int kNumDetectionsPerClass = 100;
57 
58 // Object Detection model produces axis-aligned boxes in two formats:
59 // BoxCorner represents the lower left corner (xmin, ymin) and
60 // the upper right corner (xmax, ymax).
61 // CenterSize represents the center (xcenter, ycenter), height and width.
62 // BoxCornerEncoding and CenterSizeEncoding are related as follows:
63 // ycenter = y / y_scale * anchor.h + anchor.y;
64 // xcenter = x / x_scale * anchor.w + anchor.x;
65 // half_h = 0.5*exp(h/ h_scale)) * anchor.h;
66 // half_w = 0.5*exp(w / w_scale)) * anchor.w;
67 // ymin = ycenter - half_h
68 // ymax = ycenter + half_h
69 // xmin = xcenter - half_w
70 // xmax = xcenter + half_w
71 struct BoxCornerEncoding {
72   float ymin;
73   float xmin;
74   float ymax;
75   float xmax;
76 };
77 
78 struct CenterSizeEncoding {
79   float y;
80   float x;
81   float h;
82   float w;
83 };
84 // We make sure that the memory allocations are contiguous with static_assert.
85 static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
86               "Size of BoxCornerEncoding is 4 float values");
87 static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
88               "Size of CenterSizeEncoding is 4 float values");
89 
90 struct OpData {
91   int max_detections;
92   int max_classes_per_detection;  // Fast Non-Max-Suppression
93   int detections_per_class;       // Regular Non-Max-Suppression
94   float non_max_suppression_score_threshold;
95   float intersection_over_union_threshold;
96   int num_classes;
97   bool use_regular_non_max_suppression;
98   CenterSizeEncoding scale_values;
99 
100   // Scratch buffers indexes
101   int active_candidate_idx;
102   int decoded_boxes_idx;
103   int scores_idx;
104   int score_buffer_idx;
105   int keep_scores_idx;
106   int scores_after_regular_non_max_suppression_idx;
107   int sorted_values_idx;
108   int keep_indices_idx;
109   int sorted_indices_idx;
110   int buffer_idx;
111   int selected_idx;
112 
113   // Cached tensor scale and zero point values for quantized operations
114   TfLiteQuantizationParams input_box_encodings;
115   TfLiteQuantizationParams input_class_predictions;
116   TfLiteQuantizationParams input_anchors;
117 };
118 
Init(TfLiteContext * context,const char * buffer,size_t length)119 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
120   OpData* op_data = nullptr;
121 
122   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
123   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
124 
125   TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
126   op_data = reinterpret_cast<OpData*>(
127       context->AllocatePersistentBuffer(context, sizeof(OpData)));
128 
129   op_data->max_detections = m["max_detections"].AsInt32();
130   op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
131   if (m["detections_per_class"].IsNull())
132     op_data->detections_per_class = kNumDetectionsPerClass;
133   else
134     op_data->detections_per_class = m["detections_per_class"].AsInt32();
135   if (m["use_regular_nms"].IsNull())
136     op_data->use_regular_non_max_suppression = false;
137   else
138     op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
139 
140   op_data->non_max_suppression_score_threshold =
141       m["nms_score_threshold"].AsFloat();
142   op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
143   op_data->num_classes = m["num_classes"].AsInt32();
144   op_data->scale_values.y = m["y_scale"].AsFloat();
145   op_data->scale_values.x = m["x_scale"].AsFloat();
146   op_data->scale_values.h = m["h_scale"].AsFloat();
147   op_data->scale_values.w = m["w_scale"].AsFloat();
148 
149   return op_data;
150 }
151 
Free(TfLiteContext * context,void * buffer)152 void Free(TfLiteContext* context, void* buffer) {}
153 
Prepare(TfLiteContext * context,TfLiteNode * node)154 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
155   auto* op_data = static_cast<OpData*>(node->user_data);
156 
157   // Inputs: box_encodings, scores, anchors
158   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
159   const TfLiteTensor* input_box_encodings =
160       GetInput(context, node, kInputTensorBoxEncodings);
161   const TfLiteTensor* input_class_predictions =
162       GetInput(context, node, kInputTensorClassPredictions);
163   const TfLiteTensor* input_anchors =
164       GetInput(context, node, kInputTensorAnchors);
165   TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
166   TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
167   TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
168 
169   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
170   const int num_boxes = input_box_encodings->dims->data[1];
171   const int num_classes = op_data->num_classes;
172 
173   op_data->input_box_encodings.scale = input_box_encodings->params.scale;
174   op_data->input_box_encodings.zero_point =
175       input_box_encodings->params.zero_point;
176   op_data->input_class_predictions.scale =
177       input_class_predictions->params.scale;
178   op_data->input_class_predictions.zero_point =
179       input_class_predictions->params.zero_point;
180   op_data->input_anchors.scale = input_anchors->params.scale;
181   op_data->input_anchors.zero_point = input_anchors->params.zero_point;
182 
183   // Scratch tensors
184   context->RequestScratchBufferInArena(context, num_boxes,
185                                        &op_data->active_candidate_idx);
186   context->RequestScratchBufferInArena(context,
187                                        num_boxes * kNumCoordBox * sizeof(float),
188                                        &op_data->decoded_boxes_idx);
189   context->RequestScratchBufferInArena(
190       context,
191       input_class_predictions->dims->data[1] *
192           input_class_predictions->dims->data[2] * sizeof(float),
193       &op_data->scores_idx);
194 
195   // Additional buffers
196   context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
197                                        &op_data->score_buffer_idx);
198   context->RequestScratchBufferInArena(context, num_boxes * sizeof(float),
199                                        &op_data->keep_scores_idx);
200   context->RequestScratchBufferInArena(
201       context, op_data->max_detections * num_boxes * sizeof(float),
202       &op_data->scores_after_regular_non_max_suppression_idx);
203   context->RequestScratchBufferInArena(
204       context, op_data->max_detections * num_boxes * sizeof(float),
205       &op_data->sorted_values_idx);
206   context->RequestScratchBufferInArena(context, num_boxes * sizeof(int),
207                                        &op_data->keep_indices_idx);
208   context->RequestScratchBufferInArena(
209       context, op_data->max_detections * num_boxes * sizeof(int),
210       &op_data->sorted_indices_idx);
211   int buffer_size = std::max(num_classes, op_data->max_detections);
212   context->RequestScratchBufferInArena(
213       context, buffer_size * num_boxes * sizeof(int), &op_data->buffer_idx);
214   buffer_size = std::min(num_boxes, op_data->max_detections);
215   context->RequestScratchBufferInArena(
216       context, buffer_size * num_boxes * sizeof(int), &op_data->selected_idx);
217 
218   // Outputs: detection_boxes, detection_scores, detection_classes,
219   // num_detections
220   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
221 
222   return kTfLiteOk;
223 }
224 
225 class Dequantizer {
226  public:
Dequantizer(int zero_point,float scale)227   Dequantizer(int zero_point, float scale)
228       : zero_point_(zero_point), scale_(scale) {}
operator ()(uint8_t x)229   float operator()(uint8_t x) {
230     return (static_cast<float>(x) - zero_point_) * scale_;
231   }
232 
233  private:
234   int zero_point_;
235   float scale_;
236 };
237 
DequantizeBoxEncodings(const TfLiteEvalTensor * input_box_encodings,int idx,float quant_zero_point,float quant_scale,int length_box_encoding,CenterSizeEncoding * box_centersize)238 void DequantizeBoxEncodings(const TfLiteEvalTensor* input_box_encodings,
239                             int idx, float quant_zero_point, float quant_scale,
240                             int length_box_encoding,
241                             CenterSizeEncoding* box_centersize) {
242   const uint8_t* boxes =
243       tflite::micro::GetTensorData<uint8_t>(input_box_encodings) +
244       length_box_encoding * idx;
245   Dequantizer dequantize(quant_zero_point, quant_scale);
246   // See definition of the KeyPointBoxCoder at
247   // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/keypoint_box_coder.py
248   // The first four elements are the box coordinates, which is the same as the
249   // FastRnnBoxCoder at
250   // https://github.com/tensorflow/models/blob/master/research/object_detection/box_coders/faster_rcnn_box_coder.py
251   box_centersize->y = dequantize(boxes[0]);
252   box_centersize->x = dequantize(boxes[1]);
253   box_centersize->h = dequantize(boxes[2]);
254   box_centersize->w = dequantize(boxes[3]);
255 }
256 
257 template <class T>
ReInterpretTensor(const TfLiteEvalTensor * tensor)258 T ReInterpretTensor(const TfLiteEvalTensor* tensor) {
259   const float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
260   return reinterpret_cast<T>(tensor_base);
261 }
262 
263 template <class T>
ReInterpretTensor(TfLiteEvalTensor * tensor)264 T ReInterpretTensor(TfLiteEvalTensor* tensor) {
265   float* tensor_base = tflite::micro::GetTensorData<float>(tensor);
266   return reinterpret_cast<T>(tensor_base);
267 }
268 
DecodeCenterSizeBoxes(TfLiteContext * context,TfLiteNode * node,OpData * op_data)269 TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
270                                    OpData* op_data) {
271   // Parse input tensor boxencodings
272   const TfLiteEvalTensor* input_box_encodings =
273       tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
274   TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
275   const int num_boxes = input_box_encodings->dims->data[1];
276   TF_LITE_ENSURE(context, input_box_encodings->dims->data[2] >= kNumCoordBox);
277   const TfLiteEvalTensor* input_anchors =
278       tflite::micro::GetEvalInput(context, node, kInputTensorAnchors);
279 
280   // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
281   CenterSizeEncoding box_centersize;
282   CenterSizeEncoding scale_values = op_data->scale_values;
283   CenterSizeEncoding anchor;
284   for (int idx = 0; idx < num_boxes; ++idx) {
285     switch (input_box_encodings->type) {
286         // Quantized
287       case kTfLiteUInt8:
288         DequantizeBoxEncodings(
289             input_box_encodings, idx,
290             static_cast<float>(op_data->input_box_encodings.zero_point),
291             static_cast<float>(op_data->input_box_encodings.scale),
292             input_box_encodings->dims->data[2], &box_centersize);
293         DequantizeBoxEncodings(
294             input_anchors, idx,
295             static_cast<float>(op_data->input_anchors.zero_point),
296             static_cast<float>(op_data->input_anchors.scale), kNumCoordBox,
297             &anchor);
298         break;
299         // Float
300       case kTfLiteFloat32: {
301         // Please see DequantizeBoxEncodings function for the support detail.
302         const int box_encoding_idx = idx * input_box_encodings->dims->data[2];
303         const float* boxes = &(tflite::micro::GetTensorData<float>(
304             input_box_encodings)[box_encoding_idx]);
305         box_centersize = *reinterpret_cast<const CenterSizeEncoding*>(boxes);
306         anchor =
307             ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
308         break;
309       }
310       default:
311         // Unsupported type.
312         return kTfLiteError;
313     }
314 
315     float ycenter = static_cast<float>(static_cast<double>(box_centersize.y) /
316                                            static_cast<double>(scale_values.y) *
317                                            static_cast<double>(anchor.h) +
318                                        static_cast<double>(anchor.y));
319 
320     float xcenter = static_cast<float>(static_cast<double>(box_centersize.x) /
321                                            static_cast<double>(scale_values.x) *
322                                            static_cast<double>(anchor.w) +
323                                        static_cast<double>(anchor.x));
324 
325     float half_h =
326         static_cast<float>(0.5 *
327                            (std::exp(static_cast<double>(box_centersize.h) /
328                                      static_cast<double>(scale_values.h))) *
329                            static_cast<double>(anchor.h));
330     float half_w =
331         static_cast<float>(0.5 *
332                            (std::exp(static_cast<double>(box_centersize.w) /
333                                      static_cast<double>(scale_values.w))) *
334                            static_cast<double>(anchor.w));
335 
336     float* decoded_boxes = reinterpret_cast<float*>(
337         context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
338     auto& box = reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[idx];
339     box.ymin = ycenter - half_h;
340     box.xmin = xcenter - half_w;
341     box.ymax = ycenter + half_h;
342     box.xmax = xcenter + half_w;
343   }
344   return kTfLiteOk;
345 }
346 
DecreasingPartialArgSort(const float * values,int num_values,int num_to_sort,int * indices)347 void DecreasingPartialArgSort(const float* values, int num_values,
348                               int num_to_sort, int* indices) {
349   std::iota(indices, indices + num_values, 0);
350   std::partial_sort(
351       indices, indices + num_to_sort, indices + num_values,
352       [&values](const int i, const int j) { return values[i] > values[j]; });
353 }
354 
SelectDetectionsAboveScoreThreshold(const float * values,int size,const float threshold,float * keep_values,int * keep_indices)355 int SelectDetectionsAboveScoreThreshold(const float* values, int size,
356                                         const float threshold,
357                                         float* keep_values, int* keep_indices) {
358   int counter = 0;
359   for (int i = 0; i < size; i++) {
360     if (values[i] >= threshold) {
361       keep_values[counter] = values[i];
362       keep_indices[counter] = i;
363       counter++;
364     }
365   }
366   return counter;
367 }
368 
ValidateBoxes(const float * decoded_boxes,const int num_boxes)369 bool ValidateBoxes(const float* decoded_boxes, const int num_boxes) {
370   for (int i = 0; i < num_boxes; ++i) {
371     // ymax>=ymin, xmax>=xmin
372     auto& box = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
373     if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
374       return false;
375     }
376   }
377   return true;
378 }
379 
ComputeIntersectionOverUnion(const float * decoded_boxes,const int i,const int j)380 float ComputeIntersectionOverUnion(const float* decoded_boxes, const int i,
381                                    const int j) {
382   auto& box_i = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[i];
383   auto& box_j = reinterpret_cast<const BoxCornerEncoding*>(decoded_boxes)[j];
384   const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
385   const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
386   if (area_i <= 0 || area_j <= 0) return 0.0;
387   const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
388   const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
389   const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
390   const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
391   const float intersection_area =
392       std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
393       std::max<float>(intersection_xmax - intersection_xmin, 0.0);
394   return intersection_area / (area_i + area_j - intersection_area);
395 }
396 
397 // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
398 // before selecting the highest scoring boxes (max_detections in number)
399 // It assumes all boxes are good in beginning and sorts based on the scores.
400 // If lower-scoring box has too much overlap with a higher-scoring box,
401 // we get rid of the lower-scoring box.
402 // Complexity is O(N^2) pairwise comparison between boxes
NonMaxSuppressionSingleClassHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores,int * selected,int * selected_size,int max_detections)403 TfLiteStatus NonMaxSuppressionSingleClassHelper(
404     TfLiteContext* context, TfLiteNode* node, OpData* op_data,
405     const float* scores, int* selected, int* selected_size,
406     int max_detections) {
407   const TfLiteEvalTensor* input_box_encodings =
408       tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
409   const int num_boxes = input_box_encodings->dims->data[1];
410   const float non_max_suppression_score_threshold =
411       op_data->non_max_suppression_score_threshold;
412   const float intersection_over_union_threshold =
413       op_data->intersection_over_union_threshold;
414   // Maximum detections should be positive.
415   TF_LITE_ENSURE(context, (max_detections >= 0));
416   // intersection_over_union_threshold should be positive
417   // and should be less than 1.
418   TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
419                               (intersection_over_union_threshold <= 1.0f));
420   // Validate boxes
421   float* decoded_boxes = reinterpret_cast<float*>(
422       context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
423 
424   TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
425 
426   // threshold scores
427   int* keep_indices = reinterpret_cast<int*>(
428       context->GetScratchBuffer(context, op_data->keep_indices_idx));
429   float* keep_scores = reinterpret_cast<float*>(
430       context->GetScratchBuffer(context, op_data->keep_scores_idx));
431   int num_scores_kept = SelectDetectionsAboveScoreThreshold(
432       scores, num_boxes, non_max_suppression_score_threshold, keep_scores,
433       keep_indices);
434   int* sorted_indices = reinterpret_cast<int*>(
435       context->GetScratchBuffer(context, op_data->sorted_indices_idx));
436 
437   DecreasingPartialArgSort(keep_scores, num_scores_kept, num_scores_kept,
438                            sorted_indices);
439 
440   const int num_boxes_kept = num_scores_kept;
441   const int output_size = std::min(num_boxes_kept, max_detections);
442   *selected_size = 0;
443 
444   int num_active_candidate = num_boxes_kept;
445   uint8_t* active_box_candidate = reinterpret_cast<uint8_t*>(
446       context->GetScratchBuffer(context, op_data->active_candidate_idx));
447 
448   for (int row = 0; row < num_boxes_kept; row++) {
449     active_box_candidate[row] = 1;
450   }
451   for (int i = 0; i < num_boxes_kept; ++i) {
452     if (num_active_candidate == 0 || *selected_size >= output_size) break;
453     if (active_box_candidate[i] == 1) {
454       selected[(*selected_size)++] = keep_indices[sorted_indices[i]];
455       active_box_candidate[i] = 0;
456       num_active_candidate--;
457     } else {
458       continue;
459     }
460     for (int j = i + 1; j < num_boxes_kept; ++j) {
461       if (active_box_candidate[j] == 1) {
462         float intersection_over_union = ComputeIntersectionOverUnion(
463             decoded_boxes, keep_indices[sorted_indices[i]],
464             keep_indices[sorted_indices[j]]);
465 
466         if (intersection_over_union > intersection_over_union_threshold) {
467           active_box_candidate[j] = 0;
468           num_active_candidate--;
469         }
470       }
471     }
472   }
473 
474   return kTfLiteOk;
475 }
476 
477 // This function implements a regular version of Non Maximal Suppression (NMS)
478 // for multiple classes where
479 // 1) we do NMS separately for each class across all anchors and
480 // 2) keep only the highest anchor scores across all classes
481 // 3) The worst runtime of the regular NMS is O(K*N^2)
482 // where N is the number of anchors and K the number of
483 // classes.
NonMaxSuppressionMultiClassRegularHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)484 TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
485                                                       TfLiteNode* node,
486                                                       OpData* op_data,
487                                                       const float* scores) {
488   const TfLiteEvalTensor* input_box_encodings =
489       tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
490   const TfLiteEvalTensor* input_class_predictions =
491       tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
492   TfLiteEvalTensor* detection_boxes =
493       tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
494   TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
495       context, node, kOutputTensorDetectionClasses);
496   TfLiteEvalTensor* detection_scores =
497       tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
498   TfLiteEvalTensor* num_detections =
499       tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
500 
501   const int num_boxes = input_box_encodings->dims->data[1];
502   const int num_classes = op_data->num_classes;
503   const int num_detections_per_class = op_data->detections_per_class;
504   const int max_detections = op_data->max_detections;
505   const int num_classes_with_background =
506       input_class_predictions->dims->data[2];
507   // The row index offset is 1 if background class is included and 0 otherwise.
508   int label_offset = num_classes_with_background - num_classes;
509   TF_LITE_ENSURE(context, num_detections_per_class > 0);
510 
511   // For each class, perform non-max suppression.
512   float* class_scores = reinterpret_cast<float*>(
513       context->GetScratchBuffer(context, op_data->score_buffer_idx));
514   int* box_indices_after_regular_non_max_suppression = reinterpret_cast<int*>(
515       context->GetScratchBuffer(context, op_data->buffer_idx));
516   float* scores_after_regular_non_max_suppression =
517       reinterpret_cast<float*>(context->GetScratchBuffer(
518           context, op_data->scores_after_regular_non_max_suppression_idx));
519 
520   int size_of_sorted_indices = 0;
521   int* sorted_indices = reinterpret_cast<int*>(
522       context->GetScratchBuffer(context, op_data->sorted_indices_idx));
523   float* sorted_values = reinterpret_cast<float*>(
524       context->GetScratchBuffer(context, op_data->sorted_values_idx));
525 
526   for (int col = 0; col < num_classes; col++) {
527     for (int row = 0; row < num_boxes; row++) {
528       // Get scores of boxes corresponding to all anchors for single class
529       class_scores[row] =
530           *(scores + row * num_classes_with_background + col + label_offset);
531     }
532     // Perform non-maximal suppression on single class
533     int selected_size = 0;
534     int* selected = reinterpret_cast<int*>(
535         context->GetScratchBuffer(context, op_data->selected_idx));
536     TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
537         context, node, op_data, class_scores, selected, &selected_size,
538         num_detections_per_class));
539     // Add selected indices from non-max suppression of boxes in this class
540     int output_index = size_of_sorted_indices;
541     for (int i = 0; i < selected_size; i++) {
542       int selected_index = selected[i];
543 
544       box_indices_after_regular_non_max_suppression[output_index] =
545           (selected_index * num_classes_with_background + col + label_offset);
546       scores_after_regular_non_max_suppression[output_index] =
547           class_scores[selected_index];
548       output_index++;
549     }
550     // Sort the max scores among the selected indices
551     // Get the indices for top scores
552     int num_indices_to_sort = std::min(output_index, max_detections);
553     DecreasingPartialArgSort(scores_after_regular_non_max_suppression,
554                              output_index, num_indices_to_sort, sorted_indices);
555 
556     // Copy values to temporary vectors
557     for (int row = 0; row < num_indices_to_sort; row++) {
558       int temp = sorted_indices[row];
559       sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
560       sorted_values[row] = scores_after_regular_non_max_suppression[temp];
561     }
562     // Copy scores and indices from temporary vectors
563     for (int row = 0; row < num_indices_to_sort; row++) {
564       box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
565       scores_after_regular_non_max_suppression[row] = sorted_values[row];
566     }
567     size_of_sorted_indices = num_indices_to_sort;
568   }
569 
570   // Allocate output tensors
571   for (int output_box_index = 0; output_box_index < max_detections;
572        output_box_index++) {
573     if (output_box_index < size_of_sorted_indices) {
574       const int anchor_index = floor(
575           box_indices_after_regular_non_max_suppression[output_box_index] /
576           num_classes_with_background);
577       const int class_index =
578           box_indices_after_regular_non_max_suppression[output_box_index] -
579           anchor_index * num_classes_with_background - label_offset;
580       const float selected_score =
581           scores_after_regular_non_max_suppression[output_box_index];
582       // detection_boxes
583       float* decoded_boxes = reinterpret_cast<float*>(
584           context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
585       ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
586           reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[anchor_index];
587       // detection_classes
588       tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
589           class_index;
590       // detection_scores
591       tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
592           selected_score;
593     } else {
594       ReInterpretTensor<BoxCornerEncoding*>(
595           detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
596       // detection_classes
597       tflite::micro::GetTensorData<float>(detection_classes)[output_box_index] =
598           0.0f;
599       // detection_scores
600       tflite::micro::GetTensorData<float>(detection_scores)[output_box_index] =
601           0.0f;
602     }
603   }
604   tflite::micro::GetTensorData<float>(num_detections)[0] =
605       size_of_sorted_indices;
606 
607   return kTfLiteOk;
608 }
609 
610 // This function implements a fast version of Non Maximal Suppression for
611 // multiple classes where
612 // 1) we keep the top-k scores for each anchor and
613 // 2) during NMS, each anchor only uses the highest class score for sorting.
614 // 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
615 // instead of O(KN^2) where N is the number of anchors and K the number of
616 // classes.
NonMaxSuppressionMultiClassFastHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)617 TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
618                                                    TfLiteNode* node,
619                                                    OpData* op_data,
620                                                    const float* scores) {
621   const TfLiteEvalTensor* input_box_encodings =
622       tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
623   const TfLiteEvalTensor* input_class_predictions =
624       tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
625   TfLiteEvalTensor* detection_boxes =
626       tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionBoxes);
627 
628   TfLiteEvalTensor* detection_classes = tflite::micro::GetEvalOutput(
629       context, node, kOutputTensorDetectionClasses);
630   TfLiteEvalTensor* detection_scores =
631       tflite::micro::GetEvalOutput(context, node, kOutputTensorDetectionScores);
632   TfLiteEvalTensor* num_detections =
633       tflite::micro::GetEvalOutput(context, node, kOutputTensorNumDetections);
634 
635   const int num_boxes = input_box_encodings->dims->data[1];
636   const int num_classes = op_data->num_classes;
637   const int max_categories_per_anchor = op_data->max_classes_per_detection;
638   const int num_classes_with_background =
639       input_class_predictions->dims->data[2];
640 
641   // The row index offset is 1 if background class is included and 0 otherwise.
642   int label_offset = num_classes_with_background - num_classes;
643   TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
644   const int num_categories_per_anchor =
645       std::min(max_categories_per_anchor, num_classes);
646   float* max_scores = reinterpret_cast<float*>(
647       context->GetScratchBuffer(context, op_data->score_buffer_idx));
648   int* sorted_class_indices = reinterpret_cast<int*>(
649       context->GetScratchBuffer(context, op_data->buffer_idx));
650 
651   for (int row = 0; row < num_boxes; row++) {
652     const float* box_scores =
653         scores + row * num_classes_with_background + label_offset;
654     int* class_indices = sorted_class_indices + row * num_classes;
655     DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
656                              class_indices);
657     max_scores[row] = box_scores[class_indices[0]];
658   }
659 
660   // Perform non-maximal suppression on max scores
661   int selected_size = 0;
662   int* selected = reinterpret_cast<int*>(
663       context->GetScratchBuffer(context, op_data->selected_idx));
664   TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
665       context, node, op_data, max_scores, selected, &selected_size,
666       op_data->max_detections));
667 
668   // Allocate output tensors
669   int output_box_index = 0;
670 
671   for (int i = 0; i < selected_size; i++) {
672     int selected_index = selected[i];
673 
674     const float* box_scores =
675         scores + selected_index * num_classes_with_background + label_offset;
676     const int* class_indices =
677         sorted_class_indices + selected_index * num_classes;
678 
679     for (int col = 0; col < num_categories_per_anchor; ++col) {
680       int box_offset = num_categories_per_anchor * output_box_index + col;
681 
682       // detection_boxes
683       float* decoded_boxes = reinterpret_cast<float*>(
684           context->GetScratchBuffer(context, op_data->decoded_boxes_idx));
685       ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
686           reinterpret_cast<BoxCornerEncoding*>(decoded_boxes)[selected_index];
687 
688       // detection_classes
689       tflite::micro::GetTensorData<float>(detection_classes)[box_offset] =
690           class_indices[col];
691 
692       // detection_scores
693       tflite::micro::GetTensorData<float>(detection_scores)[box_offset] =
694           box_scores[class_indices[col]];
695 
696       output_box_index++;
697     }
698   }
699 
700   tflite::micro::GetTensorData<float>(num_detections)[0] = output_box_index;
701   return kTfLiteOk;
702 }
703 
DequantizeClassPredictions(const TfLiteEvalTensor * input_class_predictions,const int num_boxes,const int num_classes_with_background,float * scores,OpData * op_data)704 void DequantizeClassPredictions(const TfLiteEvalTensor* input_class_predictions,
705                                 const int num_boxes,
706                                 const int num_classes_with_background,
707                                 float* scores, OpData* op_data) {
708   float quant_zero_point =
709       static_cast<float>(op_data->input_class_predictions.zero_point);
710   float quant_scale =
711       static_cast<float>(op_data->input_class_predictions.scale);
712   Dequantizer dequantize(quant_zero_point, quant_scale);
713   const uint8_t* scores_quant =
714       tflite::micro::GetTensorData<uint8_t>(input_class_predictions);
715   for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
716     scores[idx] = dequantize(scores_quant[idx]);
717   }
718 }
719 
NonMaxSuppressionMultiClass(TfLiteContext * context,TfLiteNode * node,OpData * op_data)720 TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
721                                          TfLiteNode* node, OpData* op_data) {
722   // Get the input tensors
723   const TfLiteEvalTensor* input_box_encodings =
724       tflite::micro::GetEvalInput(context, node, kInputTensorBoxEncodings);
725   const TfLiteEvalTensor* input_class_predictions =
726       tflite::micro::GetEvalInput(context, node, kInputTensorClassPredictions);
727   const int num_boxes = input_box_encodings->dims->data[1];
728   const int num_classes = op_data->num_classes;
729 
730   TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
731                     kBatchSize);
732   TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
733   const int num_classes_with_background =
734       input_class_predictions->dims->data[2];
735 
736   TF_LITE_ENSURE(context, (num_classes_with_background - num_classes <= 1));
737   TF_LITE_ENSURE(context, (num_classes_with_background >= num_classes));
738 
739   const float* scores;
740   switch (input_class_predictions->type) {
741     case kTfLiteUInt8: {
742       float* temporary_scores = reinterpret_cast<float*>(
743           context->GetScratchBuffer(context, op_data->scores_idx));
744       DequantizeClassPredictions(input_class_predictions, num_boxes,
745                                  num_classes_with_background, temporary_scores,
746                                  op_data);
747       scores = temporary_scores;
748     } break;
749     case kTfLiteFloat32:
750       scores = tflite::micro::GetTensorData<float>(input_class_predictions);
751       break;
752     default:
753       // Unsupported type.
754       return kTfLiteError;
755   }
756 
757   if (op_data->use_regular_non_max_suppression) {
758     TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
759         context, node, op_data, scores));
760   } else {
761     TF_LITE_ENSURE_STATUS(
762         NonMaxSuppressionMultiClassFastHelper(context, node, op_data, scores));
763   }
764 
765   return kTfLiteOk;
766 }
767 
Eval(TfLiteContext * context,TfLiteNode * node)768 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
769   TF_LITE_ENSURE(context, (kBatchSize == 1));
770   auto* op_data = static_cast<OpData*>(node->user_data);
771 
772   // These two functions correspond to two blocks in the Object Detection model.
773   // In future, we would like to break the custom op in two blocks, which is
774   // currently not feasible because we would like to input quantized inputs
775   // and do all calculations in float. Mixed quantized/float calculations are
776   // currently not supported in TFLite.
777 
778   // This fills in temporary decoded_boxes
779   // by transforming input_box_encodings and input_anchors from
780   // CenterSizeEncodings to BoxCornerEncoding
781   TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
782 
783   // This fills in the output tensors
784   // by choosing effective set of decoded boxes
785   // based on Non Maximal Suppression, i.e. selecting
786   // highest scoring non-overlapping boxes.
787   TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
788 
789   return kTfLiteOk;
790 }
791 }  // namespace
792 
Register_DETECTION_POSTPROCESS()793 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
794   static TfLiteRegistration r = {/*init=*/Init,
795                                  /*free=*/Free,
796                                  /*prepare=*/Prepare,
797                                  /*invoke=*/Eval,
798                                  /*profiling_string=*/nullptr,
799                                  /*builtin_code=*/0,
800                                  /*custom_name=*/nullptr,
801                                  /*version=*/0};
802   return &r;
803 }
804 
805 }  // namespace tflite
806