• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "CpuOperationUtils.h"
20 #include "HalInterfaces.h"
21 #include "OperationResolver.h"
22 #include "OperationsUtils.h"
23 #include "Tracing.h"
24 
25 #include <tensorflow/lite/kernels/internal/common.h>
26 #include <algorithm>
27 #include <cfloat>
28 #include <cmath>
29 #include <vector>
30 
31 namespace android {
32 namespace nn {
33 namespace roi_align {
34 
35 constexpr char kOperationName[] = "ROI_ALIGN";
36 
37 constexpr uint32_t kNumInputs = 10;
38 constexpr uint32_t kInputTensor = 0;
39 constexpr uint32_t kRoiTensor = 1;
40 constexpr uint32_t kBatchSplitTensor = 2;
41 constexpr uint32_t kOutputHeightScalar = 3;
42 constexpr uint32_t kOutputWidthScalar = 4;
43 constexpr uint32_t kHeightStrideSalar = 5;
44 constexpr uint32_t kWidthStrideScalar = 6;
45 constexpr uint32_t kHeightSamplingRatioScalar = 7;
46 constexpr uint32_t kWidthSamplingRatioScalar = 8;
47 constexpr uint32_t kLayoutScalar = 9;
48 
49 constexpr uint32_t kNumOutputs = 1;
50 constexpr uint32_t kOutputTensor = 0;
51 
52 namespace {
53 
54 using namespace hal;
55 
56 template <typename T_Input, typename T_Roi>
roiAlignNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)57 inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
58                          const Shape& roiShape, const int32_t* batchSplitData,
59                          const Shape& batchSplitShape, float heightStride, float widthStride,
60                          int32_t heightSamplingRatio, int32_t widthSamplingRatio,
61                          T_Input* outputData, const Shape& outputShape) {
62     NNTRACE_TRANS("RoiAlign");
63 
64     const uint32_t kRoiDim = 4;
65     const T_Roi heightScale = 1.0f / heightStride;
66     const T_Roi widthScale = 1.0f / widthStride;
67 
68     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
69     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
70     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
71     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
72     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
73     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
74     uint32_t numRois = getSizeOfDimension(roiShape, 0);
75     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
76 
77     T_Input* outPtr = outputData;
78     const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
79     uint32_t roiIndex = 0;
80     for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
81         uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
82         // Check for malformed data
83         // 1. invalid batch id
84         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
85         // 3. Invalid region: x2 < x1 || y2 < y1
86         NN_RET_CHECK_GE(batchId, 0);
87         NN_RET_CHECK_LT(batchId, numBatches);
88         NN_RET_CHECK(roiInfo[0] >= 0);
89         NN_RET_CHECK(roiInfo[1] >= 0);
90         NN_RET_CHECK(roiInfo[2] >= 0);
91         NN_RET_CHECK(roiInfo[3] >= 0);
92         NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
93         NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
94         NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
95         NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
96         NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
97         NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
98 
99         T_Roi wRoiStart = roiInfo[0] * widthScale;
100         T_Roi hRoiStart = roiInfo[1] * heightScale;
101         T_Roi wRoiEnd = roiInfo[2] * widthScale;
102         T_Roi hRoiEnd = roiInfo[3] * heightScale;
103 
104         T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f);
105         T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f);
106         T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
107         T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
108 
109         // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
110         uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio
111                                                          : std::ceil(static_cast<float>(wStepSize));
112         uint32_t hSamplingRatio = heightSamplingRatio > 0
113                                           ? heightSamplingRatio
114                                           : std::ceil(static_cast<float>(hStepSize));
115         int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
116         T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio);
117         T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio);
118 
119         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
120         for (uint32_t i = 0; i < outHeight; i++) {
121             for (uint32_t j = 0; j < outWidth; j++) {
122                 T_Roi wStart = wStepSize * j + wRoiStart;
123                 T_Roi wEnd = wStepSize * (j + 1) + wRoiStart;
124                 T_Roi hStart = hStepSize * i + hRoiStart;
125                 T_Roi hEnd = hStepSize * (i + 1) + hRoiStart;
126 
127                 // initialize output to zero
128                 for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0;
129 
130                 // calculate the sum of the sampling points
131                 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
132                     for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
133                         T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd;
134                         T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd;
135 
136                         // bilinear interpolation of point (x,y)
137                         // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
138                         uint32_t x1 = std::floor(static_cast<float>(x));
139                         uint32_t y1 = std::floor(static_cast<float>(y));
140                         uint32_t x2 = x1 + 1, y2 = y1 + 1;
141                         T_Roi dx1 = x - static_cast<T_Roi>(x1);
142                         T_Roi dy1 = y - static_cast<T_Roi>(y1);
143 
144                         // dealing with out of bound samples
145                         if (x1 >= inWidth - 1) {
146                             x1 = x2 = inWidth - 1;
147                             dx1 = 0;
148                         }
149                         if (y1 >= inHeight - 1) {
150                             y1 = y2 = inHeight - 1;
151                             dy1 = 0;
152                         }
153 
154                         T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
155                         T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
156                         uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
157                                               y1 * inWidth * inDepth + x2 * inDepth,
158                                               y2 * inWidth * inDepth + x1 * inDepth,
159                                               y2 * inWidth * inDepth + x2 * inDepth};
160 
161                         for (uint32_t k = 0; k < inDepth; k++) {
162                             T_Input interpolation = 0;
163                             for (uint32_t c = 0; c < 4; c++) {
164                                 interpolation += ws[c] * batchBase[offsets[c] + k];
165                             }
166                             outPtr[k] += interpolation;
167                         }
168                     }
169                 }
170 
171                 // take average
172                 for (uint32_t k = 0; k < inDepth; k++)
173                     outPtr[k] /= static_cast<T_Input>(numSamplingPoints);
174                 outPtr += inDepth;
175             }
176         }
177     }
178     return true;
179 }
180 
181 template <typename T_Input>
roiAlignQuantNhwc(const T_Input * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)182 inline bool roiAlignQuantNhwc(const T_Input* inputData, const Shape& inputShape,
183                               const uint16_t* roiData, const Shape& roiShape,
184                               const int32_t* batchSplitData, const Shape& batchSplitShape,
185                               float heightStride, float widthStride, int32_t heightSamplingRatio,
186                               int32_t widthSamplingRatio, T_Input* outputData,
187                               const Shape& outputShape) {
188     NNTRACE_TRANS("RoiAlignQuant8");
189 
190     constexpr float wScale = 1.0f / 255.0f;
191     constexpr uint32_t kRoiDim = 4;
192     const float heightScale = 1.0f / heightStride;
193     const float widthScale = 1.0f / widthStride;
194 
195     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
196     uint32_t inHeight = getSizeOfDimension(inputShape, 1);
197     uint32_t inWidth = getSizeOfDimension(inputShape, 2);
198     uint32_t inDepth = getSizeOfDimension(inputShape, 3);
199     uint32_t outHeight = getSizeOfDimension(outputShape, 1);
200     uint32_t outWidth = getSizeOfDimension(outputShape, 2);
201     uint32_t numRois = getSizeOfDimension(roiShape, 0);
202     uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
203 
204     T_Input* outPtr = outputData;
205     const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength;
206     uint32_t roiIndex = 0;
207     for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
208         uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
209         float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f;
210         float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f;
211         float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f;
212         float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f;
213 
214         // Check for malformed data
215         // 1. invalid batch id
216         // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
217         // 3. Invalid region: x2 < x1 || y2 < y1
218         NN_RET_CHECK_GE(batchId, 0);
219         NN_RET_CHECK_LT(batchId, numBatches);
220         NN_RET_CHECK(wRoiStart <= inWidth);
221         NN_RET_CHECK(hRoiStart <= inHeight);
222         NN_RET_CHECK(wRoiEnd <= inWidth);
223         NN_RET_CHECK(hRoiEnd <= inHeight);
224         NN_RET_CHECK_LE(wRoiStart, wRoiEnd);
225         NN_RET_CHECK_LE(hRoiStart, hRoiEnd);
226 
227         float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f);
228         float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f);
229         float wStepSize = roiWidth / static_cast<float>(outWidth);
230         float hStepSize = roiHeight / static_cast<float>(outHeight);
231 
232         // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
233         uint32_t wSamplingRatio =
234                 widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize);
235         uint32_t hSamplingRatio =
236                 heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize);
237         int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
238         float wBinSize = wStepSize / static_cast<float>(wSamplingRatio);
239         float hBinSize = hStepSize / static_cast<float>(hSamplingRatio);
240 
241         float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints;
242         int32_t outputMultiplier = 0;
243         int32_t outputShift = 0;
244         if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) {
245             return false;
246         }
247 
248         const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
249         for (uint32_t i = 0; i < outHeight; i++) {
250             for (uint32_t j = 0; j < outWidth; j++) {
251                 float wStart = wStepSize * j + wRoiStart;
252                 float wEnd = wStepSize * (j + 1) + wRoiStart;
253                 float hStart = hStepSize * i + hRoiStart;
254                 float hEnd = hStepSize * (i + 1) + hRoiStart;
255 
256                 std::vector<int32_t> outTemp(inDepth, 0);
257                 // calculate the sum of the sampling points
258                 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
259                     for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
260                         float y = hStart + hBinSize / 2 + hBinSize * yInd;
261                         float x = wStart + wBinSize / 2 + wBinSize * xInd;
262 
263                         // bilinear interpolation of point (x,y)
264                         // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
265                         uint32_t x1 = std::floor(x), y1 = std::floor(y);
266                         uint32_t x2 = x1 + 1, y2 = y1 + 1;
267                         float dx1 = x - static_cast<float>(x1);
268                         float dy1 = y - static_cast<float>(y1);
269 
270                         // dealing with out of bound samples
271                         if (x1 >= inWidth - 1) {
272                             x1 = x2 = inWidth - 1;
273                             dx1 = 0;
274                         }
275                         if (y1 >= inHeight - 1) {
276                             y1 = y2 = inHeight - 1;
277                             dy1 = 0;
278                         }
279 
280                         float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
281                         float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
282                         uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
283                                               y1 * inWidth * inDepth + x2 * inDepth,
284                                               y2 * inWidth * inDepth + x1 * inDepth,
285                                               y2 * inWidth * inDepth + x2 * inDepth};
286 
287                         for (uint32_t k = 0; k < inDepth; k++) {
288                             int32_t interpolation = 0;
289                             for (uint32_t c = 0; c < 4; c++) {
290                                 int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale));
291                                 interpolation +=
292                                         wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) -
293                                                   inputShape.offset);
294                             }
295                             outTemp[k] += interpolation;
296                         }
297                     }
298                 }
299 
300                 // take average and cast to output quantization
301                 for (uint32_t k = 0; k < inDepth; k++) {
302                     int32_t raw_out = tflite::MultiplyByQuantizedMultiplier(
303                                               outTemp[k], outputMultiplier, -outputShift) +
304                                       outputShape.offset;
305                     outPtr[k] = saturateCast<T_Input>(raw_out);
306                 }
307                 outPtr += inDepth;
308             }
309         }
310     }
311     return true;
312 }
313 
314 template <typename T_Input, typename T_Roi>
roiAlign(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,bool useNchw,T_Input * outputData,const Shape & outputShape)315 inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
316                      const Shape& roiShape, const int32_t* batchSplitData,
317                      const Shape& batchSplitShape, float heightStride, float widthStride,
318                      int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw,
319                      T_Input* outputData, const Shape& outputShape) {
320     InputWithLayout<T_Input> input(useNchw);
321     OutputWithLayout<T_Input> output(useNchw);
322     NN_RET_CHECK(input.initialize(inputData, inputShape));
323     NN_RET_CHECK(output.initialize(outputData, outputShape));
324     if constexpr (std::is_same_v<T_Roi, uint16_t> &&
325                   (std::is_same_v<T_Input, uint8_t> || std::is_same_v<T_Input, int8_t>)) {
326         NN_RET_CHECK(roiAlignQuantNhwc<T_Input>(
327                 input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape, batchSplitData,
328                 batchSplitShape, heightStride, widthStride, heightSamplingRatio, widthSamplingRatio,
329                 output.getNhwcBuffer(), output.getNhwcShape()));
330     } else {
331         NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
332                                   batchSplitData, batchSplitShape, heightStride, widthStride,
333                                   heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(),
334                                   output.getNhwcShape()));
335     }
336     NN_RET_CHECK(output.commit());
337     return true;
338 }
339 
340 }  // namespace
341 
validate(const IOperationValidationContext * context)342 bool validate(const IOperationValidationContext* context) {
343     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
344     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
345     std::vector<OperandType> inExpectedTypes;
346     auto inputType = context->getInputType(kInputTensor);
347     if (inputType == OperandType::TENSOR_FLOAT32) {
348         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
349                            OperandType::TENSOR_INT32,   OperandType::INT32,
350                            OperandType::INT32,          OperandType::FLOAT32,
351                            OperandType::FLOAT32,        OperandType::INT32,
352                            OperandType::INT32,          OperandType::BOOL};
353     } else if (inputType == OperandType::TENSOR_FLOAT16) {
354         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
355                            OperandType::TENSOR_INT32,   OperandType::INT32,
356                            OperandType::INT32,          OperandType::FLOAT16,
357                            OperandType::FLOAT16,        OperandType::INT32,
358                            OperandType::INT32,          OperandType::BOOL};
359     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
360                inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
361         inExpectedTypes = {inputType,
362                            OperandType::TENSOR_QUANT16_ASYMM,
363                            OperandType::TENSOR_INT32,
364                            OperandType::INT32,
365                            OperandType::INT32,
366                            OperandType::FLOAT32,
367                            OperandType::FLOAT32,
368                            OperandType::INT32,
369                            OperandType::INT32,
370                            OperandType::BOOL};
371     } else {
372         LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
373         return false;
374     }
375     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
376     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
377     if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
378         return validateHalVersion(context, HalVersion::V1_3);
379     } else {
380         return validateHalVersion(context, HalVersion::V1_2);
381     }
382 }
383 
prepare(IOperationExecutionContext * context)384 bool prepare(IOperationExecutionContext* context) {
385     bool useNchw = context->getInputValue<bool>(kLayoutScalar);
386     Shape input = context->getInputShape(kInputTensor);
387     Shape roiShape = context->getInputShape(kRoiTensor);
388     Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
389     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
390     NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
391 
392     uint32_t numBatches = getSizeOfDimension(input, 0);
393     uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
394     uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
395     uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
396     uint32_t numRois = getSizeOfDimension(roiShape, 0);
397     // Every dimension must be positive except for numRois.
398     NN_RET_CHECK_GT(numBatches, 0);
399     NN_RET_CHECK_GT(inHeight, 0);
400     NN_RET_CHECK_GT(inWidth, 0);
401     NN_RET_CHECK_GT(inDepth, 0);
402     NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
403     NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
404 
405     int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
406     int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
407     int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar);
408     int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar);
409     float heightScale, widthScale;
410     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
411         heightScale = context->getInputValue<_Float16>(kHeightStrideSalar);
412         widthScale = context->getInputValue<_Float16>(kWidthStrideScalar);
413     } else {
414         heightScale = context->getInputValue<float>(kHeightStrideSalar);
415         widthScale = context->getInputValue<float>(kWidthStrideScalar);
416     }
417     NN_RET_CHECK_GT(outputHeight, 0);
418     NN_RET_CHECK_GT(outputWidth, 0);
419     NN_RET_CHECK_GT(heightScale, 0);
420     NN_RET_CHECK_GT(widthScale, 0);
421     // Sampling ratio can set to 0 for adaptive value.
422     NN_RET_CHECK_GE(heightSamplingRatio, 0);
423     NN_RET_CHECK_GE(widthSamplingRatio, 0);
424 
425     if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
426         NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
427         NN_RET_CHECK_EQ(roiShape.offset, 0);
428     }
429 
430     Shape output = context->getOutputShape(kOutputTensor);
431     output.type = input.type;
432     if (useNchw) {
433         output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
434                              static_cast<uint32_t>(outputWidth)};
435     } else {
436         output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
437                              static_cast<uint32_t>(outputWidth), inDepth};
438     }
439     return context->setOutputShape(kOutputTensor, output);
440 }
441 
execute(IOperationExecutionContext * context)442 bool execute(IOperationExecutionContext* context) {
443     // Bypass execution in the case of zero-sized input.
444     if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true;
445     switch (context->getInputType(kInputTensor)) {
446         case OperandType::TENSOR_FLOAT16:
447             return roiAlign(context->getInputBuffer<_Float16>(kInputTensor),
448                             context->getInputShape(kInputTensor),
449                             context->getInputBuffer<_Float16>(kRoiTensor),
450                             context->getInputShape(kRoiTensor),
451                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
452                             context->getInputShape(kBatchSplitTensor),
453                             context->getInputValue<_Float16>(kHeightStrideSalar),
454                             context->getInputValue<_Float16>(kWidthStrideScalar),
455                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
456                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
457                             context->getInputValue<bool>(kLayoutScalar),
458                             context->getOutputBuffer<_Float16>(kOutputTensor),
459                             context->getOutputShape(kOutputTensor));
460         case OperandType::TENSOR_FLOAT32:
461             return roiAlign(context->getInputBuffer<float>(kInputTensor),
462                             context->getInputShape(kInputTensor),
463                             context->getInputBuffer<float>(kRoiTensor),
464                             context->getInputShape(kRoiTensor),
465                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
466                             context->getInputShape(kBatchSplitTensor),
467                             context->getInputValue<float>(kHeightStrideSalar),
468                             context->getInputValue<float>(kWidthStrideScalar),
469                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
470                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
471                             context->getInputValue<bool>(kLayoutScalar),
472                             context->getOutputBuffer<float>(kOutputTensor),
473                             context->getOutputShape(kOutputTensor));
474         case OperandType::TENSOR_QUANT8_ASYMM:
475             return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor),
476                             context->getInputShape(kInputTensor),
477                             context->getInputBuffer<uint16_t>(kRoiTensor),
478                             context->getInputShape(kRoiTensor),
479                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
480                             context->getInputShape(kBatchSplitTensor),
481                             context->getInputValue<float>(kHeightStrideSalar),
482                             context->getInputValue<float>(kWidthStrideScalar),
483                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
484                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
485                             context->getInputValue<bool>(kLayoutScalar),
486                             context->getOutputBuffer<uint8_t>(kOutputTensor),
487                             context->getOutputShape(kOutputTensor));
488         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
489             return roiAlign(context->getInputBuffer<int8_t>(kInputTensor),
490                             context->getInputShape(kInputTensor),
491                             context->getInputBuffer<uint16_t>(kRoiTensor),
492                             context->getInputShape(kRoiTensor),
493                             context->getInputBuffer<int32_t>(kBatchSplitTensor),
494                             context->getInputShape(kBatchSplitTensor),
495                             context->getInputValue<float>(kHeightStrideSalar),
496                             context->getInputValue<float>(kWidthStrideScalar),
497                             context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
498                             context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
499                             context->getInputValue<bool>(kLayoutScalar),
500                             context->getOutputBuffer<int8_t>(kOutputTensor),
501                             context->getOutputShape(kOutputTensor));
502         default:
503             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
504     }
505 }
506 
507 }  // namespace roi_align
508 
509 NN_REGISTER_OPERATION(ROI_ALIGN, roi_align::kOperationName, roi_align::validate, roi_align::prepare,
510                       roi_align::execute, .allowZeroSizedInput = true);
511 
512 }  // namespace nn
513 }  // namespace android
514