• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/fp16/winograd_transform_fp16.h"
18 
19 // fp16 common winograd
WinogradInputTransformFp16(const float16_t * input_data,float16_t * trans_input,float16_t * tmp_data,int cal_num,int out_tile_index,int out_w_block_num,const ConvParameter * conv_param,InputTransFp16Func func)20 void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num,
21                                 int out_tile_index, int out_w_block_num, const ConvParameter *conv_param,
22                                 InputTransFp16Func func) {
23 #ifdef ENABLE_ARM64
24   const int tile_num = 16;
25 #else
26   const int tile_num = 12;
27 #endif
28   int input_unit = conv_param->input_unit_;
29   int output_unit = conv_param->output_unit_;
30   int in_channel = conv_param->input_channel_;
31   int ic8 = UP_DIV(in_channel, C8NUM);
32   int pad_h = conv_param->pad_u_;
33   int pad_w = conv_param->pad_l_;
34   int input_h = conv_param->input_h_;
35   int input_w = conv_param->input_w_;
36   if (out_w_block_num == 0) {
37     return;
38   }
39   for (int c = 0; c < cal_num; c++) {  // actual tiled number
40     int src_x_s = (out_tile_index % out_w_block_num) * output_unit - pad_w;
41     int src_y_s = (out_tile_index / out_w_block_num) * output_unit - pad_h;
42     int interval_x_s = src_x_s > 0 ? 0 : -src_x_s;
43     int interval_y_s = src_y_s > 0 ? 0 : -src_y_s;
44     int src_x_e = src_x_s + input_unit;
45     int src_y_e = src_y_s + input_unit;
46     int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s);
47     int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
48 
49     int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
50     int dst_plane_offset = c * in_channel;
51     for (int ic = 0; ic < ic8; ic++) {
52       // clear tmp buffer
53       memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
54 
55       int real_c = in_channel - ic * C8NUM;
56       real_c = real_c > C8NUM ? C8NUM : real_c;
57       int src_ic8_offset = src_plane_offset + ic * C8NUM;
58 
59       // get real input block with padding
60       if (real_c == C8NUM) {
61         for (int interval = interval_y_s; interval < interval_y_e; interval++) {
62           int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
63           int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
64           for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
65             int src_x_offset = src_y_offset + j * in_channel;
66             int dst_x_offset = dst_y_offset + j * C8NUM;
67             const float16_t *src_addr = input_data + src_x_offset;
68             float16_t *dst_addr = tmp_data + dst_x_offset;
69 #ifdef ENABLE_NEON
70             vst1q_f16(dst_addr, vld1q_f16(src_addr));
71 #else
72             for (int k = 0; k < C8NUM; k++) {
73               dst_addr[k] = src_addr[k];
74             }
75 #endif
76           }
77         }
78       } else if (real_c < 8 && real_c >= 4) {
79         for (int interval = interval_y_s; interval < interval_y_e; interval++) {
80           int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
81           int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
82           for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
83             int src_x_offset = src_y_offset + j * in_channel;
84             int dst_x_offset = dst_y_offset + j * C8NUM;
85             const float16_t *src_addr = input_data + src_x_offset;
86             float16_t *dst_addr = tmp_data + dst_x_offset;
87             int rc = real_c - 4;
88 #ifdef ENABLE_NEON
89             vst1_f16(dst_addr, vld1_f16(src_addr));
90 #else
91             for (int k = 0; k < C4NUM; k++) {
92               dst_addr[k] = src_addr[k];
93             }
94 #endif
95             src_addr += 4;
96             dst_addr += 4;
97             for (int i = 0; i < rc; ++i) {
98               dst_addr[i] = src_addr[i];
99             }
100           }
101         }
102       } else {
103         for (int interval = interval_y_s; interval < interval_y_e; interval++) {
104           int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
105           int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
106           for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
107             int src_x_offset = src_y_offset + j * in_channel;
108             int dst_x_offset = dst_y_offset + j * C8NUM;
109             const float16_t *src_addr = input_data + src_x_offset;
110             float16_t *dst_addr = tmp_data + dst_x_offset;
111             for (int k = 0; k < real_c; k++) {
112               dst_addr[k] = src_addr[k];
113             }
114           }
115         }
116       }
117 
118       // input transform
119       int dst_ic8_offset = dst_plane_offset + ic * C8NUM;
120       size_t dst_step = in_channel * tile_num;
121       float16_t *trans_input_ptr = trans_input + dst_ic8_offset;
122       func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c);
123     }
124     out_tile_index++;
125   }  // cal_tile_num loop
126 }
127 
WinogradOutputNHWCTransformFp16(const float16_t * gemm_out,float16_t * tmp_out_data,const float16_t * bias_data,int cal_num,int out_tile_index,int output_unit_num,const ConvParameter * conv_param,OutputTransFp16Func func)128 void WinogradOutputNHWCTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
129                                      int cal_num, int out_tile_index, int output_unit_num,
130                                      const ConvParameter *conv_param, OutputTransFp16Func func) {
131   int output_unit = conv_param->output_unit_;
132   int output_w = conv_param->output_w_;
133   int output_h = conv_param->output_h_;
134   int output_channel = conv_param->output_channel_;
135   int oc8 = UP_DIV(output_channel, C8NUM);
136   int input_unit = conv_param->input_unit_;
137   NNACL_CHECK_ZERO_RETURN(output_unit_num);
138   for (int i = 0; i < cal_num; i++) {
139     int dst_x_s = out_tile_index % output_unit_num;
140     int dst_y_s = out_tile_index / output_unit_num;
141     int r_w = output_w - dst_x_s * output_unit;
142     r_w = r_w > output_unit ? output_unit : r_w;
143     int r_h = output_h - dst_y_s * output_unit;
144     r_h = r_h > output_unit ? output_unit : r_h;
145     int tmp_ix = dst_x_s * output_unit;
146     dst_x_s = tmp_ix > output_w ? output_w : tmp_ix;
147     int tmp_iy = dst_y_s * output_unit;
148     dst_y_s = tmp_iy > output_h ? output_h : tmp_iy;
149 
150     int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
151     int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w);
152 
153     for (int j = 0; j < oc8; j++) {
154       int r_c = output_channel - j * C8NUM;
155       r_c = r_c > C8NUM ? C8NUM : r_c;
156       int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
157       int dst_oc8_offset = dst_tile_offset + j * C8NUM;
158       const float16_t *src_ptr = gemm_out + src_oc8_offset;
159       const float16_t *bias_ptr = bias_data + j * C8NUM;
160       float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
161       func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c);
162     }
163     out_tile_index++;
164   }
165 }
166 
WinogradOutputNC8HW8TransformFp16(const float16_t * gemm_out,float16_t * tmp_out_data,const float16_t * bias_data,int cal_num,int out_tile_index,int output_unit_num,const ConvParameter * conv_param,OutputTransFp16Func func)167 void WinogradOutputNC8HW8TransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data,
168                                        int cal_num, int out_tile_index, int output_unit_num,
169                                        const ConvParameter *conv_param, OutputTransFp16Func func) {
170   int output_unit = conv_param->output_unit_;
171   int output_w = conv_param->output_w_;
172   int output_h = conv_param->output_h_;
173   int plane = output_w * output_h;
174   int output_channel = conv_param->output_channel_;
175   int oc8 = UP_DIV(output_channel, C8NUM);
176   int input_unit = conv_param->input_unit_;
177   NNACL_CHECK_ZERO_RETURN(output_unit_num);
178   for (int i = 0; i < cal_num; i++) {
179     int dst_x_s = out_tile_index % output_unit_num;
180     int dst_y_s = out_tile_index / output_unit_num;
181     int r_w = output_w - dst_x_s * output_unit;
182     r_w = r_w > output_unit ? output_unit : r_w;
183     int r_h = output_h - dst_y_s * output_unit;
184     r_h = r_h > output_unit ? output_unit : r_h;
185     int tmp_ix = dst_x_s * output_unit;
186     dst_x_s = tmp_ix > output_w ? output_w : tmp_ix;
187     int tmp_iy = dst_y_s * output_unit;
188     dst_y_s = tmp_iy > output_h ? output_h : tmp_iy;
189 
190     int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit;
191     int dst_tile_offset = dst_x_s + dst_y_s * output_w;
192 
193     for (int j = 0; j < oc8; j++) {
194       int r_c = output_channel - j * C8NUM;
195       r_c = r_c > C8NUM ? C8NUM : r_c;
196       int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM;
197       int dst_oc8_offset = (dst_tile_offset + plane * j) * C8NUM;
198       const float16_t *src_ptr = gemm_out + src_oc8_offset;
199       const float16_t *bias_ptr = bias_data + j * C8NUM;
200       float16_t *dst_ptr = tmp_out_data + dst_oc8_offset;
201       func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, r_c, r_w, r_h, r_c);
202     }
203     out_tile_index++;
204   }
205 }
206 
WinogradWeightTransformFp16(const float16_t * weight_data,float16_t * winograd_data,const float * matrix_g,const float * matrix_gt,int oc_block,int input_unit,int kernel_unit,int filter_channel,int filter_batch,bool pack)207 int WinogradWeightTransformFp16(const float16_t *weight_data, float16_t *winograd_data, const float *matrix_g,
208                                 const float *matrix_gt, int oc_block, int input_unit, int kernel_unit,
209                                 int filter_channel, int filter_batch, bool pack) {
210   // original weight format : ohwi
211   int oc_block_num = UP_DIV(filter_batch, oc_block);
212   int block_stride = filter_channel * oc_block;
213   int block_num_stride = block_stride * oc_block_num;
214 
215   float16_t *matrix_gt_data_fp16 = (float16_t *)(malloc(input_unit * kernel_unit * sizeof(float16_t)));
216   if (matrix_gt_data_fp16 == NULL) {
217     return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
218   }
219   Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit * kernel_unit);
220 
221   // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T
222   // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T
223   float16_t *tmp_data = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t)));
224   if (tmp_data == NULL) {
225     free(matrix_gt_data_fp16);
226     return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
227   }
228   float16_t *trans_out_data = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t)));
229   if (trans_out_data == NULL) {
230     free(tmp_data);
231     free(matrix_gt_data_fp16);
232     return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
233   }
234 
235 #ifndef ENABLE_ARM64
236   float16_t *tmp_data1 = (float16_t *)(malloc(filter_channel * input_unit * kernel_unit * sizeof(float16_t)));
237   if (tmp_data1 == NULL) {
238     free(tmp_data);
239     free(matrix_gt_data_fp16);
240     free(trans_out_data);
241     return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
242   }
243   float16_t *trans_out_data1 = (float16_t *)(malloc(filter_channel * input_unit * input_unit * sizeof(float16_t)));
244   if (trans_out_data1 == NULL) {
245     free(tmp_data);
246     free(tmp_data1);
247     free(matrix_gt_data_fp16);
248     free(trans_out_data);
249     return NNACL_ERRCODE_OP_FP16_WINOGRAD_GENERATOR;
250   }
251 #endif
252 
253   int input_oz_offset = kernel_unit * kernel_unit * filter_channel;
254   for (int i = 0; i < filter_batch; i++) {
255     int out_c_block = i / oc_block;
256     int out_c_res = i % oc_block;
257     int output_oz_offset = out_c_block * block_stride + out_c_res;
258 
259 #ifndef ENABLE_ARM64
260     // tmp_data = g * GT
261     MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit,
262                                kernel_unit, input_unit, filter_channel);
263     // tmp_data1 = (tmp_data)T
264     PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit, input_unit, filter_channel);
265     // trans_out_data1 = tmp * GT
266     MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit, kernel_unit, input_unit,
267                                filter_channel);
268     // trans_out_data = (trans_out_data1)T
269     PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit, input_unit, filter_channel);
270 #else
271     // tmp = (g * GT)T
272     MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit,
273                                kernel_unit, input_unit, filter_channel);
274     // trans = (tmp * GT)T
275     MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit, kernel_unit, input_unit,
276                                filter_channel);
277 #endif
278 
279     if (pack) {
280       int in_offset = 0;
281       for (int j = 0; j < input_unit; ++j) {
282         for (int k = 0; k < input_unit; ++k) {
283           for (int c = 0; c < filter_channel; ++c) {
284             *(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
285           }
286           in_offset += filter_channel;
287           output_oz_offset += block_num_stride;
288         }
289       }
290     } else {
291       memcpy(winograd_data + i * filter_channel * input_unit * input_unit, trans_out_data,
292              filter_channel * input_unit * input_unit * sizeof(float16_t));
293     }
294   }
295 
296 #ifndef ENABLE_ARM64
297   free(tmp_data1);
298   free(trans_out_data1);
299 #endif
300   free(tmp_data);
301   free(trans_out_data);
302   free(matrix_gt_data_fp16);
303   return NNACL_OK;
304 }
305