1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/fp16/winograd_transform_fp16.h"
18
19 // fp16 common winograd
WinogradInputTransformFp16(const float16_t * input_data,float16_t * trans_input,float16_t * tmp_data,int cal_num,int out_tile_index,int out_w_block_num,const ConvParameter * conv_param,InputTransFp16Func func)20 void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
21 int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
22 InputTransFp16Func func) {
23 #ifdef ENABLE_ARM64
24 const int tile_num = 16;
25 #else
26 const int tile_num = 12;
27 #endif
28 int input_unit = conv_param->input_unit_;
29 int output_unit = conv_param->output_unit_;
30 int in_channel = conv_param->input_channel_;
31 int ic8 = UP_DIV(in_channel, C8NUM);
32 int pad_h = conv_param->pad_u_;
33 int pad_w = conv_param->pad_l_;
34 int input_h = conv_param->input_h_;
35 int input_w = conv_param->input_w_;
36 if (out_w_block_num == 0) {
37 return;
38 }
39 for (int c = 0; c < cal_num; c++) { // actual tiled number
40 int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
41 int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
42 int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
43 int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
44 int src_x_e = src_x_s + input_unit;
45 int src_y_e = src_y_s + input_unit;
46 int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
47 int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
48
49 int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
50 int dst_plane_offset = c * in_channel;
51 for (int ic = 0; ic < ic8; ic++) {
52 // clear tmp buffer
53 memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
54
55 int real_c = in_channel - ic * C8NUM;
56 real_c = real_c > C8NUM ? C8NUM : real_c;
57 int src_ic8_offset = src_plane_offset + ic * C8NUM;
58
59 // get real input block with padding
60 if (real_c == C8NUM) {
61 for (int interval = interval_y_s; interval < interval_y_e; interval++) {
62 int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
63 int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
64 for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
65 int src_x_offset = src_y_offset + j * in_channel;
66 int dst_x_offset = dst_y_offset + j * C8NUM;
67 const float16_t *src_addr = input_data + src_x_offset;
68 float16_t *dst_addr = tmp_data + dst_x_offset;
69 #ifdef ENABLE_NEON
70 vst1q_f16(dst_addr, vld1q_f16(src_addr));
71 #else
72 for (int k = 0; k < C8NUM; k++) {
73 dst_addr[k] = src_addr[k];
74 }
75 #endif
76 }
77 }
78 } else if (real_c < 8 && real_c >= 4) {
79 for (int interval = interval_y_s; interval < interval_y_e; interval++) {
80 int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
81 int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
82 for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
83 int src_x_offset = src_y_offset + j * in_channel;
84 int dst_x_offset = dst_y_offset + j * C8NUM;
85 const float16_t *src_addr = input_data + src_x_offset;
86 float16_t *dst_addr = tmp_data + dst_x_offset;
87 int rc = real_c - 4;
88 #ifdef ENABLE_NEON
89 vst1_f16(dst_addr, vld1_f16(src_addr));
90 #else
91 for (int k = 0; k < C4NUM; k++) {
92 dst_addr[k] = src_addr[k];
93 }
94 #endif
95 src_addr += 4;
96 dst_addr += 4;
97 for (int i = 0; i < rc; ++i) {
98 dst_addr[i] = src_addr[i];
99 }
100 }
101 }
102 } else {
103 for (int interval = interval_y_s; interval < interval_y_e; interval++) {
104 int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
105 int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
106 for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
107 int src_x_offset = src_y_offset + j * in_channel;
108 int dst_x_offset = dst_y_offset + j * C8NUM;
109 const float16_t *src_addr = input_data + src_x_offset;
110 float16_t *dst_addr = tmp_data + dst_x_offset;
111 for (int k = 0; k < real_c; k++) {
112 dst_addr[k] = src_addr[k];
113 }
114 }
115 }
116 }
117
118 // input transform
119 int dst_ic8_offset = dst_plane_offset + ic * C8NUM;
120 size_t dst_step = in_channel * tile_num;
121 float16_t *trans_input_ptr = trans_input + dst_ic8_offset;
122 func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c);
123 }
124 out_tile_index++;
125 } // cal_tile_num loop
126 }
127
WinogradOutputNHWCTransformFp16(const float16_t * gemm_out,float16_t * tmp_out_data,const float16_t * bias_data,int cal_num,int out_tile_index,int output_unit_num,const ConvParameter * conv_param,OutputTransFp16Func func)128 void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
129 int cal_num, int out_tile_index, int output_unit_num,
130 const ConvParameter *conv_param, OutputTransFp16Func func) {
131 int output_unit = conv_param->output_unit_;
132 int output_w = conv_param->output_w_;
133 int output_h = conv_param->output_h_;
134 int output_channel = conv_param->output_channel_;
135 int oc8 = UP_DIV(output_channel, C8NUM);
136 int input_unit = conv_param->input_unit_;
137 NNACL_CHECK_ZERO_RETURN(output_unit_num);
138 for (int i = 0; i < cal_num; i++) {
139 int dst_x_s = out_tile_index % output_unit_num;
140 int dst_y_s = out_tile_index / output_unit_num;
141 int r_w = output_w - dst_x_s * output_unit;
142 r_w = r_w > output_unit ? output_unit : r_w;
143 int r_h = output_h - dst_y_s * output_unit;
144 r_h = r_h > output_unit ? output_unit : r_h;
145 int tmp_ix = dst_x_s * output_unit;
146 dst_x_s = tmp_ix > output_w ? output_w : tmp_ix;
147 int tmp_iy = dst_y_s * output_unit;
148 dst_y_s = tmp_iy > output_h ? output_h : tmp_iy;
149
150 int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
151 int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w);
152
153 for (int j = 0; j < oc8; j++) {
154 int r_c = output_channel - j * C8NUM;
155 r_c = r_c > C8NUM ? C8NUM : r_c;
156 int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
157 int dst_oc8_offset = dst_tile_offset + j * C8NUM;
158 const float16_t *src_ptr = gemm_out + src_oc8_offset;
159 const float16_t *bias_ptr = bias_data + j * C8NUM;
160 float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
161 func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c);
162 }
163 out_tile_index++;
164 }
165 }
166
WinogradOutputNC8HW8TransformFp16(const float16_t * gemm_out,float16_t * tmp_out_data,const float16_t * bias_data,int cal_num,int out_tile_index,int output_unit_num,const ConvParameter * conv_param,OutputTransFp16Func func)167 void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
168 int cal_num, int out_tile_index, int output_unit_num,
169 const ConvParameter *conv_param, OutputTransFp16Func func) {
170 int output_unit = conv_param->output_unit_;
171 int output_w = conv_param->output_w_;
172 int output_h = conv_param->output_h_;
173 int plane = output_w * output_h;
174 int output_channel = conv_param->output_channel_;
175 int oc8 = UP_DIV(output_channel, C8NUM);
176 int input_unit = conv_param->input_unit_;
177 NNACL_CHECK_ZERO_RETURN(output_unit_num);
178 for (int i = 0; i < cal_num; i++) {
179 int dst_x_s = out_tile_index % output_unit_num;
180 int dst_y_s = out_tile_index / output_unit_num;
181 int r_w = output_w - dst_x_s * output_unit;
182 r_w = r_w > output_unit ? output_unit : r_w;
183 int r_h = output_h - dst_y_s * output_unit;
184 r_h = r_h > output_unit ? output_unit : r_h;
185 int tmp_ix = dst_x_s * output_unit;
186 dst_x_s = tmp_ix > output_w ? output_w : tmp_ix;
187 int tmp_iy = dst_y_s * output_unit;
188 dst_y_s = tmp_iy > output_h ? output_h : tmp_iy;
189
190 int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
191 int dst_tile_offset = dst_x_s + dst_y_s * output_w;
192
193 for (int j = 0; j < oc8; j++) {
194 int r_c = output_channel - j * C8NUM;
195 r_c = r_c > C8NUM ? C8NUM : r_c;
196 int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
197 int dst_oc8_offset = (dst_tile_offset + plane * j) * C8NUM;
198 const float16_t *src_ptr = gemm_out + src_oc8_offset;
199 const float16_t *bias_ptr = bias_data + j * C8NUM;
200 float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
201 func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c);
202 }
203 out_tile_index++;
204 }
205 }
206
WinogradWeightTransformFp16(const float16_t * weight_data,float16_t * winograd_data,const float * matrix_g,const float * matrix_gt,int oc_block,int input_unit,int kernel_unit,int filter_channel,int filter_batch,bool pack)207 int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g,
208 const float *matrix_gt, int oc_block, int input_unit, int kernel_unit,
209 int filter_channel, int filter_batch, bool pack) {
210 // original weight format : ohwi
211 int oc_block_num = UP_DIV(filter_batch, oc_block);
212 int block_stride = filter_channel * oc_block;
213 int block_num_stride = block_stride * oc_block_num;
214
215 float16_t *matrix_gt_data_fp16 = (float16_t *)(malloc(input_unit * kernel_unit * sizeof(float16_t)));
216 if (matrix_gt_data_fp16 == NULL) {
217 return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
218 }
219 Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit * kernel_unit);
220
221 // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T
222 // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T
223 float16_t *tmp_data = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t)));
224 if (tmp_data == NULL) {
225 free(matrix_gt_data_fp16);
226 return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
227 }
228 float16_t *trans_out_data = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t)));
229 if (trans_out_data == NULL) {
230 free(tmp_data);
231 free(matrix_gt_data_fp16);
232 return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
233 }
234
235 #ifndef ENABLE_ARM64
236 float16_t *tmp_data1 = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t)));
237 if (tmp_data1 == NULL) {
238 free(tmp_data);
239 free(matrix_gt_data_fp16);
240 free(trans_out_data);
241 return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
242 }
243 float16_t *trans_out_data1 = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t)));
244 if (trans_out_data1 == NULL) {
245 free(tmp_data);
246 free(tmp_data1);
247 free(matrix_gt_data_fp16);
248 free(trans_out_data);
249 return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
250 }
251 #endif
252
253 int input_oz_offset = kernel_unit * kernel_unit * filter_channel;
254 for (int i = 0; i < filter_batch; i++) {
255 int out_c_block = i / oc_block;
256 int out_c_res = i % oc_block;
257 int output_oz_offset = out_c_block * block_stride + out_c_res;
258
259 #ifndef ENABLE_ARM64
260 // tmp_data = g * GT
261 MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit,
262 kernel_unit, input_unit, filter_channel);
263 // tmp_data1 = (tmp_data)T
264 PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit, input_unit, filter_channel);
265 // trans_out_data1 = tmp * GT
266 MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit, kernel_unit, input_unit,
267 filter_channel);
268 // trans_out_data = (trans_out_data1)T
269 PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit, input_unit, filter_channel);
270 #else
271 // tmp = (g * GT)T
272 MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit,
273 kernel_unit, input_unit, filter_channel);
274 // trans = (tmp * GT)T
275 MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit,
276 filter_channel);
277 #endif
278
279 if (pack) {
280 int in_offset = 0;
281 for (int j = 0; j < input_unit; ++j) {
282 for (int k = 0; k < input_unit; ++k) {
283 for (int c = 0; c < filter_channel; ++c) {
284 *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
285 }
286 in_offset += filter_channel;
287 output_oz_offset += block_num_stride;
288 }
289 }
290 } else {
291 memcpy(winograd_data + i * filter_channel * input_unit * input_unit, trans_out_data,
292 filter_channel * input_unit * input_unit * sizeof(float16_t));
293 }
294 }
295
296 #ifndef ENABLE_ARM64
297 free(tmp_data1);
298 free(trans_out_data1);
299 #endif
300 free(tmp_data);
301 free(trans_out_data);
302 free(matrix_gt_data_fp16);
303 return NNACL_OK;
304 }
305