1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/convolution_base.h"
18 #include "nnacl/conv_parameter.h"
19 #include "nnacl/tensor_c_utils.h"
20
ConvBaseUpdateParamInfo(ConvComputeParam * compute,ConvParameter * conv_param)21 int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param) {
22 compute->stride_h_ = conv_param->stride_h_;
23 compute->stride_w_ = conv_param->stride_w_;
24 compute->dilation_h_ = conv_param->dilation_h_;
25 compute->dilation_w_ = conv_param->dilation_w_;
26 compute->pad_u_ = conv_param->pad_u_;
27 compute->pad_d_ = conv_param->pad_d_;
28 compute->pad_l_ = conv_param->pad_l_;
29 compute->pad_r_ = conv_param->pad_r_;
30
31 compute->in_c_ = conv_param->input_channel_;
32 compute->out_c_ = conv_param->output_channel_;
33
34 compute->kernel_h_ = conv_param->kernel_h_;
35 compute->kernel_w_ = conv_param->kernel_w_;
36 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_h_, compute->kernel_w_, NNACL_ERR);
37 compute->kernel_hw_ = compute->kernel_h_ * compute->kernel_w_;
38
39 return NNACL_OK;
40 }
41
ConvBaseUpdateComputeInfo(ConvolutionBaseStruct * conv)42 int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv) {
43 NNACL_CHECK_NULL_RETURN_ERR(conv);
44 ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
45 NNACL_CHECK_NULL_RETURN_ERR(conv_param);
46 TensorC *input = conv->base_.in_[FIRST_INPUT];
47 NNACL_CHECK_NULL_RETURN_ERR(input);
48 TensorC *output = conv->base_.out_[OUTPUT_INDEX];
49 NNACL_CHECK_NULL_RETURN_ERR(output);
50
51 conv_param->input_batch_ = GetBatch(input);
52 conv_param->input_h_ = GetHeight(input);
53 conv_param->input_w_ = GetWidth(input);
54 conv_param->input_channel_ = GetChannel(input);
55 conv_param->output_batch_ = GetBatch(output);
56 conv_param->output_h_ = GetHeight(output);
57 conv_param->output_w_ = GetWidth(output);
58 conv_param->output_channel_ = GetChannel(output);
59
60 ConvComputeParam *compute = &conv->compute_;
61 compute->in_n_ = GetBatch(input);
62 compute->in_h_ = GetHeight(input);
63 compute->in_w_ = GetWidth(input);
64 compute->in_c_ = GetChannel(input);
65 NNACL_CHECK_FALSE(compute->in_c_ != conv_param->input_channel_, NNACL_ERR);
66 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_h_, compute->in_w_, NNACL_ERR);
67 compute->in_hw_ = compute->in_h_ * compute->in_w_;
68 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_, compute->in_n_, NNACL_ERR);
69 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_ * compute->in_n_, compute->in_c_, NNACL_ERR);
70
71 compute->out_n_ = GetBatch(output);
72 compute->out_h_ = GetHeight(output);
73 compute->out_w_ = GetWidth(output);
74 compute->out_c_ = GetChannel(output);
75 NNACL_CHECK_FALSE(compute->out_c_ != conv_param->output_channel_, NNACL_ERR);
76 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_h_, compute->out_w_, NNACL_ERR);
77 compute->out_hw_ = compute->out_h_ * compute->out_w_;
78 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_, compute->out_n_, NNACL_ERR);
79 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_ * compute->out_n_, compute->out_c_, NNACL_ERR);
80
81 return ConvBaseUpdateParamInfo(compute, conv_param);
82 }
83
ConvBaseRelease(ConvolutionBaseStruct * conv)84 void ConvBaseRelease(ConvolutionBaseStruct *conv) {
85 if (!conv->base_.train_session_) {
86 if (!conv->is_sharing_pack_) {
87 conv->base_.env_->Free(conv->base_.env_->allocator_, conv->packed_weight_);
88 } else {
89 conv->free_sharing_weight_(conv->shaing_manager_, conv->packed_weight_);
90 }
91 conv->packed_weight_ = NULL;
92 }
93
94 if (conv->bias_data_ != NULL) {
95 conv->base_.env_->Free(conv->base_.env_->allocator_, conv->bias_data_);
96 conv->bias_data_ = NULL;
97 }
98 }
99
ConvBasePrepare(ConvolutionBaseStruct * conv)100 int ConvBasePrepare(ConvolutionBaseStruct *conv) {
101 NNACL_CHECK_FALSE(conv->base_.in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR);
102 NNACL_CHECK_FALSE(conv->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR);
103
104 conv->out_format_ = conv->base_.out_[OUTPUT_INDEX]->format_;
105 return ConvBaseUpdateComputeInfo(conv);
106 }
107
ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct * conv)108 void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv) {
109 NNACL_CHECK_NULL_RETURN_VOID(conv);
110
111 if (conv->base_.in_[SECOND_INPUT]->data_ != NULL) {
112 conv->origin_weight_ = conv->base_.in_[SECOND_INPUT]->data_;
113 }
114
115 if (conv->base_.in_size_ == THREE_TENSOR && conv->base_.in_[THIRD_INPUT]->data_ != NULL) {
116 conv->origin_bias_ = conv->base_.in_[THIRD_INPUT]->data_;
117 }
118 }
119
ConvBaseInitConvWeightBias(ConvolutionBaseStruct * conv)120 int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv) {
121 if (conv->base_.train_session_) {
122 ConvBaseUpdateOriginWeightAndBias(conv);
123 }
124
125 /* check weight shape done */
126 if (!CheckInferShapeDone(&conv->base_.in_[SECOND_INPUT], ONE_TENSOR, NULL, 0)) {
127 return NNACL_OK;
128 }
129
130 int ret = conv->malloc_weight_bias_(conv);
131 if (ret != NNACL_OK) {
132 return ret;
133 }
134
135 if (conv->base_.in_size_ == THREE_TENSOR) {
136 memcpy(conv->bias_data_, conv->origin_bias_, GetSize(conv->base_.in_[THIRD_INPUT]));
137 }
138
139 if (!conv->base_.train_session_) {
140 if (conv->weight_is_packed_) {
141 return NNACL_OK;
142 }
143 if (conv->origin_weight_ != NULL) {
144 conv->pack_weight_(conv);
145 } else {
146 conv->is_repack_ = true;
147 }
148 }
149 return NNACL_OK;
150 }
151
ConvBaseCheckResizeValid(ConvolutionBaseStruct * conv)152 int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv) {
153 // ===============check in channel================= //
154 TensorC *input_tensor = conv->base_.in_[FIRST_INPUT];
155 NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
156 int resize_in_channel = GetChannel(input_tensor);
157 TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT];
158 NNACL_CHECK_NULL_RETURN_ERR(filter_tensor);
159 int filter_in_channel = GetChannel(filter_tensor);
160 if (filter_in_channel != resize_in_channel) {
161 return NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH;
162 }
163 return NNACL_OK;
164 }
165
ConvBaseGetConvPackWeightData(ConvolutionBaseStruct * conv,int data_size)166 void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size) {
167 TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT];
168 bool const_fit = weight_tensor->category_ != ConstTensor && weight_tensor->category_ != ConstScalar;
169 bool group_fit = ((ConvParameter *)conv->base_.param_)->group_ > 1;
170 bool sharing_fit = conv->get_sharing_weight_ == NULL;
171
172 void *data = NULL;
173 if (sharing_fit || const_fit || group_fit) {
174 if (data_size <= 0) {
175 return NULL;
176 }
177 data = conv->base_.env_->Alloc(conv->base_.env_->allocator_, data_size);
178 conv->weight_is_packed_ = false;
179 conv->is_sharing_pack_ = false;
180 } else {
181 data = conv->get_sharing_weight_(conv->shaing_manager_, weight_tensor->data_, data_size, &conv->weight_is_packed_);
182 }
183 return data;
184 }
185
ConvBaseRepackWeight(ConvolutionBaseStruct * conv)186 int ConvBaseRepackWeight(ConvolutionBaseStruct *conv) {
187 NNACL_CHECK_NULL_RETURN_ERR(conv);
188
189 conv->origin_weight_ = conv->origin_weight_ != NULL ? conv->origin_weight_ : conv->base_.in_[SECOND_INPUT]->data_;
190 NNACL_CHECK_NULL_RETURN_ERR(conv->origin_weight_);
191
192 if (conv->packed_weight_ == NULL) {
193 int ret = ConvBaseInitConvWeightBias(conv);
194 if (ret != NNACL_OK) {
195 return ret;
196 }
197 }
198
199 if (conv->is_repack_ || conv->base_.train_session_) {
200 if (conv->base_.train_session_) {
201 conv->packed_weight_ = (float *)conv->base_.workspace_;
202 memset(conv->packed_weight_, 0, conv->base_.work_size_);
203 } else {
204 conv->is_repack_ = false;
205 }
206 conv->pack_weight_(conv);
207 }
208 return NNACL_OK;
209 }
210