• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1R"(
2
3/*
4 * Copyright (c) 2016-2018 Arm Limited.
5 *
6 * SPDX-License-Identifier: MIT
7 *
8 * Permission is hereby granted, free of charge, to any person obtaining a copy
9 * of this software and associated documentation files (the "Software"), to
10 * deal in the Software without restriction, including without limitation the
11 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
12 * sell copies of the Software, and to permit persons to whom the Software is
13 * furnished to do so, subject to the following conditions:
14 *
15 * The above copyright notice and this permission notice shall be included in all
16 * copies or substantial portions of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 * SOFTWARE.
25 */
26/*
27 * Copyright (c) 2016-2020 Arm Limited.
28 *
29 * SPDX-License-Identifier: MIT
30 *
31 * Permission is hereby granted, free of charge, to any person obtaining a copy
32 * of this software and associated documentation files (the "Software"), to
33 * deal in the Software without restriction, including without limitation the
34 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
35 * sell copies of the Software, and to permit persons to whom the Software is
36 * furnished to do so, subject to the following conditions:
37 *
38 * The above copyright notice and this permission notice shall be included in all
39 * copies or substantial portions of the Software.
40 *
41 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 * SOFTWARE.
48 */
49#ifndef ARM_COMPUTE_HELPER_H
50#define ARM_COMPUTE_HELPER_H
51
52/*
53 * Copyright (c) 2020 Arm Limited.
54 *
55 * SPDX-License-Identifier: MIT
56 *
57 * Permission is hereby granted, free of charge, to any person obtaining a copy
58 * of this software and associated documentation files (the "Software"), to
59 * deal in the Software without restriction, including without limitation the
60 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
61 * sell copies of the Software, and to permit persons to whom the Software is
62 * furnished to do so, subject to the following conditions:
63 *
64 * The above copyright notice and this permission notice shall be included in all
65 * copies or substantial portions of the Software.
66 *
67 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
68 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
69 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
70 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
71 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
72 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
73 * SOFTWARE.
74 */
75
76/** Store the 0 to (n-1)th rows of the given variables
77 * @name STORE_ROW_n
78 *
79 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
80 * @param[in] DATA_TYPE The data type of the vectors
81 * @param[in] BASENAME  The basename of the variables
82 * @param[in] PTR       The base pointer
83 * @param[in] STRIDE_Y  The stride value in y-axis direction
84 * @param[in] Z         The offset in z-axis direction
85 * @{
86 */
87#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
88    VSTORE(N0)                                                 \
89    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
90
91#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
92    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
93    VSTORE(N0)                                                 \
94    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
95
96#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
97    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
98    VSTORE(N0)                                                 \
99    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
100
101#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
102    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
103    VSTORE(N0)                                                 \
104    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
105
106#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
107    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
108    VSTORE(N0)                                                 \
109    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
110
111#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
112    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
113    VSTORE(N0)                                                 \
114    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
115
116#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
117    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
118    VSTORE(N0)                                                 \
119    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
120
121#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
122    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
123    VSTORE(N0)                                                 \
124    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
125
126#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
127    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
128    VSTORE(N0)                                                 \
129    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
130
131#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
132    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
133    VSTORE(N0)                                                  \
134    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
135
136#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
137    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
138    VSTORE(N0)                                                  \
139    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
140
141#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
142    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
143    VSTORE(N0)                                                  \
144    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
145
146#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
147    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
148    VSTORE(N0)                                                  \
149    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
150
151#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
152    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
153    VSTORE(N0)                                                  \
154    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
155
156#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
157    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
158    VSTORE(N0)                                                  \
159    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
160
161#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
162    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
163    VSTORE(N0)                                                  \
164    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
165/** @} */ // end of groupd STORE_ROW_n
166
167/** Convert and store the 0th to (n-1)th rows of the given variables
168 * @name CONVERT_STORE_ROW_n
169 *
170 * @param[in] N0        The size of the vectors
171 * @param[in] DATA_TYPE The data type of the vectors
172 * @param[in] BASENAME  The basename of the variables
173 * @param[in] PTR       The base pointer
174 * @param[in] STRIDE_Y  The stride value in y-axis direction
175 * @param[in] Z         The offset in z-axis direction
176 * @{
177 */
178#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
179    VSTORE(N0)                                                         \
180    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
181
182#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
183    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
184    VSTORE(N0)                                                         \
185    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
186
187#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
188    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
189    VSTORE(N0)                                                         \
190    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
191
192#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
193    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
194    VSTORE(N0)                                                         \
195    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
196
197#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
198    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
199    VSTORE(N0)                                                         \
200    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
201
202#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
203    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
204    VSTORE(N0)                                                         \
205    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
206
207#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
208    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
209    VSTORE(N0)                                                         \
210    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
211
212#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
213    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
214    VSTORE(N0)                                                         \
215    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
216
217#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
218    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
219    VSTORE(N0)                                                         \
220    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
221
222#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
223    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
224    VSTORE(N0)                                                     \
225    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
226
227#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
228    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
229    VSTORE(N0)                                                          \
230    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
231
232#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
233    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
234    VSTORE(N0)                                                          \
235    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
236
237#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
238    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
239    VSTORE(N0)                                                          \
240    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
241
242#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
243    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
244    VSTORE(N0)                                                          \
245    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
246
247#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
248    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
249    VSTORE(N0)                                                          \
250    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
251
252#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
253    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
254    VSTORE(N0)                                                          \
255    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
256
257/** @} */ // end of groupd CONVERT_STORE_ROW_n
258
259/** Store a block of the given size M0xN0
260 * @name STORE_BLOCK
261 *
262 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
263 * The data to store is expected to have consecutive names for each row.
264 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
265 * The Z offset is expected to have consecutive names.
266 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
267 *
268 * @param[in] M0        The number of rows to store
269 * @param[in] N0        The size of each vector
270 * @param[in] DATA_TYPE The data type of the vectors
271 * @param[in] BASENAME  The basename of the variables
272 * @param[in] PTR       The base pointer
273 * @param[in] STRIDE_Y  The stride value in y-axis direction
274 * @param[in] Z         The offset in z-axis direction
275 * @{
276 */
277#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
278#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
279/** @} */ // end of group STORE_BLOCK
280
281/** Convert and store a block of the given size M0xN0
282 * @name CONVERT_STORE_BLOCK
283 *
284 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
285 * The data to store is expected to have consecutive names for each row.
286 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
287 * The Z offset is expected to have consecutive names.
288 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
289 *
290 * @param[in] M0        The number of rows to store
291 * @param[in] N0        The size of each vector
292 * @param[in] DATA_TYPE The data type of the vectors
293 * @param[in] BASENAME  The basename of the variables
294 * @param[in] PTR       The base pointer
295 * @param[in] STRIDE_Y  The stride value in y-axis direction
296 * @param[in] Z         The offset in z-axis direction
297 * @{
298 */
299#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
300#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
301/** @} */ // end of group CONVERT_STORE_BLOCK
302
303/** Partially store the 0 to (n-1)th rows of the given variables
304 * @name STORE_ROW_PARTIAL_n
305 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
306 *
307 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
308 *
309 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
310 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
311 * @param[in] DATA_TYPE The data type of the vectors
312 * @param[in] BASENAME  The basename of the variables
313 * @param[in] PTR       The base pointer
314 * @param[in] STRIDE_Y  The stride value in y-axis direction
315 * @param[in] Z         The offset in z-axis direction
316 * @{
317 */
318#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
319    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
320    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
321
322#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
323    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
324    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
325    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
326
327#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
328    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
329    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
330    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
331
332#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
333    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
334    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
335    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
336
337#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
338    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
339    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
340    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
341
342#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
343    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
344    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
345    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
346
347#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
348    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
349    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
350    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
351
352#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
353    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
354    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
355    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
356
357#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
358    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
359    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
360    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
361
362#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
363    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
364    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
365    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
366
367#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
368    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
369    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
370    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
371
372#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
373    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
374    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
375    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
376
377#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
378    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
379    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
380    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
381
382#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
383    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
384    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
385    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
386
387#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
388    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
389    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
390    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
391
392#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
393    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
394    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
395    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
396/** @} */ // end of groupd STORE_ROW_PARTIAL_n
397
398/** Partially store a block of the given size STORE_M0xSTORE_N0
399 * @name STORE_BLOCK_PARTIAL
400 *
401 * @note The vector width @p N0 is also required for correct partial storing behaviour.
402 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
403 *
404 * The data to store is expected to have consecutive names for each row.
405 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
406 * The Z offset is expected to have consecutive names.
407 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
408 *
409 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
410 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
411 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
412 * @param[in] DATA_TYPE The data type of the vectors
413 * @param[in] BASENAME  The basename of the variables
414 * @param[in] PTR       The base pointer
415 * @param[in] STRIDE_Y  The stride value in y-axis direction
416 * @param[in] Z         The offset in z-axis direction
417 * @{
418 */
419#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
420#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
421/** Store a block that can be partial in both x and y dimensions
422 *
423 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
424 *
425 * The data to store is expected to have consecutive names for each row.
426 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
427 * The Z offset is expected to have consecutive names.
428 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
429 *
430 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
431 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
432 * @param[in] DATA_TYPE        The data type of the vectors
433 * @param[in] BASENAME         The basename of the variables
434 * @param[in] PTR              The base pointer
435 * @param[in] STRIDE_Y         The stride value in y-axis direction
436 * @param[in] Z                The offset in z-axis direction
437 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
438 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
439 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
440 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
441 */
442#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
443    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
444    {                                                                                                                                                     \
445        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
446    }                                                                                                                                                     \
447    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
448    {                                                                                                                                                     \
449        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
450    }                                                                                                                                                     \
451    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
452    {                                                                                                                                                     \
453        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
454    }                                                                                                                                                     \
455    else                                                                                                                                                  \
456    {                                                                                                                                                     \
457        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
458    }
459/** Store a block that can only be partial in x but not y.
460 *
461 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
462 *
463 * The data to store is expected to have consecutive names for each row.
464 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
465 * The Z offset is expected to have consecutive names.
466 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
467 *
468 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
469 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
470 * @param[in] DATA_TYPE        The data type of the vectors
471 * @param[in] BASENAME         The basename of the variables
472 * @param[in] PTR              The base pointer
473 * @param[in] STRIDE_Y         The stride value in y-axis direction
474 * @param[in] Z                The offset in z-axis direction
475 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
476 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
477 */
478#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
479    if(!(PARTIAL_COND_X))                                                                                         \
480    {                                                                                                             \
481        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
482    }                                                                                                             \
483    else                                                                                                          \
484    {                                                                                                             \
485        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
486    }
487/** Store a block that can only be partial in y but not x.
488 *
489 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
490 *
491 * The data to store is expected to have consecutive names for each row.
492 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
493 * The Z offset is expected to have consecutive names.
494 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
495 *
496 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
497 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
498 * @param[in] DATA_TYPE        The data type of the vectors
499 * @param[in] BASENAME         The basename of the variables
500 * @param[in] PTR              The base pointer
501 * @param[in] STRIDE_Y         The stride value in y-axis direction
502 * @param[in] Z                The offset in z-axis direction
503 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
504 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
505 */
506#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
507    if(!(PARTIAL_COND_Y))                                                                                         \
508    {                                                                                                             \
509        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
510    }                                                                                                             \
511    else                                                                                                          \
512    {                                                                                                             \
513        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
514    }
515/** @} */ // end of group STORE_BLOCK_PARTIAL
516
517#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
518
519/** Boundary-aware GEMM block store
520 * @name STORE_BLOCK_BOUNDARY_AWARE
521 * This macro assumes the following schemes to achieve boundary-awareness:
522 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
523 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
524 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
525 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
526 *
527 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
528 * blocks **at the end**.
529 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
530 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
531 *
532 *  *--x-->                         x == 0                        x == 1
533 *  |                  |<------------------------------N-------------------------->|
534 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
535 *  |     -------------#############################################################
536 *  *     |          | |...............................|...........................|
537 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
538 *        |          | |...............................|...........................|
539 *        M          --#############################################################
540 *        |          | |                               |...........................|
541 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
542 *        |          | |                               |...........................|
543 *        |------------#############################################################
544 *
545 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
546 *
547 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
548 *
549 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
550 * and select corresponding store methods such that the boundary detection logic is only added when needed.
551 *
552 * The data to store is expected to have consecutive names for each row.
553 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
554 * The Z offset is expected to have consecutive names.
555 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
556 *
557 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
558 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
559 * @param[in] DATA_TYPE        The data type of the vectors
560 * @param[in] BASENAME         The basename of the variables
561 * @param[in] PTR              The base pointer
562 * @param[in] STRIDE_Y         The stride value in y-axis direction
563 * @param[in] Z                The offset in z-axis direction
564 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
565 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
566 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
567 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
568 * @{
569 */
570#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
571// Case1: No partial blocks in either x or y
572#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
573    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
574
575#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
576// Case2: Partial blocks in y
577#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
578    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
579
580#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
581// Case3: Partial blocks in x
582#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
583    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
584
585#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
586// Case4: Partial blocks in both x and y
587#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
588    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
589
590#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
591
592#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
593/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
594
595#if defined(PARTIAL_STORE_M0)
596/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
597 * @name COMPUTE_M0_START_ROW
598 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
599 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
600 * blocks in the y dimension to avoid any padding.
601 * EG: M0=4, PARTIAL_STORE_M0=1:
602 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
603 * block 0 (partial)| start row = 0   | start row = 0
604 * block 1 (full)   | start row = 4   | start row = 1
605 * block 2 (full)   | start row = 8   | start row = 5
606 *
607 * @param[in] y                Global id of current block in y.
608 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
609 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
610 * @{
611 */
612#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
613    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
614#else // defined(PARTIAL_STORE_M0)
615#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
616    ((uint)(y * M0))
617#endif    // defined(PARTIAL_STORE_M0)
618/** @} */ // end of group COMPUTE_M0_START_ROW
619
620/** Store a vector that can only be partial in x.
621 *
622 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
623 *
624 * The data to store is expected to end in a 0.
625 * E.g., for basename=c, the expected name is c0.
626 *
627 * @param[in] basename  The name of the variable without trailing 0
628 * @param[in] data_type The data type of the vector
629 * @param[in] ptr       The base pointer
630 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
631 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
632 * @param[in] cond      Condition to select either vec_size0 or vec_size1
633 * @{
634 */
635#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
636    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
637/** @} */ // end of group STORE_VECTOR_SELECT
638
639#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
640#pragma OPENCL EXTENSION cl_khr_fp16 : enable
641#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
642
643#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
644#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
645#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
646
647#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
648#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
649#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
650
651#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
652#pragma OPENCL EXTENSION cl_arm_printf : enable
653#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
654
655#define GPU_ARCH_MIDGARD 0x100
656#define GPU_ARCH_BIFROST 0x200
657
658/** Concatenate two inputs.
659 *
660 * @param[in] a The first input to be concatenated
661 * @param[in] b The second input to be concatenated
662 *
663 * @return The concatenated output
664 */
665#define CONCAT(a, b) a##b
666
667/** Expand the given vector
668 *
669 * @param[in] x The vector to be expanded
670 *
671 * @return The expanded output
672 */
673#define EXPAND(x) x
674
675/** Clamp the given value between an upper and lower bound.
676 *
677 * @param[in] x       The value to be clamped
678 * @param[in] min_val The lower bound
679 * @param[in] max_val The upper bound
680 *
681 * @return The clamped value.
682 */
683#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
684
685/** REVn reverses the given vector whose size is n.
686 * @name REVn
687 *
688 * @param[in] x The vector to be reversed
689 *
690 * @return The reversed vector
691 * @{
692 */
693#define REV1(x) ((x))
694#define REV2(x) ((x).s10)
695#define REV3(x) ((x).s210)
696#define REV4(x) ((x).s3210)
697#define REV8(x) ((x).s76543210)
698#define REV16(x) ((x).sFEDCBA9876543210)
699/** @} */ // end of group REVn
700
701/** Reverse the given vector.
702 * @name REVERSE
703 *
704 * @param[in] x The vector to be reversed
705 * @param[in] s The size of the vector
706 *
707 * @return The reversed vector
708 * @{
709 */
710#define REVERSE_STR(x, s) REV##s((x))
711#define REVERSE(x, s) REVERSE_STR(x, s)
712/** @} */ // end of group REVERSE
713
714/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
715 * @name ROTs_n
716 *
717 * @param[in] x The vector to be shifted
718 *
719 * @return The shifted vector
720 * @{
721 */
722#define ROT1_0(x) ((x))
723
724#define ROT2_0(x) ((x))
725#define ROT2_1(x) ((x).s10)
726
727#define ROT3_0(x) ((x))
728#define ROT3_1(x) ((x).s201)
729#define ROT3_2(x) ((x).s120)
730
731#define ROT4_0(x) ((x))
732#define ROT4_1(x) ((x).s3012)
733#define ROT4_2(x) ((x).s2301)
734#define ROT4_3(x) ((x).s1230)
735
736#define ROT8_0(x) ((x))
737#define ROT8_1(x) ((x).s70123456)
738#define ROT8_2(x) ((x).s67012345)
739#define ROT8_3(x) ((x).s56701234)
740#define ROT8_4(x) ((x).s45670123)
741#define ROT8_5(x) ((x).s34567012)
742#define ROT8_6(x) ((x).s23456701)
743#define ROT8_7(x) ((x).s12345670)
744
745#define ROT16_0(x) ((x))
746#define ROT16_1(x) ((x).sF0123456789ABCDE)
747#define ROT16_2(x) ((x).sEF0123456789ABCD)
748#define ROT16_3(x) ((x).sDEF0123456789ABC)
749#define ROT16_4(x) ((x).sCDEF0123456789AB)
750#define ROT16_5(x) ((x).sBCDEF0123456789A)
751#define ROT16_6(x) ((x).sABCDEF0123456789)
752#define ROT16_7(x) ((x).s9ABCDEF012345678)
753#define ROT16_8(x) ((x).s89ABCDEF01234567)
754#define ROT16_9(x) ((x).s789ABCDEF0123456)
755#define ROT16_10(x) ((x).s6789ABCDEF012345)
756#define ROT16_11(x) ((x).s56789ABCDEF01234)
757#define ROT16_12(x) ((x).s456789ABCDEF0123)
758#define ROT16_13(x) ((x).s3456789ABCDEF012)
759#define ROT16_14(x) ((x).s23456789ABCDEF01)
760#define ROT16_15(x) ((x).s123456789ABCDEF0)
761/** @} */ // end of group ROTs_n
762
763/** Circular-right-shift (rotate-right) the given vector by the given amount.
764 * @name ROTATE
765 *
766 * @param[in] x The vector to be shifted
767 * @param[in] s The size of the vector
768 * @param[in] n The amount to be shifted
769 *
770 * @return The shifted vector
771 * @{
772 */
773#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
774#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
775/** @} */ // end of group ROTATE
776
777/** Creates a vector of size n filled with offset values corresponding to the location of each element.
778 * @name V_OFFSn
779 *
780 * @param[in] dt The data type of the output vector
781 *
782 * @return The vector filled with offset values
783 * @{
784 */
785#define V_OFFS1(dt) (dt##1)(0)
786#define V_OFFS2(dt) (dt##2)(0, 1)
787#define V_OFFS3(dt) (dt##3)(0, 1, 2)
788#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
789#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
790#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
791/** @} */ // end of group V_OFFSn
792
793/** Create a vector filled with offset values corresponding to the location of each element.
794 * @name VEC_OFFS
795 *
796 * @param[in] dt The data type of the output vector
797 * @param[in] s  The size of the output vector
798 *
799 * @return The vector filled with offset values
800 * @{
801 */
802#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
803#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
804/** @} */ // end of group VEC_OFFS
805
806#define VLOAD_STR(size) vload##size
807#define VLOAD(size) VLOAD_STR(size)
808
809#define PIXEL_UNIT4 1
810#define PIXEL_UNIT8 2
811#define PIXEL_UNIT16 4
812
813/** Utility macro to convert a vector size in pixel unit.
814 *
815 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
816 *
817 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
818 *
819 * @return The pixel unit (number of pixels)
820 * @{
821 */
822#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
823#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
824/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
825
826#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
827#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
828#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
829
830#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
831#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
832#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
833#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
834#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
835
836/** Utility macro to read a 2D OpenCL image object.
837 *
838 * @note Coordinates are not normalized
839 *
840 * @param[in] data_type Data type
841 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
842 * @param[in] img       OpenCL image object
843 * @param[in] x_coord   The x coordinate for the top-left pixel
844 * @param[in] y_coord   The y coordinate for the top-left pixel
845 *
846 * @return Pixels from the 2D OpenCL image object
847 * @{
848 */
849#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
850#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
851
852#define VSTORE_STR(size) vstore##size
853#define VSTORE(size) VSTORE_STR(size)
854
855#define float1 float
856#define half1 half
857#define char1 char
858#define uchar1 uchar
859#define short1 short
860#define ushort1 ushort
861#define int1 int
862#define uint1 uint
863#define long1 long
864#define ulong1 ulong
865#define double1 double
866
867#define vload1(OFFSET, PTR) *(OFFSET + PTR)
868#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
869
870/** Extended partial vstore that correctly handles scalar values as well.
871 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
872 * @name VSTORE_PARTIAL
873 *
874 * @note With this macro, the passed data can be both a vector and a scalar
875 * @note @p store_size needs to be <= @p size
876 * eg 1: Valid
877 * VSTORE_PARTIAL(16, 15) ...;
878 * eg 2: Invalid
879 * VSTORE_PARTIAL(4, 7) ...;
880 *
881 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
882 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
883 * @{
884 */
885#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
886#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
887
888#define NO_STORE(data, offs, ptr) \
889    {                             \
890    }
891
892// Size == 1 (scalar)
893#define vstore_partial_1_0 NO_STORE
894#define vstore_partial_1_1 vstore1
895#define vstore_partial_1_2 NO_STORE
896#define vstore_partial_1_3 NO_STORE
897#define vstore_partial_1_4 NO_STORE
898#define vstore_partial_1_5 NO_STORE
899#define vstore_partial_1_6 NO_STORE
900#define vstore_partial_1_7 NO_STORE
901#define vstore_partial_1_8 NO_STORE
902#define vstore_partial_1_9 NO_STORE
903#define vstore_partial_1_10 NO_STORE
904#define vstore_partial_1_11 NO_STORE
905#define vstore_partial_1_12 NO_STORE
906#define vstore_partial_1_13 NO_STORE
907#define vstore_partial_1_14 NO_STORE
908#define vstore_partial_1_15 NO_STORE
909#define vstore_partial_1_16 NO_STORE
910// Size == 2
911#define vstore_partial_2_0 NO_STORE
912#define vstore_partial_2_1 vstore_partial_1
913#define vstore_partial_2_2 vstore_partial_2
914#define vstore_partial_2_3 NO_STORE
915#define vstore_partial_2_4 NO_STORE
916#define vstore_partial_2_5 NO_STORE
917#define vstore_partial_2_6 NO_STORE
918#define vstore_partial_2_7 NO_STORE
919#define vstore_partial_2_8 NO_STORE
920#define vstore_partial_2_9 NO_STORE
921#define vstore_partial_2_10 NO_STORE
922#define vstore_partial_2_11 NO_STORE
923#define vstore_partial_2_12 NO_STORE
924#define vstore_partial_2_13 NO_STORE
925#define vstore_partial_2_14 NO_STORE
926#define vstore_partial_2_15 NO_STORE
927#define vstore_partial_2_16 NO_STORE
928// Size == 3
929#define vstore_partial_3_0 NO_STORE
930#define vstore_partial_3_1 vstore_partial_1
931#define vstore_partial_3_2 vstore_partial_2
932#define vstore_partial_3_3 vstore_partial_3
933#define vstore_partial_3_4 NO_STORE
934#define vstore_partial_3_5 NO_STORE
935#define vstore_partial_3_6 NO_STORE
936#define vstore_partial_3_7 NO_STORE
937#define vstore_partial_3_8 NO_STORE
938#define vstore_partial_3_9 NO_STORE
939#define vstore_partial_3_10 NO_STORE
940#define vstore_partial_3_11 NO_STORE
941#define vstore_partial_3_12 NO_STORE
942#define vstore_partial_3_13 NO_STORE
943#define vstore_partial_3_14 NO_STORE
944#define vstore_partial_3_15 NO_STORE
945#define vstore_partial_3_16 NO_STORE
946// Size == 4
947#define vstore_partial_4_0 NO_STORE
948#define vstore_partial_4_1 vstore_partial_1
949#define vstore_partial_4_2 vstore_partial_2
950#define vstore_partial_4_3 vstore_partial_3
951#define vstore_partial_4_4 vstore_partial_4
952#define vstore_partial_4_5 NO_STORE
953#define vstore_partial_4_6 NO_STORE
954#define vstore_partial_4_7 NO_STORE
955#define vstore_partial_4_8 NO_STORE
956#define vstore_partial_4_9 NO_STORE
957#define vstore_partial_4_10 NO_STORE
958#define vstore_partial_4_11 NO_STORE
959#define vstore_partial_4_12 NO_STORE
960#define vstore_partial_4_13 NO_STORE
961#define vstore_partial_4_14 NO_STORE
962#define vstore_partial_4_15 NO_STORE
963#define vstore_partial_4_16 NO_STORE
964// Size == 8
965#define vstore_partial_8_0 NO_STORE
966#define vstore_partial_8_1 vstore_partial_1
967#define vstore_partial_8_2 vstore_partial_2
968#define vstore_partial_8_3 vstore_partial_3
969#define vstore_partial_8_4 vstore_partial_4
970#define vstore_partial_8_5 vstore_partial_5
971#define vstore_partial_8_6 vstore_partial_6
972#define vstore_partial_8_7 vstore_partial_7
973#define vstore_partial_8_8 vstore_partial_8
974#define vstore_partial_8_9 NO_STORE
975#define vstore_partial_8_10 NO_STORE
976#define vstore_partial_8_11 NO_STORE
977#define vstore_partial_8_12 NO_STORE
978#define vstore_partial_8_13 NO_STORE
979#define vstore_partial_8_14 NO_STORE
980#define vstore_partial_8_15 NO_STORE
981#define vstore_partial_8_16 NO_STORE
982// Size == 16
983#define vstore_partial_16_0 NO_STORE
984#define vstore_partial_16_1 vstore_partial_1
985#define vstore_partial_16_2 vstore_partial_2
986#define vstore_partial_16_3 vstore_partial_3
987#define vstore_partial_16_4 vstore_partial_4
988#define vstore_partial_16_5 vstore_partial_5
989#define vstore_partial_16_6 vstore_partial_6
990#define vstore_partial_16_7 vstore_partial_7
991#define vstore_partial_16_8 vstore_partial_8
992#define vstore_partial_16_9 vstore_partial_9
993#define vstore_partial_16_10 vstore_partial_10
994#define vstore_partial_16_11 vstore_partial_11
995#define vstore_partial_16_12 vstore_partial_12
996#define vstore_partial_16_13 vstore_partial_13
997#define vstore_partial_16_14 vstore_partial_14
998#define vstore_partial_16_15 vstore_partial_15
999#define vstore_partial_16_16 vstore_partial_16
1000
1001/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
1002 * @name vstore_partial_n
1003 *
1004 * @note @p DATA needs to be a vector not a scalar
1005 * @note n needs to be <= the vector width of the input variable @p DATA
1006 * eg 1: Valid
1007 * vstore_partial_15(var:float16, 0, 0xabcd);
1008 * eg 2: Invalid
1009 * vstore_partial_7(var:float4, 0, 0xabcd);
1010 *
1011 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
1012 *
1013 * @param[in] DATA   The name of the variable
1014 * @param[in] OFFSET Offset in n
1015 * @param[in] PTR    The base pointer
1016 * @{
1017 */
1018#define vstore_partial_1(DATA, OFFSET, PTR) \
1019    vstore1(DATA.s0, OFFSET, PTR);
1020
1021#define vstore_partial_2(DATA, OFFSET, PTR) \
1022    vstore2(DATA.s01, OFFSET, PTR);
1023
1024#define vstore_partial_3(DATA, OFFSET, PTR) \
1025    vstore3(DATA.s012, OFFSET, PTR);
1026
1027#define vstore_partial_4(DATA, OFFSET, PTR) \
1028    vstore4(DATA.s0123, OFFSET, PTR);
1029
1030#define vstore_partial_5(DATA, OFFSET, PTR)    \
1031    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1032    vstore1(DATA.s4, OFFSET, PTR + 4);
1033
1034#define vstore_partial_6(DATA, OFFSET, PTR)    \
1035    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1036    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
1037
1038#define vstore_partial_7(DATA, OFFSET, PTR)    \
1039    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1040    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
1041
1042#define vstore_partial_8(DATA, OFFSET, PTR) \
1043    vstore8(DATA.s01234567, OFFSET, PTR);
1044
1045#define vstore_partial_9(DATA, OFFSET, PTR)        \
1046    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1047    vstore1(DATA.s8, OFFSET, PTR + 8);
1048
1049#define vstore_partial_10(DATA, OFFSET, PTR)       \
1050    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1051    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
1052
1053#define vstore_partial_11(DATA, OFFSET, PTR)       \
1054    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1055    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
1056
1057#define vstore_partial_12(DATA, OFFSET, PTR)       \
1058    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1059    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
1060
1061#define vstore_partial_13(DATA, OFFSET, PTR)       \
1062    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1063    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
1064
1065#define vstore_partial_14(DATA, OFFSET, PTR)       \
1066    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1067    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
1068
1069#define vstore_partial_15(DATA, OFFSET, PTR)       \
1070    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1071    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
1072
1073#define vstore_partial_16(DATA, OFFSET, PTR) \
1074    vstore16(DATA, OFFSET, PTR);
1075/** @} */ // end of groupd vstore_partial_n
1076/** @} */ // end of groupd VSTORE_PARTIAL
1077
1078// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
1079// without _sat to overcome this issue
1080#define convert_float_sat convert_float
1081#define convert_float1_sat convert_float
1082#define convert_float2_sat convert_float2
1083#define convert_float3_sat convert_float3
1084#define convert_float4_sat convert_float4
1085#define convert_float8_sat convert_float8
1086#define convert_float16_sat convert_float16
1087#define convert_half_sat convert_float
1088#define convert_half1_sat convert_half
1089#define convert_half2_sat convert_half2
1090#define convert_half3_sat convert_half3
1091#define convert_half4_sat convert_half4
1092#define convert_half8_sat convert_half8
1093#define convert_half16_sat convert_half16
1094
1095#define convert_float1 convert_float
1096#define convert_half1 convert_half
1097#define convert_char1 convert_char
1098#define convert_uchar1 convert_uchar
1099#define convert_short1 convert_short
1100#define convert_ushort1 convert_ushort
1101#define convert_int1 convert_int
1102#define convert_uint1 convert_uint
1103#define convert_long1 convert_long
1104#define convert_ulong1 convert_ulong
1105#define convert_double1 convert_double
1106
1107#define convert_char1_sat convert_char_sat
1108#define convert_uchar1_sat convert_uchar_sat
1109#define convert_short1_sat convert_short_sat
1110#define convert_ushort1_sat convert_ushort_sat
1111#define convert_int1_sat convert_int_sat
1112#define convert_uint1_sat convert_uint_sat
1113#define convert_long1_sat convert_long_sat
1114#define convert_ulong1_sat convert_ulong_sat
1115#define convert_double1_sat convert_double_sat
1116
1117#define VEC_DATA_TYPE_STR(type, size) type##size
1118#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
1119
1120#define CONVERT_STR(x, type) (convert_##type((x)))
1121#define CONVERT(x, type) CONVERT_STR(x, type)
1122
1123#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
1124#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
1125
1126#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
1127#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
1128
1129#define select_vec_dt_uchar(size) uchar##size
1130#define select_vec_dt_char(size) char##size
1131#define select_vec_dt_ushort(size) ushort##size
1132#define select_vec_dt_short(size) short##size
1133#define select_vec_dt_half(size) short##size
1134#define select_vec_dt_uint(size) uint##size
1135#define select_vec_dt_int(size) int##size
1136#define select_vec_dt_float(size) int##size
1137#define select_vec_dt_ulong(size) ulong##size
1138#define select_vec_dt_long(size) long##size
1139
1140#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
1141#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
1142#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
1143
1144#define sum_reduce_1(x) (x)
1145#define sum_reduce_2(x) ((x).s0) + ((x).s1)
1146#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
1147#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
1148#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
1149#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
1150
1151#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
1152#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
1153
1154#define max_reduce_1(x) (x)
1155#define max_reduce_2(x) max(((x).s0), ((x).s1))
1156#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
1157#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
1158#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
1159#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
1160
1161#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
1162#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
1163
1164#define VECTOR_DECLARATION(name)     \
1165    __global uchar *name##_ptr,      \
1166    uint        name##_stride_x, \
1167    uint        name##_step_x,   \
1168    uint        name##_offset_first_element_in_bytes
1169
1170#define IMAGE_DECLARATION(name)      \
1171    __global uchar *name##_ptr,      \
1172    uint        name##_stride_x, \
1173    uint        name##_step_x,   \
1174    uint        name##_stride_y, \
1175    uint        name##_step_y,   \
1176    uint        name##_offset_first_element_in_bytes
1177
1178#define TENSOR3D_DECLARATION(name)   \
1179    __global uchar *name##_ptr,      \
1180    uint        name##_stride_x, \
1181    uint        name##_step_x,   \
1182    uint        name##_stride_y, \
1183    uint        name##_step_y,   \
1184    uint        name##_stride_z, \
1185    uint        name##_step_z,   \
1186    uint        name##_offset_first_element_in_bytes
1187
1188#define TENSOR4D_DECLARATION(name)   \
1189    __global uchar *name##_ptr,      \
1190    uint        name##_stride_x, \
1191    uint        name##_step_x,   \
1192    uint        name##_stride_y, \
1193    uint        name##_step_y,   \
1194    uint        name##_stride_z, \
1195    uint        name##_step_z,   \
1196    uint        name##_stride_w, \
1197    uint        name##_step_w,   \
1198    uint        name##_offset_first_element_in_bytes
1199
1200#define CONVERT_TO_VECTOR_STRUCT(name) \
1201    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
1202
1203#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
1204    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
1205
1206#define CONVERT_TO_IMAGE_STRUCT(name) \
1207    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
1208
1209#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
1210    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
1211
1212#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1213    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1214
1215#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
1216    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
1217
1218#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1219    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1220
1221#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
1222    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1223                                 name##_stride_z, name##_step_z)
1224
1225#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
1226    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
1227
1228#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
1229    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1230                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
1231
1232#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
1233    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
1234
1235#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
1236    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1237                           name##_stride_z, name##_step_z)
1238
1239/** Structure to hold Vector information */
1240typedef struct Vector
1241{
1242    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1243    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1244    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1245} Vector;
1246
1247/** Structure to hold Image information */
1248typedef struct Image
1249{
1250    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1251    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1252    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1253    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1254} Image;
1255
1256/** Structure to hold 3D tensor information */
1257typedef struct Tensor3D
1258{
1259    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1260    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1261    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1262    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1263    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1264} Tensor3D;
1265
1266/** Structure to hold 4D tensor information */
1267typedef struct Tensor4D
1268{
1269    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1270    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1271    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1272    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1273    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1274    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
1275} Tensor4D;
1276
1277/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
1278 *
1279 * @param[in] ptr                           Pointer to the starting postion of the buffer
1280 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
1281 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
1282 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1283 *
1284 * @return An image object
1285 */
1286inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
1287{
1288    Vector vector =
1289    {
1290        .ptr                           = ptr,
1291        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1292        .stride_x                      = stride_x,
1293    };
1294    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
1295    return vector;
1296}
1297
1298/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
1299 *
1300 * @param[in] ptr                           Pointer to the starting postion of the buffer
1301 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1302 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1303 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1304 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1305 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1306 *
1307 * @return An image object
1308 */
1309inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
1310{
1311    Image img =
1312    {
1313        .ptr                           = ptr,
1314        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1315        .stride_x                      = stride_x,
1316        .stride_y                      = stride_y
1317    };
1318    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
1319    return img;
1320}
1321
1322/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
1323 *
1324 * @param[in] ptr                           Pointer to the starting postion of the buffer
1325 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1326 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1327 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1328 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1329 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1330 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1331 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1332 *
1333 * @return A 3D tensor object
1334 */
1335inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1336{
1337    Image img =
1338    {
1339        .ptr                           = ptr,
1340        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1341        .stride_x                      = stride_x,
1342        .stride_y                      = stride_y
1343    };
1344    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1345    return img;
1346}
1347
1348/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
1349 *
1350 * @param[in] ptr                           Pointer to the starting postion of the buffer
1351 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1352 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1353 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1354 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1355 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1356 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1357 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1358 *
1359 * @return A 3D tensor object
1360 */
1361inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1362{
1363    Tensor3D tensor =
1364    {
1365        .ptr                           = ptr,
1366        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1367        .stride_x                      = stride_x,
1368        .stride_y                      = stride_y,
1369        .stride_z                      = stride_z
1370    };
1371    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1372    return tensor;
1373}
1374
1375/** Wrap 3D tensor information into an tensor structure.
1376 *
1377 * @param[in] ptr                           Pointer to the starting postion of the buffer
1378 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1379 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1380 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1381 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1382 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1383 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1384 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1385 *
1386 * @return A 3D tensor object
1387 */
1388inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1389{
1390    Tensor3D tensor =
1391    {
1392        .ptr                           = ptr,
1393        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1394        .stride_x                      = stride_x,
1395        .stride_y                      = stride_y,
1396        .stride_z                      = stride_z
1397    };
1398    return tensor;
1399}
1400
1401inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
1402                                             uint step_w,
1403                                             uint mod_size)
1404{
1405    Tensor4D tensor =
1406    {
1407        .ptr                           = ptr,
1408        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1409        .stride_x                      = stride_x,
1410        .stride_y                      = stride_y,
1411        .stride_z                      = stride_z,
1412        .stride_w                      = stride_w
1413    };
1414
1415    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
1416    return tensor;
1417}
1418
1419/** Get the pointer position of a Vector
1420 *
1421 * @param[in] vec Pointer to the starting position of the buffer
1422 * @param[in] x   Relative X position
1423 */
1424inline __global const uchar *vector_offset(const Vector *vec, int x)
1425{
1426    return vec->ptr + x * vec->stride_x;
1427}
1428
1429/** Get the pointer position of a Image
1430 *
1431 * @param[in] img Pointer to the starting position of the buffer
1432 * @param[in] x   Relative X position
1433 * @param[in] y   Relative Y position
1434 */
1435inline __global uchar *offset(const Image *img, int x, int y)
1436{
1437    return img->ptr + x * img->stride_x + y * img->stride_y;
1438}
1439
1440/** Get the pointer position of a Tensor3D
1441 *
1442 * @param[in] tensor Pointer to the starting position of the buffer
1443 * @param[in] x      Relative X position
1444 * @param[in] y      Relative Y position
1445 * @param[in] z      Relative Z position
1446 */
1447inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
1448{
1449    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
1450}
1451
1452/** Get the pointer position of a Tensor4D
1453 *
1454 * @param[in] tensor Pointer to the starting position of the buffer
1455 * @param[in] x      Relative X position
1456 * @param[in] y      Relative Y position
1457 * @param[in] z      Relative Z position
1458 * @param[in] w      Relative W position
1459 */
1460inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
1461{
1462    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
1463}
1464
1465/** Get the offset for a given linear index of a Tensor3D
1466 *
1467 * @param[in] tensor Pointer to the starting position of the buffer
1468 * @param[in] width  Width of the input tensor
1469 * @param[in] height Height of the input tensor
1470 * @param[in] depth  Depth of the input tensor
1471 * @param[in] index  Linear index
1472 */
1473inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
1474{
1475    uint num_elements = width * height;
1476
1477    const uint z = index / num_elements;
1478
1479    index %= num_elements;
1480
1481    const uint y = index / width;
1482
1483    index %= width;
1484
1485    const uint x = index;
1486
1487    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
1488}
1489
1490#endif // _HELPER_H
1491
1492#undef CONVERT_SAT
1493
1494#define ADD_OP(a, b) ((a) + (b))
1495#define MUL_OP(a, b) ((a) * (b))
1496#define CONVERT_SAT(a, b) ((a))
1497
1498#if defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
1499
1500#if STRIDE_X == 1
1501#define CONVOLUTION1x3(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)
1502#elif STRIDE_X == 2 /* STRIDE_X == 1 */
1503#define CONVOLUTION1x3(acc, src_row_ptr, weights_row_ptr) CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr)
1504#else /* STRIDE_X not equals 1 or 2 */
1505#error "STRIDE_X larger than 2 is not supported"
1506#endif /* STRIDE_X == 2 */
1507
1508#define CONVOLUTION1x3_STRIDE1(acc, src_row_ptr, weights_row_ptr)                                                                                  \
1509    ({                                                                                                                                             \
1510        VEC_DATA_TYPE(DATA_TYPE, 3)                                                                                                                \
1511        weights_values0 = vload3(0, weights_row_ptr);                                                                                              \
1512        VEC_DATA_TYPE(DATA_TYPE, 8)                                                                                                                \
1513        src0 = vload8(0, src_row_ptr);                                                                                                             \
1514        VEC_DATA_TYPE(DATA_TYPE, 2)                                                                                                                \
1515        src1 = vload2(0, src_row_ptr + 8);                                                                                                         \
1516        \
1517        acc = ADD_OP(acc, MUL_OP(src0, (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0));                                                          \
1518        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1)); \
1519        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
1520    })
1521
1522#define CONVOLUTION1x3_STRIDE2(acc, src_row_ptr, weights_row_ptr)                                                                               \
1523    ({                                                                                                                                          \
1524        VEC_DATA_TYPE(DATA_TYPE, 3)                                                                                                             \
1525        weights_values0 = vload3(0, weights_row_ptr);                                                                                           \
1526        VEC_DATA_TYPE(DATA_TYPE, 16)                                                                                                            \
1527        src0           = vload16(0, src_row_ptr);                                                                                               \
1528        DATA_TYPE src1 = *(src_row_ptr + 16);                                                                                                   \
1529        \
1530        acc = ADD_OP(acc, MUL_OP(src0.even, (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s0));                                                  \
1531        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s1));      \
1532        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1), (VEC_DATA_TYPE(DATA_TYPE, 8))weights_values0.s2)); \
1533    })
1534
1535#if defined(DATA_LAYOUT_NHWC)
1536
1537#define PTR_TO_VALUE(PTR, DATA_TYPE) *((__global DATA_TYPE *)(PTR))
1538
1539#if STRIDE_X == 1
1540#define CONVOLUTION1x3_NHWC(acc, row_ptr, weights_ptr) CONVOLUTION1x3_STRIDE_NHWC_STRIDE1(acc, row_ptr, weights_ptr)
1541#elif STRIDE_X == 2 /* STRIDE_X == 1 */
1542#define CONVOLUTION1x3_NHWC(acc, row_ptr, weights_ptr) CONVOLUTION1x3_STRIDE_NHWC_STRIDE2(acc, row_ptr, weights_ptr)
1543#else /* STRIDE_X not equals 1 or 2 */
1544#error "STRIDE_X larger than 2 is not supported"
1545#endif /* STRIDE_X == 2 */
1546
1547#define CONVOLUTION1x3_STRIDE_NHWC_STRIDE1(acc, row_ptr, weights_ptr)                                                                      \
1548    {                                                                                                                                      \
1549        VEC_DATA_TYPE(DATA_TYPE, 8)                                                                                                        \
1550        src0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(                                                                                              \
1551                PTR_TO_VALUE(row_ptr + 0 * src_stride_y, DATA_TYPE),                                                                           \
1552                PTR_TO_VALUE(row_ptr + 1 * src_stride_y, DATA_TYPE),                                                                           \
1553                PTR_TO_VALUE(row_ptr + 2 * src_stride_y, DATA_TYPE),                                                                           \
1554                PTR_TO_VALUE(row_ptr + 3 * src_stride_y, DATA_TYPE),                                                                           \
1555                PTR_TO_VALUE(row_ptr + 4 * src_stride_y, DATA_TYPE),                                                                           \
1556                PTR_TO_VALUE(row_ptr + 5 * src_stride_y, DATA_TYPE),                                                                           \
1557                PTR_TO_VALUE(row_ptr + 6 * src_stride_y, DATA_TYPE),                                                                           \
1558                PTR_TO_VALUE(row_ptr + 7 * src_stride_y, DATA_TYPE));                                                                          \
1559        VEC_DATA_TYPE(DATA_TYPE, 2)                                                                                                        \
1560        src1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(                                                                                              \
1561                PTR_TO_VALUE(row_ptr + 8 * src_stride_y, DATA_TYPE),                                                                           \
1562                PTR_TO_VALUE(row_ptr + 9 * src_stride_y, DATA_TYPE));                                                                          \
1563        VEC_DATA_TYPE(DATA_TYPE, 3)                                                                                                        \
1564        weights = (VEC_DATA_TYPE(DATA_TYPE, 3))(                                                                                           \
1565                  PTR_TO_VALUE((weights_ptr) + 0 * weights_stride_y, DATA_TYPE),                                                                 \
1566                  PTR_TO_VALUE((weights_ptr) + 1 * weights_stride_y, DATA_TYPE),                                                                 \
1567                  PTR_TO_VALUE((weights_ptr) + 2 * weights_stride_y, DATA_TYPE));                                                                \
1568        acc = ADD_OP(acc, MUL_OP(src0, (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s0));                                                          \
1569        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1234, src0.s567, src1.s0), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s1)); \
1570        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s234, src0.s567, src1.s01), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s2)); \
1571    }
1572
1573#define CONVOLUTION1x3_STRIDE_NHWC_STRIDE2(acc, row_ptr, weights_ptr)                                                                   \
1574    {                                                                                                                                   \
1575        VEC_DATA_TYPE(DATA_TYPE, 16)                                                                                                    \
1576        src0 = (VEC_DATA_TYPE(DATA_TYPE, 16))(                                                                                          \
1577                PTR_TO_VALUE(row_ptr + 0 * src_stride_y, DATA_TYPE),                                                                        \
1578                PTR_TO_VALUE(row_ptr + 1 * src_stride_y, DATA_TYPE),                                                                        \
1579                PTR_TO_VALUE(row_ptr + 2 * src_stride_y, DATA_TYPE),                                                                        \
1580                PTR_TO_VALUE(row_ptr + 3 * src_stride_y, DATA_TYPE),                                                                        \
1581                PTR_TO_VALUE(row_ptr + 4 * src_stride_y, DATA_TYPE),                                                                        \
1582                PTR_TO_VALUE(row_ptr + 5 * src_stride_y, DATA_TYPE),                                                                        \
1583                PTR_TO_VALUE(row_ptr + 6 * src_stride_y, DATA_TYPE),                                                                        \
1584                PTR_TO_VALUE(row_ptr + 7 * src_stride_y, DATA_TYPE),                                                                        \
1585                PTR_TO_VALUE(row_ptr + 8 * src_stride_y, DATA_TYPE),                                                                        \
1586                PTR_TO_VALUE(row_ptr + 9 * src_stride_y, DATA_TYPE),                                                                        \
1587                PTR_TO_VALUE(row_ptr + 10 * src_stride_y, DATA_TYPE),                                                                       \
1588                PTR_TO_VALUE(row_ptr + 11 * src_stride_y, DATA_TYPE),                                                                       \
1589                PTR_TO_VALUE(row_ptr + 12 * src_stride_y, DATA_TYPE),                                                                       \
1590                PTR_TO_VALUE(row_ptr + 13 * src_stride_y, DATA_TYPE),                                                                       \
1591                PTR_TO_VALUE(row_ptr + 14 * src_stride_y, DATA_TYPE),                                                                       \
1592                PTR_TO_VALUE(row_ptr + 15 * src_stride_y, DATA_TYPE));                                                                      \
1593        DATA_TYPE src1 = PTR_TO_VALUE(row_ptr + 16 * src_stride_y, DATA_TYPE);                                                          \
1594        VEC_DATA_TYPE(DATA_TYPE, 3)                                                                                                     \
1595        weights = (VEC_DATA_TYPE(DATA_TYPE, 3))(                                                                                        \
1596                  PTR_TO_VALUE((weights_ptr) + 0 * weights_stride_y, DATA_TYPE),                                                              \
1597                  PTR_TO_VALUE((weights_ptr) + 1 * weights_stride_y, DATA_TYPE),                                                              \
1598                  PTR_TO_VALUE((weights_ptr) + 2 * weights_stride_y, DATA_TYPE));                                                             \
1599        \
1600        acc = ADD_OP(acc, MUL_OP(src0.s02468ACE, (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s0));                                             \
1601        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s1357, src0.s9BDF), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s1));      \
1602        acc = ADD_OP(acc, MUL_OP((VEC_DATA_TYPE(DATA_TYPE, 8))(src0.s2468, src0.sACE, src1), (VEC_DATA_TYPE(DATA_TYPE, 8))weights.s2)); \
1603    }
1604
1605/** This kernel performs a direct convolution to convolve the low three dimensions.
1606 *
1607 * @note This OpenCL kernel works with stride_x = 1 and 2
1608 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
1609 * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
1610 * @note If biases are used then -DHAS_BIAS has to be passed at compile time
1611 *
1612 * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: QS8/QS16/F16/F32
1613 * @param[in]  src_stride_x                          Stride of the source tensor in X dimension (in bytes)
1614 * @param[in]  src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
1615 * @param[in]  src_stride_y                          Stride of the source tensor in Y dimension (in bytes)
1616 * @param[in]  src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
1617 * @param[in]  src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
1618 * @param[in]  src_step_z                            src_stride_z * number of elements along Z processed per workitem(in bytes)
1619 * @param[in]  src_offset_first_element_in_bytes     The offset of the first element in the source tensor
1620 * @param[out] dst_ptr                               Pointer to the destination tensor. Supported data types: same as @p src_ptr
1621 * @param[in]  dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
1622 * @param[in]  dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
1623 * @param[in]  dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
1624 * @param[in]  dst_step_y                            dst_stride_y * number of elements along Z processed per workitem(in bytes)
1625 * @param[in]  dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
1626 * @param[in]  dst_step_z                            dst_stride_z * number of elements along Z processed per workitem(in bytes)
1627 * @param[in]  dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
1628 * @param[in]  weights_ptr                           Pointer to the weights tensor. Supported data types: same as @p src_ptr
1629 * @param[in]  weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
1630 * @param[in]  weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
1631 * @param[in]  weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
1632 * @param[in]  weights_step_y                        weights_stride_y * number of elements along y processed per workitem(in bytes)
1633 * @param[in]  weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
1634 * @param[in]  weights_step_z                        weights_stride_z * number of elements along Z processed per workitem(in bytes)
1635 * @param[in]  weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
1636 * @param[in]  biases_ptr                            Pointer to the biases tensor. Same as @p src_ptr
1637 * @param[in]  biases_stride_x                       Stride of the biases tensor in X dimension (in bytes)
1638 * @param[in]  biases_step_x                         biases_stride_x * number of elements along X processed per workitem(in bytes)
1639 * @param[in]  biases_offset_first_element_in_bytes  The offset of the first element in the biases tensor
1640 * @param[in]  weights_stride_w                      Stride of the weights tensor in the 4th dimension
1641 */
1642__kernel void direct_convolution3x3_nhwc(
1643    TENSOR3D_DECLARATION(src),
1644    TENSOR3D_DECLARATION(dst),
1645    TENSOR3D_DECLARATION(weights),
1646#ifdef HAS_BIAS
1647    VECTOR_DECLARATION(biases),
1648#endif /* defined(HAS_BIAS) */
1649    unsigned int weights_stride_w)
1650{
1651    Image    src     = CONVERT_TO_IMAGE_STRUCT(src);
1652    Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
1653    Tensor3D dst     = CONVERT_TO_TENSOR3D_STRUCT(dst);
1654
1655    VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
1656    values0       = 0;
1657    const int id0 = get_global_id(0);
1658    const int id1 = get_global_id(1);
1659    const int id2 = get_global_id(2);
1660
1661    __global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
1662    __global uchar *src_addr     = (__global uchar *)offset(&src, 0, 0) - src_stride_x * id0 + ((id2 * STRIDE_Y) - PAD_TOP) * (int)src_stride_z;
1663
1664    weights_addr += id0 * weights_stride_w;
1665
1666    const int coordy = ((id2 * STRIDE_Y) - PAD_TOP);
1667    for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
1668    {
1669#if PAD_TOP > 0
1670        if(coordy < 0) // special case Z = -1 doesn't exists
1671        {
1672            //skip first row and load the two next ones
1673            CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
1674            CONVOLUTION1x3_NHWC(values0, src_addr + 2 * (int)src_stride_z, (weights_addr + 2 * (int)weights_stride_z));
1675        }
1676        else if(coordy == (SRC_HEIGHT - PAD_TOP - 1))
1677        {
1678            // special case when computing the last row of the output we must read the last three rows from the input buffer (including padding) but the
1679            // Z axis has no padding at all.
1680            CONVOLUTION1x3_NHWC(values0, src_addr, (weights_addr + 0 * (int)weights_stride_z));
1681            CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
1682        }
1683        else
1684        {
1685            CONVOLUTION1x3_NHWC(values0, src_addr, (weights_addr + 0 * (int)weights_stride_z));
1686            CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
1687            CONVOLUTION1x3_NHWC(values0, src_addr + 2 * (int)src_stride_z, (weights_addr + 2 * (int)weights_stride_z));
1688        }
1689#else  // PAD_TOP > 0
1690        CONVOLUTION1x3_NHWC(values0, src_addr, (weights_addr + 0 * (int)weights_stride_z));
1691        CONVOLUTION1x3_NHWC(values0, src_addr + 1 * (int)src_stride_z, (weights_addr + 1 * (int)weights_stride_z));
1692        CONVOLUTION1x3_NHWC(values0, src_addr + 2 * (int)src_stride_z, (weights_addr + 2 * (int)weights_stride_z));
1693#endif // PAD_TOP > 0
1694        src_addr += src_stride_x;
1695        weights_addr += weights_stride_x;
1696    }
1697
1698#ifdef HAS_BIAS
1699    Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
1700    values0       = ADD_OP(values0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, id0))));
1701#endif /* defined(HAS_BIAS) */
1702
1703    *((__global DATA_TYPE *)(dst.ptr + 0 * dst_stride_y)) = values0.s0;
1704    *((__global DATA_TYPE *)(dst.ptr + 1 * dst_stride_y)) = values0.s1;
1705    *((__global DATA_TYPE *)(dst.ptr + 2 * dst_stride_y)) = values0.s2;
1706    *((__global DATA_TYPE *)(dst.ptr + 3 * dst_stride_y)) = values0.s3;
1707    *((__global DATA_TYPE *)(dst.ptr + 4 * dst_stride_y)) = values0.s4;
1708    *((__global DATA_TYPE *)(dst.ptr + 5 * dst_stride_y)) = values0.s5;
1709    *((__global DATA_TYPE *)(dst.ptr + 6 * dst_stride_y)) = values0.s6;
1710    *((__global DATA_TYPE *)(dst.ptr + 7 * dst_stride_y)) = values0.s7;
1711}
1712#endif // defined(DATA_LAYOUT_NHWC)
1713
1714/** This kernel performs a direct convolution to convolve the low three dimensions.
1715 *
1716 * @note This OpenCL kernel works with stride_x = 1 and 2
1717 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
1718 * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
1719 * @note If biases are used then -DHAS_BIAS has to be passed at compile time
1720 *
1721 * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: F16/F32
1722 * @param[in]  src_stride_x                          Stride of the source tensor in X dimension (in bytes)
1723 * @param[in]  src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
1724 * @param[in]  src_stride_y                          Stride of the source tensor in Y dimension (in bytes)
1725 * @param[in]  src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
1726 * @param[in]  src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
1727 * @param[in]  src_step_z                            src_stride_z * number of elements along Z processed per workitem(in bytes)
1728 * @param[in]  src_offset_first_element_in_bytes     The offset of the first element in the source tensor
1729 * @param[out] dst_ptr                               Pointer to the destination tensor. Supported data types: same as @p src_ptr
1730 * @param[in]  dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
1731 * @param[in]  dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
1732 * @param[in]  dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
1733 * @param[in]  dst_step_y                            dst_stride_y * number of elements along Z processed per workitem(in bytes)
1734 * @param[in]  dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
1735 * @param[in]  dst_step_z                            dst_stride_z * number of elements along Z processed per workitem(in bytes)
1736 * @param[in]  dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
1737 * @param[in]  weights_ptr                           Pointer to the weights tensor. Supported data types: same as @p src_ptr
1738 * @param[in]  weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
1739 * @param[in]  weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
1740 * @param[in]  weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
1741 * @param[in]  weights_step_y                        weights_stride_y * number of elements along y processed per workitem(in bytes)
1742 * @param[in]  weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
1743 * @param[in]  weights_step_z                        weights_stride_z * number of elements along Z processed per workitem(in bytes)
1744 * @param[in]  weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
1745 * @param[in]  biases_ptr                            Pointer to the biases tensor. Same as @p src_ptr
1746 * @param[in]  biases_stride_x                       Stride of the biases tensor in X dimension (in bytes)
1747 * @param[in]  biases_step_x                         biases_stride_x * number of elements along X processed per workitem(in bytes)
1748 * @param[in]  biases_offset_first_element_in_bytes  The offset of the first element in the biases tensor
1749 * @param[in]  weights_stride_w                      Stride of the weights tensor in the 4th dimension
1750 */
1751__kernel void direct_convolution3x3(
1752    TENSOR3D_DECLARATION(src),
1753    TENSOR3D_DECLARATION(dst),
1754    TENSOR3D_DECLARATION(weights),
1755#ifdef HAS_BIAS
1756    VECTOR_DECLARATION(biases),
1757#endif /* defined(HAS_BIAS) */
1758    unsigned int weights_stride_w)
1759{
1760    Image    src     = CONVERT_TO_IMAGE_STRUCT(src);
1761    Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(weights);
1762    Tensor3D dst     = CONVERT_TO_TENSOR3D_STRUCT(dst);
1763
1764    VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)
1765    values0 = 0;
1766
1767    __global uchar *weights_addr = (__global uchar *)tensor3D_offset(&weights, 0, 0, 0);
1768    __global uchar *src_addr     = (__global uchar *)offset(&src, 0, 0);
1769
1770    const int kernel_index = get_global_id(2);
1771    weights_addr += kernel_index * weights_stride_w;
1772
1773    for(volatile int d = 0; d < WEIGHTS_DEPTH; ++d)
1774    {
1775        CONVOLUTION1x3(values0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 0 * weights_stride_y));
1776        CONVOLUTION1x3(values0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 1 * weights_stride_y));
1777        CONVOLUTION1x3(values0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y), (__global DATA_TYPE *)(weights_addr + 2 * weights_stride_y));
1778
1779        src_addr += src_stride_z;
1780        weights_addr += weights_stride_z;
1781    }
1782
1783#ifdef HAS_BIAS
1784    Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
1785
1786    values0 = ADD_OP(values0, (VEC_DATA_TYPE(DATA_TYPE_PROMOTED, 8)) * ((__global DATA_TYPE *)(vector_offset(&biases, kernel_index))));
1787#endif /* defined(HAS_BIAS) */
1788
1789    vstore8(CONVERT_SAT(values0, VEC_DATA_TYPE(DATA_TYPE, 8)), 0, (__global DATA_TYPE *)dst.ptr);
1790}
1791#endif //defined(DATA_TYPE) && defined(STRIDE_X) && defined(WEIGHTS_DEPTH)
1792
1793#if defined(WEIGHTS_DEPTH)
1794
1795#define CONVOLUTION1x3_BIFROST(acc, src0, src1, weights_row0) \
1796    ({                                                        \
1797        acc.s0 = mad(src0.s0, weights_row0.s0, acc.s0);       \
1798        acc.s1 = mad(src0.s1, weights_row0.s0, acc.s1);       \
1799        acc.s2 = mad(src0.s2, weights_row0.s0, acc.s2);       \
1800        acc.s3 = mad(src0.s3, weights_row0.s0, acc.s3);       \
1801        acc.s0 = mad(src0.s1, weights_row0.s1, acc.s0);       \
1802        acc.s1 = mad(src0.s2, weights_row0.s1, acc.s1);       \
1803        acc.s2 = mad(src0.s3, weights_row0.s1, acc.s2);       \
1804        acc.s3 = mad(src1.s0, weights_row0.s1, acc.s3);       \
1805        acc.s0 = mad(src0.s2, weights_row0.s2, acc.s0);       \
1806        acc.s1 = mad(src0.s3, weights_row0.s2, acc.s1);       \
1807        acc.s2 = mad(src1.s0, weights_row0.s2, acc.s2);       \
1808        acc.s3 = mad(src1.s1, weights_row0.s2, acc.s3);       \
1809    })
1810
1811/** An optimized direct convolution 3x3 OpenCL kernel for Bifrost architectures when the data type is F32
1812 *
1813 * @note This OpenCL kernel works only with stride_x and stride_y equal to 1
1814 * @note The third dimensions of the weights tensors must be passed at compile time using -DWEIGHTS_DEPTH
1815 * @note In case biases, -DHAS_BIAS must to be passed at compile
1816 *
1817 * @param[in]  src_ptr                               Pointer to the source tensor. Supported data types: F32
1818 * @param[in]  src_stride_x                          Stride of the source tensor in X dimension (in bytes)
1819 * @param[in]  src_step_x                            src_stride_x * number of elements along X processed per workitem(in bytes)
1820 * @param[in]  src_stride_y                          Stride of the source tensor in Y dimension (in bytes)
1821 * @param[in]  src_step_y                            src_stride_y * number of elements along Y processed per workitem(in bytes)
1822 * @param[in]  src_stride_z                          Stride of the source tensor in Z dimension (in bytes)
1823 * @param[in]  src_step_z                            src_stride_z * number of elements along Z processed per workitem(in bytes)
1824 * @param[in]  src_offset_first_element_in_bytes     The offset of the first element in the source tensor
1825 * @param[out] dst_ptr                               Pointer to the destination tensor. Supported data types: same as @p src_ptr
1826 * @param[in]  dst_stride_x                          Stride of the destination tensor in X dimension (in bytes)
1827 * @param[in]  dst_step_x                            dst_stride_x * number of elements along X processed per workitem(in bytes)
1828 * @param[in]  dst_stride_y                          Stride of the destination tensor in Y dimension (in bytes)
1829 * @param[in]  dst_step_y                            dst_stride_y * number of elements along Z processed per workitem(in bytes)
1830 * @param[in]  dst_stride_z                          Stride of the destination tensor in Z dimension (in bytes)
1831 * @param[in]  dst_step_z                            dst_stride_z * number of elements along Z processed per workitem(in bytes)
1832 * @param[in]  dst_offset_first_element_in_bytes     The offset of the first element in the destination tensor
1833 * @param[in]  weights_ptr                           Pointer to the weights tensor. Supported data types: same as @p src_ptr
1834 * @param[in]  weights_stride_x                      Stride of the weights tensor in X dimension (in bytes)
1835 * @param[in]  weights_step_x                        weights_stride_x * number of elements along X processed per workitem(in bytes)
1836 * @param[in]  weights_stride_y                      Stride of the weights tensor in Y dimension (in bytes)
1837 * @param[in]  weights_step_y                        weights_stride_y * number of elements along y processed per workitem(in bytes)
1838 * @param[in]  weights_stride_z                      Stride of the weights tensor in Z dimension (in bytes)
1839 * @param[in]  weights_step_z                        weights_stride_z * number of elements along Z processed per workitem(in bytes)
1840 * @param[in]  weights_offset_first_element_in_bytes The offset of the first element in the weights tensor
1841 * @param[in]  biases_ptr                            Pointer to the biases tensor. Same as @p src_ptr
1842 * @param[in]  biases_stride_x                       Stride of the biases tensor in X dimension (in bytes)
1843 * @param[in]  biases_step_x                         biases_stride_x * number of elements along X processed per workitem(in bytes)
1844 * @param[in]  biases_offset_first_element_in_bytes  The offset of the first element in the biases tensor
1845 * @param[in]  weights_stride_w                      Stride of the weights tensor in the 4th dimension
1846 */
1847__kernel void direct_convolution3x3_f32_bifrost(
1848    TENSOR3D_DECLARATION(src),
1849    TENSOR3D_DECLARATION(dst),
1850    TENSOR3D_DECLARATION(weights),
1851#ifdef HAS_BIAS
1852    VECTOR_DECLARATION(biases),
1853#endif /* defined(HAS_BIAS) */
1854    unsigned int weights_stride_w)
1855{
1856    // Get the kernel index
1857    const int kernel_index = get_global_id(2);
1858
1859    Image    src = CONVERT_TO_IMAGE_STRUCT(src);
1860    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
1861
1862    float4 values0 = 0;
1863    float4 values1 = 0;
1864    float4 values2 = 0;
1865
1866    __global uchar *weights_addr = (__global uchar *)(weights_ptr + weights_offset_first_element_in_bytes + kernel_index * weights_stride_w);
1867    __global uchar *src_addr     = (__global uchar *)offset(&src, 0, 0);
1868
1869    // Note: Since each work-item computes 4x3 elements, we need to load 5 rows from the input tensor
1870
1871    for(ushort d = 0; d < (ushort)WEIGHTS_DEPTH; ++d)
1872    {
1873        // Load the weights
1874        float3 weights_row0 = vload3(0, (__global float *)(weights_addr + 0 * weights_stride_y));
1875        float3 weights_row1 = vload3(0, (__global float *)(weights_addr + 1 * weights_stride_y));
1876        float3 weights_row2 = vload3(0, (__global float *)(weights_addr + 2 * weights_stride_y));
1877        float4 src0;
1878        float2 src1;
1879
1880        // Load values from row0 of input tensor
1881        src0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
1882        src1 = vload2(0, (__global float *)(src_addr + 0 * src_stride_y) + 4);
1883
1884        CONVOLUTION1x3_BIFROST(values0, src0, src1, weights_row0);
1885
1886        // Load values from row1 of input tensor
1887        src0 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
1888        src1 = vload2(0, (__global float *)(src_addr + 1 * src_stride_y) + 4);
1889
1890        // Accumulate
1891        CONVOLUTION1x3_BIFROST(values0, src0, src1, weights_row1);
1892        CONVOLUTION1x3_BIFROST(values1, src0, src1, weights_row0);
1893
1894        // Load values from row2 of input tensor
1895        src0 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
1896        src1 = vload2(0, (__global float *)(src_addr + 2 * src_stride_y) + 4);
1897
1898        // Accumulate
1899        CONVOLUTION1x3_BIFROST(values0, src0, src1, weights_row2);
1900        CONVOLUTION1x3_BIFROST(values1, src0, src1, weights_row1);
1901        CONVOLUTION1x3_BIFROST(values2, src0, src1, weights_row0);
1902
1903        // Load values from row3 of input tensor
1904        src0 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
1905        src1 = vload2(0, (__global float *)(src_addr + 3 * src_stride_y) + 4);
1906
1907        // Accumulate
1908        CONVOLUTION1x3_BIFROST(values1, src0, src1, weights_row2);
1909        CONVOLUTION1x3_BIFROST(values2, src0, src1, weights_row1);
1910
1911        // Row4
1912        src0 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
1913        src1 = vload2(0, (__global float *)(src_addr + 4 * src_stride_y) + 4);
1914
1915        // Accumulate
1916        CONVOLUTION1x3_BIFROST(values2, src0, src1, weights_row2);
1917
1918        src_addr += src_stride_z;
1919        weights_addr += weights_stride_z;
1920    }
1921
1922#ifdef HAS_BIAS
1923    Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
1924
1925    float bias = (float) * ((__global float *)(vector_offset(&biases, kernel_index)));
1926
1927    values0 += (float4)bias;
1928    values1 += (float4)bias;
1929    values2 += (float4)bias;
1930#endif /* defined(HAS_BIAS) */
1931
1932    vstore4(values0, 0, (__global float *)(dst.ptr + 0 * dst_stride_y));
1933    vstore4(values1, 0, (__global float *)(dst.ptr + 1 * dst_stride_y));
1934    vstore4(values2, 0, (__global float *)(dst.ptr + 2 * dst_stride_y));
1935}
1936#endif // defined(WEIGHTS_DEPTH)
1937
1938)"