• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <numeric>
17 #include <vector>
18 #include "flatbuffers/flexbuffers.h"  // TF:flatbuffers
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/c_api_internal.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace custom {
30 namespace detection_postprocess {
31 
32 // Input tensors
33 constexpr int kInputTensorBoxEncodings = 0;
34 constexpr int kInputTensorClassPredictions = 1;
35 constexpr int kInputTensorAnchors = 2;
36 
37 // Output tensors
38 constexpr int kOutputTensorDetectionBoxes = 0;
39 constexpr int kOutputTensorDetectionClasses = 1;
40 constexpr int kOutputTensorDetectionScores = 2;
41 constexpr int kOutputTensorNumDetections = 3;
42 
43 constexpr int kNumCoordBox = 4;
44 constexpr int kBatchSize = 1;
45 
46 constexpr int kNumDetectionsPerClass = 100;
47 
48 // Object Detection model produces axis-aligned boxes in two formats:
49 // BoxCorner represents the lower left corner (xmin, ymin) and
50 // the upper right corner (xmax, ymax).
51 // CenterSize represents the center (xcenter, ycenter), height and width.
52 // BoxCornerEncoding and CenterSizeEncoding are related as follows:
53 // ycenter = y / y_scale * anchor.h + anchor.y;
54 // xcenter = x / x_scale * anchor.w + anchor.x;
55 // half_h = 0.5*exp(h/ h_scale)) * anchor.h;
56 // half_w = 0.5*exp(w / w_scale)) * anchor.w;
57 // ymin = ycenter - half_h
58 // ymax = ycenter + half_h
59 // xmin = xcenter - half_w
60 // xmax = xcenter + half_w
61 struct BoxCornerEncoding {
62   float ymin;
63   float xmin;
64   float ymax;
65   float xmax;
66 };
67 
68 struct CenterSizeEncoding {
69   float y;
70   float x;
71   float h;
72   float w;
73 };
74 // We make sure that the memory allocations are contiguous with static assert.
75 static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
76               "Size of BoxCornerEncoding is 4 float values");
77 static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
78               "Size of CenterSizeEncoding is 4 float values");
79 
80 struct OpData {
81   int max_detections;
82   int max_classes_per_detection;  // Fast Non-Max-Suppression
83   int detections_per_class;       // Regular Non-Max-Suppression
84   float non_max_suppression_score_threshold;
85   float intersection_over_union_threshold;
86   int num_classes;
87   bool use_regular_non_max_suppression;
88   CenterSizeEncoding scale_values;
89   // Indices of Temporary tensors
90   int decoded_boxes_index;
91   int scores_index;
92   int active_candidate_index;
93 };
94 
Init(TfLiteContext * context,const char * buffer,size_t length)95 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
96   auto* op_data = new OpData;
97   const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
98   const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
99   op_data->max_detections = m["max_detections"].AsInt32();
100   op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
101   if (m["detections_per_class"].IsNull())
102     op_data->detections_per_class = kNumDetectionsPerClass;
103   else
104     op_data->detections_per_class = m["detections_per_class"].AsInt32();
105   if (m["use_regular_nms"].IsNull())
106     op_data->use_regular_non_max_suppression = false;
107   else
108     op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
109 
110   op_data->non_max_suppression_score_threshold =
111       m["nms_score_threshold"].AsFloat();
112   op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
113   op_data->num_classes = m["num_classes"].AsInt32();
114   op_data->scale_values.y = m["y_scale"].AsFloat();
115   op_data->scale_values.x = m["x_scale"].AsFloat();
116   op_data->scale_values.h = m["h_scale"].AsFloat();
117   op_data->scale_values.w = m["w_scale"].AsFloat();
118   context->AddTensors(context, 1, &op_data->decoded_boxes_index);
119   context->AddTensors(context, 1, &op_data->scores_index);
120   context->AddTensors(context, 1, &op_data->active_candidate_index);
121   return op_data;
122 }
123 
Free(TfLiteContext * context,void * buffer)124 void Free(TfLiteContext* context, void* buffer) {
125   delete reinterpret_cast<OpData*>(buffer);
126 }
127 
128 // TODO(chowdhery): Add to kernel_util.h
SetTensorSizes(TfLiteContext * context,TfLiteTensor * tensor,std::initializer_list<int> values)129 TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
130                             std::initializer_list<int> values) {
131   TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
132   int index = 0;
133   for (int v : values) {
134     size->data[index] = v;
135     ++index;
136   }
137   return context->ResizeTensor(context, tensor, size);
138 }
139 
Prepare(TfLiteContext * context,TfLiteNode * node)140 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
141   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
142   // Inputs: box_encodings, scores, anchors
143   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
144   const TfLiteTensor* input_box_encodings =
145       GetInput(context, node, kInputTensorBoxEncodings);
146   const TfLiteTensor* input_class_predictions =
147       GetInput(context, node, kInputTensorClassPredictions);
148   const TfLiteTensor* input_anchors =
149       GetInput(context, node, kInputTensorAnchors);
150   TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
151   TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
152   TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
153   // number of detected boxes
154   const int num_detected_boxes =
155       op_data->max_detections * op_data->max_classes_per_detection;
156 
157   // Outputs: detection_boxes, detection_scores, detection_classes,
158   // num_detections
159   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
160   // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
161   TfLiteTensor* detection_boxes =
162       GetOutput(context, node, kOutputTensorDetectionBoxes);
163   detection_boxes->type = kTfLiteFloat32;
164   SetTensorSizes(context, detection_boxes,
165                  {kBatchSize, num_detected_boxes, kNumCoordBox});
166 
167   // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
168   TfLiteTensor* detection_classes =
169       GetOutput(context, node, kOutputTensorDetectionClasses);
170   detection_classes->type = kTfLiteFloat32;
171   SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
172 
173   // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
174   TfLiteTensor* detection_scores =
175       GetOutput(context, node, kOutputTensorDetectionScores);
176   detection_scores->type = kTfLiteFloat32;
177   SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
178 
179   // Output Tensor num_detections: size is set to 1
180   TfLiteTensor* num_detections =
181       GetOutput(context, node, kOutputTensorNumDetections);
182   num_detections->type = kTfLiteFloat32;
183   // TODO (chowdhery): Make it a scalar when available
184   SetTensorSizes(context, num_detections, {1});
185 
186   // Temporary tensors
187   TfLiteIntArrayFree(node->temporaries);
188   node->temporaries = TfLiteIntArrayCreate(3);
189   node->temporaries->data[0] = op_data->decoded_boxes_index;
190   node->temporaries->data[1] = op_data->scores_index;
191   node->temporaries->data[2] = op_data->active_candidate_index;
192 
193   // decoded_boxes
194   TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
195   decoded_boxes->type = kTfLiteFloat32;
196   decoded_boxes->allocation_type = kTfLiteArenaRw;
197   SetTensorSizes(context, decoded_boxes,
198                  {input_box_encodings->dims->data[1], kNumCoordBox});
199 
200   // scores
201   TfLiteTensor* scores = &context->tensors[op_data->scores_index];
202   scores->type = kTfLiteFloat32;
203   scores->allocation_type = kTfLiteArenaRw;
204   SetTensorSizes(context, scores,
205                  {input_class_predictions->dims->data[1],
206                   input_class_predictions->dims->data[2]});
207 
208   // active_candidate
209   TfLiteTensor* active_candidate =
210       &context->tensors[op_data->active_candidate_index];
211   active_candidate->type = kTfLiteUInt8;
212   active_candidate->allocation_type = kTfLiteArenaRw;
213   SetTensorSizes(context, active_candidate,
214                  {input_box_encodings->dims->data[1]});
215 
216   return kTfLiteOk;
217 }
218 
219 class Dequantizer {
220  public:
Dequantizer(int zero_point,float scale)221   Dequantizer(int zero_point, float scale)
222       : zero_point_(zero_point), scale_(scale) {}
operator ()(uint8 x)223   float operator()(uint8 x) {
224     return (static_cast<float>(x) - zero_point_) * scale_;
225   }
226 
227  private:
228   int zero_point_;
229   float scale_;
230 };
231 
DequantizeBoxEncodings(const TfLiteTensor * input_box_encodings,int idx,float quant_zero_point,float quant_scale,CenterSizeEncoding * box_centersize)232 void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
233                             float quant_zero_point, float quant_scale,
234                             CenterSizeEncoding* box_centersize) {
235   const uint8* boxes =
236       GetTensorData<uint8>(input_box_encodings) + kNumCoordBox * idx;
237   Dequantizer dequantize(quant_zero_point, quant_scale);
238   box_centersize->y = dequantize(boxes[0]);
239   box_centersize->x = dequantize(boxes[1]);
240   box_centersize->h = dequantize(boxes[2]);
241   box_centersize->w = dequantize(boxes[3]);
242 }
243 
244 template <class T>
ReInterpretTensor(const TfLiteTensor * tensor)245 T ReInterpretTensor(const TfLiteTensor* tensor) {
246   // TODO (chowdhery): check float
247   const float* tensor_base = tensor->data.f;
248   return reinterpret_cast<T>(tensor_base);
249 }
250 
251 template <class T>
ReInterpretTensor(TfLiteTensor * tensor)252 T ReInterpretTensor(TfLiteTensor* tensor) {
253   // TODO (chowdhery): check float
254   float* tensor_base = tensor->data.f;
255   return reinterpret_cast<T>(tensor_base);
256 }
257 
DecodeCenterSizeBoxes(TfLiteContext * context,TfLiteNode * node,OpData * op_data)258 TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
259                                    OpData* op_data) {
260   // Parse input tensor boxencodings
261   const TfLiteTensor* input_box_encodings =
262       GetInput(context, node, kInputTensorBoxEncodings);
263   TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
264   const int num_boxes = input_box_encodings->dims->data[1];
265   TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox);
266   const TfLiteTensor* input_anchors =
267       GetInput(context, node, kInputTensorAnchors);
268 
269   // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
270   CenterSizeEncoding box_centersize;
271   CenterSizeEncoding scale_values = op_data->scale_values;
272   CenterSizeEncoding anchor;
273   for (int idx = 0; idx < num_boxes; ++idx) {
274     switch (input_box_encodings->type) {
275         // Quantized
276       case kTfLiteUInt8:
277         DequantizeBoxEncodings(
278             input_box_encodings, idx,
279             static_cast<float>(input_box_encodings->params.zero_point),
280             static_cast<float>(input_box_encodings->params.scale),
281             &box_centersize);
282         DequantizeBoxEncodings(
283             input_anchors, idx,
284             static_cast<float>(input_anchors->params.zero_point),
285             static_cast<float>(input_anchors->params.scale), &anchor);
286         break;
287         // Float
288       case kTfLiteFloat32:
289         box_centersize = ReInterpretTensor<const CenterSizeEncoding*>(
290             input_box_encodings)[idx];
291         anchor =
292             ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
293         break;
294       default:
295         // Unsupported type.
296         return kTfLiteError;
297     }
298 
299     float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y;
300     float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x;
301     float half_h =
302         0.5f * static_cast<float>(std::exp(box_centersize.h / scale_values.h)) *
303         anchor.h;
304     float half_w =
305         0.5f * static_cast<float>(std::exp(box_centersize.w / scale_values.w)) *
306         anchor.w;
307     TfLiteTensor* decoded_boxes =
308         &context->tensors[op_data->decoded_boxes_index];
309     auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
310     box.ymin = ycenter - half_h;
311     box.xmin = xcenter - half_w;
312     box.ymax = ycenter + half_h;
313     box.xmax = xcenter + half_w;
314   }
315   return kTfLiteOk;
316 }
317 
DecreasingPartialArgSort(const float * values,int num_values,int num_to_sort,int * indices)318 void DecreasingPartialArgSort(const float* values, int num_values,
319                               int num_to_sort, int* indices) {
320   std::iota(indices, indices + num_values, 0);
321   std::partial_sort(
322       indices, indices + num_to_sort, indices + num_values,
323       [&values](const int i, const int j) { return values[i] > values[j]; });
324 }
325 
SelectDetectionsAboveScoreThreshold(const std::vector<float> & values,const float threshold,std::vector<float> * keep_values,std::vector<int> * keep_indices)326 void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
327                                          const float threshold,
328                                          std::vector<float>* keep_values,
329                                          std::vector<int>* keep_indices) {
330   for (int i = 0; i < values.size(); i++) {
331     if (values[i] >= threshold) {
332       keep_values->emplace_back(values[i]);
333       keep_indices->emplace_back(i);
334     }
335   }
336 }
337 
ValidateBoxes(const TfLiteTensor * decoded_boxes,const int num_boxes)338 bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
339   for (int i = 0; i < num_boxes; ++i) {
340     // ymax>=ymin, xmax>=xmin
341     auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
342     if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
343       return false;
344     }
345   }
346   return true;
347 }
348 
ComputeIntersectionOverUnion(const TfLiteTensor * decoded_boxes,const int i,const int j)349 float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
350                                    const int i, const int j) {
351   auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
352   auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
353   const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
354   const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
355   if (area_i <= 0 || area_j <= 0) return 0.0;
356   const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
357   const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
358   const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
359   const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
360   const float intersection_area =
361       std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
362       std::max<float>(intersection_xmax - intersection_xmin, 0.0);
363   return intersection_area / (area_i + area_j - intersection_area);
364 }
365 
366 // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
367 // before selecting the highest scoring boxes (max_detections in number)
368 // It assumes all boxes are good in beginning and sorts based on the scores.
369 // If lower-scoring box has too much overlap with a higher-scoring box,
370 // we get rid of the lower-scoring box.
371 // Complexity is O(N^2) pairwise comparison between boxes
NonMaxSuppressionSingleClassHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const std::vector<float> & scores,std::vector<int> * selected,int max_detections)372 TfLiteStatus NonMaxSuppressionSingleClassHelper(
373     TfLiteContext* context, TfLiteNode* node, OpData* op_data,
374     const std::vector<float>& scores, std::vector<int>* selected,
375     int max_detections) {
376   const TfLiteTensor* input_box_encodings =
377       GetInput(context, node, kInputTensorBoxEncodings);
378   const TfLiteTensor* decoded_boxes =
379       &context->tensors[op_data->decoded_boxes_index];
380   const int num_boxes = input_box_encodings->dims->data[1];
381   const float non_max_suppression_score_threshold =
382       op_data->non_max_suppression_score_threshold;
383   const float intersection_over_union_threshold =
384       op_data->intersection_over_union_threshold;
385   // Maximum detections should be positive.
386   TF_LITE_ENSURE(context, (max_detections >= 0));
387   // intersection_over_union_threshold should be positive
388   // and should be less than 1.
389   TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
390                               (intersection_over_union_threshold <= 1.0f));
391   // Validate boxes
392   TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
393 
394   // threshold scores
395   std::vector<int> keep_indices;
396   // TODO (chowdhery): Remove the dynamic allocation and replace it
397   // with temporaries, esp for std::vector<float>
398   std::vector<float> keep_scores;
399   SelectDetectionsAboveScoreThreshold(
400       scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
401 
402   int num_scores_kept = keep_scores.size();
403   std::vector<int> sorted_indices;
404   sorted_indices.resize(num_scores_kept);
405   DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept,
406                            sorted_indices.data());
407   const int num_boxes_kept = num_scores_kept;
408   const int output_size = std::min(num_boxes_kept, max_detections);
409   selected->clear();
410   TfLiteTensor* active_candidate =
411       &context->tensors[op_data->active_candidate_index];
412   TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes);
413   int num_active_candidate = num_boxes_kept;
414   uint8_t* active_box_candidate = (active_candidate->data.uint8);
415   for (int row = 0; row < num_boxes_kept; row++) {
416     active_box_candidate[row] = 1;
417   }
418 
419   for (int i = 0; i < num_boxes_kept; ++i) {
420     if (num_active_candidate == 0 || selected->size() >= output_size) break;
421     if (active_box_candidate[i] == 1) {
422       selected->push_back(keep_indices[sorted_indices[i]]);
423       active_box_candidate[i] = 0;
424       num_active_candidate--;
425     } else {
426       continue;
427     }
428     for (int j = i + 1; j < num_boxes_kept; ++j) {
429       if (active_box_candidate[j] == 1) {
430         float intersection_over_union = ComputeIntersectionOverUnion(
431             decoded_boxes, keep_indices[sorted_indices[i]],
432             keep_indices[sorted_indices[j]]);
433 
434         if (intersection_over_union > intersection_over_union_threshold) {
435           active_box_candidate[j] = 0;
436           num_active_candidate--;
437         }
438       }
439     }
440   }
441   return kTfLiteOk;
442 }
443 
444 // This function implements a regular version of Non Maximal Suppression (NMS)
445 // for multiple classes where
446 // 1) we do NMS separately for each class across all anchors and
447 // 2) keep only the highest anchor scores across all classes
448 // 3) The worst runtime of the regular NMS is O(K*N^2)
449 // where N is the number of anchors and K the number of
450 // classes.
NonMaxSuppressionMultiClassRegularHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)451 TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
452                                                       TfLiteNode* node,
453                                                       OpData* op_data,
454                                                       const float* scores) {
455   const TfLiteTensor* input_box_encodings =
456       GetInput(context, node, kInputTensorBoxEncodings);
457   const TfLiteTensor* decoded_boxes =
458       &context->tensors[op_data->decoded_boxes_index];
459 
460   TfLiteTensor* detection_boxes =
461       GetOutput(context, node, kOutputTensorDetectionBoxes);
462   TfLiteTensor* detection_classes =
463       GetOutput(context, node, kOutputTensorDetectionClasses);
464   TfLiteTensor* detection_scores =
465       GetOutput(context, node, kOutputTensorDetectionScores);
466   TfLiteTensor* num_detections =
467       GetOutput(context, node, kOutputTensorNumDetections);
468 
469   const int num_boxes = input_box_encodings->dims->data[1];
470   const int num_classes = op_data->num_classes;
471   const int num_detections_per_class = op_data->detections_per_class;
472   const int max_detections = op_data->max_detections;
473   // The row index offset is 1 if background class is included and 0 otherwise.
474   const int label_offset = 1;
475   TF_LITE_ENSURE(context, label_offset != -1);
476   TF_LITE_ENSURE(context, num_detections_per_class > 0);
477   const int num_classes_with_background = num_classes + label_offset;
478 
479   // For each class, perform non-max suppression.
480   std::vector<float> class_scores(num_boxes);
481 
482   std::vector<int> box_indices_after_regular_non_max_suppression(
483       num_boxes + max_detections);
484   std::vector<float> scores_after_regular_non_max_suppression(num_boxes +
485                                                               max_detections);
486 
487   int size_of_sorted_indices = 0;
488   std::vector<int> sorted_indices;
489   sorted_indices.resize(max_detections);
490   std::vector<float> sorted_values;
491   sorted_values.resize(max_detections);
492 
493   for (int col = 0; col < num_classes; col++) {
494     for (int row = 0; row < num_boxes; row++) {
495       // Get scores of boxes corresponding to all anchors for single class
496       class_scores[row] =
497           *(scores + row * num_classes_with_background + col + label_offset);
498     }
499     // Perform non-maximal suppression on single class
500     std::vector<int> selected;
501     TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
502         context, node, op_data, class_scores, &selected,
503         num_detections_per_class));
504     // Add selected indices from non-max suppression of boxes in this class
505     int output_index = size_of_sorted_indices;
506     for (int selected_index : selected) {
507       box_indices_after_regular_non_max_suppression[output_index] =
508           (selected_index * num_classes_with_background + col + label_offset);
509       scores_after_regular_non_max_suppression[output_index] =
510           class_scores[selected_index];
511       output_index++;
512     }
513     // Sort the max scores among the selected indices
514     // Get the indices for top scores
515     int num_indices_to_sort = std::min(output_index, max_detections);
516     DecreasingPartialArgSort(scores_after_regular_non_max_suppression.data(),
517                              output_index, num_indices_to_sort,
518                              sorted_indices.data());
519 
520     // Copy values to temporary vectors
521     for (int row = 0; row < num_indices_to_sort; row++) {
522       int temp = sorted_indices[row];
523       sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
524       sorted_values[row] = scores_after_regular_non_max_suppression[temp];
525     }
526     // Copy scores and indices from temporary vectors
527     for (int row = 0; row < num_indices_to_sort; row++) {
528       box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
529       scores_after_regular_non_max_suppression[row] = sorted_values[row];
530     }
531     size_of_sorted_indices = num_indices_to_sort;
532   }
533 
534   // Allocate output tensors
535   for (int output_box_index = 0; output_box_index < max_detections;
536        output_box_index++) {
537     if (output_box_index < size_of_sorted_indices) {
538       const int anchor_index = floor(
539           box_indices_after_regular_non_max_suppression[output_box_index] /
540           num_classes_with_background);
541       const int class_index =
542           box_indices_after_regular_non_max_suppression[output_box_index] -
543           anchor_index * num_classes_with_background - label_offset;
544       const float selected_score =
545           scores_after_regular_non_max_suppression[output_box_index];
546       // detection_boxes
547       ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
548           ReInterpretTensor<const BoxCornerEncoding*>(
549               decoded_boxes)[anchor_index];
550       // detection_classes
551       detection_classes->data.f[output_box_index] = class_index;
552       // detection_scores
553       detection_scores->data.f[output_box_index] = selected_score;
554     } else {
555       ReInterpretTensor<BoxCornerEncoding*>(
556           detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
557       // detection_classes
558       detection_classes->data.f[output_box_index] = 0.0f;
559       // detection_scores
560       detection_scores->data.f[output_box_index] = 0.0f;
561     }
562   }
563   num_detections->data.f[0] = size_of_sorted_indices;
564   box_indices_after_regular_non_max_suppression.clear();
565   scores_after_regular_non_max_suppression.clear();
566   return kTfLiteOk;
567 }
568 
569 // This function implements a fast version of Non Maximal Suppression for
570 // multiple classes where
571 // 1) we keep the top-k scores for each anchor and
572 // 2) during NMS, each anchor only uses the highest class score for sorting.
573 // 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
574 // instead of O(KN^2) where N is the number of anchors and K the number of
575 // classes.
NonMaxSuppressionMultiClassFastHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)576 TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
577                                                    TfLiteNode* node,
578                                                    OpData* op_data,
579                                                    const float* scores) {
580   const TfLiteTensor* input_box_encodings =
581       GetInput(context, node, kInputTensorBoxEncodings);
582   const TfLiteTensor* decoded_boxes =
583       &context->tensors[op_data->decoded_boxes_index];
584 
585   TfLiteTensor* detection_boxes =
586       GetOutput(context, node, kOutputTensorDetectionBoxes);
587   TfLiteTensor* detection_classes =
588       GetOutput(context, node, kOutputTensorDetectionClasses);
589   TfLiteTensor* detection_scores =
590       GetOutput(context, node, kOutputTensorDetectionScores);
591   TfLiteTensor* num_detections =
592       GetOutput(context, node, kOutputTensorNumDetections);
593 
594   const int num_boxes = input_box_encodings->dims->data[1];
595   const int num_classes = op_data->num_classes;
596   const int max_categories_per_anchor = op_data->max_classes_per_detection;
597   // The row index offset is 1 if background class is included and 0 otherwise.
598   const int label_offset = 1;
599   TF_LITE_ENSURE(context, (label_offset != -1));
600   TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
601   const int num_classes_with_background = num_classes + label_offset;
602   const int num_categories_per_anchor =
603       std::min(max_categories_per_anchor, num_classes);
604   std::vector<float> max_scores;
605   max_scores.resize(num_boxes);
606   std::vector<int> sorted_class_indices;
607   sorted_class_indices.resize(num_boxes * num_classes);
608   for (int row = 0; row < num_boxes; row++) {
609     const float* box_scores =
610         scores + row * num_classes_with_background + label_offset;
611     int* class_indices = sorted_class_indices.data() + row * num_classes;
612     DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
613                              class_indices);
614     max_scores[row] = box_scores[class_indices[0]];
615   }
616   // Perform non-maximal suppression on max scores
617   std::vector<int> selected;
618   TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
619       context, node, op_data, max_scores, &selected, op_data->max_detections));
620   // Allocate output tensors
621   int output_box_index = 0;
622   for (const auto& selected_index : selected) {
623     const float* box_scores =
624         scores + selected_index * num_classes_with_background + label_offset;
625     const int* class_indices =
626         sorted_class_indices.data() + selected_index * num_classes;
627 
628     for (int col = 0; col < num_categories_per_anchor; ++col) {
629       int box_offset = num_categories_per_anchor * output_box_index + col;
630       // detection_boxes
631       ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
632           ReInterpretTensor<const BoxCornerEncoding*>(
633               decoded_boxes)[selected_index];
634       // detection_classes
635       detection_classes->data.f[box_offset] = class_indices[col];
636       // detection_scores
637       detection_scores->data.f[box_offset] = box_scores[class_indices[col]];
638       output_box_index++;
639     }
640   }
641   num_detections->data.f[0] = output_box_index;
642   return kTfLiteOk;
643 }
644 
DequantizeClassPredictions(const TfLiteTensor * input_class_predictions,const int num_boxes,const int num_classes_with_background,const TfLiteTensor * scores)645 void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
646                                 const int num_boxes,
647                                 const int num_classes_with_background,
648                                 const TfLiteTensor* scores) {
649   float quant_zero_point =
650       static_cast<float>(input_class_predictions->params.zero_point);
651   float quant_scale = static_cast<float>(input_class_predictions->params.scale);
652   Dequantizer dequantize(quant_zero_point, quant_scale);
653   const uint8* scores_quant = GetTensorData<uint8>(input_class_predictions);
654   for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
655     scores->data.f[idx] = dequantize(scores_quant[idx]);
656   }
657 }
658 
NonMaxSuppressionMultiClass(TfLiteContext * context,TfLiteNode * node,OpData * op_data)659 TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
660                                          TfLiteNode* node, OpData* op_data) {
661   // Get the input tensors
662   const TfLiteTensor* input_box_encodings =
663       GetInput(context, node, kInputTensorBoxEncodings);
664   const TfLiteTensor* input_class_predictions =
665       GetInput(context, node, kInputTensorClassPredictions);
666   const int num_boxes = input_box_encodings->dims->data[1];
667   const int num_classes = op_data->num_classes;
668   TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
669                     kBatchSize);
670   TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
671   const int num_classes_with_background =
672       input_class_predictions->dims->data[2];
673 
674   TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1));
675 
676   const TfLiteTensor* scores;
677   switch (input_class_predictions->type) {
678     case kTfLiteUInt8: {
679       TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
680       DequantizeClassPredictions(input_class_predictions, num_boxes,
681                                  num_classes_with_background, temporary_scores);
682       scores = temporary_scores;
683     } break;
684     case kTfLiteFloat32:
685       scores = input_class_predictions;
686       break;
687     default:
688       // Unsupported type.
689       return kTfLiteError;
690   }
691   if (op_data->use_regular_non_max_suppression)
692     TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
693         context, node, op_data, GetTensorData<float>(scores)));
694   else
695     TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper(
696         context, node, op_data, GetTensorData<float>(scores)));
697 
698   return kTfLiteOk;
699 }
700 
Eval(TfLiteContext * context,TfLiteNode * node)701 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
702   // TODO(chowdhery): Generalize for any batch size
703   TF_LITE_ENSURE(context, (kBatchSize == 1));
704   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
705   // These two functions correspond to two blocks in the Object Detection model.
706   // In future, we would like to break the custom op in two blocks, which is
707   // currently not feasible because we would like to input quantized inputs
708   // and do all calculations in float. Mixed quantized/float calculations are
709   // currently not supported in TFLite.
710 
711   // This fills in temporary decoded_boxes
712   // by transforming input_box_encodings and input_anchors from
713   // CenterSizeEncodings to BoxCornerEncoding
714   TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
715   // This fills in the output tensors
716   // by choosing effective set of decoded boxes
717   // based on Non Maximal Suppression, i.e. selecting
718   // highest scoring non-overlapping boxes.
719   TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
720 
721   return kTfLiteOk;
722 }
723 }  // namespace detection_postprocess
724 
Register_DETECTION_POSTPROCESS()725 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
726   static TfLiteRegistration r = {detection_postprocess::Init,
727                                  detection_postprocess::Free,
728                                  detection_postprocess::Prepare,
729                                  detection_postprocess::Eval};
730   return &r;
731 }
732 
733 }  // namespace custom
734 }  // namespace ops
735 }  // namespace tflite
736