• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/convolution_base.h"
18 #include "nnacl/conv_parameter.h"
19 #include "nnacl/tensor_c_utils.h"
20 
ConvBaseUpdateParamInfo(ConvComputeParam * compute,ConvParameter * conv_param)21 int ConvBaseUpdateParamInfo(ConvComputeParam *compute, ConvParameter *conv_param) {
22   compute->stride_h_ = conv_param->stride_h_;
23   compute->stride_w_ = conv_param->stride_w_;
24   compute->dilation_h_ = conv_param->dilation_h_;
25   compute->dilation_w_ = conv_param->dilation_w_;
26   compute->pad_u_ = conv_param->pad_u_;
27   compute->pad_d_ = conv_param->pad_d_;
28   compute->pad_l_ = conv_param->pad_l_;
29   compute->pad_r_ = conv_param->pad_r_;
30 
31   compute->in_c_ = conv_param->input_channel_;
32   compute->out_c_ = conv_param->output_channel_;
33 
34   compute->kernel_h_ = conv_param->kernel_h_;
35   compute->kernel_w_ = conv_param->kernel_w_;
36   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_h_, compute->kernel_w_, NNACL_ERR);
37   compute->kernel_hw_ = compute->kernel_h_ * compute->kernel_w_;
38 
39   return NNACL_OK;
40 }
41 
ConvBaseUpdateComputeInfo(ConvolutionBaseStruct * conv)42 int ConvBaseUpdateComputeInfo(ConvolutionBaseStruct *conv) {
43   NNACL_CHECK_NULL_RETURN_ERR(conv);
44   ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
45   NNACL_CHECK_NULL_RETURN_ERR(conv_param);
46   TensorC *input = conv->base_.in_[FIRST_INPUT];
47   NNACL_CHECK_NULL_RETURN_ERR(input);
48   TensorC *output = conv->base_.out_[OUTPUT_INDEX];
49   NNACL_CHECK_NULL_RETURN_ERR(output);
50 
51   conv_param->input_batch_ = GetBatch(input);
52   conv_param->input_h_ = GetHeight(input);
53   conv_param->input_w_ = GetWidth(input);
54   conv_param->input_channel_ = GetChannel(input);
55   conv_param->output_batch_ = GetBatch(output);
56   conv_param->output_h_ = GetHeight(output);
57   conv_param->output_w_ = GetWidth(output);
58   conv_param->output_channel_ = GetChannel(output);
59 
60   ConvComputeParam *compute = &conv->compute_;
61   compute->in_n_ = GetBatch(input);
62   compute->in_h_ = GetHeight(input);
63   compute->in_w_ = GetWidth(input);
64   compute->in_c_ = GetChannel(input);
65   NNACL_CHECK_FALSE(compute->in_c_ != conv_param->input_channel_, NNACL_ERR);
66   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_h_, compute->in_w_, NNACL_ERR);
67   compute->in_hw_ = compute->in_h_ * compute->in_w_;
68   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_, compute->in_n_, NNACL_ERR);
69   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->in_hw_ * compute->in_n_, compute->in_c_, NNACL_ERR);
70 
71   compute->out_n_ = GetBatch(output);
72   compute->out_h_ = GetHeight(output);
73   compute->out_w_ = GetWidth(output);
74   compute->out_c_ = GetChannel(output);
75   NNACL_CHECK_FALSE(compute->out_c_ != conv_param->output_channel_, NNACL_ERR);
76   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_h_, compute->out_w_, NNACL_ERR);
77   compute->out_hw_ = compute->out_h_ * compute->out_w_;
78   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_, compute->out_n_, NNACL_ERR);
79   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_hw_ * compute->out_n_, compute->out_c_, NNACL_ERR);
80 
81   return ConvBaseUpdateParamInfo(compute, conv_param);
82 }
83 
ConvBaseRelease(ConvolutionBaseStruct * conv)84 void ConvBaseRelease(ConvolutionBaseStruct *conv) {
85   if (!conv->base_.train_session_) {
86     if (!conv->is_sharing_pack_) {
87       conv->base_.env_->Free(conv->base_.env_->allocator_, conv->packed_weight_);
88     } else {
89       conv->free_sharing_weight_(conv->shaing_manager_, conv->packed_weight_);
90     }
91     conv->packed_weight_ = NULL;
92   }
93 
94   if (conv->bias_data_ != NULL) {
95     conv->base_.env_->Free(conv->base_.env_->allocator_, conv->bias_data_);
96     conv->bias_data_ = NULL;
97   }
98 }
99 
ConvBasePrepare(ConvolutionBaseStruct * conv)100 int ConvBasePrepare(ConvolutionBaseStruct *conv) {
101   NNACL_CHECK_FALSE(conv->base_.in_size_ < TWO_TENSOR, NNACL_INPUT_TENSOR_ERROR);
102   NNACL_CHECK_FALSE(conv->base_.out_size_ < ONE_TENSOR, NNACL_OUTPUT_TENSOR_ERROR);
103 
104   conv->out_format_ = conv->base_.out_[OUTPUT_INDEX]->format_;
105   return ConvBaseUpdateComputeInfo(conv);
106 }
107 
ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct * conv)108 void ConvBaseUpdateOriginWeightAndBias(ConvolutionBaseStruct *conv) {
109   NNACL_CHECK_NULL_RETURN_VOID(conv);
110 
111   if (conv->base_.in_[SECOND_INPUT]->data_ != NULL) {
112     conv->origin_weight_ = conv->base_.in_[SECOND_INPUT]->data_;
113   }
114 
115   if (conv->base_.in_size_ == THREE_TENSOR && conv->base_.in_[THIRD_INPUT]->data_ != NULL) {
116     conv->origin_bias_ = conv->base_.in_[THIRD_INPUT]->data_;
117   }
118 }
119 
ConvBaseInitConvWeightBias(ConvolutionBaseStruct * conv)120 int ConvBaseInitConvWeightBias(ConvolutionBaseStruct *conv) {
121   if (conv->base_.train_session_) {
122     ConvBaseUpdateOriginWeightAndBias(conv);
123   }
124 
125   /* check weight shape done */
126   if (!CheckInferShapeDone(&conv->base_.in_[SECOND_INPUT], ONE_TENSOR, NULL, 0)) {
127     return NNACL_OK;
128   }
129 
130   int ret = conv->malloc_weight_bias_(conv);
131   if (ret != NNACL_OK) {
132     return ret;
133   }
134 
135   if (conv->base_.in_size_ == THREE_TENSOR) {
136     memcpy(conv->bias_data_, conv->origin_bias_, GetSize(conv->base_.in_[THIRD_INPUT]));
137   }
138 
139   if (!conv->base_.train_session_) {
140     if (conv->weight_is_packed_) {
141       return NNACL_OK;
142     }
143     if (conv->origin_weight_ != NULL) {
144       conv->pack_weight_(conv);
145     } else {
146       conv->is_repack_ = true;
147     }
148   }
149   return NNACL_OK;
150 }
151 
ConvBaseCheckResizeValid(ConvolutionBaseStruct * conv)152 int ConvBaseCheckResizeValid(ConvolutionBaseStruct *conv) {
153   // ===============check in channel================= //
154   TensorC *input_tensor = conv->base_.in_[FIRST_INPUT];
155   NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
156   int resize_in_channel = GetChannel(input_tensor);
157   TensorC *filter_tensor = conv->base_.in_[SECOND_INPUT];
158   NNACL_CHECK_NULL_RETURN_ERR(filter_tensor);
159   int filter_in_channel = GetChannel(filter_tensor);
160   if (filter_in_channel != resize_in_channel) {
161     return NNACL_CONVOLUTION_INPUT_CHANNEL_UNMATCH;
162   }
163   return NNACL_OK;
164 }
165 
ConvBaseGetConvPackWeightData(ConvolutionBaseStruct * conv,int data_size)166 void *ConvBaseGetConvPackWeightData(ConvolutionBaseStruct *conv, int data_size) {
167   TensorC *weight_tensor = conv->base_.in_[SECOND_INPUT];
168   bool const_fit = weight_tensor->category_ != ConstTensor && weight_tensor->category_ != ConstScalar;
169   bool group_fit = ((ConvParameter *)conv->base_.param_)->group_ > 1;
170   bool sharing_fit = conv->get_sharing_weight_ == NULL;
171 
172   void *data = NULL;
173   if (sharing_fit || const_fit || group_fit) {
174     if (data_size <= 0) {
175       return NULL;
176     }
177     data = conv->base_.env_->Alloc(conv->base_.env_->allocator_, data_size);
178     conv->weight_is_packed_ = false;
179     conv->is_sharing_pack_ = false;
180   } else {
181     data = conv->get_sharing_weight_(conv->shaing_manager_, weight_tensor->data_, data_size, &conv->weight_is_packed_);
182   }
183   return data;
184 }
185 
ConvBaseRepackWeight(ConvolutionBaseStruct * conv)186 int ConvBaseRepackWeight(ConvolutionBaseStruct *conv) {
187   NNACL_CHECK_NULL_RETURN_ERR(conv);
188 
189   conv->origin_weight_ = conv->origin_weight_ != NULL ? conv->origin_weight_ : conv->base_.in_[SECOND_INPUT]->data_;
190   NNACL_CHECK_NULL_RETURN_ERR(conv->origin_weight_);
191 
192   if (conv->packed_weight_ == NULL) {
193     int ret = ConvBaseInitConvWeightBias(conv);
194     if (ret != NNACL_OK) {
195       return ret;
196     }
197   }
198 
199   if (conv->is_repack_ || conv->base_.train_session_) {
200     if (conv->base_.train_session_) {
201       conv->packed_weight_ = (float *)conv->base_.workspace_;
202       memset(conv->packed_weight_, 0, conv->base_.work_size_);
203     } else {
204       conv->is_repack_ = false;
205     }
206     conv->pack_weight_(conv);
207   }
208   return NNACL_OK;
209 }
210