1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/common/ops/populate/populate_register.h"
17 #include "nnacl/pooling_parameter.h"
18 using mindspore::schema::PrimitiveType_AvgPoolFusion;
19 using mindspore::schema::PrimitiveType_MaxPoolFusion;
20
21 namespace mindspore {
22 namespace lite {
23 namespace {
CheckPoolingParam(const PoolingParameter * param)24 int CheckPoolingParam(const PoolingParameter *param) {
25 const int max_pooling_pad = 50;
26 if (param->pad_u_ > max_pooling_pad || param->pad_d_ > max_pooling_pad || param->pad_l_ > max_pooling_pad ||
27 param->pad_r_ > max_pooling_pad) {
28 return RET_ERROR;
29 }
30 return RET_OK;
31 }
32
UpdateRoundMode(enum schema::RoundMode round_mode,PoolingParameter * param)33 void UpdateRoundMode(enum schema::RoundMode round_mode, PoolingParameter *param) {
34 switch (round_mode) {
35 case schema::RoundMode_FLOOR:
36 param->round_type_ = RoundType_Floor;
37 break;
38 case schema::RoundMode_CEIL:
39 param->round_type_ = RoundType_Ceil;
40 break;
41 default:
42 param->round_type_ = RoundType_No;
43 break;
44 }
45 }
46
UpdateActivationType(enum schema::ActivationType type,PoolingParameter * param)47 void UpdateActivationType(enum schema::ActivationType type, PoolingParameter *param) {
48 if (type == schema::ActivationType_RELU) {
49 param->act_type_ = ActType_Relu;
50 } else if (type == schema::ActivationType_RELU6) {
51 param->act_type_ = ActType_Relu6;
52 } else {
53 param->act_type_ = ActType_No;
54 }
55 }
56
UpdatePadMode(enum schema::PadMode pad_mode,PoolingParameter * param)57 void UpdatePadMode(enum schema::PadMode pad_mode, PoolingParameter *param) {
58 switch (pad_mode) {
59 case schema::PadMode_SAME:
60 param->pad_mode_ = Pad_same;
61 break;
62 case schema::PadMode_VALID:
63 param->pad_mode_ = Pad_valid;
64 break;
65 default:
66 param->pad_mode_ = Pad_pad;
67 break;
68 }
69 }
70 } // namespace
PopulateAvgPoolParameter(const void * primitive)71 OpParameter *PopulateAvgPoolParameter(const void *primitive) {
72 MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
73 auto pooling_prim = static_cast<const schema::Primitive *>(primitive);
74 auto value = pooling_prim->value_as_AvgPoolFusion();
75 if (value == nullptr) {
76 MS_LOG(ERROR) << "value is nullptr";
77 return nullptr;
78 }
79
80 auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
81 if (param == nullptr) {
82 MS_LOG(ERROR) << "malloc PoolingParameter failed.";
83 return nullptr;
84 }
85 memset(param, 0, sizeof(PoolingParameter));
86
87 param->op_parameter_.type_ = pooling_prim->value_type();
88 param->pool_mode_ = PoolMode_AvgPool;
89 param->global_ = value->global();
90 auto strides = value->strides();
91 if (strides == nullptr || strides->size() < kMinShapeSizeTwo) {
92 MS_LOG(ERROR) << "strides is invalid!";
93 free(param);
94 return nullptr;
95 }
96 param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
97 param->stride_h_ = static_cast<int>(*(strides->begin()));
98 auto pad = value->pad();
99 if (pad != nullptr && pad->size() >= kMinShapeSizeFour) {
100 param->pad_u_ = static_cast<int>(*(pad->begin()));
101 param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
102 param->pad_l_ = static_cast<int>(*(pad->begin() + kOffsetTwo));
103 param->pad_r_ = static_cast<int>(*(pad->begin() + kOffsetThree));
104 }
105 if (!param->global_) {
106 auto kernel_size = value->kernel_size();
107 if (kernel_size == nullptr || kernel_size->size() < kMinShapeSizeTwo) {
108 MS_LOG(ERROR) << "kernel_size is invalid";
109 free(param);
110 return nullptr;
111 }
112 param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
113 param->window_h_ = static_cast<int>(*(kernel_size->begin()));
114 }
115
116 UpdateRoundMode(value->round_mode(), param);
117 UpdateActivationType(value->activation_type(), param);
118 UpdatePadMode(value->pad_mode(), param);
119
120 if (CheckPoolingParam(param) != RET_OK) {
121 MS_LOG(ERROR) << "param is invalid!";
122 free(param);
123 return nullptr;
124 }
125 return reinterpret_cast<OpParameter *>(param);
126 }
127
PopulateMaxPoolParameter(const void * primitive)128 OpParameter *PopulateMaxPoolParameter(const void *primitive) {
129 auto pooling_prim = static_cast<const schema::Primitive *>(primitive);
130 MS_ASSERT(pooling_prim != nullptr);
131 auto value = pooling_prim->value_as_MaxPoolFusion();
132 if (value == nullptr) {
133 MS_LOG(ERROR) << "value is nullptr";
134 return nullptr;
135 }
136
137 auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
138 if (param == nullptr) {
139 MS_LOG(ERROR) << "malloc PoolingParameter failed.";
140 return nullptr;
141 }
142 memset(param, 0, sizeof(PoolingParameter));
143
144 param->op_parameter_.type_ = pooling_prim->value_type();
145 param->pool_mode_ = PoolMode_MaxPool;
146 param->global_ = value->global();
147 if (!param->global_) {
148 auto kernel_size = value->kernel_size();
149 auto strides = value->strides();
150 if (kernel_size == nullptr || strides == nullptr || kernel_size->size() < kMinShapeSizeTwo ||
151 strides->size() < kMinShapeSizeTwo) {
152 MS_LOG(ERROR) << "kernel_size or strides is invalid";
153 free(param);
154 return nullptr;
155 }
156 param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
157 param->window_h_ = static_cast<int>(*(kernel_size->begin()));
158 param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
159 param->stride_h_ = static_cast<int>(*(strides->begin()));
160 auto pad = value->pad();
161 if (pad != nullptr && pad->size() >= kMinShapeSizeFour) {
162 param->pad_u_ = static_cast<int>(*(pad->begin()));
163 param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
164 param->pad_l_ = static_cast<int>(*(pad->begin() + kOffsetTwo));
165 param->pad_r_ = static_cast<int>(*(pad->begin() + kOffsetThree));
166 }
167 }
168
169 UpdateRoundMode(value->round_mode(), param);
170 UpdateActivationType(value->activation_type(), param);
171 UpdatePadMode(value->pad_mode(), param);
172
173 if (CheckPoolingParam(param) != RET_OK) {
174 MS_LOG(ERROR) << "param is invalid!";
175 free(param);
176 return nullptr;
177 }
178 return reinterpret_cast<OpParameter *>(param);
179 }
180
181 REG_POPULATE(PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR)
182 REG_POPULATE(PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR)
183 } // namespace lite
184 } // namespace mindspore
185