• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "nnacl/fp16/conv_fp16.h"
17 #include <string.h>
18 #include "nnacl/fp16/pack_fp16.h"
19 #include "nnacl/fp16/winograd_transform_fp16.h"
20 #include "nnacl/fp16/matmul_fp16.h"
21 
22 // fp16 convolution common (im2col+gemm)
ConvFp16(const float16_t * input_data,float16_t * packed_input,const float16_t * packed_weight,const float16_t * bias_data,float16_t * col_major_input,float16_t * output_data,int task_id,const ConvParameter * conv_param)23 void ConvFp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight,
24               const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id,
25               const ConvParameter *conv_param) {
26 #ifdef ENABLE_ARM64
27   const int tile_n = 16;
28 #else
29   const int tile_n = 12;
30 #endif
31   NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
32   NNACL_CHECK_ZERO_RETURN(tile_n);
33   int output_hw = conv_param->output_h_ * conv_param->output_w_;
34   int block_per_thread = UP_DIV(UP_DIV(output_hw, tile_n), conv_param->thread_num_);
35   int start_block = block_per_thread * task_id;
36   int start_hw = start_block * tile_n;
37   int end_hw = MSMIN(output_hw, (start_block + block_per_thread) * tile_n);
38   if (start_hw >= end_hw) {
39     return;
40   }
41   int out_stride = conv_param->output_channel_ * tile_n;
42   int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
43   packed_input += task_id * deep * tile_n;
44   col_major_input += task_id * deep * tile_n;
45   size_t input_size = deep * tile_n * sizeof(float16_t);
46 
47   for (int b = 0; b < conv_param->input_batch_; b++) {
48     int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
49     int out_offset = b * conv_param->output_channel_ * output_hw + start_hw * conv_param->output_channel_;
50     for (int i = start_hw; i < end_hw; i += tile_n, out_offset += out_stride) {
51       int real_cal_row = MSMIN(output_hw - i, tile_n);
52       memset(packed_input, 0, input_size);
53       Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_cal_row, i);
54 #ifdef ENABLE_ARM64
55       RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
56 #else
57       RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
58 #endif
59       MatMulFp16(col_major_input, packed_weight, output_data + out_offset, bias_data, conv_param->act_type_, deep,
60                  real_cal_row, conv_param->output_channel_, conv_param->output_channel_, OutType_Nhwc);
61     }
62   }
63 }
64 
ConvOutNc8hw8Fp16(const float16_t * input_data,float16_t * packed_input,const float16_t * packed_weight,const float16_t * bias_data,float16_t * col_major_input,float16_t * output_data,int task_id,const ConvParameter * conv_param)65 void ConvOutNc8hw8Fp16(const float16_t *input_data, float16_t *packed_input, const float16_t *packed_weight,
66                        const float16_t *bias_data, float16_t *col_major_input, float16_t *output_data, int task_id,
67                        const ConvParameter *conv_param) {
68 #ifdef ENABLE_ARM64
69   const int tile_n = 16;
70 #else
71   const int tile_n = 12;
72 #endif
73   NNACL_CHECK_ZERO_RETURN(conv_param->op_parameter_.thread_num_);
74   NNACL_CHECK_ZERO_RETURN(tile_n);
75   int output_hw = conv_param->output_h_ * conv_param->output_w_;
76   int input_block = UP_DIV(output_hw, tile_n);
77   int block_per_thread = UP_DIV(input_block, conv_param->thread_num_);
78   int start_block = block_per_thread * task_id;
79   int end_block = MSMIN(start_block + block_per_thread, input_block);
80   if (start_block >= end_block) {
81     return;
82   }
83   int weight_block = UP_DIV(conv_param->output_channel_, C8NUM);
84   int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
85   packed_input += deep * tile_n * task_id;
86   col_major_input += deep * tile_n * task_id;
87   size_t input_size = deep * tile_n * sizeof(float16_t);
88 
89   for (int b = 0; b < conv_param->input_batch_; b++) {
90     int in_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
91     for (int i = start_block; i < end_block; i++) {
92       int real_in_row = (i != input_block - 1) ? tile_n : output_hw - i * tile_n;
93       memset(packed_input, 0, input_size);
94       Im2ColPackUnitFp16(input_data + in_offset, conv_param, packed_input, real_in_row, i * tile_n);
95 #ifdef ENABLE_ARM64
96       RowMajor2Col16MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
97 #else
98       RowMajor2Col12MajorFp16Opt(packed_input, col_major_input, tile_n, deep);
99 #endif
100       const float16_t *cur_weight = packed_weight;
101       const float16_t *cur_bias = bias_data;
102       for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * deep, cur_bias += C8NUM) {
103         int real_weight_row = (j != weight_block - 1) ? C8NUM : conv_param->output_channel_ - j * C8NUM;
104         int out_offset = j * output_hw * C8NUM + i * tile_n * real_weight_row;
105         MatMulFp16(col_major_input, cur_weight, output_data + out_offset, cur_bias, conv_param->act_type_, deep,
106                    real_in_row, real_weight_row, real_weight_row, OutType_Nhwc);
107       }
108     }
109   }
110 }
111 
Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t * input,float16_t * pack_input,const float16_t * weight,const float16_t * bias,float16_t * output,int task_id,const MatMulParameter * param)112 void Conv1x1OutNc8hw8MultiThreadByInputFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight,
113                                             const float16_t *bias, float16_t *output, int task_id,
114                                             const MatMulParameter *param) {
115 #ifdef ENABLE_ARM64
116   const int tile_n = 16;
117 #else
118   const int tile_n = 12;
119 #endif
120   NNACL_CHECK_ZERO_RETURN(tile_n);
121   NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_);
122   int input_block = UP_DIV(param->row_, tile_n);
123   int weight_block = UP_DIV(param->col_, C8NUM);
124 
125   int block_per_thread = UP_DIV(input_block, param->op_parameter_.thread_num_);
126   int input_start_block = block_per_thread * task_id;
127   int input_end_block = MSMIN(input_start_block + block_per_thread, input_block);
128   if (input_start_block >= input_end_block) {
129     return;
130   }
131   input += input_start_block * tile_n * param->deep_;
132   pack_input += input_start_block * tile_n * param->deep_;
133 
134   int cur_row_cnt = MSMIN(block_per_thread * tile_n, param->row_ - input_start_block * tile_n);
135 #ifdef ENABLE_ARM64
136   RowMajor2Col16MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_);
137 #else
138   RowMajor2Col12MajorFp16Opt(input, pack_input, cur_row_cnt, param->deep_);
139 #endif
140   for (int i = input_start_block; i < input_end_block; i++) {
141     int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n;
142     const float16_t *cur_weight = weight;
143     const float16_t *cur_bias = bias;
144     for (int j = 0; j < weight_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) {
145       int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM;
146       int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row;
147       MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row,
148                  real_weight_row, real_weight_row, OutType_Nhwc);
149     }
150     pack_input += real_in_row * param->deep_;
151   }
152 }
153 
Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t * input,float16_t * pack_input,const float16_t * weight,const float16_t * bias,float16_t * output,int task_id,const MatMulParameter * param)154 void Conv1x1OutNc8hw8MultiThreadByWeightFp16(const float16_t *input, float16_t *pack_input, const float16_t *weight,
155                                              const float16_t *bias, float16_t *output, int task_id,
156                                              const MatMulParameter *param) {
157 #ifdef ENABLE_ARM64
158   const int tile_n = 16;
159 #else
160   const int tile_n = 12;
161 #endif
162   NNACL_CHECK_ZERO_RETURN(tile_n);
163   NNACL_CHECK_ZERO_RETURN(param->op_parameter_.thread_num_);
164   int input_block = UP_DIV(param->row_, tile_n);
165   int weight_block = UP_DIV(param->col_, C8NUM);
166 
167   int block_per_thread = UP_DIV(weight_block, param->op_parameter_.thread_num_);
168   int weight_start_block = block_per_thread * task_id;
169   int weight_end_block = MSMIN(weight_start_block + block_per_thread, weight_block);
170   if (weight_start_block >= weight_end_block) {
171     return;
172   }
173   for (int i = 0; i < input_block; i++) {
174     int real_in_row = (i != input_block - 1) ? tile_n : param->row_ - i * tile_n;
175     const float16_t *cur_weight = weight + weight_start_block * C8NUM * param->deep_;
176     const float16_t *cur_bias = bias + weight_start_block * C8NUM;
177     for (int j = weight_start_block; j < weight_end_block; j++, cur_weight += C8NUM * param->deep_, cur_bias += C8NUM) {
178       int real_weight_row = (j != weight_block - 1) ? C8NUM : param->col_ - j * C8NUM;
179       int out_offset = j * param->row_ * C8NUM + i * tile_n * real_weight_row;
180       MatMulFp16(pack_input, cur_weight, output + out_offset, cur_bias, param->act_type_, param->deep_, real_in_row,
181                  real_weight_row, real_weight_row, OutType_Nhwc);
182     }
183     pack_input += real_in_row * param->deep_;
184   }
185 }
186 
187 // fp16 convolution winograd
ConvWinogardFp16(const float16_t * input_data,const float16_t * trans_weight,const float16_t * bias_data,float16_t * output_data,TmpBufferAddressFp16 * buffer_list,int task_id,const ConvParameter * conv_param,InputTransFp16Func in_func,OutputTransFp16Func out_func)188 void ConvWinogardFp16(const float16_t *input_data, const float16_t *trans_weight, const float16_t *bias_data,
189                       float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id,
190                       const ConvParameter *conv_param, InputTransFp16Func in_func, OutputTransFp16Func out_func) {
191 #ifdef ENABLE_ARM64
192   const int tile_num = 16;
193 #else
194   const int tile_num = 12;
195 #endif
196   NNACL_CHECK_ZERO_RETURN(conv_param->output_unit_);
197   NNACL_CHECK_ZERO_RETURN(conv_param->thread_num_);
198   int in_channel = conv_param->input_channel_;
199   int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
200   int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
201   int output_count = out_w_block * out_h_block;
202   int per_thread_num = UP_DIV(output_count, conv_param->thread_num_);
203   int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num;
204   NNACL_CHECK_ZERO_RETURN(real_tile);
205   int output_tile_count = UP_DIV(output_count, real_tile);
206   int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
207   int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
208 
209   float16_t *trans_input = buffer_list[0];
210   float16_t *gemm_out = buffer_list[1];
211   float16_t *tmp_data = buffer_list[2];
212   float16_t *col_buffer = buffer_list[3];
213   int trans_input_offset = tile_num * input_unit_square * in_channel;
214   int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
215   int tmp_data_offset = input_unit_square * C8NUM;
216   int col_buffer_offset = tile_num * in_channel;
217   // step 1 : filter transform (pre-processed offline)
218   // step 2 : input transform (online)
219   for (int b = 0; b < conv_param->input_batch_; b++) {
220     int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
221     int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
222     for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
223       int out_tile_index = thread_id * real_tile;
224       int cal_num = output_count - thread_id * real_tile;
225       cal_num = cal_num > real_tile ? real_tile : cal_num;
226       if (cal_num <= 0) {
227         return;
228       }
229       WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
230                                  tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
231                                  in_func);
232       // step 3 : gemm
233       float16_t *src_ptr = trans_input + task_id * trans_input_offset;
234       float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
235       float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
236       for (int i = 0; i < input_unit_square; ++i) {
237 #ifdef ENABLE_ARM64
238         RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
239 #else
240         RowMajor2Col12MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel);
241 #endif
242         MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
243                    cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8);
244       }
245 
246       // step 4 : output transform
247       if (conv_param->out_format_ != NNACL_NC4HW4) {  // nc4hw4
248         WinogradOutputNHWCTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data,
249                                         cal_num, out_tile_index, out_w_block, conv_param, out_func);
250       } else {
251         WinogradOutputNC8HW8TransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset,
252                                           bias_data, cal_num, out_tile_index, out_w_block, conv_param, out_func);
253       }
254     }
255   }
256 }
257