1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "nnacl/fp16/conv_fp16.h"
17 #include <string.h>
18 #include "nnacl/fp16/pack_fp16.h"
19 #include "nnacl/fp16/winograd_transform_fp16.h"
20 #include "nnacl/fp16/matmul_fp16.h"
21
22 // fp16 convolution common (im2col+gemm)
ConvFp16(const float16_t * input_data,float16_t * packed_input,const float16_t * packed_weight,const float16_t * bias_data,float16_t * col_major_input,float16_t * output_data,int task_id,const ConvParameter * conv_param)23 void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight,
24 const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id,
25 const ConvParameter *conv_param) {
26 #ifdef ENABLE_ARM64
27 const int tile_n = 16;
28 #else
29 const int tile_n = 12;
30 #endif
31 NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
32 NNACL_CHECK_ZERO_RETURN(tile_n);
33 int output_hw = conv_param->output_h_ * conv_param->output_w_;
34 int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_);
35 int start_block = block_per_thread * task_id;
36 int start_hw = start_block * tile_n;
37 int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n);
38 if (start_hw >= end_hw) {
39 return;
40 }
41 int out_stride = conv_param->output_channel_ * tile_n;
42 int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
43 packed_input += task_id * deep * tile_n;
44 col_major_input += task_id * deep * tile_n;
45 size_t input_size = deep * tile_n * sizeof(float16_t);
46
47 for (int b = 0; b < conv_param->input_batch_; b++) {
48 int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
49 int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_;
50 for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) {
51 int real_cal_row = MSMIN(output_hw - i, tile_n);
52 memset(packed_input, 0, input_size);
53 Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
54 #ifdef ENABLE_ARM64
55 RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
56 #else
57 RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
58 #endif
59 MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
60 real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc);
61 }
62 }
63 }
64
ConvOutNc8hw8Fp16(const float16_t * input_data,float16_t * packed_input,const float16_t * packed_weight,const float16_t * bias_data,float16_t * col_major_input,float16_t * output_data,int task_id,const ConvParameter * conv_param)65 void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight,
66 const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id,
67 const ConvParameter *conv_param) {
68 #ifdef ENABLE_ARM64
69 const int tile_n = 16;
70 #else
71 const int tile_n = 12;
72 #endif
73 NNACL_CHECK_ZERO_RETURN(conv_param->op_parameter_.thread_num_);
74 NNACL_CHECK_ZERO_RETURN(tile_n);
75 int output_hw = conv_param->output_h_ * conv_param->output_w_;
76 int input_block = UP_DIV(output_hw, tile_n);
77 int block_per_thread = UP_DIV(input_block, conv_param->thread_num_);
78 int start_block = block_per_thread * task_id;
79 int end_block = MSMIN(start_block + block_per_thread, input_block);
80 if (start_block >= end_block) {
81 return;
82 }
83 int weight_block = UP_DIV(conv_param->output_channel_, C8NUM);
84 int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
85 packed_input += deep * tile_n * task_id;
86 col_major_input += deep * tile_n * task_id;
87 size_t input_size = deep * tile_n * sizeof(float16_t);
88
89 for (int b = 0; b < conv_param->input_batch_; b++) {
90 int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
91 for (int i = start_block; i < end_block; i++) {
92 int real_in_row = (i != input_block - 1) ? tile_n : output_hw - i * tile_n;
93 memset(packed_input, 0, input_size);
94 Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_in_row, i * tile_n);
95 #ifdef ENABLE_ARM64
96 RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
97 #else
98 RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
99 #endif
100 const float16_t *cur_weight = packed_weight;
101 const float16_t *cur_bias = bias_data;
102 for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * deep, cur_bias += C8NUM) {
103 int real_weight_row = (j != weight_block - 1) ? C8NUM : conv_param->output_channel_ - j * C8NUM;
104 int out_offset = j * output_hw * C8NUM + i * tile_n * real_weight_row;
105 MatMulFp16(col_major_input, cur_weight, output_data + out_offset, cur_bias, conv_param->act_type_, deep,
106 real_in_row, real_weight_row, real_weight_row, OutType_Nhwc);
107 }
108 }
109 }
110 }
111
Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t * input,float16_t * pack_input,const float16_t * weight,const float16_t * bias,float16_t * output,int task_id,const MatMulParameter * param)112 void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight,
113 const float16_t *bias, float16_t *output, int task_id,
114 const MatMulParameter *param) {
115 #ifdef ENABLE_ARM64
116 const int tile_n = 16;
117 #else
118 const int tile_n = 12;
119 #endif
120 NNACL_CHECK_ZERO_RETURN(tile_n);
121 NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_);
122 int input_block = UP_DIV(param->row_, tile_n);
123 int weight_block = UP_DIV(param->col_, C8NUM);
124
125 int block_per_thread = UP_DIV(input_block, param->op_parameter_.thread_num_);
126 int input_start_block = block_per_thread * task_id;
127 int input_end_block = MSMIN(input_start_block + block_per_thread, input_block);
128 if (input_start_block >= input_end_block) {
129 return;
130 }
131 input += input_start_block * tile_n * param->deep_;
132 pack_input += input_start_block * tile_n * param->deep_;
133
134 int cur_row_cnt = MSMIN(block_per_thread * tile_n, param->row_ - input_start_block * tile_n);
135 #ifdef ENABLE_ARM64
136 RowMajor2Col16MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_);
137 #else
138 RowMajor2Col12MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_);
139 #endif
140 for (int i = input_start_block; i < input_end_block; i++) {
141 int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n;
142 const float16_t *cur_weight = weight;
143 const float16_t *cur_bias = bias;
144 for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) {
145 int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM;
146 int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row;
147 MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row,
148 real_weight_row, real_weight_row, OutType_Nhwc);
149 }
150 pack_input += real_in_row * param->deep_;
151 }
152 }
153
Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t * input,float16_t * pack_input,const float16_t * weight,const float16_t * bias,float16_t * output,int task_id,const MatMulParameter * param)154 void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight,
155 const float16_t *bias, float16_t *output, int task_id,
156 const MatMulParameter *param) {
157 #ifdef ENABLE_ARM64
158 const int tile_n = 16;
159 #else
160 const int tile_n = 12;
161 #endif
162 NNACL_CHECK_ZERO_RETURN(tile_n);
163 NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_);
164 int input_block = UP_DIV(param->row_, tile_n);
165 int weight_block = UP_DIV(param->col_, C8NUM);
166
167 int block_per_thread = UP_DIV(weight_block, param->op_parameter_.thread_num_);
168 int weight_start_block = block_per_thread * task_id;
169 int weight_end_block = MSMIN(weight_start_block + block_per_thread, weight_block);
170 if (weight_start_block >= weight_end_block) {
171 return;
172 }
173 for (int i = 0; i < input_block; i++) {
174 int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n;
175 const float16_t *cur_weight = weight + weight_start_block * C8NUM * param->deep_;
176 const float16_t *cur_bias = bias + weight_start_block * C8NUM;
177 for (int j = weight_start_block; j < weight_end_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) {
178 int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM;
179 int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row;
180 MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row,
181 real_weight_row, real_weight_row, OutType_Nhwc);
182 }
183 pack_input += real_in_row * param->deep_;
184 }
185 }
186
187 // fp16 convolution winograd
ConvWinogardFp16(const float16_t * input_data,const float16_t * trans_weight,const float16_t * bias_data,float16_t * output_data,TmpBufferAddressFp16 * buffer_list,int task_id,const ConvParameter * conv_param,InputTransFp16Func in_func,OutputTransFp16Func out_func)188 void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data,
189 float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id,
190 const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) {
191 #ifdef ENABLE_ARM64
192 const int tile_num = 16;
193 #else
194 const int tile_num = 12;
195 #endif
196 NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_);
197 NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
198 int in_channel = conv_param->input_channel_;
199 int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
200 int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
201 int output_count = out_w_block * out_h_block;
202 int per_thread_num = UP_DIV(output_count, conv_param->thread_num_);
203 int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num;
204 NNACL_CHECK_ZERO_RETURN(real_tile);
205 int output_tile_count = UP_DIV(output_count, real_tile);
206 int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
207 int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
208
209 float16_t *trans_input = buffer_list[0];
210 float16_t *gemm_out = buffer_list[1];
211 float16_t *tmp_data = buffer_list[2];
212 float16_t *col_buffer = buffer_list[3];
213 int trans_input_offset = tile_num * input_unit_square * in_channel;
214 int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
215 int tmp_data_offset = input_unit_square * C8NUM;
216 int col_buffer_offset = tile_num * in_channel;
217 // step 1 : filter transform (pre-processed offline)
218 // step 2 : input transform (online)
219 for (int b = 0; b < conv_param->input_batch_; b++) {
220 int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
221 int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
222 for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
223 int out_tile_index = thread_id * real_tile;
224 int cal_num = output_count - thread_id * real_tile;
225 cal_num = cal_num > real_tile ? real_tile : cal_num;
226 if (cal_num <= 0) {
227 return;
228 }
229 WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
230 tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
231 in_func);
232 // step 3 : gemm
233 float16_t *src_ptr = trans_input + task_id * trans_input_offset;
234 float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
235 float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
236 for (int i = 0; i < input_unit_square; ++i) {
237 #ifdef ENABLE_ARM64
238 RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
239 #else
240 RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
241 #endif
242 MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
243 cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
244 }
245
246 // step 4 : output transform
247 if (conv_param->out_format_ != NNACL_NC4HW4) { // nc4hw4
248 WinogradOutputNHWCTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data,
249 cal_num, out_tile_index, out_w_block, conv_param, out_func);
250 } else {
251 WinogradOutputNC8HW8TransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset,
252 bias_data, cal_num, out_tile_index, out_w_block, conv_param, out_func);
253 }
254 }
255 }
256 }
257