• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifdef ENABLE_AVX512
18 #include "nnacl/kernel/convolution_im2col_avx512.h"
19 #include "nnacl/fp32/conv_im2col_avx512_fp32.h"
20 #include "nnacl/fp32/pack_fp32.h"
21 #include "nnacl/tensor_c.h"
22 
ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct * conv)23 void ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct *conv) {
24   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
25   conv_im2col->oc_tile_ = C16NUM;
26   conv_im2col->row_tile_ =
27     MSMIN(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.base_.thread_nr_), C150NUM);
28   conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col64Major;
29 }
30 
ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct * conv_im2col)31 int ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct *conv_im2col) {
32   ExecEnv *env = conv_im2col->conv_.base_.env_;
33   NNACL_CHECK_NULL_RETURN_ERR(env);
34   ConvComputeParam *compute = &conv_im2col->conv_.compute_;
35   NNACL_CHECK_NULL_RETURN_ERR(compute);
36 
37   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_hw_, compute->in_c_, NNACL_ERR);
38   int kernel_chw = compute->kernel_hw_ * compute->in_c_;
39   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR);
40   int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_;
41   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR);
42   size_t unit_size = total_kernel_chw * conv_im2col->row_tile_;
43 
44   if (conv_im2col->packed_input_ != NULL) {
45     env->Free(env->allocator_, conv_im2col->packed_input_);
46     conv_im2col->packed_input_ = NULL;
47   }
48   conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float));
49   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_);
50 
51   conv_im2col->output_need_align_ = compute->out_c_ % conv_im2col->oc_tile_ != 0;
52   if (conv_im2col->output_need_align_) {
53     if (conv_im2col->tmp_output_ != NULL) {
54       env->Free(env->allocator_, conv_im2col->tmp_output_);
55       conv_im2col->tmp_output_ = NULL;
56     }
57 
58     // avx512 need to malloc dst aligned to C16NUM
59     int oc_algin = UP_ROUND(compute->out_c_, conv_im2col->oc_tile_);
60     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR);
61     int output_bhw = compute->out_n_ * compute->out_hw_;
62     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_algin, NNACL_ERR);
63     size_t pack_output_size = output_bhw * compute->out_w_ * oc_algin;
64 
65     conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float));
66     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_);
67   }
68 
69   return NNACL_OK;
70 }
71 
ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct * conv,int task_id)72 int ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct *conv, int task_id) {
73   if (conv->out_format_ == Format_NC4HW4) {
74     return NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT;
75   }
76 
77   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
78   NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
79   ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
80   NNACL_CHECK_NULL_RETURN_ERR(conv_param);
81   float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_;
82   NNACL_CHECK_NULL_RETURN_ERR(ori_input_data);
83 
84   if (conv->use_batch_cut_flag_) {
85     ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
86                                    (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param,
87                                    conv_im2col->row_tile_);
88   } else {
89     ConvIm2ColAVX512Fp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
90                          (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param,
91                          conv_im2col->row_tile_);
92   }
93   return NNACL_OK;
94 }
95 
ConvolutionIm2colAvx512Compute(KernelBase * self)96 int ConvolutionIm2colAvx512Compute(KernelBase *self) {
97   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self;
98   int ret = conv_im2col->init_tmp_buffer_(conv_im2col);
99   if (ret != NNACL_OK) {
100     ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
101     return ret;
102   }
103 
104   float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_;
105   if (!conv_im2col->output_need_align_) {
106     conv_im2col->tmp_output_ = output_addr;
107   }
108 
109   ret = ConvBaseRepackWeight(&conv_im2col->conv_);
110   if (ret != NNACL_OK) {
111     ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
112     return ret;
113   }
114 
115   ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_);
116 
117   if (conv_im2col->output_need_align_) {
118     PackNHWCXToNHWCFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_,
119                         conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_);
120   } else {
121     conv_im2col->tmp_output_ = NULL;
122   }
123 
124   ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
125   return ret;
126 }
127 
CreateConvIm2ColAVX512(ConvParameter * conv_param)128 ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param) {
129   ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct));
130   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col);
131   memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct));
132 
133   conv_im2col->init_tmp_buffer_ = ConvIm2ColAVX512InitTmpBuffer;
134   conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData;
135   conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVX512InitGlobalVariable;
136   conv_im2col->conv_.run_impl_ = ConvIm2ColAVX512RunImpl;
137   conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight;
138 
139   conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvx512Compute;
140   conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare;
141   conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize;
142   conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease;
143 
144   return (ConvolutionBaseStruct *)conv_im2col;
145 }
146 #endif
147