1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_ARM64
18 #include "nnacl/kernel/convolution_im2col_arm64.h"
19 #include "nnacl/fp32/pack_fp32.h"
20 #include "nnacl/fp32/conv_common_fp32.h"
21
ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct * conv)22 void ConvIm2ColARM64InitGlobalVariable(ConvolutionBaseStruct *conv) {
23 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
24 conv_im2col->oc_tile_ = C8NUM;
25 conv_im2col->row_tile_ = C12NUM;
26 conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col8Major;
27 }
28
ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct * conv,int task_id)29 int ConvIm2ColARM64RunImpl(struct ConvolutionBaseStruct *conv, int task_id) {
30 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
31 NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
32 float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_;
33 NNACL_CHECK_NULL_RETURN_ERR(ori_input_data);
34 ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
35 NNACL_CHECK_NULL_RETURN_ERR(conv_param);
36
37 if (conv->out_format_ != Format_NC4HW4) {
38 if (conv->use_batch_cut_flag_) {
39 ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
40 (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
41 conv_param);
42 } else {
43 ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_,
44 conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param);
45 }
46 } else {
47 ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
48 (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
49 conv_param);
50 }
51 return NNACL_OK;
52 }
53
CreateConvIm2ColARM64(ConvParameter * conv_param)54 ConvolutionBaseStruct *CreateConvIm2ColARM64(ConvParameter *conv_param) {
55 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct));
56 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col);
57 memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct));
58
59 conv_im2col->init_tmp_buffer_ = ConvIm2ColBaseInitTmpBuffer;
60 conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData;
61 conv_im2col->conv_.init_global_variable_ = ConvIm2ColARM64InitGlobalVariable;
62 conv_im2col->conv_.run_impl_ = ConvIm2ColARM64RunImpl;
63 conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight;
64
65 conv_im2col->conv_.base_.Compute = ConvolutionIm2colBaseCompute;
66 conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare;
67 conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize;
68 conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease;
69
70 return (ConvolutionBaseStruct *)conv_im2col;
71 }
72 #endif
73