1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <numeric>
17 #include <vector>
18 #include "flatbuffers/flexbuffers.h" // TF:flatbuffers
19 #include "tensorflow/lite/c/builtin_op_data.h"
20 #include "tensorflow/lite/c/c_api_internal.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23 #include "tensorflow/lite/kernels/internal/tensor.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/kernels/op_macros.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace custom {
30 namespace detection_postprocess {
31
32 // Input tensors
33 constexpr int kInputTensorBoxEncodings = 0;
34 constexpr int kInputTensorClassPredictions = 1;
35 constexpr int kInputTensorAnchors = 2;
36
37 // Output tensors
38 constexpr int kOutputTensorDetectionBoxes = 0;
39 constexpr int kOutputTensorDetectionClasses = 1;
40 constexpr int kOutputTensorDetectionScores = 2;
41 constexpr int kOutputTensorNumDetections = 3;
42
43 constexpr int kNumCoordBox = 4;
44 constexpr int kBatchSize = 1;
45
46 constexpr int kNumDetectionsPerClass = 100;
47
48 // Object Detection model produces axis-aligned boxes in two formats:
49 // BoxCorner represents the lower left corner (xmin, ymin) and
50 // the upper right corner (xmax, ymax).
51 // CenterSize represents the center (xcenter, ycenter), height and width.
52 // BoxCornerEncoding and CenterSizeEncoding are related as follows:
53 // ycenter = y / y_scale * anchor.h + anchor.y;
54 // xcenter = x / x_scale * anchor.w + anchor.x;
55 // half_h = 0.5*exp(h/ h_scale)) * anchor.h;
56 // half_w = 0.5*exp(w / w_scale)) * anchor.w;
57 // ymin = ycenter - half_h
58 // ymax = ycenter + half_h
59 // xmin = xcenter - half_w
60 // xmax = xcenter + half_w
61 struct BoxCornerEncoding {
62 float ymin;
63 float xmin;
64 float ymax;
65 float xmax;
66 };
67
68 struct CenterSizeEncoding {
69 float y;
70 float x;
71 float h;
72 float w;
73 };
74 // We make sure that the memory allocations are contiguous with static assert.
75 static_assert(sizeof(BoxCornerEncoding) == sizeof(float) * kNumCoordBox,
76 "Size of BoxCornerEncoding is 4 float values");
77 static_assert(sizeof(CenterSizeEncoding) == sizeof(float) * kNumCoordBox,
78 "Size of CenterSizeEncoding is 4 float values");
79
80 struct OpData {
81 int max_detections;
82 int max_classes_per_detection; // Fast Non-Max-Suppression
83 int detections_per_class; // Regular Non-Max-Suppression
84 float non_max_suppression_score_threshold;
85 float intersection_over_union_threshold;
86 int num_classes;
87 bool use_regular_non_max_suppression;
88 CenterSizeEncoding scale_values;
89 // Indices of Temporary tensors
90 int decoded_boxes_index;
91 int scores_index;
92 int active_candidate_index;
93 };
94
Init(TfLiteContext * context,const char * buffer,size_t length)95 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
96 auto* op_data = new OpData;
97 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
98 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
99 op_data->max_detections = m["max_detections"].AsInt32();
100 op_data->max_classes_per_detection = m["max_classes_per_detection"].AsInt32();
101 if (m["detections_per_class"].IsNull())
102 op_data->detections_per_class = kNumDetectionsPerClass;
103 else
104 op_data->detections_per_class = m["detections_per_class"].AsInt32();
105 if (m["use_regular_nms"].IsNull())
106 op_data->use_regular_non_max_suppression = false;
107 else
108 op_data->use_regular_non_max_suppression = m["use_regular_nms"].AsBool();
109
110 op_data->non_max_suppression_score_threshold =
111 m["nms_score_threshold"].AsFloat();
112 op_data->intersection_over_union_threshold = m["nms_iou_threshold"].AsFloat();
113 op_data->num_classes = m["num_classes"].AsInt32();
114 op_data->scale_values.y = m["y_scale"].AsFloat();
115 op_data->scale_values.x = m["x_scale"].AsFloat();
116 op_data->scale_values.h = m["h_scale"].AsFloat();
117 op_data->scale_values.w = m["w_scale"].AsFloat();
118 context->AddTensors(context, 1, &op_data->decoded_boxes_index);
119 context->AddTensors(context, 1, &op_data->scores_index);
120 context->AddTensors(context, 1, &op_data->active_candidate_index);
121 return op_data;
122 }
123
Free(TfLiteContext * context,void * buffer)124 void Free(TfLiteContext* context, void* buffer) {
125 delete reinterpret_cast<OpData*>(buffer);
126 }
127
128 // TODO(chowdhery): Add to kernel_util.h
SetTensorSizes(TfLiteContext * context,TfLiteTensor * tensor,std::initializer_list<int> values)129 TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor,
130 std::initializer_list<int> values) {
131 TfLiteIntArray* size = TfLiteIntArrayCreate(values.size());
132 int index = 0;
133 for (int v : values) {
134 size->data[index] = v;
135 ++index;
136 }
137 return context->ResizeTensor(context, tensor, size);
138 }
139
Prepare(TfLiteContext * context,TfLiteNode * node)140 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
141 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
142 // Inputs: box_encodings, scores, anchors
143 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
144 const TfLiteTensor* input_box_encodings =
145 GetInput(context, node, kInputTensorBoxEncodings);
146 const TfLiteTensor* input_class_predictions =
147 GetInput(context, node, kInputTensorClassPredictions);
148 const TfLiteTensor* input_anchors =
149 GetInput(context, node, kInputTensorAnchors);
150 TF_LITE_ENSURE_EQ(context, NumDimensions(input_box_encodings), 3);
151 TF_LITE_ENSURE_EQ(context, NumDimensions(input_class_predictions), 3);
152 TF_LITE_ENSURE_EQ(context, NumDimensions(input_anchors), 2);
153 // number of detected boxes
154 const int num_detected_boxes =
155 op_data->max_detections * op_data->max_classes_per_detection;
156
157 // Outputs: detection_boxes, detection_scores, detection_classes,
158 // num_detections
159 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 4);
160 // Output Tensor detection_boxes: size is set to (1, num_detected_boxes, 4)
161 TfLiteTensor* detection_boxes =
162 GetOutput(context, node, kOutputTensorDetectionBoxes);
163 detection_boxes->type = kTfLiteFloat32;
164 SetTensorSizes(context, detection_boxes,
165 {kBatchSize, num_detected_boxes, kNumCoordBox});
166
167 // Output Tensor detection_classes: size is set to (1, num_detected_boxes)
168 TfLiteTensor* detection_classes =
169 GetOutput(context, node, kOutputTensorDetectionClasses);
170 detection_classes->type = kTfLiteFloat32;
171 SetTensorSizes(context, detection_classes, {kBatchSize, num_detected_boxes});
172
173 // Output Tensor detection_scores: size is set to (1, num_detected_boxes)
174 TfLiteTensor* detection_scores =
175 GetOutput(context, node, kOutputTensorDetectionScores);
176 detection_scores->type = kTfLiteFloat32;
177 SetTensorSizes(context, detection_scores, {kBatchSize, num_detected_boxes});
178
179 // Output Tensor num_detections: size is set to 1
180 TfLiteTensor* num_detections =
181 GetOutput(context, node, kOutputTensorNumDetections);
182 num_detections->type = kTfLiteFloat32;
183 // TODO (chowdhery): Make it a scalar when available
184 SetTensorSizes(context, num_detections, {1});
185
186 // Temporary tensors
187 TfLiteIntArrayFree(node->temporaries);
188 node->temporaries = TfLiteIntArrayCreate(3);
189 node->temporaries->data[0] = op_data->decoded_boxes_index;
190 node->temporaries->data[1] = op_data->scores_index;
191 node->temporaries->data[2] = op_data->active_candidate_index;
192
193 // decoded_boxes
194 TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
195 decoded_boxes->type = kTfLiteFloat32;
196 decoded_boxes->allocation_type = kTfLiteArenaRw;
197 SetTensorSizes(context, decoded_boxes,
198 {input_box_encodings->dims->data[1], kNumCoordBox});
199
200 // scores
201 TfLiteTensor* scores = &context->tensors[op_data->scores_index];
202 scores->type = kTfLiteFloat32;
203 scores->allocation_type = kTfLiteArenaRw;
204 SetTensorSizes(context, scores,
205 {input_class_predictions->dims->data[1],
206 input_class_predictions->dims->data[2]});
207
208 // active_candidate
209 TfLiteTensor* active_candidate =
210 &context->tensors[op_data->active_candidate_index];
211 active_candidate->type = kTfLiteUInt8;
212 active_candidate->allocation_type = kTfLiteArenaRw;
213 SetTensorSizes(context, active_candidate,
214 {input_box_encodings->dims->data[1]});
215
216 return kTfLiteOk;
217 }
218
219 class Dequantizer {
220 public:
Dequantizer(int zero_point,float scale)221 Dequantizer(int zero_point, float scale)
222 : zero_point_(zero_point), scale_(scale) {}
operator ()(uint8 x)223 float operator()(uint8 x) {
224 return (static_cast<float>(x) - zero_point_) * scale_;
225 }
226
227 private:
228 int zero_point_;
229 float scale_;
230 };
231
DequantizeBoxEncodings(const TfLiteTensor * input_box_encodings,int idx,float quant_zero_point,float quant_scale,CenterSizeEncoding * box_centersize)232 void DequantizeBoxEncodings(const TfLiteTensor* input_box_encodings, int idx,
233 float quant_zero_point, float quant_scale,
234 CenterSizeEncoding* box_centersize) {
235 const uint8* boxes =
236 GetTensorData<uint8>(input_box_encodings) + kNumCoordBox * idx;
237 Dequantizer dequantize(quant_zero_point, quant_scale);
238 box_centersize->y = dequantize(boxes[0]);
239 box_centersize->x = dequantize(boxes[1]);
240 box_centersize->h = dequantize(boxes[2]);
241 box_centersize->w = dequantize(boxes[3]);
242 }
243
244 template <class T>
ReInterpretTensor(const TfLiteTensor * tensor)245 T ReInterpretTensor(const TfLiteTensor* tensor) {
246 // TODO (chowdhery): check float
247 const float* tensor_base = tensor->data.f;
248 return reinterpret_cast<T>(tensor_base);
249 }
250
251 template <class T>
ReInterpretTensor(TfLiteTensor * tensor)252 T ReInterpretTensor(TfLiteTensor* tensor) {
253 // TODO (chowdhery): check float
254 float* tensor_base = tensor->data.f;
255 return reinterpret_cast<T>(tensor_base);
256 }
257
DecodeCenterSizeBoxes(TfLiteContext * context,TfLiteNode * node,OpData * op_data)258 TfLiteStatus DecodeCenterSizeBoxes(TfLiteContext* context, TfLiteNode* node,
259 OpData* op_data) {
260 // Parse input tensor boxencodings
261 const TfLiteTensor* input_box_encodings =
262 GetInput(context, node, kInputTensorBoxEncodings);
263 TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[0], kBatchSize);
264 const int num_boxes = input_box_encodings->dims->data[1];
265 TF_LITE_ENSURE_EQ(context, input_box_encodings->dims->data[2], kNumCoordBox);
266 const TfLiteTensor* input_anchors =
267 GetInput(context, node, kInputTensorAnchors);
268
269 // Decode the boxes to get (ymin, xmin, ymax, xmax) based on the anchors
270 CenterSizeEncoding box_centersize;
271 CenterSizeEncoding scale_values = op_data->scale_values;
272 CenterSizeEncoding anchor;
273 for (int idx = 0; idx < num_boxes; ++idx) {
274 switch (input_box_encodings->type) {
275 // Quantized
276 case kTfLiteUInt8:
277 DequantizeBoxEncodings(
278 input_box_encodings, idx,
279 static_cast<float>(input_box_encodings->params.zero_point),
280 static_cast<float>(input_box_encodings->params.scale),
281 &box_centersize);
282 DequantizeBoxEncodings(
283 input_anchors, idx,
284 static_cast<float>(input_anchors->params.zero_point),
285 static_cast<float>(input_anchors->params.scale), &anchor);
286 break;
287 // Float
288 case kTfLiteFloat32:
289 box_centersize = ReInterpretTensor<const CenterSizeEncoding*>(
290 input_box_encodings)[idx];
291 anchor =
292 ReInterpretTensor<const CenterSizeEncoding*>(input_anchors)[idx];
293 break;
294 default:
295 // Unsupported type.
296 return kTfLiteError;
297 }
298
299 float ycenter = box_centersize.y / scale_values.y * anchor.h + anchor.y;
300 float xcenter = box_centersize.x / scale_values.x * anchor.w + anchor.x;
301 float half_h =
302 0.5f * static_cast<float>(std::exp(box_centersize.h / scale_values.h)) *
303 anchor.h;
304 float half_w =
305 0.5f * static_cast<float>(std::exp(box_centersize.w / scale_values.w)) *
306 anchor.w;
307 TfLiteTensor* decoded_boxes =
308 &context->tensors[op_data->decoded_boxes_index];
309 auto& box = ReInterpretTensor<BoxCornerEncoding*>(decoded_boxes)[idx];
310 box.ymin = ycenter - half_h;
311 box.xmin = xcenter - half_w;
312 box.ymax = ycenter + half_h;
313 box.xmax = xcenter + half_w;
314 }
315 return kTfLiteOk;
316 }
317
DecreasingPartialArgSort(const float * values,int num_values,int num_to_sort,int * indices)318 void DecreasingPartialArgSort(const float* values, int num_values,
319 int num_to_sort, int* indices) {
320 std::iota(indices, indices + num_values, 0);
321 std::partial_sort(
322 indices, indices + num_to_sort, indices + num_values,
323 [&values](const int i, const int j) { return values[i] > values[j]; });
324 }
325
SelectDetectionsAboveScoreThreshold(const std::vector<float> & values,const float threshold,std::vector<float> * keep_values,std::vector<int> * keep_indices)326 void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
327 const float threshold,
328 std::vector<float>* keep_values,
329 std::vector<int>* keep_indices) {
330 for (int i = 0; i < values.size(); i++) {
331 if (values[i] >= threshold) {
332 keep_values->emplace_back(values[i]);
333 keep_indices->emplace_back(i);
334 }
335 }
336 }
337
ValidateBoxes(const TfLiteTensor * decoded_boxes,const int num_boxes)338 bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
339 for (int i = 0; i < num_boxes; ++i) {
340 // ymax>=ymin, xmax>=xmin
341 auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
342 if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
343 return false;
344 }
345 }
346 return true;
347 }
348
ComputeIntersectionOverUnion(const TfLiteTensor * decoded_boxes,const int i,const int j)349 float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
350 const int i, const int j) {
351 auto& box_i = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
352 auto& box_j = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[j];
353 const float area_i = (box_i.ymax - box_i.ymin) * (box_i.xmax - box_i.xmin);
354 const float area_j = (box_j.ymax - box_j.ymin) * (box_j.xmax - box_j.xmin);
355 if (area_i <= 0 || area_j <= 0) return 0.0;
356 const float intersection_ymin = std::max<float>(box_i.ymin, box_j.ymin);
357 const float intersection_xmin = std::max<float>(box_i.xmin, box_j.xmin);
358 const float intersection_ymax = std::min<float>(box_i.ymax, box_j.ymax);
359 const float intersection_xmax = std::min<float>(box_i.xmax, box_j.xmax);
360 const float intersection_area =
361 std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
362 std::max<float>(intersection_xmax - intersection_xmin, 0.0);
363 return intersection_area / (area_i + area_j - intersection_area);
364 }
365
366 // NonMaxSuppressionSingleClass() prunes out the box locations with high overlap
367 // before selecting the highest scoring boxes (max_detections in number)
368 // It assumes all boxes are good in beginning and sorts based on the scores.
369 // If lower-scoring box has too much overlap with a higher-scoring box,
370 // we get rid of the lower-scoring box.
371 // Complexity is O(N^2) pairwise comparison between boxes
NonMaxSuppressionSingleClassHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const std::vector<float> & scores,std::vector<int> * selected,int max_detections)372 TfLiteStatus NonMaxSuppressionSingleClassHelper(
373 TfLiteContext* context, TfLiteNode* node, OpData* op_data,
374 const std::vector<float>& scores, std::vector<int>* selected,
375 int max_detections) {
376 const TfLiteTensor* input_box_encodings =
377 GetInput(context, node, kInputTensorBoxEncodings);
378 const TfLiteTensor* decoded_boxes =
379 &context->tensors[op_data->decoded_boxes_index];
380 const int num_boxes = input_box_encodings->dims->data[1];
381 const float non_max_suppression_score_threshold =
382 op_data->non_max_suppression_score_threshold;
383 const float intersection_over_union_threshold =
384 op_data->intersection_over_union_threshold;
385 // Maximum detections should be positive.
386 TF_LITE_ENSURE(context, (max_detections >= 0));
387 // intersection_over_union_threshold should be positive
388 // and should be less than 1.
389 TF_LITE_ENSURE(context, (intersection_over_union_threshold > 0.0f) &&
390 (intersection_over_union_threshold <= 1.0f));
391 // Validate boxes
392 TF_LITE_ENSURE(context, ValidateBoxes(decoded_boxes, num_boxes));
393
394 // threshold scores
395 std::vector<int> keep_indices;
396 // TODO (chowdhery): Remove the dynamic allocation and replace it
397 // with temporaries, esp for std::vector<float>
398 std::vector<float> keep_scores;
399 SelectDetectionsAboveScoreThreshold(
400 scores, non_max_suppression_score_threshold, &keep_scores, &keep_indices);
401
402 int num_scores_kept = keep_scores.size();
403 std::vector<int> sorted_indices;
404 sorted_indices.resize(num_scores_kept);
405 DecreasingPartialArgSort(keep_scores.data(), num_scores_kept, num_scores_kept,
406 sorted_indices.data());
407 const int num_boxes_kept = num_scores_kept;
408 const int output_size = std::min(num_boxes_kept, max_detections);
409 selected->clear();
410 TfLiteTensor* active_candidate =
411 &context->tensors[op_data->active_candidate_index];
412 TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes);
413 int num_active_candidate = num_boxes_kept;
414 uint8_t* active_box_candidate = (active_candidate->data.uint8);
415 for (int row = 0; row < num_boxes_kept; row++) {
416 active_box_candidate[row] = 1;
417 }
418
419 for (int i = 0; i < num_boxes_kept; ++i) {
420 if (num_active_candidate == 0 || selected->size() >= output_size) break;
421 if (active_box_candidate[i] == 1) {
422 selected->push_back(keep_indices[sorted_indices[i]]);
423 active_box_candidate[i] = 0;
424 num_active_candidate--;
425 } else {
426 continue;
427 }
428 for (int j = i + 1; j < num_boxes_kept; ++j) {
429 if (active_box_candidate[j] == 1) {
430 float intersection_over_union = ComputeIntersectionOverUnion(
431 decoded_boxes, keep_indices[sorted_indices[i]],
432 keep_indices[sorted_indices[j]]);
433
434 if (intersection_over_union > intersection_over_union_threshold) {
435 active_box_candidate[j] = 0;
436 num_active_candidate--;
437 }
438 }
439 }
440 }
441 return kTfLiteOk;
442 }
443
444 // This function implements a regular version of Non Maximal Suppression (NMS)
445 // for multiple classes where
446 // 1) we do NMS separately for each class across all anchors and
447 // 2) keep only the highest anchor scores across all classes
448 // 3) The worst runtime of the regular NMS is O(K*N^2)
449 // where N is the number of anchors and K the number of
450 // classes.
NonMaxSuppressionMultiClassRegularHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)451 TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
452 TfLiteNode* node,
453 OpData* op_data,
454 const float* scores) {
455 const TfLiteTensor* input_box_encodings =
456 GetInput(context, node, kInputTensorBoxEncodings);
457 const TfLiteTensor* decoded_boxes =
458 &context->tensors[op_data->decoded_boxes_index];
459
460 TfLiteTensor* detection_boxes =
461 GetOutput(context, node, kOutputTensorDetectionBoxes);
462 TfLiteTensor* detection_classes =
463 GetOutput(context, node, kOutputTensorDetectionClasses);
464 TfLiteTensor* detection_scores =
465 GetOutput(context, node, kOutputTensorDetectionScores);
466 TfLiteTensor* num_detections =
467 GetOutput(context, node, kOutputTensorNumDetections);
468
469 const int num_boxes = input_box_encodings->dims->data[1];
470 const int num_classes = op_data->num_classes;
471 const int num_detections_per_class = op_data->detections_per_class;
472 const int max_detections = op_data->max_detections;
473 // The row index offset is 1 if background class is included and 0 otherwise.
474 const int label_offset = 1;
475 TF_LITE_ENSURE(context, label_offset != -1);
476 TF_LITE_ENSURE(context, num_detections_per_class > 0);
477 const int num_classes_with_background = num_classes + label_offset;
478
479 // For each class, perform non-max suppression.
480 std::vector<float> class_scores(num_boxes);
481
482 std::vector<int> box_indices_after_regular_non_max_suppression(
483 num_boxes + max_detections);
484 std::vector<float> scores_after_regular_non_max_suppression(num_boxes +
485 max_detections);
486
487 int size_of_sorted_indices = 0;
488 std::vector<int> sorted_indices;
489 sorted_indices.resize(max_detections);
490 std::vector<float> sorted_values;
491 sorted_values.resize(max_detections);
492
493 for (int col = 0; col < num_classes; col++) {
494 for (int row = 0; row < num_boxes; row++) {
495 // Get scores of boxes corresponding to all anchors for single class
496 class_scores[row] =
497 *(scores + row * num_classes_with_background + col + label_offset);
498 }
499 // Perform non-maximal suppression on single class
500 std::vector<int> selected;
501 TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
502 context, node, op_data, class_scores, &selected,
503 num_detections_per_class));
504 // Add selected indices from non-max suppression of boxes in this class
505 int output_index = size_of_sorted_indices;
506 for (int selected_index : selected) {
507 box_indices_after_regular_non_max_suppression[output_index] =
508 (selected_index * num_classes_with_background + col + label_offset);
509 scores_after_regular_non_max_suppression[output_index] =
510 class_scores[selected_index];
511 output_index++;
512 }
513 // Sort the max scores among the selected indices
514 // Get the indices for top scores
515 int num_indices_to_sort = std::min(output_index, max_detections);
516 DecreasingPartialArgSort(scores_after_regular_non_max_suppression.data(),
517 output_index, num_indices_to_sort,
518 sorted_indices.data());
519
520 // Copy values to temporary vectors
521 for (int row = 0; row < num_indices_to_sort; row++) {
522 int temp = sorted_indices[row];
523 sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
524 sorted_values[row] = scores_after_regular_non_max_suppression[temp];
525 }
526 // Copy scores and indices from temporary vectors
527 for (int row = 0; row < num_indices_to_sort; row++) {
528 box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
529 scores_after_regular_non_max_suppression[row] = sorted_values[row];
530 }
531 size_of_sorted_indices = num_indices_to_sort;
532 }
533
534 // Allocate output tensors
535 for (int output_box_index = 0; output_box_index < max_detections;
536 output_box_index++) {
537 if (output_box_index < size_of_sorted_indices) {
538 const int anchor_index = floor(
539 box_indices_after_regular_non_max_suppression[output_box_index] /
540 num_classes_with_background);
541 const int class_index =
542 box_indices_after_regular_non_max_suppression[output_box_index] -
543 anchor_index * num_classes_with_background - label_offset;
544 const float selected_score =
545 scores_after_regular_non_max_suppression[output_box_index];
546 // detection_boxes
547 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[output_box_index] =
548 ReInterpretTensor<const BoxCornerEncoding*>(
549 decoded_boxes)[anchor_index];
550 // detection_classes
551 detection_classes->data.f[output_box_index] = class_index;
552 // detection_scores
553 detection_scores->data.f[output_box_index] = selected_score;
554 } else {
555 ReInterpretTensor<BoxCornerEncoding*>(
556 detection_boxes)[output_box_index] = {0.0f, 0.0f, 0.0f, 0.0f};
557 // detection_classes
558 detection_classes->data.f[output_box_index] = 0.0f;
559 // detection_scores
560 detection_scores->data.f[output_box_index] = 0.0f;
561 }
562 }
563 num_detections->data.f[0] = size_of_sorted_indices;
564 box_indices_after_regular_non_max_suppression.clear();
565 scores_after_regular_non_max_suppression.clear();
566 return kTfLiteOk;
567 }
568
569 // This function implements a fast version of Non Maximal Suppression for
570 // multiple classes where
571 // 1) we keep the top-k scores for each anchor and
572 // 2) during NMS, each anchor only uses the highest class score for sorting.
573 // 3) Compared to standard NMS, the worst runtime of this version is O(N^2)
574 // instead of O(KN^2) where N is the number of anchors and K the number of
575 // classes.
NonMaxSuppressionMultiClassFastHelper(TfLiteContext * context,TfLiteNode * node,OpData * op_data,const float * scores)576 TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
577 TfLiteNode* node,
578 OpData* op_data,
579 const float* scores) {
580 const TfLiteTensor* input_box_encodings =
581 GetInput(context, node, kInputTensorBoxEncodings);
582 const TfLiteTensor* decoded_boxes =
583 &context->tensors[op_data->decoded_boxes_index];
584
585 TfLiteTensor* detection_boxes =
586 GetOutput(context, node, kOutputTensorDetectionBoxes);
587 TfLiteTensor* detection_classes =
588 GetOutput(context, node, kOutputTensorDetectionClasses);
589 TfLiteTensor* detection_scores =
590 GetOutput(context, node, kOutputTensorDetectionScores);
591 TfLiteTensor* num_detections =
592 GetOutput(context, node, kOutputTensorNumDetections);
593
594 const int num_boxes = input_box_encodings->dims->data[1];
595 const int num_classes = op_data->num_classes;
596 const int max_categories_per_anchor = op_data->max_classes_per_detection;
597 // The row index offset is 1 if background class is included and 0 otherwise.
598 const int label_offset = 1;
599 TF_LITE_ENSURE(context, (label_offset != -1));
600 TF_LITE_ENSURE(context, (max_categories_per_anchor > 0));
601 const int num_classes_with_background = num_classes + label_offset;
602 const int num_categories_per_anchor =
603 std::min(max_categories_per_anchor, num_classes);
604 std::vector<float> max_scores;
605 max_scores.resize(num_boxes);
606 std::vector<int> sorted_class_indices;
607 sorted_class_indices.resize(num_boxes * num_classes);
608 for (int row = 0; row < num_boxes; row++) {
609 const float* box_scores =
610 scores + row * num_classes_with_background + label_offset;
611 int* class_indices = sorted_class_indices.data() + row * num_classes;
612 DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
613 class_indices);
614 max_scores[row] = box_scores[class_indices[0]];
615 }
616 // Perform non-maximal suppression on max scores
617 std::vector<int> selected;
618 TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
619 context, node, op_data, max_scores, &selected, op_data->max_detections));
620 // Allocate output tensors
621 int output_box_index = 0;
622 for (const auto& selected_index : selected) {
623 const float* box_scores =
624 scores + selected_index * num_classes_with_background + label_offset;
625 const int* class_indices =
626 sorted_class_indices.data() + selected_index * num_classes;
627
628 for (int col = 0; col < num_categories_per_anchor; ++col) {
629 int box_offset = num_categories_per_anchor * output_box_index + col;
630 // detection_boxes
631 ReInterpretTensor<BoxCornerEncoding*>(detection_boxes)[box_offset] =
632 ReInterpretTensor<const BoxCornerEncoding*>(
633 decoded_boxes)[selected_index];
634 // detection_classes
635 detection_classes->data.f[box_offset] = class_indices[col];
636 // detection_scores
637 detection_scores->data.f[box_offset] = box_scores[class_indices[col]];
638 output_box_index++;
639 }
640 }
641 num_detections->data.f[0] = output_box_index;
642 return kTfLiteOk;
643 }
644
DequantizeClassPredictions(const TfLiteTensor * input_class_predictions,const int num_boxes,const int num_classes_with_background,const TfLiteTensor * scores)645 void DequantizeClassPredictions(const TfLiteTensor* input_class_predictions,
646 const int num_boxes,
647 const int num_classes_with_background,
648 const TfLiteTensor* scores) {
649 float quant_zero_point =
650 static_cast<float>(input_class_predictions->params.zero_point);
651 float quant_scale = static_cast<float>(input_class_predictions->params.scale);
652 Dequantizer dequantize(quant_zero_point, quant_scale);
653 const uint8* scores_quant = GetTensorData<uint8>(input_class_predictions);
654 for (int idx = 0; idx < num_boxes * num_classes_with_background; ++idx) {
655 scores->data.f[idx] = dequantize(scores_quant[idx]);
656 }
657 }
658
NonMaxSuppressionMultiClass(TfLiteContext * context,TfLiteNode * node,OpData * op_data)659 TfLiteStatus NonMaxSuppressionMultiClass(TfLiteContext* context,
660 TfLiteNode* node, OpData* op_data) {
661 // Get the input tensors
662 const TfLiteTensor* input_box_encodings =
663 GetInput(context, node, kInputTensorBoxEncodings);
664 const TfLiteTensor* input_class_predictions =
665 GetInput(context, node, kInputTensorClassPredictions);
666 const int num_boxes = input_box_encodings->dims->data[1];
667 const int num_classes = op_data->num_classes;
668 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[0],
669 kBatchSize);
670 TF_LITE_ENSURE_EQ(context, input_class_predictions->dims->data[1], num_boxes);
671 const int num_classes_with_background =
672 input_class_predictions->dims->data[2];
673
674 TF_LITE_ENSURE(context, (num_classes_with_background == num_classes + 1));
675
676 const TfLiteTensor* scores;
677 switch (input_class_predictions->type) {
678 case kTfLiteUInt8: {
679 TfLiteTensor* temporary_scores = &context->tensors[op_data->scores_index];
680 DequantizeClassPredictions(input_class_predictions, num_boxes,
681 num_classes_with_background, temporary_scores);
682 scores = temporary_scores;
683 } break;
684 case kTfLiteFloat32:
685 scores = input_class_predictions;
686 break;
687 default:
688 // Unsupported type.
689 return kTfLiteError;
690 }
691 if (op_data->use_regular_non_max_suppression)
692 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassRegularHelper(
693 context, node, op_data, GetTensorData<float>(scores)));
694 else
695 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClassFastHelper(
696 context, node, op_data, GetTensorData<float>(scores)));
697
698 return kTfLiteOk;
699 }
700
Eval(TfLiteContext * context,TfLiteNode * node)701 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
702 // TODO(chowdhery): Generalize for any batch size
703 TF_LITE_ENSURE(context, (kBatchSize == 1));
704 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
705 // These two functions correspond to two blocks in the Object Detection model.
706 // In future, we would like to break the custom op in two blocks, which is
707 // currently not feasible because we would like to input quantized inputs
708 // and do all calculations in float. Mixed quantized/float calculations are
709 // currently not supported in TFLite.
710
711 // This fills in temporary decoded_boxes
712 // by transforming input_box_encodings and input_anchors from
713 // CenterSizeEncodings to BoxCornerEncoding
714 TF_LITE_ENSURE_STATUS(DecodeCenterSizeBoxes(context, node, op_data));
715 // This fills in the output tensors
716 // by choosing effective set of decoded boxes
717 // based on Non Maximal Suppression, i.e. selecting
718 // highest scoring non-overlapping boxes.
719 TF_LITE_ENSURE_STATUS(NonMaxSuppressionMultiClass(context, node, op_data));
720
721 return kTfLiteOk;
722 }
723 } // namespace detection_postprocess
724
Register_DETECTION_POSTPROCESS()725 TfLiteRegistration* Register_DETECTION_POSTPROCESS() {
726 static TfLiteRegistration r = {detection_postprocess::Init,
727 detection_postprocess::Free,
728 detection_postprocess::Prepare,
729 detection_postprocess::Eval};
730 return &r;
731 }
732
733 } // namespace custom
734 } // namespace ops
735 } // namespace tflite
736