• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_NNACL_WINOGRAD_UTILS_H_
18 #define MINDSPORE_NNACL_WINOGRAD_UTILS_H_
19 
20 #ifdef ENABLE_ARM
21 #include <arm_neon.h>
22 #endif
23 #include "nnacl/conv_parameter.h"
24 #include "nnacl/op_base.h"
25 
26 #ifdef __cplusplus
27 extern "C" {
28 #endif
29 typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
30 
31 typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step,
32                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
33 
34 #define Load16Data                                \
35   src[0] = MS_LDQ_F32(src_data + 0 * src_step);   \
36   src[1] = MS_LDQ_F32(src_data + 1 * src_step);   \
37   src[2] = MS_LDQ_F32(src_data + 2 * src_step);   \
38   src[3] = MS_LDQ_F32(src_data + 3 * src_step);   \
39   src[4] = MS_LDQ_F32(src_data + 4 * src_step);   \
40   src[5] = MS_LDQ_F32(src_data + 5 * src_step);   \
41   src[6] = MS_LDQ_F32(src_data + 6 * src_step);   \
42   src[7] = MS_LDQ_F32(src_data + 7 * src_step);   \
43   src[8] = MS_LDQ_F32(src_data + 8 * src_step);   \
44   src[9] = MS_LDQ_F32(src_data + 9 * src_step);   \
45   src[10] = MS_LDQ_F32(src_data + 10 * src_step); \
46   src[11] = MS_LDQ_F32(src_data + 11 * src_step); \
47   src[12] = MS_LDQ_F32(src_data + 12 * src_step); \
48   src[13] = MS_LDQ_F32(src_data + 13 * src_step); \
49   src[14] = MS_LDQ_F32(src_data + 14 * src_step); \
50   src[15] = MS_LDQ_F32(src_data + 15 * src_step);
51 
52 #define Load36Data                                \
53   src[0] = MS_LDQ_F32(src_data + 0 * src_step);   \
54   src[1] = MS_LDQ_F32(src_data + 1 * src_step);   \
55   src[2] = MS_LDQ_F32(src_data + 2 * src_step);   \
56   src[3] = MS_LDQ_F32(src_data + 3 * src_step);   \
57   src[4] = MS_LDQ_F32(src_data + 4 * src_step);   \
58   src[5] = MS_LDQ_F32(src_data + 5 * src_step);   \
59   src[6] = MS_LDQ_F32(src_data + 6 * src_step);   \
60   src[7] = MS_LDQ_F32(src_data + 7 * src_step);   \
61   src[8] = MS_LDQ_F32(src_data + 8 * src_step);   \
62   src[9] = MS_LDQ_F32(src_data + 9 * src_step);   \
63   src[10] = MS_LDQ_F32(src_data + 10 * src_step); \
64   src[11] = MS_LDQ_F32(src_data + 11 * src_step); \
65   src[12] = MS_LDQ_F32(src_data + 12 * src_step); \
66   src[13] = MS_LDQ_F32(src_data + 13 * src_step); \
67   src[14] = MS_LDQ_F32(src_data + 14 * src_step); \
68   src[15] = MS_LDQ_F32(src_data + 15 * src_step); \
69   src[16] = MS_LDQ_F32(src_data + 16 * src_step); \
70   src[17] = MS_LDQ_F32(src_data + 17 * src_step); \
71   src[18] = MS_LDQ_F32(src_data + 18 * src_step); \
72   src[19] = MS_LDQ_F32(src_data + 19 * src_step); \
73   src[20] = MS_LDQ_F32(src_data + 20 * src_step); \
74   src[21] = MS_LDQ_F32(src_data + 21 * src_step); \
75   src[22] = MS_LDQ_F32(src_data + 22 * src_step); \
76   src[23] = MS_LDQ_F32(src_data + 23 * src_step); \
77   src[24] = MS_LDQ_F32(src_data + 24 * src_step); \
78   src[25] = MS_LDQ_F32(src_data + 25 * src_step); \
79   src[26] = MS_LDQ_F32(src_data + 26 * src_step); \
80   src[27] = MS_LDQ_F32(src_data + 27 * src_step); \
81   src[28] = MS_LDQ_F32(src_data + 28 * src_step); \
82   src[29] = MS_LDQ_F32(src_data + 29 * src_step); \
83   src[30] = MS_LDQ_F32(src_data + 30 * src_step); \
84   src[31] = MS_LDQ_F32(src_data + 31 * src_step); \
85   src[32] = MS_LDQ_F32(src_data + 32 * src_step); \
86   src[33] = MS_LDQ_F32(src_data + 33 * src_step); \
87   src[34] = MS_LDQ_F32(src_data + 34 * src_step); \
88   src[35] = MS_LDQ_F32(src_data + 35 * src_step);
89 
90 #define Load64Data                                \
91   src[0] = MS_LDQ_F32(src_data + 0 * src_step);   \
92   src[1] = MS_LDQ_F32(src_data + 1 * src_step);   \
93   src[2] = MS_LDQ_F32(src_data + 2 * src_step);   \
94   src[3] = MS_LDQ_F32(src_data + 3 * src_step);   \
95   src[4] = MS_LDQ_F32(src_data + 4 * src_step);   \
96   src[5] = MS_LDQ_F32(src_data + 5 * src_step);   \
97   src[6] = MS_LDQ_F32(src_data + 6 * src_step);   \
98   src[7] = MS_LDQ_F32(src_data + 7 * src_step);   \
99   src[8] = MS_LDQ_F32(src_data + 8 * src_step);   \
100   src[9] = MS_LDQ_F32(src_data + 9 * src_step);   \
101   src[10] = MS_LDQ_F32(src_data + 10 * src_step); \
102   src[11] = MS_LDQ_F32(src_data + 11 * src_step); \
103   src[12] = MS_LDQ_F32(src_data + 12 * src_step); \
104   src[13] = MS_LDQ_F32(src_data + 13 * src_step); \
105   src[14] = MS_LDQ_F32(src_data + 14 * src_step); \
106   src[15] = MS_LDQ_F32(src_data + 15 * src_step); \
107   src[16] = MS_LDQ_F32(src_data + 16 * src_step); \
108   src[17] = MS_LDQ_F32(src_data + 17 * src_step); \
109   src[18] = MS_LDQ_F32(src_data + 18 * src_step); \
110   src[19] = MS_LDQ_F32(src_data + 19 * src_step); \
111   src[20] = MS_LDQ_F32(src_data + 20 * src_step); \
112   src[21] = MS_LDQ_F32(src_data + 21 * src_step); \
113   src[22] = MS_LDQ_F32(src_data + 22 * src_step); \
114   src[23] = MS_LDQ_F32(src_data + 23 * src_step); \
115   src[24] = MS_LDQ_F32(src_data + 24 * src_step); \
116   src[25] = MS_LDQ_F32(src_data + 25 * src_step); \
117   src[26] = MS_LDQ_F32(src_data + 26 * src_step); \
118   src[27] = MS_LDQ_F32(src_data + 27 * src_step); \
119   src[28] = MS_LDQ_F32(src_data + 28 * src_step); \
120   src[29] = MS_LDQ_F32(src_data + 29 * src_step); \
121   src[30] = MS_LDQ_F32(src_data + 30 * src_step); \
122   src[31] = MS_LDQ_F32(src_data + 31 * src_step); \
123   src[32] = MS_LDQ_F32(src_data + 32 * src_step); \
124   src[33] = MS_LDQ_F32(src_data + 33 * src_step); \
125   src[34] = MS_LDQ_F32(src_data + 34 * src_step); \
126   src[35] = MS_LDQ_F32(src_data + 35 * src_step); \
127   src[36] = MS_LDQ_F32(src_data + 36 * src_step); \
128   src[37] = MS_LDQ_F32(src_data + 37 * src_step); \
129   src[38] = MS_LDQ_F32(src_data + 38 * src_step); \
130   src[39] = MS_LDQ_F32(src_data + 39 * src_step); \
131   src[40] = MS_LDQ_F32(src_data + 40 * src_step); \
132   src[41] = MS_LDQ_F32(src_data + 41 * src_step); \
133   src[42] = MS_LDQ_F32(src_data + 42 * src_step); \
134   src[43] = MS_LDQ_F32(src_data + 43 * src_step); \
135   src[44] = MS_LDQ_F32(src_data + 44 * src_step); \
136   src[45] = MS_LDQ_F32(src_data + 45 * src_step); \
137   src[46] = MS_LDQ_F32(src_data + 46 * src_step); \
138   src[47] = MS_LDQ_F32(src_data + 47 * src_step); \
139   src[48] = MS_LDQ_F32(src_data + 48 * src_step); \
140   src[49] = MS_LDQ_F32(src_data + 49 * src_step); \
141   src[50] = MS_LDQ_F32(src_data + 50 * src_step); \
142   src[51] = MS_LDQ_F32(src_data + 51 * src_step); \
143   src[52] = MS_LDQ_F32(src_data + 52 * src_step); \
144   src[53] = MS_LDQ_F32(src_data + 53 * src_step); \
145   src[54] = MS_LDQ_F32(src_data + 54 * src_step); \
146   src[55] = MS_LDQ_F32(src_data + 55 * src_step); \
147   src[56] = MS_LDQ_F32(src_data + 56 * src_step); \
148   src[57] = MS_LDQ_F32(src_data + 57 * src_step); \
149   src[58] = MS_LDQ_F32(src_data + 58 * src_step); \
150   src[59] = MS_LDQ_F32(src_data + 59 * src_step); \
151   src[60] = MS_LDQ_F32(src_data + 60 * src_step); \
152   src[61] = MS_LDQ_F32(src_data + 61 * src_step); \
153   src[62] = MS_LDQ_F32(src_data + 62 * src_step); \
154   src[63] = MS_LDQ_F32(src_data + 63 * src_step);
155 
156 InputTransFunc GetInputTransFunc(int input_unit);
157 
158 void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
159 
160 void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
161 
162 void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c);
163 
164 OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type);
165 
166 #define Store4Data                               \
167   MS_STQ_F32(dst_data, m[0]);                    \
168   MS_STQ_F32(dst_data + out_c, m[1]);            \
169   MS_STQ_F32(dst_data + dst_step * out_c, m[2]); \
170   MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[3]);
171 
172 #define Store9Data                                           \
173   MS_STQ_F32(dst_data, m[0]);                                \
174   MS_STQ_F32(dst_data + out_c, m[1]);                        \
175   MS_STQ_F32(dst_data + 2 * out_c, m[2]);                    \
176   MS_STQ_F32(dst_data + dst_step * out_c, m[3]);             \
177   MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[4]);     \
178   MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[5]); \
179   MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[6]);         \
180   MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
181   MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);
182 
183 #define Store16Data                                               \
184   MS_STQ_F32(dst_data, m[0]);                                     \
185   MS_STQ_F32(dst_data + out_c, m[1]);                             \
186   MS_STQ_F32(dst_data + 2 * out_c, m[2]);                         \
187   MS_STQ_F32(dst_data + 3 * out_c, m[3]);                         \
188   MS_STQ_F32(dst_data + dst_step * out_c, m[4]);                  \
189   MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[5]);          \
190   MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[6]);      \
191   MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[7]);      \
192   MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[8]);              \
193   MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[9]);      \
194   MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \
195   MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \
196   MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[12]);             \
197   MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[13]);     \
198   MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
199   MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);
200 
201 #define Store25Data                                               \
202   MS_STQ_F32(dst_data, m[0]);                                     \
203   MS_STQ_F32(dst_data + out_c, m[1]);                             \
204   MS_STQ_F32(dst_data + 2 * out_c, m[2]);                         \
205   MS_STQ_F32(dst_data + 3 * out_c, m[3]);                         \
206   MS_STQ_F32(dst_data + 4 * out_c, m[4]);                         \
207   MS_STQ_F32(dst_data + dst_step * out_c, m[5]);                  \
208   MS_STQ_F32(dst_data + dst_step * out_c + out_c, m[6]);          \
209   MS_STQ_F32(dst_data + dst_step * out_c + 2 * out_c, m[7]);      \
210   MS_STQ_F32(dst_data + dst_step * out_c + 3 * out_c, m[8]);      \
211   MS_STQ_F32(dst_data + dst_step * out_c + 4 * out_c, m[9]);      \
212   MS_STQ_F32(dst_data + 2 * dst_step * out_c, m[10]);             \
213   MS_STQ_F32(dst_data + 2 * dst_step * out_c + out_c, m[11]);     \
214   MS_STQ_F32(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \
215   MS_STQ_F32(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \
216   MS_STQ_F32(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \
217   MS_STQ_F32(dst_data + 3 * dst_step * out_c, m[15]);             \
218   MS_STQ_F32(dst_data + 3 * dst_step * out_c + out_c, m[16]);     \
219   MS_STQ_F32(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \
220   MS_STQ_F32(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \
221   MS_STQ_F32(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \
222   MS_STQ_F32(dst_data + 4 * dst_step * out_c, m[20]);             \
223   MS_STQ_F32(dst_data + 4 * dst_step * out_c + out_c, m[21]);     \
224   MS_STQ_F32(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \
225   MS_STQ_F32(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
226   MS_STQ_F32(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
227 
228 void OutputTransform4x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
229                             int out_c, int r_w, int r_h, int r_c);
230 void OutputTransform4x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
231                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
232 void OutputTransform4x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
233                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
234 void OutputTransform4x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
235                             int out_c, int r_w, int r_h, int r_c);
236 void OutputTransform4x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
237                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
238 void OutputTransform4x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
239                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
240 
241 void OutputTransform6x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
242                             int out_c, int r_w, int r_h, int r_c);
243 void OutputTransform6x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
244                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
245 void OutputTransform6x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
246                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
247 void OutputTransform6x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
248                             int out_c, int r_w, int r_h, int r_c);
249 void OutputTransform6x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
250                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
251 void OutputTransform6x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
252                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
253 void OutputTransform6x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
254                             int out_c, int r_w, int r_h, int r_c);
255 void OutputTransform6x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
256                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
257 void OutputTransform6x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
258                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
259 void OutputTransform6x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
260                             int out_c, int r_w, int r_h, int r_c);
261 void OutputTransform6x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
262                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
263 void OutputTransform6x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
264                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
265 
266 void OutputTransform8x2Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
267                             int out_c, int r_w, int r_h, int r_c);
268 void OutputTransform8x2ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
269                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
270 void OutputTransform8x2Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
271                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
272 void OutputTransform8x3Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
273                             int out_c, int r_w, int r_h, int r_c);
274 void OutputTransform8x3ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
275                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
276 void OutputTransform8x3Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
277                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
278 void OutputTransform8x4Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
279                             int out_c, int r_w, int r_h, int r_c);
280 void OutputTransform8x4ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
281                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
282 void OutputTransform8x4Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
283                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
284 void OutputTransform8x5Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
285                             int out_c, int r_w, int r_h, int r_c);
286 void OutputTransform8x5ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
287                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
288 void OutputTransform8x5Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
289                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
290 void OutputTransform8x6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
291                             int out_c, int r_w, int r_h, int r_c);
292 void OutputTransform8x6ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
293                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
294 void OutputTransform8x6Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
295                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
296 void OutputTransform8x7Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step,
297                             int out_c, int r_w, int r_h, int r_c);
298 void OutputTransform8x7ReluUnit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
299                                 int dst_step, int out_c, int r_w, int r_h, int r_c);
300 void OutputTransform8x7Relu6Unit(const float *src_data, float *dst_data, const float *bias_data, int src_step,
301                                  int dst_step, int out_c, int r_w, int r_h, int r_c);
302 
303 int SelectOutputUnit(const ConvParameter *conv_param);
304 
305 #ifdef __cplusplus
306 }
307 #endif
308 
309 #endif  // MINDSPORE_NNACL_WINOGRAD_UTILS_H_
310