• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_ARM64
18 #include "nnacl/kernel/convolution_im2col_arm64.h"
19 #include "nnacl/fp32/pack_fp32.h"
20 #include "nnacl/fp32/conv_common_fp32.h"
21 
ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct * conv)22 void ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct *conv) {
23   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
24   conv_im2col->oc_tile_ = C8NUM;
25   conv_im2col->row_tile_ = C12NUM;
26   conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major;
27 }
28 
ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct * conv,int task_id)29 int ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct *conv, int task_id) {
30   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
31   NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
32   float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_;
33   NNACL_CHECK_NULL_RETURN_ERR(ori_input_data);
34   ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
35   NNACL_CHECK_NULL_RETURN_ERR(conv_param);
36 
37   if (conv->out_format_ != Format_NC4HW4) {
38     if (conv->use_batch_cut_flag_) {
39       ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
40                          (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
41                          conv_param);
42     } else {
43       ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_,
44                conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param);
45     }
46   } else {
47     ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
48                       (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
49                       conv_param);
50   }
51   return NNACL_OK;
52 }
53 
CreateConvIm2ColARM64(ConvParameter * conv_param)54 ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param) {
55   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct));
56   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col);
57   memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct));
58 
59   conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer;
60   conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData;
61   conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM64InitGlobalVariable;
62   conv_im2col->conv_.run_impl_ = ConvIm2ColARM64RunImpl;
63   conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight;
64 
65   conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute;
66   conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare;
67   conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize;
68   conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease;
69 
70   return (ConvolutionBaseStruct *)conv_im2col;
71 }
72 #endif
73