1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_AVX
18 #include "nnacl/kernel/convolution_im2col_avx.h"
19 #include "nnacl/fp32/pack_fp32.h"
20 #include "nnacl/fp32/conv_common_fp32.h"
21
ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct * conv)22 void ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct *conv) {
23 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
24 conv_im2col->oc_tile_ = C16NUM;
25 conv_im2col->row_tile_ = C6NUM;
26 conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col16Major;
27 }
28
ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct * conv_im2col)29 int ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) {
30 int kernel_chw = conv_im2col->conv_.compute_.kernel_hw_ * conv_im2col->conv_.compute_.in_c_;
31 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR);
32 int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_;
33 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR);
34 int unit_size = total_kernel_chw * conv_im2col->row_tile_;
35
36 ExecEnv *env = conv_im2col->conv_.base_.env_;
37 NNACL_CHECK_NULL_RETURN_ERR(env);
38
39 if (conv_im2col->packed_input_ != NULL) {
40 env->Free(env->allocator_, conv_im2col->packed_input_);
41 conv_im2col->packed_input_ = NULL;
42 }
43 conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float));
44 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_);
45
46 if (conv_im2col->col_major_input_ != NULL) {
47 env->Free(env->allocator_, conv_im2col->col_major_input_);
48 conv_im2col->col_major_input_ = NULL;
49 }
50 conv_im2col->col_major_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float));
51 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_);
52
53 conv_im2col->output_need_align_ =
54 conv_im2col->conv_.compute_.out_c_ % conv_im2col->oc_tile_ != 0 && conv_im2col->conv_.out_format_ == Format_NC4HW4;
55 if (conv_im2col->output_need_align_) {
56 int oc_algin = UP_DIV(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_);
57 int output_bhw = conv_im2col->conv_.compute_.out_n_ * conv_im2col->conv_.compute_.out_hw_;
58 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_im2col->oc_tile_ * oc_algin, NNACL_ERR);
59 int pack_output_size = output_bhw * conv_im2col->oc_tile_ * oc_algin;
60
61 if (conv_im2col->tmp_output_ != NULL) {
62 env->Free(env->allocator_, conv_im2col->tmp_output_);
63 conv_im2col->tmp_output_ = NULL;
64 }
65 conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float));
66 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_);
67 }
68 return NNACL_OK;
69 }
70
ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct * conv,int task_id)71 int ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct *conv, int task_id) {
72 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
73 NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
74 ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
75 NNACL_CHECK_NULL_RETURN_ERR(conv_param);
76 float *ori_input_data = conv->base_.in_[FIRST_INPUT]->data_;
77 NNACL_CHECK_NULL_RETURN_ERR(ori_input_data);
78
79 if (conv->out_format_ != Format_NC4HW4) {
80 if (conv->use_batch_cut_flag_) {
81 ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
82 (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
83 conv_param);
84 } else {
85 ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_,
86 conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param);
87 }
88 } else {
89 ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
90 (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
91 conv_param);
92 }
93 return NNACL_OK;
94 }
95
ConvolutionIm2colAvxCompute(KernelBase * self)96 int ConvolutionIm2colAvxCompute(KernelBase *self) {
97 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self;
98 NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
99
100 int ret = conv_im2col->init_tmp_buffer_(conv_im2col);
101 if (ret != NNACL_OK) {
102 ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
103 return ret;
104 }
105
106 float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_;
107 NNACL_CHECK_NULL_RETURN_ERR(output_addr);
108 if (!conv_im2col->output_need_align_) {
109 conv_im2col->tmp_output_ = output_addr;
110 }
111
112 ret = ConvBaseRepackWeight(&conv_im2col->conv_);
113 if (ret != NNACL_OK) {
114 ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
115 return ret;
116 }
117
118 ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_);
119
120 if (conv_im2col->output_need_align_) {
121 PackNC8HW8AlignedToNC8HW8NotAlignedFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_,
122 conv_im2col->conv_.compute_.out_w_ * conv_im2col->conv_.compute_.out_h_,
123 conv_im2col->conv_.compute_.out_c_);
124 } else {
125 conv_im2col->tmp_output_ = NULL;
126 }
127
128 ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
129 return ret;
130 }
131
CreateConvIm2ColAVX(ConvParameter * conv_param)132 ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param) {
133 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct));
134 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col);
135 memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct));
136
137 conv_im2col->init_tmp_buffer_ = ConvIm2ColAVXInitTmpBuffer;
138
139 conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData;
140 conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVXInitGlobalVariable;
141 conv_im2col->conv_.run_impl_ = ConvIm2ColAVXRunImpl;
142 conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight;
143
144 conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvxCompute;
145 conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare;
146 conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize;
147 conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease;
148
149 return (ConvolutionBaseStruct *)conv_im2col;
150 }
151 #endif
152