• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1R"(
2
3/*
4 * Copyright (c) 2016-2019 Arm Limited.
5 *
6 * SPDX-License-Identifier: MIT
7 *
8 * Permission is hereby granted, free of charge, to any person obtaining a copy
9 * of this software and associated documentation files (the "Software"), to
10 * deal in the Software without restriction, including without limitation the
11 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
12 * sell copies of the Software, and to permit persons to whom the Software is
13 * furnished to do so, subject to the following conditions:
14 *
15 * The above copyright notice and this permission notice shall be included in all
16 * copies or substantial portions of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 * SOFTWARE.
25 */
26/*
27 * Copyright (c) 2016-2020 Arm Limited.
28 *
29 * SPDX-License-Identifier: MIT
30 *
31 * Permission is hereby granted, free of charge, to any person obtaining a copy
32 * of this software and associated documentation files (the "Software"), to
33 * deal in the Software without restriction, including without limitation the
34 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
35 * sell copies of the Software, and to permit persons to whom the Software is
36 * furnished to do so, subject to the following conditions:
37 *
38 * The above copyright notice and this permission notice shall be included in all
39 * copies or substantial portions of the Software.
40 *
41 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 * SOFTWARE.
48 */
49#ifndef ARM_COMPUTE_HELPER_H
50#define ARM_COMPUTE_HELPER_H
51
52/*
53 * Copyright (c) 2020 Arm Limited.
54 *
55 * SPDX-License-Identifier: MIT
56 *
57 * Permission is hereby granted, free of charge, to any person obtaining a copy
58 * of this software and associated documentation files (the "Software"), to
59 * deal in the Software without restriction, including without limitation the
60 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
61 * sell copies of the Software, and to permit persons to whom the Software is
62 * furnished to do so, subject to the following conditions:
63 *
64 * The above copyright notice and this permission notice shall be included in all
65 * copies or substantial portions of the Software.
66 *
67 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
68 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
69 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
70 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
71 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
72 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
73 * SOFTWARE.
74 */
75
76/** Store the 0 to (n-1)th rows of the given variables
77 * @name STORE_ROW_n
78 *
79 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
80 * @param[in] DATA_TYPE The data type of the vectors
81 * @param[in] BASENAME  The basename of the variables
82 * @param[in] PTR       The base pointer
83 * @param[in] STRIDE_Y  The stride value in y-axis direction
84 * @param[in] Z         The offset in z-axis direction
85 * @{
86 */
87#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
88    VSTORE(N0)                                                 \
89    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
90
91#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
92    STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
93    VSTORE(N0)                                                 \
94    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
95
96#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
97    STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
98    VSTORE(N0)                                                 \
99    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
100
101#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
102    STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
103    VSTORE(N0)                                                 \
104    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
105
106#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
107    STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
108    VSTORE(N0)                                                 \
109    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
110
111#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
112    STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
113    VSTORE(N0)                                                 \
114    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
115
116#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
117    STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
118    VSTORE(N0)                                                 \
119    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
120
121#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
122    STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
123    VSTORE(N0)                                                 \
124    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
125
126#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
127    STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
128    VSTORE(N0)                                                 \
129    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
130
131#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
132    STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
133    VSTORE(N0)                                                  \
134    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
135
136#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
137    STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
138    VSTORE(N0)                                                  \
139    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
140
141#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
142    STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
143    VSTORE(N0)                                                  \
144    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
145
146#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
147    STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
148    VSTORE(N0)                                                  \
149    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
150
151#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
152    STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
153    VSTORE(N0)                                                  \
154    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
155
156#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
157    STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
158    VSTORE(N0)                                                  \
159    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
160
161#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
162    STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
163    VSTORE(N0)                                                  \
164    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
165/** @} */ // end of groupd STORE_ROW_n
166
167/** Convert and store the 0th to (n-1)th rows of the given variables
168 * @name CONVERT_STORE_ROW_n
169 *
170 * @param[in] N0        The size of the vectors
171 * @param[in] DATA_TYPE The data type of the vectors
172 * @param[in] BASENAME  The basename of the variables
173 * @param[in] PTR       The base pointer
174 * @param[in] STRIDE_Y  The stride value in y-axis direction
175 * @param[in] Z         The offset in z-axis direction
176 * @{
177 */
178#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
179    VSTORE(N0)                                                         \
180    (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
181
182#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
183    CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
184    VSTORE(N0)                                                         \
185    (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
186
187#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
188    CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
189    VSTORE(N0)                                                         \
190    (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
191
192#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
193    CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
194    VSTORE(N0)                                                         \
195    (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
196
197#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
198    CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
199    VSTORE(N0)                                                         \
200    (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
201
202#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
203    CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
204    VSTORE(N0)                                                         \
205    (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
206
207#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
208    CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
209    VSTORE(N0)                                                         \
210    (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
211
212#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
213    CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
214    VSTORE(N0)                                                         \
215    (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
216
217#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
218    CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
219    VSTORE(N0)                                                         \
220    (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
221
222#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \
223    CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
224    VSTORE(N0)                                                     \
225    (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
226
227#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
228    CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
229    VSTORE(N0)                                                          \
230    (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
231
232#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
233    CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
234    VSTORE(N0)                                                          \
235    (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
236
237#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
238    CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
239    VSTORE(N0)                                                          \
240    (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
241
242#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
243    CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
244    VSTORE(N0)                                                          \
245    (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
246
247#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
248    CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
249    VSTORE(N0)                                                          \
250    (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
251
252#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
253    CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
254    VSTORE(N0)                                                          \
255    (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
256
257/** @} */ // end of groupd CONVERT_STORE_ROW_n
258
259/** Store a block of the given size M0xN0
260 * @name STORE_BLOCK
261 *
262 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
263 * The data to store is expected to have consecutive names for each row.
264 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
265 * The Z offset is expected to have consecutive names.
266 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
267 *
268 * @param[in] M0        The number of rows to store
269 * @param[in] N0        The size of each vector
270 * @param[in] DATA_TYPE The data type of the vectors
271 * @param[in] BASENAME  The basename of the variables
272 * @param[in] PTR       The base pointer
273 * @param[in] STRIDE_Y  The stride value in y-axis direction
274 * @param[in] Z         The offset in z-axis direction
275 * @{
276 */
277#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
278#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
279/** @} */ // end of group STORE_BLOCK
280
281/** Convert and store a block of the given size M0xN0
282 * @name CONVERT_STORE_BLOCK
283 *
284 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16.
285 * The data to store is expected to have consecutive names for each row.
286 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
287 * The Z offset is expected to have consecutive names.
288 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
289 *
290 * @param[in] M0        The number of rows to store
291 * @param[in] N0        The size of each vector
292 * @param[in] DATA_TYPE The data type of the vectors
293 * @param[in] BASENAME  The basename of the variables
294 * @param[in] PTR       The base pointer
295 * @param[in] STRIDE_Y  The stride value in y-axis direction
296 * @param[in] Z         The offset in z-axis direction
297 * @{
298 */
299#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
300#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
301/** @} */ // end of group CONVERT_STORE_BLOCK
302
303/** Partially store the 0 to (n-1)th rows of the given variables
304 * @name STORE_ROW_PARTIAL_n
305 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0
306 *
307 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
308 *
309 * @param[in] N0        The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16
310 * @param[in] STORE_N0  The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0
311 * @param[in] DATA_TYPE The data type of the vectors
312 * @param[in] BASENAME  The basename of the variables
313 * @param[in] PTR       The base pointer
314 * @param[in] STRIDE_Y  The stride value in y-axis direction
315 * @param[in] Z         The offset in z-axis direction
316 * @{
317 */
318#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
319    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
320    (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0));
321
322#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
323    STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
324    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
325    (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1));
326
327#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
328    STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
329    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
330    (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2));
331
332#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
333    STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
334    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
335    (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3));
336
337#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
338    STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
339    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
340    (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4));
341
342#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
343    STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
344    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
345    (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5));
346
347#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
348    STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
349    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
350    (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6));
351
352#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
353    STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
354    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
355    (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7));
356
357#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
358    STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
359    VSTORE_PARTIAL(N0, STORE_N0)                                                 \
360    (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8));
361
362#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
363    STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)      \
364    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
365    (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9));
366
367#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
368    STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
369    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
370    (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A));
371
372#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
373    STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
374    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
375    (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B));
376
377#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
378    STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
379    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
380    (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C));
381
382#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
383    STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
384    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
385    (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D));
386
387#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
388    STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
389    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
390    (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E));
391
392#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \
393    STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)     \
394    VSTORE_PARTIAL(N0, STORE_N0)                                                  \
395    (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F));
396/** @} */ // end of groupd STORE_ROW_PARTIAL_n
397
398/** Partially store a block of the given size STORE_M0xSTORE_N0
399 * @name STORE_BLOCK_PARTIAL
400 *
401 * @note The vector width @p N0 is also required for correct partial storing behaviour.
402 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
403 *
404 * The data to store is expected to have consecutive names for each row.
405 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2.
406 * The Z offset is expected to have consecutive names.
407 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
408 *
409 * @param[in] STORE_M0  The number of rows to store. Supported: 1-16
410 * @param[in] STORE_N0  The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0
411 * @param[in] N0        The size of each vector. Supported: 1, 2, 3, 4, 8, 16
412 * @param[in] DATA_TYPE The data type of the vectors
413 * @param[in] BASENAME  The basename of the variables
414 * @param[in] PTR       The base pointer
415 * @param[in] STRIDE_Y  The stride value in y-axis direction
416 * @param[in] Z         The offset in z-axis direction
417 * @{
418 */
419#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
420#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
421/** Store a block that can be partial in both x and y dimensions
422 *
423 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
424 *
425 * The data to store is expected to have consecutive names for each row.
426 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
427 * The Z offset is expected to have consecutive names.
428 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
429 *
430 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
431 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
432 * @param[in] DATA_TYPE        The data type of the vectors
433 * @param[in] BASENAME         The basename of the variables
434 * @param[in] PTR              The base pointer
435 * @param[in] STRIDE_Y         The stride value in y-axis direction
436 * @param[in] Z                The offset in z-axis direction
437 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
438 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
439 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
440 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
441 */
442#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
443    if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y))                                                                                                            \
444    {                                                                                                                                                     \
445        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                                           \
446    }                                                                                                                                                     \
447    else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X))                                                                                                        \
448    {                                                                                                                                                     \
449        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
450    }                                                                                                                                                     \
451    else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X))                                                                                                        \
452    {                                                                                                                                                     \
453        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                                             \
454    }                                                                                                                                                     \
455    else                                                                                                                                                  \
456    {                                                                                                                                                     \
457        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                               \
458    }
459/** Store a block that can only be partial in x but not y.
460 *
461 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
462 *
463 * The data to store is expected to have consecutive names for each row.
464 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
465 * The Z offset is expected to have consecutive names.
466 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
467 *
468 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
469 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
470 * @param[in] DATA_TYPE        The data type of the vectors
471 * @param[in] BASENAME         The basename of the variables
472 * @param[in] PTR              The base pointer
473 * @param[in] STRIDE_Y         The stride value in y-axis direction
474 * @param[in] Z                The offset in z-axis direction
475 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0)
476 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
477 */
478#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \
479    if(!(PARTIAL_COND_X))                                                                                         \
480    {                                                                                                             \
481        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
482    }                                                                                                             \
483    else                                                                                                          \
484    {                                                                                                             \
485        STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
486    }
487/** Store a block that can only be partial in y but not x.
488 *
489 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
490 *
491 * The data to store is expected to have consecutive names for each row.
492 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
493 * The Z offset is expected to have consecutive names.
494 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
495 *
496 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
497 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
498 * @param[in] DATA_TYPE        The data type of the vectors
499 * @param[in] BASENAME         The basename of the variables
500 * @param[in] PTR              The base pointer
501 * @param[in] STRIDE_Y         The stride value in y-axis direction
502 * @param[in] Z                The offset in z-axis direction
503 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0)
504 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
505 */
506#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \
507    if(!(PARTIAL_COND_Y))                                                                                         \
508    {                                                                                                             \
509        STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                                   \
510    }                                                                                                             \
511    else                                                                                                          \
512    {                                                                                                             \
513        STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z);                     \
514    }
515/** @} */ // end of group STORE_BLOCK_PARTIAL
516
517#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
518
519/** Boundary-aware GEMM block store
520 * @name STORE_BLOCK_BOUNDARY_AWARE
521 * This macro assumes the following schemes to achieve boundary-awareness:
522 *  - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim.
523 *  - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings.
524 *  - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim.
525 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim.
526 *
527 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial
528 * blocks **at the end**.
529 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/
530 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters:
531 *
532 *  *--x-->                         x == 0                        x == 1
533 *  |                  |<------------------------------N-------------------------->|
534 *  y                  |<--------------N0------------->|<----PARTIAL_STORE_N0----->|
535 *  |     -------------#############################################################
536 *  *     |          | |...............................|...........................|
537 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.|
538 *        |          | |...............................|...........................|
539 *        M          --#############################################################
540 *        |          | |                               |...........................|
541 * y == 1 |         M0 |      Non-boundary block       |....Boundary block in x....|
542 *        |          | |                               |...........................|
543 *        |------------#############################################################
544 *
545 * Then @p PARTIAL_STORE_M0 = M % M0      and @p PARTIAL_STORE_N0 = N % N0
546 *
547 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
548 *
549 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension,
550 * and select corresponding store methods such that the boundary detection logic is only added when needed.
551 *
552 * The data to store is expected to have consecutive names for each row.
553 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2.
554 * The Z offset is expected to have consecutive names.
555 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2.
556 *
557 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
558 * @param[in] N0               The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16
559 * @param[in] DATA_TYPE        The data type of the vectors
560 * @param[in] BASENAME         The basename of the variables
561 * @param[in] PTR              The base pointer
562 * @param[in] STRIDE_Y         The stride value in y-axis direction
563 * @param[in] Z                The offset in z-axis direction
564 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
565 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0)
566 * @param[in] PARTIAL_COND_Y   Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0.
567 * @param[in] PARTIAL_COND_X   Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0.
568 * @{
569 */
570#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
571// Case1: No partial blocks in either x or y
572#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
573    STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z)
574
575#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0
576// Case2: Partial blocks in y
577#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
578    STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y)
579
580#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0
581// Case3: Partial blocks in x
582#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
583    STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X)
584
585#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
586// Case4: Partial blocks in both x and y
587#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \
588    STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X)
589
590#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0
591
592#endif    // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
593/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE
594
595#if defined(PARTIAL_STORE_M0)
596/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding
597 * @name COMPUTE_M0_START_ROW
598 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows.
599 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent
600 * blocks in the y dimension to avoid any padding.
601 * EG: M0=4, PARTIAL_STORE_M0=1:
602 *                  | Non-overlapping | +M0_ROW_SHIFT (Overlapping)
603 * block 0 (partial)| start row = 0   | start row = 0
604 * block 1 (full)   | start row = 4   | start row = 1
605 * block 2 (full)   | start row = 8   | start row = 5
606 *
607 * @param[in] y                Global id of current block in y.
608 * @param[in] M0               The number of rows to store, for non-partial blocks. Supported: 1-16
609 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0)
610 * @{
611 */
612#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
613    ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0))))
614#else // defined(PARTIAL_STORE_M0)
615#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \
616    ((uint)(y * M0))
617#endif    // defined(PARTIAL_STORE_M0)
618/** @} */ // end of group COMPUTE_M0_START_ROW
619
620/** Store a vector that can only be partial in x.
621 *
622 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty.
623 *
624 * The data to store is expected to end in a 0.
625 * E.g., for basename=c, the expected name is c0.
626 *
627 * @param[in] basename  The name of the variable without trailing 0
628 * @param[in] data_type The data type of the vector
629 * @param[in] ptr       The base pointer
630 * @param[in] vec_size  The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16
631 * @param[in] leftover  The vector size if cond = true. Supported range: [1, @p vec_size0)
632 * @param[in] cond      Condition to select either vec_size0 or vec_size1
633 * @{
634 */
635#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \
636    STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond)
637/** @} */ // end of group STORE_VECTOR_SELECT
638
639#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
640#pragma OPENCL EXTENSION cl_khr_fp16 : enable
641#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
642
643#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
644#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable
645#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
646
647#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
648#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable
649#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
650
651#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
652#pragma OPENCL EXTENSION cl_arm_printf : enable
653#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf)
654
655#define GPU_ARCH_MIDGARD 0x100
656#define GPU_ARCH_BIFROST 0x200
657
658/** Concatenate two inputs.
659 *
660 * @param[in] a The first input to be concatenated
661 * @param[in] b The second input to be concatenated
662 *
663 * @return The concatenated output
664 */
665#define CONCAT(a, b) a##b
666
667/** Expand the given vector
668 *
669 * @param[in] x The vector to be expanded
670 *
671 * @return The expanded output
672 */
673#define EXPAND(x) x
674
675/** Clamp the given value between an upper and lower bound.
676 *
677 * @param[in] x       The value to be clamped
678 * @param[in] min_val The lower bound
679 * @param[in] max_val The upper bound
680 *
681 * @return The clamped value.
682 */
683#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val)
684
685/** REVn reverses the given vector whose size is n.
686 * @name REVn
687 *
688 * @param[in] x The vector to be reversed
689 *
690 * @return The reversed vector
691 * @{
692 */
693#define REV1(x) ((x))
694#define REV2(x) ((x).s10)
695#define REV3(x) ((x).s210)
696#define REV4(x) ((x).s3210)
697#define REV8(x) ((x).s76543210)
698#define REV16(x) ((x).sFEDCBA9876543210)
699/** @} */ // end of group REVn
700
701/** Reverse the given vector.
702 * @name REVERSE
703 *
704 * @param[in] x The vector to be reversed
705 * @param[in] s The size of the vector
706 *
707 * @return The reversed vector
708 * @{
709 */
710#define REVERSE_STR(x, s) REV##s((x))
711#define REVERSE(x, s) REVERSE_STR(x, s)
712/** @} */ // end of group REVERSE
713
714/** Circular-right-shift (rotate-right) the vector of size s by the amount of n.
715 * @name ROTs_n
716 *
717 * @param[in] x The vector to be shifted
718 *
719 * @return The shifted vector
720 * @{
721 */
722#define ROT1_0(x) ((x))
723
724#define ROT2_0(x) ((x))
725#define ROT2_1(x) ((x).s10)
726
727#define ROT3_0(x) ((x))
728#define ROT3_1(x) ((x).s201)
729#define ROT3_2(x) ((x).s120)
730
731#define ROT4_0(x) ((x))
732#define ROT4_1(x) ((x).s3012)
733#define ROT4_2(x) ((x).s2301)
734#define ROT4_3(x) ((x).s1230)
735
736#define ROT8_0(x) ((x))
737#define ROT8_1(x) ((x).s70123456)
738#define ROT8_2(x) ((x).s67012345)
739#define ROT8_3(x) ((x).s56701234)
740#define ROT8_4(x) ((x).s45670123)
741#define ROT8_5(x) ((x).s34567012)
742#define ROT8_6(x) ((x).s23456701)
743#define ROT8_7(x) ((x).s12345670)
744
745#define ROT16_0(x) ((x))
746#define ROT16_1(x) ((x).sF0123456789ABCDE)
747#define ROT16_2(x) ((x).sEF0123456789ABCD)
748#define ROT16_3(x) ((x).sDEF0123456789ABC)
749#define ROT16_4(x) ((x).sCDEF0123456789AB)
750#define ROT16_5(x) ((x).sBCDEF0123456789A)
751#define ROT16_6(x) ((x).sABCDEF0123456789)
752#define ROT16_7(x) ((x).s9ABCDEF012345678)
753#define ROT16_8(x) ((x).s89ABCDEF01234567)
754#define ROT16_9(x) ((x).s789ABCDEF0123456)
755#define ROT16_10(x) ((x).s6789ABCDEF012345)
756#define ROT16_11(x) ((x).s56789ABCDEF01234)
757#define ROT16_12(x) ((x).s456789ABCDEF0123)
758#define ROT16_13(x) ((x).s3456789ABCDEF012)
759#define ROT16_14(x) ((x).s23456789ABCDEF01)
760#define ROT16_15(x) ((x).s123456789ABCDEF0)
761/** @} */ // end of group ROTs_n
762
763/** Circular-right-shift (rotate-right) the given vector by the given amount.
764 * @name ROTATE
765 *
766 * @param[in] x The vector to be shifted
767 * @param[in] s The size of the vector
768 * @param[in] n The amount to be shifted
769 *
770 * @return The shifted vector
771 * @{
772 */
773#define ROTATE_STR(x, s, n) ROT##s##_##n(x)
774#define ROTATE(x, s, n) ROTATE_STR(x, s, n)
775/** @} */ // end of group ROTATE
776
777/** Creates a vector of size n filled with offset values corresponding to the location of each element.
778 * @name V_OFFSn
779 *
780 * @param[in] dt The data type of the output vector
781 *
782 * @return The vector filled with offset values
783 * @{
784 */
785#define V_OFFS1(dt) (dt##1)(0)
786#define V_OFFS2(dt) (dt##2)(0, 1)
787#define V_OFFS3(dt) (dt##3)(0, 1, 2)
788#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3)
789#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7)
790#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
791/** @} */ // end of group V_OFFSn
792
793/** Create a vector filled with offset values corresponding to the location of each element.
794 * @name VEC_OFFS
795 *
796 * @param[in] dt The data type of the output vector
797 * @param[in] s  The size of the output vector
798 *
799 * @return The vector filled with offset values
800 * @{
801 */
802#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt)
803#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s)
804/** @} */ // end of group VEC_OFFS
805
806#define VLOAD_STR(size) vload##size
807#define VLOAD(size) VLOAD_STR(size)
808
809#define PIXEL_UNIT4 1
810#define PIXEL_UNIT8 2
811#define PIXEL_UNIT16 4
812
813/** Utility macro to convert a vector size in pixel unit.
814 *
815 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
816 *
817 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported
818 *
819 * @return The pixel unit (number of pixels)
820 * @{
821 */
822#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size
823#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size)
824/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT
825
826#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord)));
827#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)));
828#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord)));
829
830#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
831#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord)));
832#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)));
833#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord)));
834#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16)
835
836/** Utility macro to read a 2D OpenCL image object.
837 *
838 * @note Coordinates are not normalized
839 *
840 * @param[in] data_type Data type
841 * @param[in] n0        Number of pixel to read. Only 1,2 and 4 is supported
842 * @param[in] img       OpenCL image object
843 * @param[in] x_coord   The x coordinate for the top-left pixel
844 * @param[in] y_coord   The y coordinate for the top-left pixel
845 *
846 * @return Pixels from the 2D OpenCL image object
847 * @{
848 */
849#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord)
850#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord)
851
852#define VSTORE_STR(size) vstore##size
853#define VSTORE(size) VSTORE_STR(size)
854
855#define float1 float
856#define half1 half
857#define char1 char
858#define uchar1 uchar
859#define short1 short
860#define ushort1 ushort
861#define int1 int
862#define uint1 uint
863#define long1 long
864#define ulong1 ulong
865#define double1 double
866
867#define vload1(OFFSET, PTR) *(OFFSET + PTR)
868#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA
869
870/** Extended partial vstore that correctly handles scalar values as well.
871 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
872 * @name VSTORE_PARTIAL
873 *
874 * @note With this macro, the passed data can be both a vector and a scalar
875 * @note @p store_size needs to be <= @p size
876 * eg 1: Valid
877 * VSTORE_PARTIAL(16, 15) ...;
878 * eg 2: Invalid
879 * VSTORE_PARTIAL(4, 7) ...;
880 *
881 * @param[in] size       The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16
882 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size
883 * @{
884 */
885#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size
886#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size)
887
888#define NO_STORE(data, offs, ptr) \
889    {                             \
890    }
891
892// Size == 1 (scalar)
893#define vstore_partial_1_0 NO_STORE
894#define vstore_partial_1_1 vstore1
895#define vstore_partial_1_2 NO_STORE
896#define vstore_partial_1_3 NO_STORE
897#define vstore_partial_1_4 NO_STORE
898#define vstore_partial_1_5 NO_STORE
899#define vstore_partial_1_6 NO_STORE
900#define vstore_partial_1_7 NO_STORE
901#define vstore_partial_1_8 NO_STORE
902#define vstore_partial_1_9 NO_STORE
903#define vstore_partial_1_10 NO_STORE
904#define vstore_partial_1_11 NO_STORE
905#define vstore_partial_1_12 NO_STORE
906#define vstore_partial_1_13 NO_STORE
907#define vstore_partial_1_14 NO_STORE
908#define vstore_partial_1_15 NO_STORE
909#define vstore_partial_1_16 NO_STORE
910// Size == 2
911#define vstore_partial_2_0 NO_STORE
912#define vstore_partial_2_1 vstore_partial_1
913#define vstore_partial_2_2 vstore_partial_2
914#define vstore_partial_2_3 NO_STORE
915#define vstore_partial_2_4 NO_STORE
916#define vstore_partial_2_5 NO_STORE
917#define vstore_partial_2_6 NO_STORE
918#define vstore_partial_2_7 NO_STORE
919#define vstore_partial_2_8 NO_STORE
920#define vstore_partial_2_9 NO_STORE
921#define vstore_partial_2_10 NO_STORE
922#define vstore_partial_2_11 NO_STORE
923#define vstore_partial_2_12 NO_STORE
924#define vstore_partial_2_13 NO_STORE
925#define vstore_partial_2_14 NO_STORE
926#define vstore_partial_2_15 NO_STORE
927#define vstore_partial_2_16 NO_STORE
928// Size == 3
929#define vstore_partial_3_0 NO_STORE
930#define vstore_partial_3_1 vstore_partial_1
931#define vstore_partial_3_2 vstore_partial_2
932#define vstore_partial_3_3 vstore_partial_3
933#define vstore_partial_3_4 NO_STORE
934#define vstore_partial_3_5 NO_STORE
935#define vstore_partial_3_6 NO_STORE
936#define vstore_partial_3_7 NO_STORE
937#define vstore_partial_3_8 NO_STORE
938#define vstore_partial_3_9 NO_STORE
939#define vstore_partial_3_10 NO_STORE
940#define vstore_partial_3_11 NO_STORE
941#define vstore_partial_3_12 NO_STORE
942#define vstore_partial_3_13 NO_STORE
943#define vstore_partial_3_14 NO_STORE
944#define vstore_partial_3_15 NO_STORE
945#define vstore_partial_3_16 NO_STORE
946// Size == 4
947#define vstore_partial_4_0 NO_STORE
948#define vstore_partial_4_1 vstore_partial_1
949#define vstore_partial_4_2 vstore_partial_2
950#define vstore_partial_4_3 vstore_partial_3
951#define vstore_partial_4_4 vstore_partial_4
952#define vstore_partial_4_5 NO_STORE
953#define vstore_partial_4_6 NO_STORE
954#define vstore_partial_4_7 NO_STORE
955#define vstore_partial_4_8 NO_STORE
956#define vstore_partial_4_9 NO_STORE
957#define vstore_partial_4_10 NO_STORE
958#define vstore_partial_4_11 NO_STORE
959#define vstore_partial_4_12 NO_STORE
960#define vstore_partial_4_13 NO_STORE
961#define vstore_partial_4_14 NO_STORE
962#define vstore_partial_4_15 NO_STORE
963#define vstore_partial_4_16 NO_STORE
964// Size == 8
965#define vstore_partial_8_0 NO_STORE
966#define vstore_partial_8_1 vstore_partial_1
967#define vstore_partial_8_2 vstore_partial_2
968#define vstore_partial_8_3 vstore_partial_3
969#define vstore_partial_8_4 vstore_partial_4
970#define vstore_partial_8_5 vstore_partial_5
971#define vstore_partial_8_6 vstore_partial_6
972#define vstore_partial_8_7 vstore_partial_7
973#define vstore_partial_8_8 vstore_partial_8
974#define vstore_partial_8_9 NO_STORE
975#define vstore_partial_8_10 NO_STORE
976#define vstore_partial_8_11 NO_STORE
977#define vstore_partial_8_12 NO_STORE
978#define vstore_partial_8_13 NO_STORE
979#define vstore_partial_8_14 NO_STORE
980#define vstore_partial_8_15 NO_STORE
981#define vstore_partial_8_16 NO_STORE
982// Size == 16
983#define vstore_partial_16_0 NO_STORE
984#define vstore_partial_16_1 vstore_partial_1
985#define vstore_partial_16_2 vstore_partial_2
986#define vstore_partial_16_3 vstore_partial_3
987#define vstore_partial_16_4 vstore_partial_4
988#define vstore_partial_16_5 vstore_partial_5
989#define vstore_partial_16_6 vstore_partial_6
990#define vstore_partial_16_7 vstore_partial_7
991#define vstore_partial_16_8 vstore_partial_8
992#define vstore_partial_16_9 vstore_partial_9
993#define vstore_partial_16_10 vstore_partial_10
994#define vstore_partial_16_11 vstore_partial_11
995#define vstore_partial_16_12 vstore_partial_12
996#define vstore_partial_16_13 vstore_partial_13
997#define vstore_partial_16_14 vstore_partial_14
998#define vstore_partial_16_15 vstore_partial_15
999#define vstore_partial_16_16 vstore_partial_16
1000
1001/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops
1002 * @name vstore_partial_n
1003 *
1004 * @note @p DATA needs to be a vector not a scalar
1005 * @note n needs to be <= the vector width of the input variable @p DATA
1006 * eg 1: Valid
1007 * vstore_partial_15(var:float16, 0, 0xabcd);
1008 * eg 2: Invalid
1009 * vstore_partial_7(var:float4, 0, 0xabcd);
1010 *
1011 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty.
1012 *
1013 * @param[in] DATA   The name of the variable
1014 * @param[in] OFFSET Offset in n
1015 * @param[in] PTR    The base pointer
1016 * @{
1017 */
1018#define vstore_partial_1(DATA, OFFSET, PTR) \
1019    vstore1(DATA.s0, OFFSET, PTR);
1020
1021#define vstore_partial_2(DATA, OFFSET, PTR) \
1022    vstore2(DATA.s01, OFFSET, PTR);
1023
1024#define vstore_partial_3(DATA, OFFSET, PTR) \
1025    vstore3(DATA.s012, OFFSET, PTR);
1026
1027#define vstore_partial_4(DATA, OFFSET, PTR) \
1028    vstore4(DATA.s0123, OFFSET, PTR);
1029
1030#define vstore_partial_5(DATA, OFFSET, PTR)    \
1031    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1032    vstore1(DATA.s4, OFFSET, PTR + 4);
1033
1034#define vstore_partial_6(DATA, OFFSET, PTR)    \
1035    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1036    vstore_partial_2(DATA.s45, OFFSET, PTR + 4);
1037
1038#define vstore_partial_7(DATA, OFFSET, PTR)    \
1039    vstore_partial_4(DATA.s0123, OFFSET, PTR); \
1040    vstore_partial_3(DATA.s456, OFFSET, PTR + 4);
1041
1042#define vstore_partial_8(DATA, OFFSET, PTR) \
1043    vstore8(DATA.s01234567, OFFSET, PTR);
1044
1045#define vstore_partial_9(DATA, OFFSET, PTR)        \
1046    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1047    vstore1(DATA.s8, OFFSET, PTR + 8);
1048
1049#define vstore_partial_10(DATA, OFFSET, PTR)       \
1050    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1051    vstore_partial_2(DATA.s89, OFFSET, PTR + 8);
1052
1053#define vstore_partial_11(DATA, OFFSET, PTR)       \
1054    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1055    vstore_partial_3(DATA.s89a, OFFSET, PTR + 8);
1056
1057#define vstore_partial_12(DATA, OFFSET, PTR)       \
1058    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1059    vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8);
1060
1061#define vstore_partial_13(DATA, OFFSET, PTR)       \
1062    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1063    vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8);
1064
1065#define vstore_partial_14(DATA, OFFSET, PTR)       \
1066    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1067    vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8);
1068
1069#define vstore_partial_15(DATA, OFFSET, PTR)       \
1070    vstore_partial_8(DATA.s01234567, OFFSET, PTR); \
1071    vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8);
1072
1073#define vstore_partial_16(DATA, OFFSET, PTR) \
1074    vstore16(DATA, OFFSET, PTR);
1075/** @} */ // end of groupd vstore_partial_n
1076/** @} */ // end of groupd VSTORE_PARTIAL
1077
1078// Convert built-in functions with _sat modifier are not supported in floating point so we create defines
1079// without _sat to overcome this issue
1080#define convert_float_sat convert_float
1081#define convert_float1_sat convert_float
1082#define convert_float2_sat convert_float2
1083#define convert_float3_sat convert_float3
1084#define convert_float4_sat convert_float4
1085#define convert_float8_sat convert_float8
1086#define convert_float16_sat convert_float16
1087#define convert_half_sat convert_float
1088#define convert_half1_sat convert_half
1089#define convert_half2_sat convert_half2
1090#define convert_half3_sat convert_half3
1091#define convert_half4_sat convert_half4
1092#define convert_half8_sat convert_half8
1093#define convert_half16_sat convert_half16
1094
1095#define convert_float1 convert_float
1096#define convert_half1 convert_half
1097#define convert_char1 convert_char
1098#define convert_uchar1 convert_uchar
1099#define convert_short1 convert_short
1100#define convert_ushort1 convert_ushort
1101#define convert_int1 convert_int
1102#define convert_uint1 convert_uint
1103#define convert_long1 convert_long
1104#define convert_ulong1 convert_ulong
1105#define convert_double1 convert_double
1106
1107#define convert_char1_sat convert_char_sat
1108#define convert_uchar1_sat convert_uchar_sat
1109#define convert_short1_sat convert_short_sat
1110#define convert_ushort1_sat convert_ushort_sat
1111#define convert_int1_sat convert_int_sat
1112#define convert_uint1_sat convert_uint_sat
1113#define convert_long1_sat convert_long_sat
1114#define convert_ulong1_sat convert_ulong_sat
1115#define convert_double1_sat convert_double_sat
1116
1117#define VEC_DATA_TYPE_STR(type, size) type##size
1118#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size)
1119
1120#define CONVERT_STR(x, type) (convert_##type((x)))
1121#define CONVERT(x, type) CONVERT_STR(x, type)
1122
1123#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x)))
1124#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type)
1125
1126#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x)))
1127#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round)
1128
1129#define select_vec_dt_uchar(size) uchar##size
1130#define select_vec_dt_char(size) char##size
1131#define select_vec_dt_ushort(size) ushort##size
1132#define select_vec_dt_short(size) short##size
1133#define select_vec_dt_half(size) short##size
1134#define select_vec_dt_uint(size) uint##size
1135#define select_vec_dt_int(size) int##size
1136#define select_vec_dt_float(size) int##size
1137#define select_vec_dt_ulong(size) ulong##size
1138#define select_vec_dt_long(size) long##size
1139
1140#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size)
1141#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size)
1142#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1)
1143
1144#define sum_reduce_1(x) (x)
1145#define sum_reduce_2(x) ((x).s0) + ((x).s1)
1146#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2)
1147#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23)
1148#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567)
1149#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF)
1150
1151#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x)
1152#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size)
1153
1154#define max_reduce_1(x) (x)
1155#define max_reduce_2(x) max(((x).s0), ((x).s1))
1156#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2))
1157#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23))
1158#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567))
1159#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF))
1160
1161#define MAX_REDUCE_STR(x, size) max_reduce_##size(x)
1162#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size)
1163
1164#define VECTOR_DECLARATION(name)     \
1165    __global uchar *name##_ptr,      \
1166    uint        name##_stride_x, \
1167    uint        name##_step_x,   \
1168    uint        name##_offset_first_element_in_bytes
1169
1170#define IMAGE_DECLARATION(name)      \
1171    __global uchar *name##_ptr,      \
1172    uint        name##_stride_x, \
1173    uint        name##_step_x,   \
1174    uint        name##_stride_y, \
1175    uint        name##_step_y,   \
1176    uint        name##_offset_first_element_in_bytes
1177
1178#define TENSOR3D_DECLARATION(name)   \
1179    __global uchar *name##_ptr,      \
1180    uint        name##_stride_x, \
1181    uint        name##_step_x,   \
1182    uint        name##_stride_y, \
1183    uint        name##_step_y,   \
1184    uint        name##_stride_z, \
1185    uint        name##_step_z,   \
1186    uint        name##_offset_first_element_in_bytes
1187
1188#define TENSOR4D_DECLARATION(name)   \
1189    __global uchar *name##_ptr,      \
1190    uint        name##_stride_x, \
1191    uint        name##_step_x,   \
1192    uint        name##_stride_y, \
1193    uint        name##_step_y,   \
1194    uint        name##_stride_z, \
1195    uint        name##_step_z,   \
1196    uint        name##_stride_w, \
1197    uint        name##_step_w,   \
1198    uint        name##_offset_first_element_in_bytes
1199
1200#define CONVERT_TO_VECTOR_STRUCT(name) \
1201    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x)
1202
1203#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \
1204    update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0)
1205
1206#define CONVERT_TO_IMAGE_STRUCT(name) \
1207    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y)
1208
1209#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \
1210    update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0)
1211
1212#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1213    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1214
1215#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \
1216    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z)
1217
1218#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \
1219    update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z)
1220
1221#define CONVERT_TO_TENSOR3D_STRUCT(name)                                                                                                           \
1222    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1223                                 name##_stride_z, name##_step_z)
1224
1225#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \
1226    update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0)
1227
1228#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size)                                                                                                 \
1229    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1230                                 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size)
1231
1232#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \
1233    update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size)
1234
1235#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name)                                                                                       \
1236    tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \
1237                           name##_stride_z, name##_step_z)
1238
1239/** Structure to hold Vector information */
1240typedef struct Vector
1241{
1242    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1243    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1244    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1245} Vector;
1246
1247/** Structure to hold Image information */
1248typedef struct Image
1249{
1250    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1251    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1252    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1253    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1254} Image;
1255
1256/** Structure to hold 3D tensor information */
1257typedef struct Tensor3D
1258{
1259    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1260    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1261    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1262    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1263    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1264} Tensor3D;
1265
1266/** Structure to hold 4D tensor information */
1267typedef struct Tensor4D
1268{
1269    __global uchar *ptr;                           /**< Pointer to the starting postion of the buffer */
1270    int             offset_first_element_in_bytes; /**< The offset of the first element in the source image */
1271    int             stride_x;                      /**< Stride of the image in X dimension (in bytes) */
1272    int             stride_y;                      /**< Stride of the image in Y dimension (in bytes) */
1273    int             stride_z;                      /**< Stride of the image in Z dimension (in bytes) */
1274    int             stride_w;                      /**< Stride of the image in W dimension (in bytes) */
1275} Tensor4D;
1276
1277/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data.
1278 *
1279 * @param[in] ptr                           Pointer to the starting postion of the buffer
1280 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector
1281 * @param[in] stride_x                      Stride of the vector in X dimension (in bytes)
1282 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1283 *
1284 * @return An image object
1285 */
1286inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x)
1287{
1288    Vector vector =
1289    {
1290        .ptr                           = ptr,
1291        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1292        .stride_x                      = stride_x,
1293    };
1294    vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x;
1295    return vector;
1296}
1297
1298/** Wrap image information into an Image structure, and make the pointer point at this workitem's data.
1299 *
1300 * @param[in] ptr                           Pointer to the starting postion of the buffer
1301 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1302 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1303 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1304 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1305 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1306 *
1307 * @return An image object
1308 */
1309inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y)
1310{
1311    Image img =
1312    {
1313        .ptr                           = ptr,
1314        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1315        .stride_x                      = stride_x,
1316        .stride_y                      = stride_y
1317    };
1318    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y;
1319    return img;
1320}
1321
1322/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data.
1323 *
1324 * @param[in] ptr                           Pointer to the starting postion of the buffer
1325 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1326 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1327 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1328 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1329 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1330 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1331 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1332 *
1333 * @return A 3D tensor object
1334 */
1335inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1336{
1337    Image img =
1338    {
1339        .ptr                           = ptr,
1340        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1341        .stride_x                      = stride_x,
1342        .stride_y                      = stride_y
1343    };
1344    img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1345    return img;
1346}
1347
1348/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data.
1349 *
1350 * @param[in] ptr                           Pointer to the starting postion of the buffer
1351 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1352 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1353 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1354 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1355 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1356 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1357 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1358 *
1359 * @return A 3D tensor object
1360 */
1361inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1362{
1363    Tensor3D tensor =
1364    {
1365        .ptr                           = ptr,
1366        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1367        .stride_x                      = stride_x,
1368        .stride_y                      = stride_y,
1369        .stride_z                      = stride_z
1370    };
1371    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z;
1372    return tensor;
1373}
1374
1375/** Wrap 3D tensor information into an tensor structure.
1376 *
1377 * @param[in] ptr                           Pointer to the starting postion of the buffer
1378 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image
1379 * @param[in] stride_x                      Stride of the image in X dimension (in bytes)
1380 * @param[in] step_x                        stride_x * number of elements along X processed per workitem(in bytes)
1381 * @param[in] stride_y                      Stride of the image in Y dimension (in bytes)
1382 * @param[in] step_y                        stride_y * number of elements along Y processed per workitem(in bytes)
1383 * @param[in] stride_z                      Stride of the image in Z dimension (in bytes)
1384 * @param[in] step_z                        stride_z * number of elements along Z processed per workitem(in bytes)
1385 *
1386 * @return A 3D tensor object
1387 */
1388inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z)
1389{
1390    Tensor3D tensor =
1391    {
1392        .ptr                           = ptr,
1393        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1394        .stride_x                      = stride_x,
1395        .stride_y                      = stride_y,
1396        .stride_z                      = stride_z
1397    };
1398    return tensor;
1399}
1400
1401inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w,
1402                                             uint step_w,
1403                                             uint mod_size)
1404{
1405    Tensor4D tensor =
1406    {
1407        .ptr                           = ptr,
1408        .offset_first_element_in_bytes = offset_first_element_in_bytes,
1409        .stride_x                      = stride_x,
1410        .stride_y                      = stride_y,
1411        .stride_z                      = stride_z,
1412        .stride_w                      = stride_w
1413    };
1414
1415    tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w;
1416    return tensor;
1417}
1418
1419/** Get the pointer position of a Vector
1420 *
1421 * @param[in] vec Pointer to the starting position of the buffer
1422 * @param[in] x   Relative X position
1423 */
1424inline __global const uchar *vector_offset(const Vector *vec, int x)
1425{
1426    return vec->ptr + x * vec->stride_x;
1427}
1428
1429/** Get the pointer position of a Image
1430 *
1431 * @param[in] img Pointer to the starting position of the buffer
1432 * @param[in] x   Relative X position
1433 * @param[in] y   Relative Y position
1434 */
1435inline __global uchar *offset(const Image *img, int x, int y)
1436{
1437    return img->ptr + x * img->stride_x + y * img->stride_y;
1438}
1439
1440/** Get the pointer position of a Tensor3D
1441 *
1442 * @param[in] tensor Pointer to the starting position of the buffer
1443 * @param[in] x      Relative X position
1444 * @param[in] y      Relative Y position
1445 * @param[in] z      Relative Z position
1446 */
1447inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z)
1448{
1449    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z;
1450}
1451
1452/** Get the pointer position of a Tensor4D
1453 *
1454 * @param[in] tensor Pointer to the starting position of the buffer
1455 * @param[in] x      Relative X position
1456 * @param[in] y      Relative Y position
1457 * @param[in] z      Relative Z position
1458 * @param[in] w      Relative W position
1459 */
1460inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w)
1461{
1462    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w;
1463}
1464
1465/** Get the offset for a given linear index of a Tensor3D
1466 *
1467 * @param[in] tensor Pointer to the starting position of the buffer
1468 * @param[in] width  Width of the input tensor
1469 * @param[in] height Height of the input tensor
1470 * @param[in] depth  Depth of the input tensor
1471 * @param[in] index  Linear index
1472 */
1473inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index)
1474{
1475    uint num_elements = width * height;
1476
1477    const uint z = index / num_elements;
1478
1479    index %= num_elements;
1480
1481    const uint y = index / width;
1482
1483    index %= width;
1484
1485    const uint x = index;
1486
1487    return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes;
1488}
1489
1490#endif // _HELPER_H
1491
1492/** This kernel performs l2 normalization on x-axis
1493 *
1494 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
1495 * @note The data size must be passed at compile time using -DDATA_SIZE e.g. -DDATA_SIZE=32
1496 *
1497 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F16/F32
1498 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
1499 * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
1500 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1501 * @param[in]  src_step_y                        src_stride_y * number of elements along X processed per workitem(in bytes)
1502 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
1503 * @param[in]  sum_ptr                           Pointer to the source tensor. Supported data types: F16/F32
1504 * @param[in]  sum_stride_x                      Stride of the source tensor in X dimension (in bytes)
1505 * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
1506 * @param[in]  sum_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1507 * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
1508 * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the source tensor
1509 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
1510 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
1511 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
1512 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
1513 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
1514 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1515 * @param[in]  epsilon                           Epsilon value
1516 */
1517__kernel void l2_normalize_x(
1518    IMAGE_DECLARATION(src),
1519    IMAGE_DECLARATION(sum),
1520    IMAGE_DECLARATION(dst),
1521    DATA_TYPE epsilon)
1522{
1523    Image src = CONVERT_TO_IMAGE_STRUCT(src);
1524    Image sum = CONVERT_TO_IMAGE_STRUCT(sum);
1525    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1526
1527    VEC_DATA_TYPE(DATA_TYPE, 16)
1528    in = vload16(0, (__global DATA_TYPE *)src.ptr);
1529    VEC_DATA_TYPE(DATA_TYPE, 16)
1530    normalize_value = (VEC_DATA_TYPE(DATA_TYPE, 16))rsqrt(fmax(((__global DATA_TYPE *)sum.ptr)[0], epsilon));
1531
1532    vstore16(in * normalize_value, 0, (__global DATA_TYPE *)dst.ptr);
1533}
1534
1535/** This kernel performs l2 normalization on y-axis.
1536 *
1537 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
1538 * @note The data size must be passed at compile time using -DDATA_SIZE e.g. -DDATA_SIZE=32
1539 *
1540 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F16/F32
1541 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
1542 * @param[in]  src_step_x                        src_stride_x * number of elements along Y processed per workitem(in bytes)
1543 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1544 * @param[in]  src_step_y                        src_stride_y * number of elements along X processed per workitem(in bytes)
1545 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
1546 * @param[in]  sum_ptr                           Pointer to the source tensor. Supported data types: F16/F32
1547 * @param[in]  sum_stride_x                      Stride of the source tensor in X dimension (in bytes)
1548 * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
1549 * @param[in]  sum_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1550 * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
1551 * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the source tensor
1552 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
1553 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
1554 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
1555 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
1556 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
1557 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1558 * @param[in]  epsilon                           Epsilon value
1559 */
1560__kernel void l2_normalize_y(
1561    IMAGE_DECLARATION(src),
1562    IMAGE_DECLARATION(sum),
1563    IMAGE_DECLARATION(dst),
1564    DATA_TYPE epsilon)
1565{
1566    Image src = CONVERT_TO_IMAGE_STRUCT(src);
1567    Image sum = CONVERT_TO_IMAGE_STRUCT(sum);
1568    Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1569
1570    VEC_DATA_TYPE(DATA_TYPE, 16)
1571    in = vload16(0, (__global DATA_TYPE *)src.ptr);
1572    VEC_DATA_TYPE(DATA_TYPE, 16)
1573    sums = vload16(0, (__global DATA_TYPE *)sum.ptr);
1574
1575    VEC_DATA_TYPE(DATA_TYPE, 16)
1576    normalize_value = (VEC_DATA_TYPE(DATA_TYPE, 16))rsqrt(fmax(sums, epsilon));
1577
1578    vstore16(in * normalize_value, 0, (__global DATA_TYPE *)dst.ptr);
1579}
1580/** This kernel performs l2 normalization on z-axis.
1581 *
1582 * @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
1583 * @note The data size must be passed at compile time using -DDATA_SIZE e.g. -DDATA_SIZE=32
1584 *
1585 * @param[in]  src_ptr                           Pointer to the source tensor. Supported data types: F16/F32
1586 * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
1587 * @param[in]  src_step_x                        src_stride_x * number of elements along Y processed per workitem(in bytes)
1588 * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1589 * @param[in]  src_step_y                        src_stride_y * number of elements along X processed per workitem(in bytes)
1590 * @param[in]  src_stride_z                      Stride of the source tensor in Z dimension (in bytes)
1591 * @param[in]  src_step_z                        src_stride_z * number of elements along Z processed per workitem(in bytes)
1592 * @param[in]  src_offset_first_element_in_bytes The offset of the first element in the source tensor
1593 * @param[in]  sum_ptr                           Pointer to the source tensor. Supported data types: F16/F32
1594 * @param[in]  sum_stride_x                      Stride of the source tensor in X dimension (in bytes)
1595 * @param[in]  sum_step_x                        sum_stride_x * number of elements along X processed per workitem(in bytes)
1596 * @param[in]  sum_stride_y                      Stride of the source tensor in Y dimension (in bytes)
1597 * @param[in]  sum_step_y                        sum_stride_y * number of elements along Y processed per workitem(in bytes)
1598 * @param[in]  sum_stride_z                      Stride of the source tensor in Z dimension (in bytes)
1599 * @param[in]  sum_step_z                        sum_stride_z * number of elements along Z processed per workitem(in bytes)
1600 * @param[in]  sum_offset_first_element_in_bytes The offset of the first element in the source tensor
1601 * @param[out] dst_ptr                           Pointer to the destination tensor. Supported data types: same as @p src_ptr
1602 * @param[in]  dst_stride_x                      Stride of the destination tensor in X dimension (in bytes)
1603 * @param[in]  dst_step_x                        dst_stride_x * number of elements along X processed per workitem(in bytes)
1604 * @param[in]  dst_stride_y                      Stride of the destination tensor in Y dimension (in bytes)
1605 * @param[in]  dst_step_y                        dst_stride_y * number of elements along Y processed per workitem(in bytes)
1606 * @param[in]  dst_stride_z                      Stride of the destination tensor in Z dimension (in bytes)
1607 * @param[in]  dst_step_z                        dst_stride_z * number of elements along Y processed per workitem(in bytes)
1608 * @param[in]  dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1609 * @param[in]  epsilon                           Epsilon value
1610 */
1611__kernel void l2_normalize_z(
1612    TENSOR3D_DECLARATION(src),
1613    TENSOR3D_DECLARATION(sum),
1614    TENSOR3D_DECLARATION(dst),
1615    DATA_TYPE epsilon)
1616{
1617    Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
1618    Tensor3D sum = CONVERT_TO_TENSOR3D_STRUCT(sum);
1619    Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
1620
1621    VEC_DATA_TYPE(DATA_TYPE, 16)
1622    in = vload16(0, (__global DATA_TYPE *)src.ptr);
1623    VEC_DATA_TYPE(DATA_TYPE, 16)
1624    sums = vload16(0, (__global DATA_TYPE *)sum.ptr);
1625
1626    VEC_DATA_TYPE(DATA_TYPE, 16)
1627    normalize_value = (VEC_DATA_TYPE(DATA_TYPE, 16))rsqrt(fmax(sums, epsilon));
1628
1629    vstore16(in * normalize_value, 0, (__global DATA_TYPE *)dst.ptr);
1630}
1631
1632)"