• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "GenerateProposals.h"
20 
21 #include <algorithm>
22 #include <cfloat>
23 #include <cmath>
24 #include <numeric>
25 #include <utility>
26 #include <vector>
27 
28 #include "OperationResolver.h"
29 #include "OperationsExecutionUtils.h"
30 #include "Tracing.h"
31 
32 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
33 #include "CpuOperationUtils.h"
34 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
35 
36 namespace android {
37 namespace nn {
38 namespace bbox_ops {
39 
40 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
41 namespace {
42 
43 struct BoxEncodingCorner {
44     float x1, y1, x2, y2;
45 };
46 struct BoxEncodingCenter {
47     float w, h, x, y;
48 };
toBoxEncodingCorner(const BoxEncodingCenter & ctr)49 BoxEncodingCorner toBoxEncodingCorner(const BoxEncodingCenter& ctr) {
50     return {.x1 = ctr.x - ctr.w / 2,
51             .y1 = ctr.y - ctr.h / 2,
52             .x2 = ctr.x + ctr.w / 2,
53             .y2 = ctr.y + ctr.h / 2};
54 }
toBoxEncodingCenter(const BoxEncodingCorner & cnr)55 BoxEncodingCenter toBoxEncodingCenter(const BoxEncodingCorner& cnr) {
56     return {.w = cnr.x2 - cnr.x1,
57             .h = cnr.y2 - cnr.y1,
58             .x = (cnr.x1 + cnr.x2) / 2,
59             .y = (cnr.y1 + cnr.y2) / 2};
60 }
61 
bboxTransformFloat32(const float * roiData,const Shape & roiShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape &,const float * imageInfoData,const Shape & imageInfoDataShape,float * outputData,const Shape &)62 inline bool bboxTransformFloat32(const float* roiData, const Shape& roiShape,
63                                  const float* bboxDeltasData, const Shape& bboxDeltasShape,
64                                  const int32_t* batchesData, const Shape& /*batchesShape*/,
65                                  const float* imageInfoData, const Shape& imageInfoDataShape,
66                                  float* outputData, const Shape& /*outputShape*/) {
67     const uint32_t roiLength = 4;
68     const uint32_t imageLength = 2;
69 
70     uint32_t numClasses = getSizeOfDimension(bboxDeltasShape, 1) / roiLength;
71     uint32_t numBatches = getSizeOfDimension(imageInfoDataShape, 0);
72 
73     const float* roiDataEnd = roiData + getNumberOfElements(roiShape);
74     const float* deltas = bboxDeltasData;
75     float* outPtr = outputData;
76     uint32_t roiIndex = 0;
77     for (const float* roiBase = roiData; roiBase < roiDataEnd; roiBase += roiLength, roiIndex++) {
78         uint32_t batchIndex = batchesData[roiIndex];
79         // Check for malformed data
80         // 1. Invalid batch id
81         // 2. Invalid region: x2 < x1 || y2 < y1
82         NN_RET_CHECK_GE(batchIndex, 0u);
83         NN_RET_CHECK_LT(batchIndex, numBatches);
84         NN_RET_CHECK_LE(roiBase[0], roiBase[2]);
85         NN_RET_CHECK_LE(roiBase[1], roiBase[3]);
86 
87         const float* imageInfoBase = imageInfoData + batchIndex * imageLength;
88         float imageHeight = imageInfoBase[0];
89         float imageWidth = imageInfoBase[1];
90         auto roiBefore = toBoxEncodingCenter(
91                 {.x1 = roiBase[0], .y1 = roiBase[1], .x2 = roiBase[2], .y2 = roiBase[3]});
92         for (uint32_t i = 0; i < numClasses; i++) {
93             auto roiAfter = toBoxEncodingCorner({.w = std::exp(deltas[2]) * roiBefore.w,
94                                                  .h = std::exp(deltas[3]) * roiBefore.h,
95                                                  .x = roiBefore.x + deltas[0] * roiBefore.w,
96                                                  .y = roiBefore.y + deltas[1] * roiBefore.h});
97             BoxEncodingCorner cliped = {.x1 = std::min(std::max(roiAfter.x1, 0.0f), imageWidth),
98                                         .y1 = std::min(std::max(roiAfter.y1, 0.0f), imageHeight),
99                                         .x2 = std::min(std::max(roiAfter.x2, 0.0f), imageWidth),
100                                         .y2 = std::min(std::max(roiAfter.y2, 0.0f), imageHeight)};
101             outPtr[0] = cliped.x1;
102             outPtr[1] = cliped.y1;
103             outPtr[2] = cliped.x2;
104             outPtr[3] = cliped.y2;
105             deltas += roiLength;
106             outPtr += roiLength;
107         }
108     }
109     return true;
110 }
111 
bboxTransformFloat16(const _Float16 * roiData,const Shape & roiShape,const _Float16 * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const _Float16 * imageInfoData,const Shape & imageInfoDataShape,_Float16 * outputData,const Shape & outputShape)112 inline bool bboxTransformFloat16(const _Float16* roiData, const Shape& roiShape,
113                                  const _Float16* bboxDeltasData, const Shape& bboxDeltasShape,
114                                  const int32_t* batchesData, const Shape& batchesShape,
115                                  const _Float16* imageInfoData, const Shape& imageInfoDataShape,
116                                  _Float16* outputData, const Shape& outputShape) {
117     std::vector<float> roi_float32(getNumberOfElements(roiShape));
118     convertFloat16ToFloat32(roiData, &roi_float32);
119     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
120     convertFloat16ToFloat32(bboxDeltasData, &delta_float32);
121     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
122     convertFloat16ToFloat32(imageInfoData, &imageInfo_float32);
123     std::vector<float> output_float32(getNumberOfElements(outputShape));
124     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
125                                       bboxDeltasShape, batchesData, batchesShape,
126                                       imageInfo_float32.data(), imageInfoDataShape,
127                                       output_float32.data(), outputShape));
128     convertFloat32ToFloat16(output_float32, outputData);
129     return true;
130 }
131 
bboxTransformQuant(const uint16_t * roiData,const Shape & roiShape,const uint8_t * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const uint16_t * imageInfoData,const Shape & imageInfoDataShape,uint16_t * outputData,const Shape & outputShape)132 inline bool bboxTransformQuant(const uint16_t* roiData, const Shape& roiShape,
133                                const uint8_t* bboxDeltasData, const Shape& bboxDeltasShape,
134                                const int32_t* batchesData, const Shape& batchesShape,
135                                const uint16_t* imageInfoData, const Shape& imageInfoDataShape,
136                                uint16_t* outputData, const Shape& outputShape) {
137     std::vector<float> roi_float32(getNumberOfElements(roiShape));
138     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
139     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
140     convertQuantToFloat32(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
141                           &delta_float32);
142     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
143     convertQuantToFloat32(imageInfoData, imageInfoDataShape.scale, imageInfoDataShape.offset,
144                           &imageInfo_float32);
145     std::vector<float> output_float32(getNumberOfElements(outputShape));
146     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
147                                       bboxDeltasShape, batchesData, batchesShape,
148                                       imageInfo_float32.data(), imageInfoDataShape,
149                                       output_float32.data(), outputShape));
150     convertFloat32ToQuant(output_float32, outputShape.scale, outputShape.offset, outputData);
151     return true;
152 }
153 
bboxTransformQuant(const uint16_t * roiData,const Shape & roiShape,const int8_t * bboxDeltasData,const Shape & bboxDeltasShape,const int32_t * batchesData,const Shape & batchesShape,const uint16_t * imageInfoData,const Shape & imageInfoDataShape,uint16_t * outputData,const Shape & outputShape)154 inline bool bboxTransformQuant(const uint16_t* roiData, const Shape& roiShape,
155                                const int8_t* bboxDeltasData, const Shape& bboxDeltasShape,
156                                const int32_t* batchesData, const Shape& batchesShape,
157                                const uint16_t* imageInfoData, const Shape& imageInfoDataShape,
158                                uint16_t* outputData, const Shape& outputShape) {
159     std::vector<float> roi_float32(getNumberOfElements(roiShape));
160     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
161     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
162     convertQuantToFloat32<int8_t>(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
163                                   &delta_float32);
164     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoDataShape));
165     convertQuantToFloat32(imageInfoData, imageInfoDataShape.scale, imageInfoDataShape.offset,
166                           &imageInfo_float32);
167     std::vector<float> output_float32(getNumberOfElements(outputShape));
168     NN_RET_CHECK(bboxTransformFloat32(roi_float32.data(), roiShape, delta_float32.data(),
169                                       bboxDeltasShape, batchesData, batchesShape,
170                                       imageInfo_float32.data(), imageInfoDataShape,
171                                       output_float32.data(), outputShape));
172     convertFloat32ToQuant(output_float32, outputShape.scale, outputShape.offset, outputData);
173     return true;
174 }
175 
176 // Taking two indices of bounding boxes, return the intersection-of-union.
getIoUAxisAligned(const float * roi1,const float * roi2)177 float getIoUAxisAligned(const float* roi1, const float* roi2) {
178     const float area1 = (roi1[2] - roi1[0]) * (roi1[3] - roi1[1]);
179     const float area2 = (roi2[2] - roi2[0]) * (roi2[3] - roi2[1]);
180     const float x1 = std::max(roi1[0], roi2[0]);
181     const float x2 = std::min(roi1[2], roi2[2]);
182     const float y1 = std::max(roi1[1], roi2[1]);
183     const float y2 = std::min(roi1[3], roi2[3]);
184     const float w = std::max(x2 - x1, 0.0f);
185     const float h = std::max(y2 - y1, 0.0f);
186     const float areaIntersect = w * h;
187     const float areaUnion = area1 + area2 - areaIntersect;
188     return areaIntersect / areaUnion;
189 }
190 
191 }  // namespace
192 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
193 
194 namespace axis_aligned_bbox_transform {
195 
196 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)197 bool prepare(IOperationExecutionContext* context) {
198     Shape roiShape = context->getInputShape(kRoiTensor);
199     Shape bboxDeltasShape = context->getInputShape(kDeltaTensor);
200     Shape batchesShape = context->getInputShape(kBatchesTensor);
201     Shape imageInfoShape = context->getInputShape(kImageInfoTensor);
202     Shape outputShape = context->getOutputShape(kOutputTensor);
203 
204     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2u);
205     NN_RET_CHECK_EQ(getNumberOfDimensions(bboxDeltasShape), 2u);
206     NN_RET_CHECK_EQ(getNumberOfDimensions(batchesShape), 1u);
207     NN_RET_CHECK_EQ(getNumberOfDimensions(imageInfoShape), 2u);
208 
209     // Only numRois can be zero.
210     const uint32_t kRoiDim = 4;
211     uint32_t numRois = getSizeOfDimension(roiShape, 0);
212     uint32_t numClasses = getSizeOfDimension(bboxDeltasShape, 1) / kRoiDim;
213     uint32_t numBatches = getSizeOfDimension(imageInfoShape, 0);
214     NN_RET_CHECK_GT(numClasses, 0u);
215     NN_RET_CHECK_GT(numBatches, 0u);
216     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), kRoiDim);
217     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 0), numRois);
218     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 1), kRoiDim * numClasses);
219     NN_RET_CHECK_EQ(getSizeOfDimension(batchesShape, 0), numRois);
220     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoShape, 1), 2u);
221 
222     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
223         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
224         NN_RET_CHECK_EQ(roiShape.offset, 0);
225         NN_RET_CHECK_EQ(imageInfoShape.scale, 0.125f);
226         NN_RET_CHECK_EQ(imageInfoShape.offset, 0);
227     }
228 
229     outputShape.type = roiShape.type;
230     outputShape.dimensions = {numRois, numClasses * kRoiDim};
231     outputShape.scale = 0.f;
232     outputShape.offset = 0;
233     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
234         outputShape.scale = 0.125f;
235     }
236     NN_RET_CHECK(context->setOutputShape(kOutputTensor, outputShape));
237     return true;
238 }
239 
execute(IOperationExecutionContext * context)240 bool execute(IOperationExecutionContext* context) {
241     NNTRACE_TRANS("axisAlignedBBoxTransform");
242     // Bypass execution in the case of zero-sized input.
243     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
244     switch (context->getInputType(kRoiTensor)) {
245         case OperandType::TENSOR_FLOAT16: {
246             return bboxTransformFloat16(context->getInputBuffer<_Float16>(kRoiTensor),
247                                         context->getInputShape(kRoiTensor),
248                                         context->getInputBuffer<_Float16>(kDeltaTensor),
249                                         context->getInputShape(kDeltaTensor),
250                                         context->getInputBuffer<int32_t>(kBatchesTensor),
251                                         context->getInputShape(kBatchesTensor),
252                                         context->getInputBuffer<_Float16>(kImageInfoTensor),
253                                         context->getInputShape(kImageInfoTensor),
254                                         context->getOutputBuffer<_Float16>(kOutputTensor),
255                                         context->getOutputShape(kOutputTensor));
256         }
257         case OperandType::TENSOR_FLOAT32: {
258             return bboxTransformFloat32(context->getInputBuffer<float>(kRoiTensor),
259                                         context->getInputShape(kRoiTensor),
260                                         context->getInputBuffer<float>(kDeltaTensor),
261                                         context->getInputShape(kDeltaTensor),
262                                         context->getInputBuffer<int32_t>(kBatchesTensor),
263                                         context->getInputShape(kBatchesTensor),
264                                         context->getInputBuffer<float>(kImageInfoTensor),
265                                         context->getInputShape(kImageInfoTensor),
266                                         context->getOutputBuffer<float>(kOutputTensor),
267                                         context->getOutputShape(kOutputTensor));
268         }
269         case OperandType::TENSOR_QUANT16_ASYMM: {
270             if (context->getInputType(kDeltaTensor) == OperandType::TENSOR_QUANT8_ASYMM) {
271                 return bboxTransformQuant(context->getInputBuffer<uint16_t>(kRoiTensor),
272                                           context->getInputShape(kRoiTensor),
273                                           context->getInputBuffer<uint8_t>(kDeltaTensor),
274                                           context->getInputShape(kDeltaTensor),
275                                           context->getInputBuffer<int32_t>(kBatchesTensor),
276                                           context->getInputShape(kBatchesTensor),
277                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
278                                           context->getInputShape(kImageInfoTensor),
279                                           context->getOutputBuffer<uint16_t>(kOutputTensor),
280                                           context->getOutputShape(kOutputTensor));
281             } else {
282                 return bboxTransformQuant(context->getInputBuffer<uint16_t>(kRoiTensor),
283                                           context->getInputShape(kRoiTensor),
284                                           context->getInputBuffer<int8_t>(kDeltaTensor),
285                                           context->getInputShape(kDeltaTensor),
286                                           context->getInputBuffer<int32_t>(kBatchesTensor),
287                                           context->getInputShape(kBatchesTensor),
288                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
289                                           context->getInputShape(kImageInfoTensor),
290                                           context->getOutputBuffer<uint16_t>(kOutputTensor),
291                                           context->getOutputShape(kOutputTensor));
292             }
293         }
294         default:
295             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
296     }
297 }
298 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
299 
300 }  // namespace axis_aligned_bbox_transform
301 
302 namespace box_with_nms_limit {
303 
304 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
305 namespace {
306 
307 // TODO(xusongw): Reduce code duplication with hard/soft nms path.
308 
309 // Inplace hard NMS within range [select, select + selectLength).
hardNmsSingleClass(const float * scoresData,float iouThreshold,int32_t maxNumDetections,std::function<const float * (uint32_t)> getRoiBase,uint32_t * select,uint32_t selectLength)310 uint32_t* hardNmsSingleClass(const float* scoresData, float iouThreshold, int32_t maxNumDetections,
311                              std::function<const float*(uint32_t)> getRoiBase, uint32_t* select,
312                              uint32_t selectLength) {
313     uint32_t *selectStart = select, *selectEnd = select + selectLength, numDetections = 0;
314     if (maxNumDetections < 0) {
315         maxNumDetections = selectLength;
316     }
317     while (selectStart < selectEnd && numDetections < static_cast<uint32_t>(maxNumDetections)) {
318         // find max score and swap to the front
319         auto& maxScore = *std::max_element(selectStart, selectEnd,
320                                            [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
321                                                return scoresData[lhs] < scoresData[rhs];
322                                            });
323         std::swap(maxScore, *selectStart);
324 
325         // Calculate IoU of the rest, swap to the end (disgard) if needed.
326         for (uint32_t* i = selectStart + 1; i < selectEnd; i++) {
327             float iou = getIoUAxisAligned(getRoiBase(*i), getRoiBase(*selectStart));
328             if (iou >= iouThreshold) {
329                 std::swap(*i--, *(--selectEnd));
330             }
331         }
332         selectStart++;
333         numDetections++;
334     }
335     return selectStart;
336 }
337 
hardNmsMultiClass(const float * scoresData,uint32_t numClasses,uint32_t numRois,float scoreThreshold,float iouThreshold,int32_t maxNumDetections,int32_t maxNumDetectionsPerClass,std::function<const float * (uint32_t)> getRoiBase,std::vector<uint32_t> * select)338 void hardNmsMultiClass(const float* scoresData, uint32_t numClasses, uint32_t numRois,
339                        float scoreThreshold, float iouThreshold, int32_t maxNumDetections,
340                        int32_t maxNumDetectionsPerClass,
341                        std::function<const float*(uint32_t)> getRoiBase,
342                        std::vector<uint32_t>* select) {
343     // Exclude class 0 (background)
344     for (uint32_t c = 1; c < numClasses; c++) {
345         uint32_t size = select->size();
346         for (uint32_t b = 0; b < numRois; b++) {
347             const uint32_t index = b * numClasses + c;
348             const float score = scoresData[index];
349             if (score > scoreThreshold) {
350                 select->push_back(index);
351             }
352         }
353         uint32_t* selectStart = select->data() + size;
354         uint32_t selectLength = select->size() - size;
355         uint32_t* selectEnd = hardNmsSingleClass(scoresData, iouThreshold, maxNumDetectionsPerClass,
356                                                  getRoiBase, selectStart, selectLength);
357         select->resize(selectEnd - select->data());
358     }
359 
360     // Take top maxNumDetections.
361     std::sort(select->begin(), select->end(),
362               [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
363                   return scoresData[lhs] > scoresData[rhs];
364               });
365     if (maxNumDetections < 0 || select->size() <= static_cast<size_t>(maxNumDetections)) {
366         return;
367     }
368     select->resize(maxNumDetections);
369 }
370 
371 // Inplace soft NMS within range [select, select + selectLength).
372 using SoftNmsKernel = std::function<float(float)>;
softNmsSingleClass(float * scoresData,float scoreThreshold,int32_t maxNumDetections,std::function<const float * (uint32_t)> getRoiBase,SoftNmsKernel kernel,uint32_t * select,uint32_t selectLength)373 uint32_t* softNmsSingleClass(float* scoresData, float scoreThreshold, int32_t maxNumDetections,
374                              std::function<const float*(uint32_t)> getRoiBase, SoftNmsKernel kernel,
375                              uint32_t* select, uint32_t selectLength) {
376     uint32_t *selectStart = select, *selectEnd = select + selectLength, numDetections = 0;
377     if (maxNumDetections < 0) {
378         maxNumDetections = selectLength;
379     }
380     while (selectStart < selectEnd && numDetections < static_cast<uint32_t>(maxNumDetections)) {
381         // find max score and swap to the front
382         auto& maxScore = *std::max_element(selectStart, selectEnd,
383                                            [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
384                                                return scoresData[lhs] < scoresData[rhs];
385                                            });
386         std::swap(maxScore, *selectStart);
387 
388         // Calculate IoU of the rest, swap to the end (disgard) if needed.
389         for (uint32_t* i = selectStart + 1; i < selectEnd; i++) {
390             float iou = getIoUAxisAligned(getRoiBase(*i), getRoiBase(*selectStart));
391             scoresData[*i] *= kernel(iou);
392             if (scoresData[*i] < scoreThreshold) {
393                 std::swap(*i--, *(--selectEnd));
394             }
395         }
396         selectStart++;
397         numDetections++;
398     }
399     return selectStart;
400 }
401 
softNmsMultiClass(float * scoresData,uint32_t numClasses,uint32_t numRois,float scoreThreshold,float nmsScoreThreshold,int32_t maxNumDetections,int32_t maxNumDetectionsPerClass,std::function<const float * (uint32_t)> getRoiBase,SoftNmsKernel kernel,std::vector<uint32_t> * select)402 void softNmsMultiClass(float* scoresData, uint32_t numClasses, uint32_t numRois,
403                        float scoreThreshold, float nmsScoreThreshold, int32_t maxNumDetections,
404                        int32_t maxNumDetectionsPerClass,
405                        std::function<const float*(uint32_t)> getRoiBase, SoftNmsKernel kernel,
406                        std::vector<uint32_t>* select) {
407     // Exclude class 0 (background)
408     for (uint32_t c = 1; c < numClasses; c++) {
409         uint32_t size = select->size();
410         for (uint32_t b = 0; b < numRois; b++) {
411             const uint32_t index = b * numClasses + c;
412             const float score = scoresData[index];
413             if (score > scoreThreshold) {
414                 select->push_back(index);
415             }
416         }
417         uint32_t* selectStart = select->data() + size;
418         uint32_t selectLength = select->size() - size;
419         uint32_t* selectEnd =
420                 softNmsSingleClass(scoresData, nmsScoreThreshold, maxNumDetectionsPerClass,
421                                    getRoiBase, kernel, selectStart, selectLength);
422         select->resize(selectEnd - select->data());
423     }
424 
425     // Take top maxNumDetections.
426     std::sort(select->begin(), select->end(),
427               [&scoresData](const uint32_t& lhs, const uint32_t& rhs) {
428                   return scoresData[lhs] > scoresData[rhs];
429               });
430     if (maxNumDetections < 0 || select->size() <= static_cast<size_t>(maxNumDetections)) {
431         return;
432     }
433     select->resize(maxNumDetections);
434 }
435 
boxWithNmsLimitFloat32Compute(float * scoresData,const Shape & scoresShape,const float * roiData,const Shape &,const int32_t * batchesData,const Shape &,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,std::vector<uint32_t> * batchSplitIn,std::vector<uint32_t> * batchSplitOut,std::vector<uint32_t> * selected)436 bool boxWithNmsLimitFloat32Compute(float* scoresData, const Shape& scoresShape,
437                                    const float* roiData, const Shape& /*roiShape*/,
438                                    const int32_t* batchesData, const Shape& /*batchesShape*/,
439                                    float scoreThreshold, int32_t maxNumDetections,
440                                    int32_t softNmsKernel, float iouThreshold, float sigma,
441                                    float nmsScoreThreshold, std::vector<uint32_t>* batchSplitIn,
442                                    std::vector<uint32_t>* batchSplitOut,
443                                    std::vector<uint32_t>* selected) {
444     SoftNmsKernel kernel = nullptr;
445     if (softNmsKernel == 0) {
446         kernel = [&iouThreshold](float iou) { return iou < iouThreshold ? 1.0f : 0.0f; };
447     } else if (softNmsKernel == 1) {
448         kernel = [&iouThreshold](float iou) { return iou < iouThreshold ? 1.0f : 1.0f - iou; };
449     } else if (softNmsKernel == 2) {
450         kernel = [&sigma](float iou) { return std::exp(-1.0f * iou * iou / sigma); };
451     } else {
452         NN_RET_CHECK_FAIL() << "Unsupported soft NMS kernel " << softNmsKernel;
453     }
454 
455     const uint32_t kRoiDim = 4;
456     uint32_t numRois = getSizeOfDimension(scoresShape, 0);
457     uint32_t numClasses = getSizeOfDimension(scoresShape, 1);
458 
459     // We assume boxes of the same batch are grouped together.
460     std::vector<uint32_t> batch;
461     int32_t ind = -1;
462     for (uint32_t i = 0; i < numRois; i++) {
463         if (batchesData[i] == ind) {
464             (batchSplitIn->back())++;
465         } else {
466             ind = batchesData[i];
467             batchSplitIn->push_back(1);
468         }
469     }
470 
471     float* scoresBase = scoresData;
472     const float* roiBase = roiData;
473     selected->clear();
474     for (uint32_t b = 0; b < batchSplitIn->size(); b++) {
475         for (uint32_t i = 0; i < batchSplitIn->at(b); i++) {
476             const float* roi = roiBase + i * kRoiDim;
477             // Check for malformed data: invalid region: x2 < x1 || y2 < y1
478             NN_RET_CHECK_LE(roi[0], roi[2]);
479             NN_RET_CHECK_LE(roi[1], roi[3]);
480         }
481         std::vector<uint32_t> result;
482         softNmsMultiClass(
483                 scoresBase, numClasses, batchSplitIn->at(b), scoreThreshold, nmsScoreThreshold,
484                 maxNumDetections, maxNumDetections,
485                 [&roiBase](uint32_t ind) { return roiBase + ind * kRoiDim; }, kernel, &result);
486         // Sort again by class.
487         std::sort(result.begin(), result.end(),
488                   [&scoresBase, numClasses](const uint32_t& lhs, const uint32_t& rhs) {
489                       uint32_t lhsClass = lhs % numClasses, rhsClass = rhs % numClasses;
490                       return lhsClass == rhsClass ? scoresBase[lhs] > scoresBase[rhs]
491                                                   : lhsClass < rhsClass;
492                   });
493         selected->insert(selected->end(), result.begin(), result.end());
494         batchSplitOut->push_back(result.size());
495         scoresBase += batchSplitIn->at(b) * numClasses;
496         roiBase += batchSplitIn->at(b) * numClasses * kRoiDim;
497     }
498     return true;
499 }
500 
501 template <typename T>
castTo(float val,const Shape &)502 T castTo(float val, const Shape&) {
503     return val;
504 }
505 template <>
castTo(float val,const Shape & shape)506 uint8_t castTo(float val, const Shape& shape) {
507     return saturateCast<uint8_t>(std::round(val / shape.scale + shape.offset));
508 }
509 
510 template <>
castTo(float val,const Shape & shape)511 int8_t castTo(float val, const Shape& shape) {
512     return saturateCast<int8_t>(std::round(val / shape.scale + shape.offset));
513 }
514 
515 template <typename T_Score, typename T_Roi>
boxWithNmsLimitWriteOutput(const std::vector<uint32_t> & selected,const std::vector<uint32_t> & batchSplitIn,const std::vector<uint32_t> & batchSplitOut,const std::vector<float> & scores,IOperationExecutionContext * context)516 bool boxWithNmsLimitWriteOutput(const std::vector<uint32_t>& selected,
517                                 const std::vector<uint32_t>& batchSplitIn,
518                                 const std::vector<uint32_t>& batchSplitOut,
519                                 const std::vector<float>& scores,
520                                 IOperationExecutionContext* context) {
521     const uint32_t kRoiDim = 4;
522     Shape scoresShape = context->getInputShape(kScoreTensor);
523     uint32_t numClasses = getSizeOfDimension(scoresShape, 1);
524 
525     // Set output dimensions.
526     uint32_t numOutRois = selected.size();
527     if (numOutRois == 0) return true;
528     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
529     scoresOutShape.dimensions = {numOutRois};
530     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
531 
532     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
533     roiOutShape.dimensions = {numOutRois, 4};
534     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
535 
536     Shape classesOutShape = context->getOutputShape(kOutputClassTensor);
537     classesOutShape.dimensions = {numOutRois};
538     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, classesOutShape));
539 
540     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
541     batchesOutShape.dimensions = {numOutRois};
542     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
543 
544     // Write outputs.
545     const float* scoresBase = scores.data();
546     const T_Roi* roiBase = context->getInputBuffer<T_Roi>(kRoiTensor);
547     const int32_t* batchesInPtr = context->getInputBuffer<int32_t>(kBatchesTensor);
548     T_Score* scoresOutPtr = context->getOutputBuffer<T_Score>(kOutputScoreTensor);
549     T_Roi* roiOutPtr = context->getOutputBuffer<T_Roi>(kOutputRoiTensor);
550     int32_t* classesOutPtr = context->getOutputBuffer<int32_t>(kOutputClassTensor);
551     int32_t* batchesOutPtr = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
552     uint32_t i = 0;
553     for (uint32_t b = 0; b < batchSplitOut.size(); b++) {
554         for (uint32_t j = 0; j < batchSplitOut[b]; j++) {
555             uint32_t index = selected[i++];
556             *scoresOutPtr++ = castTo<T_Score>(scoresBase[index], scoresOutShape);
557             memcpy(roiOutPtr, roiBase + index * kRoiDim, kRoiDim * sizeof(T_Roi));
558             roiOutPtr += kRoiDim;
559             *classesOutPtr++ = index % numClasses;
560             *batchesOutPtr++ = *batchesInPtr;
561         }
562         scoresBase += batchSplitIn[b] * numClasses;
563         roiBase += batchSplitIn[b] * numClasses * kRoiDim;
564         batchesInPtr += batchSplitIn[b];
565     }
566     return true;
567 }
568 
boxWithNmsLimitFloat32(const float * scoresData,const Shape & scoresShape,const float * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,float *,Shape,float *,Shape,int32_t *,Shape,int32_t *,const Shape &,IOperationExecutionContext * context)569 bool boxWithNmsLimitFloat32(const float* scoresData, const Shape& scoresShape, const float* roiData,
570                             const Shape& roiShape, const int32_t* batchesData,
571                             const Shape& batchesShape, float scoreThreshold,
572                             int32_t maxNumDetections, int32_t softNmsKernel, float iouThreshold,
573                             float sigma, float nmsScoreThreshold, float* /*scoresOutData*/,
574                             Shape /*scoresOutShape*/, float* /*roiOutData*/, Shape /*roiOutShape*/,
575                             int32_t* /*classesOutData*/, Shape /*classesOutShape*/,
576                             int32_t* /*batchesOutData*/, const Shape& /*batchSplitOutShape*/,
577                             IOperationExecutionContext* context) {
578     NNTRACE_TRANS("boxWithNmsLimit");
579     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
580     for (uint32_t i = 0; i < scores_float32.size(); i++) {
581         scores_float32[i] = scoresData[i];
582     }
583     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
584     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
585             scores_float32.data(), scoresShape, roiData, roiShape, batchesData, batchesShape,
586             scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma, nmsScoreThreshold,
587             &batchSplitIn, &batchSplitOut, &selected));
588     return boxWithNmsLimitWriteOutput<float, float>(selected, batchSplitIn, batchSplitOut,
589                                                     scores_float32, context);
590 }
591 
boxWithNmsLimitFloat16(const _Float16 * scoresData,const Shape & scoresShape,const _Float16 * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,_Float16 scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,_Float16 iouThreshold,_Float16 sigma,_Float16 nmsScoreThreshold,_Float16 *,const Shape &,_Float16 *,const Shape &,int32_t *,const Shape &,int32_t *,const Shape &,IOperationExecutionContext * context)592 bool boxWithNmsLimitFloat16(const _Float16* scoresData, const Shape& scoresShape,
593                             const _Float16* roiData, const Shape& roiShape,
594                             const int32_t* batchesData, const Shape& batchesShape,
595                             _Float16 scoreThreshold, int32_t maxNumDetections,
596                             int32_t softNmsKernel, _Float16 iouThreshold, _Float16 sigma,
597                             _Float16 nmsScoreThreshold, _Float16* /*scoresOutData*/,
598                             const Shape& /*scoresOutShape*/, _Float16* /*roiOutData*/,
599                             const Shape& /*roiOutShape*/, int32_t* /*classesOutData*/,
600                             const Shape& /*classesOutShape*/, int32_t* /*batchesOutData*/,
601                             const Shape& /*batchSplitOutShape*/,
602                             IOperationExecutionContext* context) {
603     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
604     convertFloat16ToFloat32(scoresData, &scores_float32);
605     std::vector<float> roi_float32(getNumberOfElements(roiShape));
606     convertFloat16ToFloat32(roiData, &roi_float32);
607     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
608     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
609             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
610             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
611             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
612     return boxWithNmsLimitWriteOutput<_Float16, _Float16>(selected, batchSplitIn, batchSplitOut,
613                                                           scores_float32, context);
614 }
615 
boxWithNmsLimitQuant(const uint8_t * scoresData,const Shape & scoresShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,uint8_t *,const Shape &,uint16_t *,const Shape &,int32_t *,const Shape &,int32_t *,const Shape &,IOperationExecutionContext * context)616 bool boxWithNmsLimitQuant(const uint8_t* scoresData, const Shape& scoresShape,
617                           const uint16_t* roiData, const Shape& roiShape,
618                           const int32_t* batchesData, const Shape& batchesShape,
619                           float scoreThreshold, int32_t maxNumDetections, int32_t softNmsKernel,
620                           float iouThreshold, float sigma, float nmsScoreThreshold,
621                           uint8_t* /*scoresOutData*/, const Shape& /*scoresOutShape*/,
622                           uint16_t* /*roiOutData*/, const Shape& /*roiOutShape*/,
623                           int32_t* /*classesOutData*/, const Shape& /*classesOutShape*/,
624                           int32_t* /*batchesOutData*/, const Shape& /*batchSplitOutShape*/,
625                           IOperationExecutionContext* context) {
626     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
627     convertQuantToFloat32(scoresData, scoresShape.scale, scoresShape.offset, &scores_float32);
628     std::vector<float> roi_float32(getNumberOfElements(roiShape));
629     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
630     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
631     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
632             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
633             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
634             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
635     return boxWithNmsLimitWriteOutput<uint8_t, uint16_t>(selected, batchSplitIn, batchSplitOut,
636                                                          scores_float32, context);
637 }
638 
boxWithNmsLimitQuant(const int8_t * scoresData,const Shape & scoresShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchesData,const Shape & batchesShape,float scoreThreshold,int32_t maxNumDetections,int32_t softNmsKernel,float iouThreshold,float sigma,float nmsScoreThreshold,int8_t *,const Shape &,uint16_t *,const Shape &,int32_t *,const Shape &,int32_t *,const Shape &,IOperationExecutionContext * context)639 bool boxWithNmsLimitQuant(const int8_t* scoresData, const Shape& scoresShape,
640                           const uint16_t* roiData, const Shape& roiShape,
641                           const int32_t* batchesData, const Shape& batchesShape,
642                           float scoreThreshold, int32_t maxNumDetections, int32_t softNmsKernel,
643                           float iouThreshold, float sigma, float nmsScoreThreshold,
644                           int8_t* /*scoresOutData*/, const Shape& /*scoresOutShape*/,
645                           uint16_t* /*roiOutData*/, const Shape& /*roiOutShape*/,
646                           int32_t* /*classesOutData*/, const Shape& /*classesOutShape*/,
647                           int32_t* /*batchesOutData*/, const Shape& /*batchSplitOutShape*/,
648                           IOperationExecutionContext* context) {
649     std::vector<float> scores_float32(getNumberOfElements(scoresShape));
650     convertQuantToFloat32<int8_t>(scoresData, scoresShape.scale, scoresShape.offset,
651                                   &scores_float32);
652     std::vector<float> roi_float32(getNumberOfElements(roiShape));
653     convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
654     std::vector<uint32_t> selected, batchSplitIn, batchSplitOut;
655     NN_RET_CHECK(boxWithNmsLimitFloat32Compute(
656             scores_float32.data(), scoresShape, roi_float32.data(), roiShape, batchesData,
657             batchesShape, scoreThreshold, maxNumDetections, softNmsKernel, iouThreshold, sigma,
658             nmsScoreThreshold, &batchSplitIn, &batchSplitOut, &selected));
659     return boxWithNmsLimitWriteOutput<int8_t, uint16_t>(selected, batchSplitIn, batchSplitOut,
660                                                         scores_float32, context);
661 }
662 
663 }  // namespace
664 
prepare(IOperationExecutionContext * context)665 bool prepare(IOperationExecutionContext* context) {
666     Shape scoreShape = context->getInputShape(kScoreTensor);
667     Shape roiShape = context->getInputShape(kRoiTensor);
668     Shape batchesShape = context->getInputShape(kBatchesTensor);
669     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
670     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
671     Shape outputClassShape = context->getOutputShape(kOutputClassTensor);
672     Shape outputBatchSplitShape = context->getOutputShape(kOutputBatchesTensor);
673 
674     NN_RET_CHECK(getNumberOfDimensions(scoreShape) == 2);
675     NN_RET_CHECK(getNumberOfDimensions(roiShape) == 2);
676     NN_RET_CHECK(getNumberOfDimensions(batchesShape) == 1);
677 
678     // Only numRois can be zero.
679     const uint32_t kRoiDim = 4;
680     uint32_t numRois = getSizeOfDimension(scoreShape, 0);
681     uint32_t numClasses = getSizeOfDimension(scoreShape, 1);
682     NN_RET_CHECK(getSizeOfDimension(roiShape, 0) == numRois);
683     NN_RET_CHECK(getSizeOfDimension(roiShape, 1) == kRoiDim * numClasses);
684     NN_RET_CHECK(getSizeOfDimension(batchesShape, 0) == numRois);
685     NN_RET_CHECK_GT(numClasses, 1u);
686 
687     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM ||
688         scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
689         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
690         NN_RET_CHECK_EQ(roiShape.offset, 0);
691     }
692 
693     outputScoreShape.type = scoreShape.type;
694     outputScoreShape.dimensions = {0};
695     outputScoreShape.scale = scoreShape.scale;
696     outputScoreShape.offset = scoreShape.offset;
697     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
698 
699     outputRoiShape.type = roiShape.type;
700     outputRoiShape.dimensions = {0, 4};
701     outputRoiShape.scale = 0.f;
702     outputRoiShape.offset = 0;
703     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM ||
704         scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
705         outputRoiShape.scale = 0.125f;
706     }
707     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
708 
709     outputClassShape.type = OperandType::TENSOR_INT32;
710     outputClassShape.dimensions = {0};
711     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, outputClassShape));
712 
713     outputBatchSplitShape.type = batchesShape.type;
714     outputBatchSplitShape.dimensions = {0};
715     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, outputBatchSplitShape));
716     return true;
717 }
718 
execute(IOperationExecutionContext * context)719 bool execute(IOperationExecutionContext* context) {
720     NNTRACE_TRANS("boxWithNMSLimit");
721     // Bypass execution in the case of zero numRois.
722     if (getSizeOfDimension(context->getInputShape(kScoreTensor), 0) == 0) return true;
723     switch (context->getInputType(kScoreTensor)) {
724         case OperandType::TENSOR_FLOAT16: {
725             return boxWithNmsLimitFloat16(
726                     context->getInputBuffer<_Float16>(kScoreTensor),
727                     context->getInputShape(kScoreTensor),
728                     context->getInputBuffer<_Float16>(kRoiTensor),
729                     context->getInputShape(kRoiTensor),
730                     context->getInputBuffer<int32_t>(kBatchesTensor),
731                     context->getInputShape(kBatchesTensor),
732                     context->getInputValue<_Float16>(kScoreThresholdScalar),
733                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
734                     context->getInputValue<int32_t>(kNmsKernelScalar),
735                     context->getInputValue<_Float16>(kIoUThresholdScalar),
736                     context->getInputValue<_Float16>(kSigmaScalar),
737                     context->getInputValue<_Float16>(kNmsScoreThresholdScalar),
738                     context->getOutputBuffer<_Float16>(kOutputScoreTensor),
739                     context->getOutputShape(kOutputScoreTensor),
740                     context->getOutputBuffer<_Float16>(kOutputRoiTensor),
741                     context->getOutputShape(kOutputRoiTensor),
742                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
743                     context->getOutputShape(kOutputClassTensor),
744                     context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
745                     context->getOutputShape(kOutputBatchesTensor), context);
746         }
747         case OperandType::TENSOR_FLOAT32: {
748             return boxWithNmsLimitFloat32(context->getInputBuffer<float>(kScoreTensor),
749                                           context->getInputShape(kScoreTensor),
750                                           context->getInputBuffer<float>(kRoiTensor),
751                                           context->getInputShape(kRoiTensor),
752                                           context->getInputBuffer<int32_t>(kBatchesTensor),
753                                           context->getInputShape(kBatchesTensor),
754                                           context->getInputValue<float>(kScoreThresholdScalar),
755                                           context->getInputValue<int32_t>(kMaxNumDetectionScalar),
756                                           context->getInputValue<int32_t>(kNmsKernelScalar),
757                                           context->getInputValue<float>(kIoUThresholdScalar),
758                                           context->getInputValue<float>(kSigmaScalar),
759                                           context->getInputValue<float>(kNmsScoreThresholdScalar),
760                                           context->getOutputBuffer<float>(kOutputScoreTensor),
761                                           context->getOutputShape(kOutputScoreTensor),
762                                           context->getOutputBuffer<float>(kOutputRoiTensor),
763                                           context->getOutputShape(kOutputRoiTensor),
764                                           context->getOutputBuffer<int32_t>(kOutputClassTensor),
765                                           context->getOutputShape(kOutputClassTensor),
766                                           context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
767                                           context->getOutputShape(kOutputBatchesTensor), context);
768         }
769         case OperandType::TENSOR_QUANT8_ASYMM: {
770             return boxWithNmsLimitQuant(context->getInputBuffer<uint8_t>(kScoreTensor),
771                                         context->getInputShape(kScoreTensor),
772                                         context->getInputBuffer<uint16_t>(kRoiTensor),
773                                         context->getInputShape(kRoiTensor),
774                                         context->getInputBuffer<int32_t>(kBatchesTensor),
775                                         context->getInputShape(kBatchesTensor),
776                                         context->getInputValue<float>(kScoreThresholdScalar),
777                                         context->getInputValue<int32_t>(kMaxNumDetectionScalar),
778                                         context->getInputValue<int32_t>(kNmsKernelScalar),
779                                         context->getInputValue<float>(kIoUThresholdScalar),
780                                         context->getInputValue<float>(kSigmaScalar),
781                                         context->getInputValue<float>(kNmsScoreThresholdScalar),
782                                         context->getOutputBuffer<uint8_t>(kOutputScoreTensor),
783                                         context->getOutputShape(kOutputScoreTensor),
784                                         context->getOutputBuffer<uint16_t>(kOutputRoiTensor),
785                                         context->getOutputShape(kOutputRoiTensor),
786                                         context->getOutputBuffer<int32_t>(kOutputClassTensor),
787                                         context->getOutputShape(kOutputClassTensor),
788                                         context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
789                                         context->getOutputShape(kOutputBatchesTensor), context);
790         }
791         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
792             return boxWithNmsLimitQuant(context->getInputBuffer<int8_t>(kScoreTensor),
793                                         context->getInputShape(kScoreTensor),
794                                         context->getInputBuffer<uint16_t>(kRoiTensor),
795                                         context->getInputShape(kRoiTensor),
796                                         context->getInputBuffer<int32_t>(kBatchesTensor),
797                                         context->getInputShape(kBatchesTensor),
798                                         context->getInputValue<float>(kScoreThresholdScalar),
799                                         context->getInputValue<int32_t>(kMaxNumDetectionScalar),
800                                         context->getInputValue<int32_t>(kNmsKernelScalar),
801                                         context->getInputValue<float>(kIoUThresholdScalar),
802                                         context->getInputValue<float>(kSigmaScalar),
803                                         context->getInputValue<float>(kNmsScoreThresholdScalar),
804                                         context->getOutputBuffer<int8_t>(kOutputScoreTensor),
805                                         context->getOutputShape(kOutputScoreTensor),
806                                         context->getOutputBuffer<uint16_t>(kOutputRoiTensor),
807                                         context->getOutputShape(kOutputRoiTensor),
808                                         context->getOutputBuffer<int32_t>(kOutputClassTensor),
809                                         context->getOutputShape(kOutputClassTensor),
810                                         context->getOutputBuffer<int32_t>(kOutputBatchesTensor),
811                                         context->getOutputShape(kOutputBatchesTensor), context);
812         }
813         default:
814             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
815     }
816 }
817 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
818 
819 }  // namespace box_with_nms_limit
820 
821 namespace generate_proposals {
822 
823 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
824 namespace {
825 
filterBoxes(const float * roiBase,const float * imageInfoBase,float minSize,std::vector<uint32_t> * select)826 void filterBoxes(const float* roiBase, const float* imageInfoBase, float minSize,
827                  std::vector<uint32_t>* select) {
828     const uint32_t kRoiDim = 4;
829     uint32_t i = 0;
830     for (uint32_t j = 0; j < select->size(); j++) {
831         const float* roiInfo = roiBase + (*select)[j] * kRoiDim;
832         float roiWidth, roiHeight, xRoiCenter, yRoiCenter;
833         roiWidth = roiInfo[2] - roiInfo[0];
834         roiHeight = roiInfo[3] - roiInfo[1];
835         xRoiCenter = roiInfo[0] + roiWidth / 2.0f;
836         yRoiCenter = roiInfo[1] + roiHeight / 2.0f;
837         if (roiWidth > minSize && roiHeight > minSize && xRoiCenter < imageInfoBase[1] &&
838             yRoiCenter < imageInfoBase[0]) {
839             (*select)[i++] = (*select)[j];
840         }
841     }
842     select->resize(i);
843 }
844 
generateProposalsNhwcFloat32Compute(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,std::vector<float> * scoresOutData,std::vector<float> * roiOutData,std::vector<int32_t> * batchesOutData)845 bool generateProposalsNhwcFloat32Compute(const float* scoresData, const Shape& scoresShape,
846                                          const float* bboxDeltasData, const Shape& bboxDeltasShape,
847                                          const float* anchorsData, const Shape& anchorsShape,
848                                          const float* imageInfoData, const Shape& imageInfoShape,
849                                          float heightStride, float widthStride, int32_t preNmsTopN,
850                                          int32_t postNmsTopN, float iouThreshold, float minSize,
851                                          std::vector<float>* scoresOutData,
852                                          std::vector<float>* roiOutData,
853                                          std::vector<int32_t>* batchesOutData) {
854     const uint32_t kRoiDim = 4;
855     uint32_t numBatches = getSizeOfDimension(scoresShape, 0);
856     uint32_t height = getSizeOfDimension(scoresShape, 1);
857     uint32_t width = getSizeOfDimension(scoresShape, 2);
858     uint32_t numAnchors = getSizeOfDimension(scoresShape, 3);
859     uint32_t imageInfoLength = getSizeOfDimension(imageInfoShape, 1);
860 
861     uint32_t batchSize = height * width * numAnchors;
862     uint32_t roiBufferSize = batchSize * kRoiDim;
863     std::vector<float> roiBuffer(roiBufferSize);
864     std::vector<float> roiTransformedBuffer(roiBufferSize);
865     scoresOutData->clear();
866     roiOutData->clear();
867     batchesOutData->clear();
868 
869     // Compute the roi region for each anchor.
870     float* roiBase = roiBuffer.data();
871     for (uint32_t h = 0; h < height; h++) {
872         float hShift = h * heightStride;
873         for (uint32_t w = 0; w < width; w++) {
874             const float* anchorsBase = anchorsData;
875             float wShift = w * widthStride;
876             for (uint32_t a = 0; a < numAnchors; a++, roiBase += kRoiDim, anchorsBase += kRoiDim) {
877                 roiBase[0] = anchorsBase[0] + wShift;
878                 roiBase[1] = anchorsBase[1] + hShift;
879                 roiBase[2] = anchorsBase[2] + wShift;
880                 roiBase[3] = anchorsBase[3] + hShift;
881             }
882         }
883     }
884 
885     const float* scoresBase = scoresData;
886     const float* bboxDeltasBase = bboxDeltasData;
887     const float* imageInfoBase = imageInfoData;
888     // Need to fake some data to satisfy bboxTransform.
889     Shape tempRoiShape = anchorsShape;
890     tempRoiShape.dimensions = {batchSize, kRoiDim};
891     Shape tempBBoxDeltasShape = bboxDeltasShape;
892     tempBBoxDeltasShape.dimensions = {batchSize, kRoiDim};
893     std::vector<int32_t> tempBatchSplitData(batchSize, 0);
894     Shape tempbatchSplitShape = {.dimensions = {batchSize}};
895     Shape tempImageInfoShape = imageInfoShape;
896     tempImageInfoShape.dimensions = {1, imageInfoLength};
897 
898     for (uint32_t b = 0; b < numBatches; b++) {
899         // Apply bboxDeltas to anchor locations.
900         float tempImageInfo[] = {imageInfoBase[0], imageInfoBase[1]};
901         if (!bboxTransformFloat32(roiBuffer.data(), tempRoiShape, bboxDeltasBase,
902                                   tempBBoxDeltasShape, tempBatchSplitData.data(),
903                                   tempbatchSplitShape, tempImageInfo, tempImageInfoShape,
904                                   roiTransformedBuffer.data(), tempRoiShape)) {
905             LOG(ERROR) << "BBoxTransform step failed in GENERATE_PROPOSALS op.";
906             return false;
907         }
908 
909         // Find the top preNmsTopN scores.
910         std::vector<uint32_t> select(batchSize);
911         std::iota(select.begin(), select.end(), 0);
912         if (preNmsTopN > 0 && static_cast<size_t>(preNmsTopN) < select.size()) {
913             std::sort(select.begin(), select.end(),
914                       [&scoresBase](const uint32_t lhs, const uint32_t rhs) {
915                           return scoresBase[lhs] > scoresBase[rhs];
916                       });
917             select.resize(preNmsTopN);
918         }
919 
920         // Filter boxes, disgard regions with height or width < minSize.
921         filterBoxes(roiTransformedBuffer.data(), imageInfoBase, minSize, &select);
922 
923         // Apply hard NMS.
924         uint32_t* selectEnd = box_with_nms_limit::hardNmsSingleClass(
925                 scoresBase, iouThreshold, postNmsTopN,
926                 [&roiTransformedBuffer](uint32_t ind) {
927                     return roiTransformedBuffer.data() + ind * kRoiDim;
928                 },
929                 select.data(), select.size());
930         uint32_t selectSize = selectEnd - select.data();
931         select.resize(selectSize);
932 
933         // Write output.
934         for (auto i : select) {
935             roiOutData->insert(roiOutData->end(), roiTransformedBuffer.begin() + i * kRoiDim,
936                                roiTransformedBuffer.begin() + (i + 1) * kRoiDim);
937             scoresOutData->push_back(scoresBase[i]);
938             batchesOutData->push_back(b);
939         }
940         scoresBase += batchSize;
941         bboxDeltasBase += roiBufferSize;
942         imageInfoBase += imageInfoLength;
943     }
944     return true;
945 }
946 
generateProposalsFloat32Compute(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,std::vector<float> * scoresOutData,std::vector<float> * roiOutData,std::vector<int32_t> * batchesOutData)947 bool generateProposalsFloat32Compute(const float* scoresData, const Shape& scoresShape,
948                                      const float* bboxDeltasData, const Shape& bboxDeltasShape,
949                                      const float* anchorsData, const Shape& anchorsShape,
950                                      const float* imageInfoData, const Shape& imageInfoShape,
951                                      float heightStride, float widthStride, int32_t preNmsTopN,
952                                      int32_t postNmsTopN, float iouThreshold, float minSize,
953                                      bool useNchw, std::vector<float>* scoresOutData,
954                                      std::vector<float>* roiOutData,
955                                      std::vector<int32_t>* batchesOutData) {
956     InputWithLayout<float> score_nhwc(useNchw), delta_nhwc(useNchw);
957     NN_RET_CHECK(score_nhwc.initialize(scoresData, scoresShape));
958     NN_RET_CHECK(delta_nhwc.initialize(bboxDeltasData, bboxDeltasShape));
959     return generateProposalsNhwcFloat32Compute(
960             score_nhwc.getNhwcBuffer(), score_nhwc.getNhwcShape(), delta_nhwc.getNhwcBuffer(),
961             delta_nhwc.getNhwcShape(), anchorsData, anchorsShape, imageInfoData, imageInfoShape,
962             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize,
963             scoresOutData, roiOutData, batchesOutData);
964 }
965 
generateProposalsFloat32(const float * scoresData,const Shape & scoresShape,const float * bboxDeltasData,const Shape & bboxDeltasShape,const float * anchorsData,const Shape & anchorsShape,const float * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)966 bool generateProposalsFloat32(const float* scoresData, const Shape& scoresShape,
967                               const float* bboxDeltasData, const Shape& bboxDeltasShape,
968                               const float* anchorsData, const Shape& anchorsShape,
969                               const float* imageInfoData, const Shape& imageInfoShape,
970                               float heightStride, float widthStride, int32_t preNmsTopN,
971                               int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
972                               IOperationExecutionContext* context) {
973     std::vector<float> scoresOut_float32, roiOut_float32;
974     std::vector<int32_t> batchesOut;
975     NN_RET_CHECK(generateProposalsFloat32Compute(
976             scoresData, scoresShape, bboxDeltasData, bboxDeltasShape, anchorsData, anchorsShape,
977             imageInfoData, imageInfoShape, heightStride, widthStride, preNmsTopN, postNmsTopN,
978             iouThreshold, minSize, useNchw, &scoresOut_float32, &roiOut_float32, &batchesOut));
979 
980     // Set output dimensions.
981     uint32_t numOutRois = scoresOut_float32.size();
982     if (numOutRois == 0) return true;
983     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
984     scoresOutShape.dimensions = {numOutRois};
985     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
986     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
987     roiOutShape.dimensions = {numOutRois, 4};
988     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
989     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
990     batchesOutShape.dimensions = {numOutRois};
991     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
992 
993     // Write outputs.
994     float* scoresOutData = context->getOutputBuffer<float>(kOutputScoreTensor);
995     for (uint32_t i = 0; i < scoresOut_float32.size(); i++) {
996         scoresOutData[i] = scoresOut_float32[i];
997     }
998     float* roiOutData = context->getOutputBuffer<float>(kOutputRoiTensor);
999     for (uint32_t i = 0; i < roiOut_float32.size(); i++) {
1000         roiOutData[i] = roiOut_float32[i];
1001     }
1002     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1003     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1004         batchesOutData[i] = batchesOut[i];
1005     }
1006     return true;
1007 }
1008 
generateProposalsFloat16(const _Float16 * scoresData,const Shape & scoresShape,const _Float16 * bboxDeltasData,const Shape & bboxDeltasShape,const _Float16 * anchorsData,const Shape & anchorsShape,const _Float16 * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)1009 bool generateProposalsFloat16(const _Float16* scoresData, const Shape& scoresShape,
1010                               const _Float16* bboxDeltasData, const Shape& bboxDeltasShape,
1011                               const _Float16* anchorsData, const Shape& anchorsShape,
1012                               const _Float16* imageInfoData, const Shape& imageInfoShape,
1013                               float heightStride, float widthStride, int32_t preNmsTopN,
1014                               int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
1015                               IOperationExecutionContext* context) {
1016     std::vector<float> score_float32(getNumberOfElements(scoresShape));
1017     convertFloat16ToFloat32(scoresData, &score_float32);
1018     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
1019     convertFloat16ToFloat32(bboxDeltasData, &delta_float32);
1020     std::vector<float> anchors_float32(getNumberOfElements(anchorsShape));
1021     convertFloat16ToFloat32(anchorsData, &anchors_float32);
1022     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoShape));
1023     convertFloat16ToFloat32(imageInfoData, &imageInfo_float32);
1024     std::vector<float> scoresOut_float32, roiOut_float32;
1025     std::vector<int32_t> batchesOut;
1026     NN_RET_CHECK(generateProposalsFloat32Compute(
1027             score_float32.data(), scoresShape, delta_float32.data(), bboxDeltasShape,
1028             anchors_float32.data(), anchorsShape, imageInfo_float32.data(), imageInfoShape,
1029             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize, useNchw,
1030             &scoresOut_float32, &roiOut_float32, &batchesOut));
1031 
1032     // Set output dimensions.
1033     uint32_t numOutRois = scoresOut_float32.size();
1034     if (numOutRois == 0) return true;
1035     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
1036     scoresOutShape.dimensions = {numOutRois};
1037     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
1038     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
1039     roiOutShape.dimensions = {numOutRois, 4};
1040     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
1041     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
1042     batchesOutShape.dimensions = {numOutRois};
1043     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
1044 
1045     // Write outputs.
1046     _Float16* scoresOutData = context->getOutputBuffer<_Float16>(kOutputScoreTensor);
1047     convertFloat32ToFloat16(scoresOut_float32, scoresOutData);
1048     _Float16* roiOutData = context->getOutputBuffer<_Float16>(kOutputRoiTensor);
1049     convertFloat32ToFloat16(roiOut_float32, roiOutData);
1050     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1051     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1052         batchesOutData[i] = batchesOut[i];
1053     }
1054     return true;
1055 }
1056 
1057 template <typename T_8QInput>
generateProposalsQuant(const T_8QInput * scoresData,const Shape & scoresShape,const T_8QInput * bboxDeltasData,const Shape & bboxDeltasShape,const int16_t * anchorsData,const Shape & anchorsShape,const uint16_t * imageInfoData,const Shape & imageInfoShape,float heightStride,float widthStride,int32_t preNmsTopN,int32_t postNmsTopN,float iouThreshold,float minSize,bool useNchw,IOperationExecutionContext * context)1058 bool generateProposalsQuant(const T_8QInput* scoresData, const Shape& scoresShape,
1059                             const T_8QInput* bboxDeltasData, const Shape& bboxDeltasShape,
1060                             const int16_t* anchorsData, const Shape& anchorsShape,
1061                             const uint16_t* imageInfoData, const Shape& imageInfoShape,
1062                             float heightStride, float widthStride, int32_t preNmsTopN,
1063                             int32_t postNmsTopN, float iouThreshold, float minSize, bool useNchw,
1064                             IOperationExecutionContext* context) {
1065     std::vector<float> score_float32(getNumberOfElements(scoresShape));
1066     convertQuantToFloat32<T_8QInput>(scoresData, scoresShape.scale, scoresShape.offset,
1067                                      &score_float32);
1068     std::vector<float> delta_float32(getNumberOfElements(bboxDeltasShape));
1069     convertQuantToFloat32<T_8QInput>(bboxDeltasData, bboxDeltasShape.scale, bboxDeltasShape.offset,
1070                                      &delta_float32);
1071     std::vector<float> anchors_float32(getNumberOfElements(anchorsShape));
1072     convertQuantToFloat32(anchorsData, anchorsShape.scale, anchorsShape.offset, &anchors_float32);
1073     std::vector<float> imageInfo_float32(getNumberOfElements(imageInfoShape));
1074     convertQuantToFloat32(imageInfoData, imageInfoShape.scale, imageInfoShape.offset,
1075                           &imageInfo_float32);
1076     std::vector<float> scoresOut_float32, roiOut_float32;
1077     std::vector<int32_t> batchesOut;
1078     NN_RET_CHECK(generateProposalsFloat32Compute(
1079             score_float32.data(), scoresShape, delta_float32.data(), bboxDeltasShape,
1080             anchors_float32.data(), anchorsShape, imageInfo_float32.data(), imageInfoShape,
1081             heightStride, widthStride, preNmsTopN, postNmsTopN, iouThreshold, minSize, useNchw,
1082             &scoresOut_float32, &roiOut_float32, &batchesOut));
1083 
1084     // Set output dimensions.
1085     uint32_t numOutRois = scoresOut_float32.size();
1086     if (numOutRois == 0) return true;
1087     Shape scoresOutShape = context->getOutputShape(kOutputScoreTensor);
1088     scoresOutShape.dimensions = {numOutRois};
1089     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, scoresOutShape));
1090     Shape roiOutShape = context->getOutputShape(kOutputRoiTensor);
1091     roiOutShape.dimensions = {numOutRois, 4};
1092     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, roiOutShape));
1093     Shape batchesOutShape = context->getOutputShape(kOutputBatchesTensor);
1094     batchesOutShape.dimensions = {numOutRois};
1095     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, batchesOutShape));
1096 
1097     // Write outputs.
1098     T_8QInput* scoresOutData = context->getOutputBuffer<T_8QInput>(kOutputScoreTensor);
1099     convertFloat32ToQuant<T_8QInput>(scoresOut_float32, scoresOutShape.scale, scoresOutShape.offset,
1100                                      scoresOutData);
1101     uint16_t* roiOutData = context->getOutputBuffer<uint16_t>(kOutputRoiTensor);
1102     convertFloat32ToQuant(roiOut_float32, roiOutShape.scale, roiOutShape.offset, roiOutData);
1103     int32_t* batchesOutData = context->getOutputBuffer<int32_t>(kOutputBatchesTensor);
1104     for (uint32_t i = 0; i < batchesOut.size(); i++) {
1105         batchesOutData[i] = batchesOut[i];
1106     }
1107     return true;
1108 }
1109 
1110 }  // namespace
1111 
prepare(IOperationExecutionContext * context)1112 bool prepare(IOperationExecutionContext* context) {
1113     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
1114     Shape scoreShape = context->getInputShape(kScoreTensor);
1115     Shape bboxDeltasShape = context->getInputShape(kDeltaTensor);
1116     Shape anchorsShape = context->getInputShape(kAnchorTensor);
1117     Shape imageInfoDataShape = context->getInputShape(kImageInfoTensor);
1118     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
1119     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
1120     Shape outputBatchSplitShape = context->getOutputShape(kOutputBatchesTensor);
1121 
1122     NN_RET_CHECK_EQ(getNumberOfDimensions(scoreShape), 4u);
1123     NN_RET_CHECK_EQ(getNumberOfDimensions(bboxDeltasShape), 4u);
1124     NN_RET_CHECK_EQ(getNumberOfDimensions(anchorsShape), 2u);
1125     NN_RET_CHECK_EQ(getNumberOfDimensions(imageInfoDataShape), 2u);
1126 
1127     const uint32_t kRoiDim = 4;
1128     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1129     uint32_t height = getSizeOfDimension(scoreShape, useNchw ? 2 : 1);
1130     uint32_t width = getSizeOfDimension(scoreShape, useNchw ? 3 : 2);
1131     uint32_t numAnchors = getSizeOfDimension(scoreShape, useNchw ? 1 : 3);
1132 
1133     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, 0), numBatches);
1134     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 2 : 1), height);
1135     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 3 : 2), width);
1136     NN_RET_CHECK_EQ(getSizeOfDimension(bboxDeltasShape, useNchw ? 1 : 3), numAnchors * kRoiDim);
1137     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoDataShape, 0), numBatches);
1138     NN_RET_CHECK_EQ(getSizeOfDimension(imageInfoDataShape, 1), 2u);
1139     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 0), numAnchors);
1140     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 1), kRoiDim);
1141 
1142     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
1143         NN_RET_CHECK_EQ(anchorsShape.scale, 0.125f);
1144         NN_RET_CHECK_EQ(imageInfoDataShape.scale, 0.125f);
1145         NN_RET_CHECK_EQ(imageInfoDataShape.offset, 0);
1146     }
1147 
1148     outputScoreShape.type = scoreShape.type;
1149     outputScoreShape.dimensions = {0};
1150     outputScoreShape.scale = scoreShape.scale;
1151     outputScoreShape.offset = scoreShape.offset;
1152     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
1153 
1154     outputRoiShape.dimensions = {0, 4};
1155     if (scoreShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
1156         outputRoiShape.scale = 0.125f;
1157         outputRoiShape.offset = 0;
1158     }
1159     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
1160 
1161     outputBatchSplitShape.dimensions = {0};
1162     NN_RET_CHECK(context->setOutputShape(kOutputBatchesTensor, outputBatchSplitShape));
1163     return true;
1164 }
1165 
execute(IOperationExecutionContext * context)1166 bool execute(IOperationExecutionContext* context) {
1167     NNTRACE_TRANS("generateProposals");
1168     switch (context->getInputType(kScoreTensor)) {
1169         case OperandType::TENSOR_FLOAT16: {
1170             return generateProposalsFloat16(context->getInputBuffer<_Float16>(kScoreTensor),
1171                                             context->getInputShape(kScoreTensor),
1172                                             context->getInputBuffer<_Float16>(kDeltaTensor),
1173                                             context->getInputShape(kDeltaTensor),
1174                                             context->getInputBuffer<_Float16>(kAnchorTensor),
1175                                             context->getInputShape(kAnchorTensor),
1176                                             context->getInputBuffer<_Float16>(kImageInfoTensor),
1177                                             context->getInputShape(kImageInfoTensor),
1178                                             context->getInputValue<_Float16>(kHeightStrideSalar),
1179                                             context->getInputValue<_Float16>(kWidthStrideScalar),
1180                                             context->getInputValue<int32_t>(kPreNmsMaxScalar),
1181                                             context->getInputValue<int32_t>(kPostNmsMaxScalar),
1182                                             context->getInputValue<_Float16>(kIoUThresholdScalar),
1183                                             context->getInputValue<_Float16>(kMinSizeScalar),
1184                                             context->getInputValue<bool>(kLayoutScalar), context);
1185         }
1186         case OperandType::TENSOR_FLOAT32: {
1187             return generateProposalsFloat32(context->getInputBuffer<float>(kScoreTensor),
1188                                             context->getInputShape(kScoreTensor),
1189                                             context->getInputBuffer<float>(kDeltaTensor),
1190                                             context->getInputShape(kDeltaTensor),
1191                                             context->getInputBuffer<float>(kAnchorTensor),
1192                                             context->getInputShape(kAnchorTensor),
1193                                             context->getInputBuffer<float>(kImageInfoTensor),
1194                                             context->getInputShape(kImageInfoTensor),
1195                                             context->getInputValue<float>(kHeightStrideSalar),
1196                                             context->getInputValue<float>(kWidthStrideScalar),
1197                                             context->getInputValue<int32_t>(kPreNmsMaxScalar),
1198                                             context->getInputValue<int32_t>(kPostNmsMaxScalar),
1199                                             context->getInputValue<float>(kIoUThresholdScalar),
1200                                             context->getInputValue<float>(kMinSizeScalar),
1201                                             context->getInputValue<bool>(kLayoutScalar), context);
1202         }
1203         case OperandType::TENSOR_QUANT8_ASYMM: {
1204             return generateProposalsQuant(context->getInputBuffer<uint8_t>(kScoreTensor),
1205                                           context->getInputShape(kScoreTensor),
1206                                           context->getInputBuffer<uint8_t>(kDeltaTensor),
1207                                           context->getInputShape(kDeltaTensor),
1208                                           context->getInputBuffer<int16_t>(kAnchorTensor),
1209                                           context->getInputShape(kAnchorTensor),
1210                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
1211                                           context->getInputShape(kImageInfoTensor),
1212                                           context->getInputValue<float>(kHeightStrideSalar),
1213                                           context->getInputValue<float>(kWidthStrideScalar),
1214                                           context->getInputValue<int32_t>(kPreNmsMaxScalar),
1215                                           context->getInputValue<int32_t>(kPostNmsMaxScalar),
1216                                           context->getInputValue<float>(kIoUThresholdScalar),
1217                                           context->getInputValue<float>(kMinSizeScalar),
1218                                           context->getInputValue<bool>(kLayoutScalar), context);
1219         }
1220         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED: {
1221             return generateProposalsQuant(context->getInputBuffer<int8_t>(kScoreTensor),
1222                                           context->getInputShape(kScoreTensor),
1223                                           context->getInputBuffer<int8_t>(kDeltaTensor),
1224                                           context->getInputShape(kDeltaTensor),
1225                                           context->getInputBuffer<int16_t>(kAnchorTensor),
1226                                           context->getInputShape(kAnchorTensor),
1227                                           context->getInputBuffer<uint16_t>(kImageInfoTensor),
1228                                           context->getInputShape(kImageInfoTensor),
1229                                           context->getInputValue<float>(kHeightStrideSalar),
1230                                           context->getInputValue<float>(kWidthStrideScalar),
1231                                           context->getInputValue<int32_t>(kPreNmsMaxScalar),
1232                                           context->getInputValue<int32_t>(kPostNmsMaxScalar),
1233                                           context->getInputValue<float>(kIoUThresholdScalar),
1234                                           context->getInputValue<float>(kMinSizeScalar),
1235                                           context->getInputValue<bool>(kLayoutScalar), context);
1236         }
1237         default:
1238             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1239     }
1240 }
1241 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
1242 
1243 }  // namespace generate_proposals
1244 
1245 namespace detection_postprocess {
1246 
1247 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
1248 namespace {
1249 
detectionPostprocessFloat32(const float * scoreData,const Shape & scoreShape,const float * deltaData,const Shape & deltaShape,const float * anchorData,const Shape &,float scaleY,float scaleX,float scaleH,float scaleW,bool useRegularNms,int32_t maxNumDetections,int32_t maxClassesPerDetection,int32_t maxNumDetectionsPerClass,float iouThreshold,float scoreThreshold,bool isBGInLabel,float * scoreOutData,const Shape & scoreOutShape,float * roiOutData,const Shape & roiOutShape,int32_t * classOutData,const Shape & classOutShape,int32_t * detectionOutData,const Shape & detectionOutShape)1250 bool detectionPostprocessFloat32(const float* scoreData, const Shape& scoreShape,
1251                                  const float* deltaData, const Shape& deltaShape,
1252                                  const float* anchorData, const Shape& /*anchorShape*/,
1253                                  float scaleY, float scaleX, float scaleH, float scaleW,
1254                                  bool useRegularNms, int32_t maxNumDetections,
1255                                  int32_t maxClassesPerDetection, int32_t maxNumDetectionsPerClass,
1256                                  float iouThreshold, float scoreThreshold, bool isBGInLabel,
1257                                  float* scoreOutData, const Shape& scoreOutShape, float* roiOutData,
1258                                  const Shape& roiOutShape, int32_t* classOutData,
1259                                  const Shape& classOutShape, int32_t* detectionOutData,
1260                                  const Shape& detectionOutShape) {
1261     const uint32_t kRoiDim = 4;
1262     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1263     uint32_t numAnchors = getSizeOfDimension(scoreShape, 1);
1264     uint32_t numClasses = getSizeOfDimension(scoreShape, 2);
1265     uint32_t lengthBoxEncoding = getSizeOfDimension(deltaShape, 2);
1266     uint32_t numOutDetection = getSizeOfDimension(scoreOutShape, 1);
1267 
1268     memset(scoreOutData, 0, getNumberOfElements(scoreOutShape) * sizeof(float));
1269     memset(roiOutData, 0, getNumberOfElements(roiOutShape) * sizeof(float));
1270     memset(classOutData, 0, getNumberOfElements(classOutShape) * sizeof(int32_t));
1271     memset(detectionOutData, 0, getNumberOfElements(detectionOutShape) * sizeof(int32_t));
1272 
1273     const float* scoreBase = scoreData;
1274     const float* deltaBase = deltaData;
1275     float* scoreOutBase = scoreOutData;
1276     float* roiOutBase = roiOutData;
1277     int32_t* classOutBase = classOutData;
1278     std::vector<float> roiBuffer(numAnchors * kRoiDim);
1279     std::vector<float> scoreBuffer(numAnchors);
1280     for (uint32_t b = 0; b < numBatches; b++) {
1281         const float* anchorBase = anchorData;
1282         for (uint32_t a = 0; a < numAnchors; a++) {
1283             float yCtr = anchorBase[0] + anchorBase[2] * deltaBase[0] / scaleY;
1284             float xCtr = anchorBase[1] + anchorBase[3] * deltaBase[1] / scaleX;
1285             float hHalf = anchorBase[2] * std::exp(deltaBase[2] / scaleH) * 0.5f;
1286             float wHalf = anchorBase[3] * std::exp(deltaBase[3] / scaleW) * 0.5f;
1287             roiBuffer[a * kRoiDim] = yCtr - hHalf;
1288             roiBuffer[a * kRoiDim + 1] = xCtr - wHalf;
1289             roiBuffer[a * kRoiDim + 2] = yCtr + hHalf;
1290             roiBuffer[a * kRoiDim + 3] = xCtr + wHalf;
1291             anchorBase += kRoiDim;
1292             deltaBase += lengthBoxEncoding;
1293         }
1294 
1295         if (useRegularNms) {
1296             std::vector<uint32_t> select;
1297             box_with_nms_limit::hardNmsMultiClass(
1298                     scoreBase, numClasses, numAnchors, scoreThreshold, iouThreshold,
1299                     maxNumDetections, maxNumDetectionsPerClass,
1300                     [&roiBuffer, numClasses](uint32_t ind) {
1301                         return roiBuffer.data() + (ind / numClasses) * kRoiDim;
1302                     },
1303                     &select);
1304             for (uint32_t i = 0; i < select.size(); i++) {
1305                 uint32_t ind = select[i];
1306                 scoreOutBase[i] = scoreBase[ind];
1307                 memcpy(roiOutBase + i * kRoiDim, &roiBuffer[(ind / numClasses) * kRoiDim],
1308                        kRoiDim * sizeof(float));
1309                 classOutBase[i] = (ind % numClasses) - (isBGInLabel ? 0 : 1);
1310             }
1311             *detectionOutData++ = select.size();
1312         } else {
1313             uint32_t numOutClasses = std::min<uint32_t>(numClasses - 1, maxClassesPerDetection);
1314             std::vector<float> maxScores(numAnchors);
1315             for (uint32_t a = 0; a < numAnchors; a++) {
1316                 maxScores[a] = *std::max_element(scoreBase + a * numClasses + 1,
1317                                                  scoreBase + (a + 1) * numClasses);
1318             }
1319             std::vector<uint32_t> select;
1320             for (uint32_t a = 0; a < numAnchors; a++) {
1321                 if (maxScores[a] > scoreThreshold) {
1322                     select.push_back(a);
1323                 }
1324             }
1325             uint32_t* selectEnd = box_with_nms_limit::hardNmsSingleClass(
1326                     maxScores.data(), iouThreshold, maxNumDetections,
1327                     [&roiBuffer](uint32_t ind) { return roiBuffer.data() + ind * kRoiDim; },
1328                     select.data(), select.size());
1329             select.resize(selectEnd - select.data());
1330             float* scoreOutPtr = scoreOutBase;
1331             float* roiOutPtr = roiOutBase;
1332             int32_t* classOutPtr = classOutBase;
1333             for (auto i : select) {
1334                 const float* score = scoreBase + i * numClasses;
1335                 std::vector<uint32_t> scoreInds(numClasses - 1);
1336                 std::iota(scoreInds.begin(), scoreInds.end(), 1);
1337                 std::sort(scoreInds.begin(), scoreInds.end(),
1338                           [&score](const uint32_t lhs, const uint32_t rhs) {
1339                               return score[lhs] > score[rhs];
1340                           });
1341                 for (uint32_t c = 0; c < numOutClasses; c++) {
1342                     *scoreOutPtr++ = score[scoreInds[c]];
1343                     memcpy(roiOutPtr, &roiBuffer[i * kRoiDim], kRoiDim * sizeof(float));
1344                     roiOutPtr += kRoiDim;
1345                     *classOutPtr++ = scoreInds[c] - (isBGInLabel ? 0 : 1);
1346                 }
1347             }
1348             *detectionOutData++ = select.size() * numOutClasses;
1349         }
1350         scoreBase += numAnchors * numClasses;
1351         scoreOutBase += numOutDetection;
1352         roiOutBase += numOutDetection * kRoiDim;
1353         classOutBase += numOutDetection;
1354     }
1355     return true;
1356 }
1357 
detectionPostprocessFloat16(const _Float16 * scoreData,const Shape & scoreShape,const _Float16 * deltaData,const Shape & deltaShape,const _Float16 * anchorData,const Shape & anchorShape,float scaleY,float scaleX,float scaleH,float scaleW,bool useRegularNms,int32_t maxNumDetections,int32_t maxClassesPerDetection,int32_t maxNumDetectionsPerClass,float iouThreshold,float scoreThreshold,bool isBGInLabel,_Float16 * scoreOutData,const Shape & scoreOutShape,_Float16 * roiOutData,const Shape & roiOutShape,int32_t * classOutData,const Shape & classOutShape,int32_t * detectionOutData,const Shape & detectionOutShape)1358 bool detectionPostprocessFloat16(
1359         const _Float16* scoreData, const Shape& scoreShape, const _Float16* deltaData,
1360         const Shape& deltaShape, const _Float16* anchorData, const Shape& anchorShape, float scaleY,
1361         float scaleX, float scaleH, float scaleW, bool useRegularNms, int32_t maxNumDetections,
1362         int32_t maxClassesPerDetection, int32_t maxNumDetectionsPerClass, float iouThreshold,
1363         float scoreThreshold, bool isBGInLabel, _Float16* scoreOutData, const Shape& scoreOutShape,
1364         _Float16* roiOutData, const Shape& roiOutShape, int32_t* classOutData,
1365         const Shape& classOutShape, int32_t* detectionOutData, const Shape& detectionOutShape) {
1366     std::vector<float> scores_float32(getNumberOfElements(scoreShape));
1367     convertFloat16ToFloat32(scoreData, &scores_float32);
1368     std::vector<float> delta_float32(getNumberOfElements(deltaShape));
1369     convertFloat16ToFloat32(deltaData, &delta_float32);
1370     std::vector<float> anchor_float32(getNumberOfElements(anchorShape));
1371     convertFloat16ToFloat32(anchorData, &anchor_float32);
1372     std::vector<float> outputScore_float32(getNumberOfElements(scoreOutShape));
1373     std::vector<float> outputRoi_float32(getNumberOfElements(roiOutShape));
1374     NN_RET_CHECK(detectionPostprocessFloat32(
1375             scores_float32.data(), scoreShape, delta_float32.data(), deltaShape,
1376             anchor_float32.data(), anchorShape, scaleY, scaleX, scaleH, scaleW, useRegularNms,
1377             maxNumDetections, maxClassesPerDetection, maxNumDetectionsPerClass, iouThreshold,
1378             scoreThreshold, isBGInLabel, outputScore_float32.data(), scoreOutShape,
1379             outputRoi_float32.data(), roiOutShape, classOutData, classOutShape, detectionOutData,
1380             detectionOutShape));
1381     convertFloat32ToFloat16(outputScore_float32, scoreOutData);
1382     convertFloat32ToFloat16(outputRoi_float32, roiOutData);
1383     return true;
1384 }
1385 
1386 }  // namespace
1387 
prepare(IOperationExecutionContext * context)1388 bool prepare(IOperationExecutionContext* context) {
1389     Shape scoreShape = context->getInputShape(kScoreTensor);
1390     Shape deltasShape = context->getInputShape(kDeltaTensor);
1391     Shape anchorsShape = context->getInputShape(kAnchorTensor);
1392     Shape outputScoreShape = context->getOutputShape(kOutputScoreTensor);
1393     Shape outputRoiShape = context->getOutputShape(kOutputRoiTensor);
1394     Shape outputClassShape = context->getOutputShape(kOutputClassTensor);
1395     Shape outputDetectionShape = context->getOutputShape(kOutputDetectionTensor);
1396 
1397     NN_RET_CHECK_EQ(getNumberOfDimensions(scoreShape), 3u);
1398     NN_RET_CHECK_EQ(getNumberOfDimensions(deltasShape), 3u);
1399     NN_RET_CHECK_EQ(getNumberOfDimensions(anchorsShape), 2u);
1400 
1401     const uint32_t kRoiDim = 4;
1402     uint32_t numBatches = getSizeOfDimension(scoreShape, 0);
1403     uint32_t numAnchors = getSizeOfDimension(scoreShape, 1);
1404     uint32_t numClasses = getSizeOfDimension(scoreShape, 2);
1405     uint32_t lengthBoxEncoding = getSizeOfDimension(deltasShape, 2);
1406     uint32_t maxNumDetections = context->getInputValue<int32_t>(kMaxNumDetectionScalar);
1407     uint32_t maxClassesPerDetection =
1408             context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar);
1409     uint32_t numOutDetections = maxNumDetections;
1410 
1411     NN_RET_CHECK_EQ(getSizeOfDimension(deltasShape, 0), numBatches);
1412     NN_RET_CHECK_EQ(getSizeOfDimension(deltasShape, 1), numAnchors);
1413     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 0), numAnchors);
1414     NN_RET_CHECK_EQ(getSizeOfDimension(anchorsShape, 1), kRoiDim);
1415 
1416     if (scoreShape.type == OperandType::TENSOR_FLOAT32) {
1417         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleYScalar), 0);
1418         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleXScalar), 0);
1419         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleHScalar), 0);
1420         NN_RET_CHECK_GT(context->getInputValue<float>(kScaleWScalar), 0);
1421         NN_RET_CHECK_GE(context->getInputValue<float>(kScoreThresholdScalar), 0);
1422         NN_RET_CHECK_GE(context->getInputValue<float>(kIoUThresholdScalar), 0);
1423     } else if (scoreShape.type == OperandType::TENSOR_FLOAT16) {
1424         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleYScalar) > 0);
1425         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleXScalar) > 0);
1426         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleHScalar) > 0);
1427         NN_RET_CHECK(context->getInputValue<_Float16>(kScaleWScalar) > 0);
1428         NN_RET_CHECK(context->getInputValue<_Float16>(kScoreThresholdScalar) >= 0);
1429         NN_RET_CHECK(context->getInputValue<_Float16>(kIoUThresholdScalar) >= 0);
1430     } else {
1431         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1432     }
1433     NN_RET_CHECK_GT(numClasses, 1u);
1434     NN_RET_CHECK_GE(lengthBoxEncoding, 4u);
1435     NN_RET_CHECK_GT(maxNumDetections, 0u);
1436     if (context->getInputValue<bool>(kUseRegularNmsScalar)) {
1437         NN_RET_CHECK_GT(context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar), 0);
1438     } else {
1439         NN_RET_CHECK_GT(maxClassesPerDetection, 0u);
1440         numOutDetections *= maxClassesPerDetection;
1441     }
1442 
1443     outputScoreShape.type = scoreShape.type;
1444     outputScoreShape.dimensions = {numBatches, numOutDetections};
1445     NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScoreShape));
1446 
1447     outputRoiShape.type = anchorsShape.type;
1448     outputRoiShape.dimensions = {numBatches, numOutDetections, 4};
1449     NN_RET_CHECK(context->setOutputShape(kOutputRoiTensor, outputRoiShape));
1450 
1451     outputClassShape.type = OperandType::TENSOR_INT32;
1452     outputClassShape.dimensions = {numBatches, numOutDetections};
1453     NN_RET_CHECK(context->setOutputShape(kOutputClassTensor, outputClassShape));
1454 
1455     outputDetectionShape.type = OperandType::TENSOR_INT32;
1456     outputDetectionShape.dimensions = {numBatches};
1457     NN_RET_CHECK(context->setOutputShape(kOutputDetectionTensor, outputDetectionShape));
1458     return true;
1459 }
1460 
execute(IOperationExecutionContext * context)1461 bool execute(IOperationExecutionContext* context) {
1462     NNTRACE_TRANS("detectionPostProcess");
1463     switch (context->getInputType(kScoreTensor)) {
1464         case OperandType::TENSOR_FLOAT16: {
1465             return detectionPostprocessFloat16(
1466                     context->getInputBuffer<_Float16>(kScoreTensor),
1467                     context->getInputShape(kScoreTensor),
1468                     context->getInputBuffer<_Float16>(kDeltaTensor),
1469                     context->getInputShape(kDeltaTensor),
1470                     context->getInputBuffer<_Float16>(kAnchorTensor),
1471                     context->getInputShape(kAnchorTensor),
1472                     context->getInputValue<_Float16>(kScaleYScalar),
1473                     context->getInputValue<_Float16>(kScaleXScalar),
1474                     context->getInputValue<_Float16>(kScaleHScalar),
1475                     context->getInputValue<_Float16>(kScaleWScalar),
1476                     context->getInputValue<bool>(kUseRegularNmsScalar),
1477                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
1478                     context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar),
1479                     context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar),
1480                     context->getInputValue<_Float16>(kIoUThresholdScalar),
1481                     context->getInputValue<_Float16>(kScoreThresholdScalar),
1482                     context->getInputValue<bool>(kIsBGInLabelScalar),
1483                     context->getOutputBuffer<_Float16>(kOutputScoreTensor),
1484                     context->getOutputShape(kOutputScoreTensor),
1485                     context->getOutputBuffer<_Float16>(kOutputRoiTensor),
1486                     context->getOutputShape(kOutputRoiTensor),
1487                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
1488                     context->getOutputShape(kOutputClassTensor),
1489                     context->getOutputBuffer<int32_t>(kOutputDetectionTensor),
1490                     context->getOutputShape(kOutputDetectionTensor));
1491         }
1492         case OperandType::TENSOR_FLOAT32: {
1493             return detectionPostprocessFloat32(
1494                     context->getInputBuffer<float>(kScoreTensor),
1495                     context->getInputShape(kScoreTensor),
1496                     context->getInputBuffer<float>(kDeltaTensor),
1497                     context->getInputShape(kDeltaTensor),
1498                     context->getInputBuffer<float>(kAnchorTensor),
1499                     context->getInputShape(kAnchorTensor),
1500                     context->getInputValue<float>(kScaleYScalar),
1501                     context->getInputValue<float>(kScaleXScalar),
1502                     context->getInputValue<float>(kScaleHScalar),
1503                     context->getInputValue<float>(kScaleWScalar),
1504                     context->getInputValue<bool>(kUseRegularNmsScalar),
1505                     context->getInputValue<int32_t>(kMaxNumDetectionScalar),
1506                     context->getInputValue<int32_t>(kMaxClassesPerDetectionScalar),
1507                     context->getInputValue<int32_t>(kMaxNumDetectionPerClassScalar),
1508                     context->getInputValue<float>(kIoUThresholdScalar),
1509                     context->getInputValue<float>(kScoreThresholdScalar),
1510                     context->getInputValue<bool>(kIsBGInLabelScalar),
1511                     context->getOutputBuffer<float>(kOutputScoreTensor),
1512                     context->getOutputShape(kOutputScoreTensor),
1513                     context->getOutputBuffer<float>(kOutputRoiTensor),
1514                     context->getOutputShape(kOutputRoiTensor),
1515                     context->getOutputBuffer<int32_t>(kOutputClassTensor),
1516                     context->getOutputShape(kOutputClassTensor),
1517                     context->getOutputBuffer<int32_t>(kOutputDetectionTensor),
1518                     context->getOutputShape(kOutputDetectionTensor));
1519         }
1520         default:
1521             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
1522     }
1523 }
1524 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
1525 
1526 }  // namespace detection_postprocess
1527 
1528 }  // namespace bbox_ops
1529 
1530 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(AXIS_ALIGNED_BBOX_TRANSFORM,
1531                                          bbox_ops::axis_aligned_bbox_transform::prepare,
1532                                          bbox_ops::axis_aligned_bbox_transform::execute,
1533                                          .allowZeroSizedInput = true);
1534 
1535 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BOX_WITH_NMS_LIMIT, bbox_ops::box_with_nms_limit::prepare,
1536                                          bbox_ops::box_with_nms_limit::execute,
1537                                          .allowZeroSizedInput = true);
1538 
1539 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(GENERATE_PROPOSALS, bbox_ops::generate_proposals::prepare,
1540                                          bbox_ops::generate_proposals::execute);
1541 
1542 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(DETECTION_POSTPROCESSING,
1543                                          bbox_ops::detection_postprocess::prepare,
1544                                          bbox_ops::detection_postprocess::execute);
1545 }  // namespace nn
1546 }  // namespace android
1547