• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1R"(
2
3/*
4 * Copyright (c) 2018-2020 Arm Limited.
5 *
6 * SPDX-License-Identifier: MIT
7 *
8 * Permission is hereby granted, free of charge, to any person obtaining a copy
9 * of this software and associated documentation files (the "Software"), to
10 * deal in the Software without restriction, including without limitation the
11 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
12 * sell copies of the Software, and to permit persons to whom the Software is
13 * furnished to do so, subject to the following conditions:
14 *
15 * The above copyright notice and this permission notice shall be included in all
16 * copies or substantial portions of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 * SOFTWARE.
25 */
26#if defined(DATA_TYPE) && defined(ACTIVATION_TYPE) && defined(NUM_CLASSES) && defined(VEC_SIZE)
27
28/*
29 * Copyright (c) 2019-2020 Arm Limited.
30 *
31 * SPDX-License-Identifier: MIT
32 *
33 * Permission is hereby granted, free of charge, to any person obtaining a copy
34 * of this software and associated documentation files (the "Software"), to
35 * deal in the Software without restriction, including without limitation the
36 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
37 * sell copies of the Software, and to permit persons to whom the Software is
38 * furnished to do so, subject to the following conditions:
39 *
40 * The above copyright notice and this permission notice shall be included in all
41 * copies or substantial portions of the Software.
42 *
43 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
44 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
45 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
46 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
47 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
48 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
49 * SOFTWARE.
50 */
51
52/*
53 * Copyright (c) 2016-2020 Arm Limited.
54 *
55 * SPDX-License-Identifier: MIT
56 *
57 * Permission is hereby granted, free of charge, to any person obtaining a copy
58 * of this software and associated documentation files (the "Software"), to
59 * deal in the Software without restriction, including without limitation the
60 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
61 * sell copies of the Software, and to permit persons to whom the Software is
62 * furnished to do so, subject to the following conditions:
63 *
64 * The above copyright notice and this permission notice shall be included in all
65 * copies or substantial portions of the Software.
66 *
67 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
68 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
69 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
70 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
71 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
72 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
73 * SOFTWARE.
74 */
75#ifndef ARM_COMPUTE_HELPER_H
76#define ARM_COMPUTE_HELPER_H
77
78/*
79 * Copyright (c) 2020 Arm Limited.
80 *
81 * SPDX-License-Identifier: MIT
82 *
83 * Permission is hereby granted, free of charge, to any person obtaining a copy
84 * of this software and associated documentation files (the "Software"), to
85 * deal in the Software without restriction, including without limitation the
86 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
87 * sell copies of the Software, and to permit persons to whom the Software is
88 * furnished to do so, subject to the following conditions:
89 *
90 * The above copyright notice and this permission notice shall be included in all
91 * copies or substantial portions of the Software.
92 *
93 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
94 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
95 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
96 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
97 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
98 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
99 * SOFTWARE.
100 */
101
102/** Store the 0 to (n-1)th rows of the given variables
103 * @name STORE_ROW_n
104 *
105 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
106 * @param[in] DATA_TYPE The data type of the vectors
107 * @param[in] BASENAME  The basename of the variables
108 * @param[in] PTR       The base pointer
109 * @param[in] STRIDE_Y  The stride value in y-axis direction
110 * @param[in] Z         The offset in z-axis direction
111 * @{
112 */
113#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
114    VSTORE(N0)                                                 \
115    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
116
117#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
118    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
119    VSTORE(N0)                                                 \
120    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
121
122#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
123    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
124    VSTORE(N0)                                                 \
125    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
126
127#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
128    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
129    VSTORE(N0)                                                 \
130    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
131
132#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
133    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
134    VSTORE(N0)                                                 \
135    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
136
137#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
138    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
139    VSTORE(N0)                                                 \
140    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
141
142#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
143    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
144    VSTORE(N0)                                                 \
145    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
146
147#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
148    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
149    VSTORE(N0)                                                 \
150    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
151
152#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
153    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
154    VSTORE(N0)                                                 \
155    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
156
157#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
158    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
159    VSTORE(N0)                                                  \
160    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
161
162#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
163    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
164    VSTORE(N0)                                                  \
165    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
166
167#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
168    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
169    VSTORE(N0)                                                  \
170    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
171
172#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
173    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
174    VSTORE(N0)                                                  \
175    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
176
177#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
178    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
179    VSTORE(N0)                                                  \
180    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
181
182#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
183    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
184    VSTORE(N0)                                                  \
185    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
186
187#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
188    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
189    VSTORE(N0)                                                  \
190    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
191/** @} */ // end of groupd STORE_ROW_n
192
193/** Convert and store the 0th to (n-1)th rows of the given variables
194 * @name CONVERT_STORE_ROW_n
195 *
196 * @param[in] N0        The size of the vectors
197 * @param[in] DATA_TYPE The data type of the vectors
198 * @param[in] BASENAME  The basename of the variables
199 * @param[in] PTR       The base pointer
200 * @param[in] STRIDE_Y  The stride value in y-axis direction
201 * @param[in] Z         The offset in z-axis direction
202 * @{
203 */
204#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
205    VSTORE(N0)                                                         \
206    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
207
208#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
209    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
210    VSTORE(N0)                                                         \
211    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
212
213#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
214    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
215    VSTORE(N0)                                                         \
216    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
217
218#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
219    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
220    VSTORE(N0)                                                         \
221    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
222
223#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
224    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
225    VSTORE(N0)                                                         \
226    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
227
228#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
229    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
230    VSTORE(N0)                                                         \
231    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
232
233#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
234    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
235    VSTORE(N0)                                                         \
236    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
237
238#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
239    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
240    VSTORE(N0)                                                         \
241    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
242
243#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
244    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
245    VSTORE(N0)                                                         \
246    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
247
248#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
249    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
250    VSTORE(N0)                                                     \
251    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
252
253#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
254    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
255    VSTORE(N0)                                                          \
256    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
257
258#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
259    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
260    VSTORE(N0)                                                          \
261    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
262
263#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
264    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
265    VSTORE(N0)                                                          \
266    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
267
268#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
269    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
270    VSTORE(N0)                                                          \
271    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
272
273#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
274    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
275    VSTORE(N0)                                                          \
276    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
277
278#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
279    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
280    VSTORE(N0)                                                          \
281    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
282
283/** @} */ // end of groupd CONVERT_STORE_ROW_n
284
285/** Store a block of the given size M0xN0
286 * @name STORE_BLOCK
287 *
288 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
289 * The data to store is expected to have consecutive names for each row.
290 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
291 * The Z offset is expected to have consecutive names.
292 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
293 *
294 * @param[in] M0        The number of rows to store
295 * @param[in] N0        The size of each vector
296 * @param[in] DATA_TYPE The data type of the vectors
297 * @param[in] BASENAME  The basename of the variables
298 * @param[in] PTR       The base pointer
299 * @param[in] STRIDE_Y  The stride value in y-axis direction
300 * @param[in] Z         The offset in z-axis direction
301 * @{
302 */
303#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
304#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
305/** @} */ // end of group STORE_BLOCK
306
307/** Convert and store a block of the given size M0xN0
308 * @name CONVERT_STORE_BLOCK
309 *
310 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
311 * The data to store is expected to have consecutive names for each row.
312 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
313 * The Z offset is expected to have consecutive names.
314 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
315 *
316 * @param[in] M0        The number of rows to store
317 * @param[in] N0        The size of each vector
318 * @param[in] DATA_TYPE The data type of the vectors
319 * @param[in] BASENAME  The basename of the variables
320 * @param[in] PTR       The base pointer
321 * @param[in] STRIDE_Y  The stride value in y-axis direction
322 * @param[in] Z         The offset in z-axis direction
323 * @{
324 */
325#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
326#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
327/** @} */ // end of group CONVERT_STORE_BLOCK
328
329/** Partially store the 0 to (n-1)th rows of the given variables
330 * @name STORE_ROW_PARTIAL_n
331 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
332 *
333 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
334 *
335 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
336 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
337 * @param[in] DATA_TYPE The data type of the vectors
338 * @param[in] BASENAME  The basename of the variables
339 * @param[in] PTR       The base pointer
340 * @param[in] STRIDE_Y  The stride value in y-axis direction
341 * @param[in] Z         The offset in z-axis direction
342 * @{
343 */
344#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
345    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
346    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
347
348#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
349    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
350    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
351    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
352
353#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
354    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
355    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
356    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
357
358#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
359    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
360    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
361    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
362
363#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
364    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
365    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
366    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
367
368#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
369    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
370    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
371    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
372
373#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
374    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
375    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
376    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
377
378#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
379    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
380    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
381    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
382
383#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
384    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
385    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
386    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
387
388#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
389    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
390    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
391    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
392
393#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
394    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
395    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
396    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
397
398#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
399    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
400    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
401    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
402
403#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
404    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
405    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
406    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
407
408#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
409    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
410    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
411    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
412
413#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
414    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
415    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
416    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
417
418#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
419    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
420    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
421    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
422/** @} */ // end of groupd STORE_ROW_PARTIAL_n
423
424/** Partially store a block of the given size STORE_M0xSTORE_N0
425 * @name STORE_BLOCK_PARTIAL
426 *
427 * @note The vector width @p N0 is also required for correct partial storing behaviour.
428 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
429 *
430 * The data to store is expected to have consecutive names for each row.
431 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
432 * The Z offset is expected to have consecutive names.
433 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
434 *
435 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
436 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
437 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
438 * @param[in] DATA_TYPE The data type of the vectors
439 * @param[in] BASENAME  The basename of the variables
440 * @param[in] PTR       The base pointer
441 * @param[in] STRIDE_Y  The stride value in y-axis direction
442 * @param[in] Z         The offset in z-axis direction
443 * @{
444 */
445#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
446#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
447/** Store a block that can be partial in both x and y dimensions
448 *
449 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
450 *
451 * The data to store is expected to have consecutive names for each row.
452 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
453 * The Z offset is expected to have consecutive names.
454 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
455 *
456 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
457 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
458 * @param[in] DATA_TYPE        The data type of the vectors
459 * @param[in] BASENAME         The basename of the variables
460 * @param[in] PTR              The base pointer
461 * @param[in] STRIDE_Y         The stride value in y-axis direction
462 * @param[in] Z                The offset in z-axis direction
463 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
464 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
465 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
466 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
467 */
468#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
469    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
470    {                                                                                                                                                     \
471        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
472    }                                                                                                                                                     \
473    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
474    {                                                                                                                                                     \
475        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
476    }                                                                                                                                                     \
477    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
478    {                                                                                                                                                     \
479        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
480    }                                                                                                                                                     \
481    else                                                                                                                                                  \
482    {                                                                                                                                                     \
483        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
484    }
485/** Store a block that can only be partial in x but not y.
486 *
487 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
488 *
489 * The data to store is expected to have consecutive names for each row.
490 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
491 * The Z offset is expected to have consecutive names.
492 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
493 *
494 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
495 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
496 * @param[in] DATA_TYPE        The data type of the vectors
497 * @param[in] BASENAME         The basename of the variables
498 * @param[in] PTR              The base pointer
499 * @param[in] STRIDE_Y         The stride value in y-axis direction
500 * @param[in] Z                The offset in z-axis direction
501 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
502 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
503 */
504#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
505    if(!(PARTIAL_COND_X))                                                                                         \
506    {                                                                                                             \
507        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
508    }                                                                                                             \
509    else                                                                                                          \
510    {                                                                                                             \
511        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
512    }
513/** Store a block that can only be partial in y but not x.
514 *
515 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
516 *
517 * The data to store is expected to have consecutive names for each row.
518 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
519 * The Z offset is expected to have consecutive names.
520 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
521 *
522 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
523 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
524 * @param[in] DATA_TYPE        The data type of the vectors
525 * @param[in] BASENAME         The basename of the variables
526 * @param[in] PTR              The base pointer
527 * @param[in] STRIDE_Y         The stride value in y-axis direction
528 * @param[in] Z                The offset in z-axis direction
529 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
530 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
531 */
532#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
533    if(!(PARTIAL_COND_Y))                                                                                         \
534    {                                                                                                             \
535        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
536    }                                                                                                             \
537    else                                                                                                          \
538    {                                                                                                             \
539        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
540    }
541/** @} */ // end of group STORE_BLOCK_PARTIAL
542
543#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
544
545/** Boundary-aware GEMM block store
546 * @name STORE_BLOCK_BOUNDARY_AWARE
547 * This macro assumes the following schemes to achieve boundary-awareness:
548 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
549 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
550 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
551 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
552 *
553 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
554 * blocks **at the end**.
555 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
556 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
557 *
558 *  *--x-->                         x == 0                        x == 1
559 *  |                  |<------------------------------N-------------------------->|
560 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
561 *  |     -------------#############################################################
562 *  *     |          | |...............................|...........................|
563 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
564 *        |          | |...............................|...........................|
565 *        M          --#############################################################
566 *        |          | |                               |...........................|
567 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
568 *        |          | |                               |...........................|
569 *        |------------#############################################################
570 *
571 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
572 *
573 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
574 *
575 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
576 * and select corresponding store methods such that the boundary detection logic is only added when needed.
577 *
578 * The data to store is expected to have consecutive names for each row.
579 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
580 * The Z offset is expected to have consecutive names.
581 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
582 *
583 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
584 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
585 * @param[in] DATA_TYPE        The data type of the vectors
586 * @param[in] BASENAME         The basename of the variables
587 * @param[in] PTR              The base pointer
588 * @param[in] STRIDE_Y         The stride value in y-axis direction
589 * @param[in] Z                The offset in z-axis direction
590 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
591 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
592 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
593 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
594 * @{
595 */
596#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
597// Case1: No partial blocks in either x or y
598#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
599    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
600
601#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
602// Case2: Partial blocks in y
603#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
604    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
605
606#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
607// Case3: Partial blocks in x
608#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
609    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
610
611#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
612// Case4: Partial blocks in both x and y
613#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
614    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
615
616#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
617
618#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
619/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
620
621#if defined(PARTIAL_STORE_M0)
622/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
623 * @name COMPUTE_M0_START_ROW
624 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
625 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
626 * blocks in the y dimension to avoid any padding.
627 * EG: M0=4, PARTIAL_STORE_M0=1:
628 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
629 * block 0 (partial)| start row = 0   | start row = 0
630 * block 1 (full)   | start row = 4   | start row = 1
631 * block 2 (full)   | start row = 8   | start row = 5
632 *
633 * @param[in] y                Global id of current block in y.
634 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
635 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
636 * @{
637 */
638#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
639    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
640#else // defined(PARTIAL_STORE_M0)
641#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
642    ((uint)(y * M0))
643#endif    // defined(PARTIAL_STORE_M0)
644/** @} */ // end of group COMPUTE_M0_START_ROW
645
646/** Store a vector that can only be partial in x.
647 *
648 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
649 *
650 * The data to store is expected to end in a 0.
651 * E.g., for basename=c, the expected name is c0.
652 *
653 * @param[in] basename  The name of the variable without trailing 0
654 * @param[in] data_type The data type of the vector
655 * @param[in] ptr       The base pointer
656 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
657 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
658 * @param[in] cond      Condition to select either vec_size0 or vec_size1
659 * @{
660 */
661#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
662    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
663/** @} */ // end of group STORE_VECTOR_SELECT
664
665#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
666#pragma OPENCL EXTENSION cl_khr_fp16 : enable
667#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
668
669#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
670#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
671#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
672
673#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
674#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
675#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
676
677#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
678#pragma OPENCL EXTENSION cl_arm_printf : enable
679#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
680
681#define GPU_ARCH_MIDGARD 0x100
682#define GPU_ARCH_BIFROST 0x200
683
684/** Concatenate two inputs.
685 *
686 * @param[in] a The first input to be concatenated
687 * @param[in] b The second input to be concatenated
688 *
689 * @return The concatenated output
690 */
691#define CONCAT(a, b) a##b
692
693/** Expand the given vector
694 *
695 * @param[in] x The vector to be expanded
696 *
697 * @return The expanded output
698 */
699#define EXPAND(x) x
700
701/** Clamp the given value between an upper and lower bound.
702 *
703 * @param[in] x       The value to be clamped
704 * @param[in] min_val The lower bound
705 * @param[in] max_val The upper bound
706 *
707 * @return The clamped value.
708 */
709#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
710
711/** REVn reverses the given vector whose size is n.
712 * @name REVn
713 *
714 * @param[in] x The vector to be reversed
715 *
716 * @return The reversed vector
717 * @{
718 */
719#define REV1(x) ((x))
720#define REV2(x) ((x).s10)
721#define REV3(x) ((x).s210)
722#define REV4(x) ((x).s3210)
723#define REV8(x) ((x).s76543210)
724#define REV16(x) ((x).sFEDCBA9876543210)
725/** @} */ // end of group REVn
726
727/** Reverse the given vector.
728 * @name REVERSE
729 *
730 * @param[in] x The vector to be reversed
731 * @param[in] s The size of the vector
732 *
733 * @return The reversed vector
734 * @{
735 */
736#define REVERSE_STR(x, s) REV##s((x))
737#define REVERSE(x, s) REVERSE_STR(x, s)
738/** @} */ // end of group REVERSE
739
740/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
741 * @name ROTs_n
742 *
743 * @param[in] x The vector to be shifted
744 *
745 * @return The shifted vector
746 * @{
747 */
748#define ROT1_0(x) ((x))
749
750#define ROT2_0(x) ((x))
751#define ROT2_1(x) ((x).s10)
752
753#define ROT3_0(x) ((x))
754#define ROT3_1(x) ((x).s201)
755#define ROT3_2(x) ((x).s120)
756
757#define ROT4_0(x) ((x))
758#define ROT4_1(x) ((x).s3012)
759#define ROT4_2(x) ((x).s2301)
760#define ROT4_3(x) ((x).s1230)
761
762#define ROT8_0(x) ((x))
763#define ROT8_1(x) ((x).s70123456)
764#define ROT8_2(x) ((x).s67012345)
765#define ROT8_3(x) ((x).s56701234)
766#define ROT8_4(x) ((x).s45670123)
767#define ROT8_5(x) ((x).s34567012)
768#define ROT8_6(x) ((x).s23456701)
769#define ROT8_7(x) ((x).s12345670)
770
771#define ROT16_0(x) ((x))
772#define ROT16_1(x) ((x).sF0123456789ABCDE)
773#define ROT16_2(x) ((x).sEF0123456789ABCD)
774#define ROT16_3(x) ((x).sDEF0123456789ABC)
775#define ROT16_4(x) ((x).sCDEF0123456789AB)
776#define ROT16_5(x) ((x).sBCDEF0123456789A)
777#define ROT16_6(x) ((x).sABCDEF0123456789)
778#define ROT16_7(x) ((x).s9ABCDEF012345678)
779#define ROT16_8(x) ((x).s89ABCDEF01234567)
780#define ROT16_9(x) ((x).s789ABCDEF0123456)
781#define ROT16_10(x) ((x).s6789ABCDEF012345)
782#define ROT16_11(x) ((x).s56789ABCDEF01234)
783#define ROT16_12(x) ((x).s456789ABCDEF0123)
784#define ROT16_13(x) ((x).s3456789ABCDEF012)
785#define ROT16_14(x) ((x).s23456789ABCDEF01)
786#define ROT16_15(x) ((x).s123456789ABCDEF0)
787/** @} */ // end of group ROTs_n
788
789/** Circular-right-shift (rotate-right) the given vector by the given amount.
790 * @name ROTATE
791 *
792 * @param[in] x The vector to be shifted
793 * @param[in] s The size of the vector
794 * @param[in] n The amount to be shifted
795 *
796 * @return The shifted vector
797 * @{
798 */
799#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
800#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
801/** @} */ // end of group ROTATE
802
803/** Creates a vector of size n filled with offset values corresponding to the location of each element.
804 * @name V_OFFSn
805 *
806 * @param[in] dt The data type of the output vector
807 *
808 * @return The vector filled with offset values
809 * @{
810 */
811#define V_OFFS1(dt) (dt##1)(0)
812#define V_OFFS2(dt) (dt##2)(0, 1)
813#define V_OFFS3(dt) (dt##3)(0, 1, 2)
814#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
815#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
816#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
817/** @} */ // end of group V_OFFSn
818
819/** Create a vector filled with offset values corresponding to the location of each element.
820 * @name VEC_OFFS
821 *
822 * @param[in] dt The data type of the output vector
823 * @param[in] s  The size of the output vector
824 *
825 * @return The vector filled with offset values
826 * @{
827 */
828#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
829#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
830/** @} */ // end of group VEC_OFFS
831
832#define VLOAD_STR(size) vload##size
833#define VLOAD(size) VLOAD_STR(size)
834
835#define PIXEL_UNIT4 1
836#define PIXEL_UNIT8 2
837#define PIXEL_UNIT16 4
838
839/** Utility macro to convert a vector size in pixel unit.
840 *
841 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
842 *
843 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
844 *
845 * @return The pixel unit (number of pixels)
846 * @{
847 */
848#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
849#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
850/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
851
852#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
853#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
854#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
855
856#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
857#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
858#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
859#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
860#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
861
862/** Utility macro to read a 2D OpenCL image object.
863 *
864 * @note Coordinates are not normalized
865 *
866 * @param[in] data_type Data type
867 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
868 * @param[in] img       OpenCL image object
869 * @param[in] x_coord   The x coordinate for the top-left pixel
870 * @param[in] y_coord   The y coordinate for the top-left pixel
871 *
872 * @return Pixels from the 2D OpenCL image object
873 * @{
874 */
875#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
876#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
877
878#define VSTORE_STR(size) vstore##size
879#define VSTORE(size) VSTORE_STR(size)
880
881#define float1 float
882#define half1 half
883#define char1 char
884#define uchar1 uchar
885#define short1 short
886#define ushort1 ushort
887#define int1 int
888#define uint1 uint
889#define long1 long
890#define ulong1 ulong
891#define double1 double
892
893#define vload1(OFFSET, PTR) *(OFFSET + PTR)
894#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
895
896/** Extended partial vstore that correctly handles scalar values as well.
897 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
898 * @name VSTORE_PARTIAL
899 *
900 * @note With this macro, the passed data can be both a vector and a scalar
901 * @note @p store_size needs to be <= @p size
902 * eg 1: Valid
903 * VSTORE_PARTIAL(16, 15) ...;
904 * eg 2: Invalid
905 * VSTORE_PARTIAL(4, 7) ...;
906 *
907 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
908 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
909 * @{
910 */
911#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
912#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
913
914#define NO_STORE(data, offs, ptr) \
915    {                             \
916    }
917
918// Size == 1 (scalar)
919#define vstore_partial_1_0 NO_STORE
920#define vstore_partial_1_1 vstore1
921#define vstore_partial_1_2 NO_STORE
922#define vstore_partial_1_3 NO_STORE
923#define vstore_partial_1_4 NO_STORE
924#define vstore_partial_1_5 NO_STORE
925#define vstore_partial_1_6 NO_STORE
926#define vstore_partial_1_7 NO_STORE
927#define vstore_partial_1_8 NO_STORE
928#define vstore_partial_1_9 NO_STORE
929#define vstore_partial_1_10 NO_STORE
930#define vstore_partial_1_11 NO_STORE
931#define vstore_partial_1_12 NO_STORE
932#define vstore_partial_1_13 NO_STORE
933#define vstore_partial_1_14 NO_STORE
934#define vstore_partial_1_15 NO_STORE
935#define vstore_partial_1_16 NO_STORE
936// Size == 2
937#define vstore_partial_2_0 NO_STORE
938#define vstore_partial_2_1 vstore_partial_1
939#define vstore_partial_2_2 vstore_partial_2
940#define vstore_partial_2_3 NO_STORE
941#define vstore_partial_2_4 NO_STORE
942#define vstore_partial_2_5 NO_STORE
943#define vstore_partial_2_6 NO_STORE
944#define vstore_partial_2_7 NO_STORE
945#define vstore_partial_2_8 NO_STORE
946#define vstore_partial_2_9 NO_STORE
947#define vstore_partial_2_10 NO_STORE
948#define vstore_partial_2_11 NO_STORE
949#define vstore_partial_2_12 NO_STORE
950#define vstore_partial_2_13 NO_STORE
951#define vstore_partial_2_14 NO_STORE
952#define vstore_partial_2_15 NO_STORE
953#define vstore_partial_2_16 NO_STORE
954// Size == 3
955#define vstore_partial_3_0 NO_STORE
956#define vstore_partial_3_1 vstore_partial_1
957#define vstore_partial_3_2 vstore_partial_2
958#define vstore_partial_3_3 vstore_partial_3
959#define vstore_partial_3_4 NO_STORE
960#define vstore_partial_3_5 NO_STORE
961#define vstore_partial_3_6 NO_STORE
962#define vstore_partial_3_7 NO_STORE
963#define vstore_partial_3_8 NO_STORE
964#define vstore_partial_3_9 NO_STORE
965#define vstore_partial_3_10 NO_STORE
966#define vstore_partial_3_11 NO_STORE
967#define vstore_partial_3_12 NO_STORE
968#define vstore_partial_3_13 NO_STORE
969#define vstore_partial_3_14 NO_STORE
970#define vstore_partial_3_15 NO_STORE
971#define vstore_partial_3_16 NO_STORE
972// Size == 4
973#define vstore_partial_4_0 NO_STORE
974#define vstore_partial_4_1 vstore_partial_1
975#define vstore_partial_4_2 vstore_partial_2
976#define vstore_partial_4_3 vstore_partial_3
977#define vstore_partial_4_4 vstore_partial_4
978#define vstore_partial_4_5 NO_STORE
979#define vstore_partial_4_6 NO_STORE
980#define vstore_partial_4_7 NO_STORE
981#define vstore_partial_4_8 NO_STORE
982#define vstore_partial_4_9 NO_STORE
983#define vstore_partial_4_10 NO_STORE
984#define vstore_partial_4_11 NO_STORE
985#define vstore_partial_4_12 NO_STORE
986#define vstore_partial_4_13 NO_STORE
987#define vstore_partial_4_14 NO_STORE
988#define vstore_partial_4_15 NO_STORE
989#define vstore_partial_4_16 NO_STORE
990// Size == 8
991#define vstore_partial_8_0 NO_STORE
992#define vstore_partial_8_1 vstore_partial_1
993#define vstore_partial_8_2 vstore_partial_2
994#define vstore_partial_8_3 vstore_partial_3
995#define vstore_partial_8_4 vstore_partial_4
996#define vstore_partial_8_5 vstore_partial_5
997#define vstore_partial_8_6 vstore_partial_6
998#define vstore_partial_8_7 vstore_partial_7
999#define vstore_partial_8_8 vstore_partial_8
1000#define vstore_partial_8_9 NO_STORE
1001#define vstore_partial_8_10 NO_STORE
1002#define vstore_partial_8_11 NO_STORE
1003#define vstore_partial_8_12 NO_STORE
1004#define vstore_partial_8_13 NO_STORE
1005#define vstore_partial_8_14 NO_STORE
1006#define vstore_partial_8_15 NO_STORE
1007#define vstore_partial_8_16 NO_STORE
1008// Size == 16
1009#define vstore_partial_16_0 NO_STORE
1010#define vstore_partial_16_1 vstore_partial_1
1011#define vstore_partial_16_2 vstore_partial_2
1012#define vstore_partial_16_3 vstore_partial_3
1013#define vstore_partial_16_4 vstore_partial_4
1014#define vstore_partial_16_5 vstore_partial_5
1015#define vstore_partial_16_6 vstore_partial_6
1016#define vstore_partial_16_7 vstore_partial_7
1017#define vstore_partial_16_8 vstore_partial_8
1018#define vstore_partial_16_9 vstore_partial_9
1019#define vstore_partial_16_10 vstore_partial_10
1020#define vstore_partial_16_11 vstore_partial_11
1021#define vstore_partial_16_12 vstore_partial_12
1022#define vstore_partial_16_13 vstore_partial_13
1023#define vstore_partial_16_14 vstore_partial_14
1024#define vstore_partial_16_15 vstore_partial_15
1025#define vstore_partial_16_16 vstore_partial_16
1026
1027/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
1028 * @name vstore_partial_n
1029 *
1030 * @note @p DATA needs to be a vector not a scalar
1031 * @note n needs to be <= the vector width of the input variable @p DATA
1032 * eg 1: Valid
1033 * vstore_partial_15(var:float16, 0, 0xabcd);
1034 * eg 2: Invalid
1035 * vstore_partial_7(var:float4, 0, 0xabcd);
1036 *
1037 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
1038 *
1039 * @param[in] DATA   The name of the variable
1040 * @param[in] OFFSET Offset in n
1041 * @param[in] PTR    The base pointer
1042 * @{
1043 */
1044#define vstore_partial_1(DATA, OFFSET, PTR) \
1045    vstore1(DATA.s0, OFFSET, PTR);
1046
1047#define vstore_partial_2(DATA, OFFSET, PTR) \
1048    vstore2(DATA.s01, OFFSET, PTR);
1049
1050#define vstore_partial_3(DATA, OFFSET, PTR) \
1051    vstore3(DATA.s012, OFFSET, PTR);
1052
1053#define vstore_partial_4(DATA, OFFSET, PTR) \
1054    vstore4(DATA.s0123, OFFSET, PTR);
1055
1056#define vstore_partial_5(DATA, OFFSET, PTR)    \
1057    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1058    vstore1(DATA.s4, OFFSET, PTR + 4);
1059
1060#define vstore_partial_6(DATA, OFFSET, PTR)    \
1061    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1062    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
1063
1064#define vstore_partial_7(DATA, OFFSET, PTR)    \
1065    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1066    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
1067
1068#define vstore_partial_8(DATA, OFFSET, PTR) \
1069    vstore8(DATA.s01234567, OFFSET, PTR);
1070
1071#define vstore_partial_9(DATA, OFFSET, PTR)        \
1072    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1073    vstore1(DATA.s8, OFFSET, PTR + 8);
1074
1075#define vstore_partial_10(DATA, OFFSET, PTR)       \
1076    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1077    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
1078
1079#define vstore_partial_11(DATA, OFFSET, PTR)       \
1080    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1081    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
1082
1083#define vstore_partial_12(DATA, OFFSET, PTR)       \
1084    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1085    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
1086
1087#define vstore_partial_13(DATA, OFFSET, PTR)       \
1088    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1089    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
1090
1091#define vstore_partial_14(DATA, OFFSET, PTR)       \
1092    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1093    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
1094
1095#define vstore_partial_15(DATA, OFFSET, PTR)       \
1096    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1097    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
1098
1099#define vstore_partial_16(DATA, OFFSET, PTR) \
1100    vstore16(DATA, OFFSET, PTR);
1101/** @} */ // end of groupd vstore_partial_n
1102/** @} */ // end of groupd VSTORE_PARTIAL
1103
1104// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
1105// without _sat to overcome this issue
1106#define convert_float_sat convert_float
1107#define convert_float1_sat convert_float
1108#define convert_float2_sat convert_float2
1109#define convert_float3_sat convert_float3
1110#define convert_float4_sat convert_float4
1111#define convert_float8_sat convert_float8
1112#define convert_float16_sat convert_float16
1113#define convert_half_sat convert_float
1114#define convert_half1_sat convert_half
1115#define convert_half2_sat convert_half2
1116#define convert_half3_sat convert_half3
1117#define convert_half4_sat convert_half4
1118#define convert_half8_sat convert_half8
1119#define convert_half16_sat convert_half16
1120
1121#define convert_float1 convert_float
1122#define convert_half1 convert_half
1123#define convert_char1 convert_char
1124#define convert_uchar1 convert_uchar
1125#define convert_short1 convert_short
1126#define convert_ushort1 convert_ushort
1127#define convert_int1 convert_int
1128#define convert_uint1 convert_uint
1129#define convert_long1 convert_long
1130#define convert_ulong1 convert_ulong
1131#define convert_double1 convert_double
1132
1133#define convert_char1_sat convert_char_sat
1134#define convert_uchar1_sat convert_uchar_sat
1135#define convert_short1_sat convert_short_sat
1136#define convert_ushort1_sat convert_ushort_sat
1137#define convert_int1_sat convert_int_sat
1138#define convert_uint1_sat convert_uint_sat
1139#define convert_long1_sat convert_long_sat
1140#define convert_ulong1_sat convert_ulong_sat
1141#define convert_double1_sat convert_double_sat
1142
1143#define VEC_DATA_TYPE_STR(type, size) type##size
1144#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
1145
1146#define CONVERT_STR(x, type) (convert_##type((x)))
1147#define CONVERT(x, type) CONVERT_STR(x, type)
1148
1149#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
1150#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
1151
1152#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
1153#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
1154
1155#define select_vec_dt_uchar(size) uchar##size
1156#define select_vec_dt_char(size) char##size
1157#define select_vec_dt_ushort(size) ushort##size
1158#define select_vec_dt_short(size) short##size
1159#define select_vec_dt_half(size) short##size
1160#define select_vec_dt_uint(size) uint##size
1161#define select_vec_dt_int(size) int##size
1162#define select_vec_dt_float(size) int##size
1163#define select_vec_dt_ulong(size) ulong##size
1164#define select_vec_dt_long(size) long##size
1165
1166#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
1167#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
1168#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
1169
1170#define sum_reduce_1(x) (x)
1171#define sum_reduce_2(x) ((x).s0) + ((x).s1)
1172#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
1173#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
1174#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
1175#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
1176
1177#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
1178#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
1179
1180#define max_reduce_1(x) (x)
1181#define max_reduce_2(x) max(((x).s0), ((x).s1))
1182#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
1183#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
1184#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
1185#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
1186
1187#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
1188#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
1189
1190#define VECTOR_DECLARATION(name)     \
1191    __global uchar *name##_ptr,      \
1192    uint        name##_stride_x, \
1193    uint        name##_step_x,   \
1194    uint        name##_offset_first_element_in_bytes
1195
1196#define IMAGE_DECLARATION(name)      \
1197    __global uchar *name##_ptr,      \
1198    uint        name##_stride_x, \
1199    uint        name##_step_x,   \
1200    uint        name##_stride_y, \
1201    uint        name##_step_y,   \
1202    uint        name##_offset_first_element_in_bytes
1203
1204#define TENSOR3D_DECLARATION(name)   \
1205    __global uchar *name##_ptr,      \
1206    uint        name##_stride_x, \
1207    uint        name##_step_x,   \
1208    uint        name##_stride_y, \
1209    uint        name##_step_y,   \
1210    uint        name##_stride_z, \
1211    uint        name##_step_z,   \
1212    uint        name##_offset_first_element_in_bytes
1213
1214#define TENSOR4D_DECLARATION(name)   \
1215    __global uchar *name##_ptr,      \
1216    uint        name##_stride_x, \
1217    uint        name##_step_x,   \
1218    uint        name##_stride_y, \
1219    uint        name##_step_y,   \
1220    uint        name##_stride_z, \
1221    uint        name##_step_z,   \
1222    uint        name##_stride_w, \
1223    uint        name##_step_w,   \
1224    uint        name##_offset_first_element_in_bytes
1225
1226#define CONVERT_TO_VECTOR_STRUCT(name) \
1227    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
1228
1229#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
1230    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
1231
1232#define CONVERT_TO_IMAGE_STRUCT(name) \
1233    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
1234
1235#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
1236    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
1237
1238#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1239    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1240
1241#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
1242    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
1243
1244#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1245    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1246
1247#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
1248    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1249                                 name##_stride_z, name##_step_z)
1250
1251#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
1252    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
1253
1254#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
1255    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1256                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
1257
1258#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
1259    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
1260
1261#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
1262    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1263                           name##_stride_z, name##_step_z)
1264
1265/** Structure to hold Vector information */
1266typedef struct Vector
1267{
1268    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1269    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1270    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1271} Vector;
1272
1273/** Structure to hold Image information */
1274typedef struct Image
1275{
1276    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1277    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1278    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1279    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1280} Image;
1281
1282/** Structure to hold 3D tensor information */
1283typedef struct Tensor3D
1284{
1285    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1286    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1287    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1288    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1289    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1290} Tensor3D;
1291
1292/** Structure to hold 4D tensor information */
1293typedef struct Tensor4D
1294{
1295    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1296    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1297    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1298    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1299    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1300    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
1301} Tensor4D;
1302
1303/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
1304 *
1305 * @param[in] ptr                           Pointer to the starting postion of the buffer
1306 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
1307 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
1308 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1309 *
1310 * @return An image object
1311 */
1312inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
1313{
1314    Vector vector =
1315    {
1316        .ptr                           = ptr,
1317        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1318        .stride_x                      = stride_x,
1319    };
1320    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
1321    return vector;
1322}
1323
1324/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
1325 *
1326 * @param[in] ptr                           Pointer to the starting postion of the buffer
1327 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1328 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1329 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1330 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1331 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1332 *
1333 * @return An image object
1334 */
1335inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
1336{
1337    Image img =
1338    {
1339        .ptr                           = ptr,
1340        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1341        .stride_x                      = stride_x,
1342        .stride_y                      = stride_y
1343    };
1344    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
1345    return img;
1346}
1347
1348/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
1349 *
1350 * @param[in] ptr                           Pointer to the starting postion of the buffer
1351 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1352 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1353 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1354 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1355 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1356 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1357 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1358 *
1359 * @return A 3D tensor object
1360 */
1361inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1362{
1363    Image img =
1364    {
1365        .ptr                           = ptr,
1366        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1367        .stride_x                      = stride_x,
1368        .stride_y                      = stride_y
1369    };
1370    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1371    return img;
1372}
1373
1374/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
1375 *
1376 * @param[in] ptr                           Pointer to the starting postion of the buffer
1377 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1378 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1379 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1380 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1381 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1382 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1383 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1384 *
1385 * @return A 3D tensor object
1386 */
1387inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1388{
1389    Tensor3D tensor =
1390    {
1391        .ptr                           = ptr,
1392        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1393        .stride_x                      = stride_x,
1394        .stride_y                      = stride_y,
1395        .stride_z                      = stride_z
1396    };
1397    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1398    return tensor;
1399}
1400
1401/** Wrap 3D tensor information into an tensor structure.
1402 *
1403 * @param[in] ptr                           Pointer to the starting postion of the buffer
1404 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1405 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1406 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1407 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1408 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1409 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1410 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1411 *
1412 * @return A 3D tensor object
1413 */
1414inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1415{
1416    Tensor3D tensor =
1417    {
1418        .ptr                           = ptr,
1419        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1420        .stride_x                      = stride_x,
1421        .stride_y                      = stride_y,
1422        .stride_z                      = stride_z
1423    };
1424    return tensor;
1425}
1426
1427inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
1428                                             uint step_w,
1429                                             uint mod_size)
1430{
1431    Tensor4D tensor =
1432    {
1433        .ptr                           = ptr,
1434        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1435        .stride_x                      = stride_x,
1436        .stride_y                      = stride_y,
1437        .stride_z                      = stride_z,
1438        .stride_w                      = stride_w
1439    };
1440
1441    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
1442    return tensor;
1443}
1444
1445/** Get the pointer position of a Vector
1446 *
1447 * @param[in] vec Pointer to the starting position of the buffer
1448 * @param[in] x   Relative X position
1449 */
1450inline __global const uchar *vector_offset(const Vector *vec, int x)
1451{
1452    return vec->ptr + x * vec->stride_x;
1453}
1454
1455/** Get the pointer position of a Image
1456 *
1457 * @param[in] img Pointer to the starting position of the buffer
1458 * @param[in] x   Relative X position
1459 * @param[in] y   Relative Y position
1460 */
1461inline __global uchar *offset(const Image *img, int x, int y)
1462{
1463    return img->ptr + x * img->stride_x + y * img->stride_y;
1464}
1465
1466/** Get the pointer position of a Tensor3D
1467 *
1468 * @param[in] tensor Pointer to the starting position of the buffer
1469 * @param[in] x      Relative X position
1470 * @param[in] y      Relative Y position
1471 * @param[in] z      Relative Z position
1472 */
1473inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
1474{
1475    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
1476}
1477
1478/** Get the pointer position of a Tensor4D
1479 *
1480 * @param[in] tensor Pointer to the starting position of the buffer
1481 * @param[in] x      Relative X position
1482 * @param[in] y      Relative Y position
1483 * @param[in] z      Relative Z position
1484 * @param[in] w      Relative W position
1485 */
1486inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
1487{
1488    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
1489}
1490
1491/** Get the offset for a given linear index of a Tensor3D
1492 *
1493 * @param[in] tensor Pointer to the starting position of the buffer
1494 * @param[in] width  Width of the input tensor
1495 * @param[in] height Height of the input tensor
1496 * @param[in] depth  Depth of the input tensor
1497 * @param[in] index  Linear index
1498 */
1499inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
1500{
1501    uint num_elements = width * height;
1502
1503    const uint z = index / num_elements;
1504
1505    index %= num_elements;
1506
1507    const uint y = index / width;
1508
1509    index %= width;
1510
1511    const uint x = index;
1512
1513    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
1514}
1515
1516#endif // _HELPER_H
1517
1518#if GPU_ARCH == GPU_ARCH_BIFROST
1519#define MLA(a, b, c) (fma(c, b, a))
1520#else // GPU_ARCH == GPU_ARCH_BIFROST
1521#define MLA(a, b, c) ((b) * (c) + (a))
1522#endif // GPU_ARCH == GPU_ARCH_BIFROST
1523
1524// Hard-Swish
1525#define hard_swish_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * ((min(max((x + (DATA_TYPE)3.0), (DATA_TYPE)0.0), (DATA_TYPE)6.0)) * (DATA_TYPE)0.166666667))
1526
1527// Logistic Activation
1528#define logistic_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((DATA_TYPE)1.0 / ((DATA_TYPE)1.0 + exp(-x)))
1529
1530// Hyperbolic Tangent Activation
1531#define tanh_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((DATA_TYPE)A_VAL * tanh((DATA_TYPE)B_VAL * x))
1532
1533// RELU Tangent Activation
1534#define relu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (max((DATA_TYPE)0.0, x))
1535
1536// Bounded RELU Activation
1537#define brelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (min((DATA_TYPE)A_VAL, max((DATA_TYPE)0.0, x)))
1538
1539// Lower Upper Bounded RELU Activation
1540#define lu_brelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (min(max(x, (DATA_TYPE)B_VAL), (DATA_TYPE)A_VAL))
1541
1542// Leaky RELU Activation
1543#define lrelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((min(x, (DATA_TYPE)0.0) * (DATA_TYPE)A_VAL) + max(x, (DATA_TYPE)0.0))
1544
1545// Soft RELU Activation
1546#define srelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (log((DATA_TYPE)1.0 + exp(x)))
1547
1548// ELU Activation
1549#define elu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (select(((DATA_TYPE)A_VAL * (exp(x) - (DATA_TYPE)1.0)), x, (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))isgreaterequal(x, (DATA_TYPE)0.0)))
1550
1551// Absolute Activation
1552#define abs_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (fabs(x))
1553
1554// Square Activation
1555#define square_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * x)
1556
1557// Square-root Activation
1558#define sqrt_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (sqrt(x))
1559
1560// Linear Activation
1561#define linear_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (MLA((DATA_TYPE)B_VAL, (DATA_TYPE)A_VAL, x))
1562
1563// Identity Activation
1564#define identity_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x)
1565
1566#define ACT_OP(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) op##_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL)
1567
1568#define ACTIVATION(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ACT_OP(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL)
1569
1570#define SELECT_TYPE SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
1571
1572#if VEC_SIZE != 1
1573#define TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE)
1574
1575/** This performs a YOLO partial activation function for NCHW data layout
1576 *
1577 * @note In order to perform the activation function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1578 *
1579 * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
1580 * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=16
1581 * @note Activation function should be given as a preprocessor argument using -DACTIVATION_TYPE=name. e.g. -DACTIVATION_TYPE=TANH
1582 * @note The number of classes should be given as a preprocessor argument using -DNUM_CLASSES=num. e.g. -DNUM_CLASSES=80
1583 * @note A, B variables required by some activation functions are set using -DA_VAL= and -DB_VAL= respectively.
1584 *
1585 * @param[in]  input_ptr                            Pointer to the source tensor. Supported data types: F16/F32
1586 * @param[in]  input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1587 * @param[in]  input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1588 * @param[in]  input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1589 * @param[in]  input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1590 * @param[in]  input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1591 * @param[in]  input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1592 * @param[in]  input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1593 * @param[out] output_ptr                           (Optional) Pointer to the destination tensor. Supported data types: same as @p input_ptr
1594 * @param[in]  output_stride_x                      (Optional) Stride of the destination tensor in X dimension (in bytes)
1595 * @param[in]  output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1596 * @param[in]  output_stride_y                      (Optional) Stride of the destination tensor in Y dimension (in bytes)
1597 * @param[in]  output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1598 * @param[in]  output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1599 * @param[in]  output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1600 * @param[in]  output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination tensor
1601 */
1602__kernel void yolo_layer_nchw(
1603    TENSOR3D_DECLARATION(input)
1604#ifndef IN_PLACE
1605    ,
1606    TENSOR3D_DECLARATION(output)
1607#endif /* not IN_PLACE */
1608)
1609{
1610    // Get pixels pointer
1611    Tensor3D input = CONVERT_TO_TENSOR3D_STRUCT(input);
1612#ifdef IN_PLACE
1613    Tensor3D output = input;
1614#else  /* IN_PLACE */
1615    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
1616#endif /* IN_PLACE */
1617
1618    const int  box_ch_id = get_global_id(2) % (NUM_CLASSES + 5);
1619    const bool activate  = box_ch_id != 2 && box_ch_id != 3;
1620
1621    if(activate)
1622    {
1623        // Load data
1624        TYPE data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)input.ptr);
1625        data      = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, data, A_VAL, B_VAL); // select(1.0f, ACTIVATION_OP(ACTIVATION_TYPE, data), (SELECT_TYPE)activate);
1626
1627        // Store result
1628        VSTORE(VEC_SIZE)
1629        (data, 0, (__global DATA_TYPE *)output.ptr);
1630    }
1631#ifndef IN_PLACE
1632    else
1633    {
1634        // Load data
1635        TYPE data = VLOAD(VEC_SIZE)(0, (__global DATA_TYPE *)input.ptr);
1636
1637        // Store result
1638        VSTORE(VEC_SIZE)
1639        (data, 0, (__global DATA_TYPE *)output.ptr);
1640    }
1641#endif // IN_PLACE
1642}
1643
1644#else // VEC_SIZE != 1
1645
1646/** This performs a YOLO partial activation function for NCHW data layout
1647 *
1648 * @note In order to perform the activation function "in-place", the pre-processor -DIN_PLACE must be passed at compile time
1649 *
1650 * @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
1651 * @note Vector size should be given as a preprocessor argument using -DVEC_SIZE=size. e.g. -DVEC_SIZE=1
1652 * @note Activation function should be given as a preprocessor argument using -DACTIVATION_TYPE=name. e.g. -DACTIVATION_TYPE=TANH
1653 * @note The number of classes should be given as a preprocessor argument using -DNUM_CLASSES=num. e.g. -DNUM_CLASSES=80
1654 * @note A, B variables required by some activation functions are set using -DA_VAL= and -DB_VAL= respectively.
1655 *
1656 * @param[in]  input_ptr                            Pointer to the source tensor. Supported data types: F16/F32
1657 * @param[in]  input_stride_x                       Stride of the source tensor in X dimension (in bytes)
1658 * @param[in]  input_step_x                         input_stride_x * number of elements along X processed per workitem(in bytes)
1659 * @param[in]  input_stride_y                       Stride of the source tensor in Y dimension (in bytes)
1660 * @param[in]  input_step_y                         input_stride_y * number of elements along Y processed per workitem(in bytes)
1661 * @param[in]  input_stride_z                       Stride of the source tensor in Z dimension (in bytes)
1662 * @param[in]  input_step_z                         input_stride_z * number of elements along Z processed per workitem(in bytes)
1663 * @param[in]  input_offset_first_element_in_bytes  The offset of the first element in the source tensor
1664 * @param[out] output_ptr                           (Optional) Pointer to the destination tensor. Supported data types: same as @p input_ptr
1665 * @param[in]  output_stride_x                      (Optional) Stride of the destination tensor in X dimension (in bytes)
1666 * @param[in]  output_step_x                        (Optional) output_stride_x * number of elements along X processed per workitem(in bytes)
1667 * @param[in]  output_stride_y                      (Optional) Stride of the destination tensor in Y dimension (in bytes)
1668 * @param[in]  output_step_y                        (Optional) output_stride_y * number of elements along Y processed per workitem(in bytes)
1669 * @param[in]  output_stride_z                      (Optional) Stride of the source tensor in Z dimension (in bytes)
1670 * @param[in]  output_step_z                        (Optional) output_stride_z * number of elements along Z processed per workitem(in bytes)
1671 * @param[in]  output_offset_first_element_in_bytes (Optional) The offset of the first element in the destination tensor
1672 */
1673__kernel void yolo_layer_nhwc(
1674    TENSOR3D_DECLARATION(input)
1675#ifndef IN_PLACE
1676    ,
1677    TENSOR3D_DECLARATION(output)
1678#endif /* not IN_PLACE */
1679)
1680{
1681    // Get pixels pointer
1682    Tensor3D input  = CONVERT_TO_TENSOR3D_STRUCT(input);
1683#ifdef IN_PLACE
1684    Tensor3D output = input;
1685#else  /* IN_PLACE */
1686    Tensor3D output = CONVERT_TO_TENSOR3D_STRUCT(output);
1687#endif /* IN_PLACE */
1688
1689    const int  box_ch_id = get_global_id(0) % (NUM_CLASSES + 5);
1690    const bool activate  = box_ch_id != 2 && box_ch_id != 3;
1691
1692    if(activate)
1693    {
1694        // Load data
1695        DATA_TYPE data = *((__global DATA_TYPE *)input.ptr);
1696        data           = select(data, ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, data, A_VAL, B_VAL), (SELECT_TYPE)activate);
1697
1698        // Store result
1699        *((__global DATA_TYPE *)output.ptr) = data;
1700    }
1701#ifndef IN_PLACE
1702    else
1703    {
1704        // Load data
1705        DATA_TYPE data = *((__global DATA_TYPE *)input.ptr);
1706
1707        // Store result
1708        *((__global DATA_TYPE *)output.ptr) = data;
1709    }
1710#endif // IN_PLACE
1711}
1712
1713#endif // VEC_SIZE != 1
1714#endif // defined(DATA_TYPE) && defined(ACTIVATION_TYPE) && defined(NUM_CLASSES) && defined(VEC_SIZE)
1715
1716)"