• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_AVX
18 #include "nnacl/kernel/convolution_im2col_avx.h"
19 #include "nnacl/fp32/pack_fp32.h"
20 #include "nnacl/fp32/conv_common_fp32.h"
21 
ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct * conv)22 void ConvIm2ColAVXInitGlobalVariable(ConvolutionBaseStruct *conv) {
23   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
24   conv_im2col->oc_tile_ = C16NUM;
25   conv_im2col->row_tile_ = C6NUM;
26   conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col16Major;
27 }
28 
ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct * conv_im2col)29 int ConvIm2ColAVXInitTmpBuffer(ConvolutionIm2ColBaseStruct *conv_im2col) {
30   int kernel_chw = conv_im2col->conv_.compute_.kernel_hw_ * conv_im2col->conv_.compute_.in_c_;
31   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR);
32   int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_;
33   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR);
34   int unit_size = total_kernel_chw * conv_im2col->row_tile_;
35 
36   ExecEnv *env = conv_im2col->conv_.base_.env_;
37   NNACL_CHECK_NULL_RETURN_ERR(env);
38 
39   if (conv_im2col->packed_input_ != NULL) {
40     env->Free(env->allocator_, conv_im2col->packed_input_);
41     conv_im2col->packed_input_ = NULL;
42   }
43   conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float));
44   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_);
45 
46   if (conv_im2col->col_major_input_ != NULL) {
47     env->Free(env->allocator_, conv_im2col->col_major_input_);
48     conv_im2col->col_major_input_ = NULL;
49   }
50   conv_im2col->col_major_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float));
51   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->col_major_input_);
52 
53   conv_im2col->output_need_align_ =
54     conv_im2col->conv_.compute_.out_c_ % conv_im2col->oc_tile_ != 0 && conv_im2col->conv_.out_format_ == Format_NC4HW4;
55   if (conv_im2col->output_need_align_) {
56     int oc_algin = UP_DIV(conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_);
57     int output_bhw = conv_im2col->conv_.compute_.out_n_ * conv_im2col->conv_.compute_.out_hw_;
58     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, conv_im2col->oc_tile_ * oc_algin, NNACL_ERR);
59     int pack_output_size = output_bhw * conv_im2col->oc_tile_ * oc_algin;
60 
61     if (conv_im2col->tmp_output_ != NULL) {
62       env->Free(env->allocator_, conv_im2col->tmp_output_);
63       conv_im2col->tmp_output_ = NULL;
64     }
65     conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float));
66     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_);
67   }
68   return NNACL_OK;
69 }
70 
ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct * conv,int task_id)71 int ConvIm2ColAVXRunImpl(struct ConvolutionBaseStruct *conv, int task_id) {
72   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
73   NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
74   ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
75   NNACL_CHECK_NULL_RETURN_ERR(conv_param);
76   float *ori_input_data = conv->base_.in_[FIRST_INPUT]->data_;
77   NNACL_CHECK_NULL_RETURN_ERR(ori_input_data);
78 
79   if (conv->out_format_ != Format_NC4HW4) {
80     if (conv->use_batch_cut_flag_) {
81       ConvFp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
82                          (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
83                          conv_param);
84     } else {
85       ConvFp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_, (float *)conv->bias_data_,
86                conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id, conv_param);
87     }
88   } else {
89     ConvFp32OutNC4HW4(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
90                       (float *)conv->bias_data_, conv_im2col->col_major_input_, conv_im2col->tmp_output_, task_id,
91                       conv_param);
92   }
93   return NNACL_OK;
94 }
95 
ConvolutionIm2colAvxCompute(KernelBase * self)96 int ConvolutionIm2colAvxCompute(KernelBase *self) {
97   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self;
98   NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
99 
100   int ret = conv_im2col->init_tmp_buffer_(conv_im2col);
101   if (ret != NNACL_OK) {
102     ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
103     return ret;
104   }
105 
106   float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_;
107   NNACL_CHECK_NULL_RETURN_ERR(output_addr);
108   if (!conv_im2col->output_need_align_) {
109     conv_im2col->tmp_output_ = output_addr;
110   }
111 
112   ret = ConvBaseRepackWeight(&conv_im2col->conv_);
113   if (ret != NNACL_OK) {
114     ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
115     return ret;
116   }
117 
118   ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_);
119 
120   if (conv_im2col->output_need_align_) {
121     PackNC8HW8AlignedToNC8HW8NotAlignedFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_,
122                                             conv_im2col->conv_.compute_.out_w_ * conv_im2col->conv_.compute_.out_h_,
123                                             conv_im2col->conv_.compute_.out_c_);
124   } else {
125     conv_im2col->tmp_output_ = NULL;
126   }
127 
128   ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
129   return ret;
130 }
131 
CreateConvIm2ColAVX(ConvParameter * conv_param)132 ConvolutionBaseStruct *CreateConvIm2ColAVX(ConvParameter *conv_param) {
133   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct));
134   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col);
135   memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct));
136 
137   conv_im2col->init_tmp_buffer_ = ConvIm2ColAVXInitTmpBuffer;
138 
139   conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData;
140   conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVXInitGlobalVariable;
141   conv_im2col->conv_.run_impl_ = ConvIm2ColAVXRunImpl;
142   conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight;
143 
144   conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvxCompute;
145   conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare;
146   conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize;
147   conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease;
148 
149   return (ConvolutionBaseStruct *)conv_im2col;
150 }
151 #endif
152