• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1R"(
2
3/*
4 * Copyright (c) 2020 Arm Limited.
5 *
6 * SPDX-License-Identifier: MIT
7 *
8 * Permission is hereby granted, free of charge, to any person obtaining a copy
9 * of this software and associated documentation files (the "Software"), to
10 * deal in the Software without restriction, including without limitation the
11 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
12 * sell copies of the Software, and to permit persons to whom the Software is
13 * furnished to do so, subject to the following conditions:
14 *
15 * The above copyright notice and this permission notice shall be included in all
16 * copies or substantial portions of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 * SOFTWARE.
25 */
26/*
27 * Copyright (c) 2019-2020 Arm Limited.
28 *
29 * SPDX-License-Identifier: MIT
30 *
31 * Permission is hereby granted, free of charge, to any person obtaining a copy
32 * of this software and associated documentation files (the "Software"), to
33 * deal in the Software without restriction, including without limitation the
34 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
35 * sell copies of the Software, and to permit persons to whom the Software is
36 * furnished to do so, subject to the following conditions:
37 *
38 * The above copyright notice and this permission notice shall be included in all
39 * copies or substantial portions of the Software.
40 *
41 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 * SOFTWARE.
48 */
49/*
50 * Copyright (c) 2019-2020 Arm Limited.
51 *
52 * SPDX-License-Identifier: MIT
53 *
54 * Permission is hereby granted, free of charge, to any person obtaining a copy
55 * of this software and associated documentation files (the "Software"), to
56 * deal in the Software without restriction, including without limitation the
57 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
58 * sell copies of the Software, and to permit persons to whom the Software is
59 * furnished to do so, subject to the following conditions:
60 *
61 * The above copyright notice and this permission notice shall be included in all
62 * copies or substantial portions of the Software.
63 *
64 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
65 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
66 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
67 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
68 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
69 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
70 * SOFTWARE.
71 */
72
73/*
74 * Copyright (c) 2016-2020 Arm Limited.
75 *
76 * SPDX-License-Identifier: MIT
77 *
78 * Permission is hereby granted, free of charge, to any person obtaining a copy
79 * of this software and associated documentation files (the "Software"), to
80 * deal in the Software without restriction, including without limitation the
81 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
82 * sell copies of the Software, and to permit persons to whom the Software is
83 * furnished to do so, subject to the following conditions:
84 *
85 * The above copyright notice and this permission notice shall be included in all
86 * copies or substantial portions of the Software.
87 *
88 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
89 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
90 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
91 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
92 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
93 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
94 * SOFTWARE.
95 */
96#ifndef ARM_COMPUTE_HELPER_H
97#define ARM_COMPUTE_HELPER_H
98
99/*
100 * Copyright (c) 2020 Arm Limited.
101 *
102 * SPDX-License-Identifier: MIT
103 *
104 * Permission is hereby granted, free of charge, to any person obtaining a copy
105 * of this software and associated documentation files (the "Software"), to
106 * deal in the Software without restriction, including without limitation the
107 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
108 * sell copies of the Software, and to permit persons to whom the Software is
109 * furnished to do so, subject to the following conditions:
110 *
111 * The above copyright notice and this permission notice shall be included in all
112 * copies or substantial portions of the Software.
113 *
114 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
115 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
116 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
117 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
118 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
119 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
120 * SOFTWARE.
121 */
122
123/** Store the 0 to (n-1)th rows of the given variables
124 * @name STORE_ROW_n
125 *
126 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
127 * @param[in] DATA_TYPE The data type of the vectors
128 * @param[in] BASENAME  The basename of the variables
129 * @param[in] PTR       The base pointer
130 * @param[in] STRIDE_Y  The stride value in y-axis direction
131 * @param[in] Z         The offset in z-axis direction
132 * @{
133 */
134#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
135    VSTORE(N0)                                                 \
136    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
137
138#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
139    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
140    VSTORE(N0)                                                 \
141    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
142
143#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
144    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
145    VSTORE(N0)                                                 \
146    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
147
148#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
149    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
150    VSTORE(N0)                                                 \
151    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
152
153#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
154    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
155    VSTORE(N0)                                                 \
156    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
157
158#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
159    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
160    VSTORE(N0)                                                 \
161    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
162
163#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
164    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
165    VSTORE(N0)                                                 \
166    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
167
168#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
169    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
170    VSTORE(N0)                                                 \
171    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
172
173#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
174    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
175    VSTORE(N0)                                                 \
176    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
177
178#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
179    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
180    VSTORE(N0)                                                  \
181    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
182
183#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
184    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
185    VSTORE(N0)                                                  \
186    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
187
188#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
189    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
190    VSTORE(N0)                                                  \
191    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
192
193#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
194    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
195    VSTORE(N0)                                                  \
196    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
197
198#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
199    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
200    VSTORE(N0)                                                  \
201    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
202
203#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
204    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
205    VSTORE(N0)                                                  \
206    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
207
208#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
209    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
210    VSTORE(N0)                                                  \
211    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
212/** @} */ // end of groupd STORE_ROW_n
213
214/** Convert and store the 0th to (n-1)th rows of the given variables
215 * @name CONVERT_STORE_ROW_n
216 *
217 * @param[in] N0        The size of the vectors
218 * @param[in] DATA_TYPE The data type of the vectors
219 * @param[in] BASENAME  The basename of the variables
220 * @param[in] PTR       The base pointer
221 * @param[in] STRIDE_Y  The stride value in y-axis direction
222 * @param[in] Z         The offset in z-axis direction
223 * @{
224 */
225#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
226    VSTORE(N0)                                                         \
227    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
228
229#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
230    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
231    VSTORE(N0)                                                         \
232    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
233
234#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
235    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
236    VSTORE(N0)                                                         \
237    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
238
239#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
240    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
241    VSTORE(N0)                                                         \
242    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
243
244#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
245    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
246    VSTORE(N0)                                                         \
247    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
248
249#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
250    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
251    VSTORE(N0)                                                         \
252    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
253
254#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
255    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
256    VSTORE(N0)                                                         \
257    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
258
259#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
260    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
261    VSTORE(N0)                                                         \
262    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
263
264#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
265    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
266    VSTORE(N0)                                                         \
267    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
268
269#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
270    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
271    VSTORE(N0)                                                     \
272    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
273
274#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
275    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
276    VSTORE(N0)                                                          \
277    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
278
279#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
280    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
281    VSTORE(N0)                                                          \
282    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
283
284#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
285    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
286    VSTORE(N0)                                                          \
287    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
288
289#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
290    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
291    VSTORE(N0)                                                          \
292    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
293
294#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
295    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
296    VSTORE(N0)                                                          \
297    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
298
299#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
300    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
301    VSTORE(N0)                                                          \
302    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
303
304/** @} */ // end of groupd CONVERT_STORE_ROW_n
305
306/** Store a block of the given size M0xN0
307 * @name STORE_BLOCK
308 *
309 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
310 * The data to store is expected to have consecutive names for each row.
311 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
312 * The Z offset is expected to have consecutive names.
313 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
314 *
315 * @param[in] M0        The number of rows to store
316 * @param[in] N0        The size of each vector
317 * @param[in] DATA_TYPE The data type of the vectors
318 * @param[in] BASENAME  The basename of the variables
319 * @param[in] PTR       The base pointer
320 * @param[in] STRIDE_Y  The stride value in y-axis direction
321 * @param[in] Z         The offset in z-axis direction
322 * @{
323 */
324#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
325#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
326/** @} */ // end of group STORE_BLOCK
327
328/** Convert and store a block of the given size M0xN0
329 * @name CONVERT_STORE_BLOCK
330 *
331 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
332 * The data to store is expected to have consecutive names for each row.
333 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
334 * The Z offset is expected to have consecutive names.
335 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
336 *
337 * @param[in] M0        The number of rows to store
338 * @param[in] N0        The size of each vector
339 * @param[in] DATA_TYPE The data type of the vectors
340 * @param[in] BASENAME  The basename of the variables
341 * @param[in] PTR       The base pointer
342 * @param[in] STRIDE_Y  The stride value in y-axis direction
343 * @param[in] Z         The offset in z-axis direction
344 * @{
345 */
346#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
347#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
348/** @} */ // end of group CONVERT_STORE_BLOCK
349
350/** Partially store the 0 to (n-1)th rows of the given variables
351 * @name STORE_ROW_PARTIAL_n
352 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
353 *
354 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
355 *
356 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
357 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
358 * @param[in] DATA_TYPE The data type of the vectors
359 * @param[in] BASENAME  The basename of the variables
360 * @param[in] PTR       The base pointer
361 * @param[in] STRIDE_Y  The stride value in y-axis direction
362 * @param[in] Z         The offset in z-axis direction
363 * @{
364 */
365#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
366    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
367    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
368
369#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
370    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
371    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
372    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
373
374#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
375    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
376    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
377    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
378
379#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
380    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
381    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
382    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
383
384#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
385    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
386    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
387    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
388
389#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
390    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
391    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
392    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
393
394#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
395    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
396    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
397    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
398
399#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
400    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
401    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
402    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
403
404#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
405    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
406    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
407    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
408
409#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
410    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
411    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
412    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
413
414#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
415    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
416    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
417    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
418
419#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
420    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
421    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
422    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
423
424#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
425    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
426    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
427    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
428
429#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
430    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
431    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
432    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
433
434#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
435    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
436    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
437    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
438
439#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
440    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
441    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
442    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
443/** @} */ // end of groupd STORE_ROW_PARTIAL_n
444
445/** Partially store a block of the given size STORE_M0xSTORE_N0
446 * @name STORE_BLOCK_PARTIAL
447 *
448 * @note The vector width @p N0 is also required for correct partial storing behaviour.
449 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
450 *
451 * The data to store is expected to have consecutive names for each row.
452 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
453 * The Z offset is expected to have consecutive names.
454 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
455 *
456 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
457 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
458 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
459 * @param[in] DATA_TYPE The data type of the vectors
460 * @param[in] BASENAME  The basename of the variables
461 * @param[in] PTR       The base pointer
462 * @param[in] STRIDE_Y  The stride value in y-axis direction
463 * @param[in] Z         The offset in z-axis direction
464 * @{
465 */
466#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
467#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
468/** Store a block that can be partial in both x and y dimensions
469 *
470 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
471 *
472 * The data to store is expected to have consecutive names for each row.
473 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
474 * The Z offset is expected to have consecutive names.
475 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
476 *
477 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
478 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
479 * @param[in] DATA_TYPE        The data type of the vectors
480 * @param[in] BASENAME         The basename of the variables
481 * @param[in] PTR              The base pointer
482 * @param[in] STRIDE_Y         The stride value in y-axis direction
483 * @param[in] Z                The offset in z-axis direction
484 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
485 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
486 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
487 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
488 */
489#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
490    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
491    {                                                                                                                                                     \
492        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
493    }                                                                                                                                                     \
494    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
495    {                                                                                                                                                     \
496        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
497    }                                                                                                                                                     \
498    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
499    {                                                                                                                                                     \
500        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
501    }                                                                                                                                                     \
502    else                                                                                                                                                  \
503    {                                                                                                                                                     \
504        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
505    }
506/** Store a block that can only be partial in x but not y.
507 *
508 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
509 *
510 * The data to store is expected to have consecutive names for each row.
511 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
512 * The Z offset is expected to have consecutive names.
513 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
514 *
515 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
516 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
517 * @param[in] DATA_TYPE        The data type of the vectors
518 * @param[in] BASENAME         The basename of the variables
519 * @param[in] PTR              The base pointer
520 * @param[in] STRIDE_Y         The stride value in y-axis direction
521 * @param[in] Z                The offset in z-axis direction
522 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
523 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
524 */
525#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
526    if(!(PARTIAL_COND_X))                                                                                         \
527    {                                                                                                             \
528        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
529    }                                                                                                             \
530    else                                                                                                          \
531    {                                                                                                             \
532        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
533    }
534/** Store a block that can only be partial in y but not x.
535 *
536 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
537 *
538 * The data to store is expected to have consecutive names for each row.
539 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
540 * The Z offset is expected to have consecutive names.
541 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
542 *
543 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
544 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
545 * @param[in] DATA_TYPE        The data type of the vectors
546 * @param[in] BASENAME         The basename of the variables
547 * @param[in] PTR              The base pointer
548 * @param[in] STRIDE_Y         The stride value in y-axis direction
549 * @param[in] Z                The offset in z-axis direction
550 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
551 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
552 */
553#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
554    if(!(PARTIAL_COND_Y))                                                                                         \
555    {                                                                                                             \
556        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
557    }                                                                                                             \
558    else                                                                                                          \
559    {                                                                                                             \
560        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
561    }
562/** @} */ // end of group STORE_BLOCK_PARTIAL
563
564#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
565
566/** Boundary-aware GEMM block store
567 * @name STORE_BLOCK_BOUNDARY_AWARE
568 * This macro assumes the following schemes to achieve boundary-awareness:
569 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
570 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
571 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
572 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
573 *
574 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
575 * blocks **at the end**.
576 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
577 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
578 *
579 *  *--x-->                         x == 0                        x == 1
580 *  |                  |<------------------------------N-------------------------->|
581 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
582 *  |     -------------#############################################################
583 *  *     |          | |...............................|...........................|
584 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
585 *        |          | |...............................|...........................|
586 *        M          --#############################################################
587 *        |          | |                               |...........................|
588 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
589 *        |          | |                               |...........................|
590 *        |------------#############################################################
591 *
592 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
593 *
594 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
595 *
596 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
597 * and select corresponding store methods such that the boundary detection logic is only added when needed.
598 *
599 * The data to store is expected to have consecutive names for each row.
600 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
601 * The Z offset is expected to have consecutive names.
602 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
603 *
604 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
605 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
606 * @param[in] DATA_TYPE        The data type of the vectors
607 * @param[in] BASENAME         The basename of the variables
608 * @param[in] PTR              The base pointer
609 * @param[in] STRIDE_Y         The stride value in y-axis direction
610 * @param[in] Z                The offset in z-axis direction
611 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
612 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
613 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
614 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
615 * @{
616 */
617#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
618// Case1: No partial blocks in either x or y
619#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
620    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
621
622#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
623// Case2: Partial blocks in y
624#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
625    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
626
627#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
628// Case3: Partial blocks in x
629#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
630    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
631
632#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
633// Case4: Partial blocks in both x and y
634#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
635    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
636
637#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
638
639#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
640/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
641
642#if defined(PARTIAL_STORE_M0)
643/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
644 * @name COMPUTE_M0_START_ROW
645 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
646 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
647 * blocks in the y dimension to avoid any padding.
648 * EG: M0=4, PARTIAL_STORE_M0=1:
649 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
650 * block 0 (partial)| start row = 0   | start row = 0
651 * block 1 (full)   | start row = 4   | start row = 1
652 * block 2 (full)   | start row = 8   | start row = 5
653 *
654 * @param[in] y                Global id of current block in y.
655 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
656 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
657 * @{
658 */
659#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
660    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
661#else // defined(PARTIAL_STORE_M0)
662#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
663    ((uint)(y * M0))
664#endif    // defined(PARTIAL_STORE_M0)
665/** @} */ // end of group COMPUTE_M0_START_ROW
666
667/** Store a vector that can only be partial in x.
668 *
669 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
670 *
671 * The data to store is expected to end in a 0.
672 * E.g., for basename=c, the expected name is c0.
673 *
674 * @param[in] basename  The name of the variable without trailing 0
675 * @param[in] data_type The data type of the vector
676 * @param[in] ptr       The base pointer
677 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
678 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
679 * @param[in] cond      Condition to select either vec_size0 or vec_size1
680 * @{
681 */
682#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
683    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
684/** @} */ // end of group STORE_VECTOR_SELECT
685
686#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
687#pragma OPENCL EXTENSION cl_khr_fp16 : enable
688#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
689
690#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
691#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
692#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
693
694#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
695#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
696#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
697
698#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
699#pragma OPENCL EXTENSION cl_arm_printf : enable
700#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
701
702#define GPU_ARCH_MIDGARD 0x100
703#define GPU_ARCH_BIFROST 0x200
704
705/** Concatenate two inputs.
706 *
707 * @param[in] a The first input to be concatenated
708 * @param[in] b The second input to be concatenated
709 *
710 * @return The concatenated output
711 */
712#define CONCAT(a, b) a##b
713
714/** Expand the given vector
715 *
716 * @param[in] x The vector to be expanded
717 *
718 * @return The expanded output
719 */
720#define EXPAND(x) x
721
722/** Clamp the given value between an upper and lower bound.
723 *
724 * @param[in] x       The value to be clamped
725 * @param[in] min_val The lower bound
726 * @param[in] max_val The upper bound
727 *
728 * @return The clamped value.
729 */
730#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
731
732/** REVn reverses the given vector whose size is n.
733 * @name REVn
734 *
735 * @param[in] x The vector to be reversed
736 *
737 * @return The reversed vector
738 * @{
739 */
740#define REV1(x) ((x))
741#define REV2(x) ((x).s10)
742#define REV3(x) ((x).s210)
743#define REV4(x) ((x).s3210)
744#define REV8(x) ((x).s76543210)
745#define REV16(x) ((x).sFEDCBA9876543210)
746/** @} */ // end of group REVn
747
748/** Reverse the given vector.
749 * @name REVERSE
750 *
751 * @param[in] x The vector to be reversed
752 * @param[in] s The size of the vector
753 *
754 * @return The reversed vector
755 * @{
756 */
757#define REVERSE_STR(x, s) REV##s((x))
758#define REVERSE(x, s) REVERSE_STR(x, s)
759/** @} */ // end of group REVERSE
760
761/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
762 * @name ROTs_n
763 *
764 * @param[in] x The vector to be shifted
765 *
766 * @return The shifted vector
767 * @{
768 */
769#define ROT1_0(x) ((x))
770
771#define ROT2_0(x) ((x))
772#define ROT2_1(x) ((x).s10)
773
774#define ROT3_0(x) ((x))
775#define ROT3_1(x) ((x).s201)
776#define ROT3_2(x) ((x).s120)
777
778#define ROT4_0(x) ((x))
779#define ROT4_1(x) ((x).s3012)
780#define ROT4_2(x) ((x).s2301)
781#define ROT4_3(x) ((x).s1230)
782
783#define ROT8_0(x) ((x))
784#define ROT8_1(x) ((x).s70123456)
785#define ROT8_2(x) ((x).s67012345)
786#define ROT8_3(x) ((x).s56701234)
787#define ROT8_4(x) ((x).s45670123)
788#define ROT8_5(x) ((x).s34567012)
789#define ROT8_6(x) ((x).s23456701)
790#define ROT8_7(x) ((x).s12345670)
791
792#define ROT16_0(x) ((x))
793#define ROT16_1(x) ((x).sF0123456789ABCDE)
794#define ROT16_2(x) ((x).sEF0123456789ABCD)
795#define ROT16_3(x) ((x).sDEF0123456789ABC)
796#define ROT16_4(x) ((x).sCDEF0123456789AB)
797#define ROT16_5(x) ((x).sBCDEF0123456789A)
798#define ROT16_6(x) ((x).sABCDEF0123456789)
799#define ROT16_7(x) ((x).s9ABCDEF012345678)
800#define ROT16_8(x) ((x).s89ABCDEF01234567)
801#define ROT16_9(x) ((x).s789ABCDEF0123456)
802#define ROT16_10(x) ((x).s6789ABCDEF012345)
803#define ROT16_11(x) ((x).s56789ABCDEF01234)
804#define ROT16_12(x) ((x).s456789ABCDEF0123)
805#define ROT16_13(x) ((x).s3456789ABCDEF012)
806#define ROT16_14(x) ((x).s23456789ABCDEF01)
807#define ROT16_15(x) ((x).s123456789ABCDEF0)
808/** @} */ // end of group ROTs_n
809
810/** Circular-right-shift (rotate-right) the given vector by the given amount.
811 * @name ROTATE
812 *
813 * @param[in] x The vector to be shifted
814 * @param[in] s The size of the vector
815 * @param[in] n The amount to be shifted
816 *
817 * @return The shifted vector
818 * @{
819 */
820#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
821#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
822/** @} */ // end of group ROTATE
823
824/** Creates a vector of size n filled with offset values corresponding to the location of each element.
825 * @name V_OFFSn
826 *
827 * @param[in] dt The data type of the output vector
828 *
829 * @return The vector filled with offset values
830 * @{
831 */
832#define V_OFFS1(dt) (dt##1)(0)
833#define V_OFFS2(dt) (dt##2)(0, 1)
834#define V_OFFS3(dt) (dt##3)(0, 1, 2)
835#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
836#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
837#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
838/** @} */ // end of group V_OFFSn
839
840/** Create a vector filled with offset values corresponding to the location of each element.
841 * @name VEC_OFFS
842 *
843 * @param[in] dt The data type of the output vector
844 * @param[in] s  The size of the output vector
845 *
846 * @return The vector filled with offset values
847 * @{
848 */
849#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
850#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
851/** @} */ // end of group VEC_OFFS
852
853#define VLOAD_STR(size) vload##size
854#define VLOAD(size) VLOAD_STR(size)
855
856#define PIXEL_UNIT4 1
857#define PIXEL_UNIT8 2
858#define PIXEL_UNIT16 4
859
860/** Utility macro to convert a vector size in pixel unit.
861 *
862 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
863 *
864 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
865 *
866 * @return The pixel unit (number of pixels)
867 * @{
868 */
869#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
870#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
871/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
872
873#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
874#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
875#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
876
877#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
878#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
879#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
880#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
881#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
882
883/** Utility macro to read a 2D OpenCL image object.
884 *
885 * @note Coordinates are not normalized
886 *
887 * @param[in] data_type Data type
888 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
889 * @param[in] img       OpenCL image object
890 * @param[in] x_coord   The x coordinate for the top-left pixel
891 * @param[in] y_coord   The y coordinate for the top-left pixel
892 *
893 * @return Pixels from the 2D OpenCL image object
894 * @{
895 */
896#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
897#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
898
899#define VSTORE_STR(size) vstore##size
900#define VSTORE(size) VSTORE_STR(size)
901
902#define float1 float
903#define half1 half
904#define char1 char
905#define uchar1 uchar
906#define short1 short
907#define ushort1 ushort
908#define int1 int
909#define uint1 uint
910#define long1 long
911#define ulong1 ulong
912#define double1 double
913
914#define vload1(OFFSET, PTR) *(OFFSET + PTR)
915#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
916
917/** Extended partial vstore that correctly handles scalar values as well.
918 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
919 * @name VSTORE_PARTIAL
920 *
921 * @note With this macro, the passed data can be both a vector and a scalar
922 * @note @p store_size needs to be <= @p size
923 * eg 1: Valid
924 * VSTORE_PARTIAL(16, 15) ...;
925 * eg 2: Invalid
926 * VSTORE_PARTIAL(4, 7) ...;
927 *
928 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
929 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
930 * @{
931 */
932#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
933#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
934
935#define NO_STORE(data, offs, ptr) \
936    {                             \
937    }
938
939// Size == 1 (scalar)
940#define vstore_partial_1_0 NO_STORE
941#define vstore_partial_1_1 vstore1
942#define vstore_partial_1_2 NO_STORE
943#define vstore_partial_1_3 NO_STORE
944#define vstore_partial_1_4 NO_STORE
945#define vstore_partial_1_5 NO_STORE
946#define vstore_partial_1_6 NO_STORE
947#define vstore_partial_1_7 NO_STORE
948#define vstore_partial_1_8 NO_STORE
949#define vstore_partial_1_9 NO_STORE
950#define vstore_partial_1_10 NO_STORE
951#define vstore_partial_1_11 NO_STORE
952#define vstore_partial_1_12 NO_STORE
953#define vstore_partial_1_13 NO_STORE
954#define vstore_partial_1_14 NO_STORE
955#define vstore_partial_1_15 NO_STORE
956#define vstore_partial_1_16 NO_STORE
957// Size == 2
958#define vstore_partial_2_0 NO_STORE
959#define vstore_partial_2_1 vstore_partial_1
960#define vstore_partial_2_2 vstore_partial_2
961#define vstore_partial_2_3 NO_STORE
962#define vstore_partial_2_4 NO_STORE
963#define vstore_partial_2_5 NO_STORE
964#define vstore_partial_2_6 NO_STORE
965#define vstore_partial_2_7 NO_STORE
966#define vstore_partial_2_8 NO_STORE
967#define vstore_partial_2_9 NO_STORE
968#define vstore_partial_2_10 NO_STORE
969#define vstore_partial_2_11 NO_STORE
970#define vstore_partial_2_12 NO_STORE
971#define vstore_partial_2_13 NO_STORE
972#define vstore_partial_2_14 NO_STORE
973#define vstore_partial_2_15 NO_STORE
974#define vstore_partial_2_16 NO_STORE
975// Size == 3
976#define vstore_partial_3_0 NO_STORE
977#define vstore_partial_3_1 vstore_partial_1
978#define vstore_partial_3_2 vstore_partial_2
979#define vstore_partial_3_3 vstore_partial_3
980#define vstore_partial_3_4 NO_STORE
981#define vstore_partial_3_5 NO_STORE
982#define vstore_partial_3_6 NO_STORE
983#define vstore_partial_3_7 NO_STORE
984#define vstore_partial_3_8 NO_STORE
985#define vstore_partial_3_9 NO_STORE
986#define vstore_partial_3_10 NO_STORE
987#define vstore_partial_3_11 NO_STORE
988#define vstore_partial_3_12 NO_STORE
989#define vstore_partial_3_13 NO_STORE
990#define vstore_partial_3_14 NO_STORE
991#define vstore_partial_3_15 NO_STORE
992#define vstore_partial_3_16 NO_STORE
993// Size == 4
994#define vstore_partial_4_0 NO_STORE
995#define vstore_partial_4_1 vstore_partial_1
996#define vstore_partial_4_2 vstore_partial_2
997#define vstore_partial_4_3 vstore_partial_3
998#define vstore_partial_4_4 vstore_partial_4
999#define vstore_partial_4_5 NO_STORE
1000#define vstore_partial_4_6 NO_STORE
1001#define vstore_partial_4_7 NO_STORE
1002#define vstore_partial_4_8 NO_STORE
1003#define vstore_partial_4_9 NO_STORE
1004#define vstore_partial_4_10 NO_STORE
1005#define vstore_partial_4_11 NO_STORE
1006#define vstore_partial_4_12 NO_STORE
1007#define vstore_partial_4_13 NO_STORE
1008#define vstore_partial_4_14 NO_STORE
1009#define vstore_partial_4_15 NO_STORE
1010#define vstore_partial_4_16 NO_STORE
1011// Size == 8
1012#define vstore_partial_8_0 NO_STORE
1013#define vstore_partial_8_1 vstore_partial_1
1014#define vstore_partial_8_2 vstore_partial_2
1015#define vstore_partial_8_3 vstore_partial_3
1016#define vstore_partial_8_4 vstore_partial_4
1017#define vstore_partial_8_5 vstore_partial_5
1018#define vstore_partial_8_6 vstore_partial_6
1019#define vstore_partial_8_7 vstore_partial_7
1020#define vstore_partial_8_8 vstore_partial_8
1021#define vstore_partial_8_9 NO_STORE
1022#define vstore_partial_8_10 NO_STORE
1023#define vstore_partial_8_11 NO_STORE
1024#define vstore_partial_8_12 NO_STORE
1025#define vstore_partial_8_13 NO_STORE
1026#define vstore_partial_8_14 NO_STORE
1027#define vstore_partial_8_15 NO_STORE
1028#define vstore_partial_8_16 NO_STORE
1029// Size == 16
1030#define vstore_partial_16_0 NO_STORE
1031#define vstore_partial_16_1 vstore_partial_1
1032#define vstore_partial_16_2 vstore_partial_2
1033#define vstore_partial_16_3 vstore_partial_3
1034#define vstore_partial_16_4 vstore_partial_4
1035#define vstore_partial_16_5 vstore_partial_5
1036#define vstore_partial_16_6 vstore_partial_6
1037#define vstore_partial_16_7 vstore_partial_7
1038#define vstore_partial_16_8 vstore_partial_8
1039#define vstore_partial_16_9 vstore_partial_9
1040#define vstore_partial_16_10 vstore_partial_10
1041#define vstore_partial_16_11 vstore_partial_11
1042#define vstore_partial_16_12 vstore_partial_12
1043#define vstore_partial_16_13 vstore_partial_13
1044#define vstore_partial_16_14 vstore_partial_14
1045#define vstore_partial_16_15 vstore_partial_15
1046#define vstore_partial_16_16 vstore_partial_16
1047
1048/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
1049 * @name vstore_partial_n
1050 *
1051 * @note @p DATA needs to be a vector not a scalar
1052 * @note n needs to be <= the vector width of the input variable @p DATA
1053 * eg 1: Valid
1054 * vstore_partial_15(var:float16, 0, 0xabcd);
1055 * eg 2: Invalid
1056 * vstore_partial_7(var:float4, 0, 0xabcd);
1057 *
1058 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
1059 *
1060 * @param[in] DATA   The name of the variable
1061 * @param[in] OFFSET Offset in n
1062 * @param[in] PTR    The base pointer
1063 * @{
1064 */
1065#define vstore_partial_1(DATA, OFFSET, PTR) \
1066    vstore1(DATA.s0, OFFSET, PTR);
1067
1068#define vstore_partial_2(DATA, OFFSET, PTR) \
1069    vstore2(DATA.s01, OFFSET, PTR);
1070
1071#define vstore_partial_3(DATA, OFFSET, PTR) \
1072    vstore3(DATA.s012, OFFSET, PTR);
1073
1074#define vstore_partial_4(DATA, OFFSET, PTR) \
1075    vstore4(DATA.s0123, OFFSET, PTR);
1076
1077#define vstore_partial_5(DATA, OFFSET, PTR)    \
1078    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1079    vstore1(DATA.s4, OFFSET, PTR + 4);
1080
1081#define vstore_partial_6(DATA, OFFSET, PTR)    \
1082    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1083    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
1084
1085#define vstore_partial_7(DATA, OFFSET, PTR)    \
1086    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1087    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
1088
1089#define vstore_partial_8(DATA, OFFSET, PTR) \
1090    vstore8(DATA.s01234567, OFFSET, PTR);
1091
1092#define vstore_partial_9(DATA, OFFSET, PTR)        \
1093    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1094    vstore1(DATA.s8, OFFSET, PTR + 8);
1095
1096#define vstore_partial_10(DATA, OFFSET, PTR)       \
1097    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1098    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
1099
1100#define vstore_partial_11(DATA, OFFSET, PTR)       \
1101    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1102    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
1103
1104#define vstore_partial_12(DATA, OFFSET, PTR)       \
1105    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1106    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
1107
1108#define vstore_partial_13(DATA, OFFSET, PTR)       \
1109    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1110    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
1111
1112#define vstore_partial_14(DATA, OFFSET, PTR)       \
1113    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1114    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
1115
1116#define vstore_partial_15(DATA, OFFSET, PTR)       \
1117    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1118    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
1119
1120#define vstore_partial_16(DATA, OFFSET, PTR) \
1121    vstore16(DATA, OFFSET, PTR);
1122/** @} */ // end of groupd vstore_partial_n
1123/** @} */ // end of groupd VSTORE_PARTIAL
1124
1125// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
1126// without _sat to overcome this issue
1127#define convert_float_sat convert_float
1128#define convert_float1_sat convert_float
1129#define convert_float2_sat convert_float2
1130#define convert_float3_sat convert_float3
1131#define convert_float4_sat convert_float4
1132#define convert_float8_sat convert_float8
1133#define convert_float16_sat convert_float16
1134#define convert_half_sat convert_float
1135#define convert_half1_sat convert_half
1136#define convert_half2_sat convert_half2
1137#define convert_half3_sat convert_half3
1138#define convert_half4_sat convert_half4
1139#define convert_half8_sat convert_half8
1140#define convert_half16_sat convert_half16
1141
1142#define convert_float1 convert_float
1143#define convert_half1 convert_half
1144#define convert_char1 convert_char
1145#define convert_uchar1 convert_uchar
1146#define convert_short1 convert_short
1147#define convert_ushort1 convert_ushort
1148#define convert_int1 convert_int
1149#define convert_uint1 convert_uint
1150#define convert_long1 convert_long
1151#define convert_ulong1 convert_ulong
1152#define convert_double1 convert_double
1153
1154#define convert_char1_sat convert_char_sat
1155#define convert_uchar1_sat convert_uchar_sat
1156#define convert_short1_sat convert_short_sat
1157#define convert_ushort1_sat convert_ushort_sat
1158#define convert_int1_sat convert_int_sat
1159#define convert_uint1_sat convert_uint_sat
1160#define convert_long1_sat convert_long_sat
1161#define convert_ulong1_sat convert_ulong_sat
1162#define convert_double1_sat convert_double_sat
1163
1164#define VEC_DATA_TYPE_STR(type, size) type##size
1165#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
1166
1167#define CONVERT_STR(x, type) (convert_##type((x)))
1168#define CONVERT(x, type) CONVERT_STR(x, type)
1169
1170#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
1171#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
1172
1173#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
1174#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
1175
1176#define select_vec_dt_uchar(size) uchar##size
1177#define select_vec_dt_char(size) char##size
1178#define select_vec_dt_ushort(size) ushort##size
1179#define select_vec_dt_short(size) short##size
1180#define select_vec_dt_half(size) short##size
1181#define select_vec_dt_uint(size) uint##size
1182#define select_vec_dt_int(size) int##size
1183#define select_vec_dt_float(size) int##size
1184#define select_vec_dt_ulong(size) ulong##size
1185#define select_vec_dt_long(size) long##size
1186
1187#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
1188#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
1189#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
1190
1191#define sum_reduce_1(x) (x)
1192#define sum_reduce_2(x) ((x).s0) + ((x).s1)
1193#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
1194#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
1195#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
1196#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
1197
1198#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
1199#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
1200
1201#define max_reduce_1(x) (x)
1202#define max_reduce_2(x) max(((x).s0), ((x).s1))
1203#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
1204#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
1205#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
1206#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
1207
1208#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
1209#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
1210
1211#define VECTOR_DECLARATION(name)     \
1212    __global uchar *name##_ptr,      \
1213    uint        name##_stride_x, \
1214    uint        name##_step_x,   \
1215    uint        name##_offset_first_element_in_bytes
1216
1217#define IMAGE_DECLARATION(name)      \
1218    __global uchar *name##_ptr,      \
1219    uint        name##_stride_x, \
1220    uint        name##_step_x,   \
1221    uint        name##_stride_y, \
1222    uint        name##_step_y,   \
1223    uint        name##_offset_first_element_in_bytes
1224
1225#define TENSOR3D_DECLARATION(name)   \
1226    __global uchar *name##_ptr,      \
1227    uint        name##_stride_x, \
1228    uint        name##_step_x,   \
1229    uint        name##_stride_y, \
1230    uint        name##_step_y,   \
1231    uint        name##_stride_z, \
1232    uint        name##_step_z,   \
1233    uint        name##_offset_first_element_in_bytes
1234
1235#define TENSOR4D_DECLARATION(name)   \
1236    __global uchar *name##_ptr,      \
1237    uint        name##_stride_x, \
1238    uint        name##_step_x,   \
1239    uint        name##_stride_y, \
1240    uint        name##_step_y,   \
1241    uint        name##_stride_z, \
1242    uint        name##_step_z,   \
1243    uint        name##_stride_w, \
1244    uint        name##_step_w,   \
1245    uint        name##_offset_first_element_in_bytes
1246
1247#define CONVERT_TO_VECTOR_STRUCT(name) \
1248    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
1249
1250#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
1251    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
1252
1253#define CONVERT_TO_IMAGE_STRUCT(name) \
1254    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
1255
1256#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
1257    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
1258
1259#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1260    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1261
1262#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
1263    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
1264
1265#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1266    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1267
1268#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
1269    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1270                                 name##_stride_z, name##_step_z)
1271
1272#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
1273    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
1274
1275#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
1276    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1277                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
1278
1279#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
1280    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
1281
1282#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
1283    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1284                           name##_stride_z, name##_step_z)
1285
1286/** Structure to hold Vector information */
1287typedef struct Vector
1288{
1289    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1290    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1291    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1292} Vector;
1293
1294/** Structure to hold Image information */
1295typedef struct Image
1296{
1297    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1298    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1299    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1300    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1301} Image;
1302
1303/** Structure to hold 3D tensor information */
1304typedef struct Tensor3D
1305{
1306    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1307    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1308    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1309    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1310    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1311} Tensor3D;
1312
1313/** Structure to hold 4D tensor information */
1314typedef struct Tensor4D
1315{
1316    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1317    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1318    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1319    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1320    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1321    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
1322} Tensor4D;
1323
1324/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
1325 *
1326 * @param[in] ptr                           Pointer to the starting postion of the buffer
1327 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
1328 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
1329 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1330 *
1331 * @return An image object
1332 */
1333inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
1334{
1335    Vector vector =
1336    {
1337        .ptr                           = ptr,
1338        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1339        .stride_x                      = stride_x,
1340    };
1341    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
1342    return vector;
1343}
1344
1345/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
1346 *
1347 * @param[in] ptr                           Pointer to the starting postion of the buffer
1348 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1349 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1350 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1351 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1352 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1353 *
1354 * @return An image object
1355 */
1356inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
1357{
1358    Image img =
1359    {
1360        .ptr                           = ptr,
1361        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1362        .stride_x                      = stride_x,
1363        .stride_y                      = stride_y
1364    };
1365    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
1366    return img;
1367}
1368
1369/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
1370 *
1371 * @param[in] ptr                           Pointer to the starting postion of the buffer
1372 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1373 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1374 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1375 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1376 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1377 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1378 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1379 *
1380 * @return A 3D tensor object
1381 */
1382inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1383{
1384    Image img =
1385    {
1386        .ptr                           = ptr,
1387        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1388        .stride_x                      = stride_x,
1389        .stride_y                      = stride_y
1390    };
1391    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1392    return img;
1393}
1394
1395/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
1396 *
1397 * @param[in] ptr                           Pointer to the starting postion of the buffer
1398 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1399 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1400 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1401 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1402 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1403 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1404 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1405 *
1406 * @return A 3D tensor object
1407 */
1408inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1409{
1410    Tensor3D tensor =
1411    {
1412        .ptr                           = ptr,
1413        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1414        .stride_x                      = stride_x,
1415        .stride_y                      = stride_y,
1416        .stride_z                      = stride_z
1417    };
1418    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1419    return tensor;
1420}
1421
1422/** Wrap 3D tensor information into an tensor structure.
1423 *
1424 * @param[in] ptr                           Pointer to the starting postion of the buffer
1425 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1426 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1427 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1428 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1429 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1430 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1431 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1432 *
1433 * @return A 3D tensor object
1434 */
1435inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1436{
1437    Tensor3D tensor =
1438    {
1439        .ptr                           = ptr,
1440        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1441        .stride_x                      = stride_x,
1442        .stride_y                      = stride_y,
1443        .stride_z                      = stride_z
1444    };
1445    return tensor;
1446}
1447
1448inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
1449                                             uint step_w,
1450                                             uint mod_size)
1451{
1452    Tensor4D tensor =
1453    {
1454        .ptr                           = ptr,
1455        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1456        .stride_x                      = stride_x,
1457        .stride_y                      = stride_y,
1458        .stride_z                      = stride_z,
1459        .stride_w                      = stride_w
1460    };
1461
1462    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
1463    return tensor;
1464}
1465
1466/** Get the pointer position of a Vector
1467 *
1468 * @param[in] vec Pointer to the starting position of the buffer
1469 * @param[in] x   Relative X position
1470 */
1471inline __global const uchar *vector_offset(const Vector *vec, int x)
1472{
1473    return vec->ptr + x * vec->stride_x;
1474}
1475
1476/** Get the pointer position of a Image
1477 *
1478 * @param[in] img Pointer to the starting position of the buffer
1479 * @param[in] x   Relative X position
1480 * @param[in] y   Relative Y position
1481 */
1482inline __global uchar *offset(const Image *img, int x, int y)
1483{
1484    return img->ptr + x * img->stride_x + y * img->stride_y;
1485}
1486
1487/** Get the pointer position of a Tensor3D
1488 *
1489 * @param[in] tensor Pointer to the starting position of the buffer
1490 * @param[in] x      Relative X position
1491 * @param[in] y      Relative Y position
1492 * @param[in] z      Relative Z position
1493 */
1494inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
1495{
1496    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
1497}
1498
1499/** Get the pointer position of a Tensor4D
1500 *
1501 * @param[in] tensor Pointer to the starting position of the buffer
1502 * @param[in] x      Relative X position
1503 * @param[in] y      Relative Y position
1504 * @param[in] z      Relative Z position
1505 * @param[in] w      Relative W position
1506 */
1507inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
1508{
1509    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
1510}
1511
1512/** Get the offset for a given linear index of a Tensor3D
1513 *
1514 * @param[in] tensor Pointer to the starting position of the buffer
1515 * @param[in] width  Width of the input tensor
1516 * @param[in] height Height of the input tensor
1517 * @param[in] depth  Depth of the input tensor
1518 * @param[in] index  Linear index
1519 */
1520inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
1521{
1522    uint num_elements = width * height;
1523
1524    const uint z = index / num_elements;
1525
1526    index %= num_elements;
1527
1528    const uint y = index / width;
1529
1530    index %= width;
1531
1532    const uint x = index;
1533
1534    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
1535}
1536
1537#endif // _HELPER_H
1538
1539#if GPU_ARCH == GPU_ARCH_BIFROST
1540#define MLA(a, b, c) (fma(c, b, a))
1541#else // GPU_ARCH == GPU_ARCH_BIFROST
1542#define MLA(a, b, c) ((b) * (c) + (a))
1543#endif // GPU_ARCH == GPU_ARCH_BIFROST
1544
1545// Hard-Swish
1546#define hard_swish_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * ((min(max((x + (DATA_TYPE)3.0), (DATA_TYPE)0.0), (DATA_TYPE)6.0)) * (DATA_TYPE)0.166666667))
1547
1548// Logistic Activation
1549#define logistic_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((DATA_TYPE)1.0 / ((DATA_TYPE)1.0 + exp(-x)))
1550
1551// Hyperbolic Tangent Activation
1552#define tanh_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((DATA_TYPE)A_VAL * tanh((DATA_TYPE)B_VAL * x))
1553
1554// RELU Tangent Activation
1555#define relu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (max((DATA_TYPE)0.0, x))
1556
1557// Bounded RELU Activation
1558#define brelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (min((DATA_TYPE)A_VAL, max((DATA_TYPE)0.0, x)))
1559
1560// Lower Upper Bounded RELU Activation
1561#define lu_brelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (min(max(x, (DATA_TYPE)B_VAL), (DATA_TYPE)A_VAL))
1562
1563// Leaky RELU Activation
1564#define lrelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((min(x, (DATA_TYPE)0.0) * (DATA_TYPE)A_VAL) + max(x, (DATA_TYPE)0.0))
1565
1566// Soft RELU Activation
1567#define srelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (log((DATA_TYPE)1.0 + exp(x)))
1568
1569// ELU Activation
1570#define elu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (select(((DATA_TYPE)A_VAL * (exp(x) - (DATA_TYPE)1.0)), x, (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))isgreaterequal(x, (DATA_TYPE)0.0)))
1571
1572// Absolute Activation
1573#define abs_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (fabs(x))
1574
1575// Square Activation
1576#define square_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * x)
1577
1578// Square-root Activation
1579#define sqrt_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (sqrt(x))
1580
1581// Linear Activation
1582#define linear_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (MLA((DATA_TYPE)B_VAL, (DATA_TYPE)A_VAL, x))
1583
1584// Identity Activation
1585#define identity_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x)
1586
1587#define ACT_OP(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) op##_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL)
1588
1589#define ACTIVATION(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ACT_OP(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL)
1590/*
1591 * Copyright (c) 2016-2020 Arm Limited.
1592 *
1593 * SPDX-License-Identifier: MIT
1594 *
1595 * Permission is hereby granted, free of charge, to any person obtaining a copy
1596 * of this software and associated documentation files (the "Software"), to
1597 * deal in the Software without restriction, including without limitation the
1598 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
1599 * sell copies of the Software, and to permit persons to whom the Software is
1600 * furnished to do so, subject to the following conditions:
1601 *
1602 * The above copyright notice and this permission notice shall be included in all
1603 * copies or substantial portions of the Software.
1604 *
1605 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1606 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1607 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1608 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1609 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1610 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1611 * SOFTWARE.
1612 */
1613#ifndef ARM_COMPUTE_HELPER_H
1614#define ARM_COMPUTE_HELPER_H
1615
1616/*
1617 * Copyright (c) 2020 Arm Limited.
1618 *
1619 * SPDX-License-Identifier: MIT
1620 *
1621 * Permission is hereby granted, free of charge, to any person obtaining a copy
1622 * of this software and associated documentation files (the "Software"), to
1623 * deal in the Software without restriction, including without limitation the
1624 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
1625 * sell copies of the Software, and to permit persons to whom the Software is
1626 * furnished to do so, subject to the following conditions:
1627 *
1628 * The above copyright notice and this permission notice shall be included in all
1629 * copies or substantial portions of the Software.
1630 *
1631 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
1632 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
1633 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1634 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1635 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
1636 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
1637 * SOFTWARE.
1638 */
1639
1640/** Store the 0 to (n-1)th rows of the given variables
1641 * @name STORE_ROW_n
1642 *
1643 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
1644 * @param[in] DATA_TYPE The data type of the vectors
1645 * @param[in] BASENAME  The basename of the variables
1646 * @param[in] PTR       The base pointer
1647 * @param[in] STRIDE_Y  The stride value in y-axis direction
1648 * @param[in] Z         The offset in z-axis direction
1649 * @{
1650 */
1651#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1652    VSTORE(N0)                                                 \
1653    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
1654
1655#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1656    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1657    VSTORE(N0)                                                 \
1658    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
1659
1660#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1661    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1662    VSTORE(N0)                                                 \
1663    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
1664
1665#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1666    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1667    VSTORE(N0)                                                 \
1668    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
1669
1670#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1671    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1672    VSTORE(N0)                                                 \
1673    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
1674
1675#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1676    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1677    VSTORE(N0)                                                 \
1678    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
1679
1680#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1681    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1682    VSTORE(N0)                                                 \
1683    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
1684
1685#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1686    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1687    VSTORE(N0)                                                 \
1688    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
1689
1690#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1691    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1692    VSTORE(N0)                                                 \
1693    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
1694
1695#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1696    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
1697    VSTORE(N0)                                                  \
1698    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
1699
1700#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1701    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1702    VSTORE(N0)                                                  \
1703    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
1704
1705#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1706    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1707    VSTORE(N0)                                                  \
1708    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
1709
1710#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1711    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1712    VSTORE(N0)                                                  \
1713    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
1714
1715#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1716    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1717    VSTORE(N0)                                                  \
1718    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
1719
1720#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1721    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1722    VSTORE(N0)                                                  \
1723    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
1724
1725#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1726    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1727    VSTORE(N0)                                                  \
1728    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
1729/** @} */ // end of groupd STORE_ROW_n
1730
1731/** Convert and store the 0th to (n-1)th rows of the given variables
1732 * @name CONVERT_STORE_ROW_n
1733 *
1734 * @param[in] N0        The size of the vectors
1735 * @param[in] DATA_TYPE The data type of the vectors
1736 * @param[in] BASENAME  The basename of the variables
1737 * @param[in] PTR       The base pointer
1738 * @param[in] STRIDE_Y  The stride value in y-axis direction
1739 * @param[in] Z         The offset in z-axis direction
1740 * @{
1741 */
1742#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1743    VSTORE(N0)                                                         \
1744    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
1745
1746#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1747    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1748    VSTORE(N0)                                                         \
1749    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
1750
1751#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1752    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1753    VSTORE(N0)                                                         \
1754    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
1755
1756#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1757    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1758    VSTORE(N0)                                                         \
1759    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
1760
1761#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1762    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1763    VSTORE(N0)                                                         \
1764    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
1765
1766#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1767    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1768    VSTORE(N0)                                                         \
1769    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
1770
1771#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1772    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1773    VSTORE(N0)                                                         \
1774    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
1775
1776#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1777    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1778    VSTORE(N0)                                                         \
1779    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
1780
1781#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1782    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1783    VSTORE(N0)                                                         \
1784    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
1785
1786#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
1787    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1788    VSTORE(N0)                                                     \
1789    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
1790
1791#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1792    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1793    VSTORE(N0)                                                          \
1794    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
1795
1796#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1797    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1798    VSTORE(N0)                                                          \
1799    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
1800
1801#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1802    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1803    VSTORE(N0)                                                          \
1804    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
1805
1806#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1807    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1808    VSTORE(N0)                                                          \
1809    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
1810
1811#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1812    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1813    VSTORE(N0)                                                          \
1814    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
1815
1816#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1817    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1818    VSTORE(N0)                                                          \
1819    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
1820
1821/** @} */ // end of groupd CONVERT_STORE_ROW_n
1822
1823/** Store a block of the given size M0xN0
1824 * @name STORE_BLOCK
1825 *
1826 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
1827 * The data to store is expected to have consecutive names for each row.
1828 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
1829 * The Z offset is expected to have consecutive names.
1830 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
1831 *
1832 * @param[in] M0        The number of rows to store
1833 * @param[in] N0        The size of each vector
1834 * @param[in] DATA_TYPE The data type of the vectors
1835 * @param[in] BASENAME  The basename of the variables
1836 * @param[in] PTR       The base pointer
1837 * @param[in] STRIDE_Y  The stride value in y-axis direction
1838 * @param[in] Z         The offset in z-axis direction
1839 * @{
1840 */
1841#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
1842#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
1843/** @} */ // end of group STORE_BLOCK
1844
1845/** Convert and store a block of the given size M0xN0
1846 * @name CONVERT_STORE_BLOCK
1847 *
1848 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
1849 * The data to store is expected to have consecutive names for each row.
1850 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
1851 * The Z offset is expected to have consecutive names.
1852 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
1853 *
1854 * @param[in] M0        The number of rows to store
1855 * @param[in] N0        The size of each vector
1856 * @param[in] DATA_TYPE The data type of the vectors
1857 * @param[in] BASENAME  The basename of the variables
1858 * @param[in] PTR       The base pointer
1859 * @param[in] STRIDE_Y  The stride value in y-axis direction
1860 * @param[in] Z         The offset in z-axis direction
1861 * @{
1862 */
1863#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
1864#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
1865/** @} */ // end of group CONVERT_STORE_BLOCK
1866
1867/** Partially store the 0 to (n-1)th rows of the given variables
1868 * @name STORE_ROW_PARTIAL_n
1869 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
1870 *
1871 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
1872 *
1873 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
1874 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
1875 * @param[in] DATA_TYPE The data type of the vectors
1876 * @param[in] BASENAME  The basename of the variables
1877 * @param[in] PTR       The base pointer
1878 * @param[in] STRIDE_Y  The stride value in y-axis direction
1879 * @param[in] Z         The offset in z-axis direction
1880 * @{
1881 */
1882#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1883    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1884    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
1885
1886#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1887    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1888    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1889    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
1890
1891#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1892    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1893    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1894    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
1895
1896#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1897    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1898    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1899    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
1900
1901#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1902    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1903    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1904    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
1905
1906#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1907    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1908    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1909    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
1910
1911#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1912    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1913    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1914    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
1915
1916#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1917    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1918    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1919    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
1920
1921#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1922    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1923    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
1924    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
1925
1926#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1927    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
1928    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1929    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
1930
1931#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1932    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1933    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1934    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
1935
1936#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1937    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1938    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1939    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
1940
1941#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1942    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1943    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1944    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
1945
1946#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1947    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1948    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1949    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
1950
1951#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1952    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1953    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1954    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
1955
1956#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
1957    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
1958    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
1959    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
1960/** @} */ // end of groupd STORE_ROW_PARTIAL_n
1961
1962/** Partially store a block of the given size STORE_M0xSTORE_N0
1963 * @name STORE_BLOCK_PARTIAL
1964 *
1965 * @note The vector width @p N0 is also required for correct partial storing behaviour.
1966 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
1967 *
1968 * The data to store is expected to have consecutive names for each row.
1969 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
1970 * The Z offset is expected to have consecutive names.
1971 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
1972 *
1973 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
1974 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
1975 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
1976 * @param[in] DATA_TYPE The data type of the vectors
1977 * @param[in] BASENAME  The basename of the variables
1978 * @param[in] PTR       The base pointer
1979 * @param[in] STRIDE_Y  The stride value in y-axis direction
1980 * @param[in] Z         The offset in z-axis direction
1981 * @{
1982 */
1983#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
1984#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
1985/** Store a block that can be partial in both x and y dimensions
1986 *
1987 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
1988 *
1989 * The data to store is expected to have consecutive names for each row.
1990 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
1991 * The Z offset is expected to have consecutive names.
1992 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
1993 *
1994 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
1995 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
1996 * @param[in] DATA_TYPE        The data type of the vectors
1997 * @param[in] BASENAME         The basename of the variables
1998 * @param[in] PTR              The base pointer
1999 * @param[in] STRIDE_Y         The stride value in y-axis direction
2000 * @param[in] Z                The offset in z-axis direction
2001 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
2002 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
2003 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
2004 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
2005 */
2006#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
2007    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
2008    {                                                                                                                                                     \
2009        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
2010    }                                                                                                                                                     \
2011    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
2012    {                                                                                                                                                     \
2013        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
2014    }                                                                                                                                                     \
2015    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
2016    {                                                                                                                                                     \
2017        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
2018    }                                                                                                                                                     \
2019    else                                                                                                                                                  \
2020    {                                                                                                                                                     \
2021        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
2022    }
2023/** Store a block that can only be partial in x but not y.
2024 *
2025 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
2026 *
2027 * The data to store is expected to have consecutive names for each row.
2028 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
2029 * The Z offset is expected to have consecutive names.
2030 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
2031 *
2032 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
2033 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
2034 * @param[in] DATA_TYPE        The data type of the vectors
2035 * @param[in] BASENAME         The basename of the variables
2036 * @param[in] PTR              The base pointer
2037 * @param[in] STRIDE_Y         The stride value in y-axis direction
2038 * @param[in] Z                The offset in z-axis direction
2039 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
2040 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
2041 */
2042#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
2043    if(!(PARTIAL_COND_X))                                                                                         \
2044    {                                                                                                             \
2045        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
2046    }                                                                                                             \
2047    else                                                                                                          \
2048    {                                                                                                             \
2049        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
2050    }
2051/** Store a block that can only be partial in y but not x.
2052 *
2053 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
2054 *
2055 * The data to store is expected to have consecutive names for each row.
2056 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
2057 * The Z offset is expected to have consecutive names.
2058 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
2059 *
2060 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
2061 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
2062 * @param[in] DATA_TYPE        The data type of the vectors
2063 * @param[in] BASENAME         The basename of the variables
2064 * @param[in] PTR              The base pointer
2065 * @param[in] STRIDE_Y         The stride value in y-axis direction
2066 * @param[in] Z                The offset in z-axis direction
2067 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
2068 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
2069 */
2070#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
2071    if(!(PARTIAL_COND_Y))                                                                                         \
2072    {                                                                                                             \
2073        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
2074    }                                                                                                             \
2075    else                                                                                                          \
2076    {                                                                                                             \
2077        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
2078    }
2079/** @} */ // end of group STORE_BLOCK_PARTIAL
2080
2081#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
2082
2083/** Boundary-aware GEMM block store
2084 * @name STORE_BLOCK_BOUNDARY_AWARE
2085 * This macro assumes the following schemes to achieve boundary-awareness:
2086 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
2087 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
2088 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
2089 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
2090 *
2091 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
2092 * blocks **at the end**.
2093 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
2094 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
2095 *
2096 *  *--x-->                         x == 0                        x == 1
2097 *  |                  |<------------------------------N-------------------------->|
2098 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
2099 *  |     -------------#############################################################
2100 *  *     |          | |...............................|...........................|
2101 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
2102 *        |          | |...............................|...........................|
2103 *        M          --#############################################################
2104 *        |          | |                               |...........................|
2105 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
2106 *        |          | |                               |...........................|
2107 *        |------------#############################################################
2108 *
2109 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
2110 *
2111 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
2112 *
2113 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
2114 * and select corresponding store methods such that the boundary detection logic is only added when needed.
2115 *
2116 * The data to store is expected to have consecutive names for each row.
2117 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
2118 * The Z offset is expected to have consecutive names.
2119 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
2120 *
2121 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
2122 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
2123 * @param[in] DATA_TYPE        The data type of the vectors
2124 * @param[in] BASENAME         The basename of the variables
2125 * @param[in] PTR              The base pointer
2126 * @param[in] STRIDE_Y         The stride value in y-axis direction
2127 * @param[in] Z                The offset in z-axis direction
2128 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
2129 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
2130 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
2131 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
2132 * @{
2133 */
2134#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
2135// Case1: No partial blocks in either x or y
2136#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
2137    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
2138
2139#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
2140// Case2: Partial blocks in y
2141#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
2142    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
2143
2144#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
2145// Case3: Partial blocks in x
2146#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
2147    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
2148
2149#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
2150// Case4: Partial blocks in both x and y
2151#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
2152    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
2153
2154#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
2155
2156#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
2157/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
2158
2159#if defined(PARTIAL_STORE_M0)
2160/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
2161 * @name COMPUTE_M0_START_ROW
2162 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
2163 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
2164 * blocks in the y dimension to avoid any padding.
2165 * EG: M0=4, PARTIAL_STORE_M0=1:
2166 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
2167 * block 0 (partial)| start row = 0   | start row = 0
2168 * block 1 (full)   | start row = 4   | start row = 1
2169 * block 2 (full)   | start row = 8   | start row = 5
2170 *
2171 * @param[in] y                Global id of current block in y.
2172 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
2173 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
2174 * @{
2175 */
2176#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
2177    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
2178#else // defined(PARTIAL_STORE_M0)
2179#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
2180    ((uint)(y * M0))
2181#endif    // defined(PARTIAL_STORE_M0)
2182/** @} */ // end of group COMPUTE_M0_START_ROW
2183
2184/** Store a vector that can only be partial in x.
2185 *
2186 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
2187 *
2188 * The data to store is expected to end in a 0.
2189 * E.g., for basename=c, the expected name is c0.
2190 *
2191 * @param[in] basename  The name of the variable without trailing 0
2192 * @param[in] data_type The data type of the vector
2193 * @param[in] ptr       The base pointer
2194 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
2195 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
2196 * @param[in] cond      Condition to select either vec_size0 or vec_size1
2197 * @{
2198 */
2199#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
2200    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
2201/** @} */ // end of group STORE_VECTOR_SELECT
2202
2203#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
2204#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2205#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
2206
2207#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2208#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
2209#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
2210
2211#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
2212#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
2213#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
2214
2215#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
2216#pragma OPENCL EXTENSION cl_arm_printf : enable
2217#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
2218
2219#define GPU_ARCH_MIDGARD 0x100
2220#define GPU_ARCH_BIFROST 0x200
2221
2222/** Concatenate two inputs.
2223 *
2224 * @param[in] a The first input to be concatenated
2225 * @param[in] b The second input to be concatenated
2226 *
2227 * @return The concatenated output
2228 */
2229#define CONCAT(a, b) a##b
2230
2231/** Expand the given vector
2232 *
2233 * @param[in] x The vector to be expanded
2234 *
2235 * @return The expanded output
2236 */
2237#define EXPAND(x) x
2238
2239/** Clamp the given value between an upper and lower bound.
2240 *
2241 * @param[in] x       The value to be clamped
2242 * @param[in] min_val The lower bound
2243 * @param[in] max_val The upper bound
2244 *
2245 * @return The clamped value.
2246 */
2247#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
2248
2249/** REVn reverses the given vector whose size is n.
2250 * @name REVn
2251 *
2252 * @param[in] x The vector to be reversed
2253 *
2254 * @return The reversed vector
2255 * @{
2256 */
2257#define REV1(x) ((x))
2258#define REV2(x) ((x).s10)
2259#define REV3(x) ((x).s210)
2260#define REV4(x) ((x).s3210)
2261#define REV8(x) ((x).s76543210)
2262#define REV16(x) ((x).sFEDCBA9876543210)
2263/** @} */ // end of group REVn
2264
2265/** Reverse the given vector.
2266 * @name REVERSE
2267 *
2268 * @param[in] x The vector to be reversed
2269 * @param[in] s The size of the vector
2270 *
2271 * @return The reversed vector
2272 * @{
2273 */
2274#define REVERSE_STR(x, s) REV##s((x))
2275#define REVERSE(x, s) REVERSE_STR(x, s)
2276/** @} */ // end of group REVERSE
2277
2278/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
2279 * @name ROTs_n
2280 *
2281 * @param[in] x The vector to be shifted
2282 *
2283 * @return The shifted vector
2284 * @{
2285 */
2286#define ROT1_0(x) ((x))
2287
2288#define ROT2_0(x) ((x))
2289#define ROT2_1(x) ((x).s10)
2290
2291#define ROT3_0(x) ((x))
2292#define ROT3_1(x) ((x).s201)
2293#define ROT3_2(x) ((x).s120)
2294
2295#define ROT4_0(x) ((x))
2296#define ROT4_1(x) ((x).s3012)
2297#define ROT4_2(x) ((x).s2301)
2298#define ROT4_3(x) ((x).s1230)
2299
2300#define ROT8_0(x) ((x))
2301#define ROT8_1(x) ((x).s70123456)
2302#define ROT8_2(x) ((x).s67012345)
2303#define ROT8_3(x) ((x).s56701234)
2304#define ROT8_4(x) ((x).s45670123)
2305#define ROT8_5(x) ((x).s34567012)
2306#define ROT8_6(x) ((x).s23456701)
2307#define ROT8_7(x) ((x).s12345670)
2308
2309#define ROT16_0(x) ((x))
2310#define ROT16_1(x) ((x).sF0123456789ABCDE)
2311#define ROT16_2(x) ((x).sEF0123456789ABCD)
2312#define ROT16_3(x) ((x).sDEF0123456789ABC)
2313#define ROT16_4(x) ((x).sCDEF0123456789AB)
2314#define ROT16_5(x) ((x).sBCDEF0123456789A)
2315#define ROT16_6(x) ((x).sABCDEF0123456789)
2316#define ROT16_7(x) ((x).s9ABCDEF012345678)
2317#define ROT16_8(x) ((x).s89ABCDEF01234567)
2318#define ROT16_9(x) ((x).s789ABCDEF0123456)
2319#define ROT16_10(x) ((x).s6789ABCDEF012345)
2320#define ROT16_11(x) ((x).s56789ABCDEF01234)
2321#define ROT16_12(x) ((x).s456789ABCDEF0123)
2322#define ROT16_13(x) ((x).s3456789ABCDEF012)
2323#define ROT16_14(x) ((x).s23456789ABCDEF01)
2324#define ROT16_15(x) ((x).s123456789ABCDEF0)
2325/** @} */ // end of group ROTs_n
2326
2327/** Circular-right-shift (rotate-right) the given vector by the given amount.
2328 * @name ROTATE
2329 *
2330 * @param[in] x The vector to be shifted
2331 * @param[in] s The size of the vector
2332 * @param[in] n The amount to be shifted
2333 *
2334 * @return The shifted vector
2335 * @{
2336 */
2337#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
2338#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
2339/** @} */ // end of group ROTATE
2340
2341/** Creates a vector of size n filled with offset values corresponding to the location of each element.
2342 * @name V_OFFSn
2343 *
2344 * @param[in] dt The data type of the output vector
2345 *
2346 * @return The vector filled with offset values
2347 * @{
2348 */
2349#define V_OFFS1(dt) (dt##1)(0)
2350#define V_OFFS2(dt) (dt##2)(0, 1)
2351#define V_OFFS3(dt) (dt##3)(0, 1, 2)
2352#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
2353#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
2354#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
2355/** @} */ // end of group V_OFFSn
2356
2357/** Create a vector filled with offset values corresponding to the location of each element.
2358 * @name VEC_OFFS
2359 *
2360 * @param[in] dt The data type of the output vector
2361 * @param[in] s  The size of the output vector
2362 *
2363 * @return The vector filled with offset values
2364 * @{
2365 */
2366#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
2367#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
2368/** @} */ // end of group VEC_OFFS
2369
2370#define VLOAD_STR(size) vload##size
2371#define VLOAD(size) VLOAD_STR(size)
2372
2373#define PIXEL_UNIT4 1
2374#define PIXEL_UNIT8 2
2375#define PIXEL_UNIT16 4
2376
2377/** Utility macro to convert a vector size in pixel unit.
2378 *
2379 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
2380 *
2381 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
2382 *
2383 * @return The pixel unit (number of pixels)
2384 * @{
2385 */
2386#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
2387#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
2388/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
2389
2390#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
2391#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
2392#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
2393
2394#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
2395#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
2396#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
2397#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
2398#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
2399
2400/** Utility macro to read a 2D OpenCL image object.
2401 *
2402 * @note Coordinates are not normalized
2403 *
2404 * @param[in] data_type Data type
2405 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
2406 * @param[in] img       OpenCL image object
2407 * @param[in] x_coord   The x coordinate for the top-left pixel
2408 * @param[in] y_coord   The y coordinate for the top-left pixel
2409 *
2410 * @return Pixels from the 2D OpenCL image object
2411 * @{
2412 */
2413#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
2414#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
2415
2416#define VSTORE_STR(size) vstore##size
2417#define VSTORE(size) VSTORE_STR(size)
2418
2419#define float1 float
2420#define half1 half
2421#define char1 char
2422#define uchar1 uchar
2423#define short1 short
2424#define ushort1 ushort
2425#define int1 int
2426#define uint1 uint
2427#define long1 long
2428#define ulong1 ulong
2429#define double1 double
2430
2431#define vload1(OFFSET, PTR) *(OFFSET + PTR)
2432#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
2433
2434/** Extended partial vstore that correctly handles scalar values as well.
2435 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
2436 * @name VSTORE_PARTIAL
2437 *
2438 * @note With this macro, the passed data can be both a vector and a scalar
2439 * @note @p store_size needs to be <= @p size
2440 * eg 1: Valid
2441 * VSTORE_PARTIAL(16, 15) ...;
2442 * eg 2: Invalid
2443 * VSTORE_PARTIAL(4, 7) ...;
2444 *
2445 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
2446 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
2447 * @{
2448 */
2449#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
2450#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
2451
2452#define NO_STORE(data, offs, ptr) \
2453    {                             \
2454    }
2455
2456// Size == 1 (scalar)
2457#define vstore_partial_1_0 NO_STORE
2458#define vstore_partial_1_1 vstore1
2459#define vstore_partial_1_2 NO_STORE
2460#define vstore_partial_1_3 NO_STORE
2461#define vstore_partial_1_4 NO_STORE
2462#define vstore_partial_1_5 NO_STORE
2463#define vstore_partial_1_6 NO_STORE
2464#define vstore_partial_1_7 NO_STORE
2465#define vstore_partial_1_8 NO_STORE
2466#define vstore_partial_1_9 NO_STORE
2467#define vstore_partial_1_10 NO_STORE
2468#define vstore_partial_1_11 NO_STORE
2469#define vstore_partial_1_12 NO_STORE
2470#define vstore_partial_1_13 NO_STORE
2471#define vstore_partial_1_14 NO_STORE
2472#define vstore_partial_1_15 NO_STORE
2473#define vstore_partial_1_16 NO_STORE
2474// Size == 2
2475#define vstore_partial_2_0 NO_STORE
2476#define vstore_partial_2_1 vstore_partial_1
2477#define vstore_partial_2_2 vstore_partial_2
2478#define vstore_partial_2_3 NO_STORE
2479#define vstore_partial_2_4 NO_STORE
2480#define vstore_partial_2_5 NO_STORE
2481#define vstore_partial_2_6 NO_STORE
2482#define vstore_partial_2_7 NO_STORE
2483#define vstore_partial_2_8 NO_STORE
2484#define vstore_partial_2_9 NO_STORE
2485#define vstore_partial_2_10 NO_STORE
2486#define vstore_partial_2_11 NO_STORE
2487#define vstore_partial_2_12 NO_STORE
2488#define vstore_partial_2_13 NO_STORE
2489#define vstore_partial_2_14 NO_STORE
2490#define vstore_partial_2_15 NO_STORE
2491#define vstore_partial_2_16 NO_STORE
2492// Size == 3
2493#define vstore_partial_3_0 NO_STORE
2494#define vstore_partial_3_1 vstore_partial_1
2495#define vstore_partial_3_2 vstore_partial_2
2496#define vstore_partial_3_3 vstore_partial_3
2497#define vstore_partial_3_4 NO_STORE
2498#define vstore_partial_3_5 NO_STORE
2499#define vstore_partial_3_6 NO_STORE
2500#define vstore_partial_3_7 NO_STORE
2501#define vstore_partial_3_8 NO_STORE
2502#define vstore_partial_3_9 NO_STORE
2503#define vstore_partial_3_10 NO_STORE
2504#define vstore_partial_3_11 NO_STORE
2505#define vstore_partial_3_12 NO_STORE
2506#define vstore_partial_3_13 NO_STORE
2507#define vstore_partial_3_14 NO_STORE
2508#define vstore_partial_3_15 NO_STORE
2509#define vstore_partial_3_16 NO_STORE
2510// Size == 4
2511#define vstore_partial_4_0 NO_STORE
2512#define vstore_partial_4_1 vstore_partial_1
2513#define vstore_partial_4_2 vstore_partial_2
2514#define vstore_partial_4_3 vstore_partial_3
2515#define vstore_partial_4_4 vstore_partial_4
2516#define vstore_partial_4_5 NO_STORE
2517#define vstore_partial_4_6 NO_STORE
2518#define vstore_partial_4_7 NO_STORE
2519#define vstore_partial_4_8 NO_STORE
2520#define vstore_partial_4_9 NO_STORE
2521#define vstore_partial_4_10 NO_STORE
2522#define vstore_partial_4_11 NO_STORE
2523#define vstore_partial_4_12 NO_STORE
2524#define vstore_partial_4_13 NO_STORE
2525#define vstore_partial_4_14 NO_STORE
2526#define vstore_partial_4_15 NO_STORE
2527#define vstore_partial_4_16 NO_STORE
2528// Size == 8
2529#define vstore_partial_8_0 NO_STORE
2530#define vstore_partial_8_1 vstore_partial_1
2531#define vstore_partial_8_2 vstore_partial_2
2532#define vstore_partial_8_3 vstore_partial_3
2533#define vstore_partial_8_4 vstore_partial_4
2534#define vstore_partial_8_5 vstore_partial_5
2535#define vstore_partial_8_6 vstore_partial_6
2536#define vstore_partial_8_7 vstore_partial_7
2537#define vstore_partial_8_8 vstore_partial_8
2538#define vstore_partial_8_9 NO_STORE
2539#define vstore_partial_8_10 NO_STORE
2540#define vstore_partial_8_11 NO_STORE
2541#define vstore_partial_8_12 NO_STORE
2542#define vstore_partial_8_13 NO_STORE
2543#define vstore_partial_8_14 NO_STORE
2544#define vstore_partial_8_15 NO_STORE
2545#define vstore_partial_8_16 NO_STORE
2546// Size == 16
2547#define vstore_partial_16_0 NO_STORE
2548#define vstore_partial_16_1 vstore_partial_1
2549#define vstore_partial_16_2 vstore_partial_2
2550#define vstore_partial_16_3 vstore_partial_3
2551#define vstore_partial_16_4 vstore_partial_4
2552#define vstore_partial_16_5 vstore_partial_5
2553#define vstore_partial_16_6 vstore_partial_6
2554#define vstore_partial_16_7 vstore_partial_7
2555#define vstore_partial_16_8 vstore_partial_8
2556#define vstore_partial_16_9 vstore_partial_9
2557#define vstore_partial_16_10 vstore_partial_10
2558#define vstore_partial_16_11 vstore_partial_11
2559#define vstore_partial_16_12 vstore_partial_12
2560#define vstore_partial_16_13 vstore_partial_13
2561#define vstore_partial_16_14 vstore_partial_14
2562#define vstore_partial_16_15 vstore_partial_15
2563#define vstore_partial_16_16 vstore_partial_16
2564
2565/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
2566 * @name vstore_partial_n
2567 *
2568 * @note @p DATA needs to be a vector not a scalar
2569 * @note n needs to be <= the vector width of the input variable @p DATA
2570 * eg 1: Valid
2571 * vstore_partial_15(var:float16, 0, 0xabcd);
2572 * eg 2: Invalid
2573 * vstore_partial_7(var:float4, 0, 0xabcd);
2574 *
2575 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
2576 *
2577 * @param[in] DATA   The name of the variable
2578 * @param[in] OFFSET Offset in n
2579 * @param[in] PTR    The base pointer
2580 * @{
2581 */
2582#define vstore_partial_1(DATA, OFFSET, PTR) \
2583    vstore1(DATA.s0, OFFSET, PTR);
2584
2585#define vstore_partial_2(DATA, OFFSET, PTR) \
2586    vstore2(DATA.s01, OFFSET, PTR);
2587
2588#define vstore_partial_3(DATA, OFFSET, PTR) \
2589    vstore3(DATA.s012, OFFSET, PTR);
2590
2591#define vstore_partial_4(DATA, OFFSET, PTR) \
2592    vstore4(DATA.s0123, OFFSET, PTR);
2593
2594#define vstore_partial_5(DATA, OFFSET, PTR)    \
2595    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
2596    vstore1(DATA.s4, OFFSET, PTR + 4);
2597
2598#define vstore_partial_6(DATA, OFFSET, PTR)    \
2599    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
2600    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
2601
2602#define vstore_partial_7(DATA, OFFSET, PTR)    \
2603    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
2604    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
2605
2606#define vstore_partial_8(DATA, OFFSET, PTR) \
2607    vstore8(DATA.s01234567, OFFSET, PTR);
2608
2609#define vstore_partial_9(DATA, OFFSET, PTR)        \
2610    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2611    vstore1(DATA.s8, OFFSET, PTR + 8);
2612
2613#define vstore_partial_10(DATA, OFFSET, PTR)       \
2614    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2615    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
2616
2617#define vstore_partial_11(DATA, OFFSET, PTR)       \
2618    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2619    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
2620
2621#define vstore_partial_12(DATA, OFFSET, PTR)       \
2622    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2623    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
2624
2625#define vstore_partial_13(DATA, OFFSET, PTR)       \
2626    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2627    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
2628
2629#define vstore_partial_14(DATA, OFFSET, PTR)       \
2630    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2631    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
2632
2633#define vstore_partial_15(DATA, OFFSET, PTR)       \
2634    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
2635    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
2636
2637#define vstore_partial_16(DATA, OFFSET, PTR) \
2638    vstore16(DATA, OFFSET, PTR);
2639/** @} */ // end of groupd vstore_partial_n
2640/** @} */ // end of groupd VSTORE_PARTIAL
2641
2642// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
2643// without _sat to overcome this issue
2644#define convert_float_sat convert_float
2645#define convert_float1_sat convert_float
2646#define convert_float2_sat convert_float2
2647#define convert_float3_sat convert_float3
2648#define convert_float4_sat convert_float4
2649#define convert_float8_sat convert_float8
2650#define convert_float16_sat convert_float16
2651#define convert_half_sat convert_float
2652#define convert_half1_sat convert_half
2653#define convert_half2_sat convert_half2
2654#define convert_half3_sat convert_half3
2655#define convert_half4_sat convert_half4
2656#define convert_half8_sat convert_half8
2657#define convert_half16_sat convert_half16
2658
2659#define convert_float1 convert_float
2660#define convert_half1 convert_half
2661#define convert_char1 convert_char
2662#define convert_uchar1 convert_uchar
2663#define convert_short1 convert_short
2664#define convert_ushort1 convert_ushort
2665#define convert_int1 convert_int
2666#define convert_uint1 convert_uint
2667#define convert_long1 convert_long
2668#define convert_ulong1 convert_ulong
2669#define convert_double1 convert_double
2670
2671#define convert_char1_sat convert_char_sat
2672#define convert_uchar1_sat convert_uchar_sat
2673#define convert_short1_sat convert_short_sat
2674#define convert_ushort1_sat convert_ushort_sat
2675#define convert_int1_sat convert_int_sat
2676#define convert_uint1_sat convert_uint_sat
2677#define convert_long1_sat convert_long_sat
2678#define convert_ulong1_sat convert_ulong_sat
2679#define convert_double1_sat convert_double_sat
2680
2681#define VEC_DATA_TYPE_STR(type, size) type##size
2682#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
2683
2684#define CONVERT_STR(x, type) (convert_##type((x)))
2685#define CONVERT(x, type) CONVERT_STR(x, type)
2686
2687#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
2688#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
2689
2690#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
2691#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
2692
2693#define select_vec_dt_uchar(size) uchar##size
2694#define select_vec_dt_char(size) char##size
2695#define select_vec_dt_ushort(size) ushort##size
2696#define select_vec_dt_short(size) short##size
2697#define select_vec_dt_half(size) short##size
2698#define select_vec_dt_uint(size) uint##size
2699#define select_vec_dt_int(size) int##size
2700#define select_vec_dt_float(size) int##size
2701#define select_vec_dt_ulong(size) ulong##size
2702#define select_vec_dt_long(size) long##size
2703
2704#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
2705#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
2706#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
2707
2708#define sum_reduce_1(x) (x)
2709#define sum_reduce_2(x) ((x).s0) + ((x).s1)
2710#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
2711#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
2712#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
2713#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
2714
2715#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
2716#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
2717
2718#define max_reduce_1(x) (x)
2719#define max_reduce_2(x) max(((x).s0), ((x).s1))
2720#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
2721#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
2722#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
2723#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
2724
2725#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
2726#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
2727
2728#define VECTOR_DECLARATION(name)     \
2729    __global uchar *name##_ptr,      \
2730    uint        name##_stride_x, \
2731    uint        name##_step_x,   \
2732    uint        name##_offset_first_element_in_bytes
2733
2734#define IMAGE_DECLARATION(name)      \
2735    __global uchar *name##_ptr,      \
2736    uint        name##_stride_x, \
2737    uint        name##_step_x,   \
2738    uint        name##_stride_y, \
2739    uint        name##_step_y,   \
2740    uint        name##_offset_first_element_in_bytes
2741
2742#define TENSOR3D_DECLARATION(name)   \
2743    __global uchar *name##_ptr,      \
2744    uint        name##_stride_x, \
2745    uint        name##_step_x,   \
2746    uint        name##_stride_y, \
2747    uint        name##_step_y,   \
2748    uint        name##_stride_z, \
2749    uint        name##_step_z,   \
2750    uint        name##_offset_first_element_in_bytes
2751
2752#define TENSOR4D_DECLARATION(name)   \
2753    __global uchar *name##_ptr,      \
2754    uint        name##_stride_x, \
2755    uint        name##_step_x,   \
2756    uint        name##_stride_y, \
2757    uint        name##_step_y,   \
2758    uint        name##_stride_z, \
2759    uint        name##_step_z,   \
2760    uint        name##_stride_w, \
2761    uint        name##_step_w,   \
2762    uint        name##_offset_first_element_in_bytes
2763
2764#define CONVERT_TO_VECTOR_STRUCT(name) \
2765    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
2766
2767#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
2768    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
2769
2770#define CONVERT_TO_IMAGE_STRUCT(name) \
2771    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
2772
2773#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
2774    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
2775
2776#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
2777    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
2778
2779#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
2780    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
2781
2782#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
2783    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
2784
2785#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
2786    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
2787                                 name##_stride_z, name##_step_z)
2788
2789#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
2790    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
2791
2792#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
2793    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
2794                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
2795
2796#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
2797    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
2798
2799#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
2800    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
2801                           name##_stride_z, name##_step_z)
2802
2803/** Structure to hold Vector information */
2804typedef struct Vector
2805{
2806    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
2807    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
2808    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
2809} Vector;
2810
2811/** Structure to hold Image information */
2812typedef struct Image
2813{
2814    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
2815    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
2816    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
2817    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
2818} Image;
2819
2820/** Structure to hold 3D tensor information */
2821typedef struct Tensor3D
2822{
2823    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
2824    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
2825    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
2826    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
2827    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
2828} Tensor3D;
2829
2830/** Structure to hold 4D tensor information */
2831typedef struct Tensor4D
2832{
2833    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
2834    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
2835    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
2836    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
2837    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
2838    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
2839} Tensor4D;
2840
2841/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
2842 *
2843 * @param[in] ptr                           Pointer to the starting postion of the buffer
2844 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
2845 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
2846 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
2847 *
2848 * @return An image object
2849 */
2850inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
2851{
2852    Vector vector =
2853    {
2854        .ptr                           = ptr,
2855        .offset_first_element_in_bytes = offset_first_element_in_bytes,
2856        .stride_x                      = stride_x,
2857    };
2858    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
2859    return vector;
2860}
2861
2862/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
2863 *
2864 * @param[in] ptr                           Pointer to the starting postion of the buffer
2865 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
2866 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
2867 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
2868 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
2869 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
2870 *
2871 * @return An image object
2872 */
2873inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
2874{
2875    Image img =
2876    {
2877        .ptr                           = ptr,
2878        .offset_first_element_in_bytes = offset_first_element_in_bytes,
2879        .stride_x                      = stride_x,
2880        .stride_y                      = stride_y
2881    };
2882    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
2883    return img;
2884}
2885
2886/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
2887 *
2888 * @param[in] ptr                           Pointer to the starting postion of the buffer
2889 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
2890 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
2891 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
2892 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
2893 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
2894 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
2895 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
2896 *
2897 * @return A 3D tensor object
2898 */
2899inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
2900{
2901    Image img =
2902    {
2903        .ptr                           = ptr,
2904        .offset_first_element_in_bytes = offset_first_element_in_bytes,
2905        .stride_x                      = stride_x,
2906        .stride_y                      = stride_y
2907    };
2908    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
2909    return img;
2910}
2911
2912/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
2913 *
2914 * @param[in] ptr                           Pointer to the starting postion of the buffer
2915 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
2916 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
2917 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
2918 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
2919 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
2920 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
2921 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
2922 *
2923 * @return A 3D tensor object
2924 */
2925inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
2926{
2927    Tensor3D tensor =
2928    {
2929        .ptr                           = ptr,
2930        .offset_first_element_in_bytes = offset_first_element_in_bytes,
2931        .stride_x                      = stride_x,
2932        .stride_y                      = stride_y,
2933        .stride_z                      = stride_z
2934    };
2935    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
2936    return tensor;
2937}
2938
2939/** Wrap 3D tensor information into an tensor structure.
2940 *
2941 * @param[in] ptr                           Pointer to the starting postion of the buffer
2942 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
2943 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
2944 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
2945 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
2946 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
2947 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
2948 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
2949 *
2950 * @return A 3D tensor object
2951 */
2952inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
2953{
2954    Tensor3D tensor =
2955    {
2956        .ptr                           = ptr,
2957        .offset_first_element_in_bytes = offset_first_element_in_bytes,
2958        .stride_x                      = stride_x,
2959        .stride_y                      = stride_y,
2960        .stride_z                      = stride_z
2961    };
2962    return tensor;
2963}
2964
2965inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
2966                                             uint step_w,
2967                                             uint mod_size)
2968{
2969    Tensor4D tensor =
2970    {
2971        .ptr                           = ptr,
2972        .offset_first_element_in_bytes = offset_first_element_in_bytes,
2973        .stride_x                      = stride_x,
2974        .stride_y                      = stride_y,
2975        .stride_z                      = stride_z,
2976        .stride_w                      = stride_w
2977    };
2978
2979    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
2980    return tensor;
2981}
2982
2983/** Get the pointer position of a Vector
2984 *
2985 * @param[in] vec Pointer to the starting position of the buffer
2986 * @param[in] x   Relative X position
2987 */
2988inline __global const uchar *vector_offset(const Vector *vec, int x)
2989{
2990    return vec->ptr + x * vec->stride_x;
2991}
2992
2993/** Get the pointer position of a Image
2994 *
2995 * @param[in] img Pointer to the starting position of the buffer
2996 * @param[in] x   Relative X position
2997 * @param[in] y   Relative Y position
2998 */
2999inline __global uchar *offset(const Image *img, int x, int y)
3000{
3001    return img->ptr + x * img->stride_x + y * img->stride_y;
3002}
3003
3004/** Get the pointer position of a Tensor3D
3005 *
3006 * @param[in] tensor Pointer to the starting position of the buffer
3007 * @param[in] x      Relative X position
3008 * @param[in] y      Relative Y position
3009 * @param[in] z      Relative Z position
3010 */
3011inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
3012{
3013    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
3014}
3015
3016/** Get the pointer position of a Tensor4D
3017 *
3018 * @param[in] tensor Pointer to the starting position of the buffer
3019 * @param[in] x      Relative X position
3020 * @param[in] y      Relative Y position
3021 * @param[in] z      Relative Z position
3022 * @param[in] w      Relative W position
3023 */
3024inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
3025{
3026    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
3027}
3028
3029/** Get the offset for a given linear index of a Tensor3D
3030 *
3031 * @param[in] tensor Pointer to the starting position of the buffer
3032 * @param[in] width  Width of the input tensor
3033 * @param[in] height Height of the input tensor
3034 * @param[in] depth  Depth of the input tensor
3035 * @param[in] index  Linear index
3036 */
3037inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
3038{
3039    uint num_elements = width * height;
3040
3041    const uint z = index / num_elements;
3042
3043    index %= num_elements;
3044
3045    const uint y = index / width;
3046
3047    index %= width;
3048
3049    const uint x = index;
3050
3051    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
3052}
3053
3054#endif // _HELPER_H
3055
3056/** Utility macro to access a vector with the scalar positions
3057 *
3058 * Supported cases are: Offset can only be of the same size of the OpenCL vector (2,3,4,8,16)
3059 *
3060 * @param[in] offset The offset within the vector. Offset can only be of the same size of the OpenCL vector (2,3,4,8,16)
3061 * @param[in] n0     The number of consecutive columns to access. n0 + offset must be <= 16
3062 * @param[in] x      Vector to access
3063 * @{
3064 */
3065#define SCALAR_ACCESS_STR(offset, n0, x) scalar_access_##offset##_##n0(x)
3066#define SCALAR_ACCESS(offset, n0, x) SCALAR_ACCESS_STR(offset, n0, x)
3067
3068// offset == 0
3069#define scalar_access_0_1(x) ((x).s0)
3070#define scalar_access_0_2(x) ((x).s01)
3071#define scalar_access_0_3(x) ((x).s012)
3072#define scalar_access_0_4(x) ((x).s0123)
3073#define scalar_access_0_8(x) ((x).s01234567)
3074#define scalar_access_0_16(x) ((x).s0123456789ABCDEF)
3075
3076// offset == 1
3077#define scalar_access_1_1(x) ((x).s1)
3078#define scalar_access_1_2(x) ((x).s12)
3079#define scalar_access_1_3(x) ((x).s123)
3080#define scalar_access_1_4(x) ((x).s1234)
3081#define scalar_access_1_8(x) ((x).s12345678)
3082
3083// offset == 2
3084#define scalar_access_2_1(x) ((x).s2)
3085#define scalar_access_2_2(x) ((x).s23)
3086#define scalar_access_2_3(x) ((x).s234)
3087#define scalar_access_2_4(x) ((x).s2345)
3088#define scalar_access_2_8(x) ((x).s23456789)
3089
3090// offset == 3
3091#define scalar_access_3_1(x) ((x).s3)
3092#define scalar_access_3_2(x) ((x).s34)
3093#define scalar_access_3_3(x) ((x).s345)
3094#define scalar_access_3_4(x) ((x).s3456)
3095#define scalar_access_3_8(x) ((x).s3456789A)
3096
3097// offset == 4
3098#define scalar_access_4_1(x) ((x).s4)
3099#define scalar_access_4_2(x) ((x).s45)
3100#define scalar_access_4_3(x) ((x).s456)
3101#define scalar_access_4_4(x) ((x).s4567)
3102#define scalar_access_4_8(x) ((x).s456789AB)
3103
3104// offset == 8
3105#define scalar_access_8_1(x) ((x).s8)
3106#define scalar_access_8_2(x) ((x).s89)
3107#define scalar_access_8_3(x) ((x).s89A)
3108#define scalar_access_8_4(x) ((x).s89AB)
3109#define scalar_access_8_8(x) ((x).s89ABCDEF)
3110
3111// offset == 12
3112#define scalar_access_12_1(x) ((x).sC)
3113#define scalar_access_12_2(x) ((x).sCD)
3114#define scalar_access_12_3(x) ((x).sCDE)
3115#define scalar_access_12_4(x) ((x).sCDEF)
3116
3117// offset == 16
3118#define scalar_access_16_1(x) ((x).sF)
3119
3120/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1) without allocating variables.
3121 * @name LOAD_TENSOR_ROW_n
3122 *
3123 * @param[in] N0         The number of columns to load
3124 * @param[in] DATA_TYPE  The data type of variables
3125 * @param[in] BASENAME   The basename of the destination variables for the loaded rows
3126 * @param[in] PTR        The base pointer
3127 * @param[in] COL_OFFSET The column vector offset. COL_OFFSET + N0 must be <= 16
3128 * @param[in] STRIDE_Y   The stride value in y-axis direction
3129 * @param[in] Z          The z-axis offset vector
3130 * @{
3131 */
3132#define LOAD_TENSOR_ROW_0(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3133    ({})
3134
3135#define LOAD_TENSOR_ROW_1(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3136    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##0) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
3137
3138#define LOAD_TENSOR_ROW_2(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3139    LOAD_TENSOR_ROW_1(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3140    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##1) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
3141
3142#define LOAD_TENSOR_ROW_3(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3143    LOAD_TENSOR_ROW_2(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3144    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##2) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
3145
3146#define LOAD_TENSOR_ROW_4(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3147    LOAD_TENSOR_ROW_3(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3148    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##3) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
3149
3150#define LOAD_TENSOR_ROW_5(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3151    LOAD_TENSOR_ROW_4(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3152    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##4) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
3153
3154#define LOAD_TENSOR_ROW_6(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3155    LOAD_TENSOR_ROW_5(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3156    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##5) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
3157
3158#define LOAD_TENSOR_ROW_7(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3159    LOAD_TENSOR_ROW_6(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3160    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##6) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
3161
3162#define LOAD_TENSOR_ROW_8(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3163    LOAD_TENSOR_ROW_7(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3164    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##7) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
3165
3166#define LOAD_TENSOR_ROW_9(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3167    LOAD_TENSOR_ROW_8(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3168    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##8) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
3169
3170#define LOAD_TENSOR_ROW_10(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3171    LOAD_TENSOR_ROW_9(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)      \
3172    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##9) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
3173
3174#define LOAD_TENSOR_ROW_11(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3175    LOAD_TENSOR_ROW_10(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3176    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##A) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
3177
3178#define LOAD_TENSOR_ROW_12(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3179    LOAD_TENSOR_ROW_11(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3180    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##B) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
3181
3182#define LOAD_TENSOR_ROW_13(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3183    LOAD_TENSOR_ROW_12(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3184    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##C) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
3185
3186#define LOAD_TENSOR_ROW_14(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3187    LOAD_TENSOR_ROW_13(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3188    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##D) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
3189
3190#define LOAD_TENSOR_ROW_15(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3191    LOAD_TENSOR_ROW_14(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3192    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##E) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
3193
3194#define LOAD_TENSOR_ROW_16(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \
3195    LOAD_TENSOR_ROW_15(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)     \
3196    SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##F) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
3197/** @}*/ // end of group LOAD_TENSOR_ROW_n
3198
3199/** Load tensor (consecutive rows and columns) with Z offset.
3200 * @name LOAD_TENSOR
3201 *
3202 * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16
3203 * The data to load is expected to have consecutive names for each row.
3204 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2.
3205 * The Z offset is expected to have consecutive names.
3206 * E.g., for M0=3, and Z=zin, the expected Z offsets are zin0, zin1 and zin2.
3207 *
3208 * @param[in] M0         The number of consecutive rows
3209 * @param[in] N0         The number of consecutive columns
3210 * @param[in] DATA_TYPE  The data type of the target
3211 * @param[in] BASENAME   The basename of the result variables
3212 * @param[in] PTR        The base pointer for the data
3213 * @param[in] COL_OFFSET The column vector offset. COL_OFFSET + N0 must be <= 16
3214 * @param[in] STRIDE_Y   The stride in y-axis direction
3215 * @param[in] Z          The z-axis offset vector
3216 * @{
3217 */
3218#define LOAD_TENSOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) LOAD_TENSOR_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)
3219#define LOAD_TENSOR(M0, N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) LOAD_TENSOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z)
3220/** @} */ // end of group LOAD_TENSOR
3221
3222/** Load 2D tensor (consecutive rows and columns) with Z offset.
3223 * @name LOAD_TENSOR_M0Xn
3224 *
3225 * @param[in] M0        The number of rows to load [0-16]
3226 * @param[in] N0        The number of columns to load [0-16]
3227 * @param[in] DATA_TYPE The data type of variables
3228 * @param[in] BASENAME  The basename of the destination variables for the loaded rows
3229 * @param[in] PTR       The base pointer
3230 * @param[in] STRIDE_Y  The stride value in y-axis direction
3231 * @param[in] Z         The z-axis offset vector
3232 * @{
3233 */
3234#define LOAD_TENSOR_M0X0(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3235    ({})
3236
3237#define LOAD_TENSOR_M0X1(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3238    LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
3239
3240#define LOAD_TENSOR_M0X2(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3241    LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
3242
3243#define LOAD_TENSOR_M0X3(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3244    LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
3245
3246#define LOAD_TENSOR_M0X4(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3247    LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
3248
3249#define LOAD_TENSOR_M0X5(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3250    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);       \
3251    LOAD_TENSOR(M0, 1, DATA_TYPE, a, input_ptr + 4 * sizeof(DATA_TYPE), 4, src_stride_y, zin);
3252
3253#define LOAD_TENSOR_M0X6(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3254    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);       \
3255    LOAD_TENSOR(M0, 2, DATA_TYPE, a, input_ptr + 4 * sizeof(DATA_TYPE), 4, src_stride_y, zin);
3256
3257#define LOAD_TENSOR_M0X7(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3258    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);       \
3259    LOAD_TENSOR(M0, 3, DATA_TYPE, a, input_ptr + 4 * sizeof(DATA_TYPE), 4, src_stride_y, zin);
3260
3261#define LOAD_TENSOR_M0X8(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3262    LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
3263
3264#define LOAD_TENSOR_M0X9(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3265    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr 0, src_stride_y, zin);        \
3266    LOAD_TENSOR(M0, 1, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin);
3267
3268#define LOAD_TENSOR_M0X10(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3269    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);        \
3270    LOAD_TENSOR(M0, 2, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin);
3271
3272#define LOAD_TENSOR_M0X11(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3273    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);        \
3274    LOAD_TENSOR(M0, 3, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin);
3275
3276#define LOAD_TENSOR_M0X12(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3277    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);        \
3278    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin);
3279
3280#define LOAD_TENSOR_M0X13(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin)                  \
3281    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);                         \
3282    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); \
3283    LOAD_TENSOR(M0, 1, DATA_TYPE, a, input_ptr + 12 * sizeof(DATA_TYPE), 12, src_stride_y, zin);
3284
3285#define LOAD_TENSOR_M0X14(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin)                  \
3286    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr 0, src_stride_y, zin);                          \
3287    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); \
3288    LOAD_TENSOR(M0, 2, DATA_TYPE, a, input_ptr + 12 * sizeof(DATA_TYPE), 12, src_stride_y, zin);
3289
3290#define LOAD_TENSOR_M0X15(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin)                  \
3291    LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);                         \
3292    LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); \
3293    LOAD_TENSOR(M0, 3, DATA_TYPE, a, input_ptr + 12 * sizeof(DATA_TYPE), 12, src_stride_y, zin);
3294
3295#define LOAD_TENSOR_M0X16(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \
3296    LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
3297/** @}*/ // end of group LOAD_TENSOR_M0Xn
3298
3299/** Load 2D tensor (consecutive rows and columns) with Z offset.
3300 * @name LOAD_TENSOR_M0XN0
3301 *
3302 * @param[in] M0        The number of consecutive rows [0-16]
3303 * @param[in] N0        The number of consecutive columns [0-16]
3304 * @param[in] DATA_TYPE The data type of the target
3305 * @param[in] BASENAME  The basename of the result variables
3306 * @param[in] PTR       The base pointer for the data
3307 * @param[in] STRIDE_Y  The stride in y-axis direction
3308 * @param[in] Z         The z-axis offset vector
3309 * @{
3310 */
3311#define LOAD_TENSOR_M0XN0_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) LOAD_TENSOR_M0X##N0(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
3312#define LOAD_TENSOR_M0XN0(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) LOAD_TENSOR_M0XN0_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
3313
3314/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1).
3315 * @name LOAD_ROW_n
3316 *
3317 * @param[in] N0        The number of columns to load
3318 * @param[in] DATA_TYPE The data type of variables
3319 * @param[in] BASENAME  The basename of the destination variables for the loaded rows
3320 * @param[in] PTR       The base pointer
3321 * @param[in] OFFSET    The offset within a row
3322 * @param[in] STRIDE_Y  The stride value in y-axis direction
3323 * @param[in] Z         The z-axis offset vector
3324 * @{
3325 */
3326#define LOAD_ROW_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3327    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3328    BASENAME##0 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 0 * STRIDE_Y + Z##0));
3329
3330#define LOAD_ROW_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3331    LOAD_ROW_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3332    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3333    BASENAME##1 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 1 * STRIDE_Y + Z##1));
3334
3335#define LOAD_ROW_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3336    LOAD_ROW_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3337    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3338    BASENAME##2 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 2 * STRIDE_Y + Z##2));
3339
3340#define LOAD_ROW_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3341    LOAD_ROW_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3342    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3343    BASENAME##3 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 3 * STRIDE_Y + Z##3));
3344
3345#define LOAD_ROW_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3346    LOAD_ROW_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3347    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3348    BASENAME##4 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 4 * STRIDE_Y + Z##4));
3349
3350#define LOAD_ROW_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3351    LOAD_ROW_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3352    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3353    BASENAME##5 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 5 * STRIDE_Y + Z##5));
3354
3355#define LOAD_ROW_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3356    LOAD_ROW_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3357    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3358    BASENAME##6 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 6 * STRIDE_Y + Z##6));
3359
3360#define LOAD_ROW_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3361    LOAD_ROW_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3362    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3363    BASENAME##7 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 7 * STRIDE_Y + Z##7));
3364
3365#define LOAD_ROW_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3366    LOAD_ROW_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3367    VEC_DATA_TYPE(DATA_TYPE, N0)                                      \
3368    BASENAME##8 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 8 * STRIDE_Y + Z##8));
3369
3370#define LOAD_ROW_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3371    LOAD_ROW_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)      \
3372    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3373    BASENAME##9 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 9 * STRIDE_Y + Z##9));
3374
3375#define LOAD_ROW_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3376    LOAD_ROW_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3377    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3378    BASENAME##A = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 10 * STRIDE_Y + Z##A));
3379
3380#define LOAD_ROW_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3381    LOAD_ROW_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3382    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3383    BASENAME##B = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 11 * STRIDE_Y + Z##B));
3384
3385#define LOAD_ROW_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3386    LOAD_ROW_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3387    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3388    BASENAME##C = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 12 * STRIDE_Y + Z##C));
3389
3390#define LOAD_ROW_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3391    LOAD_ROW_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3392    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3393    BASENAME##D = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 13 * STRIDE_Y + Z##D));
3394
3395#define LOAD_ROW_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3396    LOAD_ROW_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3397    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3398    BASENAME##E = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 14 * STRIDE_Y + Z##E));
3399
3400#define LOAD_ROW_16(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \
3401    LOAD_ROW_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)     \
3402    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3403    BASENAME##F = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 15 * STRIDE_Y + Z##F));
3404
3405/** @}*/ // end of group LOAD_ROW_n
3406
3407/** Load Blocks (consecutive rows and columns) with Z offset.
3408 * @name LOAD_BLOCK
3409 *
3410 * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16
3411 * The data to load is expected to have consecutive names for each row.
3412 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2.
3413 * The Z offset is expected to have consecutive names.
3414 * E.g., for M0=3, and Z=zin, the expected Z offsets are zin0, zin1 and zin2.
3415 *
3416 * @param[in] M0        The number of consecutive rows
3417 * @param[in] N0        The number of consecutive columns
3418 * @param[in] DATA_TYPE The data type of the target
3419 * @param[in] BASENAME  The basename of the result variables
3420 * @param[in] PTR       The base pointer for the data
3421 * @param[in] OFFSET    The offset within a row
3422 * @param[in] STRIDE_Y  The stride in y-axis direction
3423 * @param[in] Z         The z-axis offset vector
3424 * @{
3425 */
3426#define LOAD_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) LOAD_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)
3427#define LOAD_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) LOAD_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z)
3428/** @} */ // end of group LOAD_BLOCK
3429
3430/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1).
3431 * @name LOAD_TEXTURE2D_ROW_n
3432 *
3433 * @param[in] N0         The number of pixels to read
3434 * @param[in] DATA_TYPE  The data type of variables
3435 * @param[in] BASENAME   The basename of the destination variables for the loaded rows
3436 * @param[in] IMG        The 2D OpenCL image object
3437 * @param[in] X_COORD    The x coordinate for the top-left pixel
3438 * @param[in] Y_COORD    The y coordinate for the top-left pixel
3439 * @param[in] X_STEP_ROW The incremental step row for the x coordinate (in pixels)
3440 * @param[in] Y_STEP_ROW The incremental step row for the y coordinate (in pixels)
3441 * @{
3442 */
3443#define LOAD_TEXTURE2D_ROW_1(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3444    BASENAME##0 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 0 * X_STEP_ROW), (Y_COORD + 0 * Y_STEP_ROW))
3445
3446#define LOAD_TEXTURE2D_ROW_2(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3447    LOAD_TEXTURE2D_ROW_1(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3448    BASENAME##1 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 1 * X_STEP_ROW), (Y_COORD + 1 * Y_STEP_ROW))
3449
3450#define LOAD_TEXTURE2D_ROW_3(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3451    LOAD_TEXTURE2D_ROW_2(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3452    BASENAME##2 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 2 * X_STEP_ROW), (Y_COORD + 2 * Y_STEP_ROW))
3453
3454#define LOAD_TEXTURE2D_ROW_4(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3455    LOAD_TEXTURE2D_ROW_3(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3456    BASENAME##3 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 3 * X_STEP_ROW), (Y_COORD + 3 * Y_STEP_ROW))
3457
3458#define LOAD_TEXTURE2D_ROW_5(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3459    LOAD_TEXTURE2D_ROW_4(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3460    BASENAME##4 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 4 * X_STEP_ROW), (Y_COORD + 4 * Y_STEP_ROW))
3461
3462#define LOAD_TEXTURE2D_ROW_6(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3463    LOAD_TEXTURE2D_ROW_5(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3464    BASENAME##5 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 5 * X_STEP_ROW), (Y_COORD + 5 * Y_STEP_ROW))
3465
3466#define LOAD_TEXTURE2D_ROW_7(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3467    LOAD_TEXTURE2D_ROW_6(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3468    BASENAME##6 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 6 * X_STEP_ROW), (Y_COORD + 6 * Y_STEP_ROW))
3469
3470#define LOAD_TEXTURE2D_ROW_8(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3471    LOAD_TEXTURE2D_ROW_7(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3472    BASENAME##7 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 7 * X_STEP_ROW), (Y_COORD + 7 * Y_STEP_ROW))
3473
3474#define LOAD_TEXTURE2D_ROW_9(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3475    LOAD_TEXTURE2D_ROW_8(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3476    BASENAME##8 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 8 * X_STEP_ROW), (Y_COORD + 8 * Y_STEP_ROW))
3477
3478#define LOAD_TEXTURE2D_ROW_10(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3479    LOAD_TEXTURE2D_ROW_9(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)      \
3480    BASENAME##9 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 9 * X_STEP_ROW), (Y_COORD + 9 * Y_STEP_ROW))
3481
3482#define LOAD_TEXTURE2D_ROW_11(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3483    LOAD_TEXTURE2D_ROW_10(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3484    BASENAME##A = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 10 * X_STEP_ROW), (Y_COORD + 10 * Y_STEP_ROW))
3485
3486#define LOAD_TEXTURE2D_ROW_12(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3487    LOAD_TEXTURE2D_ROW_11(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3488    BASENAME##B = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 11 * X_STEP_ROW), (Y_COORD + 11 * Y_STEP_ROW))
3489
3490#define LOAD_TEXTURE2D_ROW_13(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3491    LOAD_TEXTURE2D_ROW_12(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3492    BASENAME##C = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 12 * X_STEP_ROW), (Y_COORD + 12 * Y_STEP_ROW))
3493
3494#define LOAD_TEXTURE2D_ROW_14(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3495    LOAD_TEXTURE2D_ROW_13(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3496    BASENAME##D = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 13 * X_STEP_ROW), (Y_COORD + 13 * Y_STEP_ROW))
3497
3498#define LOAD_TEXTURE2D_ROW_15(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3499    LOAD_TEXTURE2D_ROW_14(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3500    BASENAME##E = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 14 * X_STEP_ROW), (Y_COORD + 14 * Y_STEP_ROW))
3501
3502#define LOAD_TEXTURE2D_ROW_16(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \
3503    LOAD_TEXTURE2D_ROW_15(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)     \
3504    BASENAME##F = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 15 * X_STEP_ROW), (Y_COORD + 15 * Y_STEP_ROW))
3505/** @} */ // end of group LOAD_TEXTURE2D_ROW_n
3506
3507/** Load a 2D texture in unit of pixel. A pixel is made of 4 floating point values
3508 * @name LOAD_TEXTURE2D
3509 *
3510 * Supported cases are M0=1,2,3,...,16 and N0=1
3511 * The data to load is expected to have consecutive names for each row.
3512 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2.
3513 *
3514 * @param[in] M0         The number of consecutive rows
3515 * @param[in] N0         The number of consecutive pixels. Only 1, 2 and 4 are supported
3516 * @param[in] DATA_TYPE  The data type of the target
3517 * @param[in] BASENAME   The basename of the result variables
3518 * @param[in] IMG        The 2D OpenCL image object
3519 * @param[in] X_COORD    The x coordinate for the top-left pixel
3520 * @param[in] Y_COORD    The y coordinate for the top-left pixel
3521 * @param[in] X_STEP_ROW The incremental step row for the x coordinate (in pixels)
3522 * @param[in] Y_STEP_ROW The incremental step row for the y coordinate (in pixels)
3523 * @{
3524 */
3525#define LOAD_TEXTURE2D_STR(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) LOAD_TEXTURE2D_ROW_##M0(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)
3526#define LOAD_TEXTURE2D(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) LOAD_TEXTURE2D_STR(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW)
3527/** @} */ // end of group LOAD_TEXTURE2D
3528
3529/** Loads the elements from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1).
3530 * @name LOAD_ELEMENT_n
3531 *
3532 * @param[in] N0        The number of rows to load
3533 * @param[in] DATA_TYPE The data type of variables
3534 * @param[in] BASENAME  The basename of the destination variables for the loaded rows
3535 * @param[in] PTR       The base pointer
3536 * @param[in] OFFSET    The offset within a row
3537 * @param[in] STRIDE_Y  The stride value in y-axis direction
3538 * @{
3539 */
3540#define LOAD_ELEMENT_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3541    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3542    BASENAME##0 = *((__global DATA_TYPE *)(PTR + OFFSET + 0 * STRIDE_Y));
3543
3544#define LOAD_ELEMENT_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3545    LOAD_ELEMENT_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3546    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3547    BASENAME##1 = *((__global DATA_TYPE *)(PTR + OFFSET + 1 * STRIDE_Y));
3548
3549#define LOAD_ELEMENT_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3550    LOAD_ELEMENT_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3551    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3552    BASENAME##2 = *((__global DATA_TYPE *)(PTR + OFFSET + 2 * STRIDE_Y));
3553
3554#define LOAD_ELEMENT_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3555    LOAD_ELEMENT_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3556    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3557    BASENAME##3 = *((__global DATA_TYPE *)(PTR + OFFSET + 3 * STRIDE_Y));
3558
3559#define LOAD_ELEMENT_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3560    LOAD_ELEMENT_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3561    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3562    BASENAME##4 = *((__global DATA_TYPE *)(PTR + OFFSET + 4 * STRIDE_Y));
3563
3564#define LOAD_ELEMENT_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3565    LOAD_ELEMENT_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3566    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3567    BASENAME##5 = *((__global DATA_TYPE *)(PTR + OFFSET + 5 * STRIDE_Y));
3568
3569#define LOAD_ELEMENT_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3570    LOAD_ELEMENT_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3571    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3572    BASENAME##6 = *((__global DATA_TYPE *)(PTR + OFFSET + 6 * STRIDE_Y));
3573
3574#define LOAD_ELEMENT_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3575    LOAD_ELEMENT_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3576    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3577    BASENAME##7 = *((__global DATA_TYPE *)(PTR + OFFSET + 7 * STRIDE_Y));
3578
3579#define LOAD_ELEMENT_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3580    LOAD_ELEMENT_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3581    VEC_DATA_TYPE(DATA_TYPE, N0)                                       \
3582    BASENAME##8 = *((__global DATA_TYPE *)(PTR + OFFSET + 8 * STRIDE_Y));
3583
3584#define LOAD_ELEMENT_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3585    LOAD_ELEMENT_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)      \
3586    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3587    BASENAME##9 = *((__global DATA_TYPE *)(PTR + OFFSET + 9 * STRIDE_Y));
3588
3589#define LOAD_ELEMENT_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3590    LOAD_ELEMENT_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3591    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3592    BASENAME##A = *((__global DATA_TYPE *)(PTR + OFFSET + 10 * STRIDE_Y));
3593
3594#define LOAD_ELEMENT_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3595    LOAD_ELEMENT_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3596    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3597    BASENAME##B = *((__global DATA_TYPE *)(PTR + OFFSET + 11 * STRIDE_Y));
3598
3599#define LOAD_ELEMENT_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3600    LOAD_ELEMENT_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3601    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3602    BASENAME##C = *((__global DATA_TYPE *)(PTR + OFFSET + 12 * STRIDE_Y));
3603
3604#define LOAD_ELEMENT_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3605    LOAD_ELEMENT_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3606    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3607    BASENAME##D = *((__global DATA_TYPE *)(PTR + OFFSET + 13 * STRIDE_Y));
3608
3609#define LOAD_ELEMENT_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3610    LOAD_ELEMENT_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3611    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3612    BASENAME##E = *((__global DATA_TYPE *)(PTR + OFFSET + 14 * STRIDE_Y));
3613
3614#define LOAD_ELEMENT_16(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \
3615    LOAD_ELEMENT_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)     \
3616    VEC_DATA_TYPE(DATA_TYPE, N0)                                        \
3617    BASENAME##F = *((__global DATA_TYPE *)(PTR + OFFSET + 15 * STRIDE_Y));
3618
3619/** @}*/ // end of group LOAD_ELEMENT_n
3620
3621/** Load Scalar as Vector (consecutive elements).
3622 * @name LOAD_SCALAR_AS_VECTOR
3623 *
3624 * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16
3625 * The data to load is expected to have consecutive names for each row.
3626 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2.
3627 *
3628 * @param[in] M0        The number of consecutive rows
3629 * @param[in] N0        The number of consecutive columns
3630 * @param[in] DATA_TYPE The data type of the target
3631 * @param[in] BASENAME  The basename of the result variables
3632 * @param[in] PTR       The base pointer for the data
3633 * @param[in] OFFSET    The offset within a row
3634 * @param[in] STRIDE_Y  The stride in y-axis direction
3635 * @{
3636 */
3637#define LOAD_SCALAR_AS_VECTOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) LOAD_ELEMENT_##M0(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)
3638#define LOAD_SCALAR_AS_VECTOR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) LOAD_SCALAR_AS_VECTOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y)
3639/** @} */ // end of group LOAD_SCALAR_AS_VECTOR
3640
3641/** Basic macros to calculate Z offset values from Z0 to Zn-1
3642 * @name CALCULATE_Z_OFFSET_n
3643 *
3644 * @param[in] M0              The number of offset values to calculate
3645 * @param[in] DATA_TYPE       The data type of the results
3646 * @param[in] Z               The basename of the result variables
3647 * @param[in] Y               The work-itme ID of y-axis
3648 * @param[in] HEIGHT_GEMM3D   The height of GEMM3D
3649 * @param[in] DEPTH_GEMM3D    The depth of GEMM3D
3650 * @param[in] CROSS_PLANE_PAD The padding required for plane changes accross the z-dimension
3651 * @param[in] STRIDE_Y        The stride value in y-axis direction
3652 *
3653 * @{
3654 */
3655#define CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3656    Z##0 = (0 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3657    Z##0 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##0);                                                      \
3658    Z##0 *= (CROSS_PLANE_PAD * STRIDE_Y);
3659
3660#define CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3661    CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3662    Z##1 = (1 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3663    Z##1 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##1);                                                      \
3664    Z##1 *= (CROSS_PLANE_PAD * STRIDE_Y);
3665
3666#define CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3667    CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3668    Z##2 = (2 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3669    Z##2 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##2);                                                      \
3670    Z##2 *= (CROSS_PLANE_PAD * STRIDE_Y);
3671
3672#define CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3673    CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3674    Z##3 = (3 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3675    Z##3 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##3);                                                      \
3676    Z##3 *= (CROSS_PLANE_PAD * STRIDE_Y);
3677
3678#define CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3679    CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3680    Z##4 = (4 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3681    Z##4 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##4);                                                      \
3682    Z##4 *= (CROSS_PLANE_PAD * STRIDE_Y);
3683
3684#define CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3685    CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3686    Z##5 = (5 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3687    Z##5 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##5);                                                      \
3688    Z##5 *= (CROSS_PLANE_PAD * STRIDE_Y);
3689
3690#define CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3691    CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3692    Z##6 = (6 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3693    Z##6 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##6);                                                      \
3694    Z##6 *= (CROSS_PLANE_PAD * STRIDE_Y);
3695
3696#define CALCULATE_Z_OFFSET_8(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \
3697    CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)     \
3698    Z##7 = (7 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D;                               \
3699    Z##7 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##7);                                                      \
3700    Z##7 *= (CROSS_PLANE_PAD * STRIDE_Y);
3701
3702/** @} */ // end of group CALCULATE_Z_OFFSET_n
3703
3704/** Calculate Z offset values from Z0 to Zn-1
3705 * @name CALCULATE_Z_OFFSET
3706 *
3707 * The Z offsets are expected to have consecutive names.
3708 * E.g., for M0=3 and Z=zin, the expected names of Z offsets are zin1, zin2, zin3.
3709 * Note that, CROSS_PLANE_PAD (cross plain padding) is required to take into account
3710 * the possible cross plane paddings in case of the plance changes across the z-dimension.
3711 *
3712 * <!--
3713 * |                  |
3714 * |      plane0      |
3715 * |                  |
3716 * |__________________|
3717 * |******************|
3718 * |  cross_plane_pad |
3719 * |******************|
3720 * |                  |
3721 * |      plane1      |
3722 * |                  |
3723 * |__________________|
3724 * -->
3725 *
3726 * @param[in] M0              The number of offset values to calculate
3727 * @param[in] DATA_TYPE       The data type of the results
3728 * @param[in] Z               The basename of the result variables
3729 * @param[in] Y               The work-itme ID of y-axis
3730 * @param[in] HEIGHT_GEMM3D   The height of GEMM3D
3731 * @param[in] DEPTH_GEMM3D    The depth of GEMM3D
3732 * @param[in] CROSS_PLANE_PAD The padding required for plane changes accross the z-dimension
3733 * @param[in] STRIDE_Y        The stride value in y-axis direction
3734 * @{
3735 */
3736#define CALCULATE_Z_OFFSET_STR(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) CALCULATE_Z_OFFSET_##M0(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)
3737#define CALCULATE_Z_OFFSET(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) CALCULATE_Z_OFFSET_STR(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y)
3738/** @} */ // end of group CALCULATE_Z_OFFSET
3739
3740/** Scale the rows in the given variables (BASENAME0 to BASENAMEn-1)
3741 * @name SCALE_ROW_n
3742 *
3743 * @param[in] DATA_TYPE The data type of the variables
3744 * @param[in] BASENAME  The basename of the variables
3745 * @param[in] SCALE     The scale factor
3746 * @{
3747 */
3748#define SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \
3749    BASENAME##0 *= (DATA_TYPE)SCALE;
3750
3751#define SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \
3752    SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE)     \
3753    BASENAME##1 *= (DATA_TYPE)SCALE;
3754
3755#define SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \
3756    SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE)     \
3757    BASENAME##2 *= (DATA_TYPE)SCALE;
3758
3759#define SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \
3760    SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE)     \
3761    BASENAME##3 *= (DATA_TYPE)SCALE;
3762
3763#define SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \
3764    SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE)     \
3765    BASENAME##4 *= (DATA_TYPE)SCALE;
3766
3767#define SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \
3768    SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE)     \
3769    BASENAME##5 *= (DATA_TYPE)SCALE;
3770
3771#define SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \
3772    SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE)     \
3773    BASENAME##6 *= (DATA_TYPE)SCALE;
3774
3775#define SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \
3776    SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE)     \
3777    BASENAME##7 *= (DATA_TYPE)SCALE;
3778
3779#define SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \
3780    SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE)     \
3781    BASENAME##8 *= (DATA_TYPE)SCALE;
3782
3783#define SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \
3784    SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE)      \
3785    BASENAME##9 *= (DATA_TYPE)SCALE;
3786
3787#define SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \
3788    SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE)     \
3789    BASENAME##A *= (DATA_TYPE)SCALE;
3790
3791#define SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \
3792    SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE)     \
3793    BASENAME##B *= (DATA_TYPE)SCALE;
3794
3795#define SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \
3796    SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE)     \
3797    BASENAME##C *= (DATA_TYPE)SCALE;
3798
3799#define SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \
3800    SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE)     \
3801    BASENAME##D *= (DATA_TYPE)SCALE;
3802
3803#define SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \
3804    SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE)     \
3805    BASENAME##E *= (DATA_TYPE)SCALE;
3806
3807#define SCALE_ROW_16(DATA_TYPE, BASENAME, SCALE) \
3808    SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE)     \
3809    BASENAME##F *= (DATA_TYPE)SCALE;
3810/** @} */ // end of group SCALE_ROW_n
3811
3812/** Scale elements stored in a block (BASENAME)
3813 * @name SCALE_BLOCK
3814 *
3815 * Supported cases are N=1,2,3,...,16
3816 *
3817 * @param[in] N         The number of rows in the block
3818 * @param[in] DATA_TYPE The data type of the block
3819 * @param[in] BASENAME  The basename of the block
3820 * @param[in] SCALE     The scale factor
3821 * @{
3822 */
3823#define SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) SCALE_ROW_##N(DATA_TYPE, BASENAME, SCALE)
3824#define SCALE_BLOCK(N, DATA_TYPE, BASENAME, SCALE) SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE)
3825/** @} */ // end of group SCALE_BLOCK
3826
3827/** Create a new vector containing the values at the given index for a set of given vectors
3828 * @name COLUMN_VECTORn
3829 *
3830 * @param[in] IDX_COL  The index value
3831 * @param[in] BASENAME The basename of the destination vectors
3832 * @param[in] X        The basename of the source vectors
3833 * @param[in] TYPE     The data type of the destination vectors
3834 * @{
3835 */
3836#define COLUMN_VECTOR1(IDX_COL, BASENAME, X, TYPE) \
3837    TYPE BASENAME##IDX_COL = (TYPE)((X##0).s##IDX_COL);
3838#define COLUMN_VECTOR2(IDX_COL, BASENAME, X, TYPE) \
3839    VEC_DATA_TYPE(TYPE, 2)                         \
3840    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 2))((X##0).s##IDX_COL, (X##1).s##IDX_COL);
3841#define COLUMN_VECTOR3(IDX_COL, BASENAME, X, TYPE) \
3842    VEC_DATA_TYPE(TYPE, 3)                         \
3843    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 3))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL);
3844#define COLUMN_VECTOR4(IDX_COL, BASENAME, X, TYPE) \
3845    VEC_DATA_TYPE(TYPE, 4)                         \
3846    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 4))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL);
3847#define COLUMN_VECTOR8(IDX_COL, BASENAME, X, TYPE) \
3848    VEC_DATA_TYPE(TYPE, 8)                         \
3849    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 8))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL);
3850#define COLUMN_VECTOR16(IDX_COL, BASENAME, X, TYPE) \
3851    VEC_DATA_TYPE(TYPE, 16)                         \
3852    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 16))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL, (X##8).s##IDX_COL, (X##9).s##IDX_COL, (X##A).s##IDX_COL, (X##B).s##IDX_COL, (X##C).s##IDX_COL, (X##D).s##IDX_COL, (X##E).s##IDX_COL, (X##F).s##IDX_COL);
3853/** @} */ // end of group COLUMN_VECTORn
3854
3855/** Create a new vector containing the values at the given index. Utility macros for transposing a colum-vector
3856 * @name COLUMN_VECTOR_SCALARn
3857 *
3858 * @param[in] IDX_COL  The index value
3859 * @param[in] BASENAME The basename of the destination vectors
3860 * @param[in] X        The basename of the source vectors
3861 * @param[in] TYPE     The data type of the destination vectors
3862 * @{
3863 */
3864#define COLUMN_VECTOR_SCALAR1(IDX_COL, BASENAME, X, TYPE) \
3865    TYPE BASENAME##IDX_COL = (TYPE)((X##0));
3866#define COLUMN_VECTOR_SCALAR2(IDX_COL, BASENAME, X, TYPE) \
3867    VEC_DATA_TYPE(TYPE, 2)                                \
3868    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 2))((X##0), (X##1));
3869#define COLUMN_VECTOR_SCALAR3(IDX_COL, BASENAME, X, TYPE) \
3870    VEC_DATA_TYPE(TYPE, 3)                                \
3871    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 3))((X##0), (X##1), (X##2));
3872#define COLUMN_VECTOR_SCALAR4(IDX_COL, BASENAME, X, TYPE) \
3873    VEC_DATA_TYPE(TYPE, 4)                                \
3874    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 4))((X##0), (X##1), (X##2), (X##3));
3875#define COLUMN_VECTOR_SCALAR8(IDX_COL, BASENAME, X, TYPE) \
3876    VEC_DATA_TYPE(TYPE, 8)                                \
3877    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 8))((X##0), (X##1), (X##2), (X##3), (X##4), (X##5), (X##6), (X##7));
3878#define COLUMN_VECTOR_SCALAR16(IDX_COL, BASENAME, X, TYPE) \
3879    VEC_DATA_TYPE(TYPE, 16)                                \
3880    BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 16))((X##0), (X##1), (X##2), (X##3), (X##4), (X##5), (X##6), (X##7), (X##8), (X##9), (X##A), (X##B), (X##C), (X##D), (X##E), (X##F));
3881/** @} */ // end of group COLUMN_VECTORn
3882
3883/** Create transposed vectors of the given vectors
3884 * @name TRANSPOSE_K0Xn
3885 *
3886 * @param[in] K0       The size of the source vectors
3887 * @param[in] BASENAME The basename of transposed vectors
3888 * @param[in] B        The basename of source vectors for transposition
3889 * @param[in] TYPE     The data type of the transposed vectors
3890 * @{
3891 */
3892#define TRANSPOSE_K0X1(K0, BASENAME, B, TYPE) \
3893    COLUMN_VECTOR_SCALAR(K0, 0, BASENAME, B, TYPE);
3894#define TRANSPOSE_K0X2(K0, BASENAME, B, TYPE) \
3895    COLUMN_VECTOR(K0, 0, BASENAME, B, TYPE);  \
3896    COLUMN_VECTOR(K0, 1, BASENAME, B, TYPE);
3897#define TRANSPOSE_K0X3(K0, BASENAME, B, TYPE) \
3898    TRANSPOSE_K0X2(K0, BASENAME, B, TYPE);    \
3899    COLUMN_VECTOR(K0, 2, BASENAME, B, TYPE);
3900#define TRANSPOSE_K0X4(K0, BASENAME, B, TYPE) \
3901    TRANSPOSE_K0X3(K0, BASENAME, B, TYPE);    \
3902    COLUMN_VECTOR(K0, 3, BASENAME, B, TYPE);
3903#define TRANSPOSE_K0X8(K0, BASENAME, B, TYPE) \
3904    TRANSPOSE_K0X4(K0, BASENAME, B, TYPE);    \
3905    COLUMN_VECTOR(K0, 4, BASENAME, B, TYPE);  \
3906    COLUMN_VECTOR(K0, 5, BASENAME, B, TYPE);  \
3907    COLUMN_VECTOR(K0, 6, BASENAME, B, TYPE);  \
3908    COLUMN_VECTOR(K0, 7, BASENAME, B, TYPE);
3909#define TRANSPOSE_K0X16(K0, BASENAME, B, TYPE) \
3910    TRANSPOSE_K0X8(K0, BASENAME, B, TYPE);     \
3911    COLUMN_VECTOR(K0, 8, BASENAME, B, TYPE);   \
3912    COLUMN_VECTOR(K0, 9, BASENAME, B, TYPE);   \
3913    COLUMN_VECTOR(K0, A, BASENAME, B, TYPE);   \
3914    COLUMN_VECTOR(K0, B, BASENAME, B, TYPE);   \
3915    COLUMN_VECTOR(K0, C, BASENAME, B, TYPE);   \
3916    COLUMN_VECTOR(K0, D, BASENAME, B, TYPE);   \
3917    COLUMN_VECTOR(K0, E, BASENAME, B, TYPE);   \
3918    COLUMN_VECTOR(K0, F, BASENAME, B, TYPE);
3919
3920/** @} */ // end of group TRANSPOSE_K0Xn
3921
3922/** Create column vectors to contain the values at the given index for a set of given vectors
3923 *
3924 * @param[in] K0       The number of source vectors
3925 * @param[in] IDX_COL  The index value
3926 * @param[in] BASENAME The basename of the destination vectors
3927 * @param[in] B        The basename of the source vectors
3928 * @param[in] TYPE     The data type of the destination vectors
3929 */
3930#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B, TYPE) \
3931    CONCAT(COLUMN_VECTOR, K0)                         \
3932    (IDX_COL, BASENAME, B, TYPE);
3933
3934/** Create column vectors to contain the values at the given index. Utility macro for transposing a column-vector
3935 *
3936 * @param[in] K0       The number of source vectors
3937 * @param[in] IDX_COL  The index value
3938 * @param[in] BASENAME The basename of the destination vectors
3939 * @param[in] B        The basename of the source vectors
3940 * @param[in] TYPE     The data type of the destination vectors
3941 */
3942#define COLUMN_VECTOR_SCALAR(K0, IDX_COL, BASENAME, B, TYPE) \
3943    CONCAT(COLUMN_VECTOR_SCALAR, K0)                         \
3944    (IDX_COL, BASENAME, B, TYPE);
3945
3946/** Create transposed vectors form the given source vectors
3947 *
3948 * @param[in] K0       The size of source vectors
3949 * @param[in] N0       The number of source vectors
3950 * @param[in] BASENAME The basename of transposed vectors
3951 * @param[in] B        The basename of source vectors for transposition
3952 * @param[in] TYPE     The data type of the transposed vectors
3953 *
3954 */
3955#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B, TYPE) \
3956    CONCAT(TRANSPOSE_K0X, N0)                      \
3957    (K0, BASENAME, B, TYPE);
3958
3959/** Add the variables (BIAS0 to BIASn-1) to the others (BASENAME0 to BASENAMEn-1)
3960 * @name ADD_ROW_n
3961 *
3962 * @param[in] BASENAME The basename of the destination variables
3963 * @param[in] BIAS     The basename of the added variables
3964 * @{
3965 */
3966#define ADD_ROW_1(BASENAME, BIAS) \
3967    BASENAME##0 += BIAS##0;
3968
3969#define ADD_ROW_2(BASENAME, BIAS) \
3970    ADD_ROW_1(BASENAME, BIAS)     \
3971    BASENAME##1 += BIAS##1;
3972
3973#define ADD_ROW_3(BASENAME, BIAS) \
3974    ADD_ROW_2(BASENAME, BIAS)     \
3975    BASENAME##2 += BIAS##2;
3976
3977#define ADD_ROW_4(BASENAME, BIAS) \
3978    ADD_ROW_3(BASENAME, BIAS)     \
3979    BASENAME##3 += BIAS##3;
3980
3981#define ADD_ROW_5(BASENAME, BIAS) \
3982    ADD_ROW_4(BASENAME, BIAS)     \
3983    BASENAME##4 += BIAS##4;
3984
3985#define ADD_ROW_6(BASENAME, BIAS) \
3986    ADD_ROW_5(BASENAME, BIAS)     \
3987    BASENAME##5 += BIAS##5;
3988
3989#define ADD_ROW_7(BASENAME, BIAS) \
3990    ADD_ROW_6(BASENAME, BIAS)     \
3991    BASENAME##6 += BIAS##6;
3992
3993#define ADD_ROW_8(BASENAME, BIAS) \
3994    ADD_ROW_7(BASENAME, BIAS)     \
3995    BASENAME##7 += BIAS##7;
3996
3997#define ADD_ROW_9(BASENAME, BIAS) \
3998    ADD_ROW_8(BASENAME, BIAS)     \
3999    BASENAME##8 += BIAS##8;
4000
4001#define ADD_ROW_10(BASENAME, BIAS) \
4002    ADD_ROW_9(BASENAME, BIAS)      \
4003    BASENAME##9 += BIAS##9;
4004
4005#define ADD_ROW_11(BASENAME, BIAS) \
4006    ADD_ROW_10(BASENAME, BIAS)     \
4007    BASENAME##A += BIAS##A;
4008
4009#define ADD_ROW_12(BASENAME, BIAS) \
4010    ADD_ROW_11(BASENAME, BIAS)     \
4011    BASENAME##B += BIAS##B;
4012
4013#define ADD_ROW_13(BASENAME, BIAS) \
4014    ADD_ROW_12(BASENAME, BIAS)     \
4015    BASENAME##C += BIAS##C;
4016
4017#define ADD_ROW_14(BASENAME, BIAS) \
4018    ADD_ROW_13(BASENAME, BIAS)     \
4019    BASENAME##D += BIAS##D;
4020
4021#define ADD_ROW_15(BASENAME, BIAS) \
4022    ADD_ROW_14(BASENAME, BIAS)     \
4023    BASENAME##E += BIAS##E;
4024
4025#define ADD_ROW_16(BASENAME, BIAS) \
4026    ADD_ROW_15(BASENAME, BIAS)     \
4027    BASENAME##F += BIAS##F;
4028
4029/** @} */ // end of group ADD_ROW_n
4030
4031/** Add the block (BIAS) to another block (BASENAME)
4032 * @name ADD_BLOCK
4033 *
4034 * Supported cases are N=1,2,3,...,16
4035 *
4036 * @param[in] N        The number of vectors in the block
4037 * @param[in] BASENAME The basename of the destination variables
4038 * @param[in] BIAS     The basename of the added variables
4039 * @{
4040 */
4041#define ADD_BLOCK_STR(N, BASENAME, BIAS) ADD_ROW_##N(BASENAME, BIAS)
4042#define ADD_BLOCK(N, BASENAME, BIAS) ADD_BLOCK_STR(N, BASENAME, BIAS)
4043/** @} */ // end of group ADD_BLOCK
4044
4045/** Broadcast (add single value) to the each element of the destination variables
4046 * @name ADD_ROW_BROADCAST_n
4047 *
4048 * @param[in] BASENAME The basename of the destination variables
4049 * @param[in] BIAS     The variable containing the value to add
4050 * @{
4051 */
4052#define ADD_ROW_BROADCAST_1(BASENAME, BIAS) \
4053    BASENAME##0 += BIAS;
4054
4055#define ADD_ROW_BROADCAST_2(BASENAME, BIAS) \
4056    ADD_ROW_BROADCAST_1(BASENAME, BIAS)     \
4057    BASENAME##1 += BIAS;
4058
4059#define ADD_ROW_BROADCAST_3(BASENAME, BIAS) \
4060    ADD_ROW_BROADCAST_2(BASENAME, BIAS)     \
4061    BASENAME##2 += BIAS;
4062
4063#define ADD_ROW_BROADCAST_4(BASENAME, BIAS) \
4064    ADD_ROW_BROADCAST_3(BASENAME, BIAS)     \
4065    BASENAME##3 += BIAS;
4066
4067#define ADD_ROW_BROADCAST_5(BASENAME, BIAS) \
4068    ADD_ROW_BROADCAST_4(BASENAME, BIAS)     \
4069    BASENAME##4 += BIAS;
4070
4071#define ADD_ROW_BROADCAST_6(BASENAME, BIAS) \
4072    ADD_ROW_BROADCAST_5(BASENAME, BIAS)     \
4073    BASENAME##5 += BIAS;
4074
4075#define ADD_ROW_BROADCAST_7(BASENAME, BIAS) \
4076    ADD_ROW_BROADCAST_6(BASENAME, BIAS)     \
4077    BASENAME##6 += BIAS;
4078
4079#define ADD_ROW_BROADCAST_8(BASENAME, BIAS) \
4080    ADD_ROW_BROADCAST_7(BASENAME, BIAS)     \
4081    BASENAME##7 += BIAS;
4082
4083#define ADD_ROW_BROADCAST_9(BASENAME, BIAS) \
4084    ADD_ROW_BROADCAST_8(BASENAME, BIAS)     \
4085    BASENAME##8 += BIAS;
4086
4087#define ADD_ROW_BROADCAST_10(BASENAME, BIAS) \
4088    ADD_ROW_BROADCAST_9(BASENAME, BIAS)      \
4089    BASENAME##9 += BIAS;
4090
4091#define ADD_ROW_BROADCAST_11(BASENAME, BIAS) \
4092    ADD_ROW_BROADCAST_10(BASENAME, BIAS)     \
4093    BASENAME##A += BIAS;
4094
4095#define ADD_ROW_BROADCAST_12(BASENAME, BIAS) \
4096    ADD_ROW_BROADCAST_11(BASENAME, BIAS)     \
4097    BASENAME##B += BIAS;
4098
4099#define ADD_ROW_BROADCAST_13(BASENAME, BIAS) \
4100    ADD_ROW_BROADCAST_12(BASENAME, BIAS)     \
4101    BASENAME##C += BIAS;
4102
4103#define ADD_ROW_BROADCAST_14(BASENAME, BIAS) \
4104    ADD_ROW_BROADCAST_13(BASENAME, BIAS)     \
4105    BASENAME##D += BIAS;
4106
4107#define ADD_ROW_BROADCAST_15(BASENAME, BIAS) \
4108    ADD_ROW_BROADCAST_14(BASENAME, BIAS)     \
4109    BASENAME##E += BIAS;
4110
4111#define ADD_ROW_BROADCAST_16(BASENAME, BIAS) \
4112    ADD_ROW_BROADCAST_15(BASENAME, BIAS)     \
4113    BASENAME##F += BIAS;
4114
4115/** Broadcast (add a value) to the each element of the destination block (BASENAME)
4116 * @name ADD_BLOCK_BROADCAST
4117 *
4118 * Supported cases are N=1,2,3,...,16.
4119 *
4120 * @param[in] N        The number of vectors in the block
4121 * @param[in] BASENAME The basename of the destination variables
4122 * @param[in] BIAS     The variable containing the value to add
4123 * @{
4124 */
4125#define ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) ADD_ROW_BROADCAST_##N(BASENAME, BIAS)
4126#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS) ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS)
4127/** @} */ // end of group ADD_BLOCK_BROADCAST
4128
4129/** Apply activation to the given variables
4130 * @name ACTIVATION_ROW_n
4131 *
4132 * @param[in] ACTIVATION_TYPE The type of the activation
4133 * @param[in] DATA_TYPE       The data type of the vectors
4134 * @param[in] BASENAME        The basename of the variables
4135 * @param[in] A_VAL           Additional value required by the activation
4136 * @param[in] B_VAL           Additional value required by the activation
4137 * @{
4138 */
4139#define ACTIVATION_ROW_1(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4140    BASENAME##0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##0, A_VAL, B_VAL);
4141
4142#define ACTIVATION_ROW_2(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4143    ACTIVATION_ROW_1(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4144    BASENAME##1 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##1, A_VAL, B_VAL);
4145
4146#define ACTIVATION_ROW_3(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4147    ACTIVATION_ROW_2(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4148    BASENAME##2 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##2, A_VAL, B_VAL);
4149
4150#define ACTIVATION_ROW_4(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4151    ACTIVATION_ROW_3(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4152    BASENAME##3 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##3, A_VAL, B_VAL);
4153
4154#define ACTIVATION_ROW_5(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4155    ACTIVATION_ROW_4(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4156    BASENAME##4 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##4, A_VAL, B_VAL);
4157
4158#define ACTIVATION_ROW_6(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4159    ACTIVATION_ROW_5(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4160    BASENAME##5 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##5, A_VAL, B_VAL);
4161
4162#define ACTIVATION_ROW_7(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4163    ACTIVATION_ROW_6(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4164    BASENAME##6 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##6, A_VAL, B_VAL);
4165
4166#define ACTIVATION_ROW_8(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4167    ACTIVATION_ROW_7(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4168    BASENAME##7 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##7, A_VAL, B_VAL);
4169
4170#define ACTIVATION_ROW_9(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4171    ACTIVATION_ROW_8(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4172    BASENAME##8 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##8, A_VAL, B_VAL);
4173
4174#define ACTIVATION_ROW_10(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4175    ACTIVATION_ROW_9(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)      \
4176    BASENAME##9 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##9, A_VAL, B_VAL);
4177
4178#define ACTIVATION_ROW_11(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4179    ACTIVATION_ROW_10(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4180    BASENAME##A = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##A, A_VAL, B_VAL);
4181
4182#define ACTIVATION_ROW_12(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4183    ACTIVATION_ROW_11(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4184    BASENAME##B = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##B, A_VAL, B_VAL);
4185
4186#define ACTIVATION_ROW_13(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4187    ACTIVATION_ROW_12(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4188    BASENAME##C = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##C, A_VAL, B_VAL);
4189
4190#define ACTIVATION_ROW_14(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4191    ACTIVATION_ROW_13(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4192    BASENAME##D = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##D, A_VAL, B_VAL);
4193
4194#define ACTIVATION_ROW_15(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4195    ACTIVATION_ROW_14(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4196    BASENAME##E = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##E, A_VAL, B_VAL);
4197
4198#define ACTIVATION_ROW_16(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \
4199    ACTIVATION_ROW_15(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)     \
4200    BASENAME##F = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##F, A_VAL, B_VAL);
4201/** @} */ // end of group ACTIVATION_ROW_n
4202
4203/** Apply activation to a block (BASENAME)
4204 * @name ACTIVATION_BLOCK
4205 *
4206 * Supported cases are N=1,2,3,...,16.
4207 *
4208 * @param[in] N               The number of vectors in the block
4209 * @param[in] ACTIVATION_TYPE The type of the activation
4210 * @param[in] DATA_TYPE       The data type of the vectors
4211 * @param[in] BASENAME        The basename of the variables
4212 * @param[in] A_VAL           Additional value required by the activation
4213 * @param[in] B_VAL           Additional value required by the activation
4214 * @{
4215 */
4216#define ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) ACTIVATION_ROW_##N(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)
4217#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL)
4218/** @} */ // end of group ACTIVATION_BLOCK
4219
4220/** Apply convert_<data_type> to the given variables
4221 * @name CONVERT_ROW_n
4222 *
4223 * @param[in] N            The size of the vectors
4224 * @param[in] DATA_TYPE    The data type of the vectors
4225 * @param[in] BASENAME_SRC The basename of the source variables
4226 * @param[in] BASENAME_DST The basename of the destination variables
4227 */
4228#define CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4229    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4230    BASENAME_DST##0 = CONVERT(BASENAME_SRC##0, VEC_DATA_TYPE(DATA_TYPE, N));
4231
4232#define CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4233    CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4234    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4235    BASENAME_DST##1 = CONVERT(BASENAME_SRC##1, VEC_DATA_TYPE(DATA_TYPE, N));
4236
4237#define CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4238    CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4239    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4240    BASENAME_DST##2 = CONVERT(BASENAME_SRC##2, VEC_DATA_TYPE(DATA_TYPE, N));
4241
4242#define CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4243    CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4244    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4245    BASENAME_DST##3 = CONVERT(BASENAME_SRC##3, VEC_DATA_TYPE(DATA_TYPE, N));
4246
4247#define CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4248    CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4249    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4250    BASENAME_DST##4 = CONVERT(BASENAME_SRC##4, VEC_DATA_TYPE(DATA_TYPE, N));
4251
4252#define CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4253    CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4254    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4255    BASENAME_DST##5 = CONVERT(BASENAME_SRC##5, VEC_DATA_TYPE(DATA_TYPE, N));
4256
4257#define CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4258    CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4259    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4260    BASENAME_DST##6 = CONVERT(BASENAME_SRC##6, VEC_DATA_TYPE(DATA_TYPE, N));
4261
4262#define CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4263    CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4264    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4265    BASENAME_DST##7 = CONVERT(BASENAME_SRC##7, VEC_DATA_TYPE(DATA_TYPE, N));
4266
4267#define CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4268    CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4269    VEC_DATA_TYPE(DATA_TYPE, N)                                 \
4270    BASENAME_DST##8 = CONVERT(BASENAME_SRC##8, VEC_DATA_TYPE(DATA_TYPE, N));
4271
4272#define CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4273    CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)      \
4274    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4275    BASENAME_DST##9 = CONVERT(BASENAME_SRC##9, VEC_DATA_TYPE(DATA_TYPE, N));
4276
4277#define CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4278    CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4279    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4280    BASENAME_DST##A = CONVERT(BASENAME_SRC##A, VEC_DATA_TYPE(DATA_TYPE, N));
4281
4282#define CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4283    CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4284    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4285    BASENAME_DST##B = CONVERT(BASENAME_SRC##B, VEC_DATA_TYPE(DATA_TYPE, N));
4286
4287#define CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4288    CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4289    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4290    BASENAME_DST##C = CONVERT(BASENAME_SRC##C, VEC_DATA_TYPE(DATA_TYPE, N));
4291
4292#define CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4293    CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4294    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4295    BASENAME_DST##D = CONVERT(BASENAME_SRC##D, VEC_DATA_TYPE(DATA_TYPE, N));
4296
4297#define CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4298    CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4299    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4300    BASENAME_DST##E = CONVERT(BASENAME_SRC##E, VEC_DATA_TYPE(DATA_TYPE, N));
4301
4302#define CONVERT_ROW_16(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \
4303    CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)     \
4304    VEC_DATA_TYPE(DATA_TYPE, N)                                  \
4305    BASENAME_DST##F = CONVERT(BASENAME_SRC##F, VEC_DATA_TYPE(DATA_TYPE, N));
4306/** @} */ // end of group CONVERT_ROW_n
4307
4308/** Apply convert_<data_type> to a block (BASENAME_SRC) and save to another block (BASENAME_DST)
4309 * @name CONVERT_BLOCK
4310 *
4311 * Supported cases N=1,2,3,...,16.
4312 *
4313 * @param[in] M            The number of vectors to convert
4314 * @param[in] N            The size of the vectors
4315 * @param[in] DATA_TYPE    The data type of the vectors
4316 * @param[in] BASENAME_SRC The basename of the source variables
4317 * @param[in] BASENAME_DST The basename of the destination variables
4318 */
4319#define CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_ROW_##M(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)
4320#define CONVERT_BLOCK(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST)
4321/** @} */ // end of group CONVERT_BLOCK
4322/*
4323 * Copyright (c) 2019-2020 Arm Limited.
4324 *
4325 * SPDX-License-Identifier: MIT
4326 *
4327 * Permission is hereby granted, free of charge, to any person obtaining a copy
4328 * of this software and associated documentation files (the "Software"), to
4329 * deal in the Software without restriction, including without limitation the
4330 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
4331 * sell copies of the Software, and to permit persons to whom the Software is
4332 * furnished to do so, subject to the following conditions:
4333 *
4334 * The above copyright notice and this permission notice shall be included in all
4335 * copies or substantial portions of the Software.
4336 *
4337 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
4338 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
4339 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
4340 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
4341 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
4342 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
4343 * SOFTWARE.
4344 */
4345#ifndef ARM_COMPUTE_REPEAT_H
4346#define ARM_COMPUTE_REPEAT_H
4347
4348/*
4349 * Copyright (c) 2016-2020 Arm Limited.
4350 *
4351 * SPDX-License-Identifier: MIT
4352 *
4353 * Permission is hereby granted, free of charge, to any person obtaining a copy
4354 * of this software and associated documentation files (the "Software"), to
4355 * deal in the Software without restriction, including without limitation the
4356 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
4357 * sell copies of the Software, and to permit persons to whom the Software is
4358 * furnished to do so, subject to the following conditions:
4359 *
4360 * The above copyright notice and this permission notice shall be included in all
4361 * copies or substantial portions of the Software.
4362 *
4363 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
4364 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
4365 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
4366 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
4367 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
4368 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
4369 * SOFTWARE.
4370 */
4371#ifndef ARM_COMPUTE_HELPER_H
4372#define ARM_COMPUTE_HELPER_H
4373
4374/*
4375 * Copyright (c) 2020 Arm Limited.
4376 *
4377 * SPDX-License-Identifier: MIT
4378 *
4379 * Permission is hereby granted, free of charge, to any person obtaining a copy
4380 * of this software and associated documentation files (the "Software"), to
4381 * deal in the Software without restriction, including without limitation the
4382 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
4383 * sell copies of the Software, and to permit persons to whom the Software is
4384 * furnished to do so, subject to the following conditions:
4385 *
4386 * The above copyright notice and this permission notice shall be included in all
4387 * copies or substantial portions of the Software.
4388 *
4389 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
4390 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
4391 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
4392 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
4393 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
4394 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
4395 * SOFTWARE.
4396 */
4397
4398/** Store the 0 to (n-1)th rows of the given variables
4399 * @name STORE_ROW_n
4400 *
4401 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
4402 * @param[in] DATA_TYPE The data type of the vectors
4403 * @param[in] BASENAME  The basename of the variables
4404 * @param[in] PTR       The base pointer
4405 * @param[in] STRIDE_Y  The stride value in y-axis direction
4406 * @param[in] Z         The offset in z-axis direction
4407 * @{
4408 */
4409#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4410    VSTORE(N0)                                                 \
4411    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
4412
4413#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4414    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4415    VSTORE(N0)                                                 \
4416    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
4417
4418#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4419    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4420    VSTORE(N0)                                                 \
4421    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
4422
4423#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4424    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4425    VSTORE(N0)                                                 \
4426    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
4427
4428#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4429    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4430    VSTORE(N0)                                                 \
4431    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
4432
4433#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4434    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4435    VSTORE(N0)                                                 \
4436    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
4437
4438#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4439    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4440    VSTORE(N0)                                                 \
4441    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
4442
4443#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4444    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4445    VSTORE(N0)                                                 \
4446    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
4447
4448#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4449    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4450    VSTORE(N0)                                                 \
4451    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
4452
4453#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4454    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
4455    VSTORE(N0)                                                  \
4456    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
4457
4458#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4459    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4460    VSTORE(N0)                                                  \
4461    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
4462
4463#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4464    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4465    VSTORE(N0)                                                  \
4466    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
4467
4468#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4469    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4470    VSTORE(N0)                                                  \
4471    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
4472
4473#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4474    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4475    VSTORE(N0)                                                  \
4476    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
4477
4478#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4479    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4480    VSTORE(N0)                                                  \
4481    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
4482
4483#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4484    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4485    VSTORE(N0)                                                  \
4486    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
4487/** @} */ // end of groupd STORE_ROW_n
4488
4489/** Convert and store the 0th to (n-1)th rows of the given variables
4490 * @name CONVERT_STORE_ROW_n
4491 *
4492 * @param[in] N0        The size of the vectors
4493 * @param[in] DATA_TYPE The data type of the vectors
4494 * @param[in] BASENAME  The basename of the variables
4495 * @param[in] PTR       The base pointer
4496 * @param[in] STRIDE_Y  The stride value in y-axis direction
4497 * @param[in] Z         The offset in z-axis direction
4498 * @{
4499 */
4500#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4501    VSTORE(N0)                                                         \
4502    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
4503
4504#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4505    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4506    VSTORE(N0)                                                         \
4507    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
4508
4509#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4510    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4511    VSTORE(N0)                                                         \
4512    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
4513
4514#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4515    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4516    VSTORE(N0)                                                         \
4517    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
4518
4519#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4520    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4521    VSTORE(N0)                                                         \
4522    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
4523
4524#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4525    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4526    VSTORE(N0)                                                         \
4527    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
4528
4529#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4530    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4531    VSTORE(N0)                                                         \
4532    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
4533
4534#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4535    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4536    VSTORE(N0)                                                         \
4537    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
4538
4539#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4540    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4541    VSTORE(N0)                                                         \
4542    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
4543
4544#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
4545    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4546    VSTORE(N0)                                                     \
4547    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
4548
4549#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4550    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4551    VSTORE(N0)                                                          \
4552    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
4553
4554#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4555    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4556    VSTORE(N0)                                                          \
4557    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
4558
4559#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4560    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4561    VSTORE(N0)                                                          \
4562    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
4563
4564#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4565    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4566    VSTORE(N0)                                                          \
4567    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
4568
4569#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4570    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4571    VSTORE(N0)                                                          \
4572    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
4573
4574#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4575    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4576    VSTORE(N0)                                                          \
4577    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
4578
4579/** @} */ // end of groupd CONVERT_STORE_ROW_n
4580
4581/** Store a block of the given size M0xN0
4582 * @name STORE_BLOCK
4583 *
4584 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
4585 * The data to store is expected to have consecutive names for each row.
4586 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
4587 * The Z offset is expected to have consecutive names.
4588 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4589 *
4590 * @param[in] M0        The number of rows to store
4591 * @param[in] N0        The size of each vector
4592 * @param[in] DATA_TYPE The data type of the vectors
4593 * @param[in] BASENAME  The basename of the variables
4594 * @param[in] PTR       The base pointer
4595 * @param[in] STRIDE_Y  The stride value in y-axis direction
4596 * @param[in] Z         The offset in z-axis direction
4597 * @{
4598 */
4599#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4600#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4601/** @} */ // end of group STORE_BLOCK
4602
4603/** Convert and store a block of the given size M0xN0
4604 * @name CONVERT_STORE_BLOCK
4605 *
4606 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
4607 * The data to store is expected to have consecutive names for each row.
4608 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
4609 * The Z offset is expected to have consecutive names.
4610 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4611 *
4612 * @param[in] M0        The number of rows to store
4613 * @param[in] N0        The size of each vector
4614 * @param[in] DATA_TYPE The data type of the vectors
4615 * @param[in] BASENAME  The basename of the variables
4616 * @param[in] PTR       The base pointer
4617 * @param[in] STRIDE_Y  The stride value in y-axis direction
4618 * @param[in] Z         The offset in z-axis direction
4619 * @{
4620 */
4621#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4622#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4623/** @} */ // end of group CONVERT_STORE_BLOCK
4624
4625/** Partially store the 0 to (n-1)th rows of the given variables
4626 * @name STORE_ROW_PARTIAL_n
4627 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
4628 *
4629 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4630 *
4631 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
4632 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
4633 * @param[in] DATA_TYPE The data type of the vectors
4634 * @param[in] BASENAME  The basename of the variables
4635 * @param[in] PTR       The base pointer
4636 * @param[in] STRIDE_Y  The stride value in y-axis direction
4637 * @param[in] Z         The offset in z-axis direction
4638 * @{
4639 */
4640#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4641    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4642    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
4643
4644#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4645    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4646    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4647    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
4648
4649#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4650    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4651    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4652    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
4653
4654#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4655    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4656    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4657    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
4658
4659#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4660    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4661    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4662    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
4663
4664#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4665    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4666    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4667    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
4668
4669#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4670    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4671    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4672    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
4673
4674#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4675    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4676    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4677    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
4678
4679#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4680    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4681    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
4682    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
4683
4684#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4685    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
4686    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4687    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
4688
4689#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4690    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4691    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4692    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
4693
4694#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4695    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4696    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4697    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
4698
4699#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4700    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4701    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4702    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
4703
4704#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4705    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4706    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4707    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
4708
4709#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4710    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4711    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4712    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
4713
4714#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
4715    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
4716    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
4717    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
4718/** @} */ // end of groupd STORE_ROW_PARTIAL_n
4719
4720/** Partially store a block of the given size STORE_M0xSTORE_N0
4721 * @name STORE_BLOCK_PARTIAL
4722 *
4723 * @note The vector width @p N0 is also required for correct partial storing behaviour.
4724 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4725 *
4726 * The data to store is expected to have consecutive names for each row.
4727 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
4728 * The Z offset is expected to have consecutive names.
4729 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4730 *
4731 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
4732 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
4733 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
4734 * @param[in] DATA_TYPE The data type of the vectors
4735 * @param[in] BASENAME  The basename of the variables
4736 * @param[in] PTR       The base pointer
4737 * @param[in] STRIDE_Y  The stride value in y-axis direction
4738 * @param[in] Z         The offset in z-axis direction
4739 * @{
4740 */
4741#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4742#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4743/** Store a block that can be partial in both x and y dimensions
4744 *
4745 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4746 *
4747 * The data to store is expected to have consecutive names for each row.
4748 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
4749 * The Z offset is expected to have consecutive names.
4750 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4751 *
4752 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
4753 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
4754 * @param[in] DATA_TYPE        The data type of the vectors
4755 * @param[in] BASENAME         The basename of the variables
4756 * @param[in] PTR              The base pointer
4757 * @param[in] STRIDE_Y         The stride value in y-axis direction
4758 * @param[in] Z                The offset in z-axis direction
4759 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
4760 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
4761 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
4762 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
4763 */
4764#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
4765    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
4766    {                                                                                                                                                     \
4767        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
4768    }                                                                                                                                                     \
4769    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
4770    {                                                                                                                                                     \
4771        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
4772    }                                                                                                                                                     \
4773    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
4774    {                                                                                                                                                     \
4775        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
4776    }                                                                                                                                                     \
4777    else                                                                                                                                                  \
4778    {                                                                                                                                                     \
4779        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
4780    }
4781/** Store a block that can only be partial in x but not y.
4782 *
4783 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4784 *
4785 * The data to store is expected to have consecutive names for each row.
4786 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
4787 * The Z offset is expected to have consecutive names.
4788 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4789 *
4790 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
4791 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
4792 * @param[in] DATA_TYPE        The data type of the vectors
4793 * @param[in] BASENAME         The basename of the variables
4794 * @param[in] PTR              The base pointer
4795 * @param[in] STRIDE_Y         The stride value in y-axis direction
4796 * @param[in] Z                The offset in z-axis direction
4797 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
4798 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
4799 */
4800#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
4801    if(!(PARTIAL_COND_X))                                                                                         \
4802    {                                                                                                             \
4803        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
4804    }                                                                                                             \
4805    else                                                                                                          \
4806    {                                                                                                             \
4807        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
4808    }
4809/** Store a block that can only be partial in y but not x.
4810 *
4811 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4812 *
4813 * The data to store is expected to have consecutive names for each row.
4814 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
4815 * The Z offset is expected to have consecutive names.
4816 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4817 *
4818 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
4819 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
4820 * @param[in] DATA_TYPE        The data type of the vectors
4821 * @param[in] BASENAME         The basename of the variables
4822 * @param[in] PTR              The base pointer
4823 * @param[in] STRIDE_Y         The stride value in y-axis direction
4824 * @param[in] Z                The offset in z-axis direction
4825 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
4826 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
4827 */
4828#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
4829    if(!(PARTIAL_COND_Y))                                                                                         \
4830    {                                                                                                             \
4831        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
4832    }                                                                                                             \
4833    else                                                                                                          \
4834    {                                                                                                             \
4835        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
4836    }
4837/** @} */ // end of group STORE_BLOCK_PARTIAL
4838
4839#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
4840
4841/** Boundary-aware GEMM block store
4842 * @name STORE_BLOCK_BOUNDARY_AWARE
4843 * This macro assumes the following schemes to achieve boundary-awareness:
4844 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
4845 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
4846 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
4847 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
4848 *
4849 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
4850 * blocks **at the end**.
4851 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
4852 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
4853 *
4854 *  *--x-->                         x == 0                        x == 1
4855 *  |                  |<------------------------------N-------------------------->|
4856 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
4857 *  |     -------------#############################################################
4858 *  *     |          | |...............................|...........................|
4859 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
4860 *        |          | |...............................|...........................|
4861 *        M          --#############################################################
4862 *        |          | |                               |...........................|
4863 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
4864 *        |          | |                               |...........................|
4865 *        |------------#############################################################
4866 *
4867 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
4868 *
4869 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4870 *
4871 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
4872 * and select corresponding store methods such that the boundary detection logic is only added when needed.
4873 *
4874 * The data to store is expected to have consecutive names for each row.
4875 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
4876 * The Z offset is expected to have consecutive names.
4877 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
4878 *
4879 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
4880 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
4881 * @param[in] DATA_TYPE        The data type of the vectors
4882 * @param[in] BASENAME         The basename of the variables
4883 * @param[in] PTR              The base pointer
4884 * @param[in] STRIDE_Y         The stride value in y-axis direction
4885 * @param[in] Z                The offset in z-axis direction
4886 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
4887 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
4888 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
4889 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
4890 * @{
4891 */
4892#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
4893// Case1: No partial blocks in either x or y
4894#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
4895    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
4896
4897#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
4898// Case2: Partial blocks in y
4899#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
4900    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
4901
4902#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
4903// Case3: Partial blocks in x
4904#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
4905    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
4906
4907#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
4908// Case4: Partial blocks in both x and y
4909#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
4910    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
4911
4912#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
4913
4914#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
4915/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
4916
4917#if defined(PARTIAL_STORE_M0)
4918/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
4919 * @name COMPUTE_M0_START_ROW
4920 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
4921 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
4922 * blocks in the y dimension to avoid any padding.
4923 * EG: M0=4, PARTIAL_STORE_M0=1:
4924 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
4925 * block 0 (partial)| start row = 0   | start row = 0
4926 * block 1 (full)   | start row = 4   | start row = 1
4927 * block 2 (full)   | start row = 8   | start row = 5
4928 *
4929 * @param[in] y                Global id of current block in y.
4930 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
4931 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
4932 * @{
4933 */
4934#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
4935    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
4936#else // defined(PARTIAL_STORE_M0)
4937#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
4938    ((uint)(y * M0))
4939#endif    // defined(PARTIAL_STORE_M0)
4940/** @} */ // end of group COMPUTE_M0_START_ROW
4941
4942/** Store a vector that can only be partial in x.
4943 *
4944 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
4945 *
4946 * The data to store is expected to end in a 0.
4947 * E.g., for basename=c, the expected name is c0.
4948 *
4949 * @param[in] basename  The name of the variable without trailing 0
4950 * @param[in] data_type The data type of the vector
4951 * @param[in] ptr       The base pointer
4952 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
4953 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
4954 * @param[in] cond      Condition to select either vec_size0 or vec_size1
4955 * @{
4956 */
4957#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
4958    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
4959/** @} */ // end of group STORE_VECTOR_SELECT
4960
4961#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
4962#pragma OPENCL EXTENSION cl_khr_fp16 : enable
4963#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
4964
4965#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
4966#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
4967#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
4968
4969#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
4970#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
4971#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
4972
4973#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
4974#pragma OPENCL EXTENSION cl_arm_printf : enable
4975#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
4976
4977#define GPU_ARCH_MIDGARD 0x100
4978#define GPU_ARCH_BIFROST 0x200
4979
4980/** Concatenate two inputs.
4981 *
4982 * @param[in] a The first input to be concatenated
4983 * @param[in] b The second input to be concatenated
4984 *
4985 * @return The concatenated output
4986 */
4987#define CONCAT(a, b) a##b
4988
4989/** Expand the given vector
4990 *
4991 * @param[in] x The vector to be expanded
4992 *
4993 * @return The expanded output
4994 */
4995#define EXPAND(x) x
4996
4997/** Clamp the given value between an upper and lower bound.
4998 *
4999 * @param[in] x       The value to be clamped
5000 * @param[in] min_val The lower bound
5001 * @param[in] max_val The upper bound
5002 *
5003 * @return The clamped value.
5004 */
5005#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
5006
5007/** REVn reverses the given vector whose size is n.
5008 * @name REVn
5009 *
5010 * @param[in] x The vector to be reversed
5011 *
5012 * @return The reversed vector
5013 * @{
5014 */
5015#define REV1(x) ((x))
5016#define REV2(x) ((x).s10)
5017#define REV3(x) ((x).s210)
5018#define REV4(x) ((x).s3210)
5019#define REV8(x) ((x).s76543210)
5020#define REV16(x) ((x).sFEDCBA9876543210)
5021/** @} */ // end of group REVn
5022
5023/** Reverse the given vector.
5024 * @name REVERSE
5025 *
5026 * @param[in] x The vector to be reversed
5027 * @param[in] s The size of the vector
5028 *
5029 * @return The reversed vector
5030 * @{
5031 */
5032#define REVERSE_STR(x, s) REV##s((x))
5033#define REVERSE(x, s) REVERSE_STR(x, s)
5034/** @} */ // end of group REVERSE
5035
5036/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
5037 * @name ROTs_n
5038 *
5039 * @param[in] x The vector to be shifted
5040 *
5041 * @return The shifted vector
5042 * @{
5043 */
5044#define ROT1_0(x) ((x))
5045
5046#define ROT2_0(x) ((x))
5047#define ROT2_1(x) ((x).s10)
5048
5049#define ROT3_0(x) ((x))
5050#define ROT3_1(x) ((x).s201)
5051#define ROT3_2(x) ((x).s120)
5052
5053#define ROT4_0(x) ((x))
5054#define ROT4_1(x) ((x).s3012)
5055#define ROT4_2(x) ((x).s2301)
5056#define ROT4_3(x) ((x).s1230)
5057
5058#define ROT8_0(x) ((x))
5059#define ROT8_1(x) ((x).s70123456)
5060#define ROT8_2(x) ((x).s67012345)
5061#define ROT8_3(x) ((x).s56701234)
5062#define ROT8_4(x) ((x).s45670123)
5063#define ROT8_5(x) ((x).s34567012)
5064#define ROT8_6(x) ((x).s23456701)
5065#define ROT8_7(x) ((x).s12345670)
5066
5067#define ROT16_0(x) ((x))
5068#define ROT16_1(x) ((x).sF0123456789ABCDE)
5069#define ROT16_2(x) ((x).sEF0123456789ABCD)
5070#define ROT16_3(x) ((x).sDEF0123456789ABC)
5071#define ROT16_4(x) ((x).sCDEF0123456789AB)
5072#define ROT16_5(x) ((x).sBCDEF0123456789A)
5073#define ROT16_6(x) ((x).sABCDEF0123456789)
5074#define ROT16_7(x) ((x).s9ABCDEF012345678)
5075#define ROT16_8(x) ((x).s89ABCDEF01234567)
5076#define ROT16_9(x) ((x).s789ABCDEF0123456)
5077#define ROT16_10(x) ((x).s6789ABCDEF012345)
5078#define ROT16_11(x) ((x).s56789ABCDEF01234)
5079#define ROT16_12(x) ((x).s456789ABCDEF0123)
5080#define ROT16_13(x) ((x).s3456789ABCDEF012)
5081#define ROT16_14(x) ((x).s23456789ABCDEF01)
5082#define ROT16_15(x) ((x).s123456789ABCDEF0)
5083/** @} */ // end of group ROTs_n
5084
5085/** Circular-right-shift (rotate-right) the given vector by the given amount.
5086 * @name ROTATE
5087 *
5088 * @param[in] x The vector to be shifted
5089 * @param[in] s The size of the vector
5090 * @param[in] n The amount to be shifted
5091 *
5092 * @return The shifted vector
5093 * @{
5094 */
5095#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
5096#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
5097/** @} */ // end of group ROTATE
5098
5099/** Creates a vector of size n filled with offset values corresponding to the location of each element.
5100 * @name V_OFFSn
5101 *
5102 * @param[in] dt The data type of the output vector
5103 *
5104 * @return The vector filled with offset values
5105 * @{
5106 */
5107#define V_OFFS1(dt) (dt##1)(0)
5108#define V_OFFS2(dt) (dt##2)(0, 1)
5109#define V_OFFS3(dt) (dt##3)(0, 1, 2)
5110#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
5111#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
5112#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
5113/** @} */ // end of group V_OFFSn
5114
5115/** Create a vector filled with offset values corresponding to the location of each element.
5116 * @name VEC_OFFS
5117 *
5118 * @param[in] dt The data type of the output vector
5119 * @param[in] s  The size of the output vector
5120 *
5121 * @return The vector filled with offset values
5122 * @{
5123 */
5124#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
5125#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
5126/** @} */ // end of group VEC_OFFS
5127
5128#define VLOAD_STR(size) vload##size
5129#define VLOAD(size) VLOAD_STR(size)
5130
5131#define PIXEL_UNIT4 1
5132#define PIXEL_UNIT8 2
5133#define PIXEL_UNIT16 4
5134
5135/** Utility macro to convert a vector size in pixel unit.
5136 *
5137 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
5138 *
5139 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
5140 *
5141 * @return The pixel unit (number of pixels)
5142 * @{
5143 */
5144#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
5145#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
5146/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
5147
5148#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
5149#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
5150#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
5151
5152#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
5153#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
5154#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
5155#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
5156#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
5157
5158/** Utility macro to read a 2D OpenCL image object.
5159 *
5160 * @note Coordinates are not normalized
5161 *
5162 * @param[in] data_type Data type
5163 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
5164 * @param[in] img       OpenCL image object
5165 * @param[in] x_coord   The x coordinate for the top-left pixel
5166 * @param[in] y_coord   The y coordinate for the top-left pixel
5167 *
5168 * @return Pixels from the 2D OpenCL image object
5169 * @{
5170 */
5171#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
5172#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
5173
5174#define VSTORE_STR(size) vstore##size
5175#define VSTORE(size) VSTORE_STR(size)
5176
5177#define float1 float
5178#define half1 half
5179#define char1 char
5180#define uchar1 uchar
5181#define short1 short
5182#define ushort1 ushort
5183#define int1 int
5184#define uint1 uint
5185#define long1 long
5186#define ulong1 ulong
5187#define double1 double
5188
5189#define vload1(OFFSET, PTR) *(OFFSET + PTR)
5190#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
5191
5192/** Extended partial vstore that correctly handles scalar values as well.
5193 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
5194 * @name VSTORE_PARTIAL
5195 *
5196 * @note With this macro, the passed data can be both a vector and a scalar
5197 * @note @p store_size needs to be <= @p size
5198 * eg 1: Valid
5199 * VSTORE_PARTIAL(16, 15) ...;
5200 * eg 2: Invalid
5201 * VSTORE_PARTIAL(4, 7) ...;
5202 *
5203 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
5204 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
5205 * @{
5206 */
5207#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
5208#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
5209
5210#define NO_STORE(data, offs, ptr) \
5211    {                             \
5212    }
5213
5214// Size == 1 (scalar)
5215#define vstore_partial_1_0 NO_STORE
5216#define vstore_partial_1_1 vstore1
5217#define vstore_partial_1_2 NO_STORE
5218#define vstore_partial_1_3 NO_STORE
5219#define vstore_partial_1_4 NO_STORE
5220#define vstore_partial_1_5 NO_STORE
5221#define vstore_partial_1_6 NO_STORE
5222#define vstore_partial_1_7 NO_STORE
5223#define vstore_partial_1_8 NO_STORE
5224#define vstore_partial_1_9 NO_STORE
5225#define vstore_partial_1_10 NO_STORE
5226#define vstore_partial_1_11 NO_STORE
5227#define vstore_partial_1_12 NO_STORE
5228#define vstore_partial_1_13 NO_STORE
5229#define vstore_partial_1_14 NO_STORE
5230#define vstore_partial_1_15 NO_STORE
5231#define vstore_partial_1_16 NO_STORE
5232// Size == 2
5233#define vstore_partial_2_0 NO_STORE
5234#define vstore_partial_2_1 vstore_partial_1
5235#define vstore_partial_2_2 vstore_partial_2
5236#define vstore_partial_2_3 NO_STORE
5237#define vstore_partial_2_4 NO_STORE
5238#define vstore_partial_2_5 NO_STORE
5239#define vstore_partial_2_6 NO_STORE
5240#define vstore_partial_2_7 NO_STORE
5241#define vstore_partial_2_8 NO_STORE
5242#define vstore_partial_2_9 NO_STORE
5243#define vstore_partial_2_10 NO_STORE
5244#define vstore_partial_2_11 NO_STORE
5245#define vstore_partial_2_12 NO_STORE
5246#define vstore_partial_2_13 NO_STORE
5247#define vstore_partial_2_14 NO_STORE
5248#define vstore_partial_2_15 NO_STORE
5249#define vstore_partial_2_16 NO_STORE
5250// Size == 3
5251#define vstore_partial_3_0 NO_STORE
5252#define vstore_partial_3_1 vstore_partial_1
5253#define vstore_partial_3_2 vstore_partial_2
5254#define vstore_partial_3_3 vstore_partial_3
5255#define vstore_partial_3_4 NO_STORE
5256#define vstore_partial_3_5 NO_STORE
5257#define vstore_partial_3_6 NO_STORE
5258#define vstore_partial_3_7 NO_STORE
5259#define vstore_partial_3_8 NO_STORE
5260#define vstore_partial_3_9 NO_STORE
5261#define vstore_partial_3_10 NO_STORE
5262#define vstore_partial_3_11 NO_STORE
5263#define vstore_partial_3_12 NO_STORE
5264#define vstore_partial_3_13 NO_STORE
5265#define vstore_partial_3_14 NO_STORE
5266#define vstore_partial_3_15 NO_STORE
5267#define vstore_partial_3_16 NO_STORE
5268// Size == 4
5269#define vstore_partial_4_0 NO_STORE
5270#define vstore_partial_4_1 vstore_partial_1
5271#define vstore_partial_4_2 vstore_partial_2
5272#define vstore_partial_4_3 vstore_partial_3
5273#define vstore_partial_4_4 vstore_partial_4
5274#define vstore_partial_4_5 NO_STORE
5275#define vstore_partial_4_6 NO_STORE
5276#define vstore_partial_4_7 NO_STORE
5277#define vstore_partial_4_8 NO_STORE
5278#define vstore_partial_4_9 NO_STORE
5279#define vstore_partial_4_10 NO_STORE
5280#define vstore_partial_4_11 NO_STORE
5281#define vstore_partial_4_12 NO_STORE
5282#define vstore_partial_4_13 NO_STORE
5283#define vstore_partial_4_14 NO_STORE
5284#define vstore_partial_4_15 NO_STORE
5285#define vstore_partial_4_16 NO_STORE
5286// Size == 8
5287#define vstore_partial_8_0 NO_STORE
5288#define vstore_partial_8_1 vstore_partial_1
5289#define vstore_partial_8_2 vstore_partial_2
5290#define vstore_partial_8_3 vstore_partial_3
5291#define vstore_partial_8_4 vstore_partial_4
5292#define vstore_partial_8_5 vstore_partial_5
5293#define vstore_partial_8_6 vstore_partial_6
5294#define vstore_partial_8_7 vstore_partial_7
5295#define vstore_partial_8_8 vstore_partial_8
5296#define vstore_partial_8_9 NO_STORE
5297#define vstore_partial_8_10 NO_STORE
5298#define vstore_partial_8_11 NO_STORE
5299#define vstore_partial_8_12 NO_STORE
5300#define vstore_partial_8_13 NO_STORE
5301#define vstore_partial_8_14 NO_STORE
5302#define vstore_partial_8_15 NO_STORE
5303#define vstore_partial_8_16 NO_STORE
5304// Size == 16
5305#define vstore_partial_16_0 NO_STORE
5306#define vstore_partial_16_1 vstore_partial_1
5307#define vstore_partial_16_2 vstore_partial_2
5308#define vstore_partial_16_3 vstore_partial_3
5309#define vstore_partial_16_4 vstore_partial_4
5310#define vstore_partial_16_5 vstore_partial_5
5311#define vstore_partial_16_6 vstore_partial_6
5312#define vstore_partial_16_7 vstore_partial_7
5313#define vstore_partial_16_8 vstore_partial_8
5314#define vstore_partial_16_9 vstore_partial_9
5315#define vstore_partial_16_10 vstore_partial_10
5316#define vstore_partial_16_11 vstore_partial_11
5317#define vstore_partial_16_12 vstore_partial_12
5318#define vstore_partial_16_13 vstore_partial_13
5319#define vstore_partial_16_14 vstore_partial_14
5320#define vstore_partial_16_15 vstore_partial_15
5321#define vstore_partial_16_16 vstore_partial_16
5322
5323/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
5324 * @name vstore_partial_n
5325 *
5326 * @note @p DATA needs to be a vector not a scalar
5327 * @note n needs to be <= the vector width of the input variable @p DATA
5328 * eg 1: Valid
5329 * vstore_partial_15(var:float16, 0, 0xabcd);
5330 * eg 2: Invalid
5331 * vstore_partial_7(var:float4, 0, 0xabcd);
5332 *
5333 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
5334 *
5335 * @param[in] DATA   The name of the variable
5336 * @param[in] OFFSET Offset in n
5337 * @param[in] PTR    The base pointer
5338 * @{
5339 */
5340#define vstore_partial_1(DATA, OFFSET, PTR) \
5341    vstore1(DATA.s0, OFFSET, PTR);
5342
5343#define vstore_partial_2(DATA, OFFSET, PTR) \
5344    vstore2(DATA.s01, OFFSET, PTR);
5345
5346#define vstore_partial_3(DATA, OFFSET, PTR) \
5347    vstore3(DATA.s012, OFFSET, PTR);
5348
5349#define vstore_partial_4(DATA, OFFSET, PTR) \
5350    vstore4(DATA.s0123, OFFSET, PTR);
5351
5352#define vstore_partial_5(DATA, OFFSET, PTR)    \
5353    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
5354    vstore1(DATA.s4, OFFSET, PTR + 4);
5355
5356#define vstore_partial_6(DATA, OFFSET, PTR)    \
5357    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
5358    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
5359
5360#define vstore_partial_7(DATA, OFFSET, PTR)    \
5361    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
5362    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
5363
5364#define vstore_partial_8(DATA, OFFSET, PTR) \
5365    vstore8(DATA.s01234567, OFFSET, PTR);
5366
5367#define vstore_partial_9(DATA, OFFSET, PTR)        \
5368    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5369    vstore1(DATA.s8, OFFSET, PTR + 8);
5370
5371#define vstore_partial_10(DATA, OFFSET, PTR)       \
5372    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5373    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
5374
5375#define vstore_partial_11(DATA, OFFSET, PTR)       \
5376    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5377    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
5378
5379#define vstore_partial_12(DATA, OFFSET, PTR)       \
5380    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5381    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
5382
5383#define vstore_partial_13(DATA, OFFSET, PTR)       \
5384    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5385    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
5386
5387#define vstore_partial_14(DATA, OFFSET, PTR)       \
5388    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5389    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
5390
5391#define vstore_partial_15(DATA, OFFSET, PTR)       \
5392    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
5393    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
5394
5395#define vstore_partial_16(DATA, OFFSET, PTR) \
5396    vstore16(DATA, OFFSET, PTR);
5397/** @} */ // end of groupd vstore_partial_n
5398/** @} */ // end of groupd VSTORE_PARTIAL
5399
5400// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
5401// without _sat to overcome this issue
5402#define convert_float_sat convert_float
5403#define convert_float1_sat convert_float
5404#define convert_float2_sat convert_float2
5405#define convert_float3_sat convert_float3
5406#define convert_float4_sat convert_float4
5407#define convert_float8_sat convert_float8
5408#define convert_float16_sat convert_float16
5409#define convert_half_sat convert_float
5410#define convert_half1_sat convert_half
5411#define convert_half2_sat convert_half2
5412#define convert_half3_sat convert_half3
5413#define convert_half4_sat convert_half4
5414#define convert_half8_sat convert_half8
5415#define convert_half16_sat convert_half16
5416
5417#define convert_float1 convert_float
5418#define convert_half1 convert_half
5419#define convert_char1 convert_char
5420#define convert_uchar1 convert_uchar
5421#define convert_short1 convert_short
5422#define convert_ushort1 convert_ushort
5423#define convert_int1 convert_int
5424#define convert_uint1 convert_uint
5425#define convert_long1 convert_long
5426#define convert_ulong1 convert_ulong
5427#define convert_double1 convert_double
5428
5429#define convert_char1_sat convert_char_sat
5430#define convert_uchar1_sat convert_uchar_sat
5431#define convert_short1_sat convert_short_sat
5432#define convert_ushort1_sat convert_ushort_sat
5433#define convert_int1_sat convert_int_sat
5434#define convert_uint1_sat convert_uint_sat
5435#define convert_long1_sat convert_long_sat
5436#define convert_ulong1_sat convert_ulong_sat
5437#define convert_double1_sat convert_double_sat
5438
5439#define VEC_DATA_TYPE_STR(type, size) type##size
5440#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
5441
5442#define CONVERT_STR(x, type) (convert_##type((x)))
5443#define CONVERT(x, type) CONVERT_STR(x, type)
5444
5445#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
5446#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
5447
5448#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
5449#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
5450
5451#define select_vec_dt_uchar(size) uchar##size
5452#define select_vec_dt_char(size) char##size
5453#define select_vec_dt_ushort(size) ushort##size
5454#define select_vec_dt_short(size) short##size
5455#define select_vec_dt_half(size) short##size
5456#define select_vec_dt_uint(size) uint##size
5457#define select_vec_dt_int(size) int##size
5458#define select_vec_dt_float(size) int##size
5459#define select_vec_dt_ulong(size) ulong##size
5460#define select_vec_dt_long(size) long##size
5461
5462#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
5463#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
5464#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
5465
5466#define sum_reduce_1(x) (x)
5467#define sum_reduce_2(x) ((x).s0) + ((x).s1)
5468#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
5469#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
5470#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
5471#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
5472
5473#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
5474#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
5475
5476#define max_reduce_1(x) (x)
5477#define max_reduce_2(x) max(((x).s0), ((x).s1))
5478#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
5479#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
5480#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
5481#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
5482
5483#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
5484#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
5485
5486#define VECTOR_DECLARATION(name)     \
5487    __global uchar *name##_ptr,      \
5488    uint        name##_stride_x, \
5489    uint        name##_step_x,   \
5490    uint        name##_offset_first_element_in_bytes
5491
5492#define IMAGE_DECLARATION(name)      \
5493    __global uchar *name##_ptr,      \
5494    uint        name##_stride_x, \
5495    uint        name##_step_x,   \
5496    uint        name##_stride_y, \
5497    uint        name##_step_y,   \
5498    uint        name##_offset_first_element_in_bytes
5499
5500#define TENSOR3D_DECLARATION(name)   \
5501    __global uchar *name##_ptr,      \
5502    uint        name##_stride_x, \
5503    uint        name##_step_x,   \
5504    uint        name##_stride_y, \
5505    uint        name##_step_y,   \
5506    uint        name##_stride_z, \
5507    uint        name##_step_z,   \
5508    uint        name##_offset_first_element_in_bytes
5509
5510#define TENSOR4D_DECLARATION(name)   \
5511    __global uchar *name##_ptr,      \
5512    uint        name##_stride_x, \
5513    uint        name##_step_x,   \
5514    uint        name##_stride_y, \
5515    uint        name##_step_y,   \
5516    uint        name##_stride_z, \
5517    uint        name##_step_z,   \
5518    uint        name##_stride_w, \
5519    uint        name##_step_w,   \
5520    uint        name##_offset_first_element_in_bytes
5521
5522#define CONVERT_TO_VECTOR_STRUCT(name) \
5523    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
5524
5525#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
5526    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
5527
5528#define CONVERT_TO_IMAGE_STRUCT(name) \
5529    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
5530
5531#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
5532    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
5533
5534#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
5535    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
5536
5537#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
5538    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
5539
5540#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
5541    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
5542
5543#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
5544    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
5545                                 name##_stride_z, name##_step_z)
5546
5547#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
5548    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
5549
5550#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
5551    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
5552                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
5553
5554#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
5555    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
5556
5557#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
5558    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
5559                           name##_stride_z, name##_step_z)
5560
5561/** Structure to hold Vector information */
5562typedef struct Vector
5563{
5564    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
5565    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
5566    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
5567} Vector;
5568
5569/** Structure to hold Image information */
5570typedef struct Image
5571{
5572    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
5573    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
5574    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
5575    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
5576} Image;
5577
5578/** Structure to hold 3D tensor information */
5579typedef struct Tensor3D
5580{
5581    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
5582    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
5583    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
5584    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
5585    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
5586} Tensor3D;
5587
5588/** Structure to hold 4D tensor information */
5589typedef struct Tensor4D
5590{
5591    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
5592    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
5593    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
5594    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
5595    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
5596    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
5597} Tensor4D;
5598
5599/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
5600 *
5601 * @param[in] ptr                           Pointer to the starting postion of the buffer
5602 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
5603 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
5604 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
5605 *
5606 * @return An image object
5607 */
5608inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
5609{
5610    Vector vector =
5611    {
5612        .ptr                           = ptr,
5613        .offset_first_element_in_bytes = offset_first_element_in_bytes,
5614        .stride_x                      = stride_x,
5615    };
5616    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
5617    return vector;
5618}
5619
5620/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
5621 *
5622 * @param[in] ptr                           Pointer to the starting postion of the buffer
5623 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
5624 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
5625 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
5626 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
5627 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
5628 *
5629 * @return An image object
5630 */
5631inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
5632{
5633    Image img =
5634    {
5635        .ptr                           = ptr,
5636        .offset_first_element_in_bytes = offset_first_element_in_bytes,
5637        .stride_x                      = stride_x,
5638        .stride_y                      = stride_y
5639    };
5640    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
5641    return img;
5642}
5643
5644/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
5645 *
5646 * @param[in] ptr                           Pointer to the starting postion of the buffer
5647 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
5648 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
5649 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
5650 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
5651 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
5652 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
5653 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
5654 *
5655 * @return A 3D tensor object
5656 */
5657inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
5658{
5659    Image img =
5660    {
5661        .ptr                           = ptr,
5662        .offset_first_element_in_bytes = offset_first_element_in_bytes,
5663        .stride_x                      = stride_x,
5664        .stride_y                      = stride_y
5665    };
5666    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
5667    return img;
5668}
5669
5670/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
5671 *
5672 * @param[in] ptr                           Pointer to the starting postion of the buffer
5673 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
5674 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
5675 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
5676 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
5677 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
5678 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
5679 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
5680 *
5681 * @return A 3D tensor object
5682 */
5683inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
5684{
5685    Tensor3D tensor =
5686    {
5687        .ptr                           = ptr,
5688        .offset_first_element_in_bytes = offset_first_element_in_bytes,
5689        .stride_x                      = stride_x,
5690        .stride_y                      = stride_y,
5691        .stride_z                      = stride_z
5692    };
5693    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
5694    return tensor;
5695}
5696
5697/** Wrap 3D tensor information into an tensor structure.
5698 *
5699 * @param[in] ptr                           Pointer to the starting postion of the buffer
5700 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
5701 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
5702 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
5703 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
5704 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
5705 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
5706 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
5707 *
5708 * @return A 3D tensor object
5709 */
5710inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
5711{
5712    Tensor3D tensor =
5713    {
5714        .ptr                           = ptr,
5715        .offset_first_element_in_bytes = offset_first_element_in_bytes,
5716        .stride_x                      = stride_x,
5717        .stride_y                      = stride_y,
5718        .stride_z                      = stride_z
5719    };
5720    return tensor;
5721}
5722
5723inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
5724                                             uint step_w,
5725                                             uint mod_size)
5726{
5727    Tensor4D tensor =
5728    {
5729        .ptr                           = ptr,
5730        .offset_first_element_in_bytes = offset_first_element_in_bytes,
5731        .stride_x                      = stride_x,
5732        .stride_y                      = stride_y,
5733        .stride_z                      = stride_z,
5734        .stride_w                      = stride_w
5735    };
5736
5737    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
5738    return tensor;
5739}
5740
5741/** Get the pointer position of a Vector
5742 *
5743 * @param[in] vec Pointer to the starting position of the buffer
5744 * @param[in] x   Relative X position
5745 */
5746inline __global const uchar *vector_offset(const Vector *vec, int x)
5747{
5748    return vec->ptr + x * vec->stride_x;
5749}
5750
5751/** Get the pointer position of a Image
5752 *
5753 * @param[in] img Pointer to the starting position of the buffer
5754 * @param[in] x   Relative X position
5755 * @param[in] y   Relative Y position
5756 */
5757inline __global uchar *offset(const Image *img, int x, int y)
5758{
5759    return img->ptr + x * img->stride_x + y * img->stride_y;
5760}
5761
5762/** Get the pointer position of a Tensor3D
5763 *
5764 * @param[in] tensor Pointer to the starting position of the buffer
5765 * @param[in] x      Relative X position
5766 * @param[in] y      Relative Y position
5767 * @param[in] z      Relative Z position
5768 */
5769inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
5770{
5771    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
5772}
5773
5774/** Get the pointer position of a Tensor4D
5775 *
5776 * @param[in] tensor Pointer to the starting position of the buffer
5777 * @param[in] x      Relative X position
5778 * @param[in] y      Relative Y position
5779 * @param[in] z      Relative Z position
5780 * @param[in] w      Relative W position
5781 */
5782inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
5783{
5784    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
5785}
5786
5787/** Get the offset for a given linear index of a Tensor3D
5788 *
5789 * @param[in] tensor Pointer to the starting position of the buffer
5790 * @param[in] width  Width of the input tensor
5791 * @param[in] height Height of the input tensor
5792 * @param[in] depth  Depth of the input tensor
5793 * @param[in] index  Linear index
5794 */
5795inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
5796{
5797    uint num_elements = width * height;
5798
5799    const uint z = index / num_elements;
5800
5801    index %= num_elements;
5802
5803    const uint y = index / width;
5804
5805    index %= width;
5806
5807    const uint x = index;
5808
5809    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
5810}
5811
5812#endif // _HELPER_H
5813
5814/** Macros that help in loop unrolling */
5815//Repeat macros with 3 param, excluding the implicit ID param
5816#define REPEAT_3_1(P_X, P_A, P_B, P_C) P_X##_DEF(0, P_A, P_B, P_C)
5817#define REPEAT_3_2(P_X, P_A, P_B, P_C) \
5818    P_X##_DEF(1, P_A, P_B, P_C);       \
5819    REPEAT_3_1(P_X, P_A, P_B, P_C)
5820#define REPEAT_3_3(P_X, P_A, P_B, P_C) \
5821    P_X##_DEF(2, P_A, P_B, P_C);       \
5822    REPEAT_3_2(P_X, P_A, P_B, P_C)
5823#define REPEAT_3_4(P_X, P_A, P_B, P_C) \
5824    P_X##_DEF(3, P_A, P_B, P_C);       \
5825    REPEAT_3_3(P_X, P_A, P_B, P_C)
5826#define REPEAT_3_5(P_X, P_A, P_B, P_C) \
5827    P_X##_DEF(4, P_A, P_B, P_C);       \
5828    REPEAT_3_4(P_X, P_A, P_B, P_C)
5829#define REPEAT_3_6(P_X, P_A, P_B, P_C) \
5830    P_X##_DEF(5, P_A, P_B, P_C);       \
5831    REPEAT_3_5(P_X, P_A, P_B, P_C)
5832#define REPEAT_3_7(P_X, P_A, P_B, P_C) \
5833    P_X##_DEF(6, P_A, P_B, P_C);       \
5834    REPEAT_3_6(P_X, P_A, P_B, P_C)
5835#define REPEAT_3_8(P_X, P_A, P_B, P_C) \
5836    P_X##_DEF(7, P_A, P_B, P_C);       \
5837    REPEAT_3_7(P_X, P_A, P_B, P_C)
5838#define REPEAT_3_9(P_X, P_A, P_B, P_C) \
5839    P_X##_DEF(8, P_A, P_B, P_C);       \
5840    REPEAT_3_8(P_X, P_A, P_B, P_C)
5841#define REPEAT_3_10(P_X, P_A, P_B, P_C) \
5842    P_X##_DEF(9, P_A, P_B, P_C);        \
5843    REPEAT_3_9(P_X, P_A, P_B, P_C)
5844#define REPEAT_3_11(P_X, P_A, P_B, P_C) \
5845    P_X##_DEF(A, P_A, P_B, P_C);        \
5846    REPEAT_3_10(P_X, P_A, P_B, P_C)
5847#define REPEAT_3_12(P_X, P_A, P_B, P_C) \
5848    P_X##_DEF(B, P_A, P_B, P_C);        \
5849    REPEAT_3_11(P_X, P_A, P_B, P_C)
5850#define REPEAT_3_13(P_X, P_A, P_B, P_C) \
5851    P_X##_DEF(C, P_A, P_B, P_C);        \
5852    REPEAT_3_12(P_X, P_A, P_B, P_C)
5853#define REPEAT_3_14(P_X, P_A, P_B, P_C) \
5854    P_X##_DEF(D, P_A, P_B, P_C);        \
5855    REPEAT_3_13(P_X, P_A, P_B, P_C)
5856#define REPEAT_3_15(P_X, P_A, P_B, P_C) \
5857    P_X##_DEF(E, P_A, P_B, P_C);        \
5858    REPEAT_3_14(P_X, P_A, P_B, P_C)
5859#define REPEAT_3_16(P_X, P_A, P_B, P_C) \
5860    P_X##_DEF(F, P_A, P_B, P_C);        \
5861    REPEAT_3_15(P_X, P_A, P_B, P_C)
5862
5863#define REPEAT_DEF_3_N(P_NUM, P_OP, P_A, P_B, P_C) REPEAT_3_##P_NUM(P_OP, P_A, P_B, P_C) //One level of indirection to ensure order of expansion does not affect preprocessing P_NUM
5864#define REPEAT_3_N(P_NUM, P_OP, P_A, P_B, P_C) REPEAT_DEF_3_N(P_NUM, P_OP, P_A, P_B, P_C)
5865
5866// Repeat macros with 4 param, excluding the implicit ID param
5867#define REPEAT_4_1(P_X, P_A, P_B, P_C, P_D) P_X##_DEF(0, P_A, P_B, P_C, P_D)
5868#define REPEAT_4_2(P_X, P_A, P_B, P_C, P_D) \
5869    P_X##_DEF(1, P_A, P_B, P_C, P_D);       \
5870    REPEAT_4_1(P_X, P_A, P_B, P_C, P_D)
5871#define REPEAT_4_3(P_X, P_A, P_B, P_C, P_D) \
5872    P_X##_DEF(2, P_A, P_B, P_C, P_D);       \
5873    REPEAT_4_2(P_X, P_A, P_B, P_C, P_D)
5874#define REPEAT_4_4(P_X, P_A, P_B, P_C, P_D) \
5875    P_X##_DEF(3, P_A, P_B, P_C, P_D);       \
5876    REPEAT_4_3(P_X, P_A, P_B, P_C, P_D)
5877#define REPEAT_4_5(P_X, P_A, P_B, P_C, P_D) \
5878    P_X##_DEF(4, P_A, P_B, P_C, P_D);       \
5879    REPEAT_4_4(P_X, P_A, P_B, P_C, P_D)
5880#define REPEAT_4_6(P_X, P_A, P_B, P_C, P_D) \
5881    P_X##_DEF(5, P_A, P_B, P_C, P_D);       \
5882    REPEAT_4_5(P_X, P_A, P_B, P_C, P_D)
5883#define REPEAT_4_7(P_X, P_A, P_B, P_C, P_D) \
5884    P_X##_DEF(6, P_A, P_B, P_C, P_D);       \
5885    REPEAT_4_6(P_X, P_A, P_B, P_C, P_D)
5886#define REPEAT_4_8(P_X, P_A, P_B, P_C, P_D) \
5887    P_X##_DEF(7, P_A, P_B, P_C, P_D);       \
5888    REPEAT_4_7(P_X, P_A, P_B, P_C, P_D)
5889#define REPEAT_4_9(P_X, P_A, P_B, P_C, P_D) \
5890    P_X##_DEF(8, P_A, P_B, P_C, P_D);       \
5891    REPEAT_4_8(P_X, P_A, P_B, P_C, P_D)
5892#define REPEAT_4_10(P_X, P_A, P_B, P_C, P_D) \
5893    P_X##_DEF(9, P_A, P_B, P_C, P_D);        \
5894    REPEAT_4_9(P_X, P_A, P_B, P_C, P_D)
5895#define REPEAT_4_11(P_X, P_A, P_B, P_C, P_D) \
5896    P_X##_DEF(A, P_A, P_B, P_C, P_D);        \
5897    REPEAT_4_10(P_X, P_A, P_B, P_C, P_D)
5898#define REPEAT_4_12(P_X, P_A, P_B, P_C, P_D) \
5899    P_X##_DEF(B, P_A, P_B, P_C, P_D);        \
5900    REPEAT_4_11(P_X, P_A, P_B, P_C, P_D)
5901#define REPEAT_4_13(P_X, P_A, P_B, P_C, P_D) \
5902    P_X##_DEF(C, P_A, P_B, P_C, P_D);        \
5903    REPEAT_4_12(P_X, P_A, P_B, P_C, P_D)
5904#define REPEAT_4_14(P_X, P_A, P_B, P_C, P_D) \
5905    P_X##_DEF(D, P_A, P_B, P_C, P_D);        \
5906    REPEAT_4_13(P_X, P_A, P_B, P_C, P_D)
5907#define REPEAT_4_15(P_X, P_A, P_B, P_C, P_D) \
5908    P_X##_DEF(E, P_A, P_B, P_C, P_D);        \
5909    REPEAT_4_14(P_X, P_A, P_B, P_C, P_D)
5910#define REPEAT_4_16(P_X, P_A, P_B, P_C, P_D) \
5911    P_X##_DEF(F, P_A, P_B, P_C, P_D);        \
5912    REPEAT_4_15(P_X, P_A, P_B, P_C, P_D)
5913
5914#define REPEAT_DEF_4_N(P_NUM, P_OP, P_A, P_B, P_C, P_D) REPEAT_4_##P_NUM(P_OP, P_A, P_B, P_C, P_D) //One level of indirection to ensure order of expansion does not affect preprocessing P_NUM
5915#define REPEAT_4_N(P_NUM, P_OP, P_A, P_B, P_C, P_D) REPEAT_DEF_4_N(P_NUM, P_OP, P_A, P_B, P_C, P_D)
5916
5917// Macro for initializing N variables. Generates N statements that defines VAR##N = RHS_ACCESSOR_DEF(...)
5918#define VAR_INIT_TO_CONST_DEF(ID, TYPE, VAR, VAL) TYPE VAR##ID = VAL
5919#define REPEAT_VAR_INIT_TO_CONST(N, TYPE, VAR, VAL) REPEAT_3_N(N, VAR_INIT_TO_CONST, TYPE, VAR, VAL)
5920
5921// Macro for initializing N variables by converting the data type. Generates N statements that defines VAR##N = RHS_ACCESSOR_DEF(...)
5922#define VAR_INIT_CONVERT_DEF(ID, TYPE_OUT, VAR_IN, VAR_OUT) TYPE_OUT VAR_OUT##ID = CONVERT(VAR_IN##ID, TYPE_OUT)
5923#define REPEAT_VAR_INIT_CONVERT(N, TYPE_OUT, VAR_IN, VAR_OUT) REPEAT_3_N(N, VAR_INIT_CONVERT, TYPE_OUT, VAR_IN, VAR_OUT)
5924
5925// Macro for initializing N variables by converting the data type with saturation. Generates N statements that defines VAR##N = RHS_ACCESSOR_DEF(...)
5926#define VAR_INIT_CONVERT_SAT_DEF(ID, TYPE_OUT, VAR_IN, VAR_OUT) TYPE_OUT VAR_OUT##ID = CONVERT_SAT(VAR_IN##ID, TYPE_OUT)
5927#define REPEAT_VAR_INIT_CONVERT_SAT(N, TYPE_OUT, VAR_IN, VAR_OUT) REPEAT_3_N(N, VAR_INIT_CONVERT_SAT, TYPE_OUT, VAR_IN, VAR_OUT)
5928
5929// Macro for adding a constant to N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5930#define ADD_CONST_TO_VAR_DEF(ID, TYPE, VAR, VAL) VAR##ID += (TYPE)VAL
5931#define REPEAT_ADD_CONST_TO_VAR(N, TYPE, VAR, VAL) REPEAT_3_N(N, ADD_CONST_TO_VAR, TYPE, VAR, VAL)
5932
5933// Macro for multiplying N variables (VAR_B) by a constant (VAL) and adding to other N variables (VAR_A). Generates N statements that defines VAR_A##N =RHS_ACCESSOR_DEF(...)
5934#define MLA_VAR_WITH_CONST_VEC_DEF(ID, VAR_A, VAR_B, VAL) VAR_A##ID += VAR_B##ID * VAL
5935#define REPEAT_MLA_VAR_WITH_CONST_VEC(N, VAR_A, VAR_B, VAL) REPEAT_3_N(N, MLA_VAR_WITH_CONST_VEC, VAR_A, VAR_B, VAL)
5936
5937// Macro for adding a vector to N-variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5938#define ADD_VECTOR_TO_VAR_DEF(ID, TYPE, VAR, VEC) VAR##ID += VEC
5939#define REPEAT_ADD_VECTOR_TO_VAR(N, VAR, VEC) REPEAT_3_N(N, ADD_VECTOR_TO_VAR, "", VAR, VEC)
5940
5941// Macro for adding a two N-variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5942#define ADD_TWO_VARS_DEF(ID, TYPE, VAR_A, VAR_B) VAR_A##ID += VAR_B##ID
5943#define REPEAT_ADD_TWO_VARS(N, VAR_A, VAR_B) REPEAT_3_N(N, ADD_TWO_VARS, "", VAR_A, VAR_B)
5944
5945// Macro for performing Max between a constant and N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5946#define MAX_CONST_VAR_DEF(ID, TYPE, VAR, VAL) VAR##ID = max(VAR##ID, (TYPE)VAL)
5947#define REPEAT_MAX_CONST_VAR(N, TYPE, VAR, VAL) REPEAT_3_N(N, MAX_CONST_VAR, TYPE, VAR, VAL)
5948
5949// Macro for performing Min between a constant and N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5950#define MIN_CONST_VAR_DEF(ID, TYPE, VAR, VAL) VAR##ID = min(VAR##ID, (TYPE)VAL)
5951#define REPEAT_MIN_CONST_VAR(N, TYPE, VAR, VAL) REPEAT_3_N(N, MIN_CONST_VAR, TYPE, VAR, VAL)
5952
5953// Macro for performing ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE to N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5954#define ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE_DEF(ID, SIZE, VAR, RES_MUL, RES_SHIFT) VAR##ID = ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, SIZE)
5955#define REPEAT_ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE(N, SIZE, VAR, RES_MUL, RES_SHIFT) REPEAT_4_N(N, ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE, SIZE, VAR, RES_MUL, RES_SHIFT)
5956
5957// Macro for performing ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE to N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...)
5958#define ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE_DEF(ID, SIZE, VAR, RES_MUL, RES_SHIFT) VAR##ID = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, SIZE)
5959#define REPEAT_ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(N, SIZE, VAR, RES_MUL, RES_SHIFT) REPEAT_4_N(N, ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE, SIZE, VAR, RES_MUL, RES_SHIFT)
5960
5961// Macro for performing per-channel ASYMM_MULT_BY_QUANT_MULTIPLIER to N variables.
5962#define ASYMM_MULT_BY_QUANT_MULTIPLIER_PER_CHANNEL_DEF(ID, SIZE, VAR, RES_MUL, RES_SHIFT)                     \
5963    ({                                                                                                        \
5964        VEC_DATA_TYPE(int, N0)                                                                                \
5965        VAR##ID_shift_lt0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, N0); \
5966        VEC_DATA_TYPE(int, N0)                                                                                \
5967        VAR##ID_shift_gt0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, N0);    \
5968        VAR##ID           = select(VAR##ID_shift_lt0, VAR##ID_shift_gt0, RES_SHIFT >= 0);                     \
5969    })
5970#define REPEAT_ASYMM_MULT_BY_QUANT_MULTIPLIER_PER_CHANNEL(N, SIZE, VAR, RES_MUL, RES_SHIFT) REPEAT_4_N(N, ASYMM_MULT_BY_QUANT_MULTIPLIER_PER_CHANNEL, SIZE, VAR, RES_MUL, RES_SHIFT)
5971
5972#endif // ARM_COMPUTE_REPEAT_H
5973
5974#if defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
5975/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
5976 *
5977 * @note The number of rows of destination matrix must be passed at compile time using -DM
5978 * @note The number of columns of the destination matrix must be passed at compile time using -DN
5979 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
5980 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
5981 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
5982 * @note The optional alpha's value need to be passed at compile time using -DALPHA
5983 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
5984 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
5985 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5986 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
5987 *
5988 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5989 *       The activation function is performed after the bias addition
5990 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
5991 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5992 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5993 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5994 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5995 *
5996 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
5997 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
5998 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
5999 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6000 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6001 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6002 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
6003 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
6004 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6005 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6006 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6007 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
6008 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6009 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
6010 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6011 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
6012 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6013 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
6014 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
6015 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
6016 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
6017 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
6018 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
6019 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
6020 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6021 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6022 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
6023 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
6024 * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
6025 */
6026__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
6027                                                 IMAGE_DECLARATION(src1),
6028#if defined(BETA)
6029                                                 IMAGE_DECLARATION(src2),
6030#endif // defined(BETA)
6031                                                 IMAGE_DECLARATION(dst),
6032                                                 uint src0_stride_z,
6033                                                 uint src1_stride_z,
6034#if defined(BETA)
6035                                                 uint src2_stride_z,
6036#endif //defined(BETA)
6037                                                 uint dst_stride_z
6038#if defined(REINTERPRET_OUTPUT_AS_3D)
6039                                                 ,
6040                                                 uint cross_plane_pad
6041#endif // REINTERPRET_OUTPUT_AS_3D
6042                                                )
6043{
6044    int x = get_global_id(0) / H0;
6045    int y = get_global_id(1) / V0;
6046    int z = get_global_id(2);
6047
6048    // Offset
6049    const int offset_row_a = (get_global_id(1) % V0) * 4;
6050    const int offset_row_b = (get_global_id(0) % H0) * 4;
6051
6052    // src_addr_a = address of matrix A
6053    // src_addr_b = address of matrix B
6054    int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
6055    int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
6056
6057#if defined(MATRIX_B_DEPTH)
6058    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6059    src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
6060#else  // defined(MATRIX_B_DEPTH)
6061    src1_addr_in_bytes += z * src1_stride_z;
6062#endif // defined(MATRIX_B_DEPTH)
6063
6064    __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
6065    __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
6066
6067    // Compute end row address for matrix B
6068    __global float *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(float));
6069
6070    src_addr_a += offset_row_a;
6071    src_addr_b += offset_row_b;
6072
6073    // Reset accumulators
6074    float4 c0 = 0.0f;
6075    float4 c1 = 0.0f;
6076    float4 c2 = 0.0f;
6077    float4 c3 = 0.0f;
6078
6079    for(; src_addr_b <= (src_end_addr_b - (int)(8 * H0)); src_addr_a += 8 * V0, src_addr_b += 8 * H0)
6080    {
6081        // Load values from matrix A (interleaved) and matrix B (transposed)
6082        float4 a0 = vload4(0, src_addr_a);
6083        float4 b0 = vload4(0, src_addr_b);
6084
6085        c0 += (float4)a0.s0 * b0;
6086        c1 += (float4)a0.s1 * b0;
6087        c2 += (float4)a0.s2 * b0;
6088        c3 += (float4)a0.s3 * b0;
6089
6090        // Load values from matrix A (interleaved) and matrix B (transposed)
6091        a0 = vload4(0, src_addr_a + 4 * V0);
6092        b0 = vload4(0, src_addr_b + 4 * H0);
6093
6094        c0 += (float4)a0.s0 * b0;
6095        c1 += (float4)a0.s1 * b0;
6096        c2 += (float4)a0.s2 * b0;
6097        c3 += (float4)a0.s3 * b0;
6098    }
6099
6100    for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 4 * H0)
6101    {
6102        // Load values from matrix A (interleaved) and matrix B (transposed)
6103        float4 a0 = vload4(0, src_addr_a);
6104        float4 b0 = vload4(0, src_addr_b);
6105
6106        c0 += (float4)a0.s0 * b0;
6107        c1 += (float4)a0.s1 * b0;
6108        c2 += (float4)a0.s2 * b0;
6109        c3 += (float4)a0.s3 * b0;
6110    }
6111
6112    // Compute destination address
6113    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6114
6115    // Compute dst address
6116    __global uchar *dst_addr = offset(&dst, 0, 0);
6117
6118    uint4 zout = 0;
6119
6120#if defined(REINTERPRET_OUTPUT_AS_3D)
6121    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
6122    // in order to take into account the presence of possible cross plane paddings
6123    //
6124    //  |                  |
6125    //  |      plane0      |
6126    //  |                  |
6127    //  |__________________|
6128    //  |******************|
6129    //  |  cross_plane_pad |
6130    //  |******************|
6131    //  |                  |
6132    //  |      plane1      |
6133    //  |                  |
6134    //  |__________________|
6135
6136    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
6137    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
6138    zout = min(DEPTH_GEMM3D - 1, zout);
6139
6140    // Add offset due to the cross plane paddings
6141    zout *= (cross_plane_pad * dst_stride_y);
6142
6143    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6144    // multiply dst_stride_z by DEPTH_GEMM3D
6145    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
6146#else  // defined(REINTERPRET_OUTPUT_AS_3D)
6147    // Add offset for batched GEMM
6148    dst_addr += z * dst_stride_z;
6149#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6150
6151    // Multiply by the weight of matrix-matrix product and store the result
6152#if defined(ALPHA)
6153    SCALE_BLOCK(4, float, c, ALPHA);
6154#endif // defined(ALPHA)
6155
6156    // Add beta*bias
6157#if defined(BETA)
6158    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
6159
6160#if defined(BROADCAST_BIAS)
6161    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
6162
6163    LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
6164
6165#ifndef UNIT_BETA
6166    SCALE_BLOCK(1, float, bias, BETA);
6167#endif // UNIT_BIAS
6168
6169    // c = c + bias[broadcasted]
6170    ADD_BLOCK_BROADCAST(4, c, bias0);
6171
6172#else // defined(BROADCAST_BIAS)
6173    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
6174                                    2) * src2_stride_z;
6175
6176    LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
6177
6178#ifndef UNIT_BETA
6179    SCALE_BLOCK(4, float, bias, BETA);
6180#endif // UNIT_BIAS
6181
6182    // c = c + bias
6183    ADD_BLOCK(4, c, bias);
6184
6185#endif // defined(BROADCAST_BIAS)
6186#endif // defined(BETA)
6187
6188#if defined(ACTIVATION_TYPE)
6189    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL);
6190#endif // defined(ACTIVATION_TYPE)
6191
6192    // Store 4x4 block
6193    const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
6194    const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
6195    STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
6196}
6197
6198/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
6199 *
6200 * @note The number of rows of destination matrix must be passed at compile time using -DM
6201 * @note The number of columns of the destination matrix must be passed at compile time using -DN
6202 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
6203 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
6204 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
6205 * @note The optional alpha's value need to be passed at compile time using -DALPHA
6206 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
6207 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
6208 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6209 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
6210 *
6211 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6212 *       The activation function is performed after the bias addition
6213 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
6214 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6215 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6216 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6217 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6218 *
6219 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
6220 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
6221 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6222 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6223 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6224 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6225 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
6226 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
6227 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6228 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6229 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6230 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
6231 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6232 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
6233 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6234 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
6235 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6236 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
6237 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
6238 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
6239 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
6240 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
6241 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
6242 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
6243 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6244 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6245 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
6246 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
6247 * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
6248 */
6249__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
6250                                                         IMAGE_DECLARATION(src1),
6251#if defined(BETA)
6252                                                         IMAGE_DECLARATION(src2),
6253#endif // defined(BETA)
6254                                                         IMAGE_DECLARATION(dst),
6255                                                         uint src0_stride_z,
6256                                                         uint src1_stride_z,
6257#if defined(BETA)
6258                                                         uint src2_stride_z,
6259#endif //defined(BETA)
6260                                                         uint dst_stride_z
6261#if defined(REINTERPRET_OUTPUT_AS_3D)
6262                                                         ,
6263                                                         uint cross_plane_pad
6264#endif // REINTERPRET_OUTPUT_AS_3D
6265                                                        )
6266{
6267    int x = get_global_id(0) / H0;
6268    int y = get_global_id(1) / V0;
6269    int z = get_global_id(2);
6270
6271    // Offset
6272    const int offset_row_a = (get_global_id(1) % V0) * 4;
6273    const int offset_row_b = (get_global_id(0) % H0) * 4;
6274
6275    // src_addr_a = address of matrix A
6276    // src_addr_b = address of matrix B
6277    int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
6278    int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
6279
6280#if defined(MATRIX_B_DEPTH)
6281    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6282    src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
6283#else  // defined(MATRIX_B_DEPTH)
6284    src1_addr_in_bytes += z * src1_stride_z;
6285#endif // defined(MATRIX_B_DEPTH)
6286
6287    __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
6288    __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
6289
6290    src_addr_a += offset_row_a;
6291    src_addr_b += offset_row_b;
6292
6293    // Reset accumulators
6294    float4 c0 = 0.0f;
6295    float4 c1 = 0.0f;
6296    float4 c2 = 0.0f;
6297    float4 c3 = 0.0f;
6298
6299    int i = 0;
6300    for(; i <= (int)(K - 4); i += 4)
6301    {
6302        // Load values from matrix A (interleaved) and matrix B (transposed)
6303        float4 a0 = vload4(0, src_addr_a);
6304        float4 b0 = vload4(0, src_addr_b);
6305
6306        src_addr_a += 4 * V0;
6307        src_addr_b += 4 * H0;
6308
6309        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
6310        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
6311        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
6312        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
6313
6314        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
6315        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
6316        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
6317        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
6318
6319        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
6320        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
6321        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
6322        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
6323
6324        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
6325        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
6326        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
6327        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
6328
6329        // Load values from matrix A (interleaved) and matrix B (transposed)
6330        a0 = vload4(0, src_addr_a);
6331        b0 = vload4(0, src_addr_b);
6332
6333        src_addr_a += 4 * V0;
6334        src_addr_b += 4 * H0;
6335
6336        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
6337        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
6338        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
6339        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
6340
6341        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
6342        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
6343        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
6344        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
6345
6346        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
6347        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
6348        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
6349        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
6350
6351        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
6352        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
6353        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
6354        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
6355
6356        // Load values from matrix A (interleaved) and matrix B (transposed)
6357        a0 = vload4(0, src_addr_a);
6358        b0 = vload4(0, src_addr_b);
6359
6360        src_addr_a += 4 * V0;
6361        src_addr_b += 4 * H0;
6362
6363        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
6364        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
6365        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
6366        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
6367
6368        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
6369        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
6370        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
6371        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
6372
6373        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
6374        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
6375        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
6376        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
6377
6378        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
6379        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
6380        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
6381        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
6382
6383        // Load values from matrix A (interleaved) and matrix B (transposed)
6384        a0 = vload4(0, src_addr_a);
6385        b0 = vload4(0, src_addr_b);
6386
6387        src_addr_a += 4 * V0;
6388        src_addr_b += 4 * H0;
6389
6390        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
6391        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
6392        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
6393        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
6394
6395        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
6396        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
6397        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
6398        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
6399
6400        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
6401        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
6402        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
6403        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
6404
6405        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
6406        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
6407        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
6408        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
6409    }
6410
6411    for(; i < (int)K; ++i)
6412    {
6413        // Load values from matrix A (interleaved) and matrix B (transposed)
6414        float4 a0 = vload4(0, src_addr_a);
6415        float4 b0 = vload4(0, src_addr_b);
6416
6417        src_addr_a += 4 * V0;
6418        src_addr_b += 4 * H0;
6419
6420        c0.s0 = fma(a0.s0, b0.s0, c0.s0);
6421        c0.s1 = fma(a0.s0, b0.s1, c0.s1);
6422        c0.s2 = fma(a0.s0, b0.s2, c0.s2);
6423        c0.s3 = fma(a0.s0, b0.s3, c0.s3);
6424
6425        c1.s0 = fma(a0.s1, b0.s0, c1.s0);
6426        c1.s1 = fma(a0.s1, b0.s1, c1.s1);
6427        c1.s2 = fma(a0.s1, b0.s2, c1.s2);
6428        c1.s3 = fma(a0.s1, b0.s3, c1.s3);
6429
6430        c2.s0 = fma(a0.s2, b0.s0, c2.s0);
6431        c2.s1 = fma(a0.s2, b0.s1, c2.s1);
6432        c2.s2 = fma(a0.s2, b0.s2, c2.s2);
6433        c2.s3 = fma(a0.s2, b0.s3, c2.s3);
6434
6435        c3.s0 = fma(a0.s3, b0.s0, c3.s0);
6436        c3.s1 = fma(a0.s3, b0.s1, c3.s1);
6437        c3.s2 = fma(a0.s3, b0.s2, c3.s2);
6438        c3.s3 = fma(a0.s3, b0.s3, c3.s3);
6439    }
6440
6441    // Compute destination address
6442    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6443
6444    // Compute dst address
6445    __global uchar *dst_addr = offset(&dst, 0, 0);
6446
6447    uint4 zout = 0;
6448
6449#if defined(REINTERPRET_OUTPUT_AS_3D)
6450    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
6451    // in order to take into account the presence of possible cross plane paddings
6452    //
6453    //  |                  |
6454    //  |      plane0      |
6455    //  |                  |
6456    //  |__________________|
6457    //  |******************|
6458    //  |  cross_plane_pad |
6459    //  |******************|
6460    //  |                  |
6461    //  |      plane1      |
6462    //  |                  |
6463    //  |__________________|
6464
6465    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
6466    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
6467    zout = min(DEPTH_GEMM3D - 1, zout);
6468
6469    // Add offset due to the cross plane paddings
6470    zout *= (cross_plane_pad * dst_stride_y);
6471
6472    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6473    // multiply dst_stride_z by DEPTH_GEMM3D
6474    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
6475#else  // defined(REINTERPRET_OUTPUT_AS_3D)
6476    // Add offset for batched GEMM
6477    dst_addr += z * dst_stride_z;
6478#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6479
6480    // Multiply by the weight of matrix-matrix product and store the result
6481#if defined(ALPHA)
6482    SCALE_BLOCK(4, float, c, ALPHA);
6483#endif // defined(ALPHA)
6484
6485    // Add beta*bias
6486#if defined(BETA)
6487    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
6488
6489#if defined(BROADCAST_BIAS)
6490    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
6491
6492    LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
6493
6494#ifndef UNIT_BETA
6495    SCALE_BLOCK(1, float, bias, BETA);
6496#endif // UNIT_BIAS
6497
6498    // c = c + bias[broadcasted]
6499    ADD_BLOCK_BROADCAST(4, c, bias0);
6500
6501#else // defined(BROADCAST_BIAS)
6502    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
6503                                    2) * src2_stride_z;
6504
6505    LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
6506
6507#ifndef UNIT_BETA
6508    SCALE_BLOCK(4, float, bias, BETA);
6509#endif // UNIT_BIAS
6510
6511    // c = c + bias
6512    ADD_BLOCK(4, c, bias);
6513
6514#endif // defined(BROADCAST_BIAS)
6515#endif // defined(BETA)
6516
6517#if defined(ACTIVATION_TYPE)
6518    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL);
6519#endif // defined(ACTIVATION_TYPE)
6520
6521    // Store 4x4 block
6522    const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
6523    const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
6524    STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
6525}
6526
6527#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
6528/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
6529 *
6530 * @note The number of rows of destination matrix must be passed at compile time using -DM
6531 * @note The number of columns of the destination matrix must be passed at compile time using -DN
6532 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
6533 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
6534 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
6535 * @note The optional alpha's value need to be passed at compile time using -DALPHA
6536 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
6537 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
6538 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6539 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
6540 *
6541 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6542 *       The activation function is performed after the bias addition
6543 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
6544 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6545 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6546 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6547 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6548 *
6549 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
6550 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
6551 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6552 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6553 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6554 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6555 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
6556 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
6557 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6558 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6559 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6560 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
6561 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6562 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
6563 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6564 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
6565 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6566 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
6567 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
6568 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
6569 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
6570 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
6571 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
6572 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
6573 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6574 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6575 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
6576 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
6577 * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
6578 */
6579__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
6580                                                 IMAGE_DECLARATION(src1),
6581#if defined(BETA)
6582                                                 IMAGE_DECLARATION(src2),
6583#endif // defined(BETA)
6584                                                 IMAGE_DECLARATION(dst),
6585                                                 uint src0_stride_z,
6586                                                 uint src1_stride_z,
6587#if defined(BETA)
6588                                                 uint src2_stride_z,
6589#endif //defined(BETA)
6590                                                 uint dst_stride_z
6591#if defined(REINTERPRET_OUTPUT_AS_3D)
6592                                                 ,
6593                                                 uint cross_plane_pad
6594#endif // REINTERPRET_OUTPUT_AS_3D
6595                                                )
6596{
6597    int x = get_global_id(0) / H0;
6598    int y = get_global_id(1) / V0;
6599    int z = get_global_id(2);
6600
6601    // Offset
6602    const int offset_row_a = (get_global_id(1) % V0) * 4;
6603    const int offset_row_b = (get_global_id(0) % H0) * 8;
6604
6605    // src_addr_a = address of matrix A
6606    // src_addr_b = address of matrix B
6607    int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
6608    int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
6609
6610#if defined(MATRIX_B_DEPTH)
6611    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6612    src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
6613#else  // defined(MATRIX_B_DEPTH)
6614    src1_addr_in_bytes += z * src1_stride_z;
6615#endif // defined(MATRIX_B_DEPTH)
6616
6617    __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
6618    __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
6619
6620    // Compute end row address for matrix B
6621    __global half *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(half));
6622
6623    src_addr_a += offset_row_a;
6624    src_addr_b += offset_row_b;
6625
6626    // Reset accumulators
6627    half8 c0 = 0.0f;
6628    half8 c1 = 0.0f;
6629    half8 c2 = 0.0f;
6630    half8 c3 = 0.0f;
6631
6632    for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
6633    {
6634        // Load values from matrix A (interleaved) and matrix B (transposed)
6635        half4 a0 = vload4(0, src_addr_a);
6636        half8 b0 = vload8(0, src_addr_b);
6637
6638        c0 += (half8)a0.s0 * b0;
6639        c1 += (half8)a0.s1 * b0;
6640        c2 += (half8)a0.s2 * b0;
6641        c3 += (half8)a0.s3 * b0;
6642
6643        // Load values from matrix A (interleaved) and matrix B (transposed)
6644        a0 = vload4(0, src_addr_a + 4 * V0);
6645        b0 = vload8(0, src_addr_b + 8 * H0);
6646
6647        c0 += (half8)a0.s0 * b0;
6648        c1 += (half8)a0.s1 * b0;
6649        c2 += (half8)a0.s2 * b0;
6650        c3 += (half8)a0.s3 * b0;
6651    }
6652
6653    for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
6654    {
6655        // Load values from matrix A (interleaved) and matrix B (transposed)
6656        half4 a0 = vload4(0, src_addr_a);
6657        half8 b0 = vload8(0, src_addr_b);
6658
6659        c0 += (half8)a0.s0 * b0;
6660        c1 += (half8)a0.s1 * b0;
6661        c2 += (half8)a0.s2 * b0;
6662        c3 += (half8)a0.s3 * b0;
6663    }
6664
6665    // Compute destination address
6666    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6667
6668    // Compute dst address
6669    __global uchar *dst_addr = offset(&dst, 0, 0);
6670
6671    uint4 zout = 0;
6672
6673#if defined(REINTERPRET_OUTPUT_AS_3D)
6674    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
6675    // in order to take into account the presence of possible cross plane paddings
6676    //
6677    //  |                  |
6678    //  |      plane0      |
6679    //  |                  |
6680    //  |__________________|
6681    //  |******************|
6682    //  |  cross_plane_pad |
6683    //  |******************|
6684    //  |                  |
6685    //  |      plane1      |
6686    //  |                  |
6687    //  |__________________|
6688
6689    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
6690    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
6691    zout = min(DEPTH_GEMM3D - 1, zout);
6692
6693    // Add offset due to the cross plane paddings
6694    zout *= (cross_plane_pad * dst_stride_y);
6695
6696    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6697    // multiply dst_stride_z by DEPTH_GEMM3D
6698    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
6699#else  // defined(REINTERPRET_OUTPUT_AS_3D)
6700    // Add offset for batched GEMM
6701    dst_addr += z * dst_stride_z;
6702#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6703
6704    // Multiply by the weight of matrix-matrix product and store the result
6705#if defined(ALPHA)
6706    SCALE_BLOCK(4, half, c, ALPHA);
6707#endif // defined(ALPHA)
6708
6709    // Add beta*bias
6710#if defined(BETA)
6711    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
6712
6713#if defined(BROADCAST_BIAS)
6714    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6715
6716    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6717
6718#ifndef UNIT_BETA
6719    SCALE_BLOCK(1, half, bias, BETA);
6720#endif // UNIT_BIAS
6721
6722    // c = c + bias[broadcasted]
6723    ADD_BLOCK_BROADCAST(4, c, bias0);
6724
6725#else // defined(BROADCAST_BIAS)
6726
6727    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
6728                                    2) * src2_stride_z;
6729
6730    LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6731
6732#ifndef UNIT_BETA
6733    SCALE_BLOCK(4, half, bias, BETA);
6734#endif // UNIT_BIAS
6735
6736    // c = c + bias
6737    ADD_BLOCK(4, c, bias);
6738
6739#endif // defined(BROADCAST_BIAS)
6740#endif // defined(BETA)
6741
6742#if defined(ACTIVATION_TYPE)
6743    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL);
6744#endif // defined(ACTIVATION_TYPE)
6745
6746    // Store 4x8 block
6747    const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
6748    const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
6749    STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
6750}
6751
6752/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
6753 *
6754 * @note The number of rows of destination matrix must be passed at compile time using -DM
6755 * @note The number of columns of the destination matrix must be passed at compile time using -DN
6756 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
6757 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
6758 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
6759 * @note The optional alpha's value need to be passed at compile time using -DALPHA
6760 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
6761 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
6762 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6763 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
6764 *
6765 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6766 *       The activation function is performed after the bias addition
6767 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
6768 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6769 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6770 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6771 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6772 *
6773 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
6774 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
6775 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6776 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6777 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6778 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6779 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
6780 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
6781 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
6782 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
6783 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
6784 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
6785 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6786 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
6787 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6788 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
6789 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6790 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
6791 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
6792 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
6793 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
6794 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
6795 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
6796 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
6797 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6798 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
6799 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
6800 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
6801 * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
6802 */
6803__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
6804                                                       IMAGE_DECLARATION(src1),
6805#if defined(BETA)
6806                                                       IMAGE_DECLARATION(src2),
6807#endif // defined(BETA)
6808                                                       IMAGE_DECLARATION(dst),
6809                                                       uint src0_stride_z,
6810                                                       uint src1_stride_z,
6811#if defined(BETA)
6812                                                       uint src2_stride_z,
6813#endif //defined(BETA)
6814                                                       uint dst_stride_z
6815#if defined(REINTERPRET_OUTPUT_AS_3D)
6816                                                       ,
6817                                                       uint cross_plane_pad
6818#endif // REINTERPRET_OUTPUT_AS_3D
6819                                                      )
6820{
6821    int x = get_global_id(0) / H0;
6822    int y = get_global_id(1) / V0;
6823    int z = get_global_id(2);
6824
6825    // Offset
6826    const int offset_row_a = (get_global_id(1) % V0) * 4;
6827    const int offset_row_b = (get_global_id(0) % H0) * 8;
6828
6829    // src_addr_a = address of matrix A
6830    // src_addr_b = address of matrix B
6831    int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
6832    int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
6833
6834#if defined(MATRIX_B_DEPTH)
6835    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6836    src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
6837#else  // defined(MATRIX_B_DEPTH)
6838    src1_addr_in_bytes += z * src1_stride_z;
6839#endif // defined(MATRIX_B_DEPTH)
6840
6841    __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
6842    __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
6843
6844    // Compute end row address for matrix B
6845    __global half *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(half));
6846
6847    src_addr_a += offset_row_a;
6848    src_addr_b += offset_row_b;
6849
6850    // Reset accumulators
6851    float8 c0 = 0.0f;
6852    float8 c1 = 0.0f;
6853    float8 c2 = 0.0f;
6854    float8 c3 = 0.0f;
6855
6856    for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
6857    {
6858        // Load values from matrix A (interleaved) and matrix B (transposed)
6859        float4 a0 = convert_float4(vload4(0, src_addr_a));
6860        float8 b0 = convert_float8(vload8(0, src_addr_b));
6861
6862        c0 += (float8)a0.s0 * b0;
6863        c1 += (float8)a0.s1 * b0;
6864        c2 += (float8)a0.s2 * b0;
6865        c3 += (float8)a0.s3 * b0;
6866
6867        // Load values from matrix A (interleaved) and matrix B (transposed)
6868        a0 = convert_float4(vload4(0, src_addr_a + 4 * V0));
6869        b0 = convert_float8(vload8(0, src_addr_b + 8 * H0));
6870
6871        c0 += (float8)a0.s0 * b0;
6872        c1 += (float8)a0.s1 * b0;
6873        c2 += (float8)a0.s2 * b0;
6874        c3 += (float8)a0.s3 * b0;
6875    }
6876
6877    for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
6878    {
6879        // Load values from matrix A (interleaved) and matrix B (transposed)
6880        float4 a0 = convert_float4(vload4(0, src_addr_a));
6881        float8 b0 = convert_float8(vload8(0, src_addr_b));
6882
6883        c0 += (float8)a0.s0 * b0;
6884        c1 += (float8)a0.s1 * b0;
6885        c2 += (float8)a0.s2 * b0;
6886        c3 += (float8)a0.s3 * b0;
6887    }
6888
6889    // Compute destination address
6890    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6891
6892    // Compute dst address
6893    __global uchar *dst_addr = offset(&dst, 0, 0);
6894
6895    uint4 zout = 0;
6896
6897#if defined(REINTERPRET_OUTPUT_AS_3D)
6898    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
6899    // in order to take into account the presence of possible cross plane paddings
6900    //
6901    //  |                  |
6902    //  |      plane0      |
6903    //  |                  |
6904    //  |__________________|
6905    //  |******************|
6906    //  |  cross_plane_pad |
6907    //  |******************|
6908    //  |                  |
6909    //  |      plane1      |
6910    //  |                  |
6911    //  |__________________|
6912
6913    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
6914    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
6915    zout = min(DEPTH_GEMM3D - 1, zout);
6916
6917    // Add offset due to the cross plane paddings
6918    zout *= (cross_plane_pad * dst_stride_y);
6919
6920    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6921    // multiply dst_stride_z by DEPTH_GEMM3D
6922    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
6923#else  // defined(REINTERPRET_OUTPUT_AS_3D)
6924    // Add offset for batched GEMM
6925    dst_addr += z * dst_stride_z;
6926#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6927
6928    // Multiply by the weight of matrix-matrix product and store the result
6929#if defined(ALPHA)
6930    SCALE_BLOCK(4, float, c, ALPHA);
6931#endif // defined(ALPHA)
6932
6933#if defined(BETA)
6934    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
6935
6936#if defined(BROADCAST_BIAS)
6937    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6938
6939    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6940
6941    float8 bias_f0 = convert_float8(bias0);
6942
6943#ifndef UNIT_BETA
6944    SCALE_BLOCK(1, float, bias_f, BETA);
6945#endif // UNIT_BIAS
6946
6947    // c = c + bias[broadcasted]
6948    ADD_BLOCK_BROADCAST(4, c, bias_f0);
6949
6950#else // defined(BROADCAST_BIAS)
6951    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
6952                                    2) * src2_stride_z;
6953
6954    LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6955
6956    float8 bias_f0 = convert_float8(bias0);
6957    float8 bias_f1 = convert_float8(bias1);
6958    float8 bias_f2 = convert_float8(bias2);
6959    float8 bias_f3 = convert_float8(bias3);
6960
6961#ifndef UNIT_BETA
6962    SCALE_BLOCK(4, float, bias_f, BETA);
6963#endif // UNIT_BIAS
6964
6965    // c = c + bias
6966    ADD_BLOCK(4, c, bias_f);
6967
6968#endif // defined(BROADCAST_BIAS)
6969#endif // defined(BETA)
6970
6971    half8 c_h0 = convert_half8(c0);
6972    half8 c_h1 = convert_half8(c1);
6973    half8 c_h2 = convert_half8(c2);
6974    half8 c_h3 = convert_half8(c3);
6975
6976#if defined(ACTIVATION_TYPE)
6977    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c_h, A_VAL, B_VAL);
6978#endif // defined(ACTIVATION_TYPE)
6979
6980    // Store 4x8 block
6981    const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
6982    const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
6983    STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
6984}
6985
6986/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
6987 *
6988 * @note The number of rows of destination matrix must be passed at compile time using -DM
6989 * @note The number of columns of the destination matrix must be passed at compile time using -DN
6990 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
6991 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
6992 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
6993 * @note The optional alpha's value need to be passed at compile time using -DALPHA
6994 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
6995 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
6996 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6997 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
6998 *
6999 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
7000 *       The activation function is performed after the bias addition
7001 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
7002 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
7003 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
7004 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
7005 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
7006 *
7007 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
7008 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
7009 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
7010 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
7011 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
7012 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
7013 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
7014 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
7015 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
7016 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
7017 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
7018 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
7019 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
7020 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
7021 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
7022 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
7023 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
7024 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
7025 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
7026 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
7027 * @param[in]  dst_step_x                         dst_stride_x * number of elements along X processed per workitem(in bytes)
7028 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
7029 * @param[in]  dst_step_y                         dst_stride_y * number of elements along Y processed per workitem(in bytes)
7030 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
7031 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
7032 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
7033 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
7034 * @param[in]  cross_plane_pad                    (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
7035 */
7036__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
7037                                                         IMAGE_DECLARATION(src1),
7038#if defined(BETA)
7039                                                         IMAGE_DECLARATION(src2),
7040#endif // defined(BETA)
7041                                                         IMAGE_DECLARATION(dst),
7042                                                         uint src0_stride_z,
7043                                                         uint src1_stride_z,
7044#if defined(BETA)
7045                                                         uint src2_stride_z,
7046#endif //defined(BETA)
7047                                                         uint dst_stride_z
7048#if defined(REINTERPRET_OUTPUT_AS_3D)
7049                                                         ,
7050                                                         uint cross_plane_pad
7051#endif // REINTERPRET_OUTPUT_AS_3D
7052                                                        )
7053{
7054    int x = get_global_id(0) / H0;
7055    int y = get_global_id(1) / V0;
7056    int z = get_global_id(2);
7057
7058    // Offset
7059    const int offset_row_a = (get_global_id(1) % V0) * 4;
7060    const int offset_row_b = (get_global_id(0) % H0) * 8;
7061
7062    // src_addr_a = address of matrix A
7063    // src_addr_b = address of matrix B
7064    int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
7065    int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
7066
7067#if defined(MATRIX_B_DEPTH)
7068    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
7069    src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
7070#else  // defined(MATRIX_B_DEPTH)
7071    src1_addr_in_bytes += z * src1_stride_z;
7072#endif // defined(MATRIX_B_DEPTH)
7073
7074    __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
7075    __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
7076
7077    src_addr_a += offset_row_a;
7078    src_addr_b += offset_row_b;
7079
7080    // Reset accumulators
7081    half8 c0 = 0.0f;
7082    half8 c1 = 0.0f;
7083    half8 c2 = 0.0f;
7084    half8 c3 = 0.0f;
7085
7086    int i = 0;
7087    for(; i <= (int)(K - 4); i += 4)
7088    {
7089#if V0 == 1
7090        // Load values from matrix A (interleaved) and matrix B (transposed)
7091        half8 a0 = vload8(0, src_addr_a);
7092        half8 b0 = vload8(0, src_addr_b);
7093
7094        src_addr_a += 8 * V0;
7095        src_addr_b += 8 * H0;
7096
7097        c0 = fma((half8)a0.s0, b0, c0);
7098        c1 = fma((half8)a0.s1, b0, c1);
7099        c2 = fma((half8)a0.s2, b0, c2);
7100        c3 = fma((half8)a0.s3, b0, c3);
7101
7102        // Load values from matrix B (transposed)
7103        b0 = vload8(0, src_addr_b);
7104
7105        src_addr_b += 8 * H0;
7106
7107        c0 = fma((half8)a0.s4, b0, c0);
7108        c1 = fma((half8)a0.s5, b0, c1);
7109        c2 = fma((half8)a0.s6, b0, c2);
7110        c3 = fma((half8)a0.s7, b0, c3);
7111
7112        // Load values from matrix A (interleaved) and matrix B (transposed)
7113        a0 = vload8(0, src_addr_a);
7114        b0 = vload8(0, src_addr_b);
7115
7116        src_addr_a += 8 * V0;
7117        src_addr_b += 8 * H0;
7118
7119        c0 = fma((half8)a0.s0, b0, c0);
7120        c1 = fma((half8)a0.s1, b0, c1);
7121        c2 = fma((half8)a0.s2, b0, c2);
7122        c3 = fma((half8)a0.s3, b0, c3);
7123
7124        // Load values from matrix B (transposed)
7125        b0 = vload8(0, src_addr_b);
7126
7127        src_addr_b += 8 * H0;
7128
7129        c0 = fma((half8)a0.s4, b0, c0);
7130        c1 = fma((half8)a0.s5, b0, c1);
7131        c2 = fma((half8)a0.s6, b0, c2);
7132        c3 = fma((half8)a0.s7, b0, c3);
7133#else  // V0 == 1
7134        // Load values from matrix A (interleaved) and matrix B (transposed)
7135        half4 a0 = vload4(0, src_addr_a);
7136        half8 b0 = vload8(0, src_addr_b);
7137
7138        src_addr_a += 4 * V0;
7139        src_addr_b += 8 * H0;
7140
7141        c0 = fma((half8)a0.s0, b0, c0);
7142        c1 = fma((half8)a0.s1, b0, c1);
7143        c2 = fma((half8)a0.s2, b0, c2);
7144        c3 = fma((half8)a0.s3, b0, c3);
7145
7146        // Load values from matrix A (interleaved) and matrix B (transposed)
7147        a0 = vload4(0, src_addr_a);
7148        b0 = vload8(0, src_addr_b);
7149
7150        src_addr_a += 4 * V0;
7151        src_addr_b += 8 * H0;
7152
7153        c0 = fma((half8)a0.s0, b0, c0);
7154        c1 = fma((half8)a0.s1, b0, c1);
7155        c2 = fma((half8)a0.s2, b0, c2);
7156        c3 = fma((half8)a0.s3, b0, c3);
7157
7158        // Load values from matrix A (interleaved) and matrix B (transposed)
7159        a0 = vload4(0, src_addr_a);
7160        b0 = vload8(0, src_addr_b);
7161
7162        src_addr_a += 4 * V0;
7163        src_addr_b += 8 * H0;
7164
7165        c0 = fma((half8)a0.s0, b0, c0);
7166        c1 = fma((half8)a0.s1, b0, c1);
7167        c2 = fma((half8)a0.s2, b0, c2);
7168        c3 = fma((half8)a0.s3, b0, c3);
7169
7170        // Load values from matrix A (interleaved) and matrix B (transposed)
7171        a0 = vload4(0, src_addr_a);
7172        b0 = vload8(0, src_addr_b);
7173
7174        src_addr_a += 4 * V0;
7175        src_addr_b += 8 * H0;
7176
7177        c0 = fma((half8)a0.s0, b0, c0);
7178        c1 = fma((half8)a0.s1, b0, c1);
7179        c2 = fma((half8)a0.s2, b0, c2);
7180        c3 = fma((half8)a0.s3, b0, c3);
7181#endif // V0 == 1
7182    }
7183
7184    for(; i < (int)K; ++i)
7185    {
7186        // Load values from matrix A (interleaved) and matrix B (transposed)
7187        half4 a0 = vload4(0, src_addr_a);
7188        half8 b0 = vload8(0, src_addr_b);
7189
7190        src_addr_a += 4 * V0;
7191        src_addr_b += 8 * H0;
7192
7193        c0 = fma((half8)a0.s0, b0, c0);
7194        c1 = fma((half8)a0.s1, b0, c1);
7195        c2 = fma((half8)a0.s2, b0, c2);
7196        c3 = fma((half8)a0.s3, b0, c3);
7197    }
7198
7199    // Compute destination address
7200    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
7201
7202    // Compute dst address
7203    __global uchar *dst_addr = offset(&dst, 0, 0);
7204
7205    uint4 zout = 0;
7206
7207#if defined(REINTERPRET_OUTPUT_AS_3D)
7208    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
7209    // in order to take into account the presence of possible cross plane paddings
7210    //
7211    //  |                  |
7212    //  |      plane0      |
7213    //  |                  |
7214    //  |__________________|
7215    //  |******************|
7216    //  |  cross_plane_pad |
7217    //  |******************|
7218    //  |                  |
7219    //  |      plane1      |
7220    //  |                  |
7221    //  |__________________|
7222
7223    // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
7224    zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
7225    zout = min(DEPTH_GEMM3D - 1, zout);
7226
7227    // Add offset due to the cross plane paddings
7228    zout *= (cross_plane_pad * dst_stride_y);
7229
7230    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
7231    // multiply dst_stride_z by DEPTH_GEMM3D
7232    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
7233#else  // defined(REINTERPRET_OUTPUT_AS_3D)
7234    // Add offset for batched GEMM
7235    dst_addr += z * dst_stride_z;
7236#endif // defined(REINTERPRET_OUTPUT_AS_3D)
7237
7238    // Multiply by the weight of matrix-matrix product and store the result
7239#if defined(ALPHA)
7240    SCALE_BLOCK(4, half, c, ALPHA);
7241#endif // defined(ALPHA)
7242
7243    // Add beta*bias
7244#if defined(BETA)
7245    REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
7246
7247#if defined(BROADCAST_BIAS)
7248    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
7249
7250    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
7251
7252#ifndef UNIT_BETA
7253    SCALE_BLOCK(1, half, bias, BETA);
7254#endif // UNIT_BIAS
7255
7256    // c = c + bias[broadcasted]
7257    ADD_BLOCK_BROADCAST(4, c, bias0);
7258
7259#else // defined(BROADCAST_BIAS)
7260    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
7261                                    2) * src2_stride_z;
7262
7263    LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
7264
7265#ifndef UNIT_BETA
7266    SCALE_BLOCK(4, half, bias, BETA);
7267#endif // UNIT_BIAS
7268
7269    // c = c + bias
7270    ADD_BLOCK(4, c, bias);
7271
7272#endif // defined(BROADCAST_BIAS)
7273#endif // defined(BETA)
7274
7275#if defined(ACTIVATION_TYPE)
7276    ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL);
7277#endif // defined(ACTIVATION_TYPE)
7278
7279    // Store 4x8 block
7280    const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
7281    const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
7282    STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
7283}
7284
7285#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
7286
7287#endif // defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
7288
7289#if defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
7290#if defined(DATA_TYPE)
7291#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, N0)
7292/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
7293 *
7294 * @note This OpenCL kernel works with floating point data types (F16/F32)
7295 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
7296 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0
7297 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
7298 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
7299 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
7300 * @note The optional alpha's value need to be passed at compile time using -DALPHA
7301 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
7302 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
7303 *
7304 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
7305 *       The activation function is performed after the bias addition
7306 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
7307 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
7308 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
7309 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
7310 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
7311 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
7312 *
7313 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16/F32
7314 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
7315 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
7316 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
7317 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
7318 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
7319 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
7320 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
7321 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
7322 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
7323 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
7324 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
7325 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
7326 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
7327 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
7328 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
7329 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
7330 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
7331 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
7332 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
7333 * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
7334 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
7335 * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
7336 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
7337 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
7338 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
7339 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
7340 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
7341 * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
7342 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
7343 */
7344__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
7345                                     IMAGE_DECLARATION(src1),
7346#if defined(BETA)
7347                                     IMAGE_DECLARATION(src2),
7348#endif // defined(BETA)
7349                                     IMAGE_DECLARATION(dst),
7350                                     uint src0_stride_z,
7351                                     uint src1_stride_z,
7352#if defined(BETA)
7353                                     uint src2_stride_z,
7354#endif //defined(BETA)
7355                                     uint dst_stride_z
7356#if defined(REINTERPRET_INPUT_AS_3D)
7357                                     ,
7358                                     uint src_cross_plane_pad
7359#endif // REINTERPRET_INPUT_AS_3D
7360#if defined(REINTERPRET_OUTPUT_AS_3D)
7361                                     ,
7362                                     uint dst_cross_plane_pad
7363#endif // REINTERPRET_OUTPUT_AS_3D
7364                                    )
7365{
7366    int idx = get_global_id(0) * N0;
7367
7368    // Compute starting address for matrix A and Matrix B
7369    int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
7370
7371    // Update address for the matrix A
7372    src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
7373
7374    // Update address for the matrix B
7375    src_addr.s1 += idx * sizeof(DATA_TYPE);
7376
7377#if defined(REINTERPRET_INPUT_AS_3D)
7378    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
7379    // in order to take into account the presence of possible cross plane paddings
7380    //
7381    //  |                  |
7382    //  |      plane0      |
7383    //  |                  |
7384    //  |__________________|
7385    //  |******************|
7386    //  |  cross_plane_pad |
7387    //  |******************|
7388    //  |                  |
7389    //  |      plane1      |
7390    //  |                  |
7391    //  |__________________|
7392
7393    // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
7394    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
7395    zin       = min(DEPTH_GEMM3D - 1, zin);
7396
7397    // Add offset due to the cross plane paddings
7398    zin *= (src_cross_plane_pad * src0_stride_y);
7399
7400    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
7401    // multiply src0_stride_z by DEPTH_GEMM3D
7402    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
7403
7404#else // defined(REINTERPRET_INPUT_AS_3D)
7405
7406    // Add offset for batched GEMM
7407    src_addr.s0 += get_global_id(2) * src0_stride_z;
7408
7409#endif // defined(REINTERPRET_INPUT_AS_3D)
7410
7411#if defined(MATRIX_B_DEPTH)
7412    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
7413    src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
7414#else  // defined(MATRIX_B_DEPTH)
7415    src_addr.s1 += get_global_id(2) * src1_stride_z;
7416#endif // defined(MATRIX_B_DEPTH)
7417
7418    int end_row_vec_a = src_addr.s0 + (K * sizeof(DATA_TYPE));
7419
7420    VECTOR_TYPE acc0 = 0.0f;
7421#if M0 > 1
7422    VECTOR_TYPE acc1 = 0.0f;
7423#endif // M0 > 1
7424#if M0 > 2
7425    VECTOR_TYPE acc2 = 0.0f;
7426#endif // M0 > 2
7427#if M0 > 3
7428    VECTOR_TYPE acc3 = 0.0f;
7429#endif // M0 > 3
7430
7431    for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
7432    {
7433#if defined(REINTERPRET_INPUT_AS_3D)
7434        // Load values from matrix A
7435        LOAD_BLOCK(M0, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
7436#else // defined(REINTERPRET_INPUT_AS_3D)
7437        // Load values from matrix A
7438        VEC_DATA_TYPE(DATA_TYPE, 2)
7439        a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
7440#if M0 > 1
7441        VEC_DATA_TYPE(DATA_TYPE, 2)
7442        a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
7443#endif // M0 > 1
7444#if M0 > 2
7445        VEC_DATA_TYPE(DATA_TYPE, 2)
7446        a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
7447#endif // M0 > 2
7448#if M0 > 3
7449        VEC_DATA_TYPE(DATA_TYPE, 2)
7450        a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
7451#endif // M0 > 3
7452#endif // defined(REINTERPRET_INPUT_AS_3D)
7453
7454        // Load values from matrix B
7455        VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
7456        VECTOR_TYPE b1 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
7457
7458        // Accumulate
7459        acc0 += b0 * (VECTOR_TYPE)a0.s0;
7460        acc0 += b1 * (VECTOR_TYPE)a0.s1;
7461#if M0 > 1
7462        acc1 += b0 * (VECTOR_TYPE)a1.s0;
7463        acc1 += b1 * (VECTOR_TYPE)a1.s1;
7464#endif // M0 > 1
7465#if M0 > 2
7466        acc2 += b0 * (VECTOR_TYPE)a2.s0;
7467        acc2 += b1 * (VECTOR_TYPE)a2.s1;
7468#endif // M0 > 2
7469#if M0 > 3
7470        acc3 += b0 * (VECTOR_TYPE)a3.s0;
7471        acc3 += b1 * (VECTOR_TYPE)a3.s1;
7472#endif // M0 > 3
7473    }
7474
7475    for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
7476    {
7477#if defined(REINTERPRET_INPUT_AS_3D)
7478        // Load values from matrix A
7479        DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
7480#if M0 > 1
7481        DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
7482#endif // M0 > 1
7483#if M0 > 2
7484        DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
7485#endif // M0 > 2
7486#if M0 > 3
7487        DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
7488#endif // M0 > 3
7489#else  // defined(REINTERPRET_INPUT_AS_3D)
7490        // Load values from matrix A
7491        DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
7492#if M0 > 1
7493        DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
7494#endif // M0 > 1
7495#if M0 > 2
7496        DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
7497#endif // M0 > 2
7498#if M0 > 3
7499        DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
7500#endif // M0 > 3
7501#endif // defined(REINTERPRET_INPUT_AS_3D)
7502
7503        // Load values from matrix B
7504        VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
7505
7506        // Accumulate
7507        acc0 += b0 * (VECTOR_TYPE)a0;
7508#if M0 > 1
7509        acc1 += b0 * (VECTOR_TYPE)a1;
7510#endif // M0 > 1
7511#if M0 > 2
7512        acc2 += b0 * (VECTOR_TYPE)a2;
7513#endif // M0 > 2
7514#if M0 > 3
7515        acc3 += b0 * (VECTOR_TYPE)a3;
7516#endif // M0 > 3
7517    }
7518
7519    int z = get_global_id(2);
7520
7521    // Compute dst address
7522    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
7523                               PARTIAL_STORE_M0)
7524                               * dst_stride_y);
7525
7526    uint4 zout = 0;
7527
7528#if defined(REINTERPRET_OUTPUT_AS_3D)
7529
7530    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
7531    // in order to take into account the presence of possible cross plane paddings
7532    //
7533    //  |                  |
7534    //  |      plane0      |
7535    //  |                  |
7536    //  |__________________|
7537    //  |******************|
7538    //  |  cross_plane_pad |
7539    //  |******************|
7540    //  |                  |
7541    //  |      plane1      |
7542    //  |                  |
7543    //  |__________________|
7544
7545    // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
7546    zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
7547    zout = min(DEPTH_GEMM3D - 1, zout);
7548
7549    // Add offset due to the cross plane paddings
7550    zout *= (dst_cross_plane_pad * dst_stride_y);
7551
7552    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
7553    // multiply dst_stride_z by DEPTH_GEMM3D
7554    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
7555#else  // defined(REINTERPRET_OUTPUT_AS_3D)
7556    // Add offset for batched GEMM
7557    dst_addr += z * dst_stride_z;
7558#endif // defined(REINTERPRET_OUTPUT_AS_3D)
7559
7560    // Multiply by the weight of matrix-matrix product and store the result
7561#if defined(ALPHA)
7562    SCALE_BLOCK(M0, DATA_TYPE, acc, ALPHA);
7563#endif // defined(ALPHA)
7564
7565    // Add beta*bias
7566#if defined(BETA)
7567    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
7568
7569#if defined(BROADCAST_BIAS)
7570    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
7571
7572    LOAD_BLOCK(1, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
7573
7574#ifndef UNIT_BETA
7575    SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
7576#endif // UNIT_BIAS
7577
7578    // c = c + bias[broadcasted]
7579    ADD_BLOCK_BROADCAST(M0, acc, bias0);
7580
7581#else // defined(BROADCAST_BIAS)
7582    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
7583                                PARTIAL_STORE_M0)
7584                                * src2_stride_y)
7585                                + z * src2_stride_z;
7586
7587    LOAD_BLOCK(M0, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
7588
7589#ifndef UNIT_BETA
7590    SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
7591#endif // UNIT_BIAS
7592
7593    // c = c + bias
7594    ADD_BLOCK(M0, acc, bias);
7595
7596#endif // defined(BROADCAST_BIAS)
7597#endif // defined(BETA)
7598
7599#if defined(ACTIVATION_TYPE)
7600    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, acc, A_VAL, B_VAL);
7601#endif // defined(ACTIVATION_TYPE)
7602
7603    // Store output block
7604    const bool cond_y = get_global_id(1) == 0;
7605    const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
7606    STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
7607}
7608#endif // defined(DATA_TYPE)
7609
7610/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
7611 *
7612 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
7613 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
7614 * @note This kernel processed a fixed number of elements along x: -DN0=4.
7615 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
7616 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
7617 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
7618 * @note The optional alpha's value need to be passed at compile time using -DALPHA
7619 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
7620 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
7621 *
7622 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
7623 *       The activation function is performed after the bias addition
7624 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
7625 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
7626 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
7627 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
7628 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
7629 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
7630 *
7631 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
7632 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
7633 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
7634 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
7635 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
7636 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
7637 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
7638 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
7639 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
7640 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
7641 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
7642 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
7643 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
7644 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
7645 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
7646 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
7647 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
7648 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
7649 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
7650 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
7651 * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
7652 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
7653 * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
7654 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
7655 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
7656 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
7657 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
7658 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
7659 * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
7660 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
7661 */
7662__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
7663                                                 IMAGE_DECLARATION(src1),
7664#if defined(BETA)
7665                                                 IMAGE_DECLARATION(src2),
7666#endif // defined(BETA)
7667                                                 IMAGE_DECLARATION(dst),
7668                                                 uint src0_stride_z,
7669                                                 uint src1_stride_z,
7670#if defined(BETA)
7671                                                 uint src2_stride_z,
7672#endif //defined(BETA)
7673                                                 uint dst_stride_z
7674#if defined(REINTERPRET_INPUT_AS_3D)
7675                                                 ,
7676                                                 uint src_cross_plane_pad
7677#endif // REINTERPRET_INPUT_AS_3D
7678#if defined(REINTERPRET_OUTPUT_AS_3D)
7679                                                 ,
7680                                                 uint dst_cross_plane_pad
7681#endif // REINTERPRET_OUTPUT_AS_3D
7682                                                )
7683{
7684    int idx = get_global_id(0) * N0;
7685
7686    // Compute starting address for matrix A and matrix B
7687    int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
7688
7689    // Update address for matrix A
7690    src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
7691
7692    // Update address for matrix B
7693    src_addr.s1 += idx * sizeof(float);
7694
7695#if defined(REINTERPRET_INPUT_AS_3D)
7696    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
7697    // in order to take into account the presence of possible cross plane paddings
7698    //
7699    //  |                  |
7700    //  |      plane0      |
7701    //  |                  |
7702    //  |__________________|
7703    //  |******************|
7704    //  |  cross_plane_pad |
7705    //  |******************|
7706    //  |                  |
7707    //  |      plane1      |
7708    //  |                  |
7709    //  |__________________|
7710
7711    // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
7712    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
7713    zin       = min(DEPTH_GEMM3D - 1, zin);
7714
7715    // Add offset due to the cross plane paddings
7716    zin *= (src_cross_plane_pad * src0_stride_y);
7717
7718    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
7719    // multiply src0_stride_z by DEPTH_GEMM3D
7720    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
7721
7722#else // defined(REINTERPRET_INPUT_AS_3D)
7723
7724    // Add offset for batched GEMM
7725    src_addr.s0 += get_global_id(2) * src0_stride_z;
7726
7727#endif // defined(REINTERPRET_INPUT_AS_3D)
7728
7729#if defined(MATRIX_B_DEPTH)
7730    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
7731    src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
7732#else  // defined(MATRIX_B_DEPTH)
7733    src_addr.s1 += get_global_id(2) * src1_stride_z;
7734#endif // defined(MATRIX_B_DEPTH)
7735
7736    // Initialize accumulators
7737    float4 acc0 = 0.0f;
7738
7739#if M0 > 1
7740    float4 acc1 = 0.0f;
7741#endif // M0 > 1
7742
7743#if M0 > 2
7744    float4 acc2 = 0.0f;
7745#endif // M0 > 2
7746
7747#if M0 > 3
7748    float4 acc3 = 0.0f;
7749#endif // M0 > 3
7750
7751    // A and B src indices get incremented at the same time.
7752    int i = 0;
7753    for(; i <= ((int)K - 4); i += 4)
7754    {
7755#if defined(REINTERPRET_INPUT_AS_3D)
7756        // Load values from matrix A and matrix B
7757        LOAD_BLOCK(M0, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
7758#else // defined(REINTERPRET_INPUT_AS_3D)
7759        // Load values from matrix A and matrix B
7760        float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
7761#if M0 > 1
7762        float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
7763#endif // M0 > 1
7764#if M0 > 2
7765        float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
7766#endif // M0 > 2
7767#if M0 > 3
7768        float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
7769#endif // M0 > 3
7770#endif // defined(REINTERPRET_INPUT_AS_3D)
7771
7772        float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
7773        src_addr.s1 += src1_stride_y;
7774
7775        // Multiply and accumulate
7776        acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
7777        acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
7778        acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
7779        acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
7780
7781#if M0 > 1
7782
7783        acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
7784        acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
7785        acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
7786        acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
7787
7788#endif // M0 > 1
7789#if M0 > 2
7790
7791        acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
7792        acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
7793        acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
7794        acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
7795
7796#endif // M0 > 2
7797#if M0 > 3
7798
7799        acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
7800        acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
7801        acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
7802        acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
7803#endif // M0 > 3
7804
7805        // Load values from matrix A and matrix B
7806        b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
7807        src_addr.s1 += src1_stride_y;
7808
7809        // Multiply and accumulate
7810        acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
7811        acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
7812        acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
7813        acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
7814
7815#if M0 > 1
7816
7817        acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
7818        acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
7819        acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
7820        acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
7821
7822#endif // M0 > 1
7823#if M0 > 2
7824
7825        acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
7826        acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
7827        acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
7828        acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
7829
7830#endif // M0 > 2
7831#if M0 > 3
7832
7833        acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
7834        acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
7835        acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
7836        acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
7837#endif // M0 > 3
7838
7839        // Load values from matrix A and matrix B
7840        b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
7841        src_addr.s1 += src1_stride_y;
7842
7843        // Multiply and accumulate
7844        acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
7845        acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
7846        acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
7847        acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
7848
7849#if M0 > 1
7850
7851        acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
7852        acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
7853        acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
7854        acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
7855
7856#endif // M0 > 1
7857#if M0 > 2
7858
7859        acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
7860        acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
7861        acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
7862        acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
7863
7864#endif // M0 > 2
7865#if M0 > 3
7866
7867        acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
7868        acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
7869        acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
7870        acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
7871#endif // M0 > 3
7872
7873        // Load values from matrix A and matrix B
7874        b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
7875        src_addr.s1 += src1_stride_y;
7876
7877        // Multiply and accumulate
7878        acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
7879        acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
7880        acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
7881        acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
7882
7883#if M0 > 1
7884
7885        acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
7886        acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
7887        acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
7888        acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
7889
7890#endif // M0 > 1
7891#if M0 > 2
7892
7893        acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
7894        acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
7895        acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
7896        acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
7897
7898#endif // M0 > 2
7899#if M0 > 3
7900
7901        acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
7902        acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
7903        acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
7904        acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
7905#endif // M0 > 3
7906
7907        src_addr.s0 += 4 * sizeof(float);
7908    }
7909
7910    for(; i < (int)K; ++i)
7911    {
7912#if defined(REINTERPRET_INPUT_AS_3D)
7913        // Load values from matrix A
7914        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
7915#if M0 > 1
7916        float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
7917#endif // M0 > 1
7918#if M0 > 2
7919        float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
7920#endif // M0 > 2
7921#if M0 > 3
7922        float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
7923#endif // M0 > 3
7924#else  // defined(REINTERPRET_INPUT_AS_3D)
7925        // Load values from matrix A
7926        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
7927#if M0 > 1
7928        float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
7929#endif // M0 > 1
7930#if M0 > 2
7931        float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
7932#endif // M0 > 2
7933#if M0 > 3
7934        float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
7935#endif // M0 > 3
7936#endif // defined(REINTERPRET_INPUT_AS_3D)
7937
7938        // Load values from matrix B
7939        float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
7940        src_addr.s1 += src1_stride_y;
7941
7942        // Multiply and accumulate
7943        acc0.s0 = fma(a0, b0.s0, acc0.s0);
7944        acc0.s1 = fma(a0, b0.s1, acc0.s1);
7945        acc0.s2 = fma(a0, b0.s2, acc0.s2);
7946        acc0.s3 = fma(a0, b0.s3, acc0.s3);
7947#if M0 > 1
7948        acc1.s0 = fma(a1, b0.s0, acc1.s0);
7949        acc1.s1 = fma(a1, b0.s1, acc1.s1);
7950        acc1.s2 = fma(a1, b0.s2, acc1.s2);
7951        acc1.s3 = fma(a1, b0.s3, acc1.s3);
7952#endif // M0 > 1
7953#if M0 > 2
7954        acc2.s0 = fma(a2, b0.s0, acc2.s0);
7955        acc2.s1 = fma(a2, b0.s1, acc2.s1);
7956        acc2.s2 = fma(a2, b0.s2, acc2.s2);
7957        acc2.s3 = fma(a2, b0.s3, acc2.s3);
7958#endif // M0 > 2
7959#if M0 > 3
7960        acc3.s0 = fma(a3, b0.s0, acc3.s0);
7961        acc3.s1 = fma(a3, b0.s1, acc3.s1);
7962        acc3.s2 = fma(a3, b0.s2, acc3.s2);
7963        acc3.s3 = fma(a3, b0.s3, acc3.s3);
7964#endif // M0 > 3
7965
7966        src_addr.s0 += sizeof(float);
7967    }
7968
7969    int z = get_global_id(2);
7970
7971    // Compute dst address
7972    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
7973                               PARTIAL_STORE_M0) * dst_stride_y);
7974
7975    uint4 zout = 0;
7976
7977#if defined(REINTERPRET_OUTPUT_AS_3D)
7978    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
7979    // in order to take into account the presence of possible cross plane paddings
7980    //
7981    //  |                  |
7982    //  |      plane0      |
7983    //  |                  |
7984    //  |__________________|
7985    //  |******************|
7986    //  |  cross_plane_pad |
7987    //  |******************|
7988    //  |                  |
7989    //  |      plane1      |
7990    //  |                  |
7991    //  |__________________|
7992
7993    // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
7994    zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
7995    zout = min(DEPTH_GEMM3D - 1, zout);
7996
7997    // Add offset due to the cross plane paddings
7998    zout *= (dst_cross_plane_pad * dst_stride_y);
7999
8000    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
8001    // multiply dst_stride_z by DEPTH_GEMM3D
8002    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
8003#else  // defined(REINTERPRET_OUTPUT_AS_3D)
8004    // Add offset for batched GEMM
8005    dst_addr += z * dst_stride_z;
8006#endif // defined(REINTERPRET_OUTPUT_AS_3D)
8007
8008    // Multiply by the weight of matrix-matrix product and store the result
8009#if defined(ALPHA)
8010    SCALE_BLOCK(M0, float, acc, ALPHA);
8011#endif // defined(ALPHA)
8012
8013    // Add beta*bias
8014#if defined(BETA)
8015    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
8016
8017#if defined(BROADCAST_BIAS)
8018    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
8019
8020    LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
8021
8022#ifndef UNIT_BETA
8023    SCALE_BLOCK(1, float, bias, BETA);
8024#endif // UNIT_BIAS
8025
8026    // acc = acc + bias[broadcasted]
8027    ADD_BLOCK_BROADCAST(M0, acc, bias0);
8028
8029#else // defined(BROADCAST_BIAS)
8030    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
8031                                PARTIAL_STORE_M0)
8032                                * src2_stride_y)
8033                                + z * src2_stride_z;
8034
8035    LOAD_BLOCK(M0, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
8036
8037#ifndef UNIT_BETA
8038    SCALE_BLOCK(M0, float, bias, BETA);
8039#endif // UNIT_BIAS
8040
8041    // acc = acc + bias
8042    ADD_BLOCK(M0, acc, bias);
8043
8044#endif // defined(BROADCAST_BIAS)
8045#endif // defined(BETA)
8046
8047#if defined(ACTIVATION_TYPE)
8048    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL);
8049#endif // defined(ACTIVATION_TYPE)
8050
8051    // Store the output block
8052    const bool cond_y = get_global_id(1) == 0;
8053    const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
8054    STORE_BLOCK_BOUNDARY_AWARE(M0, 4, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
8055}
8056
8057/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
8058 *
8059 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
8060 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
8061 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
8062 * @note This kernel processed a fixed number of elements along x: -DN0=2.
8063 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
8064 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
8065 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
8066 * @note The optional alpha's value need to be passed at compile time using -DALPHA
8067 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
8068 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
8069 *
8070 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
8071 *       The activation function is performed after the bias addition
8072 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
8073 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
8074 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
8075 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
8076 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
8077 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
8078 *
8079 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F32
8080 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
8081 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
8082 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
8083 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
8084 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
8085 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
8086 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
8087 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
8088 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
8089 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
8090 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
8091 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
8092 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
8093 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
8094 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
8095 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
8096 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
8097 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
8098 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
8099 * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
8100 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
8101 * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
8102 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
8103 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
8104 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
8105 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
8106 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
8107 * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
8108 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
8109 */
8110__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
8111                                                      IMAGE_DECLARATION(src1),
8112#if defined(BETA)
8113                                                      IMAGE_DECLARATION(src2),
8114#endif // defined(BETA)
8115                                                      IMAGE_DECLARATION(dst),
8116                                                      uint src0_stride_z,
8117                                                      uint src1_stride_z,
8118#if defined(BETA)
8119                                                      uint src2_stride_z,
8120#endif //defined(BETA)
8121                                                      uint dst_stride_z
8122#if defined(REINTERPRET_INPUT_AS_3D)
8123                                                      ,
8124                                                      uint src_cross_plane_pad
8125#endif // REINTERPRET_INPUT_AS_3D
8126#if defined(REINTERPRET_OUTPUT_AS_3D)
8127                                                      ,
8128                                                      uint dst_cross_plane_pad
8129#endif // REINTERPRET_OUTPUT_AS_3D
8130                                                     )
8131{
8132    // Requires 2 N0, C vect2, A vect4, B (2 vload2) // to fix for M0 > 1
8133    int idx = get_global_id(0) * N0;
8134
8135    // Compute starting address for matrix A and Matrix B
8136    int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
8137
8138    // Update address for the matrix A
8139    src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
8140
8141    // Update address for the matrix B
8142    src_addr.s1 += idx * sizeof(float);
8143
8144#if defined(REINTERPRET_INPUT_AS_3D)
8145    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
8146    // in order to take into account the presence of possible cross plane paddings
8147    //
8148    //  |                  |
8149    //  |      plane0      |
8150    //  |                  |
8151    //  |__________________|
8152    //  |******************|
8153    //  |  cross_plane_pad |
8154    //  |******************|
8155    //  |                  |
8156    //  |      plane1      |
8157    //  |                  |
8158    //  |__________________|
8159
8160    // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
8161    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
8162    zin       = min(DEPTH_GEMM3D - 1, zin);
8163
8164    // Add offset due to the cross plane paddings
8165    zin *= (src_cross_plane_pad * src0_stride_y);
8166
8167    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
8168    // multiply src0_stride_z by DEPTH_GEMM3D
8169    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
8170
8171#else // defined(REINTERPRET_INPUT_AS_3D)
8172
8173    // Add offset for batched GEMM
8174    src_addr.s0 += get_global_id(2) * src0_stride_z;
8175
8176#endif // defined(REINTERPRET_INPUT_AS_3D)
8177
8178#if defined(MATRIX_B_DEPTH)
8179    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
8180    src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
8181#else  // defined(MATRIX_B_DEPTH)
8182    src_addr.s1 += get_global_id(2) * src1_stride_z;
8183#endif // defined(MATRIX_B_DEPTH)
8184
8185    // Initialize accumulators
8186    float2 acc0 = 0.0f;
8187#if M0 > 1
8188    float2 acc1 = 0.0f;
8189#endif // M0 > 1
8190#if M0 > 2
8191    float2 acc2 = 0.0f;
8192#endif // M0 > 2
8193#if M0 > 3
8194    float2 acc3 = 0.0f;
8195#endif // M0 > 3
8196
8197    // A and B src indices get incremented at the same time.
8198    int i = 0;
8199    for(; i <= ((int)K - 8); i += 8)
8200    {
8201#if defined(REINTERPRET_INPUT_AS_3D)
8202        // Load values from matrix A
8203        float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
8204#else  // defined(REINTERPRET_INPUT_AS_3D)
8205        // Load values from matrix A
8206        float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
8207#endif // defined(REINTERPRET_INPUT_AS_3D)
8208
8209        // Load values from matrix B
8210        float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8211        src_addr.s1 += src1_stride_y;
8212        float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8213        src_addr.s1 += src1_stride_y;
8214        float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8215        src_addr.s1 += src1_stride_y;
8216        float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8217        src_addr.s1 += src1_stride_y;
8218        float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8219        src_addr.s1 += src1_stride_y;
8220        float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8221        src_addr.s1 += src1_stride_y;
8222        float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8223        src_addr.s1 += src1_stride_y;
8224        float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8225        src_addr.s1 += src1_stride_y;
8226
8227        // Multiply and accumulate
8228        acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
8229        acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
8230        acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
8231        acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
8232        acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
8233        acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
8234        acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
8235        acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
8236
8237        acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
8238        acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
8239        acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
8240        acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
8241        acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
8242        acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
8243        acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
8244        acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
8245
8246#if M0 > 1
8247#if defined(REINTERPRET_INPUT_AS_3D)
8248        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
8249#else  // defined(REINTERPRET_INPUT_AS_3D)
8250        a0                    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
8251#endif // defined(REINTERPRET_INPUT_AS_3D)
8252        acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
8253        acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
8254        acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
8255        acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
8256        acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
8257        acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
8258        acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
8259        acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
8260
8261        acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
8262        acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
8263        acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
8264        acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
8265        acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
8266        acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
8267        acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
8268        acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
8269#endif // M0 > 1
8270#if M0 > 2
8271#if defined(REINTERPRET_INPUT_AS_3D)
8272        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
8273#else  // defined(REINTERPRET_INPUT_AS_3D)
8274        a0                    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
8275#endif // defined(REINTERPRET_INPUT_AS_3D)
8276        acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
8277        acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
8278        acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
8279        acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
8280        acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
8281        acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
8282        acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
8283        acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
8284
8285        acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
8286        acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
8287        acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
8288        acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
8289        acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
8290        acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
8291        acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
8292        acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
8293#endif // M0 > 2
8294#if M0 > 3
8295#if defined(REINTERPRET_INPUT_AS_3D)
8296        a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
8297#else  // defined(REINTERPRET_INPUT_AS_3D)
8298        a0                    = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
8299#endif // defined(REINTERPRET_INPUT_AS_3D)
8300        acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
8301        acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
8302        acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
8303        acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
8304        acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
8305        acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
8306        acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
8307        acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
8308
8309        acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
8310        acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
8311        acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
8312        acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
8313        acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
8314        acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
8315        acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
8316        acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
8317#endif // M0 > 3
8318
8319        src_addr.s0 += sizeof(float) * 8;
8320    }
8321    // float size increment
8322    for(; i < (int)K; ++i)
8323    {
8324#if defined(REINTERPRET_INPUT_AS_3D)
8325        // Load values from matrix A
8326        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
8327#if M0 > 1
8328        float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
8329#endif // M0 > 1
8330#if M0 > 2
8331        float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
8332#endif // M0 > 2
8333#if M0 > 3
8334        float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
8335#endif // M0 > 3
8336#else  // defined(REINTERPRET_INPUT_AS_3D)
8337        // Load values from matrix A
8338        float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
8339#if M0 > 1
8340        float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
8341#endif // M0 > 1
8342#if M0 > 2
8343        float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
8344#endif // M0 > 2
8345#if M0 > 3
8346        float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
8347#endif // M0 > 3
8348#endif // defined(REINTERPRET_INPUT_AS_3D)
8349
8350        // Load values from matrix B
8351        float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
8352        src_addr.s1 += src1_stride_y;
8353
8354        // Multiply and accumulate
8355        acc0.s0 = fma(a0, b0.s0, acc0.s0);
8356        acc0.s1 = fma(a0, b0.s1, acc0.s1);
8357#if M0 > 1
8358        acc1.s0 = fma(a1, b0.s0, acc1.s0);
8359        acc1.s1 = fma(a1, b0.s1, acc1.s1);
8360#endif // M0 > 1
8361#if M0 > 2
8362        acc2.s0 = fma(a2, b0.s0, acc2.s0);
8363        acc2.s1 = fma(a2, b0.s1, acc2.s1);
8364#endif // M0 > 2
8365#if M0 > 3
8366        acc3.s0 = fma(a3, b0.s0, acc3.s0);
8367        acc3.s1 = fma(a3, b0.s1, acc3.s1);
8368#endif // M0 > 3
8369
8370        src_addr.s0 += sizeof(float);
8371    }
8372
8373    int z = get_global_id(2);
8374
8375    // Compute dst address
8376    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
8377                               PARTIAL_STORE_M0) * dst_stride_y);
8378
8379    uint4 zout = 0;
8380
8381#if defined(REINTERPRET_OUTPUT_AS_3D)
8382
8383    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
8384    // in order to take into account the presence of possible cross plane paddings
8385    //
8386    //  |                  |
8387    //  |      plane0      |
8388    //  |                  |
8389    //  |__________________|
8390    //  |******************|
8391    //  |  cross_plane_pad |
8392    //  |******************|
8393    //  |                  |
8394    //  |      plane1      |
8395    //  |                  |
8396    //  |__________________|
8397
8398    // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
8399    zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
8400    zout = min(DEPTH_GEMM3D - 1, zout);
8401
8402    // Add offset due to the cross plane paddings
8403    zout *= (dst_cross_plane_pad * dst_stride_y);
8404
8405    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
8406    // multiply dst_stride_z by DEPTH_GEMM3D
8407    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
8408#else  // defined(REINTERPRET_OUTPUT_AS_3D)
8409    // Add offset for batched GEMM
8410    dst_addr += z * dst_stride_z;
8411#endif // defined(REINTERPRET_OUTPUT_AS_3D)
8412
8413    // Multiply by the weight of matrix-matrix product and store the result
8414#if defined(ALPHA)
8415    SCALE_BLOCK(M0, float, acc, ALPHA);
8416#endif // defined(ALPHA)
8417
8418    // Add beta*bias
8419#if defined(BETA)
8420    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
8421
8422#if defined(BROADCAST_BIAS)
8423    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
8424
8425    LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
8426
8427#ifndef UNIT_BETA
8428    SCALE_BLOCK(1, float, bias, BETA);
8429#endif // UNIT_BIAS
8430
8431    // acc = acc + bias[broadcasted]
8432    ADD_BLOCK_BROADCAST(M0, acc, bias0);
8433
8434#else // defined(BROADCAST_BIAS)
8435    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
8436                                PARTIAL_STORE_M0)
8437                                * src2_stride_y)
8438                                + z * src2_stride_z;
8439
8440    LOAD_BLOCK(M0, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
8441
8442#ifndef UNIT_BETA
8443    SCALE_BLOCK(M0, float, bias, BETA);
8444#endif // UNIT_BIAS
8445
8446    // acc = acc + bias
8447    ADD_BLOCK(M0, acc, bias);
8448
8449#endif // defined(BROADCAST_BIAS)
8450#endif // defined(BETA)
8451
8452#if defined(ACTIVATION_TYPE)
8453    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL);
8454#endif // defined(ACTIVATION_TYPE)
8455
8456    // Store the output block
8457    const bool cond_y = get_global_id(1) == 0;
8458    const bool cond_x = ((get_global_id(0) + 1) * 2 >= N);
8459    STORE_BLOCK_BOUNDARY_AWARE(M0, 2, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
8460}
8461
8462#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
8463/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
8464 *
8465 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
8466 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
8467 * @note This kernel processed a fixed number of elements along x: -DN0=8.
8468 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
8469 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
8470 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
8471 * @note The optional alpha's value need to be passed at compile time using -DALPHA
8472 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
8473 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
8474 *
8475 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
8476 *       The activation function is performed after the bias addition
8477 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
8478 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
8479 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
8480 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
8481 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
8482 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
8483 *
8484 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
8485 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
8486 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
8487 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
8488 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
8489 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
8490 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
8491 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
8492 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
8493 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
8494 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
8495 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
8496 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
8497 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
8498 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
8499 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
8500 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
8501 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
8502 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
8503 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
8504 * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
8505 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
8506 * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
8507 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
8508 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
8509 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
8510 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
8511 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
8512 * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
8513 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
8514 */
8515__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
8516                                                       IMAGE_DECLARATION(src1),
8517#if defined(BETA)
8518                                                       IMAGE_DECLARATION(src2),
8519#endif // defined(BETA)
8520                                                       IMAGE_DECLARATION(dst),
8521                                                       uint src0_stride_z,
8522                                                       uint src1_stride_z,
8523#if defined(BETA)
8524                                                       uint src2_stride_z,
8525#endif //defined(BETA)
8526                                                       uint dst_stride_z
8527#if defined(REINTERPRET_INPUT_AS_3D)
8528                                                       ,
8529                                                       uint src_cross_plane_pad
8530#endif // REINTERPRET_INPUT_AS_3D
8531#if defined(REINTERPRET_OUTPUT_AS_3D)
8532                                                       ,
8533                                                       uint dst_cross_plane_pad
8534#endif // REINTERPRET_OUTPUT_AS_3D
8535                                                      )
8536{
8537    int idx = get_global_id(0) * N0;
8538
8539    // Compute starting address for matrix A and Matrix B
8540    int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
8541
8542    // Update address for the matrix A
8543    src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
8544
8545    // Update address for the matrix B
8546    src_addr.s1 += idx * sizeof(half);
8547
8548#if defined(REINTERPRET_INPUT_AS_3D)
8549    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
8550    // in order to take into account the presence of possible cross plane paddings
8551    //
8552    //  |                  |
8553    //  |      plane0      |
8554    //  |                  |
8555    //  |__________________|
8556    //  |******************|
8557    //  |  cross_plane_pad |
8558    //  |******************|
8559    //  |                  |
8560    //  |      plane1      |
8561    //  |                  |
8562    //  |__________________|
8563
8564    // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
8565    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
8566    zin       = min(DEPTH_GEMM3D - 1, zin);
8567
8568    // Add offset due to the cross plane paddings
8569    zin *= (src_cross_plane_pad * src0_stride_y);
8570
8571    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
8572    // multiply src0_stride_z by DEPTH_GEMM3D
8573    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
8574
8575#else // defined(REINTERPRET_INPUT_AS_3D)
8576
8577    // Add offset for batched GEMM
8578    src_addr.s0 += get_global_id(2) * src0_stride_z;
8579
8580#endif // defined(REINTERPRET_INPUT_AS_3D)
8581
8582#if defined(MATRIX_B_DEPTH)
8583    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
8584    src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
8585#else  // defined(MATRIX_B_DEPTH)
8586    src_addr.s1 += get_global_id(2) * src1_stride_z;
8587#endif // defined(MATRIX_B_DEPTH)
8588
8589    float8 acc0 = 0.0h;
8590#if M0 > 1
8591    float8 acc1 = 0.0h;
8592#endif // M0 > 1
8593#if M0 > 2
8594    float8 acc2 = 0.0h;
8595#endif // M0 > 2
8596#if M0 > 3
8597    float8 acc3 = 0.0h;
8598#endif // M0 > 3
8599
8600    int i = 0;
8601    for(; i <= ((int)K - 4); i += 4)
8602    {
8603#if defined(REINTERPRET_INPUT_AS_3D)
8604        // Load values from matrix A
8605        LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
8606#else // defined(REINTERPRET_INPUT_AS_3D)
8607        // Load values from matrix A
8608        half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
8609#if M0 > 1
8610        half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
8611#endif // M0 > 1
8612#if M0 > 2
8613        half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
8614#endif // M0 > 2
8615#if M0 > 3
8616        half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
8617#endif // M0 > 3
8618#endif // defined(REINTERPRET_INPUT_AS_3D)
8619
8620        // Load values from matrix B
8621        float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
8622        src_addr.s1 += src1_stride_y;
8623
8624        // Accumulate
8625        acc0 = fma(b0, (float8)a0.s0, acc0);
8626#if M0 > 1
8627        acc1 = fma(b0, (float8)a1.s0, acc1);
8628#endif // M0 > 1
8629#if M0 > 2
8630        acc2 = fma(b0, (float8)a2.s0, acc2);
8631#endif // M0 > 2
8632#if M0 > 3
8633        acc3 = fma(b0, (float8)a3.s0, acc3);
8634#endif // M0 > 3
8635
8636        b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
8637        src_addr.s1 += src1_stride_y;
8638        acc0 = fma(b0, (float8)a0.s1, acc0);
8639#if M0 > 1
8640        acc1 = fma(b0, (float8)a1.s1, acc1);
8641#endif // M0 > 1
8642#if M0 > 2
8643        acc2 = fma(b0, (float8)a2.s1, acc2);
8644#endif // M0 > 2
8645#if M0 > 3
8646        acc3 = fma(b0, (float8)a3.s1, acc3);
8647#endif // M0 > 3
8648
8649        b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
8650        src_addr.s1 += src1_stride_y;
8651        acc0 = fma(b0, (float8)a0.s2, acc0);
8652#if M0 > 1
8653        acc1 = fma(b0, (float8)a1.s2, acc1);
8654#endif // M0 > 1
8655#if M0 > 2
8656        acc2 = fma(b0, (float8)a2.s2, acc2);
8657#endif // M0 > 2
8658#if M0 > 3
8659        acc3 = fma(b0, (float8)a3.s2, acc3);
8660#endif // M0 > 3
8661
8662        b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
8663        src_addr.s1 += src1_stride_y;
8664        acc0 = fma(b0, (float8)a0.s3, acc0);
8665#if M0 > 1
8666        acc1 = fma(b0, (float8)a1.s3, acc1);
8667#endif // M0 > 1
8668#if M0 > 2
8669        acc2 = fma(b0, (float8)a2.s3, acc2);
8670#endif // M0 > 2
8671#if M0 > 3
8672        acc3 = fma(b0, (float8)a3.s3, acc3);
8673#endif // M0 > 3
8674
8675        src_addr.s0 += 4 * sizeof(half);
8676    }
8677
8678    for(; i < (int)K; ++i)
8679    {
8680#if defined(REINTERPRET_INPUT_AS_3D)
8681        // Load values from matrix A
8682        half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
8683#if M0 > 1
8684        half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
8685#endif // M0 > 1
8686#if M0 > 2
8687        half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
8688#endif // M0 > 2
8689#if M0 > 3
8690        half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
8691#endif // M0 > 3
8692#else  // defined(REINTERPRET_INPUT_AS_3D)
8693        // Load values from matrix A
8694        half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
8695#if M0 > 1
8696        half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
8697#endif // M0 > 1
8698#if M0 > 2
8699        half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
8700#endif // M0 > 2
8701#if M0 > 3
8702        half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
8703#endif // M0 > 3
8704#endif // defined(REINTERPRET_INPUT_AS_3D)
8705
8706        // Load values from matrix B
8707        float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
8708
8709        src_addr += (int2)(sizeof(half), src1_stride_y);
8710
8711        // Accumulate
8712        acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
8713#if M0 > 1
8714        acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
8715#endif                                    // M0 > 1
8716#if M0 > 2
8717        acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
8718#endif                                    // M0 > 2
8719#if M0 > 3
8720        acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
8721#endif                                    // M0 > 3
8722    }
8723
8724    int z = get_global_id(2);
8725
8726    // Compute dst address
8727    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y);
8728
8729    uint4 zout = 0;
8730
8731#if defined(REINTERPRET_OUTPUT_AS_3D)
8732
8733    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
8734    // in order to take into account the presence of possible cross plane paddings
8735    //
8736    //  |                  |
8737    //  |      plane0      |
8738    //  |                  |
8739    //  |__________________|
8740    //  |******************|
8741    //  |  cross_plane_pad |
8742    //  |******************|
8743    //  |                  |
8744    //  |      plane1      |
8745    //  |                  |
8746    //  |__________________|
8747
8748    // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
8749    zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
8750    zout = min(DEPTH_GEMM3D - 1, zout);
8751
8752    // Add offset due to the cross plane paddings
8753    zout *= (dst_cross_plane_pad * dst_stride_y);
8754
8755    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
8756    // multiply dst_stride_z by DEPTH_GEMM3D
8757    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
8758#else  // defined(REINTERPRET_OUTPUT_AS_3D)
8759    // Add offset for batched GEMM
8760    dst_addr += z * dst_stride_z;
8761#endif // defined(REINTERPRET_OUTPUT_AS_3D)
8762
8763    // Multiply by the weight of matrix-matrix product and store the result
8764#if defined(ALPHA)
8765    SCALE_BLOCK(M0, float, acc, ALPHA);
8766#endif // defined(ALPHA)
8767
8768#if defined(BETA)
8769    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
8770
8771#if defined(BROADCAST_BIAS)
8772    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
8773
8774    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
8775
8776    float8 bias_f0 = convert_float8(bias0);
8777
8778#ifndef UNIT_BETA
8779    SCALE_BLOCK(1, float, bias_f, BETA);
8780#endif // UNIT_BIAS
8781
8782    // acc = acc + bias[broadcasted]
8783    ADD_BLOCK_BROADCAST(M0, acc, bias_f0);
8784
8785#else // defined(BROADCAST_BIAS)
8786    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
8787                                PARTIAL_STORE_M0)
8788                                * src2_stride_y)
8789                                + z * src2_stride_z;
8790
8791    LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
8792
8793    float8 bias_f0 = convert_float8(bias0);
8794#if M0 > 1
8795    float8 bias_f1 = convert_float8(bias1);
8796#endif // M0 > 1
8797#if M0 > 2
8798    float8 bias_f2 = convert_float8(bias2);
8799#endif // M0 > 2
8800#if M0 > 3
8801    float8 bias_f3 = convert_float8(bias3);
8802#endif // M0 > 3
8803
8804#ifndef UNIT_BETA
8805    SCALE_BLOCK(M0, float, bias_f, BETA);
8806#endif // UNIT_BIAS
8807
8808    // acc = acc + bias
8809    ADD_BLOCK(M0, acc, bias_f);
8810
8811#endif // defined(BROADCAST_BIAS)
8812#endif // defined(BETA)
8813
8814    half8 acc_h0 = convert_half8(acc0);
8815#if M0 > 1
8816    half8 acc_h1 = convert_half8(acc1);
8817#endif // M0 > 1
8818#if M0 > 2
8819    half8 acc_h2 = convert_half8(acc2);
8820#endif // M0 > 2
8821#if M0 > 3
8822    half8 acc_h3 = convert_half8(acc3);
8823#endif // M0 > 3
8824
8825#if defined(ACTIVATION_TYPE)
8826    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc_h, A_VAL, B_VAL);
8827#endif // defined(ACTIVATION_TYPE)
8828
8829    // Store the output block
8830    const bool cond_y = get_global_id(1) == 0;
8831    const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
8832    STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
8833}
8834
8835/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
8836 *
8837 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
8838 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
8839 * @note This kernel processed a fixed number of elements along x: -DN0=8.
8840 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
8841 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
8842 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
8843 * @note The optional alpha's value need to be passed at compile time using -DALPHA
8844 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
8845 *       This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
8846 *
8847 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
8848 *       The activation function is performed after the bias addition
8849 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
8850 *       -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
8851 *       -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
8852 *       -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
8853 *       -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
8854 *          (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
8855 *
8856 * @param[in]  src0_ptr                           Pointer to the source matrix. Supported data types: F16
8857 * @param[in]  src0_stride_x                      Stride of the source matrix in X dimension (in bytes)
8858 * @param[in]  src0_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
8859 * @param[in]  src0_stride_y                      Stride of the source matrix in Y dimension (in bytes)
8860 * @param[in]  src0_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
8861 * @param[in]  src0_offset_first_element_in_bytes The offset of the first element in the source matrix
8862 * @param[in]  src1_ptr                           Pointer to the source matrix. Supported data types: same as @p src0_ptr
8863 * @param[in]  src1_stride_x                      Stride of the source matrix in X dimension (in bytes)
8864 * @param[in]  src1_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
8865 * @param[in]  src1_stride_y                      Stride of the source matrix in Y dimension (in bytes)
8866 * @param[in]  src1_step_y                        src_stride_y * number of elements along Y processed per workitem(in bytes)
8867 * @param[in]  src1_offset_first_element_in_bytes The offset of the first element in the source matrix
8868 * @param[in]  src2_ptr                           (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
8869 * @param[in]  src2_stride_x                      (Optional) Stride of the bias matrix in X dimension (in bytes)
8870 * @param[in]  src2_step_x                        (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
8871 * @param[in]  src2_stride_y                      (Optional) Stride of the bias matrix in Y dimension (in bytes)
8872 * @param[in]  src2_step_y                        (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
8873 * @param[in]  src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
8874 * @param[out] dst_ptr                            Pointer to the destination matrix Supported data types: same as @p src0_ptr
8875 * @param[in]  dst_stride_x                       Stride of the destination matrix in X dimension (in bytes)
8876 * @param[in]  dst_step_x                         dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
8877 * @param[in]  dst_stride_y                       Stride of the destination matrix in Y dimension (in bytes)
8878 * @param[in]  dst_step_y                         dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
8879 * @param[in]  dst_offset_first_element_in_bytes  The offset of the first element in the destination matrix
8880 * @param[in]  src0_stride_z                      Stride of the source matrix in Z dimension (in bytes)
8881 * @param[in]  src1_stride_z                      Stride of the source matrix in Z dimension (in bytes)
8882 * @param[in]  src2_stride_z                      (Optional) Stride of the bias matrix in Z dimension (in bytes)
8883 * @param[in]  dst_stride_z                       Stride of the destination tensor in Z dimension (in bytes)
8884 * @param[in]  src_cross_plane_pad                (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
8885 * @param[in]  dst_cross_plane_pad                (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
8886 */
8887__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
8888                                                 IMAGE_DECLARATION(src1),
8889#if defined(BETA)
8890                                                 IMAGE_DECLARATION(src2),
8891#endif // defined(BETA)
8892                                                 IMAGE_DECLARATION(dst),
8893                                                 uint src0_stride_z,
8894                                                 uint src1_stride_z,
8895#if defined(BETA)
8896                                                 uint src2_stride_z,
8897#endif //defined(BETA)
8898                                                 uint dst_stride_z
8899#if defined(REINTERPRET_INPUT_AS_3D)
8900                                                 ,
8901                                                 uint src_cross_plane_pad
8902#endif // REINTERPRET_INPUT_AS_3D
8903#if defined(REINTERPRET_OUTPUT_AS_3D)
8904                                                 ,
8905                                                 uint dst_cross_plane_pad
8906#endif // REINTERPRET_OUTPUT_AS_3D
8907                                                )
8908{
8909    int idx = get_global_id(0) * N0;
8910
8911    // Compute starting address for matrix A and Matrix B
8912    int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
8913
8914    // Update address for the matrix A
8915    src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
8916
8917    // Update address for the matrix B
8918    src_addr.s1 += idx * sizeof(half);
8919
8920#if defined(REINTERPRET_INPUT_AS_3D)
8921    // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
8922    // in order to take into account the presence of possible cross plane paddings
8923    //
8924    //  |                  |
8925    //  |      plane0      |
8926    //  |                  |
8927    //  |__________________|
8928    //  |******************|
8929    //  |  cross_plane_pad |
8930    //  |******************|
8931    //  |                  |
8932    //  |      plane1      |
8933    //  |                  |
8934    //  |__________________|
8935
8936    // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
8937    uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
8938    zin       = min(DEPTH_GEMM3D - 1, zin);
8939
8940    // Add offset due to the cross plane paddings
8941    zin *= (src_cross_plane_pad * src0_stride_y);
8942
8943    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
8944    // multiply src0_stride_z by DEPTH_GEMM3D
8945    src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
8946
8947#else // defined(REINTERPRET_INPUT_AS_3D)
8948
8949    // Add offset for batched GEMM
8950    src_addr.s0 += get_global_id(2) * src0_stride_z;
8951
8952#endif // defined(REINTERPRET_INPUT_AS_3D)
8953
8954#if defined(MATRIX_B_DEPTH)
8955    // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
8956    src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
8957#else  // defined(MATRIX_B_DEPTH)
8958    src_addr.s1 += get_global_id(2) * src1_stride_z;
8959#endif // defined(MATRIX_B_DEPTH)
8960
8961    half8 acc0 = 0.0h;
8962#if M0 > 1
8963    half8 acc1 = 0.0h;
8964#endif // M0 > 1
8965#if M0 > 2
8966    half8 acc2 = 0.0h;
8967#endif // M0 > 2
8968#if M0 > 3
8969    half8 acc3 = 0.0h;
8970#endif // M0 > 3
8971
8972    int i = 0;
8973    for(; i <= ((int)K - 4); i += 4)
8974    {
8975#if defined(REINTERPRET_INPUT_AS_3D)
8976        // Load values from matrix A
8977        LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
8978#else // defined(REINTERPRET_INPUT_AS_3D)
8979        // Load values from matrix A
8980        half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
8981#if M0 > 1
8982        half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
8983#endif // M0 > 1
8984#if M0 > 2
8985        half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
8986#endif // M0 > 2
8987#if M0 > 3
8988        half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
8989#endif // M0 > 3
8990#endif // defined(REINTERPRET_INPUT_AS_3D)
8991
8992        // Load values from matrix B
8993        half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
8994        src_addr.s1 += src1_stride_y;
8995
8996        // Accumulate
8997        acc0 = fma(b0, (half8)a0.s0, acc0);
8998#if M0 > 1
8999        acc1 = fma(b0, (half8)a1.s0, acc1);
9000#endif // M0 > 1
9001#if M0 > 2
9002        acc2 = fma(b0, (half8)a2.s0, acc2);
9003#endif // M0 > 2
9004#if M0 > 3
9005        acc3 = fma(b0, (half8)a3.s0, acc3);
9006#endif // M0 > 3
9007
9008        b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
9009        src_addr.s1 += src1_stride_y;
9010        acc0 = fma(b0, (half8)a0.s1, acc0);
9011#if M0 > 1
9012        acc1 = fma(b0, (half8)a1.s1, acc1);
9013#endif // M0 > 1
9014#if M0 > 2
9015        acc2 = fma(b0, (half8)a2.s1, acc2);
9016#endif // M0 > 2
9017#if M0 > 3
9018        acc3 = fma(b0, (half8)a3.s1, acc3);
9019#endif // M0 > 3
9020
9021        b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
9022        src_addr.s1 += src1_stride_y;
9023        acc0 = fma(b0, (half8)a0.s2, acc0);
9024#if M0 > 1
9025        acc1 = fma(b0, (half8)a1.s2, acc1);
9026#endif // M0 > 1
9027#if M0 > 2
9028        acc2 = fma(b0, (half8)a2.s2, acc2);
9029#endif // M0 > 2
9030#if M0 > 3
9031        acc3 = fma(b0, (half8)a3.s2, acc3);
9032#endif // M0 > 3
9033
9034        b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
9035        src_addr.s1 += src1_stride_y;
9036        acc0 = fma(b0, (half8)a0.s3, acc0);
9037#if M0 > 1
9038        acc1 = fma(b0, (half8)a1.s3, acc1);
9039#endif // M0 > 1
9040#if M0 > 2
9041        acc2 = fma(b0, (half8)a2.s3, acc2);
9042#endif // M0 > 2
9043#if M0 > 3
9044        acc3 = fma(b0, (half8)a3.s3, acc3);
9045#endif // M0 > 3
9046
9047        src_addr.s0 += 4 * sizeof(half);
9048    }
9049
9050    for(; i < (int)K; ++i)
9051    {
9052#if defined(REINTERPRET_INPUT_AS_3D)
9053        // Load values from matrix A
9054        half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
9055#if M0 > 1
9056        half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
9057#endif // M0 > 1
9058#if M0 > 2
9059        half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
9060#endif // M0 > 2
9061#if M0 > 3
9062        half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
9063#endif // M0 > 3
9064#else  // defined(REINTERPRET_INPUT_AS_3D)
9065        // Load values from matrix A
9066        half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
9067#if M0 > 1
9068        half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
9069#endif // M0 > 1
9070#if M0 > 2
9071        half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
9072#endif // M0 > 2
9073#if M0 > 3
9074        half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
9075#endif // M0 > 3
9076#endif // defined(REINTERPRET_INPUT_AS_3D)
9077
9078        // Load values from matrix B
9079        half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
9080
9081        src_addr += (int2)(sizeof(half), src1_stride_y);
9082
9083        // Accumulate
9084        acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
9085#if M0 > 1
9086        acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
9087#endif                                   // M0 > 1
9088#if M0 > 2
9089        acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
9090#endif                                   // M0 > 2
9091#if M0 > 3
9092        acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
9093#endif                                   // M0 > 3
9094    }
9095
9096    int z = get_global_id(2);
9097
9098    // Compute dst address
9099    __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y);
9100
9101    uint4 zout = 0;
9102
9103#if defined(REINTERPRET_OUTPUT_AS_3D)
9104
9105    // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
9106    // in order to take into account the presence of possible cross plane paddings
9107    //
9108    //  |                  |
9109    //  |      plane0      |
9110    //  |                  |
9111    //  |__________________|
9112    //  |******************|
9113    //  |  cross_plane_pad |
9114    //  |******************|
9115    //  |                  |
9116    //  |      plane1      |
9117    //  |                  |
9118    //  |__________________|
9119
9120    // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
9121    zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
9122    zout = min(DEPTH_GEMM3D - 1, zout);
9123
9124    // Add offset due to the cross plane paddings
9125    zout *= (dst_cross_plane_pad * dst_stride_y);
9126
9127    // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
9128    // multiply dst_stride_z by DEPTH_GEMM3D
9129    dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
9130#else  // defined(REINTERPRET_OUTPUT_AS_3D)
9131    // Add offset for batched GEMM
9132    dst_addr += z * dst_stride_z;
9133#endif // defined(REINTERPRET_OUTPUT_AS_3D)
9134
9135    // Multiply by the weight of matrix-matrix product and store the result
9136#if defined(ALPHA)
9137    SCALE_BLOCK(M0, half, acc, ALPHA);
9138#endif // defined(ALPHA)
9139
9140    // Add beta*bias
9141#if defined(BETA)
9142    REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
9143
9144#if defined(BROADCAST_BIAS)
9145    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
9146
9147    LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
9148
9149#ifndef UNIT_BETA
9150    SCALE_BLOCK(1, half, bias, BETA);
9151#endif // UNIT_BIAS
9152
9153    // acc = acc + bias[broadcasted]
9154    ADD_BLOCK_BROADCAST(M0, acc, bias0);
9155
9156#else // defined(BROADCAST_BIAS)
9157    __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
9158                                PARTIAL_STORE_M0)
9159                                * src2_stride_y)
9160                                + z * src2_stride_z;
9161
9162    LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
9163
9164#ifndef UNIT_BETA
9165    SCALE_BLOCK(M0, half, bias, BETA);
9166#endif // UNIT_BIAS
9167
9168    // acc = acc + bias
9169    ADD_BLOCK(M0, acc, bias);
9170
9171#endif // defined(BROADCAST_BIAS)
9172#endif // defined(BETA)
9173
9174#if defined(ACTIVATION_TYPE)
9175    ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc, A_VAL, B_VAL);
9176#endif // defined(ACTIVATION_TYPE)
9177
9178    // Store the output block
9179    const bool cond_y = get_global_id(1) == 0;
9180    const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
9181    STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
9182}
9183#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
9184
9185#endif // defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
9186
9187)"