1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <vector>
20
21 #include "OperationResolver.h"
22 #include "Tracing.h"
23 #include "nnapi/Validation.h"
24
25 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
26 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
27 #include <tensorflow/lite/kernels/internal/reference/integer_ops/pooling.h>
28
29 #include "CpuOperationUtils.h"
30 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
31
32 namespace android {
33 namespace nn {
34
35 namespace pooling {
36
37 constexpr uint32_t kInputTensor = 0;
38
39 constexpr uint32_t kNumOutputs = 1;
40 constexpr uint32_t kOutputTensor = 0;
41
42 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
43 namespace {
44
45 struct PoolingParam {
46 int32_t padding_left, padding_right;
47 int32_t padding_top, padding_bottom;
48 int32_t stride_width, stride_height;
49 int32_t filter_width, filter_height;
50 int32_t activation;
51 bool useNchw = false;
52
initializeandroid::nn::pooling::__anone683dcc30111::PoolingParam53 bool initialize(const IOperationExecutionContext* context) {
54 uint32_t inCount = context->getNumInputs();
55 int32_t padding_implicit = 0;
56 if (inCount >= 10) {
57 padding_left = context->getInputValue<int32_t>(1);
58 padding_right = context->getInputValue<int32_t>(2);
59 padding_top = context->getInputValue<int32_t>(3);
60 padding_bottom = context->getInputValue<int32_t>(4);
61 stride_width = context->getInputValue<int32_t>(5);
62 stride_height = context->getInputValue<int32_t>(6);
63 filter_width = context->getInputValue<int32_t>(7);
64 filter_height = context->getInputValue<int32_t>(8);
65 activation = context->getInputValue<int32_t>(9);
66 if (inCount == 11) {
67 useNchw = context->getInputValue<bool>(10);
68 }
69 } else {
70 padding_implicit = context->getInputValue<int32_t>(1);
71 stride_width = context->getInputValue<int32_t>(2);
72 stride_height = context->getInputValue<int32_t>(3);
73 filter_width = context->getInputValue<int32_t>(4);
74 filter_height = context->getInputValue<int32_t>(5);
75 activation = context->getInputValue<int32_t>(6);
76 if (inCount == 8) {
77 useNchw = context->getInputValue<bool>(7);
78 }
79 }
80 if (inCount <= 8) {
81 Shape inputShape = context->getInputShape(kInputTensor);
82 int32_t input_height = getSizeOfDimension(inputShape, useNchw ? 2 : 1);
83 int32_t input_width = getSizeOfDimension(inputShape, useNchw ? 3 : 2);
84 calculateExplicitPadding(input_width, stride_width, filter_width, padding_implicit,
85 &padding_left, &padding_right);
86 calculateExplicitPadding(input_height, stride_height, filter_height, padding_implicit,
87 &padding_top, &padding_bottom);
88 }
89 NN_RET_CHECK_GE(padding_left, 0);
90 NN_RET_CHECK_GE(padding_right, 0);
91 NN_RET_CHECK_GE(padding_top, 0);
92 NN_RET_CHECK_GE(padding_bottom, 0);
93 NN_RET_CHECK_GT(stride_width, 0);
94 NN_RET_CHECK_GT(stride_height, 0);
95 NN_RET_CHECK_GT(filter_width, 0);
96 NN_RET_CHECK_GT(filter_height, 0);
97 NN_RET_CHECK_GE(activation, 0);
98 NN_RET_CHECK_GT(filter_width, padding_left);
99 NN_RET_CHECK_GT(filter_width, padding_right);
100 NN_RET_CHECK_GT(filter_height, padding_top);
101 NN_RET_CHECK_GT(filter_height, padding_bottom);
102 return true;
103 }
104
toTfliteParamandroid::nn::pooling::__anone683dcc30111::PoolingParam105 tflite::PoolParams toTfliteParam(const Shape& output) const {
106 tflite::PoolParams params = {
107 .padding_values = {.width = static_cast<int16_t>(padding_left),
108 .height = static_cast<int16_t>(padding_top),
109 .width_offset = 0,
110 .height_offset = 0},
111 .stride_height = stride_height,
112 .stride_width = stride_width,
113 .filter_height = filter_height,
114 .filter_width = filter_width,
115 };
116 if (output.type == OperandType::TENSOR_QUANT8_ASYMM) {
117 int32_t output_activation_min = 0;
118 int32_t output_activation_max = 0;
119 CalculateActivationRangeUint8(activation, output, &output_activation_min,
120 &output_activation_max);
121 params.quantized_activation_min = output_activation_min;
122 params.quantized_activation_max = output_activation_max;
123 } else if (output.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
124 int32_t output_activation_min = 0;
125 int32_t output_activation_max = 0;
126 CalculateActivationRangeInt8(activation, output, &output_activation_min,
127 &output_activation_max);
128 params.quantized_activation_min = output_activation_min;
129 params.quantized_activation_max = output_activation_max;
130 } else {
131 float output_activation_min, output_activation_max;
132 CalculateActivationRangeFloat(activation, &output_activation_min,
133 &output_activation_max);
134 params.float_activation_min = output_activation_min;
135 params.float_activation_max = output_activation_max;
136 }
137 return params;
138 }
139 };
140
averagePoolNhwc(const float * inputData,const Shape & inputShape,const PoolingParam & param,float * outputData,const Shape & outputShape)141 bool averagePoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
142 float* outputData, const Shape& outputShape) {
143 NNTRACE_TRANS("averagePoolFloat32");
144 auto op_params = param.toTfliteParam(outputShape);
145 NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
146 tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
147 convertShapeToTflshape(outputShape), outputData);
148 return true;
149 }
150
averagePoolNhwc(const _Float16 * inputData,const Shape & inputShape,const PoolingParam & param,_Float16 * outputData,const Shape & outputShape)151 bool averagePoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
152 _Float16* outputData, const Shape& outputShape) {
153 NNTRACE_TRANS("averagePoolFloat16");
154 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
155 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
156
157 convertFloat16ToFloat32(inputData, &inputDataFloat32);
158 averagePoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(),
159 outputShape);
160 convertFloat32ToFloat16(outputDataFloat32, outputData);
161 return true;
162 }
163
averagePoolNhwc(const uint8_t * inputData,const Shape & inputShape,const PoolingParam & param,uint8_t * outputData,const Shape & outputShape)164 bool averagePoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
165 uint8_t* outputData, const Shape& outputShape) {
166 NNTRACE_TRANS("averagePoolQuant8");
167 auto op_params = param.toTfliteParam(outputShape);
168 NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
169 tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
170 convertShapeToTflshape(outputShape), outputData);
171 return true;
172 }
173
averagePoolNhwc(const int8_t * inputData,const Shape & inputShape,const PoolingParam & param,int8_t * outputData,const Shape & outputShape)174 bool averagePoolNhwc(const int8_t* inputData, const Shape& inputShape, const PoolingParam& param,
175 int8_t* outputData, const Shape& outputShape) {
176 NNTRACE_TRANS("averagePoolQuant8Signed");
177 auto op_params = param.toTfliteParam(outputShape);
178 NNTRACE_COMP_SWITCH("optimized_integer_ops::AveragePool");
179 // We are using reference implementation of the AveragePool op because the
180 // optimized version fails to pass some of the quantization coupling tests.
181 tflite::reference_integer_ops::AveragePool(op_params, convertShapeToTflshape(inputShape),
182 inputData, convertShapeToTflshape(outputShape),
183 outputData);
184 return true;
185 }
186
l2PoolNhwc(const float * inputData,const Shape & inputShape,const PoolingParam & param,float * outputData,const Shape & outputShape)187 bool l2PoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
188 float* outputData, const Shape& outputShape) {
189 NNTRACE_TRANS("l2PoolFloat32");
190 auto op_params = param.toTfliteParam(outputShape);
191 NNTRACE_COMP_SWITCH("optimized_ops::L2Pool");
192 tflite::optimized_ops::L2Pool(op_params, convertShapeToTflshape(inputShape), inputData,
193 convertShapeToTflshape(outputShape), outputData);
194 return true;
195 }
196
l2PoolNhwc(const _Float16 * inputData,const Shape & inputShape,const PoolingParam & param,_Float16 * outputData,const Shape & outputShape)197 bool l2PoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
198 _Float16* outputData, const Shape& outputShape) {
199 NNTRACE_TRANS("l2PoolFloat16");
200 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
201 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
202
203 convertFloat16ToFloat32(inputData, &inputDataFloat32);
204 l2PoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(), outputShape);
205 convertFloat32ToFloat16(outputDataFloat32, outputData);
206 return true;
207 }
208
maxPoolNhwc(const float * inputData,const Shape & inputShape,const PoolingParam & param,float * outputData,const Shape & outputShape)209 bool maxPoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
210 float* outputData, const Shape& outputShape) {
211 NNTRACE_TRANS("maxPoolFloat32");
212 auto op_params = param.toTfliteParam(outputShape);
213 NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
214 tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
215 convertShapeToTflshape(outputShape), outputData);
216 return true;
217 }
218
maxPoolNhwc(const uint8_t * inputData,const Shape & inputShape,const PoolingParam & param,uint8_t * outputData,const Shape & outputShape)219 bool maxPoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
220 uint8_t* outputData, const Shape& outputShape) {
221 NNTRACE_TRANS("maxPoolQuant8");
222 auto op_params = param.toTfliteParam(outputShape);
223 NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
224 tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
225 convertShapeToTflshape(outputShape), outputData);
226 return true;
227 }
228
maxPoolNhwc(const int8_t * inputData,const Shape & inputShape,const PoolingParam & param,int8_t * outputData,const Shape & outputShape)229 bool maxPoolNhwc(const int8_t* inputData, const Shape& inputShape, const PoolingParam& param,
230 int8_t* outputData, const Shape& outputShape) {
231 NNTRACE_TRANS("maxPoolQuant8Signed");
232 auto op_params = param.toTfliteParam(outputShape);
233 NNTRACE_COMP_SWITCH("optimized_integer_ops::MaxPool");
234 // We are using reference implementation of the MaxPool op because the
235 // optimized version fails to pass some of the quantization coupling tests.
236 tflite::reference_integer_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
237 convertShapeToTflshape(outputShape), outputData);
238 return true;
239 }
240
maxPoolNhwc(const _Float16 * inputData,const Shape & inputShape,const PoolingParam & param,_Float16 * outputData,const Shape & outputShape)241 bool maxPoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
242 _Float16* outputData, const Shape& outputShape) {
243 NNTRACE_TRANS("maxPoolFloat16");
244 std::vector<float> inputData_float32(getNumberOfElements(inputShape));
245 std::vector<float> outputData_float32(getNumberOfElements(outputShape));
246
247 convertFloat16ToFloat32(inputData, &inputData_float32);
248 maxPoolNhwc(inputData_float32.data(), inputShape, param, outputData_float32.data(),
249 outputShape);
250 convertFloat32ToFloat16(outputData_float32, outputData);
251 return true;
252 }
253
254 template <typename T>
averagePool(const T * inputData,const Shape & inputShape,const PoolingParam & param,T * outputData,const Shape & outputShape)255 bool averagePool(const T* inputData, const Shape& inputShape, const PoolingParam& param,
256 T* outputData, const Shape& outputShape) {
257 InputWithLayout<T> input(param.useNchw);
258 OutputWithLayout<T> output(param.useNchw);
259 NN_RET_CHECK(input.initialize(inputData, inputShape));
260 NN_RET_CHECK(output.initialize(outputData, outputShape));
261 NN_RET_CHECK(averagePoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
262 output.getNhwcBuffer(), output.getNhwcShape()));
263 NN_RET_CHECK(output.commit());
264 return true;
265 }
266
267 template <typename T>
l2Pool(const T * inputData,const Shape & inputShape,const PoolingParam & param,T * outputData,const Shape & outputShape)268 bool l2Pool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
269 const Shape& outputShape) {
270 InputWithLayout<T> input(param.useNchw);
271 OutputWithLayout<T> output(param.useNchw);
272 NN_RET_CHECK(input.initialize(inputData, inputShape));
273 NN_RET_CHECK(output.initialize(outputData, outputShape));
274 NN_RET_CHECK(l2PoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
275 output.getNhwcBuffer(), output.getNhwcShape()));
276 NN_RET_CHECK(output.commit());
277 return true;
278 }
279
280 template <typename T>
maxPool(const T * inputData,const Shape & inputShape,const PoolingParam & param,T * outputData,const Shape & outputShape)281 bool maxPool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
282 const Shape& outputShape) {
283 InputWithLayout<T> input(param.useNchw);
284 OutputWithLayout<T> output(param.useNchw);
285 NN_RET_CHECK(input.initialize(inputData, inputShape));
286 NN_RET_CHECK(output.initialize(outputData, outputShape));
287 NN_RET_CHECK(maxPoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
288 output.getNhwcBuffer(), output.getNhwcShape()));
289 NN_RET_CHECK(output.commit());
290 return true;
291 }
292
293 } // namespace
294 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
295
validate(OperationType opType,const IOperationValidationContext * context)296 Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
297 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
298 auto inputCount = context->getNumInputs();
299 NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7);
300 auto inputType = context->getInputType(kInputTensor);
301 std::vector<OperandType> inExpectedTypes;
302 auto minSupportedVersion = Version::ANDROID_OC_MR1;
303 if (inputType == OperandType::TENSOR_FLOAT32) {
304 minSupportedVersion = Version::ANDROID_OC_MR1;
305 inExpectedTypes = {
306 inputType, OperandType::INT32, OperandType::INT32, OperandType::INT32,
307 OperandType::INT32, OperandType::INT32, OperandType::INT32,
308 };
309 } else if (inputType == OperandType::TENSOR_FLOAT16) {
310 minSupportedVersion = Version::ANDROID_Q;
311 inExpectedTypes = {
312 OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32,
313 OperandType::INT32, OperandType::INT32, OperandType::INT32,
314 OperandType::INT32,
315 };
316 } else if (opType != OperationType::L2_POOL_2D &&
317 inputType == OperandType::TENSOR_QUANT8_ASYMM) {
318 minSupportedVersion = Version::ANDROID_OC_MR1;
319 inExpectedTypes = {
320 OperandType::TENSOR_QUANT8_ASYMM,
321 OperandType::INT32,
322 OperandType::INT32,
323 OperandType::INT32,
324 OperandType::INT32,
325 OperandType::INT32,
326 OperandType::INT32,
327 };
328 } else if (opType != OperationType::L2_POOL_2D &&
329 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
330 minSupportedVersion = Version::ANDROID_R;
331 inExpectedTypes = {
332 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
333 OperandType::INT32,
334 OperandType::INT32,
335 OperandType::INT32,
336 OperandType::INT32,
337 OperandType::INT32,
338 OperandType::INT32,
339 };
340 } else {
341 NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << opType;
342 }
343
344 if (inputCount >= 10) {
345 std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
346 inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
347 explicitScalarTypes.end());
348 }
349 if (inputCount == 11 || inputCount == 8) {
350 inExpectedTypes.push_back(OperandType::BOOL);
351 minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
352 } else {
353 minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
354 }
355 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
356 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
357 return minSupportedVersion;
358 }
359
360 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)361 bool prepare(IOperationExecutionContext* context) {
362 Shape input = context->getInputShape(kInputTensor);
363 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
364
365 PoolingParam param;
366 NN_RET_CHECK(param.initialize(context));
367
368 // Only batches can be zero.
369 uint32_t batches = getSizeOfDimension(input, 0);
370 uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
371 uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
372 uint32_t channels = getSizeOfDimension(input, param.useNchw ? 1 : 3);
373 NN_RET_CHECK_GT(height, 0);
374 NN_RET_CHECK_GT(width, 0);
375 NN_RET_CHECK_GT(channels, 0);
376
377 uint32_t outWidth = computeOutSize(width, param.filter_width, param.stride_width,
378 param.padding_left, param.padding_right);
379 uint32_t outHeight = computeOutSize(height, param.filter_height, param.stride_height,
380 param.padding_top, param.padding_bottom);
381
382 Shape output = input;
383 if (param.useNchw) {
384 output.dimensions = {batches, channels, outHeight, outWidth};
385 } else {
386 output.dimensions = {batches, outHeight, outWidth, channels};
387 }
388 return context->setOutputShape(kOutputTensor, output);
389 }
390
391 #define POOLING_DISPATCH_INPUT_TYPE(name, type, cppType) \
392 case OperandType::type: \
393 return name(context->getInputBuffer<cppType>(kInputTensor), \
394 context->getInputShape(kInputTensor), param, \
395 context->getOutputBuffer<cppType>(kOutputTensor), \
396 context->getOutputShape(kOutputTensor))
397
executeAveragePool(IOperationExecutionContext * context)398 bool executeAveragePool(IOperationExecutionContext* context) {
399 // Bypass execution in the case of zero-sized input.
400 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
401 PoolingParam param;
402 NN_RET_CHECK(param.initialize(context));
403 switch (context->getInputType(kInputTensor)) {
404 POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT32, float);
405 POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT16, _Float16);
406 POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_QUANT8_ASYMM, uint8_t);
407 POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_QUANT8_ASYMM_SIGNED, int8_t);
408 default:
409 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation AVERAGE_POOL_2D";
410 }
411 }
412
executeL2Pool(IOperationExecutionContext * context)413 bool executeL2Pool(IOperationExecutionContext* context) {
414 // Bypass execution in the case of zero-sized input.
415 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
416 PoolingParam param;
417 NN_RET_CHECK(param.initialize(context));
418 switch (context->getInputType(kInputTensor)) {
419 POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT32, float);
420 POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT16, _Float16);
421 default:
422 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation L2_POOL_2D";
423 }
424 }
425
executeMaxPool(IOperationExecutionContext * context)426 bool executeMaxPool(IOperationExecutionContext* context) {
427 // Bypass execution in the case of zero-sized input.
428 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
429 PoolingParam param;
430 NN_RET_CHECK(param.initialize(context));
431 switch (context->getInputType(kInputTensor)) {
432 POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT32, float);
433 POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT16, _Float16);
434 POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_QUANT8_ASYMM, uint8_t);
435 POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_QUANT8_ASYMM_SIGNED, int8_t);
436 default:
437 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MAX_POOL_2D";
438 }
439 }
440 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
441
442 #undef POOLING_DISPATCH_INPUT_TYPE
443
444 } // namespace pooling
445
446 using std::placeholders::_1;
447 NN_REGISTER_OPERATION(AVERAGE_POOL_2D, "AVERAGE_POOL_2D",
448 std::bind(pooling::validate, OperationType::AVERAGE_POOL_2D, _1),
449 pooling::prepare, pooling::executeAveragePool, .allowZeroSizedInput = true);
450 NN_REGISTER_OPERATION(L2_POOL_2D, "L2_POOL_2D",
451 std::bind(pooling::validate, OperationType::L2_POOL_2D, _1), pooling::prepare,
452 pooling::executeL2Pool, .allowZeroSizedInput = true);
453 NN_REGISTER_OPERATION(MAX_POOL_2D, "MAX_POOL_2D",
454 std::bind(pooling::validate, OperationType::MAX_POOL_2D, _1),
455 pooling::prepare, pooling::executeMaxPool, .allowZeroSizedInput = true);
456
457 } // namespace nn
458 } // namespace android
459