• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/common/ops/populate/populate_register.h"
17 #include "nnacl/pooling_parameter.h"
18 using mindspore::schema::PrimitiveType_AvgPoolFusion;
19 using mindspore::schema::PrimitiveType_MaxPoolFusion;
20 
21 namespace mindspore {
22 namespace lite {
23 namespace {
CheckPoolingParam(const PoolingParameter * param)24 int CheckPoolingParam(const PoolingParameter *param) {
25   const int max_pooling_pad = 50;
26   if (param->pad_u_ > max_pooling_pad || param->pad_d_ > max_pooling_pad || param->pad_l_ > max_pooling_pad ||
27       param->pad_r_ > max_pooling_pad) {
28     return RET_ERROR;
29   }
30   return RET_OK;
31 }
32 
UpdateRoundMode(enum schema::RoundMode round_mode,PoolingParameter * param)33 void UpdateRoundMode(enum schema::RoundMode round_mode, PoolingParameter *param) {
34   switch (round_mode) {
35     case schema::RoundMode_FLOOR:
36       param->round_type_ = RoundType_Floor;
37       break;
38     case schema::RoundMode_CEIL:
39       param->round_type_ = RoundType_Ceil;
40       break;
41     default:
42       param->round_type_ = RoundType_No;
43       break;
44   }
45 }
46 
UpdateActivationType(enum schema::ActivationType type,PoolingParameter * param)47 void UpdateActivationType(enum schema::ActivationType type, PoolingParameter *param) {
48   if (type == schema::ActivationType_RELU) {
49     param->act_type_ = ActType_Relu;
50   } else if (type == schema::ActivationType_RELU6) {
51     param->act_type_ = ActType_Relu6;
52   } else {
53     param->act_type_ = ActType_No;
54   }
55 }
56 
UpdatePadMode(enum schema::PadMode pad_mode,PoolingParameter * param)57 void UpdatePadMode(enum schema::PadMode pad_mode, PoolingParameter *param) {
58   switch (pad_mode) {
59     case schema::PadMode_SAME:
60       param->pad_mode_ = Pad_same;
61       break;
62     case schema::PadMode_VALID:
63       param->pad_mode_ = Pad_valid;
64       break;
65     default:
66       param->pad_mode_ = Pad_pad;
67       break;
68   }
69 }
70 }  // namespace
PopulateAvgPoolParameter(const void * primitive)71 OpParameter *PopulateAvgPoolParameter(const void *primitive) {
72   MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
73   auto pooling_prim = static_cast<const schema::Primitive *>(primitive);
74   auto value = pooling_prim->value_as_AvgPoolFusion();
75   if (value == nullptr) {
76     MS_LOG(ERROR) << "value is nullptr";
77     return nullptr;
78   }
79 
80   auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
81   if (param == nullptr) {
82     MS_LOG(ERROR) << "malloc PoolingParameter failed.";
83     return nullptr;
84   }
85   memset(param, 0, sizeof(PoolingParameter));
86 
87   param->op_parameter_.type_ = pooling_prim->value_type();
88   param->pool_mode_ = PoolMode_AvgPool;
89   param->global_ = value->global();
90   auto strides = value->strides();
91   if (strides == nullptr || strides->size() < kMinShapeSizeTwo) {
92     MS_LOG(ERROR) << "strides is invalid!";
93     free(param);
94     return nullptr;
95   }
96   param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
97   param->stride_h_ = static_cast<int>(*(strides->begin()));
98   auto pad = value->pad();
99   if (pad != nullptr && pad->size() >= kMinShapeSizeFour) {
100     param->pad_u_ = static_cast<int>(*(pad->begin()));
101     param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
102     param->pad_l_ = static_cast<int>(*(pad->begin() + kOffsetTwo));
103     param->pad_r_ = static_cast<int>(*(pad->begin() + kOffsetThree));
104   }
105   if (!param->global_) {
106     auto kernel_size = value->kernel_size();
107     if (kernel_size == nullptr || kernel_size->size() < kMinShapeSizeTwo) {
108       MS_LOG(ERROR) << "kernel_size is invalid";
109       free(param);
110       return nullptr;
111     }
112     param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
113     param->window_h_ = static_cast<int>(*(kernel_size->begin()));
114   }
115 
116   UpdateRoundMode(value->round_mode(), param);
117   UpdateActivationType(value->activation_type(), param);
118   UpdatePadMode(value->pad_mode(), param);
119 
120   if (CheckPoolingParam(param) != RET_OK) {
121     MS_LOG(ERROR) << "param is invalid!";
122     free(param);
123     return nullptr;
124   }
125   return reinterpret_cast<OpParameter *>(param);
126 }
127 
PopulateMaxPoolParameter(const void * primitive)128 OpParameter *PopulateMaxPoolParameter(const void *primitive) {
129   auto pooling_prim = static_cast<const schema::Primitive *>(primitive);
130   MS_ASSERT(pooling_prim != nullptr);
131   auto value = pooling_prim->value_as_MaxPoolFusion();
132   if (value == nullptr) {
133     MS_LOG(ERROR) << "value is nullptr";
134     return nullptr;
135   }
136 
137   auto *param = reinterpret_cast<PoolingParameter *>(malloc(sizeof(PoolingParameter)));
138   if (param == nullptr) {
139     MS_LOG(ERROR) << "malloc PoolingParameter failed.";
140     return nullptr;
141   }
142   memset(param, 0, sizeof(PoolingParameter));
143 
144   param->op_parameter_.type_ = pooling_prim->value_type();
145   param->pool_mode_ = PoolMode_MaxPool;
146   param->global_ = value->global();
147   if (!param->global_) {
148     auto kernel_size = value->kernel_size();
149     auto strides = value->strides();
150     if (kernel_size == nullptr || strides == nullptr || kernel_size->size() < kMinShapeSizeTwo ||
151         strides->size() < kMinShapeSizeTwo) {
152       MS_LOG(ERROR) << "kernel_size or strides is invalid";
153       free(param);
154       return nullptr;
155     }
156     param->window_w_ = static_cast<int>(*(kernel_size->begin() + 1));
157     param->window_h_ = static_cast<int>(*(kernel_size->begin()));
158     param->stride_w_ = static_cast<int>(*(strides->begin() + 1));
159     param->stride_h_ = static_cast<int>(*(strides->begin()));
160     auto pad = value->pad();
161     if (pad != nullptr && pad->size() >= kMinShapeSizeFour) {
162       param->pad_u_ = static_cast<int>(*(pad->begin()));
163       param->pad_d_ = static_cast<int>(*(pad->begin() + 1));
164       param->pad_l_ = static_cast<int>(*(pad->begin() + kOffsetTwo));
165       param->pad_r_ = static_cast<int>(*(pad->begin() + kOffsetThree));
166     }
167   }
168 
169   UpdateRoundMode(value->round_mode(), param);
170   UpdateActivationType(value->activation_type(), param);
171   UpdatePadMode(value->pad_mode(), param);
172 
173   if (CheckPoolingParam(param) != RET_OK) {
174     MS_LOG(ERROR) << "param is invalid!";
175     free(param);
176     return nullptr;
177   }
178   return reinterpret_cast<OpParameter *>(param);
179 }
180 
181 REG_POPULATE(PrimitiveType_AvgPoolFusion, PopulateAvgPoolParameter, SCHEMA_CUR)
182 REG_POPULATE(PrimitiveType_MaxPoolFusion, PopulateMaxPoolParameter, SCHEMA_CUR)
183 }  // namespace lite
184 }  // namespace mindspore
185