• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/common/log_adapter.h"
17 #include "src/common/ops/populate/populate_register.h"
18 #include "nnacl/conv_parameter.h"
19 using mindspore::schema::PrimitiveType_Conv2dTransposeFusion;
20 
21 namespace mindspore {
22 namespace lite {
SetPadAndAct(schema::PadMode pad_mode,schema::ActivationType act_type,ConvParameter * param)23 int SetPadAndAct(schema::PadMode pad_mode, schema::ActivationType act_type, ConvParameter *param) {
24   switch (pad_mode) {
25     case schema::PadMode_SAME:
26       param->pad_mode_ = Pad_same;
27       break;
28     case schema::PadMode_VALID:
29       param->pad_mode_ = Pad_valid;
30       break;
31     case schema::PadMode_PAD:
32       param->pad_mode_ = Pad_pad;
33       break;
34     default:
35       MS_LOG(ERROR) << "pad mode does not support, " << pad_mode;
36       return RET_NOT_SUPPORT;
37   }
38 
39   switch (act_type) {
40     case schema::ActivationType_RELU:
41       param->act_type_ = ActType_Relu;
42       break;
43     case schema::ActivationType_RELU6:
44       param->act_type_ = ActType_Relu6;
45       break;
46     default:
47       if (act_type != schema::ActivationType_NO_ACTIVATION) {
48         MS_LOG(ERROR) << "activation type does not support, " << act_type;
49         return RET_NOT_SUPPORT;
50       }
51       param->act_type_ = ActType_No;
52       break;
53   }
54   return RET_OK;
55 }
56 
PopulateDeconvParameter(const void * prim)57 OpParameter *PopulateDeconvParameter(const void *prim) {
58   MS_CHECK_TRUE_RET(prim != nullptr, nullptr);
59 
60   auto primitive = static_cast<const schema::Primitive *>(prim);
61   auto value = primitive->value_as_Conv2dTransposeFusion();
62   if (value == nullptr) {
63     MS_LOG(ERROR) << "value is nullptr";
64     return nullptr;
65   }
66 
67   auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
68   if (param == nullptr) {
69     MS_LOG(ERROR) << "malloc ConvParameter failed.";
70     return nullptr;
71   }
72   memset(param, 0, sizeof(ConvParameter));
73 
74   param->op_parameter_.type_ = primitive->value_type();
75   auto kernel_size = value->kernel_size();
76   auto stride = value->stride();
77   auto pad_list = value->pad_list();
78   auto dilation = value->dilation();
79   auto output_paddings = value->output_paddings();
80   param->kernel_h_ = -1;
81   param->kernel_w_ = -1;
82   if (kernel_size != nullptr) {
83     if (kernel_size->size() < kMinShapeSizeTwo) {
84       MS_LOG(ERROR) << "kernel size is invalid.";
85       free(param);
86       return nullptr;
87     }
88     CHECK_LESS_RETURN_RET(INT32_MAX, *(kernel_size->begin()), nullptr, param);
89     param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
90     CHECK_LESS_RETURN_RET(INT32_MAX, *(kernel_size->begin() + 1), nullptr, param);
91     param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
92   }
93   param->output_padding_h_ = 0;
94   param->output_padding_w_ = 0;
95   if (output_paddings != nullptr) {
96     if (output_paddings->size() < kMinShapeSizeTwo) {
97       MS_LOG(ERROR) << "output_paddings size is invalid.";
98       free(param);
99       return nullptr;
100     }
101     CHECK_LESS_RETURN_RET(INT32_MAX, *(output_paddings->begin()), nullptr, param);
102     param->output_padding_h_ = static_cast<int>(*(output_paddings->begin()));
103     CHECK_LESS_RETURN_RET(INT32_MAX, *(output_paddings->begin() + 1), nullptr, param);
104     param->output_padding_w_ = static_cast<int>(*(output_paddings->begin() + 1));
105   }
106   if (param->output_padding_h_ < 0 || param->output_padding_w_ < 0) {
107     MS_LOG(ERROR) << "invalid output padding";
108     free(param);
109     return nullptr;
110   }
111 
112   if (stride == nullptr || dilation == nullptr) {
113     MS_LOG(ERROR) << "nullptr";
114     free(param);
115     return nullptr;
116   }
117   if (stride->size() < kMinShapeSizeTwo || dilation->size() < kMinShapeSizeTwo) {
118     MS_LOG(ERROR) << "stride size: " << stride->size() << ", dilation size: " << dilation->size();
119     free(param);
120     return nullptr;
121   }
122 
123   CHECK_LESS_RETURN_RET(INT32_MAX, value->group(), nullptr, param);
124   param->group_ = static_cast<int>(value->group());
125   CHECK_LESS_RETURN_RET(INT32_MAX, *(stride->begin()), nullptr, param);
126   param->stride_h_ = static_cast<int>(*(stride->begin()));
127   CHECK_LESS_RETURN_RET(INT32_MAX, *(stride->begin() + 1), nullptr, param);
128   param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
129 
130   if (pad_list == nullptr || pad_list->size() < kMinShapeSizeFour) {
131     param->pad_u_ = 0;
132     param->pad_d_ = 0;
133     param->pad_l_ = 0;
134     param->pad_r_ = 0;
135   } else {
136     CHECK_LESS_RETURN_RET(INT32_MAX, *(pad_list->begin()), nullptr, param);
137     param->pad_u_ = static_cast<int>(*(pad_list->begin()));
138     CHECK_LESS_RETURN_RET(INT32_MAX, *(pad_list->begin() + 1), nullptr, param);
139     param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
140     CHECK_LESS_RETURN_RET(INT32_MAX, *(pad_list->begin() + kOffsetTwo), nullptr, param);
141     param->pad_l_ = static_cast<int>(*(pad_list->begin() + kOffsetTwo));
142     CHECK_LESS_RETURN_RET(INT32_MAX, *(pad_list->begin() + kOffsetThree), nullptr, param);
143     param->pad_r_ = static_cast<int>(*(pad_list->begin() + kOffsetThree));
144   }
145   CHECK_LESS_RETURN_RET(INT32_MAX, *(dilation->begin()), nullptr, param);
146   param->dilation_h_ = static_cast<int>(*(dilation->begin()));
147 
148   CHECK_LESS_RETURN_RET(INT32_MAX, *(dilation->begin() + 1), nullptr, param);
149   param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
150 
151   CHECK_LESS_RETURN_RET(INT32_MAX, value->in_channel(), nullptr, param);
152   param->input_channel_ = static_cast<int>(value->in_channel());
153 
154   CHECK_LESS_RETURN_RET(INT32_MAX, value->out_channel(), nullptr, param);
155   param->output_channel_ = static_cast<int>(value->out_channel());
156 
157   auto act_type = value->activation_type();
158   auto pad_mode = value->pad_mode();
159   if (SetPadAndAct(pad_mode, act_type, param) != RET_OK) {
160     MS_LOG(ERROR) << "SetPadAndAct failed.";
161     free(param);
162     return nullptr;
163   }
164 
165   return reinterpret_cast<OpParameter *>(param);
166 }
167 REG_POPULATE(PrimitiveType_Conv2dTransposeFusion, PopulateDeconvParameter, SCHEMA_CUR)
168 }  // namespace lite
169 }  // namespace mindspore
170