• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/conv_parameter.h"
18 #include "src/common/ops/populate/populate_register.h"
19 using mindspore::schema::PrimitiveType_Conv2DFusion;
20 
21 namespace mindspore {
22 namespace lite {
23 namespace {
SetPadAndAct(schema::PadMode pad_mode,schema::ActivationType act_type,ConvParameter * param)24 int SetPadAndAct(schema::PadMode pad_mode, schema::ActivationType act_type, ConvParameter *param) {
25   switch (pad_mode) {
26     case schema::PadMode_SAME:
27       param->pad_mode_ = Pad_same;
28       break;
29     case schema::PadMode_VALID:
30       param->pad_mode_ = Pad_valid;
31       break;
32     case schema::PadMode_PAD:
33       param->pad_mode_ = Pad_pad;
34       break;
35     default:
36       MS_LOG(ERROR) << "Pad mode does not support, " << pad_mode;
37       return RET_NOT_SUPPORT;
38   }
39 
40   switch (act_type) {
41     case schema::ActivationType_RELU:
42       param->act_type_ = ActType_Relu;
43       break;
44     case schema::ActivationType_RELU6:
45       param->act_type_ = ActType_Relu6;
46       break;
47     default:
48       if (act_type != schema::ActivationType_NO_ACTIVATION) {
49         MS_LOG(ERROR) << "activation type does not support, " << act_type;
50         return RET_NOT_SUPPORT;
51       }
52       param->act_type_ = ActType_No;
53       break;
54   }
55   return RET_OK;
56 }
57 }  // namespace
58 
PopulateConvParameter(const void * prim)59 OpParameter *PopulateConvParameter(const void *prim) {
60   auto primitive = static_cast<const schema::Primitive *>(prim);
61   MS_CHECK_TRUE_MSG(primitive != nullptr, nullptr, "primitive is nullptr.");
62   auto value = primitive->value_as_Conv2DFusion();
63   MS_CHECK_TRUE_MSG(value != nullptr, nullptr, "value is nullptr.");
64 
65   auto *param = reinterpret_cast<ConvParameter *>(malloc(sizeof(ConvParameter)));
66   if (param == nullptr) {
67     MS_LOG(ERROR) << "malloc ConvParameter failed.";
68     return nullptr;
69   }
70   memset(param, 0, sizeof(ConvParameter));
71 
72   param->op_parameter_.type_ = primitive->value_type();
73   auto kernel_size = value->kernel_size();
74   auto stride = value->stride();
75   auto pad_list = value->pad_list();
76   auto dilation = value->dilation();
77   if (kernel_size != nullptr) {
78     if (kernel_size->size() < kMinShapeSizeTwo) {
79       MS_LOG(ERROR) << "kernel size is invalid.";
80       free(param);
81       return nullptr;
82     }
83     param->kernel_h_ = static_cast<int>(*(kernel_size->begin()));
84     param->kernel_w_ = static_cast<int>(*(kernel_size->begin() + 1));
85   } else {
86     param->kernel_h_ = -1;
87     param->kernel_w_ = -1;
88   }
89   if (stride == nullptr || dilation == nullptr) {
90     MS_LOG(ERROR) << "kernel_size/stride/dilation is nullptr";
91     free(param);
92     return nullptr;
93   }
94   if (stride->size() < kMinShapeSizeTwo || dilation->size() < kMinShapeSizeTwo) {
95     MS_LOG(ERROR) << "stride size: " << stride->size() << ", dilation size: " << dilation->size();
96     free(param);
97     return nullptr;
98   }
99   for (size_t i = 0; i <= 1; i++) {
100     auto stride_item = *(stride->begin() + i);
101     if (stride_item < 0 || stride_item > static_cast<int64_t>(INT32_MAX)) {
102       MS_LOG(ERROR) << "strides has invalid num.";
103       free(param);
104       return nullptr;
105     }
106   }
107   param->group_ = static_cast<int>(value->group());
108   param->stride_h_ = static_cast<int>(*(stride->begin()));
109   param->stride_w_ = static_cast<int>(*(stride->begin() + 1));
110   if (pad_list == nullptr || pad_list->size() < kMinShapeSizeFour) {
111     param->pad_u_ = 0;
112     param->pad_d_ = 0;
113     param->pad_l_ = 0;
114     param->pad_r_ = 0;
115   } else {
116     for (size_t i = 0; i <= kOffsetThree; i++) {
117       auto pad_item = *(pad_list->begin() + i);
118       if (pad_item < 0 || pad_item > static_cast<int64_t>(INT32_MAX)) {
119         MS_LOG(ERROR) << "pad list has invalid num.";
120         free(param);
121         return nullptr;
122       }
123     }
124     param->pad_u_ = static_cast<int>(*(pad_list->begin()));
125     param->pad_d_ = static_cast<int>(*(pad_list->begin() + 1));
126     param->pad_l_ = static_cast<int>(*(pad_list->begin() + kOffsetTwo));
127     param->pad_r_ = static_cast<int>(*(pad_list->begin() + kOffsetThree));
128   }
129   param->dilation_h_ = static_cast<int>(*(dilation->begin()));
130   param->dilation_w_ = static_cast<int>(*(dilation->begin() + 1));
131   param->input_channel_ = static_cast<int>(value->in_channel());
132   param->output_channel_ = static_cast<int>(value->out_channel());
133   auto pad_mode = value->pad_mode();
134   auto act_type = value->activation_type();
135   if (SetPadAndAct(pad_mode, act_type, param) != RET_OK) {
136     MS_LOG(ERROR) << "SetPadAndAct failed.";
137     free(param);
138     return nullptr;
139   }
140   return reinterpret_cast<OpParameter *>(param);
141 }
142 
143 REG_POPULATE(PrimitiveType_Conv2DFusion, PopulateConvParameter, SCHEMA_CUR)
144 }  // namespace lite
145 }  // namespace mindspore
146