• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1R"(
2
3/*
4 * Copyright (c) 2020 Arm Limited.
5 *
6 * SPDX-License-Identifier: MIT
7 *
8 * Permission is hereby granted, free of charge, to any person obtaining a copy
9 * of this software and associated documentation files (the "Software"), to
10 * deal in the Software without restriction, including without limitation the
11 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
12 * sell copies of the Software, and to permit persons to whom the Software is
13 * furnished to do so, subject to the following conditions:
14 *
15 * The above copyright notice and this permission notice shall be included in all
16 * copies or substantial portions of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 * SOFTWARE.
25 */
26
27/*
28 * Copyright (c) 2016-2020 Arm Limited.
29 *
30 * SPDX-License-Identifier: MIT
31 *
32 * Permission is hereby granted, free of charge, to any person obtaining a copy
33 * of this software and associated documentation files (the "Software"), to
34 * deal in the Software without restriction, including without limitation the
35 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
36 * sell copies of the Software, and to permit persons to whom the Software is
37 * furnished to do so, subject to the following conditions:
38 *
39 * The above copyright notice and this permission notice shall be included in all
40 * copies or substantial portions of the Software.
41 *
42 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
43 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
44 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
45 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
46 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
47 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
48 * SOFTWARE.
49 */
50#ifndef ARM_COMPUTE_HELPER_H
51#define ARM_COMPUTE_HELPER_H
52
53/*
54 * Copyright (c) 2020 Arm Limited.
55 *
56 * SPDX-License-Identifier: MIT
57 *
58 * Permission is hereby granted, free of charge, to any person obtaining a copy
59 * of this software and associated documentation files (the "Software"), to
60 * deal in the Software without restriction, including without limitation the
61 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
62 * sell copies of the Software, and to permit persons to whom the Software is
63 * furnished to do so, subject to the following conditions:
64 *
65 * The above copyright notice and this permission notice shall be included in all
66 * copies or substantial portions of the Software.
67 *
68 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
69 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
70 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
71 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
72 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
73 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
74 * SOFTWARE.
75 */
76
77/** Store the 0 to (n-1)th rows of the given variables
78 * @name STORE_ROW_n
79 *
80 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
81 * @param[in] DATA_TYPE The data type of the vectors
82 * @param[in] BASENAME  The basename of the variables
83 * @param[in] PTR       The base pointer
84 * @param[in] STRIDE_Y  The stride value in y-axis direction
85 * @param[in] Z         The offset in z-axis direction
86 * @{
87 */
88#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
89    VSTORE(N0)                                                 \
90    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
91
92#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
93    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
94    VSTORE(N0)                                                 \
95    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
96
97#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
98    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
99    VSTORE(N0)                                                 \
100    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
101
102#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
103    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
104    VSTORE(N0)                                                 \
105    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
106
107#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
108    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
109    VSTORE(N0)                                                 \
110    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
111
112#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
113    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
114    VSTORE(N0)                                                 \
115    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
116
117#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
118    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
119    VSTORE(N0)                                                 \
120    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
121
122#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
123    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
124    VSTORE(N0)                                                 \
125    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
126
127#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
128    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
129    VSTORE(N0)                                                 \
130    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
131
132#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
133    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
134    VSTORE(N0)                                                  \
135    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
136
137#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
138    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
139    VSTORE(N0)                                                  \
140    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
141
142#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
143    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
144    VSTORE(N0)                                                  \
145    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
146
147#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
148    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
149    VSTORE(N0)                                                  \
150    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
151
152#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
153    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
154    VSTORE(N0)                                                  \
155    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
156
157#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
158    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
159    VSTORE(N0)                                                  \
160    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
161
162#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
163    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
164    VSTORE(N0)                                                  \
165    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
166/** @} */ // end of groupd STORE_ROW_n
167
168/** Convert and store the 0th to (n-1)th rows of the given variables
169 * @name CONVERT_STORE_ROW_n
170 *
171 * @param[in] N0        The size of the vectors
172 * @param[in] DATA_TYPE The data type of the vectors
173 * @param[in] BASENAME  The basename of the variables
174 * @param[in] PTR       The base pointer
175 * @param[in] STRIDE_Y  The stride value in y-axis direction
176 * @param[in] Z         The offset in z-axis direction
177 * @{
178 */
179#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
180    VSTORE(N0)                                                         \
181    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
182
183#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
184    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
185    VSTORE(N0)                                                         \
186    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
187
188#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
189    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
190    VSTORE(N0)                                                         \
191    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
192
193#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
194    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
195    VSTORE(N0)                                                         \
196    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
197
198#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
199    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
200    VSTORE(N0)                                                         \
201    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
202
203#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
204    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
205    VSTORE(N0)                                                         \
206    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
207
208#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
209    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
210    VSTORE(N0)                                                         \
211    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
212
213#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
214    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
215    VSTORE(N0)                                                         \
216    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
217
218#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
219    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
220    VSTORE(N0)                                                         \
221    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
222
223#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
224    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
225    VSTORE(N0)                                                     \
226    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
227
228#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
229    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
230    VSTORE(N0)                                                          \
231    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
232
233#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
234    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
235    VSTORE(N0)                                                          \
236    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
237
238#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
239    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
240    VSTORE(N0)                                                          \
241    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
242
243#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
244    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
245    VSTORE(N0)                                                          \
246    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
247
248#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
249    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
250    VSTORE(N0)                                                          \
251    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
252
253#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
254    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
255    VSTORE(N0)                                                          \
256    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
257
258/** @} */ // end of groupd CONVERT_STORE_ROW_n
259
260/** Store a block of the given size M0xN0
261 * @name STORE_BLOCK
262 *
263 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
264 * The data to store is expected to have consecutive names for each row.
265 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
266 * The Z offset is expected to have consecutive names.
267 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
268 *
269 * @param[in] M0        The number of rows to store
270 * @param[in] N0        The size of each vector
271 * @param[in] DATA_TYPE The data type of the vectors
272 * @param[in] BASENAME  The basename of the variables
273 * @param[in] PTR       The base pointer
274 * @param[in] STRIDE_Y  The stride value in y-axis direction
275 * @param[in] Z         The offset in z-axis direction
276 * @{
277 */
278#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
279#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
280/** @} */ // end of group STORE_BLOCK
281
282/** Convert and store a block of the given size M0xN0
283 * @name CONVERT_STORE_BLOCK
284 *
285 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
286 * The data to store is expected to have consecutive names for each row.
287 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
288 * The Z offset is expected to have consecutive names.
289 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
290 *
291 * @param[in] M0        The number of rows to store
292 * @param[in] N0        The size of each vector
293 * @param[in] DATA_TYPE The data type of the vectors
294 * @param[in] BASENAME  The basename of the variables
295 * @param[in] PTR       The base pointer
296 * @param[in] STRIDE_Y  The stride value in y-axis direction
297 * @param[in] Z         The offset in z-axis direction
298 * @{
299 */
300#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
301#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
302/** @} */ // end of group CONVERT_STORE_BLOCK
303
304/** Partially store the 0 to (n-1)th rows of the given variables
305 * @name STORE_ROW_PARTIAL_n
306 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
307 *
308 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
309 *
310 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
311 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
312 * @param[in] DATA_TYPE The data type of the vectors
313 * @param[in] BASENAME  The basename of the variables
314 * @param[in] PTR       The base pointer
315 * @param[in] STRIDE_Y  The stride value in y-axis direction
316 * @param[in] Z         The offset in z-axis direction
317 * @{
318 */
319#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
320    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
321    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
322
323#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
324    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
325    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
326    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
327
328#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
329    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
330    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
331    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
332
333#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
334    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
335    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
336    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
337
338#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
339    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
340    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
341    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
342
343#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
344    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
345    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
346    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
347
348#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
349    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
350    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
351    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
352
353#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
354    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
355    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
356    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
357
358#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
359    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
360    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
361    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
362
363#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
364    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
365    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
366    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
367
368#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
369    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
370    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
371    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
372
373#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
374    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
375    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
376    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
377
378#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
379    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
380    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
381    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
382
383#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
384    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
385    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
386    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
387
388#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
389    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
390    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
391    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
392
393#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
394    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
395    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
396    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
397/** @} */ // end of groupd STORE_ROW_PARTIAL_n
398
399/** Partially store a block of the given size STORE_M0xSTORE_N0
400 * @name STORE_BLOCK_PARTIAL
401 *
402 * @note The vector width @p N0 is also required for correct partial storing behaviour.
403 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
404 *
405 * The data to store is expected to have consecutive names for each row.
406 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
407 * The Z offset is expected to have consecutive names.
408 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
409 *
410 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
411 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
412 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
413 * @param[in] DATA_TYPE The data type of the vectors
414 * @param[in] BASENAME  The basename of the variables
415 * @param[in] PTR       The base pointer
416 * @param[in] STRIDE_Y  The stride value in y-axis direction
417 * @param[in] Z         The offset in z-axis direction
418 * @{
419 */
420#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
421#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
422/** Store a block that can be partial in both x and y dimensions
423 *
424 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
425 *
426 * The data to store is expected to have consecutive names for each row.
427 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
428 * The Z offset is expected to have consecutive names.
429 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
430 *
431 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
432 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
433 * @param[in] DATA_TYPE        The data type of the vectors
434 * @param[in] BASENAME         The basename of the variables
435 * @param[in] PTR              The base pointer
436 * @param[in] STRIDE_Y         The stride value in y-axis direction
437 * @param[in] Z                The offset in z-axis direction
438 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
439 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
440 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
441 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
442 */
443#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
444    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
445    {                                                                                                                                                     \
446        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
447    }                                                                                                                                                     \
448    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
449    {                                                                                                                                                     \
450        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
451    }                                                                                                                                                     \
452    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
453    {                                                                                                                                                     \
454        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
455    }                                                                                                                                                     \
456    else                                                                                                                                                  \
457    {                                                                                                                                                     \
458        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
459    }
460/** Store a block that can only be partial in x but not y.
461 *
462 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
463 *
464 * The data to store is expected to have consecutive names for each row.
465 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
466 * The Z offset is expected to have consecutive names.
467 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
468 *
469 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
470 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
471 * @param[in] DATA_TYPE        The data type of the vectors
472 * @param[in] BASENAME         The basename of the variables
473 * @param[in] PTR              The base pointer
474 * @param[in] STRIDE_Y         The stride value in y-axis direction
475 * @param[in] Z                The offset in z-axis direction
476 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
477 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
478 */
479#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
480    if(!(PARTIAL_COND_X))                                                                                         \
481    {                                                                                                             \
482        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
483    }                                                                                                             \
484    else                                                                                                          \
485    {                                                                                                             \
486        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
487    }
488/** Store a block that can only be partial in y but not x.
489 *
490 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
491 *
492 * The data to store is expected to have consecutive names for each row.
493 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
494 * The Z offset is expected to have consecutive names.
495 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
496 *
497 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
498 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
499 * @param[in] DATA_TYPE        The data type of the vectors
500 * @param[in] BASENAME         The basename of the variables
501 * @param[in] PTR              The base pointer
502 * @param[in] STRIDE_Y         The stride value in y-axis direction
503 * @param[in] Z                The offset in z-axis direction
504 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
505 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
506 */
507#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
508    if(!(PARTIAL_COND_Y))                                                                                         \
509    {                                                                                                             \
510        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
511    }                                                                                                             \
512    else                                                                                                          \
513    {                                                                                                             \
514        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
515    }
516/** @} */ // end of group STORE_BLOCK_PARTIAL
517
518#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
519
520/** Boundary-aware GEMM block store
521 * @name STORE_BLOCK_BOUNDARY_AWARE
522 * This macro assumes the following schemes to achieve boundary-awareness:
523 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
524 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
525 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
526 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
527 *
528 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
529 * blocks **at the end**.
530 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
531 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
532 *
533 *  *--x-->                         x == 0                        x == 1
534 *  |                  |<------------------------------N-------------------------->|
535 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
536 *  |     -------------#############################################################
537 *  *     |          | |...............................|...........................|
538 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
539 *        |          | |...............................|...........................|
540 *        M          --#############################################################
541 *        |          | |                               |...........................|
542 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
543 *        |          | |                               |...........................|
544 *        |------------#############################################################
545 *
546 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
547 *
548 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
549 *
550 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
551 * and select corresponding store methods such that the boundary detection logic is only added when needed.
552 *
553 * The data to store is expected to have consecutive names for each row.
554 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
555 * The Z offset is expected to have consecutive names.
556 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
557 *
558 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
559 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
560 * @param[in] DATA_TYPE        The data type of the vectors
561 * @param[in] BASENAME         The basename of the variables
562 * @param[in] PTR              The base pointer
563 * @param[in] STRIDE_Y         The stride value in y-axis direction
564 * @param[in] Z                The offset in z-axis direction
565 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
566 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
567 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
568 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
569 * @{
570 */
571#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
572// Case1: No partial blocks in either x or y
573#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
574    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
575
576#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
577// Case2: Partial blocks in y
578#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
579    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
580
581#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
582// Case3: Partial blocks in x
583#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
584    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
585
586#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
587// Case4: Partial blocks in both x and y
588#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
589    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
590
591#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
592
593#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
594/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
595
596#if defined(PARTIAL_STORE_M0)
597/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
598 * @name COMPUTE_M0_START_ROW
599 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
600 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
601 * blocks in the y dimension to avoid any padding.
602 * EG: M0=4, PARTIAL_STORE_M0=1:
603 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
604 * block 0 (partial)| start row = 0   | start row = 0
605 * block 1 (full)   | start row = 4   | start row = 1
606 * block 2 (full)   | start row = 8   | start row = 5
607 *
608 * @param[in] y                Global id of current block in y.
609 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
610 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
611 * @{
612 */
613#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
614    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
615#else // defined(PARTIAL_STORE_M0)
616#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
617    ((uint)(y * M0))
618#endif    // defined(PARTIAL_STORE_M0)
619/** @} */ // end of group COMPUTE_M0_START_ROW
620
621/** Store a vector that can only be partial in x.
622 *
623 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
624 *
625 * The data to store is expected to end in a 0.
626 * E.g., for basename=c, the expected name is c0.
627 *
628 * @param[in] basename  The name of the variable without trailing 0
629 * @param[in] data_type The data type of the vector
630 * @param[in] ptr       The base pointer
631 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
632 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
633 * @param[in] cond      Condition to select either vec_size0 or vec_size1
634 * @{
635 */
636#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
637    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
638/** @} */ // end of group STORE_VECTOR_SELECT
639
640#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
641#pragma OPENCL EXTENSION cl_khr_fp16 : enable
642#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
643
644#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
645#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
646#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
647
648#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
649#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
650#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
651
652#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
653#pragma OPENCL EXTENSION cl_arm_printf : enable
654#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
655
656#define GPU_ARCH_MIDGARD 0x100
657#define GPU_ARCH_BIFROST 0x200
658
659/** Concatenate two inputs.
660 *
661 * @param[in] a The first input to be concatenated
662 * @param[in] b The second input to be concatenated
663 *
664 * @return The concatenated output
665 */
666#define CONCAT(a, b) a##b
667
668/** Expand the given vector
669 *
670 * @param[in] x The vector to be expanded
671 *
672 * @return The expanded output
673 */
674#define EXPAND(x) x
675
676/** Clamp the given value between an upper and lower bound.
677 *
678 * @param[in] x       The value to be clamped
679 * @param[in] min_val The lower bound
680 * @param[in] max_val The upper bound
681 *
682 * @return The clamped value.
683 */
684#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
685
686/** REVn reverses the given vector whose size is n.
687 * @name REVn
688 *
689 * @param[in] x The vector to be reversed
690 *
691 * @return The reversed vector
692 * @{
693 */
694#define REV1(x) ((x))
695#define REV2(x) ((x).s10)
696#define REV3(x) ((x).s210)
697#define REV4(x) ((x).s3210)
698#define REV8(x) ((x).s76543210)
699#define REV16(x) ((x).sFEDCBA9876543210)
700/** @} */ // end of group REVn
701
702/** Reverse the given vector.
703 * @name REVERSE
704 *
705 * @param[in] x The vector to be reversed
706 * @param[in] s The size of the vector
707 *
708 * @return The reversed vector
709 * @{
710 */
711#define REVERSE_STR(x, s) REV##s((x))
712#define REVERSE(x, s) REVERSE_STR(x, s)
713/** @} */ // end of group REVERSE
714
715/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
716 * @name ROTs_n
717 *
718 * @param[in] x The vector to be shifted
719 *
720 * @return The shifted vector
721 * @{
722 */
723#define ROT1_0(x) ((x))
724
725#define ROT2_0(x) ((x))
726#define ROT2_1(x) ((x).s10)
727
728#define ROT3_0(x) ((x))
729#define ROT3_1(x) ((x).s201)
730#define ROT3_2(x) ((x).s120)
731
732#define ROT4_0(x) ((x))
733#define ROT4_1(x) ((x).s3012)
734#define ROT4_2(x) ((x).s2301)
735#define ROT4_3(x) ((x).s1230)
736
737#define ROT8_0(x) ((x))
738#define ROT8_1(x) ((x).s70123456)
739#define ROT8_2(x) ((x).s67012345)
740#define ROT8_3(x) ((x).s56701234)
741#define ROT8_4(x) ((x).s45670123)
742#define ROT8_5(x) ((x).s34567012)
743#define ROT8_6(x) ((x).s23456701)
744#define ROT8_7(x) ((x).s12345670)
745
746#define ROT16_0(x) ((x))
747#define ROT16_1(x) ((x).sF0123456789ABCDE)
748#define ROT16_2(x) ((x).sEF0123456789ABCD)
749#define ROT16_3(x) ((x).sDEF0123456789ABC)
750#define ROT16_4(x) ((x).sCDEF0123456789AB)
751#define ROT16_5(x) ((x).sBCDEF0123456789A)
752#define ROT16_6(x) ((x).sABCDEF0123456789)
753#define ROT16_7(x) ((x).s9ABCDEF012345678)
754#define ROT16_8(x) ((x).s89ABCDEF01234567)
755#define ROT16_9(x) ((x).s789ABCDEF0123456)
756#define ROT16_10(x) ((x).s6789ABCDEF012345)
757#define ROT16_11(x) ((x).s56789ABCDEF01234)
758#define ROT16_12(x) ((x).s456789ABCDEF0123)
759#define ROT16_13(x) ((x).s3456789ABCDEF012)
760#define ROT16_14(x) ((x).s23456789ABCDEF01)
761#define ROT16_15(x) ((x).s123456789ABCDEF0)
762/** @} */ // end of group ROTs_n
763
764/** Circular-right-shift (rotate-right) the given vector by the given amount.
765 * @name ROTATE
766 *
767 * @param[in] x The vector to be shifted
768 * @param[in] s The size of the vector
769 * @param[in] n The amount to be shifted
770 *
771 * @return The shifted vector
772 * @{
773 */
774#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
775#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
776/** @} */ // end of group ROTATE
777
778/** Creates a vector of size n filled with offset values corresponding to the location of each element.
779 * @name V_OFFSn
780 *
781 * @param[in] dt The data type of the output vector
782 *
783 * @return The vector filled with offset values
784 * @{
785 */
786#define V_OFFS1(dt) (dt##1)(0)
787#define V_OFFS2(dt) (dt##2)(0, 1)
788#define V_OFFS3(dt) (dt##3)(0, 1, 2)
789#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
790#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
791#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
792/** @} */ // end of group V_OFFSn
793
794/** Create a vector filled with offset values corresponding to the location of each element.
795 * @name VEC_OFFS
796 *
797 * @param[in] dt The data type of the output vector
798 * @param[in] s  The size of the output vector
799 *
800 * @return The vector filled with offset values
801 * @{
802 */
803#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
804#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
805/** @} */ // end of group VEC_OFFS
806
807#define VLOAD_STR(size) vload##size
808#define VLOAD(size) VLOAD_STR(size)
809
810#define PIXEL_UNIT4 1
811#define PIXEL_UNIT8 2
812#define PIXEL_UNIT16 4
813
814/** Utility macro to convert a vector size in pixel unit.
815 *
816 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
817 *
818 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
819 *
820 * @return The pixel unit (number of pixels)
821 * @{
822 */
823#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
824#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
825/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
826
827#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
828#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
829#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
830
831#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
832#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
833#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
834#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
835#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
836
837/** Utility macro to read a 2D OpenCL image object.
838 *
839 * @note Coordinates are not normalized
840 *
841 * @param[in] data_type Data type
842 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
843 * @param[in] img       OpenCL image object
844 * @param[in] x_coord   The x coordinate for the top-left pixel
845 * @param[in] y_coord   The y coordinate for the top-left pixel
846 *
847 * @return Pixels from the 2D OpenCL image object
848 * @{
849 */
850#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
851#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
852
853#define VSTORE_STR(size) vstore##size
854#define VSTORE(size) VSTORE_STR(size)
855
856#define float1 float
857#define half1 half
858#define char1 char
859#define uchar1 uchar
860#define short1 short
861#define ushort1 ushort
862#define int1 int
863#define uint1 uint
864#define long1 long
865#define ulong1 ulong
866#define double1 double
867
868#define vload1(OFFSET, PTR) *(OFFSET + PTR)
869#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
870
871/** Extended partial vstore that correctly handles scalar values as well.
872 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
873 * @name VSTORE_PARTIAL
874 *
875 * @note With this macro, the passed data can be both a vector and a scalar
876 * @note @p store_size needs to be <= @p size
877 * eg 1: Valid
878 * VSTORE_PARTIAL(16, 15) ...;
879 * eg 2: Invalid
880 * VSTORE_PARTIAL(4, 7) ...;
881 *
882 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
883 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
884 * @{
885 */
886#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
887#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
888
889#define NO_STORE(data, offs, ptr) \
890    {                             \
891    }
892
893// Size == 1 (scalar)
894#define vstore_partial_1_0 NO_STORE
895#define vstore_partial_1_1 vstore1
896#define vstore_partial_1_2 NO_STORE
897#define vstore_partial_1_3 NO_STORE
898#define vstore_partial_1_4 NO_STORE
899#define vstore_partial_1_5 NO_STORE
900#define vstore_partial_1_6 NO_STORE
901#define vstore_partial_1_7 NO_STORE
902#define vstore_partial_1_8 NO_STORE
903#define vstore_partial_1_9 NO_STORE
904#define vstore_partial_1_10 NO_STORE
905#define vstore_partial_1_11 NO_STORE
906#define vstore_partial_1_12 NO_STORE
907#define vstore_partial_1_13 NO_STORE
908#define vstore_partial_1_14 NO_STORE
909#define vstore_partial_1_15 NO_STORE
910#define vstore_partial_1_16 NO_STORE
911// Size == 2
912#define vstore_partial_2_0 NO_STORE
913#define vstore_partial_2_1 vstore_partial_1
914#define vstore_partial_2_2 vstore_partial_2
915#define vstore_partial_2_3 NO_STORE
916#define vstore_partial_2_4 NO_STORE
917#define vstore_partial_2_5 NO_STORE
918#define vstore_partial_2_6 NO_STORE
919#define vstore_partial_2_7 NO_STORE
920#define vstore_partial_2_8 NO_STORE
921#define vstore_partial_2_9 NO_STORE
922#define vstore_partial_2_10 NO_STORE
923#define vstore_partial_2_11 NO_STORE
924#define vstore_partial_2_12 NO_STORE
925#define vstore_partial_2_13 NO_STORE
926#define vstore_partial_2_14 NO_STORE
927#define vstore_partial_2_15 NO_STORE
928#define vstore_partial_2_16 NO_STORE
929// Size == 3
930#define vstore_partial_3_0 NO_STORE
931#define vstore_partial_3_1 vstore_partial_1
932#define vstore_partial_3_2 vstore_partial_2
933#define vstore_partial_3_3 vstore_partial_3
934#define vstore_partial_3_4 NO_STORE
935#define vstore_partial_3_5 NO_STORE
936#define vstore_partial_3_6 NO_STORE
937#define vstore_partial_3_7 NO_STORE
938#define vstore_partial_3_8 NO_STORE
939#define vstore_partial_3_9 NO_STORE
940#define vstore_partial_3_10 NO_STORE
941#define vstore_partial_3_11 NO_STORE
942#define vstore_partial_3_12 NO_STORE
943#define vstore_partial_3_13 NO_STORE
944#define vstore_partial_3_14 NO_STORE
945#define vstore_partial_3_15 NO_STORE
946#define vstore_partial_3_16 NO_STORE
947// Size == 4
948#define vstore_partial_4_0 NO_STORE
949#define vstore_partial_4_1 vstore_partial_1
950#define vstore_partial_4_2 vstore_partial_2
951#define vstore_partial_4_3 vstore_partial_3
952#define vstore_partial_4_4 vstore_partial_4
953#define vstore_partial_4_5 NO_STORE
954#define vstore_partial_4_6 NO_STORE
955#define vstore_partial_4_7 NO_STORE
956#define vstore_partial_4_8 NO_STORE
957#define vstore_partial_4_9 NO_STORE
958#define vstore_partial_4_10 NO_STORE
959#define vstore_partial_4_11 NO_STORE
960#define vstore_partial_4_12 NO_STORE
961#define vstore_partial_4_13 NO_STORE
962#define vstore_partial_4_14 NO_STORE
963#define vstore_partial_4_15 NO_STORE
964#define vstore_partial_4_16 NO_STORE
965// Size == 8
966#define vstore_partial_8_0 NO_STORE
967#define vstore_partial_8_1 vstore_partial_1
968#define vstore_partial_8_2 vstore_partial_2
969#define vstore_partial_8_3 vstore_partial_3
970#define vstore_partial_8_4 vstore_partial_4
971#define vstore_partial_8_5 vstore_partial_5
972#define vstore_partial_8_6 vstore_partial_6
973#define vstore_partial_8_7 vstore_partial_7
974#define vstore_partial_8_8 vstore_partial_8
975#define vstore_partial_8_9 NO_STORE
976#define vstore_partial_8_10 NO_STORE
977#define vstore_partial_8_11 NO_STORE
978#define vstore_partial_8_12 NO_STORE
979#define vstore_partial_8_13 NO_STORE
980#define vstore_partial_8_14 NO_STORE
981#define vstore_partial_8_15 NO_STORE
982#define vstore_partial_8_16 NO_STORE
983// Size == 16
984#define vstore_partial_16_0 NO_STORE
985#define vstore_partial_16_1 vstore_partial_1
986#define vstore_partial_16_2 vstore_partial_2
987#define vstore_partial_16_3 vstore_partial_3
988#define vstore_partial_16_4 vstore_partial_4
989#define vstore_partial_16_5 vstore_partial_5
990#define vstore_partial_16_6 vstore_partial_6
991#define vstore_partial_16_7 vstore_partial_7
992#define vstore_partial_16_8 vstore_partial_8
993#define vstore_partial_16_9 vstore_partial_9
994#define vstore_partial_16_10 vstore_partial_10
995#define vstore_partial_16_11 vstore_partial_11
996#define vstore_partial_16_12 vstore_partial_12
997#define vstore_partial_16_13 vstore_partial_13
998#define vstore_partial_16_14 vstore_partial_14
999#define vstore_partial_16_15 vstore_partial_15
1000#define vstore_partial_16_16 vstore_partial_16
1001
1002/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
1003 * @name vstore_partial_n
1004 *
1005 * @note @p DATA needs to be a vector not a scalar
1006 * @note n needs to be <= the vector width of the input variable @p DATA
1007 * eg 1: Valid
1008 * vstore_partial_15(var:float16, 0, 0xabcd);
1009 * eg 2: Invalid
1010 * vstore_partial_7(var:float4, 0, 0xabcd);
1011 *
1012 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
1013 *
1014 * @param[in] DATA   The name of the variable
1015 * @param[in] OFFSET Offset in n
1016 * @param[in] PTR    The base pointer
1017 * @{
1018 */
1019#define vstore_partial_1(DATA, OFFSET, PTR) \
1020    vstore1(DATA.s0, OFFSET, PTR);
1021
1022#define vstore_partial_2(DATA, OFFSET, PTR) \
1023    vstore2(DATA.s01, OFFSET, PTR);
1024
1025#define vstore_partial_3(DATA, OFFSET, PTR) \
1026    vstore3(DATA.s012, OFFSET, PTR);
1027
1028#define vstore_partial_4(DATA, OFFSET, PTR) \
1029    vstore4(DATA.s0123, OFFSET, PTR);
1030
1031#define vstore_partial_5(DATA, OFFSET, PTR)    \
1032    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1033    vstore1(DATA.s4, OFFSET, PTR + 4);
1034
1035#define vstore_partial_6(DATA, OFFSET, PTR)    \
1036    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1037    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
1038
1039#define vstore_partial_7(DATA, OFFSET, PTR)    \
1040    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1041    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
1042
1043#define vstore_partial_8(DATA, OFFSET, PTR) \
1044    vstore8(DATA.s01234567, OFFSET, PTR);
1045
1046#define vstore_partial_9(DATA, OFFSET, PTR)        \
1047    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1048    vstore1(DATA.s8, OFFSET, PTR + 8);
1049
1050#define vstore_partial_10(DATA, OFFSET, PTR)       \
1051    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1052    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
1053
1054#define vstore_partial_11(DATA, OFFSET, PTR)       \
1055    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1056    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
1057
1058#define vstore_partial_12(DATA, OFFSET, PTR)       \
1059    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1060    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
1061
1062#define vstore_partial_13(DATA, OFFSET, PTR)       \
1063    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1064    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
1065
1066#define vstore_partial_14(DATA, OFFSET, PTR)       \
1067    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1068    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
1069
1070#define vstore_partial_15(DATA, OFFSET, PTR)       \
1071    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1072    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
1073
1074#define vstore_partial_16(DATA, OFFSET, PTR) \
1075    vstore16(DATA, OFFSET, PTR);
1076/** @} */ // end of groupd vstore_partial_n
1077/** @} */ // end of groupd VSTORE_PARTIAL
1078
1079// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
1080// without _sat to overcome this issue
1081#define convert_float_sat convert_float
1082#define convert_float1_sat convert_float
1083#define convert_float2_sat convert_float2
1084#define convert_float3_sat convert_float3
1085#define convert_float4_sat convert_float4
1086#define convert_float8_sat convert_float8
1087#define convert_float16_sat convert_float16
1088#define convert_half_sat convert_float
1089#define convert_half1_sat convert_half
1090#define convert_half2_sat convert_half2
1091#define convert_half3_sat convert_half3
1092#define convert_half4_sat convert_half4
1093#define convert_half8_sat convert_half8
1094#define convert_half16_sat convert_half16
1095
1096#define convert_float1 convert_float
1097#define convert_half1 convert_half
1098#define convert_char1 convert_char
1099#define convert_uchar1 convert_uchar
1100#define convert_short1 convert_short
1101#define convert_ushort1 convert_ushort
1102#define convert_int1 convert_int
1103#define convert_uint1 convert_uint
1104#define convert_long1 convert_long
1105#define convert_ulong1 convert_ulong
1106#define convert_double1 convert_double
1107
1108#define convert_char1_sat convert_char_sat
1109#define convert_uchar1_sat convert_uchar_sat
1110#define convert_short1_sat convert_short_sat
1111#define convert_ushort1_sat convert_ushort_sat
1112#define convert_int1_sat convert_int_sat
1113#define convert_uint1_sat convert_uint_sat
1114#define convert_long1_sat convert_long_sat
1115#define convert_ulong1_sat convert_ulong_sat
1116#define convert_double1_sat convert_double_sat
1117
1118#define VEC_DATA_TYPE_STR(type, size) type##size
1119#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
1120
1121#define CONVERT_STR(x, type) (convert_##type((x)))
1122#define CONVERT(x, type) CONVERT_STR(x, type)
1123
1124#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
1125#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
1126
1127#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
1128#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
1129
1130#define select_vec_dt_uchar(size) uchar##size
1131#define select_vec_dt_char(size) char##size
1132#define select_vec_dt_ushort(size) ushort##size
1133#define select_vec_dt_short(size) short##size
1134#define select_vec_dt_half(size) short##size
1135#define select_vec_dt_uint(size) uint##size
1136#define select_vec_dt_int(size) int##size
1137#define select_vec_dt_float(size) int##size
1138#define select_vec_dt_ulong(size) ulong##size
1139#define select_vec_dt_long(size) long##size
1140
1141#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
1142#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
1143#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
1144
1145#define sum_reduce_1(x) (x)
1146#define sum_reduce_2(x) ((x).s0) + ((x).s1)
1147#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
1148#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
1149#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
1150#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
1151
1152#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
1153#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
1154
1155#define max_reduce_1(x) (x)
1156#define max_reduce_2(x) max(((x).s0), ((x).s1))
1157#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
1158#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
1159#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
1160#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
1161
1162#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
1163#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
1164
1165#define VECTOR_DECLARATION(name)     \
1166    __global uchar *name##_ptr,      \
1167    uint        name##_stride_x, \
1168    uint        name##_step_x,   \
1169    uint        name##_offset_first_element_in_bytes
1170
1171#define IMAGE_DECLARATION(name)      \
1172    __global uchar *name##_ptr,      \
1173    uint        name##_stride_x, \
1174    uint        name##_step_x,   \
1175    uint        name##_stride_y, \
1176    uint        name##_step_y,   \
1177    uint        name##_offset_first_element_in_bytes
1178
1179#define TENSOR3D_DECLARATION(name)   \
1180    __global uchar *name##_ptr,      \
1181    uint        name##_stride_x, \
1182    uint        name##_step_x,   \
1183    uint        name##_stride_y, \
1184    uint        name##_step_y,   \
1185    uint        name##_stride_z, \
1186    uint        name##_step_z,   \
1187    uint        name##_offset_first_element_in_bytes
1188
1189#define TENSOR4D_DECLARATION(name)   \
1190    __global uchar *name##_ptr,      \
1191    uint        name##_stride_x, \
1192    uint        name##_step_x,   \
1193    uint        name##_stride_y, \
1194    uint        name##_step_y,   \
1195    uint        name##_stride_z, \
1196    uint        name##_step_z,   \
1197    uint        name##_stride_w, \
1198    uint        name##_step_w,   \
1199    uint        name##_offset_first_element_in_bytes
1200
1201#define CONVERT_TO_VECTOR_STRUCT(name) \
1202    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
1203
1204#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
1205    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
1206
1207#define CONVERT_TO_IMAGE_STRUCT(name) \
1208    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
1209
1210#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
1211    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
1212
1213#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1214    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1215
1216#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
1217    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
1218
1219#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1220    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1221
1222#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
1223    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1224                                 name##_stride_z, name##_step_z)
1225
1226#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
1227    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
1228
1229#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
1230    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1231                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
1232
1233#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
1234    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
1235
1236#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
1237    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1238                           name##_stride_z, name##_step_z)
1239
1240/** Structure to hold Vector information */
1241typedef struct Vector
1242{
1243    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1244    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1245    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1246} Vector;
1247
1248/** Structure to hold Image information */
1249typedef struct Image
1250{
1251    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1252    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1253    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1254    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1255} Image;
1256
1257/** Structure to hold 3D tensor information */
1258typedef struct Tensor3D
1259{
1260    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1261    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1262    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1263    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1264    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1265} Tensor3D;
1266
1267/** Structure to hold 4D tensor information */
1268typedef struct Tensor4D
1269{
1270    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1271    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1272    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1273    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1274    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1275    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
1276} Tensor4D;
1277
1278/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
1279 *
1280 * @param[in] ptr                           Pointer to the starting postion of the buffer
1281 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
1282 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
1283 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1284 *
1285 * @return An image object
1286 */
1287inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
1288{
1289    Vector vector =
1290    {
1291        .ptr                           = ptr,
1292        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1293        .stride_x                      = stride_x,
1294    };
1295    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
1296    return vector;
1297}
1298
1299/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
1300 *
1301 * @param[in] ptr                           Pointer to the starting postion of the buffer
1302 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1303 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1304 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1305 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1306 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1307 *
1308 * @return An image object
1309 */
1310inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
1311{
1312    Image img =
1313    {
1314        .ptr                           = ptr,
1315        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1316        .stride_x                      = stride_x,
1317        .stride_y                      = stride_y
1318    };
1319    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
1320    return img;
1321}
1322
1323/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
1324 *
1325 * @param[in] ptr                           Pointer to the starting postion of the buffer
1326 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1327 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1328 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1329 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1330 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1331 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1332 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1333 *
1334 * @return A 3D tensor object
1335 */
1336inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1337{
1338    Image img =
1339    {
1340        .ptr                           = ptr,
1341        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1342        .stride_x                      = stride_x,
1343        .stride_y                      = stride_y
1344    };
1345    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1346    return img;
1347}
1348
1349/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
1350 *
1351 * @param[in] ptr                           Pointer to the starting postion of the buffer
1352 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1353 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1354 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1355 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1356 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1357 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1358 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1359 *
1360 * @return A 3D tensor object
1361 */
1362inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1363{
1364    Tensor3D tensor =
1365    {
1366        .ptr                           = ptr,
1367        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1368        .stride_x                      = stride_x,
1369        .stride_y                      = stride_y,
1370        .stride_z                      = stride_z
1371    };
1372    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1373    return tensor;
1374}
1375
1376/** Wrap 3D tensor information into an tensor structure.
1377 *
1378 * @param[in] ptr                           Pointer to the starting postion of the buffer
1379 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1380 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1381 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1382 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1383 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1384 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1385 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1386 *
1387 * @return A 3D tensor object
1388 */
1389inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1390{
1391    Tensor3D tensor =
1392    {
1393        .ptr                           = ptr,
1394        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1395        .stride_x                      = stride_x,
1396        .stride_y                      = stride_y,
1397        .stride_z                      = stride_z
1398    };
1399    return tensor;
1400}
1401
1402inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
1403                                             uint step_w,
1404                                             uint mod_size)
1405{
1406    Tensor4D tensor =
1407    {
1408        .ptr                           = ptr,
1409        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1410        .stride_x                      = stride_x,
1411        .stride_y                      = stride_y,
1412        .stride_z                      = stride_z,
1413        .stride_w                      = stride_w
1414    };
1415
1416    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
1417    return tensor;
1418}
1419
1420/** Get the pointer position of a Vector
1421 *
1422 * @param[in] vec Pointer to the starting position of the buffer
1423 * @param[in] x   Relative X position
1424 */
1425inline __global const uchar *vector_offset(const Vector *vec, int x)
1426{
1427    return vec->ptr + x * vec->stride_x;
1428}
1429
1430/** Get the pointer position of a Image
1431 *
1432 * @param[in] img Pointer to the starting position of the buffer
1433 * @param[in] x   Relative X position
1434 * @param[in] y   Relative Y position
1435 */
1436inline __global uchar *offset(const Image *img, int x, int y)
1437{
1438    return img->ptr + x * img->stride_x + y * img->stride_y;
1439}
1440
1441/** Get the pointer position of a Tensor3D
1442 *
1443 * @param[in] tensor Pointer to the starting position of the buffer
1444 * @param[in] x      Relative X position
1445 * @param[in] y      Relative Y position
1446 * @param[in] z      Relative Z position
1447 */
1448inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
1449{
1450    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
1451}
1452
1453/** Get the pointer position of a Tensor4D
1454 *
1455 * @param[in] tensor Pointer to the starting position of the buffer
1456 * @param[in] x      Relative X position
1457 * @param[in] y      Relative Y position
1458 * @param[in] z      Relative Z position
1459 * @param[in] w      Relative W position
1460 */
1461inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
1462{
1463    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
1464}
1465
1466/** Get the offset for a given linear index of a Tensor3D
1467 *
1468 * @param[in] tensor Pointer to the starting position of the buffer
1469 * @param[in] width  Width of the input tensor
1470 * @param[in] height Height of the input tensor
1471 * @param[in] depth  Depth of the input tensor
1472 * @param[in] index  Linear index
1473 */
1474inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
1475{
1476    uint num_elements = width * height;
1477
1478    const uint z = index / num_elements;
1479
1480    index %= num_elements;
1481
1482    const uint y = index / width;
1483
1484    index %= width;
1485
1486    const uint x = index;
1487
1488    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
1489}
1490
1491#endif // _HELPER_H
1492
1493/** Performs max unpooling function with pool size equal to 2.
1494 *
1495 * @note Datatype must be passed using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types are F16/F32
1496 * @note The width of the output tensor must be passed using -DWIDTH_DST e.g. -DWIDTH_DST=24
1497 * @note The height of the output tensor must be passed using -DHEIGHT_DST e.g. -DHEIGHT_DST=54
1498 * @note The depth of the output tensor must be passed using -DDEPTH_DST e.g. -DDEPTH_DST=32
1499 *
1500 * @param[in]  input_ptr                             Pointer to the source tensor. Supported data types: F16/F32
1501 * @param[in]  input_stride_x                        Stride of the source tensor in X dimension (in bytes)
1502 * @param[in]  input_step_x                          input_stride_x * number of elements along X processed per workitem(in bytes)
1503 * @param[in]  input_stride_y                        Stride of the source tensor in Y dimension (in bytes)
1504 * @param[in]  input_step_y                          input_stride_y * number of elements along Y processed per workitem(in bytes)
1505 * @param[in]  input_stride_z                        Stride of the source tensor in Z dimension (in bytes)
1506 * @param[in]  input_step_z                          input_stride_z * number of elements along Z processed per workitem(in bytes)
1507 * @param[in]  input_offset_first_element_in_bytes   The offset of the first element in the source tensor
1508 * @param[out] output_ptr                            Pointer to the output tensor. Supported data types: same as @p input_ptr
1509 * @param[in]  output_stride_x                       Stride of the output tensor in X dimension (in bytes)
1510 * @param[in]  output_step_x                         output_stride_x * number of elements along X processed per workitem(in bytes)
1511 * @param[in]  output_stride_y                       Stride of the output tensor in Y dimension (in bytes)
1512 * @param[in]  output_step_y                         output_stride_y * number of elements along Y processed per workitem(in bytes)
1513 * @param[in]  output_stride_z                       Stride of the output tensor in Z dimension (in bytes)
1514 * @param[in]  output_step_z                         output_stride_z * number of elements along Z processed per workitem(in bytes)
1515 * @param[in]  output_offset_first_element_in_bytes  The offset of the first element in the output tensor
1516 * @param[in]  indices_ptr                           Pointer to the indices tensor. Supported data types: U32
1517 * @param[in]  indices_stride_x                      Stride of the indices tensor in X dimension (in bytes)
1518 * @param[in]  indices_step_x                        indices_stride_x * number of elements along X processed per workitem(in bytes)
1519 * @param[in]  indices_stride_y                      Stride of the indices tensor in Y dimension (in bytes)
1520 * @param[in]  indices_step_y                        indices_stride_y * number of elements along Y processed per workitem(in bytes)
1521 * @param[in]  indices_stride_z                      Stride of the indices tensor in Z dimension (in bytes)
1522 * @param[in]  indices_step_z                        indices_stride_z * number of elements along Z processed per workitem(in bytes)
1523 * @param[in]  indices_offset_first_element_in_bytes The offset of the first element in the indices tensor
1524 */
1525__kernel void max_unpooling_layer_2(
1526    TENSOR3D_DECLARATION(input),
1527    TENSOR3D_DECLARATION(output),
1528    TENSOR3D_DECLARATION(indices))
1529{
1530    Tensor3D input   = CONVERT_TO_TENSOR3D_STRUCT(input);
1531    Tensor3D output  = CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(output);
1532    Tensor3D indices = CONVERT_TO_TENSOR3D_STRUCT(indices);
1533
1534    unsigned int index = *((__global unsigned int *)indices.ptr);
1535    DATA_TYPE value    = *((__global DATA_TYPE *)input.ptr);
1536
1537    *((__global DATA_TYPE *)tensor3D_index2ptr(&output, WIDTH_DST, HEIGHT_DST, DEPTH_DST, index)) = value;
1538}
1539
1540)"