1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "CpuOperationUtils.h"
20 #include "HalInterfaces.h"
21 #include "OperationResolver.h"
22 #include "OperationsUtils.h"
23 #include "Tracing.h"
24
25 #include <tensorflow/lite/kernels/internal/common.h>
26 #include <algorithm>
27 #include <cfloat>
28 #include <cmath>
29 #include <vector>
30
31 namespace android {
32 namespace nn {
33 namespace roi_align {
34
35 constexpr char kOperationName[] = "ROI_ALIGN";
36
37 constexpr uint32_t kNumInputs = 10;
38 constexpr uint32_t kInputTensor = 0;
39 constexpr uint32_t kRoiTensor = 1;
40 constexpr uint32_t kBatchSplitTensor = 2;
41 constexpr uint32_t kOutputHeightScalar = 3;
42 constexpr uint32_t kOutputWidthScalar = 4;
43 constexpr uint32_t kHeightStrideSalar = 5;
44 constexpr uint32_t kWidthStrideScalar = 6;
45 constexpr uint32_t kHeightSamplingRatioScalar = 7;
46 constexpr uint32_t kWidthSamplingRatioScalar = 8;
47 constexpr uint32_t kLayoutScalar = 9;
48
49 constexpr uint32_t kNumOutputs = 1;
50 constexpr uint32_t kOutputTensor = 0;
51
52 namespace {
53
54 using namespace hal;
55
56 template <typename T_Input, typename T_Roi>
roiAlignNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)57 inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
58 const Shape& roiShape, const int32_t* batchSplitData,
59 const Shape& batchSplitShape, float heightStride, float widthStride,
60 int32_t heightSamplingRatio, int32_t widthSamplingRatio,
61 T_Input* outputData, const Shape& outputShape) {
62 NNTRACE_TRANS("RoiAlign");
63
64 const uint32_t kRoiDim = 4;
65 const T_Roi heightScale = 1.0f / heightStride;
66 const T_Roi widthScale = 1.0f / widthStride;
67
68 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
69 uint32_t inHeight = getSizeOfDimension(inputShape, 1);
70 uint32_t inWidth = getSizeOfDimension(inputShape, 2);
71 uint32_t inDepth = getSizeOfDimension(inputShape, 3);
72 uint32_t outHeight = getSizeOfDimension(outputShape, 1);
73 uint32_t outWidth = getSizeOfDimension(outputShape, 2);
74 uint32_t numRois = getSizeOfDimension(roiShape, 0);
75 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
76
77 T_Input* outPtr = outputData;
78 const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
79 uint32_t roiIndex = 0;
80 for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
81 uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
82 // Check for malformed data
83 // 1. invalid batch id
84 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
85 // 3. Invalid region: x2 < x1 || y2 < y1
86 NN_RET_CHECK_GE(batchId, 0);
87 NN_RET_CHECK_LT(batchId, numBatches);
88 NN_RET_CHECK(roiInfo[0] >= 0);
89 NN_RET_CHECK(roiInfo[1] >= 0);
90 NN_RET_CHECK(roiInfo[2] >= 0);
91 NN_RET_CHECK(roiInfo[3] >= 0);
92 NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
93 NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
94 NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
95 NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
96 NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
97 NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
98
99 T_Roi wRoiStart = roiInfo[0] * widthScale;
100 T_Roi hRoiStart = roiInfo[1] * heightScale;
101 T_Roi wRoiEnd = roiInfo[2] * widthScale;
102 T_Roi hRoiEnd = roiInfo[3] * heightScale;
103
104 T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f);
105 T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f);
106 T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
107 T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
108
109 // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
110 uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio
111 : std::ceil(static_cast<float>(wStepSize));
112 uint32_t hSamplingRatio = heightSamplingRatio > 0
113 ? heightSamplingRatio
114 : std::ceil(static_cast<float>(hStepSize));
115 int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
116 T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio);
117 T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio);
118
119 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
120 for (uint32_t i = 0; i < outHeight; i++) {
121 for (uint32_t j = 0; j < outWidth; j++) {
122 T_Roi wStart = wStepSize * j + wRoiStart;
123 T_Roi wEnd = wStepSize * (j + 1) + wRoiStart;
124 T_Roi hStart = hStepSize * i + hRoiStart;
125 T_Roi hEnd = hStepSize * (i + 1) + hRoiStart;
126
127 // initialize output to zero
128 for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0;
129
130 // calculate the sum of the sampling points
131 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
132 for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
133 T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd;
134 T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd;
135
136 // bilinear interpolation of point (x,y)
137 // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
138 uint32_t x1 = std::floor(static_cast<float>(x));
139 uint32_t y1 = std::floor(static_cast<float>(y));
140 uint32_t x2 = x1 + 1, y2 = y1 + 1;
141 T_Roi dx1 = x - static_cast<T_Roi>(x1);
142 T_Roi dy1 = y - static_cast<T_Roi>(y1);
143
144 // dealing with out of bound samples
145 if (x1 >= inWidth - 1) {
146 x1 = x2 = inWidth - 1;
147 dx1 = 0;
148 }
149 if (y1 >= inHeight - 1) {
150 y1 = y2 = inHeight - 1;
151 dy1 = 0;
152 }
153
154 T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
155 T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
156 uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
157 y1 * inWidth * inDepth + x2 * inDepth,
158 y2 * inWidth * inDepth + x1 * inDepth,
159 y2 * inWidth * inDepth + x2 * inDepth};
160
161 for (uint32_t k = 0; k < inDepth; k++) {
162 T_Input interpolation = 0;
163 for (uint32_t c = 0; c < 4; c++) {
164 interpolation += ws[c] * batchBase[offsets[c] + k];
165 }
166 outPtr[k] += interpolation;
167 }
168 }
169 }
170
171 // take average
172 for (uint32_t k = 0; k < inDepth; k++)
173 outPtr[k] /= static_cast<T_Input>(numSamplingPoints);
174 outPtr += inDepth;
175 }
176 }
177 }
178 return true;
179 }
180
181 template <typename T_Input>
roiAlignQuantNhwc(const T_Input * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)182 inline bool roiAlignQuantNhwc(const T_Input* inputData, const Shape& inputShape,
183 const uint16_t* roiData, const Shape& roiShape,
184 const int32_t* batchSplitData, const Shape& batchSplitShape,
185 float heightStride, float widthStride, int32_t heightSamplingRatio,
186 int32_t widthSamplingRatio, T_Input* outputData,
187 const Shape& outputShape) {
188 NNTRACE_TRANS("RoiAlignQuant8");
189
190 constexpr float wScale = 1.0f / 255.0f;
191 constexpr uint32_t kRoiDim = 4;
192 const float heightScale = 1.0f / heightStride;
193 const float widthScale = 1.0f / widthStride;
194
195 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
196 uint32_t inHeight = getSizeOfDimension(inputShape, 1);
197 uint32_t inWidth = getSizeOfDimension(inputShape, 2);
198 uint32_t inDepth = getSizeOfDimension(inputShape, 3);
199 uint32_t outHeight = getSizeOfDimension(outputShape, 1);
200 uint32_t outWidth = getSizeOfDimension(outputShape, 2);
201 uint32_t numRois = getSizeOfDimension(roiShape, 0);
202 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
203
204 T_Input* outPtr = outputData;
205 const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength;
206 uint32_t roiIndex = 0;
207 for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
208 uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
209 float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f;
210 float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f;
211 float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f;
212 float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f;
213
214 // Check for malformed data
215 // 1. invalid batch id
216 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
217 // 3. Invalid region: x2 < x1 || y2 < y1
218 NN_RET_CHECK_GE(batchId, 0);
219 NN_RET_CHECK_LT(batchId, numBatches);
220 NN_RET_CHECK(wRoiStart <= inWidth);
221 NN_RET_CHECK(hRoiStart <= inHeight);
222 NN_RET_CHECK(wRoiEnd <= inWidth);
223 NN_RET_CHECK(hRoiEnd <= inHeight);
224 NN_RET_CHECK_LE(wRoiStart, wRoiEnd);
225 NN_RET_CHECK_LE(hRoiStart, hRoiEnd);
226
227 float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f);
228 float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f);
229 float wStepSize = roiWidth / static_cast<float>(outWidth);
230 float hStepSize = roiHeight / static_cast<float>(outHeight);
231
232 // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
233 uint32_t wSamplingRatio =
234 widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize);
235 uint32_t hSamplingRatio =
236 heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize);
237 int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
238 float wBinSize = wStepSize / static_cast<float>(wSamplingRatio);
239 float hBinSize = hStepSize / static_cast<float>(hSamplingRatio);
240
241 float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints;
242 int32_t outputMultiplier = 0;
243 int32_t outputShift = 0;
244 if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) {
245 return false;
246 }
247
248 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
249 for (uint32_t i = 0; i < outHeight; i++) {
250 for (uint32_t j = 0; j < outWidth; j++) {
251 float wStart = wStepSize * j + wRoiStart;
252 float wEnd = wStepSize * (j + 1) + wRoiStart;
253 float hStart = hStepSize * i + hRoiStart;
254 float hEnd = hStepSize * (i + 1) + hRoiStart;
255
256 std::vector<int32_t> outTemp(inDepth, 0);
257 // calculate the sum of the sampling points
258 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
259 for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
260 float y = hStart + hBinSize / 2 + hBinSize * yInd;
261 float x = wStart + wBinSize / 2 + wBinSize * xInd;
262
263 // bilinear interpolation of point (x,y)
264 // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
265 uint32_t x1 = std::floor(x), y1 = std::floor(y);
266 uint32_t x2 = x1 + 1, y2 = y1 + 1;
267 float dx1 = x - static_cast<float>(x1);
268 float dy1 = y - static_cast<float>(y1);
269
270 // dealing with out of bound samples
271 if (x1 >= inWidth - 1) {
272 x1 = x2 = inWidth - 1;
273 dx1 = 0;
274 }
275 if (y1 >= inHeight - 1) {
276 y1 = y2 = inHeight - 1;
277 dy1 = 0;
278 }
279
280 float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
281 float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
282 uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
283 y1 * inWidth * inDepth + x2 * inDepth,
284 y2 * inWidth * inDepth + x1 * inDepth,
285 y2 * inWidth * inDepth + x2 * inDepth};
286
287 for (uint32_t k = 0; k < inDepth; k++) {
288 int32_t interpolation = 0;
289 for (uint32_t c = 0; c < 4; c++) {
290 int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale));
291 interpolation +=
292 wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) -
293 inputShape.offset);
294 }
295 outTemp[k] += interpolation;
296 }
297 }
298 }
299
300 // take average and cast to output quantization
301 for (uint32_t k = 0; k < inDepth; k++) {
302 int32_t raw_out = tflite::MultiplyByQuantizedMultiplier(
303 outTemp[k], outputMultiplier, -outputShift) +
304 outputShape.offset;
305 outPtr[k] = saturateCast<T_Input>(raw_out);
306 }
307 outPtr += inDepth;
308 }
309 }
310 }
311 return true;
312 }
313
314 template <typename T_Input, typename T_Roi>
roiAlign(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,bool useNchw,T_Input * outputData,const Shape & outputShape)315 inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
316 const Shape& roiShape, const int32_t* batchSplitData,
317 const Shape& batchSplitShape, float heightStride, float widthStride,
318 int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw,
319 T_Input* outputData, const Shape& outputShape) {
320 InputWithLayout<T_Input> input(useNchw);
321 OutputWithLayout<T_Input> output(useNchw);
322 NN_RET_CHECK(input.initialize(inputData, inputShape));
323 NN_RET_CHECK(output.initialize(outputData, outputShape));
324 if constexpr (std::is_same_v<T_Roi, uint16_t> &&
325 (std::is_same_v<T_Input, uint8_t> || std::is_same_v<T_Input, int8_t>)) {
326 NN_RET_CHECK(roiAlignQuantNhwc<T_Input>(
327 input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape, batchSplitData,
328 batchSplitShape, heightStride, widthStride, heightSamplingRatio, widthSamplingRatio,
329 output.getNhwcBuffer(), output.getNhwcShape()));
330 } else {
331 NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
332 batchSplitData, batchSplitShape, heightStride, widthStride,
333 heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(),
334 output.getNhwcShape()));
335 }
336 NN_RET_CHECK(output.commit());
337 return true;
338 }
339
340 } // namespace
341
validate(const IOperationValidationContext * context)342 bool validate(const IOperationValidationContext* context) {
343 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
344 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
345 std::vector<OperandType> inExpectedTypes;
346 auto inputType = context->getInputType(kInputTensor);
347 if (inputType == OperandType::TENSOR_FLOAT32) {
348 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
349 OperandType::TENSOR_INT32, OperandType::INT32,
350 OperandType::INT32, OperandType::FLOAT32,
351 OperandType::FLOAT32, OperandType::INT32,
352 OperandType::INT32, OperandType::BOOL};
353 } else if (inputType == OperandType::TENSOR_FLOAT16) {
354 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
355 OperandType::TENSOR_INT32, OperandType::INT32,
356 OperandType::INT32, OperandType::FLOAT16,
357 OperandType::FLOAT16, OperandType::INT32,
358 OperandType::INT32, OperandType::BOOL};
359 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
360 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
361 inExpectedTypes = {inputType,
362 OperandType::TENSOR_QUANT16_ASYMM,
363 OperandType::TENSOR_INT32,
364 OperandType::INT32,
365 OperandType::INT32,
366 OperandType::FLOAT32,
367 OperandType::FLOAT32,
368 OperandType::INT32,
369 OperandType::INT32,
370 OperandType::BOOL};
371 } else {
372 LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
373 return false;
374 }
375 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
376 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
377 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
378 return validateHalVersion(context, HalVersion::V1_3);
379 } else {
380 return validateHalVersion(context, HalVersion::V1_2);
381 }
382 }
383
prepare(IOperationExecutionContext * context)384 bool prepare(IOperationExecutionContext* context) {
385 bool useNchw = context->getInputValue<bool>(kLayoutScalar);
386 Shape input = context->getInputShape(kInputTensor);
387 Shape roiShape = context->getInputShape(kRoiTensor);
388 Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
389 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
390 NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
391
392 uint32_t numBatches = getSizeOfDimension(input, 0);
393 uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
394 uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
395 uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
396 uint32_t numRois = getSizeOfDimension(roiShape, 0);
397 // Every dimension must be positive except for numRois.
398 NN_RET_CHECK_GT(numBatches, 0);
399 NN_RET_CHECK_GT(inHeight, 0);
400 NN_RET_CHECK_GT(inWidth, 0);
401 NN_RET_CHECK_GT(inDepth, 0);
402 NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
403 NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
404
405 int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
406 int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
407 int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar);
408 int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar);
409 float heightScale, widthScale;
410 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
411 heightScale = context->getInputValue<_Float16>(kHeightStrideSalar);
412 widthScale = context->getInputValue<_Float16>(kWidthStrideScalar);
413 } else {
414 heightScale = context->getInputValue<float>(kHeightStrideSalar);
415 widthScale = context->getInputValue<float>(kWidthStrideScalar);
416 }
417 NN_RET_CHECK_GT(outputHeight, 0);
418 NN_RET_CHECK_GT(outputWidth, 0);
419 NN_RET_CHECK_GT(heightScale, 0);
420 NN_RET_CHECK_GT(widthScale, 0);
421 // Sampling ratio can set to 0 for adaptive value.
422 NN_RET_CHECK_GE(heightSamplingRatio, 0);
423 NN_RET_CHECK_GE(widthSamplingRatio, 0);
424
425 if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
426 NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
427 NN_RET_CHECK_EQ(roiShape.offset, 0);
428 }
429
430 Shape output = context->getOutputShape(kOutputTensor);
431 output.type = input.type;
432 if (useNchw) {
433 output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
434 static_cast<uint32_t>(outputWidth)};
435 } else {
436 output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
437 static_cast<uint32_t>(outputWidth), inDepth};
438 }
439 return context->setOutputShape(kOutputTensor, output);
440 }
441
execute(IOperationExecutionContext * context)442 bool execute(IOperationExecutionContext* context) {
443 // Bypass execution in the case of zero-sized input.
444 if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true;
445 switch (context->getInputType(kInputTensor)) {
446 case OperandType::TENSOR_FLOAT16:
447 return roiAlign(context->getInputBuffer<_Float16>(kInputTensor),
448 context->getInputShape(kInputTensor),
449 context->getInputBuffer<_Float16>(kRoiTensor),
450 context->getInputShape(kRoiTensor),
451 context->getInputBuffer<int32_t>(kBatchSplitTensor),
452 context->getInputShape(kBatchSplitTensor),
453 context->getInputValue<_Float16>(kHeightStrideSalar),
454 context->getInputValue<_Float16>(kWidthStrideScalar),
455 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
456 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
457 context->getInputValue<bool>(kLayoutScalar),
458 context->getOutputBuffer<_Float16>(kOutputTensor),
459 context->getOutputShape(kOutputTensor));
460 case OperandType::TENSOR_FLOAT32:
461 return roiAlign(context->getInputBuffer<float>(kInputTensor),
462 context->getInputShape(kInputTensor),
463 context->getInputBuffer<float>(kRoiTensor),
464 context->getInputShape(kRoiTensor),
465 context->getInputBuffer<int32_t>(kBatchSplitTensor),
466 context->getInputShape(kBatchSplitTensor),
467 context->getInputValue<float>(kHeightStrideSalar),
468 context->getInputValue<float>(kWidthStrideScalar),
469 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
470 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
471 context->getInputValue<bool>(kLayoutScalar),
472 context->getOutputBuffer<float>(kOutputTensor),
473 context->getOutputShape(kOutputTensor));
474 case OperandType::TENSOR_QUANT8_ASYMM:
475 return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor),
476 context->getInputShape(kInputTensor),
477 context->getInputBuffer<uint16_t>(kRoiTensor),
478 context->getInputShape(kRoiTensor),
479 context->getInputBuffer<int32_t>(kBatchSplitTensor),
480 context->getInputShape(kBatchSplitTensor),
481 context->getInputValue<float>(kHeightStrideSalar),
482 context->getInputValue<float>(kWidthStrideScalar),
483 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
484 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
485 context->getInputValue<bool>(kLayoutScalar),
486 context->getOutputBuffer<uint8_t>(kOutputTensor),
487 context->getOutputShape(kOutputTensor));
488 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
489 return roiAlign(context->getInputBuffer<int8_t>(kInputTensor),
490 context->getInputShape(kInputTensor),
491 context->getInputBuffer<uint16_t>(kRoiTensor),
492 context->getInputShape(kRoiTensor),
493 context->getInputBuffer<int32_t>(kBatchSplitTensor),
494 context->getInputShape(kBatchSplitTensor),
495 context->getInputValue<float>(kHeightStrideSalar),
496 context->getInputValue<float>(kWidthStrideScalar),
497 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
498 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
499 context->getInputValue<bool>(kLayoutScalar),
500 context->getOutputBuffer<int8_t>(kOutputTensor),
501 context->getOutputShape(kOutputTensor));
502 default:
503 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
504 }
505 }
506
507 } // namespace roi_align
508
509 NN_REGISTER_OPERATION(ROI_ALIGN, roi_align::kOperationName, roi_align::validate, roi_align::prepare,
510 roi_align::execute, .allowZeroSizedInput = true);
511
512 } // namespace nn
513 } // namespace android
514