1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either convolutionress or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifdef ENABLE_AVX512
18 #include "nnacl/kernel/convolution_im2col_avx512.h"
19 #include "nnacl/fp32/conv_im2col_avx512_fp32.h"
20 #include "nnacl/fp32/pack_fp32.h"
21 #include "nnacl/tensor_c.h"
22
ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct * conv)23 void ConvIm2ColAVX512InitGlobalVariable(ConvolutionBaseStruct *conv) {
24 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
25 conv_im2col->oc_tile_ = C16NUM;
26 conv_im2col->row_tile_ =
27 MSMIN(UP_DIV(conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.base_.thread_nr_), C150NUM);
28 conv_im2col->row_major_to_col_nmajor_ = RowMajor2Col64Major;
29 }
30
ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct * conv_im2col)31 int ConvIm2ColAVX512InitTmpBuffer(struct ConvolutionIm2ColBaseStruct *conv_im2col) {
32 ExecEnv *env = conv_im2col->conv_.base_.env_;
33 NNACL_CHECK_NULL_RETURN_ERR(env);
34 ConvComputeParam *compute = &conv_im2col->conv_.compute_;
35 NNACL_CHECK_NULL_RETURN_ERR(compute);
36
37 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->kernel_hw_, compute->in_c_, NNACL_ERR);
38 int kernel_chw = compute->kernel_hw_ * compute->in_c_;
39 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(kernel_chw, conv_im2col->conv_.base_.thread_nr_, NNACL_ERR);
40 int total_kernel_chw = kernel_chw * conv_im2col->conv_.base_.thread_nr_;
41 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(total_kernel_chw, conv_im2col->row_tile_, NNACL_ERR);
42 size_t unit_size = total_kernel_chw * conv_im2col->row_tile_;
43
44 if (conv_im2col->packed_input_ != NULL) {
45 env->Free(env->allocator_, conv_im2col->packed_input_);
46 conv_im2col->packed_input_ = NULL;
47 }
48 conv_im2col->packed_input_ = env->Alloc(env->allocator_, unit_size * sizeof(float));
49 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->packed_input_);
50
51 conv_im2col->output_need_align_ = compute->out_c_ % conv_im2col->oc_tile_ != 0;
52 if (conv_im2col->output_need_align_) {
53 if (conv_im2col->tmp_output_ != NULL) {
54 env->Free(env->allocator_, conv_im2col->tmp_output_);
55 conv_im2col->tmp_output_ = NULL;
56 }
57
58 // avx512 need to malloc dst aligned to C16NUM
59 int oc_algin = UP_ROUND(compute->out_c_, conv_im2col->oc_tile_);
60 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(compute->out_n_, compute->out_hw_, NNACL_ERR);
61 int output_bhw = compute->out_n_ * compute->out_hw_;
62 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(output_bhw, oc_algin, NNACL_ERR);
63 size_t pack_output_size = output_bhw * compute->out_w_ * oc_algin;
64
65 conv_im2col->tmp_output_ = env->Alloc(env->allocator_, pack_output_size * sizeof(float));
66 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(conv_im2col->tmp_output_);
67 }
68
69 return NNACL_OK;
70 }
71
ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct * conv,int task_id)72 int ConvIm2ColAVX512RunImpl(struct ConvolutionBaseStruct *conv, int task_id) {
73 if (conv->out_format_ == Format_NC4HW4) {
74 return NNACL_CONVOLUTION_AVX512_UNSUPPORT_FORMAT;
75 }
76
77 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)conv;
78 NNACL_CHECK_NULL_RETURN_ERR(conv_im2col);
79 ConvParameter *conv_param = (ConvParameter *)conv->base_.param_;
80 NNACL_CHECK_NULL_RETURN_ERR(conv_param);
81 float *ori_input_data = (float *)conv->base_.in_[FIRST_INPUT]->data_;
82 NNACL_CHECK_NULL_RETURN_ERR(ori_input_data);
83
84 if (conv->use_batch_cut_flag_) {
85 ConvIm2ColAVX512Fp32CutByBatch(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
86 (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param,
87 conv_im2col->row_tile_);
88 } else {
89 ConvIm2ColAVX512Fp32(ori_input_data, conv_im2col->packed_input_, (float *)conv->packed_weight_,
90 (float *)conv->bias_data_, conv_im2col->tmp_output_, task_id, conv_param,
91 conv_im2col->row_tile_);
92 }
93 return NNACL_OK;
94 }
95
ConvolutionIm2colAvx512Compute(KernelBase * self)96 int ConvolutionIm2colAvx512Compute(KernelBase *self) {
97 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)self;
98 int ret = conv_im2col->init_tmp_buffer_(conv_im2col);
99 if (ret != NNACL_OK) {
100 ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
101 return ret;
102 }
103
104 float *output_addr = (float *)self->out_[OUTPUT_INDEX]->data_;
105 if (!conv_im2col->output_need_align_) {
106 conv_im2col->tmp_output_ = output_addr;
107 }
108
109 ret = ConvBaseRepackWeight(&conv_im2col->conv_);
110 if (ret != NNACL_OK) {
111 ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
112 return ret;
113 }
114
115 ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConvIm2ColBaseImpl, self, self->thread_nr_);
116
117 if (conv_im2col->output_need_align_) {
118 PackNHWCXToNHWCFp32(conv_im2col->tmp_output_, output_addr, conv_im2col->conv_.compute_.out_n_,
119 conv_im2col->conv_.compute_.out_hw_, conv_im2col->conv_.compute_.out_c_, conv_im2col->oc_tile_);
120 } else {
121 conv_im2col->tmp_output_ = NULL;
122 }
123
124 ConvIm2ColBaseFreeTmpBuffer(conv_im2col);
125 return ret;
126 }
127
CreateConvIm2ColAVX512(ConvParameter * conv_param)128 ConvolutionBaseStruct *CreateConvIm2ColAVX512(ConvParameter *conv_param) {
129 ConvolutionIm2ColBaseStruct *conv_im2col = (ConvolutionIm2ColBaseStruct *)malloc(sizeof(ConvolutionIm2ColBaseStruct));
130 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(conv_im2col);
131 memset(conv_im2col, 0, sizeof(ConvolutionIm2ColBaseStruct));
132
133 conv_im2col->init_tmp_buffer_ = ConvIm2ColAVX512InitTmpBuffer;
134 conv_im2col->conv_.malloc_weight_bias_ = ConvIm2ColBaseMallocWeightBiasData;
135 conv_im2col->conv_.init_global_variable_ = ConvIm2ColAVX512InitGlobalVariable;
136 conv_im2col->conv_.run_impl_ = ConvIm2ColAVX512RunImpl;
137 conv_im2col->conv_.pack_weight_ = ConvIm2ColBasePackWeight;
138
139 conv_im2col->conv_.base_.Compute = ConvolutionIm2colAvx512Compute;
140 conv_im2col->conv_.base_.Prepare = ConvolutionIm2colBasePrepare;
141 conv_im2col->conv_.base_.Resize = ConvolutionIm2colBaseResize;
142 conv_im2col->conv_.base_.Release = ConvolutionIm2colBaseRelease;
143
144 return (ConvolutionBaseStruct *)conv_im2col;
145 }
146 #endif
147