1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/infer/pooling_infer.h"
18 #include <math.h>
19 #include "nnacl/infer/infer_register.h"
20
ComputePadList(PoolingParameter * param,int input_h,int input_w,int output_h,int output_w)21 int ComputePadList(PoolingParameter *param, int input_h, int input_w, int output_h, int output_w) {
22 if (param == NULL) {
23 return NNACL_NULL_PTR;
24 }
25 int pad_h_all = ((output_h - 1) * param->stride_h_ + (param->window_h_ - 1) + 1 - input_h);
26 int pad_w_all = ((output_w - 1) * param->stride_w_ + (param->window_w_ - 1) + 1 - input_w);
27 if (pad_h_all < 0) {
28 param->pad_u_ = param->pad_d_ = 0;
29 } else {
30 param->pad_u_ = pad_h_all / 2;
31 param->pad_d_ = pad_h_all - param->pad_u_;
32 }
33 if (pad_w_all < 0) {
34 param->pad_l_ = param->pad_r_ = 0;
35 } else {
36 param->pad_l_ = pad_w_all / 2;
37 param->pad_r_ = pad_w_all - param->pad_l_;
38 }
39 return NNACL_OK;
40 }
41
PoolingInferShape(const TensorC * const * inputs,size_t inputs_size,TensorC ** outputs,size_t outputs_size,OpParameter * parameter)42 int PoolingInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
43 OpParameter *parameter) {
44 int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 1, 1);
45 if (check_ret != NNACL_OK) {
46 return check_ret;
47 }
48
49 const TensorC *input = inputs[0];
50 NNACL_CHECK_TRUE_RET(input->format_ == Format_NHWC, NNACL_FORMAT_ERROR);
51 for (size_t i = 0; i < outputs_size; i++) {
52 TensorC *output = outputs[i];
53 SetDataTypeFormat(output, input);
54 }
55 PoolingParameter *param = (PoolingParameter *)parameter;
56 if (!InferFlag(inputs, inputs_size)) {
57 return NNACL_INFER_INVALID;
58 }
59 if (input->shape_size_ < 3 || input->shape_size_ > MAX_SHAPE_SIZE) {
60 return NNACL_INPUT_TENSOR_ERROR;
61 }
62 int input_h = input->shape_[1];
63 int input_w = input->shape_[2];
64
65 int window_h = param->window_h_;
66 int window_w = param->window_w_;
67 if (param->global_) {
68 param->window_h_ = window_h = input_h;
69 param->window_w_ = window_w = input_w;
70 }
71 int output_h = 0;
72 int output_w = 0;
73 if ((param->stride_h_ == 0 || param->stride_w_ == 0) && !param->global_) {
74 return NNACL_PARAM_INVALID;
75 }
76 if (param->pad_mode_ == Pad_same) {
77 output_w = ceil((float)(input_w) / (float)(param->stride_w_));
78 output_h = ceil((float)(input_h) / (float)(param->stride_h_));
79 if (ComputePadList(param, input_h, input_w, output_h, output_w) != NNACL_OK) {
80 return NNACL_NULL_PTR;
81 }
82 } else {
83 int round_mode = (RoundType)param->round_type_;
84 if (round_mode == RoundType_Floor) {
85 output_h = floor((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1;
86 output_w = floor((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1;
87 } else if (round_mode == RoundType_Ceil) {
88 output_h = ceil((float)(input_h + param->pad_u_ + param->pad_d_ - window_h) / param->stride_h_) + 1;
89 output_w = ceil((float)(input_w + param->pad_l_ + param->pad_r_ - window_w) / param->stride_w_) + 1;
90 } else {
91 return NNACL_ERR;
92 }
93 }
94 int input_shape[MAX_SHAPE_SIZE];
95 size_t input_shape_size = 0;
96 ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_);
97 input_shape[1] = output_h > 0 ? output_h : 1;
98 input_shape[2] = output_w > 0 ? output_w : 1;
99 for (size_t i = 0; i < outputs_size; i++) {
100 TensorC *output = outputs[i];
101 SetShapeArray(output, input_shape, input_shape_size);
102 }
103 return NNACL_OK;
104 }
105
106 REG_INFER(MaxPool, PrimType_MaxPoolFusion, PoolingInferShape)
107 REG_INFER(AvgPool, PrimType_AvgPoolFusion, PoolingInferShape)
108