1R"( 2 3/* 4 * Copyright (c) 2020 Arm Limited. 5 * 6 * SPDX-License-Identifier: MIT 7 * 8 * Permission is hereby granted, free of charge, to any person obtaining a copy 9 * of this software and associated documentation files (the "Software"), to 10 * deal in the Software without restriction, including without limitation the 11 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 12 * sell copies of the Software, and to permit persons to whom the Software is 13 * furnished to do so, subject to the following conditions: 14 * 15 * The above copyright notice and this permission notice shall be included in all 16 * copies or substantial portions of the Software. 17 * 18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 * SOFTWARE. 25 */ 26/* 27 * Copyright (c) 2019-2020 Arm Limited. 28 * 29 * SPDX-License-Identifier: MIT 30 * 31 * Permission is hereby granted, free of charge, to any person obtaining a copy 32 * of this software and associated documentation files (the "Software"), to 33 * deal in the Software without restriction, including without limitation the 34 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 35 * sell copies of the Software, and to permit persons to whom the Software is 36 * furnished to do so, subject to the following conditions: 37 * 38 * The above copyright notice and this permission notice shall be included in all 39 * copies or substantial portions of the Software. 40 * 41 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 * SOFTWARE. 48 */ 49/* 50 * Copyright (c) 2019-2020 Arm Limited. 51 * 52 * SPDX-License-Identifier: MIT 53 * 54 * Permission is hereby granted, free of charge, to any person obtaining a copy 55 * of this software and associated documentation files (the "Software"), to 56 * deal in the Software without restriction, including without limitation the 57 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 58 * sell copies of the Software, and to permit persons to whom the Software is 59 * furnished to do so, subject to the following conditions: 60 * 61 * The above copyright notice and this permission notice shall be included in all 62 * copies or substantial portions of the Software. 63 * 64 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 65 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 66 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 67 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 68 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 69 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 70 * SOFTWARE. 71 */ 72 73/* 74 * Copyright (c) 2016-2020 Arm Limited. 75 * 76 * SPDX-License-Identifier: MIT 77 * 78 * Permission is hereby granted, free of charge, to any person obtaining a copy 79 * of this software and associated documentation files (the "Software"), to 80 * deal in the Software without restriction, including without limitation the 81 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 82 * sell copies of the Software, and to permit persons to whom the Software is 83 * furnished to do so, subject to the following conditions: 84 * 85 * The above copyright notice and this permission notice shall be included in all 86 * copies or substantial portions of the Software. 87 * 88 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 89 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 90 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 91 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 92 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 93 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 94 * SOFTWARE. 95 */ 96#ifndef ARM_COMPUTE_HELPER_H 97#define ARM_COMPUTE_HELPER_H 98 99/* 100 * Copyright (c) 2020 Arm Limited. 101 * 102 * SPDX-License-Identifier: MIT 103 * 104 * Permission is hereby granted, free of charge, to any person obtaining a copy 105 * of this software and associated documentation files (the "Software"), to 106 * deal in the Software without restriction, including without limitation the 107 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 108 * sell copies of the Software, and to permit persons to whom the Software is 109 * furnished to do so, subject to the following conditions: 110 * 111 * The above copyright notice and this permission notice shall be included in all 112 * copies or substantial portions of the Software. 113 * 114 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 115 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 116 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 117 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 118 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 119 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 120 * SOFTWARE. 121 */ 122 123/** Store the 0 to (n-1)th rows of the given variables 124 * @name STORE_ROW_n 125 * 126 * @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16 127 * @param[in] DATA_TYPE The data type of the vectors 128 * @param[in] BASENAME The basename of the variables 129 * @param[in] PTR The base pointer 130 * @param[in] STRIDE_Y The stride value in y-axis direction 131 * @param[in] Z The offset in z-axis direction 132 * @{ 133 */ 134#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 135 VSTORE(N0) \ 136 (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 137 138#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 139 STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 140 VSTORE(N0) \ 141 (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 142 143#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 144 STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 145 VSTORE(N0) \ 146 (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 147 148#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 149 STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 150 VSTORE(N0) \ 151 (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 152 153#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 154 STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 155 VSTORE(N0) \ 156 (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 157 158#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 159 STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 160 VSTORE(N0) \ 161 (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 162 163#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 164 STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 165 VSTORE(N0) \ 166 (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 167 168#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 169 STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 170 VSTORE(N0) \ 171 (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 172 173#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 174 STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 175 VSTORE(N0) \ 176 (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 177 178#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 179 STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 180 VSTORE(N0) \ 181 (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 182 183#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 184 STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 185 VSTORE(N0) \ 186 (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 187 188#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 189 STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 190 VSTORE(N0) \ 191 (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 192 193#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 194 STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 195 VSTORE(N0) \ 196 (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 197 198#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 199 STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 200 VSTORE(N0) \ 201 (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 202 203#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 204 STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 205 VSTORE(N0) \ 206 (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 207 208#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 209 STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 210 VSTORE(N0) \ 211 (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 212/** @} */ // end of groupd STORE_ROW_n 213 214/** Convert and store the 0th to (n-1)th rows of the given variables 215 * @name CONVERT_STORE_ROW_n 216 * 217 * @param[in] N0 The size of the vectors 218 * @param[in] DATA_TYPE The data type of the vectors 219 * @param[in] BASENAME The basename of the variables 220 * @param[in] PTR The base pointer 221 * @param[in] STRIDE_Y The stride value in y-axis direction 222 * @param[in] Z The offset in z-axis direction 223 * @{ 224 */ 225#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 226 VSTORE(N0) \ 227 (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 228 229#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 230 CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 231 VSTORE(N0) \ 232 (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 233 234#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 235 CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 236 VSTORE(N0) \ 237 (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 238 239#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 240 CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 241 VSTORE(N0) \ 242 (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 243 244#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 245 CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 246 VSTORE(N0) \ 247 (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 248 249#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 250 CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 251 VSTORE(N0) \ 252 (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 253 254#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 255 CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 256 VSTORE(N0) \ 257 (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 258 259#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 260 CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 261 VSTORE(N0) \ 262 (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 263 264#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 265 CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 266 VSTORE(N0) \ 267 (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 268 269#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \ 270 CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 271 VSTORE(N0) \ 272 (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 273 274#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 275 CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 276 VSTORE(N0) \ 277 (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 278 279#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 280 CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 281 VSTORE(N0) \ 282 (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 283 284#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 285 CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 286 VSTORE(N0) \ 287 (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 288 289#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 290 CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 291 VSTORE(N0) \ 292 (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 293 294#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 295 CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 296 VSTORE(N0) \ 297 (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 298 299#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 300 CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 301 VSTORE(N0) \ 302 (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 303 304/** @} */ // end of groupd CONVERT_STORE_ROW_n 305 306/** Store a block of the given size M0xN0 307 * @name STORE_BLOCK 308 * 309 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16. 310 * The data to store is expected to have consecutive names for each row. 311 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 312 * The Z offset is expected to have consecutive names. 313 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 314 * 315 * @param[in] M0 The number of rows to store 316 * @param[in] N0 The size of each vector 317 * @param[in] DATA_TYPE The data type of the vectors 318 * @param[in] BASENAME The basename of the variables 319 * @param[in] PTR The base pointer 320 * @param[in] STRIDE_Y The stride value in y-axis direction 321 * @param[in] Z The offset in z-axis direction 322 * @{ 323 */ 324#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 325#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 326/** @} */ // end of group STORE_BLOCK 327 328/** Convert and store a block of the given size M0xN0 329 * @name CONVERT_STORE_BLOCK 330 * 331 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16. 332 * The data to store is expected to have consecutive names for each row. 333 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 334 * The Z offset is expected to have consecutive names. 335 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 336 * 337 * @param[in] M0 The number of rows to store 338 * @param[in] N0 The size of each vector 339 * @param[in] DATA_TYPE The data type of the vectors 340 * @param[in] BASENAME The basename of the variables 341 * @param[in] PTR The base pointer 342 * @param[in] STRIDE_Y The stride value in y-axis direction 343 * @param[in] Z The offset in z-axis direction 344 * @{ 345 */ 346#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 347#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 348/** @} */ // end of group CONVERT_STORE_BLOCK 349 350/** Partially store the 0 to (n-1)th rows of the given variables 351 * @name STORE_ROW_PARTIAL_n 352 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0 353 * 354 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 355 * 356 * @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16 357 * @param[in] STORE_N0 The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0 358 * @param[in] DATA_TYPE The data type of the vectors 359 * @param[in] BASENAME The basename of the variables 360 * @param[in] PTR The base pointer 361 * @param[in] STRIDE_Y The stride value in y-axis direction 362 * @param[in] Z The offset in z-axis direction 363 * @{ 364 */ 365#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 366 VSTORE_PARTIAL(N0, STORE_N0) \ 367 (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 368 369#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 370 STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 371 VSTORE_PARTIAL(N0, STORE_N0) \ 372 (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 373 374#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 375 STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 376 VSTORE_PARTIAL(N0, STORE_N0) \ 377 (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 378 379#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 380 STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 381 VSTORE_PARTIAL(N0, STORE_N0) \ 382 (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 383 384#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 385 STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 386 VSTORE_PARTIAL(N0, STORE_N0) \ 387 (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 388 389#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 390 STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 391 VSTORE_PARTIAL(N0, STORE_N0) \ 392 (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 393 394#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 395 STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 396 VSTORE_PARTIAL(N0, STORE_N0) \ 397 (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 398 399#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 400 STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 401 VSTORE_PARTIAL(N0, STORE_N0) \ 402 (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 403 404#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 405 STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 406 VSTORE_PARTIAL(N0, STORE_N0) \ 407 (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 408 409#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 410 STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 411 VSTORE_PARTIAL(N0, STORE_N0) \ 412 (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 413 414#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 415 STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 416 VSTORE_PARTIAL(N0, STORE_N0) \ 417 (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 418 419#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 420 STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 421 VSTORE_PARTIAL(N0, STORE_N0) \ 422 (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 423 424#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 425 STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 426 VSTORE_PARTIAL(N0, STORE_N0) \ 427 (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 428 429#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 430 STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 431 VSTORE_PARTIAL(N0, STORE_N0) \ 432 (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 433 434#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 435 STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 436 VSTORE_PARTIAL(N0, STORE_N0) \ 437 (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 438 439#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 440 STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 441 VSTORE_PARTIAL(N0, STORE_N0) \ 442 (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 443/** @} */ // end of groupd STORE_ROW_PARTIAL_n 444 445/** Partially store a block of the given size STORE_M0xSTORE_N0 446 * @name STORE_BLOCK_PARTIAL 447 * 448 * @note The vector width @p N0 is also required for correct partial storing behaviour. 449 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 450 * 451 * The data to store is expected to have consecutive names for each row. 452 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2. 453 * The Z offset is expected to have consecutive names. 454 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 455 * 456 * @param[in] STORE_M0 The number of rows to store. Supported: 1-16 457 * @param[in] STORE_N0 The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0 458 * @param[in] N0 The size of each vector. Supported: 1, 2, 3, 4, 8, 16 459 * @param[in] DATA_TYPE The data type of the vectors 460 * @param[in] BASENAME The basename of the variables 461 * @param[in] PTR The base pointer 462 * @param[in] STRIDE_Y The stride value in y-axis direction 463 * @param[in] Z The offset in z-axis direction 464 * @{ 465 */ 466#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 467#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 468/** Store a block that can be partial in both x and y dimensions 469 * 470 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 471 * 472 * The data to store is expected to have consecutive names for each row. 473 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 474 * The Z offset is expected to have consecutive names. 475 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 476 * 477 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 478 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 479 * @param[in] DATA_TYPE The data type of the vectors 480 * @param[in] BASENAME The basename of the variables 481 * @param[in] PTR The base pointer 482 * @param[in] STRIDE_Y The stride value in y-axis direction 483 * @param[in] Z The offset in z-axis direction 484 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0) 485 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0) 486 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 487 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 488 */ 489#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 490 if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y)) \ 491 { \ 492 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 493 } \ 494 else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X)) \ 495 { \ 496 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 497 } \ 498 else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X)) \ 499 { \ 500 STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 501 } \ 502 else \ 503 { \ 504 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 505 } 506/** Store a block that can only be partial in x but not y. 507 * 508 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 509 * 510 * The data to store is expected to have consecutive names for each row. 511 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 512 * The Z offset is expected to have consecutive names. 513 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 514 * 515 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 516 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 517 * @param[in] DATA_TYPE The data type of the vectors 518 * @param[in] BASENAME The basename of the variables 519 * @param[in] PTR The base pointer 520 * @param[in] STRIDE_Y The stride value in y-axis direction 521 * @param[in] Z The offset in z-axis direction 522 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0) 523 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 524 */ 525#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \ 526 if(!(PARTIAL_COND_X)) \ 527 { \ 528 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 529 } \ 530 else \ 531 { \ 532 STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 533 } 534/** Store a block that can only be partial in y but not x. 535 * 536 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 537 * 538 * The data to store is expected to have consecutive names for each row. 539 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 540 * The Z offset is expected to have consecutive names. 541 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 542 * 543 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 544 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 545 * @param[in] DATA_TYPE The data type of the vectors 546 * @param[in] BASENAME The basename of the variables 547 * @param[in] PTR The base pointer 548 * @param[in] STRIDE_Y The stride value in y-axis direction 549 * @param[in] Z The offset in z-axis direction 550 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0) 551 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 552 */ 553#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \ 554 if(!(PARTIAL_COND_Y)) \ 555 { \ 556 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 557 } \ 558 else \ 559 { \ 560 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 561 } 562/** @} */ // end of group STORE_BLOCK_PARTIAL 563 564#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 565 566/** Boundary-aware GEMM block store 567 * @name STORE_BLOCK_BOUNDARY_AWARE 568 * This macro assumes the following schemes to achieve boundary-awareness: 569 * - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim. 570 * - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings. 571 * - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim. 572 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim. 573 * 574 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial 575 * blocks **at the end**. 576 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/ 577 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters: 578 * 579 * *--x--> x == 0 x == 1 580 * | |<------------------------------N-------------------------->| 581 * y |<--------------N0------------->|<----PARTIAL_STORE_N0----->| 582 * | -------------############################################################# 583 * * | | |...............................|...........................| 584 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.| 585 * | | |...............................|...........................| 586 * M --############################################################# 587 * | | | |...........................| 588 * y == 1 | M0 | Non-boundary block |....Boundary block in x....| 589 * | | | |...........................| 590 * |------------############################################################# 591 * 592 * Then @p PARTIAL_STORE_M0 = M % M0 and @p PARTIAL_STORE_N0 = N % N0 593 * 594 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 595 * 596 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension, 597 * and select corresponding store methods such that the boundary detection logic is only added when needed. 598 * 599 * The data to store is expected to have consecutive names for each row. 600 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 601 * The Z offset is expected to have consecutive names. 602 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 603 * 604 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 605 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 606 * @param[in] DATA_TYPE The data type of the vectors 607 * @param[in] BASENAME The basename of the variables 608 * @param[in] PTR The base pointer 609 * @param[in] STRIDE_Y The stride value in y-axis direction 610 * @param[in] Z The offset in z-axis direction 611 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0) 612 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0) 613 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 614 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 615 * @{ 616 */ 617#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 618// Case1: No partial blocks in either x or y 619#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 620 STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 621 622#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0 623// Case2: Partial blocks in y 624#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 625 STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) 626 627#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0 628// Case3: Partial blocks in x 629#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 630 STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) 631 632#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 633// Case4: Partial blocks in both x and y 634#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 635 STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) 636 637#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 638 639#endif // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 640/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE 641 642#if defined(PARTIAL_STORE_M0) 643/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding 644 * @name COMPUTE_M0_START_ROW 645 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows. 646 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent 647 * blocks in the y dimension to avoid any padding. 648 * EG: M0=4, PARTIAL_STORE_M0=1: 649 * | Non-overlapping | +M0_ROW_SHIFT (Overlapping) 650 * block 0 (partial)| start row = 0 | start row = 0 651 * block 1 (full) | start row = 4 | start row = 1 652 * block 2 (full) | start row = 8 | start row = 5 653 * 654 * @param[in] y Global id of current block in y. 655 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 656 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0) 657 * @{ 658 */ 659#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \ 660 ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0)))) 661#else // defined(PARTIAL_STORE_M0) 662#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \ 663 ((uint)(y * M0)) 664#endif // defined(PARTIAL_STORE_M0) 665/** @} */ // end of group COMPUTE_M0_START_ROW 666 667/** Store a vector that can only be partial in x. 668 * 669 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 670 * 671 * The data to store is expected to end in a 0. 672 * E.g., for basename=c, the expected name is c0. 673 * 674 * @param[in] basename The name of the variable without trailing 0 675 * @param[in] data_type The data type of the vector 676 * @param[in] ptr The base pointer 677 * @param[in] vec_size The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16 678 * @param[in] leftover The vector size if cond = true. Supported range: [1, @p vec_size0) 679 * @param[in] cond Condition to select either vec_size0 or vec_size1 680 * @{ 681 */ 682#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \ 683 STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond) 684/** @} */ // end of group STORE_VECTOR_SELECT 685 686#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 687#pragma OPENCL EXTENSION cl_khr_fp16 : enable 688#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 689 690#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) 691#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable 692#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) 693 694#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) 695#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable 696#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) 697 698#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf) 699#pragma OPENCL EXTENSION cl_arm_printf : enable 700#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf) 701 702#define GPU_ARCH_MIDGARD 0x100 703#define GPU_ARCH_BIFROST 0x200 704 705/** Concatenate two inputs. 706 * 707 * @param[in] a The first input to be concatenated 708 * @param[in] b The second input to be concatenated 709 * 710 * @return The concatenated output 711 */ 712#define CONCAT(a, b) a##b 713 714/** Expand the given vector 715 * 716 * @param[in] x The vector to be expanded 717 * 718 * @return The expanded output 719 */ 720#define EXPAND(x) x 721 722/** Clamp the given value between an upper and lower bound. 723 * 724 * @param[in] x The value to be clamped 725 * @param[in] min_val The lower bound 726 * @param[in] max_val The upper bound 727 * 728 * @return The clamped value. 729 */ 730#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val) 731 732/** REVn reverses the given vector whose size is n. 733 * @name REVn 734 * 735 * @param[in] x The vector to be reversed 736 * 737 * @return The reversed vector 738 * @{ 739 */ 740#define REV1(x) ((x)) 741#define REV2(x) ((x).s10) 742#define REV3(x) ((x).s210) 743#define REV4(x) ((x).s3210) 744#define REV8(x) ((x).s76543210) 745#define REV16(x) ((x).sFEDCBA9876543210) 746/** @} */ // end of group REVn 747 748/** Reverse the given vector. 749 * @name REVERSE 750 * 751 * @param[in] x The vector to be reversed 752 * @param[in] s The size of the vector 753 * 754 * @return The reversed vector 755 * @{ 756 */ 757#define REVERSE_STR(x, s) REV##s((x)) 758#define REVERSE(x, s) REVERSE_STR(x, s) 759/** @} */ // end of group REVERSE 760 761/** Circular-right-shift (rotate-right) the vector of size s by the amount of n. 762 * @name ROTs_n 763 * 764 * @param[in] x The vector to be shifted 765 * 766 * @return The shifted vector 767 * @{ 768 */ 769#define ROT1_0(x) ((x)) 770 771#define ROT2_0(x) ((x)) 772#define ROT2_1(x) ((x).s10) 773 774#define ROT3_0(x) ((x)) 775#define ROT3_1(x) ((x).s201) 776#define ROT3_2(x) ((x).s120) 777 778#define ROT4_0(x) ((x)) 779#define ROT4_1(x) ((x).s3012) 780#define ROT4_2(x) ((x).s2301) 781#define ROT4_3(x) ((x).s1230) 782 783#define ROT8_0(x) ((x)) 784#define ROT8_1(x) ((x).s70123456) 785#define ROT8_2(x) ((x).s67012345) 786#define ROT8_3(x) ((x).s56701234) 787#define ROT8_4(x) ((x).s45670123) 788#define ROT8_5(x) ((x).s34567012) 789#define ROT8_6(x) ((x).s23456701) 790#define ROT8_7(x) ((x).s12345670) 791 792#define ROT16_0(x) ((x)) 793#define ROT16_1(x) ((x).sF0123456789ABCDE) 794#define ROT16_2(x) ((x).sEF0123456789ABCD) 795#define ROT16_3(x) ((x).sDEF0123456789ABC) 796#define ROT16_4(x) ((x).sCDEF0123456789AB) 797#define ROT16_5(x) ((x).sBCDEF0123456789A) 798#define ROT16_6(x) ((x).sABCDEF0123456789) 799#define ROT16_7(x) ((x).s9ABCDEF012345678) 800#define ROT16_8(x) ((x).s89ABCDEF01234567) 801#define ROT16_9(x) ((x).s789ABCDEF0123456) 802#define ROT16_10(x) ((x).s6789ABCDEF012345) 803#define ROT16_11(x) ((x).s56789ABCDEF01234) 804#define ROT16_12(x) ((x).s456789ABCDEF0123) 805#define ROT16_13(x) ((x).s3456789ABCDEF012) 806#define ROT16_14(x) ((x).s23456789ABCDEF01) 807#define ROT16_15(x) ((x).s123456789ABCDEF0) 808/** @} */ // end of group ROTs_n 809 810/** Circular-right-shift (rotate-right) the given vector by the given amount. 811 * @name ROTATE 812 * 813 * @param[in] x The vector to be shifted 814 * @param[in] s The size of the vector 815 * @param[in] n The amount to be shifted 816 * 817 * @return The shifted vector 818 * @{ 819 */ 820#define ROTATE_STR(x, s, n) ROT##s##_##n(x) 821#define ROTATE(x, s, n) ROTATE_STR(x, s, n) 822/** @} */ // end of group ROTATE 823 824/** Creates a vector of size n filled with offset values corresponding to the location of each element. 825 * @name V_OFFSn 826 * 827 * @param[in] dt The data type of the output vector 828 * 829 * @return The vector filled with offset values 830 * @{ 831 */ 832#define V_OFFS1(dt) (dt##1)(0) 833#define V_OFFS2(dt) (dt##2)(0, 1) 834#define V_OFFS3(dt) (dt##3)(0, 1, 2) 835#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3) 836#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7) 837#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) 838/** @} */ // end of group V_OFFSn 839 840/** Create a vector filled with offset values corresponding to the location of each element. 841 * @name VEC_OFFS 842 * 843 * @param[in] dt The data type of the output vector 844 * @param[in] s The size of the output vector 845 * 846 * @return The vector filled with offset values 847 * @{ 848 */ 849#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt) 850#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s) 851/** @} */ // end of group VEC_OFFS 852 853#define VLOAD_STR(size) vload##size 854#define VLOAD(size) VLOAD_STR(size) 855 856#define PIXEL_UNIT4 1 857#define PIXEL_UNIT8 2 858#define PIXEL_UNIT16 4 859 860/** Utility macro to convert a vector size in pixel unit. 861 * 862 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT 863 * 864 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported 865 * 866 * @return The pixel unit (number of pixels) 867 * @{ 868 */ 869#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size 870#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) 871/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT 872 873#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord))); 874#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord))); 875#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord))); 876 877#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 878#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord))); 879#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord))); 880#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord))); 881#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 882 883/** Utility macro to read a 2D OpenCL image object. 884 * 885 * @note Coordinates are not normalized 886 * 887 * @param[in] data_type Data type 888 * @param[in] n0 Number of pixel to read. Only 1,2 and 4 is supported 889 * @param[in] img OpenCL image object 890 * @param[in] x_coord The x coordinate for the top-left pixel 891 * @param[in] y_coord The y coordinate for the top-left pixel 892 * 893 * @return Pixels from the 2D OpenCL image object 894 * @{ 895 */ 896#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord) 897#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) 898 899#define VSTORE_STR(size) vstore##size 900#define VSTORE(size) VSTORE_STR(size) 901 902#define float1 float 903#define half1 half 904#define char1 char 905#define uchar1 uchar 906#define short1 short 907#define ushort1 ushort 908#define int1 int 909#define uint1 uint 910#define long1 long 911#define ulong1 ulong 912#define double1 double 913 914#define vload1(OFFSET, PTR) *(OFFSET + PTR) 915#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA 916 917/** Extended partial vstore that correctly handles scalar values as well. 918 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops 919 * @name VSTORE_PARTIAL 920 * 921 * @note With this macro, the passed data can be both a vector and a scalar 922 * @note @p store_size needs to be <= @p size 923 * eg 1: Valid 924 * VSTORE_PARTIAL(16, 15) ...; 925 * eg 2: Invalid 926 * VSTORE_PARTIAL(4, 7) ...; 927 * 928 * @param[in] size The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16 929 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size 930 * @{ 931 */ 932#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size 933#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size) 934 935#define NO_STORE(data, offs, ptr) \ 936 { \ 937 } 938 939// Size == 1 (scalar) 940#define vstore_partial_1_0 NO_STORE 941#define vstore_partial_1_1 vstore1 942#define vstore_partial_1_2 NO_STORE 943#define vstore_partial_1_3 NO_STORE 944#define vstore_partial_1_4 NO_STORE 945#define vstore_partial_1_5 NO_STORE 946#define vstore_partial_1_6 NO_STORE 947#define vstore_partial_1_7 NO_STORE 948#define vstore_partial_1_8 NO_STORE 949#define vstore_partial_1_9 NO_STORE 950#define vstore_partial_1_10 NO_STORE 951#define vstore_partial_1_11 NO_STORE 952#define vstore_partial_1_12 NO_STORE 953#define vstore_partial_1_13 NO_STORE 954#define vstore_partial_1_14 NO_STORE 955#define vstore_partial_1_15 NO_STORE 956#define vstore_partial_1_16 NO_STORE 957// Size == 2 958#define vstore_partial_2_0 NO_STORE 959#define vstore_partial_2_1 vstore_partial_1 960#define vstore_partial_2_2 vstore_partial_2 961#define vstore_partial_2_3 NO_STORE 962#define vstore_partial_2_4 NO_STORE 963#define vstore_partial_2_5 NO_STORE 964#define vstore_partial_2_6 NO_STORE 965#define vstore_partial_2_7 NO_STORE 966#define vstore_partial_2_8 NO_STORE 967#define vstore_partial_2_9 NO_STORE 968#define vstore_partial_2_10 NO_STORE 969#define vstore_partial_2_11 NO_STORE 970#define vstore_partial_2_12 NO_STORE 971#define vstore_partial_2_13 NO_STORE 972#define vstore_partial_2_14 NO_STORE 973#define vstore_partial_2_15 NO_STORE 974#define vstore_partial_2_16 NO_STORE 975// Size == 3 976#define vstore_partial_3_0 NO_STORE 977#define vstore_partial_3_1 vstore_partial_1 978#define vstore_partial_3_2 vstore_partial_2 979#define vstore_partial_3_3 vstore_partial_3 980#define vstore_partial_3_4 NO_STORE 981#define vstore_partial_3_5 NO_STORE 982#define vstore_partial_3_6 NO_STORE 983#define vstore_partial_3_7 NO_STORE 984#define vstore_partial_3_8 NO_STORE 985#define vstore_partial_3_9 NO_STORE 986#define vstore_partial_3_10 NO_STORE 987#define vstore_partial_3_11 NO_STORE 988#define vstore_partial_3_12 NO_STORE 989#define vstore_partial_3_13 NO_STORE 990#define vstore_partial_3_14 NO_STORE 991#define vstore_partial_3_15 NO_STORE 992#define vstore_partial_3_16 NO_STORE 993// Size == 4 994#define vstore_partial_4_0 NO_STORE 995#define vstore_partial_4_1 vstore_partial_1 996#define vstore_partial_4_2 vstore_partial_2 997#define vstore_partial_4_3 vstore_partial_3 998#define vstore_partial_4_4 vstore_partial_4 999#define vstore_partial_4_5 NO_STORE 1000#define vstore_partial_4_6 NO_STORE 1001#define vstore_partial_4_7 NO_STORE 1002#define vstore_partial_4_8 NO_STORE 1003#define vstore_partial_4_9 NO_STORE 1004#define vstore_partial_4_10 NO_STORE 1005#define vstore_partial_4_11 NO_STORE 1006#define vstore_partial_4_12 NO_STORE 1007#define vstore_partial_4_13 NO_STORE 1008#define vstore_partial_4_14 NO_STORE 1009#define vstore_partial_4_15 NO_STORE 1010#define vstore_partial_4_16 NO_STORE 1011// Size == 8 1012#define vstore_partial_8_0 NO_STORE 1013#define vstore_partial_8_1 vstore_partial_1 1014#define vstore_partial_8_2 vstore_partial_2 1015#define vstore_partial_8_3 vstore_partial_3 1016#define vstore_partial_8_4 vstore_partial_4 1017#define vstore_partial_8_5 vstore_partial_5 1018#define vstore_partial_8_6 vstore_partial_6 1019#define vstore_partial_8_7 vstore_partial_7 1020#define vstore_partial_8_8 vstore_partial_8 1021#define vstore_partial_8_9 NO_STORE 1022#define vstore_partial_8_10 NO_STORE 1023#define vstore_partial_8_11 NO_STORE 1024#define vstore_partial_8_12 NO_STORE 1025#define vstore_partial_8_13 NO_STORE 1026#define vstore_partial_8_14 NO_STORE 1027#define vstore_partial_8_15 NO_STORE 1028#define vstore_partial_8_16 NO_STORE 1029// Size == 16 1030#define vstore_partial_16_0 NO_STORE 1031#define vstore_partial_16_1 vstore_partial_1 1032#define vstore_partial_16_2 vstore_partial_2 1033#define vstore_partial_16_3 vstore_partial_3 1034#define vstore_partial_16_4 vstore_partial_4 1035#define vstore_partial_16_5 vstore_partial_5 1036#define vstore_partial_16_6 vstore_partial_6 1037#define vstore_partial_16_7 vstore_partial_7 1038#define vstore_partial_16_8 vstore_partial_8 1039#define vstore_partial_16_9 vstore_partial_9 1040#define vstore_partial_16_10 vstore_partial_10 1041#define vstore_partial_16_11 vstore_partial_11 1042#define vstore_partial_16_12 vstore_partial_12 1043#define vstore_partial_16_13 vstore_partial_13 1044#define vstore_partial_16_14 vstore_partial_14 1045#define vstore_partial_16_15 vstore_partial_15 1046#define vstore_partial_16_16 vstore_partial_16 1047 1048/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops 1049 * @name vstore_partial_n 1050 * 1051 * @note @p DATA needs to be a vector not a scalar 1052 * @note n needs to be <= the vector width of the input variable @p DATA 1053 * eg 1: Valid 1054 * vstore_partial_15(var:float16, 0, 0xabcd); 1055 * eg 2: Invalid 1056 * vstore_partial_7(var:float4, 0, 0xabcd); 1057 * 1058 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty. 1059 * 1060 * @param[in] DATA The name of the variable 1061 * @param[in] OFFSET Offset in n 1062 * @param[in] PTR The base pointer 1063 * @{ 1064 */ 1065#define vstore_partial_1(DATA, OFFSET, PTR) \ 1066 vstore1(DATA.s0, OFFSET, PTR); 1067 1068#define vstore_partial_2(DATA, OFFSET, PTR) \ 1069 vstore2(DATA.s01, OFFSET, PTR); 1070 1071#define vstore_partial_3(DATA, OFFSET, PTR) \ 1072 vstore3(DATA.s012, OFFSET, PTR); 1073 1074#define vstore_partial_4(DATA, OFFSET, PTR) \ 1075 vstore4(DATA.s0123, OFFSET, PTR); 1076 1077#define vstore_partial_5(DATA, OFFSET, PTR) \ 1078 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 1079 vstore1(DATA.s4, OFFSET, PTR + 4); 1080 1081#define vstore_partial_6(DATA, OFFSET, PTR) \ 1082 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 1083 vstore_partial_2(DATA.s45, OFFSET, PTR + 4); 1084 1085#define vstore_partial_7(DATA, OFFSET, PTR) \ 1086 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 1087 vstore_partial_3(DATA.s456, OFFSET, PTR + 4); 1088 1089#define vstore_partial_8(DATA, OFFSET, PTR) \ 1090 vstore8(DATA.s01234567, OFFSET, PTR); 1091 1092#define vstore_partial_9(DATA, OFFSET, PTR) \ 1093 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1094 vstore1(DATA.s8, OFFSET, PTR + 8); 1095 1096#define vstore_partial_10(DATA, OFFSET, PTR) \ 1097 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1098 vstore_partial_2(DATA.s89, OFFSET, PTR + 8); 1099 1100#define vstore_partial_11(DATA, OFFSET, PTR) \ 1101 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1102 vstore_partial_3(DATA.s89a, OFFSET, PTR + 8); 1103 1104#define vstore_partial_12(DATA, OFFSET, PTR) \ 1105 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1106 vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8); 1107 1108#define vstore_partial_13(DATA, OFFSET, PTR) \ 1109 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1110 vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8); 1111 1112#define vstore_partial_14(DATA, OFFSET, PTR) \ 1113 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1114 vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8); 1115 1116#define vstore_partial_15(DATA, OFFSET, PTR) \ 1117 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 1118 vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8); 1119 1120#define vstore_partial_16(DATA, OFFSET, PTR) \ 1121 vstore16(DATA, OFFSET, PTR); 1122/** @} */ // end of groupd vstore_partial_n 1123/** @} */ // end of groupd VSTORE_PARTIAL 1124 1125// Convert built-in functions with _sat modifier are not supported in floating point so we create defines 1126// without _sat to overcome this issue 1127#define convert_float_sat convert_float 1128#define convert_float1_sat convert_float 1129#define convert_float2_sat convert_float2 1130#define convert_float3_sat convert_float3 1131#define convert_float4_sat convert_float4 1132#define convert_float8_sat convert_float8 1133#define convert_float16_sat convert_float16 1134#define convert_half_sat convert_float 1135#define convert_half1_sat convert_half 1136#define convert_half2_sat convert_half2 1137#define convert_half3_sat convert_half3 1138#define convert_half4_sat convert_half4 1139#define convert_half8_sat convert_half8 1140#define convert_half16_sat convert_half16 1141 1142#define convert_float1 convert_float 1143#define convert_half1 convert_half 1144#define convert_char1 convert_char 1145#define convert_uchar1 convert_uchar 1146#define convert_short1 convert_short 1147#define convert_ushort1 convert_ushort 1148#define convert_int1 convert_int 1149#define convert_uint1 convert_uint 1150#define convert_long1 convert_long 1151#define convert_ulong1 convert_ulong 1152#define convert_double1 convert_double 1153 1154#define convert_char1_sat convert_char_sat 1155#define convert_uchar1_sat convert_uchar_sat 1156#define convert_short1_sat convert_short_sat 1157#define convert_ushort1_sat convert_ushort_sat 1158#define convert_int1_sat convert_int_sat 1159#define convert_uint1_sat convert_uint_sat 1160#define convert_long1_sat convert_long_sat 1161#define convert_ulong1_sat convert_ulong_sat 1162#define convert_double1_sat convert_double_sat 1163 1164#define VEC_DATA_TYPE_STR(type, size) type##size 1165#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size) 1166 1167#define CONVERT_STR(x, type) (convert_##type((x))) 1168#define CONVERT(x, type) CONVERT_STR(x, type) 1169 1170#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x))) 1171#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type) 1172 1173#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x))) 1174#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round) 1175 1176#define select_vec_dt_uchar(size) uchar##size 1177#define select_vec_dt_char(size) char##size 1178#define select_vec_dt_ushort(size) ushort##size 1179#define select_vec_dt_short(size) short##size 1180#define select_vec_dt_half(size) short##size 1181#define select_vec_dt_uint(size) uint##size 1182#define select_vec_dt_int(size) int##size 1183#define select_vec_dt_float(size) int##size 1184#define select_vec_dt_ulong(size) ulong##size 1185#define select_vec_dt_long(size) long##size 1186 1187#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size) 1188#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size) 1189#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1) 1190 1191#define sum_reduce_1(x) (x) 1192#define sum_reduce_2(x) ((x).s0) + ((x).s1) 1193#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2) 1194#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23) 1195#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567) 1196#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF) 1197 1198#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x) 1199#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size) 1200 1201#define max_reduce_1(x) (x) 1202#define max_reduce_2(x) max(((x).s0), ((x).s1)) 1203#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2)) 1204#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23)) 1205#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567)) 1206#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF)) 1207 1208#define MAX_REDUCE_STR(x, size) max_reduce_##size(x) 1209#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size) 1210 1211#define VECTOR_DECLARATION(name) \ 1212 __global uchar *name##_ptr, \ 1213 uint name##_stride_x, \ 1214 uint name##_step_x, \ 1215 uint name##_offset_first_element_in_bytes 1216 1217#define IMAGE_DECLARATION(name) \ 1218 __global uchar *name##_ptr, \ 1219 uint name##_stride_x, \ 1220 uint name##_step_x, \ 1221 uint name##_stride_y, \ 1222 uint name##_step_y, \ 1223 uint name##_offset_first_element_in_bytes 1224 1225#define TENSOR3D_DECLARATION(name) \ 1226 __global uchar *name##_ptr, \ 1227 uint name##_stride_x, \ 1228 uint name##_step_x, \ 1229 uint name##_stride_y, \ 1230 uint name##_step_y, \ 1231 uint name##_stride_z, \ 1232 uint name##_step_z, \ 1233 uint name##_offset_first_element_in_bytes 1234 1235#define TENSOR4D_DECLARATION(name) \ 1236 __global uchar *name##_ptr, \ 1237 uint name##_stride_x, \ 1238 uint name##_step_x, \ 1239 uint name##_stride_y, \ 1240 uint name##_step_y, \ 1241 uint name##_stride_z, \ 1242 uint name##_step_z, \ 1243 uint name##_stride_w, \ 1244 uint name##_step_w, \ 1245 uint name##_offset_first_element_in_bytes 1246 1247#define CONVERT_TO_VECTOR_STRUCT(name) \ 1248 update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x) 1249 1250#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \ 1251 update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0) 1252 1253#define CONVERT_TO_IMAGE_STRUCT(name) \ 1254 update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y) 1255 1256#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \ 1257 update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0) 1258 1259#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \ 1260 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z) 1261 1262#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \ 1263 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z) 1264 1265#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \ 1266 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z) 1267 1268#define CONVERT_TO_TENSOR3D_STRUCT(name) \ 1269 update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 1270 name##_stride_z, name##_step_z) 1271 1272#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \ 1273 update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0) 1274 1275#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size) \ 1276 update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 1277 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size) 1278 1279#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \ 1280 update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size) 1281 1282#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name) \ 1283 tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 1284 name##_stride_z, name##_step_z) 1285 1286/** Structure to hold Vector information */ 1287typedef struct Vector 1288{ 1289 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 1290 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 1291 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 1292} Vector; 1293 1294/** Structure to hold Image information */ 1295typedef struct Image 1296{ 1297 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 1298 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 1299 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 1300 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 1301} Image; 1302 1303/** Structure to hold 3D tensor information */ 1304typedef struct Tensor3D 1305{ 1306 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 1307 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 1308 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 1309 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 1310 int stride_z; /**< Stride of the image in Z dimension (in bytes) */ 1311} Tensor3D; 1312 1313/** Structure to hold 4D tensor information */ 1314typedef struct Tensor4D 1315{ 1316 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 1317 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 1318 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 1319 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 1320 int stride_z; /**< Stride of the image in Z dimension (in bytes) */ 1321 int stride_w; /**< Stride of the image in W dimension (in bytes) */ 1322} Tensor4D; 1323 1324/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data. 1325 * 1326 * @param[in] ptr Pointer to the starting postion of the buffer 1327 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector 1328 * @param[in] stride_x Stride of the vector in X dimension (in bytes) 1329 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 1330 * 1331 * @return An image object 1332 */ 1333inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x) 1334{ 1335 Vector vector = 1336 { 1337 .ptr = ptr, 1338 .offset_first_element_in_bytes = offset_first_element_in_bytes, 1339 .stride_x = stride_x, 1340 }; 1341 vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x; 1342 return vector; 1343} 1344 1345/** Wrap image information into an Image structure, and make the pointer point at this workitem's data. 1346 * 1347 * @param[in] ptr Pointer to the starting postion of the buffer 1348 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 1349 * @param[in] stride_x Stride of the image in X dimension (in bytes) 1350 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 1351 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 1352 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 1353 * 1354 * @return An image object 1355 */ 1356inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y) 1357{ 1358 Image img = 1359 { 1360 .ptr = ptr, 1361 .offset_first_element_in_bytes = offset_first_element_in_bytes, 1362 .stride_x = stride_x, 1363 .stride_y = stride_y 1364 }; 1365 img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y; 1366 return img; 1367} 1368 1369/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data. 1370 * 1371 * @param[in] ptr Pointer to the starting postion of the buffer 1372 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 1373 * @param[in] stride_x Stride of the image in X dimension (in bytes) 1374 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 1375 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 1376 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 1377 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 1378 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 1379 * 1380 * @return A 3D tensor object 1381 */ 1382inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 1383{ 1384 Image img = 1385 { 1386 .ptr = ptr, 1387 .offset_first_element_in_bytes = offset_first_element_in_bytes, 1388 .stride_x = stride_x, 1389 .stride_y = stride_y 1390 }; 1391 img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z; 1392 return img; 1393} 1394 1395/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data. 1396 * 1397 * @param[in] ptr Pointer to the starting postion of the buffer 1398 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 1399 * @param[in] stride_x Stride of the image in X dimension (in bytes) 1400 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 1401 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 1402 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 1403 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 1404 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 1405 * 1406 * @return A 3D tensor object 1407 */ 1408inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 1409{ 1410 Tensor3D tensor = 1411 { 1412 .ptr = ptr, 1413 .offset_first_element_in_bytes = offset_first_element_in_bytes, 1414 .stride_x = stride_x, 1415 .stride_y = stride_y, 1416 .stride_z = stride_z 1417 }; 1418 tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z; 1419 return tensor; 1420} 1421 1422/** Wrap 3D tensor information into an tensor structure. 1423 * 1424 * @param[in] ptr Pointer to the starting postion of the buffer 1425 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 1426 * @param[in] stride_x Stride of the image in X dimension (in bytes) 1427 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 1428 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 1429 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 1430 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 1431 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 1432 * 1433 * @return A 3D tensor object 1434 */ 1435inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 1436{ 1437 Tensor3D tensor = 1438 { 1439 .ptr = ptr, 1440 .offset_first_element_in_bytes = offset_first_element_in_bytes, 1441 .stride_x = stride_x, 1442 .stride_y = stride_y, 1443 .stride_z = stride_z 1444 }; 1445 return tensor; 1446} 1447 1448inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w, 1449 uint step_w, 1450 uint mod_size) 1451{ 1452 Tensor4D tensor = 1453 { 1454 .ptr = ptr, 1455 .offset_first_element_in_bytes = offset_first_element_in_bytes, 1456 .stride_x = stride_x, 1457 .stride_y = stride_y, 1458 .stride_z = stride_z, 1459 .stride_w = stride_w 1460 }; 1461 1462 tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w; 1463 return tensor; 1464} 1465 1466/** Get the pointer position of a Vector 1467 * 1468 * @param[in] vec Pointer to the starting position of the buffer 1469 * @param[in] x Relative X position 1470 */ 1471inline __global const uchar *vector_offset(const Vector *vec, int x) 1472{ 1473 return vec->ptr + x * vec->stride_x; 1474} 1475 1476/** Get the pointer position of a Image 1477 * 1478 * @param[in] img Pointer to the starting position of the buffer 1479 * @param[in] x Relative X position 1480 * @param[in] y Relative Y position 1481 */ 1482inline __global uchar *offset(const Image *img, int x, int y) 1483{ 1484 return img->ptr + x * img->stride_x + y * img->stride_y; 1485} 1486 1487/** Get the pointer position of a Tensor3D 1488 * 1489 * @param[in] tensor Pointer to the starting position of the buffer 1490 * @param[in] x Relative X position 1491 * @param[in] y Relative Y position 1492 * @param[in] z Relative Z position 1493 */ 1494inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z) 1495{ 1496 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z; 1497} 1498 1499/** Get the pointer position of a Tensor4D 1500 * 1501 * @param[in] tensor Pointer to the starting position of the buffer 1502 * @param[in] x Relative X position 1503 * @param[in] y Relative Y position 1504 * @param[in] z Relative Z position 1505 * @param[in] w Relative W position 1506 */ 1507inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w) 1508{ 1509 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w; 1510} 1511 1512/** Get the offset for a given linear index of a Tensor3D 1513 * 1514 * @param[in] tensor Pointer to the starting position of the buffer 1515 * @param[in] width Width of the input tensor 1516 * @param[in] height Height of the input tensor 1517 * @param[in] depth Depth of the input tensor 1518 * @param[in] index Linear index 1519 */ 1520inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index) 1521{ 1522 uint num_elements = width * height; 1523 1524 const uint z = index / num_elements; 1525 1526 index %= num_elements; 1527 1528 const uint y = index / width; 1529 1530 index %= width; 1531 1532 const uint x = index; 1533 1534 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes; 1535} 1536 1537#endif // _HELPER_H 1538 1539#if GPU_ARCH == GPU_ARCH_BIFROST 1540#define MLA(a, b, c) (fma(c, b, a)) 1541#else // GPU_ARCH == GPU_ARCH_BIFROST 1542#define MLA(a, b, c) ((b) * (c) + (a)) 1543#endif // GPU_ARCH == GPU_ARCH_BIFROST 1544 1545// Hard-Swish 1546#define hard_swish_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * ((min(max((x + (DATA_TYPE)3.0), (DATA_TYPE)0.0), (DATA_TYPE)6.0)) * (DATA_TYPE)0.166666667)) 1547 1548// Logistic Activation 1549#define logistic_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((DATA_TYPE)1.0 / ((DATA_TYPE)1.0 + exp(-x))) 1550 1551// Hyperbolic Tangent Activation 1552#define tanh_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((DATA_TYPE)A_VAL * tanh((DATA_TYPE)B_VAL * x)) 1553 1554// RELU Tangent Activation 1555#define relu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (max((DATA_TYPE)0.0, x)) 1556 1557// Bounded RELU Activation 1558#define brelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (min((DATA_TYPE)A_VAL, max((DATA_TYPE)0.0, x))) 1559 1560// Lower Upper Bounded RELU Activation 1561#define lu_brelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (min(max(x, (DATA_TYPE)B_VAL), (DATA_TYPE)A_VAL)) 1562 1563// Leaky RELU Activation 1564#define lrelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ((min(x, (DATA_TYPE)0.0) * (DATA_TYPE)A_VAL) + max(x, (DATA_TYPE)0.0)) 1565 1566// Soft RELU Activation 1567#define srelu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (log((DATA_TYPE)1.0 + exp(x))) 1568 1569// ELU Activation 1570#define elu_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (select(((DATA_TYPE)A_VAL * (exp(x) - (DATA_TYPE)1.0)), x, (SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE))isgreaterequal(x, (DATA_TYPE)0.0))) 1571 1572// Absolute Activation 1573#define abs_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (fabs(x)) 1574 1575// Square Activation 1576#define square_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x * x) 1577 1578// Square-root Activation 1579#define sqrt_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (sqrt(x)) 1580 1581// Linear Activation 1582#define linear_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (MLA((DATA_TYPE)B_VAL, (DATA_TYPE)A_VAL, x)) 1583 1584// Identity Activation 1585#define identity_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) (x) 1586 1587#define ACT_OP(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) op##_op(DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) 1588 1589#define ACTIVATION(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) ACT_OP(op, DATA_TYPE, VEC_SIZE, x, A_VAL, B_VAL) 1590/* 1591 * Copyright (c) 2016-2020 Arm Limited. 1592 * 1593 * SPDX-License-Identifier: MIT 1594 * 1595 * Permission is hereby granted, free of charge, to any person obtaining a copy 1596 * of this software and associated documentation files (the "Software"), to 1597 * deal in the Software without restriction, including without limitation the 1598 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 1599 * sell copies of the Software, and to permit persons to whom the Software is 1600 * furnished to do so, subject to the following conditions: 1601 * 1602 * The above copyright notice and this permission notice shall be included in all 1603 * copies or substantial portions of the Software. 1604 * 1605 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1606 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1607 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1608 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 1609 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 1610 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 1611 * SOFTWARE. 1612 */ 1613#ifndef ARM_COMPUTE_HELPER_H 1614#define ARM_COMPUTE_HELPER_H 1615 1616/* 1617 * Copyright (c) 2020 Arm Limited. 1618 * 1619 * SPDX-License-Identifier: MIT 1620 * 1621 * Permission is hereby granted, free of charge, to any person obtaining a copy 1622 * of this software and associated documentation files (the "Software"), to 1623 * deal in the Software without restriction, including without limitation the 1624 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 1625 * sell copies of the Software, and to permit persons to whom the Software is 1626 * furnished to do so, subject to the following conditions: 1627 * 1628 * The above copyright notice and this permission notice shall be included in all 1629 * copies or substantial portions of the Software. 1630 * 1631 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1632 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1633 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1634 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 1635 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 1636 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 1637 * SOFTWARE. 1638 */ 1639 1640/** Store the 0 to (n-1)th rows of the given variables 1641 * @name STORE_ROW_n 1642 * 1643 * @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16 1644 * @param[in] DATA_TYPE The data type of the vectors 1645 * @param[in] BASENAME The basename of the variables 1646 * @param[in] PTR The base pointer 1647 * @param[in] STRIDE_Y The stride value in y-axis direction 1648 * @param[in] Z The offset in z-axis direction 1649 * @{ 1650 */ 1651#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1652 VSTORE(N0) \ 1653 (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 1654 1655#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1656 STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1657 VSTORE(N0) \ 1658 (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 1659 1660#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1661 STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1662 VSTORE(N0) \ 1663 (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 1664 1665#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1666 STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1667 VSTORE(N0) \ 1668 (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 1669 1670#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1671 STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1672 VSTORE(N0) \ 1673 (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 1674 1675#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1676 STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1677 VSTORE(N0) \ 1678 (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 1679 1680#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1681 STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1682 VSTORE(N0) \ 1683 (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 1684 1685#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1686 STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1687 VSTORE(N0) \ 1688 (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 1689 1690#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1691 STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1692 VSTORE(N0) \ 1693 (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 1694 1695#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1696 STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1697 VSTORE(N0) \ 1698 (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 1699 1700#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1701 STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1702 VSTORE(N0) \ 1703 (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 1704 1705#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1706 STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1707 VSTORE(N0) \ 1708 (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 1709 1710#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1711 STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1712 VSTORE(N0) \ 1713 (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 1714 1715#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1716 STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1717 VSTORE(N0) \ 1718 (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 1719 1720#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1721 STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1722 VSTORE(N0) \ 1723 (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 1724 1725#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1726 STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1727 VSTORE(N0) \ 1728 (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 1729/** @} */ // end of groupd STORE_ROW_n 1730 1731/** Convert and store the 0th to (n-1)th rows of the given variables 1732 * @name CONVERT_STORE_ROW_n 1733 * 1734 * @param[in] N0 The size of the vectors 1735 * @param[in] DATA_TYPE The data type of the vectors 1736 * @param[in] BASENAME The basename of the variables 1737 * @param[in] PTR The base pointer 1738 * @param[in] STRIDE_Y The stride value in y-axis direction 1739 * @param[in] Z The offset in z-axis direction 1740 * @{ 1741 */ 1742#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1743 VSTORE(N0) \ 1744 (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 1745 1746#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1747 CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1748 VSTORE(N0) \ 1749 (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 1750 1751#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1752 CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1753 VSTORE(N0) \ 1754 (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 1755 1756#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1757 CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1758 VSTORE(N0) \ 1759 (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 1760 1761#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1762 CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1763 VSTORE(N0) \ 1764 (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 1765 1766#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1767 CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1768 VSTORE(N0) \ 1769 (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 1770 1771#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1772 CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1773 VSTORE(N0) \ 1774 (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 1775 1776#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1777 CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1778 VSTORE(N0) \ 1779 (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 1780 1781#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1782 CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1783 VSTORE(N0) \ 1784 (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 1785 1786#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \ 1787 CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1788 VSTORE(N0) \ 1789 (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 1790 1791#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1792 CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1793 VSTORE(N0) \ 1794 (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 1795 1796#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1797 CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1798 VSTORE(N0) \ 1799 (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 1800 1801#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1802 CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1803 VSTORE(N0) \ 1804 (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 1805 1806#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1807 CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1808 VSTORE(N0) \ 1809 (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 1810 1811#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1812 CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1813 VSTORE(N0) \ 1814 (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 1815 1816#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1817 CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1818 VSTORE(N0) \ 1819 (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 1820 1821/** @} */ // end of groupd CONVERT_STORE_ROW_n 1822 1823/** Store a block of the given size M0xN0 1824 * @name STORE_BLOCK 1825 * 1826 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16. 1827 * The data to store is expected to have consecutive names for each row. 1828 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 1829 * The Z offset is expected to have consecutive names. 1830 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 1831 * 1832 * @param[in] M0 The number of rows to store 1833 * @param[in] N0 The size of each vector 1834 * @param[in] DATA_TYPE The data type of the vectors 1835 * @param[in] BASENAME The basename of the variables 1836 * @param[in] PTR The base pointer 1837 * @param[in] STRIDE_Y The stride value in y-axis direction 1838 * @param[in] Z The offset in z-axis direction 1839 * @{ 1840 */ 1841#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 1842#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 1843/** @} */ // end of group STORE_BLOCK 1844 1845/** Convert and store a block of the given size M0xN0 1846 * @name CONVERT_STORE_BLOCK 1847 * 1848 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16. 1849 * The data to store is expected to have consecutive names for each row. 1850 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 1851 * The Z offset is expected to have consecutive names. 1852 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 1853 * 1854 * @param[in] M0 The number of rows to store 1855 * @param[in] N0 The size of each vector 1856 * @param[in] DATA_TYPE The data type of the vectors 1857 * @param[in] BASENAME The basename of the variables 1858 * @param[in] PTR The base pointer 1859 * @param[in] STRIDE_Y The stride value in y-axis direction 1860 * @param[in] Z The offset in z-axis direction 1861 * @{ 1862 */ 1863#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 1864#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 1865/** @} */ // end of group CONVERT_STORE_BLOCK 1866 1867/** Partially store the 0 to (n-1)th rows of the given variables 1868 * @name STORE_ROW_PARTIAL_n 1869 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0 1870 * 1871 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 1872 * 1873 * @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16 1874 * @param[in] STORE_N0 The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0 1875 * @param[in] DATA_TYPE The data type of the vectors 1876 * @param[in] BASENAME The basename of the variables 1877 * @param[in] PTR The base pointer 1878 * @param[in] STRIDE_Y The stride value in y-axis direction 1879 * @param[in] Z The offset in z-axis direction 1880 * @{ 1881 */ 1882#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1883 VSTORE_PARTIAL(N0, STORE_N0) \ 1884 (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 1885 1886#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1887 STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1888 VSTORE_PARTIAL(N0, STORE_N0) \ 1889 (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 1890 1891#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1892 STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1893 VSTORE_PARTIAL(N0, STORE_N0) \ 1894 (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 1895 1896#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1897 STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1898 VSTORE_PARTIAL(N0, STORE_N0) \ 1899 (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 1900 1901#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1902 STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1903 VSTORE_PARTIAL(N0, STORE_N0) \ 1904 (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 1905 1906#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1907 STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1908 VSTORE_PARTIAL(N0, STORE_N0) \ 1909 (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 1910 1911#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1912 STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1913 VSTORE_PARTIAL(N0, STORE_N0) \ 1914 (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 1915 1916#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1917 STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1918 VSTORE_PARTIAL(N0, STORE_N0) \ 1919 (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 1920 1921#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1922 STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1923 VSTORE_PARTIAL(N0, STORE_N0) \ 1924 (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 1925 1926#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1927 STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1928 VSTORE_PARTIAL(N0, STORE_N0) \ 1929 (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 1930 1931#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1932 STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1933 VSTORE_PARTIAL(N0, STORE_N0) \ 1934 (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 1935 1936#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1937 STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1938 VSTORE_PARTIAL(N0, STORE_N0) \ 1939 (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 1940 1941#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1942 STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1943 VSTORE_PARTIAL(N0, STORE_N0) \ 1944 (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 1945 1946#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1947 STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1948 VSTORE_PARTIAL(N0, STORE_N0) \ 1949 (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 1950 1951#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1952 STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1953 VSTORE_PARTIAL(N0, STORE_N0) \ 1954 (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 1955 1956#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1957 STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 1958 VSTORE_PARTIAL(N0, STORE_N0) \ 1959 (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 1960/** @} */ // end of groupd STORE_ROW_PARTIAL_n 1961 1962/** Partially store a block of the given size STORE_M0xSTORE_N0 1963 * @name STORE_BLOCK_PARTIAL 1964 * 1965 * @note The vector width @p N0 is also required for correct partial storing behaviour. 1966 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 1967 * 1968 * The data to store is expected to have consecutive names for each row. 1969 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2. 1970 * The Z offset is expected to have consecutive names. 1971 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 1972 * 1973 * @param[in] STORE_M0 The number of rows to store. Supported: 1-16 1974 * @param[in] STORE_N0 The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0 1975 * @param[in] N0 The size of each vector. Supported: 1, 2, 3, 4, 8, 16 1976 * @param[in] DATA_TYPE The data type of the vectors 1977 * @param[in] BASENAME The basename of the variables 1978 * @param[in] PTR The base pointer 1979 * @param[in] STRIDE_Y The stride value in y-axis direction 1980 * @param[in] Z The offset in z-axis direction 1981 * @{ 1982 */ 1983#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 1984#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 1985/** Store a block that can be partial in both x and y dimensions 1986 * 1987 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 1988 * 1989 * The data to store is expected to have consecutive names for each row. 1990 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 1991 * The Z offset is expected to have consecutive names. 1992 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 1993 * 1994 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 1995 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 1996 * @param[in] DATA_TYPE The data type of the vectors 1997 * @param[in] BASENAME The basename of the variables 1998 * @param[in] PTR The base pointer 1999 * @param[in] STRIDE_Y The stride value in y-axis direction 2000 * @param[in] Z The offset in z-axis direction 2001 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0) 2002 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0) 2003 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 2004 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 2005 */ 2006#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 2007 if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y)) \ 2008 { \ 2009 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2010 } \ 2011 else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X)) \ 2012 { \ 2013 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2014 } \ 2015 else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X)) \ 2016 { \ 2017 STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2018 } \ 2019 else \ 2020 { \ 2021 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2022 } 2023/** Store a block that can only be partial in x but not y. 2024 * 2025 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 2026 * 2027 * The data to store is expected to have consecutive names for each row. 2028 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 2029 * The Z offset is expected to have consecutive names. 2030 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 2031 * 2032 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 2033 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 2034 * @param[in] DATA_TYPE The data type of the vectors 2035 * @param[in] BASENAME The basename of the variables 2036 * @param[in] PTR The base pointer 2037 * @param[in] STRIDE_Y The stride value in y-axis direction 2038 * @param[in] Z The offset in z-axis direction 2039 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0) 2040 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 2041 */ 2042#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \ 2043 if(!(PARTIAL_COND_X)) \ 2044 { \ 2045 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2046 } \ 2047 else \ 2048 { \ 2049 STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2050 } 2051/** Store a block that can only be partial in y but not x. 2052 * 2053 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 2054 * 2055 * The data to store is expected to have consecutive names for each row. 2056 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 2057 * The Z offset is expected to have consecutive names. 2058 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 2059 * 2060 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 2061 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 2062 * @param[in] DATA_TYPE The data type of the vectors 2063 * @param[in] BASENAME The basename of the variables 2064 * @param[in] PTR The base pointer 2065 * @param[in] STRIDE_Y The stride value in y-axis direction 2066 * @param[in] Z The offset in z-axis direction 2067 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0) 2068 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 2069 */ 2070#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \ 2071 if(!(PARTIAL_COND_Y)) \ 2072 { \ 2073 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2074 } \ 2075 else \ 2076 { \ 2077 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 2078 } 2079/** @} */ // end of group STORE_BLOCK_PARTIAL 2080 2081#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 2082 2083/** Boundary-aware GEMM block store 2084 * @name STORE_BLOCK_BOUNDARY_AWARE 2085 * This macro assumes the following schemes to achieve boundary-awareness: 2086 * - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim. 2087 * - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings. 2088 * - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim. 2089 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim. 2090 * 2091 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial 2092 * blocks **at the end**. 2093 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/ 2094 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters: 2095 * 2096 * *--x--> x == 0 x == 1 2097 * | |<------------------------------N-------------------------->| 2098 * y |<--------------N0------------->|<----PARTIAL_STORE_N0----->| 2099 * | -------------############################################################# 2100 * * | | |...............................|...........................| 2101 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.| 2102 * | | |...............................|...........................| 2103 * M --############################################################# 2104 * | | | |...........................| 2105 * y == 1 | M0 | Non-boundary block |....Boundary block in x....| 2106 * | | | |...........................| 2107 * |------------############################################################# 2108 * 2109 * Then @p PARTIAL_STORE_M0 = M % M0 and @p PARTIAL_STORE_N0 = N % N0 2110 * 2111 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 2112 * 2113 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension, 2114 * and select corresponding store methods such that the boundary detection logic is only added when needed. 2115 * 2116 * The data to store is expected to have consecutive names for each row. 2117 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 2118 * The Z offset is expected to have consecutive names. 2119 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 2120 * 2121 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 2122 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 2123 * @param[in] DATA_TYPE The data type of the vectors 2124 * @param[in] BASENAME The basename of the variables 2125 * @param[in] PTR The base pointer 2126 * @param[in] STRIDE_Y The stride value in y-axis direction 2127 * @param[in] Z The offset in z-axis direction 2128 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0) 2129 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0) 2130 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 2131 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 2132 * @{ 2133 */ 2134#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 2135// Case1: No partial blocks in either x or y 2136#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 2137 STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 2138 2139#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0 2140// Case2: Partial blocks in y 2141#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 2142 STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) 2143 2144#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0 2145// Case3: Partial blocks in x 2146#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 2147 STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) 2148 2149#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 2150// Case4: Partial blocks in both x and y 2151#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 2152 STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) 2153 2154#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 2155 2156#endif // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 2157/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE 2158 2159#if defined(PARTIAL_STORE_M0) 2160/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding 2161 * @name COMPUTE_M0_START_ROW 2162 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows. 2163 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent 2164 * blocks in the y dimension to avoid any padding. 2165 * EG: M0=4, PARTIAL_STORE_M0=1: 2166 * | Non-overlapping | +M0_ROW_SHIFT (Overlapping) 2167 * block 0 (partial)| start row = 0 | start row = 0 2168 * block 1 (full) | start row = 4 | start row = 1 2169 * block 2 (full) | start row = 8 | start row = 5 2170 * 2171 * @param[in] y Global id of current block in y. 2172 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 2173 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0) 2174 * @{ 2175 */ 2176#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \ 2177 ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0)))) 2178#else // defined(PARTIAL_STORE_M0) 2179#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \ 2180 ((uint)(y * M0)) 2181#endif // defined(PARTIAL_STORE_M0) 2182/** @} */ // end of group COMPUTE_M0_START_ROW 2183 2184/** Store a vector that can only be partial in x. 2185 * 2186 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 2187 * 2188 * The data to store is expected to end in a 0. 2189 * E.g., for basename=c, the expected name is c0. 2190 * 2191 * @param[in] basename The name of the variable without trailing 0 2192 * @param[in] data_type The data type of the vector 2193 * @param[in] ptr The base pointer 2194 * @param[in] vec_size The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16 2195 * @param[in] leftover The vector size if cond = true. Supported range: [1, @p vec_size0) 2196 * @param[in] cond Condition to select either vec_size0 or vec_size1 2197 * @{ 2198 */ 2199#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \ 2200 STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond) 2201/** @} */ // end of group STORE_VECTOR_SELECT 2202 2203#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 2204#pragma OPENCL EXTENSION cl_khr_fp16 : enable 2205#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 2206 2207#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) 2208#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable 2209#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) 2210 2211#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) 2212#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable 2213#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) 2214 2215#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf) 2216#pragma OPENCL EXTENSION cl_arm_printf : enable 2217#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf) 2218 2219#define GPU_ARCH_MIDGARD 0x100 2220#define GPU_ARCH_BIFROST 0x200 2221 2222/** Concatenate two inputs. 2223 * 2224 * @param[in] a The first input to be concatenated 2225 * @param[in] b The second input to be concatenated 2226 * 2227 * @return The concatenated output 2228 */ 2229#define CONCAT(a, b) a##b 2230 2231/** Expand the given vector 2232 * 2233 * @param[in] x The vector to be expanded 2234 * 2235 * @return The expanded output 2236 */ 2237#define EXPAND(x) x 2238 2239/** Clamp the given value between an upper and lower bound. 2240 * 2241 * @param[in] x The value to be clamped 2242 * @param[in] min_val The lower bound 2243 * @param[in] max_val The upper bound 2244 * 2245 * @return The clamped value. 2246 */ 2247#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val) 2248 2249/** REVn reverses the given vector whose size is n. 2250 * @name REVn 2251 * 2252 * @param[in] x The vector to be reversed 2253 * 2254 * @return The reversed vector 2255 * @{ 2256 */ 2257#define REV1(x) ((x)) 2258#define REV2(x) ((x).s10) 2259#define REV3(x) ((x).s210) 2260#define REV4(x) ((x).s3210) 2261#define REV8(x) ((x).s76543210) 2262#define REV16(x) ((x).sFEDCBA9876543210) 2263/** @} */ // end of group REVn 2264 2265/** Reverse the given vector. 2266 * @name REVERSE 2267 * 2268 * @param[in] x The vector to be reversed 2269 * @param[in] s The size of the vector 2270 * 2271 * @return The reversed vector 2272 * @{ 2273 */ 2274#define REVERSE_STR(x, s) REV##s((x)) 2275#define REVERSE(x, s) REVERSE_STR(x, s) 2276/** @} */ // end of group REVERSE 2277 2278/** Circular-right-shift (rotate-right) the vector of size s by the amount of n. 2279 * @name ROTs_n 2280 * 2281 * @param[in] x The vector to be shifted 2282 * 2283 * @return The shifted vector 2284 * @{ 2285 */ 2286#define ROT1_0(x) ((x)) 2287 2288#define ROT2_0(x) ((x)) 2289#define ROT2_1(x) ((x).s10) 2290 2291#define ROT3_0(x) ((x)) 2292#define ROT3_1(x) ((x).s201) 2293#define ROT3_2(x) ((x).s120) 2294 2295#define ROT4_0(x) ((x)) 2296#define ROT4_1(x) ((x).s3012) 2297#define ROT4_2(x) ((x).s2301) 2298#define ROT4_3(x) ((x).s1230) 2299 2300#define ROT8_0(x) ((x)) 2301#define ROT8_1(x) ((x).s70123456) 2302#define ROT8_2(x) ((x).s67012345) 2303#define ROT8_3(x) ((x).s56701234) 2304#define ROT8_4(x) ((x).s45670123) 2305#define ROT8_5(x) ((x).s34567012) 2306#define ROT8_6(x) ((x).s23456701) 2307#define ROT8_7(x) ((x).s12345670) 2308 2309#define ROT16_0(x) ((x)) 2310#define ROT16_1(x) ((x).sF0123456789ABCDE) 2311#define ROT16_2(x) ((x).sEF0123456789ABCD) 2312#define ROT16_3(x) ((x).sDEF0123456789ABC) 2313#define ROT16_4(x) ((x).sCDEF0123456789AB) 2314#define ROT16_5(x) ((x).sBCDEF0123456789A) 2315#define ROT16_6(x) ((x).sABCDEF0123456789) 2316#define ROT16_7(x) ((x).s9ABCDEF012345678) 2317#define ROT16_8(x) ((x).s89ABCDEF01234567) 2318#define ROT16_9(x) ((x).s789ABCDEF0123456) 2319#define ROT16_10(x) ((x).s6789ABCDEF012345) 2320#define ROT16_11(x) ((x).s56789ABCDEF01234) 2321#define ROT16_12(x) ((x).s456789ABCDEF0123) 2322#define ROT16_13(x) ((x).s3456789ABCDEF012) 2323#define ROT16_14(x) ((x).s23456789ABCDEF01) 2324#define ROT16_15(x) ((x).s123456789ABCDEF0) 2325/** @} */ // end of group ROTs_n 2326 2327/** Circular-right-shift (rotate-right) the given vector by the given amount. 2328 * @name ROTATE 2329 * 2330 * @param[in] x The vector to be shifted 2331 * @param[in] s The size of the vector 2332 * @param[in] n The amount to be shifted 2333 * 2334 * @return The shifted vector 2335 * @{ 2336 */ 2337#define ROTATE_STR(x, s, n) ROT##s##_##n(x) 2338#define ROTATE(x, s, n) ROTATE_STR(x, s, n) 2339/** @} */ // end of group ROTATE 2340 2341/** Creates a vector of size n filled with offset values corresponding to the location of each element. 2342 * @name V_OFFSn 2343 * 2344 * @param[in] dt The data type of the output vector 2345 * 2346 * @return The vector filled with offset values 2347 * @{ 2348 */ 2349#define V_OFFS1(dt) (dt##1)(0) 2350#define V_OFFS2(dt) (dt##2)(0, 1) 2351#define V_OFFS3(dt) (dt##3)(0, 1, 2) 2352#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3) 2353#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7) 2354#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) 2355/** @} */ // end of group V_OFFSn 2356 2357/** Create a vector filled with offset values corresponding to the location of each element. 2358 * @name VEC_OFFS 2359 * 2360 * @param[in] dt The data type of the output vector 2361 * @param[in] s The size of the output vector 2362 * 2363 * @return The vector filled with offset values 2364 * @{ 2365 */ 2366#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt) 2367#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s) 2368/** @} */ // end of group VEC_OFFS 2369 2370#define VLOAD_STR(size) vload##size 2371#define VLOAD(size) VLOAD_STR(size) 2372 2373#define PIXEL_UNIT4 1 2374#define PIXEL_UNIT8 2 2375#define PIXEL_UNIT16 4 2376 2377/** Utility macro to convert a vector size in pixel unit. 2378 * 2379 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT 2380 * 2381 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported 2382 * 2383 * @return The pixel unit (number of pixels) 2384 * @{ 2385 */ 2386#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size 2387#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) 2388/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT 2389 2390#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord))); 2391#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord))); 2392#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord))); 2393 2394#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 2395#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord))); 2396#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord))); 2397#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord))); 2398#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 2399 2400/** Utility macro to read a 2D OpenCL image object. 2401 * 2402 * @note Coordinates are not normalized 2403 * 2404 * @param[in] data_type Data type 2405 * @param[in] n0 Number of pixel to read. Only 1,2 and 4 is supported 2406 * @param[in] img OpenCL image object 2407 * @param[in] x_coord The x coordinate for the top-left pixel 2408 * @param[in] y_coord The y coordinate for the top-left pixel 2409 * 2410 * @return Pixels from the 2D OpenCL image object 2411 * @{ 2412 */ 2413#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord) 2414#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) 2415 2416#define VSTORE_STR(size) vstore##size 2417#define VSTORE(size) VSTORE_STR(size) 2418 2419#define float1 float 2420#define half1 half 2421#define char1 char 2422#define uchar1 uchar 2423#define short1 short 2424#define ushort1 ushort 2425#define int1 int 2426#define uint1 uint 2427#define long1 long 2428#define ulong1 ulong 2429#define double1 double 2430 2431#define vload1(OFFSET, PTR) *(OFFSET + PTR) 2432#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA 2433 2434/** Extended partial vstore that correctly handles scalar values as well. 2435 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops 2436 * @name VSTORE_PARTIAL 2437 * 2438 * @note With this macro, the passed data can be both a vector and a scalar 2439 * @note @p store_size needs to be <= @p size 2440 * eg 1: Valid 2441 * VSTORE_PARTIAL(16, 15) ...; 2442 * eg 2: Invalid 2443 * VSTORE_PARTIAL(4, 7) ...; 2444 * 2445 * @param[in] size The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16 2446 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size 2447 * @{ 2448 */ 2449#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size 2450#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size) 2451 2452#define NO_STORE(data, offs, ptr) \ 2453 { \ 2454 } 2455 2456// Size == 1 (scalar) 2457#define vstore_partial_1_0 NO_STORE 2458#define vstore_partial_1_1 vstore1 2459#define vstore_partial_1_2 NO_STORE 2460#define vstore_partial_1_3 NO_STORE 2461#define vstore_partial_1_4 NO_STORE 2462#define vstore_partial_1_5 NO_STORE 2463#define vstore_partial_1_6 NO_STORE 2464#define vstore_partial_1_7 NO_STORE 2465#define vstore_partial_1_8 NO_STORE 2466#define vstore_partial_1_9 NO_STORE 2467#define vstore_partial_1_10 NO_STORE 2468#define vstore_partial_1_11 NO_STORE 2469#define vstore_partial_1_12 NO_STORE 2470#define vstore_partial_1_13 NO_STORE 2471#define vstore_partial_1_14 NO_STORE 2472#define vstore_partial_1_15 NO_STORE 2473#define vstore_partial_1_16 NO_STORE 2474// Size == 2 2475#define vstore_partial_2_0 NO_STORE 2476#define vstore_partial_2_1 vstore_partial_1 2477#define vstore_partial_2_2 vstore_partial_2 2478#define vstore_partial_2_3 NO_STORE 2479#define vstore_partial_2_4 NO_STORE 2480#define vstore_partial_2_5 NO_STORE 2481#define vstore_partial_2_6 NO_STORE 2482#define vstore_partial_2_7 NO_STORE 2483#define vstore_partial_2_8 NO_STORE 2484#define vstore_partial_2_9 NO_STORE 2485#define vstore_partial_2_10 NO_STORE 2486#define vstore_partial_2_11 NO_STORE 2487#define vstore_partial_2_12 NO_STORE 2488#define vstore_partial_2_13 NO_STORE 2489#define vstore_partial_2_14 NO_STORE 2490#define vstore_partial_2_15 NO_STORE 2491#define vstore_partial_2_16 NO_STORE 2492// Size == 3 2493#define vstore_partial_3_0 NO_STORE 2494#define vstore_partial_3_1 vstore_partial_1 2495#define vstore_partial_3_2 vstore_partial_2 2496#define vstore_partial_3_3 vstore_partial_3 2497#define vstore_partial_3_4 NO_STORE 2498#define vstore_partial_3_5 NO_STORE 2499#define vstore_partial_3_6 NO_STORE 2500#define vstore_partial_3_7 NO_STORE 2501#define vstore_partial_3_8 NO_STORE 2502#define vstore_partial_3_9 NO_STORE 2503#define vstore_partial_3_10 NO_STORE 2504#define vstore_partial_3_11 NO_STORE 2505#define vstore_partial_3_12 NO_STORE 2506#define vstore_partial_3_13 NO_STORE 2507#define vstore_partial_3_14 NO_STORE 2508#define vstore_partial_3_15 NO_STORE 2509#define vstore_partial_3_16 NO_STORE 2510// Size == 4 2511#define vstore_partial_4_0 NO_STORE 2512#define vstore_partial_4_1 vstore_partial_1 2513#define vstore_partial_4_2 vstore_partial_2 2514#define vstore_partial_4_3 vstore_partial_3 2515#define vstore_partial_4_4 vstore_partial_4 2516#define vstore_partial_4_5 NO_STORE 2517#define vstore_partial_4_6 NO_STORE 2518#define vstore_partial_4_7 NO_STORE 2519#define vstore_partial_4_8 NO_STORE 2520#define vstore_partial_4_9 NO_STORE 2521#define vstore_partial_4_10 NO_STORE 2522#define vstore_partial_4_11 NO_STORE 2523#define vstore_partial_4_12 NO_STORE 2524#define vstore_partial_4_13 NO_STORE 2525#define vstore_partial_4_14 NO_STORE 2526#define vstore_partial_4_15 NO_STORE 2527#define vstore_partial_4_16 NO_STORE 2528// Size == 8 2529#define vstore_partial_8_0 NO_STORE 2530#define vstore_partial_8_1 vstore_partial_1 2531#define vstore_partial_8_2 vstore_partial_2 2532#define vstore_partial_8_3 vstore_partial_3 2533#define vstore_partial_8_4 vstore_partial_4 2534#define vstore_partial_8_5 vstore_partial_5 2535#define vstore_partial_8_6 vstore_partial_6 2536#define vstore_partial_8_7 vstore_partial_7 2537#define vstore_partial_8_8 vstore_partial_8 2538#define vstore_partial_8_9 NO_STORE 2539#define vstore_partial_8_10 NO_STORE 2540#define vstore_partial_8_11 NO_STORE 2541#define vstore_partial_8_12 NO_STORE 2542#define vstore_partial_8_13 NO_STORE 2543#define vstore_partial_8_14 NO_STORE 2544#define vstore_partial_8_15 NO_STORE 2545#define vstore_partial_8_16 NO_STORE 2546// Size == 16 2547#define vstore_partial_16_0 NO_STORE 2548#define vstore_partial_16_1 vstore_partial_1 2549#define vstore_partial_16_2 vstore_partial_2 2550#define vstore_partial_16_3 vstore_partial_3 2551#define vstore_partial_16_4 vstore_partial_4 2552#define vstore_partial_16_5 vstore_partial_5 2553#define vstore_partial_16_6 vstore_partial_6 2554#define vstore_partial_16_7 vstore_partial_7 2555#define vstore_partial_16_8 vstore_partial_8 2556#define vstore_partial_16_9 vstore_partial_9 2557#define vstore_partial_16_10 vstore_partial_10 2558#define vstore_partial_16_11 vstore_partial_11 2559#define vstore_partial_16_12 vstore_partial_12 2560#define vstore_partial_16_13 vstore_partial_13 2561#define vstore_partial_16_14 vstore_partial_14 2562#define vstore_partial_16_15 vstore_partial_15 2563#define vstore_partial_16_16 vstore_partial_16 2564 2565/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops 2566 * @name vstore_partial_n 2567 * 2568 * @note @p DATA needs to be a vector not a scalar 2569 * @note n needs to be <= the vector width of the input variable @p DATA 2570 * eg 1: Valid 2571 * vstore_partial_15(var:float16, 0, 0xabcd); 2572 * eg 2: Invalid 2573 * vstore_partial_7(var:float4, 0, 0xabcd); 2574 * 2575 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty. 2576 * 2577 * @param[in] DATA The name of the variable 2578 * @param[in] OFFSET Offset in n 2579 * @param[in] PTR The base pointer 2580 * @{ 2581 */ 2582#define vstore_partial_1(DATA, OFFSET, PTR) \ 2583 vstore1(DATA.s0, OFFSET, PTR); 2584 2585#define vstore_partial_2(DATA, OFFSET, PTR) \ 2586 vstore2(DATA.s01, OFFSET, PTR); 2587 2588#define vstore_partial_3(DATA, OFFSET, PTR) \ 2589 vstore3(DATA.s012, OFFSET, PTR); 2590 2591#define vstore_partial_4(DATA, OFFSET, PTR) \ 2592 vstore4(DATA.s0123, OFFSET, PTR); 2593 2594#define vstore_partial_5(DATA, OFFSET, PTR) \ 2595 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 2596 vstore1(DATA.s4, OFFSET, PTR + 4); 2597 2598#define vstore_partial_6(DATA, OFFSET, PTR) \ 2599 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 2600 vstore_partial_2(DATA.s45, OFFSET, PTR + 4); 2601 2602#define vstore_partial_7(DATA, OFFSET, PTR) \ 2603 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 2604 vstore_partial_3(DATA.s456, OFFSET, PTR + 4); 2605 2606#define vstore_partial_8(DATA, OFFSET, PTR) \ 2607 vstore8(DATA.s01234567, OFFSET, PTR); 2608 2609#define vstore_partial_9(DATA, OFFSET, PTR) \ 2610 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2611 vstore1(DATA.s8, OFFSET, PTR + 8); 2612 2613#define vstore_partial_10(DATA, OFFSET, PTR) \ 2614 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2615 vstore_partial_2(DATA.s89, OFFSET, PTR + 8); 2616 2617#define vstore_partial_11(DATA, OFFSET, PTR) \ 2618 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2619 vstore_partial_3(DATA.s89a, OFFSET, PTR + 8); 2620 2621#define vstore_partial_12(DATA, OFFSET, PTR) \ 2622 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2623 vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8); 2624 2625#define vstore_partial_13(DATA, OFFSET, PTR) \ 2626 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2627 vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8); 2628 2629#define vstore_partial_14(DATA, OFFSET, PTR) \ 2630 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2631 vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8); 2632 2633#define vstore_partial_15(DATA, OFFSET, PTR) \ 2634 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 2635 vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8); 2636 2637#define vstore_partial_16(DATA, OFFSET, PTR) \ 2638 vstore16(DATA, OFFSET, PTR); 2639/** @} */ // end of groupd vstore_partial_n 2640/** @} */ // end of groupd VSTORE_PARTIAL 2641 2642// Convert built-in functions with _sat modifier are not supported in floating point so we create defines 2643// without _sat to overcome this issue 2644#define convert_float_sat convert_float 2645#define convert_float1_sat convert_float 2646#define convert_float2_sat convert_float2 2647#define convert_float3_sat convert_float3 2648#define convert_float4_sat convert_float4 2649#define convert_float8_sat convert_float8 2650#define convert_float16_sat convert_float16 2651#define convert_half_sat convert_float 2652#define convert_half1_sat convert_half 2653#define convert_half2_sat convert_half2 2654#define convert_half3_sat convert_half3 2655#define convert_half4_sat convert_half4 2656#define convert_half8_sat convert_half8 2657#define convert_half16_sat convert_half16 2658 2659#define convert_float1 convert_float 2660#define convert_half1 convert_half 2661#define convert_char1 convert_char 2662#define convert_uchar1 convert_uchar 2663#define convert_short1 convert_short 2664#define convert_ushort1 convert_ushort 2665#define convert_int1 convert_int 2666#define convert_uint1 convert_uint 2667#define convert_long1 convert_long 2668#define convert_ulong1 convert_ulong 2669#define convert_double1 convert_double 2670 2671#define convert_char1_sat convert_char_sat 2672#define convert_uchar1_sat convert_uchar_sat 2673#define convert_short1_sat convert_short_sat 2674#define convert_ushort1_sat convert_ushort_sat 2675#define convert_int1_sat convert_int_sat 2676#define convert_uint1_sat convert_uint_sat 2677#define convert_long1_sat convert_long_sat 2678#define convert_ulong1_sat convert_ulong_sat 2679#define convert_double1_sat convert_double_sat 2680 2681#define VEC_DATA_TYPE_STR(type, size) type##size 2682#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size) 2683 2684#define CONVERT_STR(x, type) (convert_##type((x))) 2685#define CONVERT(x, type) CONVERT_STR(x, type) 2686 2687#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x))) 2688#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type) 2689 2690#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x))) 2691#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round) 2692 2693#define select_vec_dt_uchar(size) uchar##size 2694#define select_vec_dt_char(size) char##size 2695#define select_vec_dt_ushort(size) ushort##size 2696#define select_vec_dt_short(size) short##size 2697#define select_vec_dt_half(size) short##size 2698#define select_vec_dt_uint(size) uint##size 2699#define select_vec_dt_int(size) int##size 2700#define select_vec_dt_float(size) int##size 2701#define select_vec_dt_ulong(size) ulong##size 2702#define select_vec_dt_long(size) long##size 2703 2704#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size) 2705#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size) 2706#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1) 2707 2708#define sum_reduce_1(x) (x) 2709#define sum_reduce_2(x) ((x).s0) + ((x).s1) 2710#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2) 2711#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23) 2712#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567) 2713#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF) 2714 2715#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x) 2716#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size) 2717 2718#define max_reduce_1(x) (x) 2719#define max_reduce_2(x) max(((x).s0), ((x).s1)) 2720#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2)) 2721#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23)) 2722#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567)) 2723#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF)) 2724 2725#define MAX_REDUCE_STR(x, size) max_reduce_##size(x) 2726#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size) 2727 2728#define VECTOR_DECLARATION(name) \ 2729 __global uchar *name##_ptr, \ 2730 uint name##_stride_x, \ 2731 uint name##_step_x, \ 2732 uint name##_offset_first_element_in_bytes 2733 2734#define IMAGE_DECLARATION(name) \ 2735 __global uchar *name##_ptr, \ 2736 uint name##_stride_x, \ 2737 uint name##_step_x, \ 2738 uint name##_stride_y, \ 2739 uint name##_step_y, \ 2740 uint name##_offset_first_element_in_bytes 2741 2742#define TENSOR3D_DECLARATION(name) \ 2743 __global uchar *name##_ptr, \ 2744 uint name##_stride_x, \ 2745 uint name##_step_x, \ 2746 uint name##_stride_y, \ 2747 uint name##_step_y, \ 2748 uint name##_stride_z, \ 2749 uint name##_step_z, \ 2750 uint name##_offset_first_element_in_bytes 2751 2752#define TENSOR4D_DECLARATION(name) \ 2753 __global uchar *name##_ptr, \ 2754 uint name##_stride_x, \ 2755 uint name##_step_x, \ 2756 uint name##_stride_y, \ 2757 uint name##_step_y, \ 2758 uint name##_stride_z, \ 2759 uint name##_step_z, \ 2760 uint name##_stride_w, \ 2761 uint name##_step_w, \ 2762 uint name##_offset_first_element_in_bytes 2763 2764#define CONVERT_TO_VECTOR_STRUCT(name) \ 2765 update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x) 2766 2767#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \ 2768 update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0) 2769 2770#define CONVERT_TO_IMAGE_STRUCT(name) \ 2771 update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y) 2772 2773#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \ 2774 update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0) 2775 2776#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \ 2777 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z) 2778 2779#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \ 2780 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z) 2781 2782#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \ 2783 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z) 2784 2785#define CONVERT_TO_TENSOR3D_STRUCT(name) \ 2786 update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 2787 name##_stride_z, name##_step_z) 2788 2789#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \ 2790 update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0) 2791 2792#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size) \ 2793 update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 2794 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size) 2795 2796#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \ 2797 update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size) 2798 2799#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name) \ 2800 tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 2801 name##_stride_z, name##_step_z) 2802 2803/** Structure to hold Vector information */ 2804typedef struct Vector 2805{ 2806 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 2807 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 2808 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 2809} Vector; 2810 2811/** Structure to hold Image information */ 2812typedef struct Image 2813{ 2814 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 2815 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 2816 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 2817 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 2818} Image; 2819 2820/** Structure to hold 3D tensor information */ 2821typedef struct Tensor3D 2822{ 2823 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 2824 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 2825 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 2826 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 2827 int stride_z; /**< Stride of the image in Z dimension (in bytes) */ 2828} Tensor3D; 2829 2830/** Structure to hold 4D tensor information */ 2831typedef struct Tensor4D 2832{ 2833 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 2834 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 2835 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 2836 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 2837 int stride_z; /**< Stride of the image in Z dimension (in bytes) */ 2838 int stride_w; /**< Stride of the image in W dimension (in bytes) */ 2839} Tensor4D; 2840 2841/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data. 2842 * 2843 * @param[in] ptr Pointer to the starting postion of the buffer 2844 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector 2845 * @param[in] stride_x Stride of the vector in X dimension (in bytes) 2846 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 2847 * 2848 * @return An image object 2849 */ 2850inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x) 2851{ 2852 Vector vector = 2853 { 2854 .ptr = ptr, 2855 .offset_first_element_in_bytes = offset_first_element_in_bytes, 2856 .stride_x = stride_x, 2857 }; 2858 vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x; 2859 return vector; 2860} 2861 2862/** Wrap image information into an Image structure, and make the pointer point at this workitem's data. 2863 * 2864 * @param[in] ptr Pointer to the starting postion of the buffer 2865 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 2866 * @param[in] stride_x Stride of the image in X dimension (in bytes) 2867 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 2868 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 2869 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 2870 * 2871 * @return An image object 2872 */ 2873inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y) 2874{ 2875 Image img = 2876 { 2877 .ptr = ptr, 2878 .offset_first_element_in_bytes = offset_first_element_in_bytes, 2879 .stride_x = stride_x, 2880 .stride_y = stride_y 2881 }; 2882 img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y; 2883 return img; 2884} 2885 2886/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data. 2887 * 2888 * @param[in] ptr Pointer to the starting postion of the buffer 2889 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 2890 * @param[in] stride_x Stride of the image in X dimension (in bytes) 2891 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 2892 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 2893 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 2894 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 2895 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 2896 * 2897 * @return A 3D tensor object 2898 */ 2899inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 2900{ 2901 Image img = 2902 { 2903 .ptr = ptr, 2904 .offset_first_element_in_bytes = offset_first_element_in_bytes, 2905 .stride_x = stride_x, 2906 .stride_y = stride_y 2907 }; 2908 img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z; 2909 return img; 2910} 2911 2912/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data. 2913 * 2914 * @param[in] ptr Pointer to the starting postion of the buffer 2915 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 2916 * @param[in] stride_x Stride of the image in X dimension (in bytes) 2917 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 2918 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 2919 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 2920 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 2921 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 2922 * 2923 * @return A 3D tensor object 2924 */ 2925inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 2926{ 2927 Tensor3D tensor = 2928 { 2929 .ptr = ptr, 2930 .offset_first_element_in_bytes = offset_first_element_in_bytes, 2931 .stride_x = stride_x, 2932 .stride_y = stride_y, 2933 .stride_z = stride_z 2934 }; 2935 tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z; 2936 return tensor; 2937} 2938 2939/** Wrap 3D tensor information into an tensor structure. 2940 * 2941 * @param[in] ptr Pointer to the starting postion of the buffer 2942 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 2943 * @param[in] stride_x Stride of the image in X dimension (in bytes) 2944 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 2945 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 2946 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 2947 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 2948 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 2949 * 2950 * @return A 3D tensor object 2951 */ 2952inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 2953{ 2954 Tensor3D tensor = 2955 { 2956 .ptr = ptr, 2957 .offset_first_element_in_bytes = offset_first_element_in_bytes, 2958 .stride_x = stride_x, 2959 .stride_y = stride_y, 2960 .stride_z = stride_z 2961 }; 2962 return tensor; 2963} 2964 2965inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w, 2966 uint step_w, 2967 uint mod_size) 2968{ 2969 Tensor4D tensor = 2970 { 2971 .ptr = ptr, 2972 .offset_first_element_in_bytes = offset_first_element_in_bytes, 2973 .stride_x = stride_x, 2974 .stride_y = stride_y, 2975 .stride_z = stride_z, 2976 .stride_w = stride_w 2977 }; 2978 2979 tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w; 2980 return tensor; 2981} 2982 2983/** Get the pointer position of a Vector 2984 * 2985 * @param[in] vec Pointer to the starting position of the buffer 2986 * @param[in] x Relative X position 2987 */ 2988inline __global const uchar *vector_offset(const Vector *vec, int x) 2989{ 2990 return vec->ptr + x * vec->stride_x; 2991} 2992 2993/** Get the pointer position of a Image 2994 * 2995 * @param[in] img Pointer to the starting position of the buffer 2996 * @param[in] x Relative X position 2997 * @param[in] y Relative Y position 2998 */ 2999inline __global uchar *offset(const Image *img, int x, int y) 3000{ 3001 return img->ptr + x * img->stride_x + y * img->stride_y; 3002} 3003 3004/** Get the pointer position of a Tensor3D 3005 * 3006 * @param[in] tensor Pointer to the starting position of the buffer 3007 * @param[in] x Relative X position 3008 * @param[in] y Relative Y position 3009 * @param[in] z Relative Z position 3010 */ 3011inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z) 3012{ 3013 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z; 3014} 3015 3016/** Get the pointer position of a Tensor4D 3017 * 3018 * @param[in] tensor Pointer to the starting position of the buffer 3019 * @param[in] x Relative X position 3020 * @param[in] y Relative Y position 3021 * @param[in] z Relative Z position 3022 * @param[in] w Relative W position 3023 */ 3024inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w) 3025{ 3026 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w; 3027} 3028 3029/** Get the offset for a given linear index of a Tensor3D 3030 * 3031 * @param[in] tensor Pointer to the starting position of the buffer 3032 * @param[in] width Width of the input tensor 3033 * @param[in] height Height of the input tensor 3034 * @param[in] depth Depth of the input tensor 3035 * @param[in] index Linear index 3036 */ 3037inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index) 3038{ 3039 uint num_elements = width * height; 3040 3041 const uint z = index / num_elements; 3042 3043 index %= num_elements; 3044 3045 const uint y = index / width; 3046 3047 index %= width; 3048 3049 const uint x = index; 3050 3051 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes; 3052} 3053 3054#endif // _HELPER_H 3055 3056/** Utility macro to access a vector with the scalar positions 3057 * 3058 * Supported cases are: Offset can only be of the same size of the OpenCL vector (2,3,4,8,16) 3059 * 3060 * @param[in] offset The offset within the vector. Offset can only be of the same size of the OpenCL vector (2,3,4,8,16) 3061 * @param[in] n0 The number of consecutive columns to access. n0 + offset must be <= 16 3062 * @param[in] x Vector to access 3063 * @{ 3064 */ 3065#define SCALAR_ACCESS_STR(offset, n0, x) scalar_access_##offset##_##n0(x) 3066#define SCALAR_ACCESS(offset, n0, x) SCALAR_ACCESS_STR(offset, n0, x) 3067 3068// offset == 0 3069#define scalar_access_0_1(x) ((x).s0) 3070#define scalar_access_0_2(x) ((x).s01) 3071#define scalar_access_0_3(x) ((x).s012) 3072#define scalar_access_0_4(x) ((x).s0123) 3073#define scalar_access_0_8(x) ((x).s01234567) 3074#define scalar_access_0_16(x) ((x).s0123456789ABCDEF) 3075 3076// offset == 1 3077#define scalar_access_1_1(x) ((x).s1) 3078#define scalar_access_1_2(x) ((x).s12) 3079#define scalar_access_1_3(x) ((x).s123) 3080#define scalar_access_1_4(x) ((x).s1234) 3081#define scalar_access_1_8(x) ((x).s12345678) 3082 3083// offset == 2 3084#define scalar_access_2_1(x) ((x).s2) 3085#define scalar_access_2_2(x) ((x).s23) 3086#define scalar_access_2_3(x) ((x).s234) 3087#define scalar_access_2_4(x) ((x).s2345) 3088#define scalar_access_2_8(x) ((x).s23456789) 3089 3090// offset == 3 3091#define scalar_access_3_1(x) ((x).s3) 3092#define scalar_access_3_2(x) ((x).s34) 3093#define scalar_access_3_3(x) ((x).s345) 3094#define scalar_access_3_4(x) ((x).s3456) 3095#define scalar_access_3_8(x) ((x).s3456789A) 3096 3097// offset == 4 3098#define scalar_access_4_1(x) ((x).s4) 3099#define scalar_access_4_2(x) ((x).s45) 3100#define scalar_access_4_3(x) ((x).s456) 3101#define scalar_access_4_4(x) ((x).s4567) 3102#define scalar_access_4_8(x) ((x).s456789AB) 3103 3104// offset == 8 3105#define scalar_access_8_1(x) ((x).s8) 3106#define scalar_access_8_2(x) ((x).s89) 3107#define scalar_access_8_3(x) ((x).s89A) 3108#define scalar_access_8_4(x) ((x).s89AB) 3109#define scalar_access_8_8(x) ((x).s89ABCDEF) 3110 3111// offset == 12 3112#define scalar_access_12_1(x) ((x).sC) 3113#define scalar_access_12_2(x) ((x).sCD) 3114#define scalar_access_12_3(x) ((x).sCDE) 3115#define scalar_access_12_4(x) ((x).sCDEF) 3116 3117// offset == 16 3118#define scalar_access_16_1(x) ((x).sF) 3119 3120/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1) without allocating variables. 3121 * @name LOAD_TENSOR_ROW_n 3122 * 3123 * @param[in] N0 The number of columns to load 3124 * @param[in] DATA_TYPE The data type of variables 3125 * @param[in] BASENAME The basename of the destination variables for the loaded rows 3126 * @param[in] PTR The base pointer 3127 * @param[in] COL_OFFSET The column vector offset. COL_OFFSET + N0 must be <= 16 3128 * @param[in] STRIDE_Y The stride value in y-axis direction 3129 * @param[in] Z The z-axis offset vector 3130 * @{ 3131 */ 3132#define LOAD_TENSOR_ROW_0(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3133 ({}) 3134 3135#define LOAD_TENSOR_ROW_1(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3136 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##0) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 3137 3138#define LOAD_TENSOR_ROW_2(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3139 LOAD_TENSOR_ROW_1(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3140 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##1) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 3141 3142#define LOAD_TENSOR_ROW_3(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3143 LOAD_TENSOR_ROW_2(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3144 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##2) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 3145 3146#define LOAD_TENSOR_ROW_4(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3147 LOAD_TENSOR_ROW_3(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3148 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##3) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 3149 3150#define LOAD_TENSOR_ROW_5(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3151 LOAD_TENSOR_ROW_4(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3152 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##4) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 3153 3154#define LOAD_TENSOR_ROW_6(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3155 LOAD_TENSOR_ROW_5(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3156 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##5) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 3157 3158#define LOAD_TENSOR_ROW_7(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3159 LOAD_TENSOR_ROW_6(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3160 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##6) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 3161 3162#define LOAD_TENSOR_ROW_8(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3163 LOAD_TENSOR_ROW_7(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3164 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##7) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 3165 3166#define LOAD_TENSOR_ROW_9(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3167 LOAD_TENSOR_ROW_8(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3168 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##8) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 3169 3170#define LOAD_TENSOR_ROW_10(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3171 LOAD_TENSOR_ROW_9(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3172 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##9) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 3173 3174#define LOAD_TENSOR_ROW_11(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3175 LOAD_TENSOR_ROW_10(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3176 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##A) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 3177 3178#define LOAD_TENSOR_ROW_12(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3179 LOAD_TENSOR_ROW_11(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3180 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##B) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 3181 3182#define LOAD_TENSOR_ROW_13(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3183 LOAD_TENSOR_ROW_12(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3184 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##C) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 3185 3186#define LOAD_TENSOR_ROW_14(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3187 LOAD_TENSOR_ROW_13(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3188 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##D) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 3189 3190#define LOAD_TENSOR_ROW_15(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3191 LOAD_TENSOR_ROW_14(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3192 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##E) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 3193 3194#define LOAD_TENSOR_ROW_16(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3195 LOAD_TENSOR_ROW_15(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) \ 3196 SCALAR_ACCESS(COL_OFFSET, N0, BASENAME##F) = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 3197/** @}*/ // end of group LOAD_TENSOR_ROW_n 3198 3199/** Load tensor (consecutive rows and columns) with Z offset. 3200 * @name LOAD_TENSOR 3201 * 3202 * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16 3203 * The data to load is expected to have consecutive names for each row. 3204 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2. 3205 * The Z offset is expected to have consecutive names. 3206 * E.g., for M0=3, and Z=zin, the expected Z offsets are zin0, zin1 and zin2. 3207 * 3208 * @param[in] M0 The number of consecutive rows 3209 * @param[in] N0 The number of consecutive columns 3210 * @param[in] DATA_TYPE The data type of the target 3211 * @param[in] BASENAME The basename of the result variables 3212 * @param[in] PTR The base pointer for the data 3213 * @param[in] COL_OFFSET The column vector offset. COL_OFFSET + N0 must be <= 16 3214 * @param[in] STRIDE_Y The stride in y-axis direction 3215 * @param[in] Z The z-axis offset vector 3216 * @{ 3217 */ 3218#define LOAD_TENSOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) LOAD_TENSOR_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) 3219#define LOAD_TENSOR(M0, N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) LOAD_TENSOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, COL_OFFSET, STRIDE_Y, Z) 3220/** @} */ // end of group LOAD_TENSOR 3221 3222/** Load 2D tensor (consecutive rows and columns) with Z offset. 3223 * @name LOAD_TENSOR_M0Xn 3224 * 3225 * @param[in] M0 The number of rows to load [0-16] 3226 * @param[in] N0 The number of columns to load [0-16] 3227 * @param[in] DATA_TYPE The data type of variables 3228 * @param[in] BASENAME The basename of the destination variables for the loaded rows 3229 * @param[in] PTR The base pointer 3230 * @param[in] STRIDE_Y The stride value in y-axis direction 3231 * @param[in] Z The z-axis offset vector 3232 * @{ 3233 */ 3234#define LOAD_TENSOR_M0X0(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3235 ({}) 3236 3237#define LOAD_TENSOR_M0X1(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3238 LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); 3239 3240#define LOAD_TENSOR_M0X2(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3241 LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); 3242 3243#define LOAD_TENSOR_M0X3(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3244 LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); 3245 3246#define LOAD_TENSOR_M0X4(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3247 LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); 3248 3249#define LOAD_TENSOR_M0X5(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3250 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3251 LOAD_TENSOR(M0, 1, DATA_TYPE, a, input_ptr + 4 * sizeof(DATA_TYPE), 4, src_stride_y, zin); 3252 3253#define LOAD_TENSOR_M0X6(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3254 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3255 LOAD_TENSOR(M0, 2, DATA_TYPE, a, input_ptr + 4 * sizeof(DATA_TYPE), 4, src_stride_y, zin); 3256 3257#define LOAD_TENSOR_M0X7(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3258 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3259 LOAD_TENSOR(M0, 3, DATA_TYPE, a, input_ptr + 4 * sizeof(DATA_TYPE), 4, src_stride_y, zin); 3260 3261#define LOAD_TENSOR_M0X8(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3262 LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); 3263 3264#define LOAD_TENSOR_M0X9(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3265 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr 0, src_stride_y, zin); \ 3266 LOAD_TENSOR(M0, 1, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); 3267 3268#define LOAD_TENSOR_M0X10(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3269 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3270 LOAD_TENSOR(M0, 2, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); 3271 3272#define LOAD_TENSOR_M0X11(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3273 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3274 LOAD_TENSOR(M0, 3, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); 3275 3276#define LOAD_TENSOR_M0X12(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3277 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3278 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); 3279 3280#define LOAD_TENSOR_M0X13(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3281 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3282 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); \ 3283 LOAD_TENSOR(M0, 1, DATA_TYPE, a, input_ptr + 12 * sizeof(DATA_TYPE), 12, src_stride_y, zin); 3284 3285#define LOAD_TENSOR_M0X14(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3286 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr 0, src_stride_y, zin); \ 3287 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); \ 3288 LOAD_TENSOR(M0, 2, DATA_TYPE, a, input_ptr + 12 * sizeof(DATA_TYPE), 12, src_stride_y, zin); 3289 3290#define LOAD_TENSOR_M0X15(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3291 LOAD_TENSOR(M0, 8, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); \ 3292 LOAD_TENSOR(M0, 4, DATA_TYPE, a, input_ptr + 8 * sizeof(DATA_TYPE), 8, src_stride_y, zin); \ 3293 LOAD_TENSOR(M0, 3, DATA_TYPE, a, input_ptr + 12 * sizeof(DATA_TYPE), 12, src_stride_y, zin); 3294 3295#define LOAD_TENSOR_M0X16(M0, N0, DATA_TYPE, a, input_ptr, src_stride_y, zin) \ 3296 LOAD_TENSOR(M0, N0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin); 3297/** @}*/ // end of group LOAD_TENSOR_M0Xn 3298 3299/** Load 2D tensor (consecutive rows and columns) with Z offset. 3300 * @name LOAD_TENSOR_M0XN0 3301 * 3302 * @param[in] M0 The number of consecutive rows [0-16] 3303 * @param[in] N0 The number of consecutive columns [0-16] 3304 * @param[in] DATA_TYPE The data type of the target 3305 * @param[in] BASENAME The basename of the result variables 3306 * @param[in] PTR The base pointer for the data 3307 * @param[in] STRIDE_Y The stride in y-axis direction 3308 * @param[in] Z The z-axis offset vector 3309 * @{ 3310 */ 3311#define LOAD_TENSOR_M0XN0_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) LOAD_TENSOR_M0X##N0(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 3312#define LOAD_TENSOR_M0XN0(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) LOAD_TENSOR_M0XN0_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 3313 3314/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1). 3315 * @name LOAD_ROW_n 3316 * 3317 * @param[in] N0 The number of columns to load 3318 * @param[in] DATA_TYPE The data type of variables 3319 * @param[in] BASENAME The basename of the destination variables for the loaded rows 3320 * @param[in] PTR The base pointer 3321 * @param[in] OFFSET The offset within a row 3322 * @param[in] STRIDE_Y The stride value in y-axis direction 3323 * @param[in] Z The z-axis offset vector 3324 * @{ 3325 */ 3326#define LOAD_ROW_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3327 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3328 BASENAME##0 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 0 * STRIDE_Y + Z##0)); 3329 3330#define LOAD_ROW_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3331 LOAD_ROW_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3332 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3333 BASENAME##1 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 1 * STRIDE_Y + Z##1)); 3334 3335#define LOAD_ROW_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3336 LOAD_ROW_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3337 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3338 BASENAME##2 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 2 * STRIDE_Y + Z##2)); 3339 3340#define LOAD_ROW_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3341 LOAD_ROW_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3342 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3343 BASENAME##3 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 3 * STRIDE_Y + Z##3)); 3344 3345#define LOAD_ROW_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3346 LOAD_ROW_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3347 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3348 BASENAME##4 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 4 * STRIDE_Y + Z##4)); 3349 3350#define LOAD_ROW_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3351 LOAD_ROW_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3352 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3353 BASENAME##5 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 5 * STRIDE_Y + Z##5)); 3354 3355#define LOAD_ROW_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3356 LOAD_ROW_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3357 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3358 BASENAME##6 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 6 * STRIDE_Y + Z##6)); 3359 3360#define LOAD_ROW_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3361 LOAD_ROW_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3362 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3363 BASENAME##7 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 7 * STRIDE_Y + Z##7)); 3364 3365#define LOAD_ROW_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3366 LOAD_ROW_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3367 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3368 BASENAME##8 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 8 * STRIDE_Y + Z##8)); 3369 3370#define LOAD_ROW_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3371 LOAD_ROW_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3372 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3373 BASENAME##9 = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 9 * STRIDE_Y + Z##9)); 3374 3375#define LOAD_ROW_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3376 LOAD_ROW_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3377 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3378 BASENAME##A = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 10 * STRIDE_Y + Z##A)); 3379 3380#define LOAD_ROW_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3381 LOAD_ROW_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3382 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3383 BASENAME##B = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 11 * STRIDE_Y + Z##B)); 3384 3385#define LOAD_ROW_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3386 LOAD_ROW_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3387 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3388 BASENAME##C = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 12 * STRIDE_Y + Z##C)); 3389 3390#define LOAD_ROW_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3391 LOAD_ROW_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3392 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3393 BASENAME##D = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 13 * STRIDE_Y + Z##D)); 3394 3395#define LOAD_ROW_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3396 LOAD_ROW_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3397 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3398 BASENAME##E = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 14 * STRIDE_Y + Z##E)); 3399 3400#define LOAD_ROW_16(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3401 LOAD_ROW_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) \ 3402 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3403 BASENAME##F = VLOAD(N0)(0, (__global DATA_TYPE *)(PTR + OFFSET + 15 * STRIDE_Y + Z##F)); 3404 3405/** @}*/ // end of group LOAD_ROW_n 3406 3407/** Load Blocks (consecutive rows and columns) with Z offset. 3408 * @name LOAD_BLOCK 3409 * 3410 * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16 3411 * The data to load is expected to have consecutive names for each row. 3412 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2. 3413 * The Z offset is expected to have consecutive names. 3414 * E.g., for M0=3, and Z=zin, the expected Z offsets are zin0, zin1 and zin2. 3415 * 3416 * @param[in] M0 The number of consecutive rows 3417 * @param[in] N0 The number of consecutive columns 3418 * @param[in] DATA_TYPE The data type of the target 3419 * @param[in] BASENAME The basename of the result variables 3420 * @param[in] PTR The base pointer for the data 3421 * @param[in] OFFSET The offset within a row 3422 * @param[in] STRIDE_Y The stride in y-axis direction 3423 * @param[in] Z The z-axis offset vector 3424 * @{ 3425 */ 3426#define LOAD_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) LOAD_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) 3427#define LOAD_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) LOAD_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y, Z) 3428/** @} */ // end of group LOAD_BLOCK 3429 3430/** Loads the rows from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1). 3431 * @name LOAD_TEXTURE2D_ROW_n 3432 * 3433 * @param[in] N0 The number of pixels to read 3434 * @param[in] DATA_TYPE The data type of variables 3435 * @param[in] BASENAME The basename of the destination variables for the loaded rows 3436 * @param[in] IMG The 2D OpenCL image object 3437 * @param[in] X_COORD The x coordinate for the top-left pixel 3438 * @param[in] Y_COORD The y coordinate for the top-left pixel 3439 * @param[in] X_STEP_ROW The incremental step row for the x coordinate (in pixels) 3440 * @param[in] Y_STEP_ROW The incremental step row for the y coordinate (in pixels) 3441 * @{ 3442 */ 3443#define LOAD_TEXTURE2D_ROW_1(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3444 BASENAME##0 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 0 * X_STEP_ROW), (Y_COORD + 0 * Y_STEP_ROW)) 3445 3446#define LOAD_TEXTURE2D_ROW_2(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3447 LOAD_TEXTURE2D_ROW_1(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3448 BASENAME##1 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 1 * X_STEP_ROW), (Y_COORD + 1 * Y_STEP_ROW)) 3449 3450#define LOAD_TEXTURE2D_ROW_3(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3451 LOAD_TEXTURE2D_ROW_2(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3452 BASENAME##2 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 2 * X_STEP_ROW), (Y_COORD + 2 * Y_STEP_ROW)) 3453 3454#define LOAD_TEXTURE2D_ROW_4(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3455 LOAD_TEXTURE2D_ROW_3(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3456 BASENAME##3 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 3 * X_STEP_ROW), (Y_COORD + 3 * Y_STEP_ROW)) 3457 3458#define LOAD_TEXTURE2D_ROW_5(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3459 LOAD_TEXTURE2D_ROW_4(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3460 BASENAME##4 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 4 * X_STEP_ROW), (Y_COORD + 4 * Y_STEP_ROW)) 3461 3462#define LOAD_TEXTURE2D_ROW_6(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3463 LOAD_TEXTURE2D_ROW_5(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3464 BASENAME##5 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 5 * X_STEP_ROW), (Y_COORD + 5 * Y_STEP_ROW)) 3465 3466#define LOAD_TEXTURE2D_ROW_7(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3467 LOAD_TEXTURE2D_ROW_6(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3468 BASENAME##6 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 6 * X_STEP_ROW), (Y_COORD + 6 * Y_STEP_ROW)) 3469 3470#define LOAD_TEXTURE2D_ROW_8(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3471 LOAD_TEXTURE2D_ROW_7(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3472 BASENAME##7 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 7 * X_STEP_ROW), (Y_COORD + 7 * Y_STEP_ROW)) 3473 3474#define LOAD_TEXTURE2D_ROW_9(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3475 LOAD_TEXTURE2D_ROW_8(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3476 BASENAME##8 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 8 * X_STEP_ROW), (Y_COORD + 8 * Y_STEP_ROW)) 3477 3478#define LOAD_TEXTURE2D_ROW_10(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3479 LOAD_TEXTURE2D_ROW_9(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3480 BASENAME##9 = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 9 * X_STEP_ROW), (Y_COORD + 9 * Y_STEP_ROW)) 3481 3482#define LOAD_TEXTURE2D_ROW_11(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3483 LOAD_TEXTURE2D_ROW_10(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3484 BASENAME##A = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 10 * X_STEP_ROW), (Y_COORD + 10 * Y_STEP_ROW)) 3485 3486#define LOAD_TEXTURE2D_ROW_12(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3487 LOAD_TEXTURE2D_ROW_11(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3488 BASENAME##B = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 11 * X_STEP_ROW), (Y_COORD + 11 * Y_STEP_ROW)) 3489 3490#define LOAD_TEXTURE2D_ROW_13(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3491 LOAD_TEXTURE2D_ROW_12(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3492 BASENAME##C = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 12 * X_STEP_ROW), (Y_COORD + 12 * Y_STEP_ROW)) 3493 3494#define LOAD_TEXTURE2D_ROW_14(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3495 LOAD_TEXTURE2D_ROW_13(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3496 BASENAME##D = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 13 * X_STEP_ROW), (Y_COORD + 13 * Y_STEP_ROW)) 3497 3498#define LOAD_TEXTURE2D_ROW_15(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3499 LOAD_TEXTURE2D_ROW_14(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3500 BASENAME##E = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 14 * X_STEP_ROW), (Y_COORD + 14 * Y_STEP_ROW)) 3501 3502#define LOAD_TEXTURE2D_ROW_16(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3503 LOAD_TEXTURE2D_ROW_15(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) \ 3504 BASENAME##F = READ_IMAGE2D(DATA_TYPE, N0, IMG, (X_COORD + 15 * X_STEP_ROW), (Y_COORD + 15 * Y_STEP_ROW)) 3505/** @} */ // end of group LOAD_TEXTURE2D_ROW_n 3506 3507/** Load a 2D texture in unit of pixel. A pixel is made of 4 floating point values 3508 * @name LOAD_TEXTURE2D 3509 * 3510 * Supported cases are M0=1,2,3,...,16 and N0=1 3511 * The data to load is expected to have consecutive names for each row. 3512 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2. 3513 * 3514 * @param[in] M0 The number of consecutive rows 3515 * @param[in] N0 The number of consecutive pixels. Only 1, 2 and 4 are supported 3516 * @param[in] DATA_TYPE The data type of the target 3517 * @param[in] BASENAME The basename of the result variables 3518 * @param[in] IMG The 2D OpenCL image object 3519 * @param[in] X_COORD The x coordinate for the top-left pixel 3520 * @param[in] Y_COORD The y coordinate for the top-left pixel 3521 * @param[in] X_STEP_ROW The incremental step row for the x coordinate (in pixels) 3522 * @param[in] Y_STEP_ROW The incremental step row for the y coordinate (in pixels) 3523 * @{ 3524 */ 3525#define LOAD_TEXTURE2D_STR(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) LOAD_TEXTURE2D_ROW_##M0(N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) 3526#define LOAD_TEXTURE2D(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) LOAD_TEXTURE2D_STR(M0, N0, DATA_TYPE, BASENAME, IMG, X_COORD, Y_COORD, X_STEP_ROW, Y_STEP_ROW) 3527/** @} */ // end of group LOAD_TEXTURE2D 3528 3529/** Loads the elements from 0 to n-1 in the given variables (BASENAME0 to BASENAMEn-1). 3530 * @name LOAD_ELEMENT_n 3531 * 3532 * @param[in] N0 The number of rows to load 3533 * @param[in] DATA_TYPE The data type of variables 3534 * @param[in] BASENAME The basename of the destination variables for the loaded rows 3535 * @param[in] PTR The base pointer 3536 * @param[in] OFFSET The offset within a row 3537 * @param[in] STRIDE_Y The stride value in y-axis direction 3538 * @{ 3539 */ 3540#define LOAD_ELEMENT_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3541 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3542 BASENAME##0 = *((__global DATA_TYPE *)(PTR + OFFSET + 0 * STRIDE_Y)); 3543 3544#define LOAD_ELEMENT_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3545 LOAD_ELEMENT_1(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3546 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3547 BASENAME##1 = *((__global DATA_TYPE *)(PTR + OFFSET + 1 * STRIDE_Y)); 3548 3549#define LOAD_ELEMENT_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3550 LOAD_ELEMENT_2(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3551 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3552 BASENAME##2 = *((__global DATA_TYPE *)(PTR + OFFSET + 2 * STRIDE_Y)); 3553 3554#define LOAD_ELEMENT_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3555 LOAD_ELEMENT_3(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3556 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3557 BASENAME##3 = *((__global DATA_TYPE *)(PTR + OFFSET + 3 * STRIDE_Y)); 3558 3559#define LOAD_ELEMENT_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3560 LOAD_ELEMENT_4(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3561 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3562 BASENAME##4 = *((__global DATA_TYPE *)(PTR + OFFSET + 4 * STRIDE_Y)); 3563 3564#define LOAD_ELEMENT_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3565 LOAD_ELEMENT_5(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3566 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3567 BASENAME##5 = *((__global DATA_TYPE *)(PTR + OFFSET + 5 * STRIDE_Y)); 3568 3569#define LOAD_ELEMENT_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3570 LOAD_ELEMENT_6(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3571 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3572 BASENAME##6 = *((__global DATA_TYPE *)(PTR + OFFSET + 6 * STRIDE_Y)); 3573 3574#define LOAD_ELEMENT_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3575 LOAD_ELEMENT_7(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3576 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3577 BASENAME##7 = *((__global DATA_TYPE *)(PTR + OFFSET + 7 * STRIDE_Y)); 3578 3579#define LOAD_ELEMENT_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3580 LOAD_ELEMENT_8(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3581 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3582 BASENAME##8 = *((__global DATA_TYPE *)(PTR + OFFSET + 8 * STRIDE_Y)); 3583 3584#define LOAD_ELEMENT_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3585 LOAD_ELEMENT_9(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3586 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3587 BASENAME##9 = *((__global DATA_TYPE *)(PTR + OFFSET + 9 * STRIDE_Y)); 3588 3589#define LOAD_ELEMENT_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3590 LOAD_ELEMENT_10(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3591 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3592 BASENAME##A = *((__global DATA_TYPE *)(PTR + OFFSET + 10 * STRIDE_Y)); 3593 3594#define LOAD_ELEMENT_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3595 LOAD_ELEMENT_11(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3596 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3597 BASENAME##B = *((__global DATA_TYPE *)(PTR + OFFSET + 11 * STRIDE_Y)); 3598 3599#define LOAD_ELEMENT_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3600 LOAD_ELEMENT_12(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3601 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3602 BASENAME##C = *((__global DATA_TYPE *)(PTR + OFFSET + 12 * STRIDE_Y)); 3603 3604#define LOAD_ELEMENT_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3605 LOAD_ELEMENT_13(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3606 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3607 BASENAME##D = *((__global DATA_TYPE *)(PTR + OFFSET + 13 * STRIDE_Y)); 3608 3609#define LOAD_ELEMENT_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3610 LOAD_ELEMENT_14(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3611 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3612 BASENAME##E = *((__global DATA_TYPE *)(PTR + OFFSET + 14 * STRIDE_Y)); 3613 3614#define LOAD_ELEMENT_16(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3615 LOAD_ELEMENT_15(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) \ 3616 VEC_DATA_TYPE(DATA_TYPE, N0) \ 3617 BASENAME##F = *((__global DATA_TYPE *)(PTR + OFFSET + 15 * STRIDE_Y)); 3618 3619/** @}*/ // end of group LOAD_ELEMENT_n 3620 3621/** Load Scalar as Vector (consecutive elements). 3622 * @name LOAD_SCALAR_AS_VECTOR 3623 * 3624 * Supported cases are M0=1,2,3,...,16 and N0=1,2,3,4,8,16 3625 * The data to load is expected to have consecutive names for each row. 3626 * E.g., for M0=3, and BASENAME=c, the expected data is c0, c1 and c2. 3627 * 3628 * @param[in] M0 The number of consecutive rows 3629 * @param[in] N0 The number of consecutive columns 3630 * @param[in] DATA_TYPE The data type of the target 3631 * @param[in] BASENAME The basename of the result variables 3632 * @param[in] PTR The base pointer for the data 3633 * @param[in] OFFSET The offset within a row 3634 * @param[in] STRIDE_Y The stride in y-axis direction 3635 * @{ 3636 */ 3637#define LOAD_SCALAR_AS_VECTOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) LOAD_ELEMENT_##M0(N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) 3638#define LOAD_SCALAR_AS_VECTOR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) LOAD_SCALAR_AS_VECTOR_STR(M0, N0, DATA_TYPE, BASENAME, PTR, OFFSET, STRIDE_Y) 3639/** @} */ // end of group LOAD_SCALAR_AS_VECTOR 3640 3641/** Basic macros to calculate Z offset values from Z0 to Zn-1 3642 * @name CALCULATE_Z_OFFSET_n 3643 * 3644 * @param[in] M0 The number of offset values to calculate 3645 * @param[in] DATA_TYPE The data type of the results 3646 * @param[in] Z The basename of the result variables 3647 * @param[in] Y The work-itme ID of y-axis 3648 * @param[in] HEIGHT_GEMM3D The height of GEMM3D 3649 * @param[in] DEPTH_GEMM3D The depth of GEMM3D 3650 * @param[in] CROSS_PLANE_PAD The padding required for plane changes accross the z-dimension 3651 * @param[in] STRIDE_Y The stride value in y-axis direction 3652 * 3653 * @{ 3654 */ 3655#define CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3656 Z##0 = (0 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3657 Z##0 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##0); \ 3658 Z##0 *= (CROSS_PLANE_PAD * STRIDE_Y); 3659 3660#define CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3661 CALCULATE_Z_OFFSET_1(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3662 Z##1 = (1 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3663 Z##1 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##1); \ 3664 Z##1 *= (CROSS_PLANE_PAD * STRIDE_Y); 3665 3666#define CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3667 CALCULATE_Z_OFFSET_2(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3668 Z##2 = (2 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3669 Z##2 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##2); \ 3670 Z##2 *= (CROSS_PLANE_PAD * STRIDE_Y); 3671 3672#define CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3673 CALCULATE_Z_OFFSET_3(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3674 Z##3 = (3 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3675 Z##3 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##3); \ 3676 Z##3 *= (CROSS_PLANE_PAD * STRIDE_Y); 3677 3678#define CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3679 CALCULATE_Z_OFFSET_4(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3680 Z##4 = (4 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3681 Z##4 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##4); \ 3682 Z##4 *= (CROSS_PLANE_PAD * STRIDE_Y); 3683 3684#define CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3685 CALCULATE_Z_OFFSET_5(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3686 Z##5 = (5 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3687 Z##5 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##5); \ 3688 Z##5 *= (CROSS_PLANE_PAD * STRIDE_Y); 3689 3690#define CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3691 CALCULATE_Z_OFFSET_6(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3692 Z##6 = (6 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3693 Z##6 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##6); \ 3694 Z##6 *= (CROSS_PLANE_PAD * STRIDE_Y); 3695 3696#define CALCULATE_Z_OFFSET_8(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3697 CALCULATE_Z_OFFSET_7(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) \ 3698 Z##7 = (7 + (DATA_TYPE)(Y)) / (DATA_TYPE)HEIGHT_GEMM3D; \ 3699 Z##7 = min((DATA_TYPE)(DEPTH_GEMM3D - 1), Z##7); \ 3700 Z##7 *= (CROSS_PLANE_PAD * STRIDE_Y); 3701 3702/** @} */ // end of group CALCULATE_Z_OFFSET_n 3703 3704/** Calculate Z offset values from Z0 to Zn-1 3705 * @name CALCULATE_Z_OFFSET 3706 * 3707 * The Z offsets are expected to have consecutive names. 3708 * E.g., for M0=3 and Z=zin, the expected names of Z offsets are zin1, zin2, zin3. 3709 * Note that, CROSS_PLANE_PAD (cross plain padding) is required to take into account 3710 * the possible cross plane paddings in case of the plance changes across the z-dimension. 3711 * 3712 * <!-- 3713 * | | 3714 * | plane0 | 3715 * | | 3716 * |__________________| 3717 * |******************| 3718 * | cross_plane_pad | 3719 * |******************| 3720 * | | 3721 * | plane1 | 3722 * | | 3723 * |__________________| 3724 * --> 3725 * 3726 * @param[in] M0 The number of offset values to calculate 3727 * @param[in] DATA_TYPE The data type of the results 3728 * @param[in] Z The basename of the result variables 3729 * @param[in] Y The work-itme ID of y-axis 3730 * @param[in] HEIGHT_GEMM3D The height of GEMM3D 3731 * @param[in] DEPTH_GEMM3D The depth of GEMM3D 3732 * @param[in] CROSS_PLANE_PAD The padding required for plane changes accross the z-dimension 3733 * @param[in] STRIDE_Y The stride value in y-axis direction 3734 * @{ 3735 */ 3736#define CALCULATE_Z_OFFSET_STR(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) CALCULATE_Z_OFFSET_##M0(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) 3737#define CALCULATE_Z_OFFSET(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) CALCULATE_Z_OFFSET_STR(M0, DATA_TYPE, Z, Y, HEIGHT_GEMM3D, DEPTH_GEMM3D, CROSS_PLANE_PAD, STRIDE_Y) 3738/** @} */ // end of group CALCULATE_Z_OFFSET 3739 3740/** Scale the rows in the given variables (BASENAME0 to BASENAMEn-1) 3741 * @name SCALE_ROW_n 3742 * 3743 * @param[in] DATA_TYPE The data type of the variables 3744 * @param[in] BASENAME The basename of the variables 3745 * @param[in] SCALE The scale factor 3746 * @{ 3747 */ 3748#define SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \ 3749 BASENAME##0 *= (DATA_TYPE)SCALE; 3750 3751#define SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \ 3752 SCALE_ROW_1(DATA_TYPE, BASENAME, SCALE) \ 3753 BASENAME##1 *= (DATA_TYPE)SCALE; 3754 3755#define SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \ 3756 SCALE_ROW_2(DATA_TYPE, BASENAME, SCALE) \ 3757 BASENAME##2 *= (DATA_TYPE)SCALE; 3758 3759#define SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \ 3760 SCALE_ROW_3(DATA_TYPE, BASENAME, SCALE) \ 3761 BASENAME##3 *= (DATA_TYPE)SCALE; 3762 3763#define SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \ 3764 SCALE_ROW_4(DATA_TYPE, BASENAME, SCALE) \ 3765 BASENAME##4 *= (DATA_TYPE)SCALE; 3766 3767#define SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \ 3768 SCALE_ROW_5(DATA_TYPE, BASENAME, SCALE) \ 3769 BASENAME##5 *= (DATA_TYPE)SCALE; 3770 3771#define SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \ 3772 SCALE_ROW_6(DATA_TYPE, BASENAME, SCALE) \ 3773 BASENAME##6 *= (DATA_TYPE)SCALE; 3774 3775#define SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \ 3776 SCALE_ROW_7(DATA_TYPE, BASENAME, SCALE) \ 3777 BASENAME##7 *= (DATA_TYPE)SCALE; 3778 3779#define SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \ 3780 SCALE_ROW_8(DATA_TYPE, BASENAME, SCALE) \ 3781 BASENAME##8 *= (DATA_TYPE)SCALE; 3782 3783#define SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \ 3784 SCALE_ROW_9(DATA_TYPE, BASENAME, SCALE) \ 3785 BASENAME##9 *= (DATA_TYPE)SCALE; 3786 3787#define SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \ 3788 SCALE_ROW_10(DATA_TYPE, BASENAME, SCALE) \ 3789 BASENAME##A *= (DATA_TYPE)SCALE; 3790 3791#define SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \ 3792 SCALE_ROW_11(DATA_TYPE, BASENAME, SCALE) \ 3793 BASENAME##B *= (DATA_TYPE)SCALE; 3794 3795#define SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \ 3796 SCALE_ROW_12(DATA_TYPE, BASENAME, SCALE) \ 3797 BASENAME##C *= (DATA_TYPE)SCALE; 3798 3799#define SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \ 3800 SCALE_ROW_13(DATA_TYPE, BASENAME, SCALE) \ 3801 BASENAME##D *= (DATA_TYPE)SCALE; 3802 3803#define SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \ 3804 SCALE_ROW_14(DATA_TYPE, BASENAME, SCALE) \ 3805 BASENAME##E *= (DATA_TYPE)SCALE; 3806 3807#define SCALE_ROW_16(DATA_TYPE, BASENAME, SCALE) \ 3808 SCALE_ROW_15(DATA_TYPE, BASENAME, SCALE) \ 3809 BASENAME##F *= (DATA_TYPE)SCALE; 3810/** @} */ // end of group SCALE_ROW_n 3811 3812/** Scale elements stored in a block (BASENAME) 3813 * @name SCALE_BLOCK 3814 * 3815 * Supported cases are N=1,2,3,...,16 3816 * 3817 * @param[in] N The number of rows in the block 3818 * @param[in] DATA_TYPE The data type of the block 3819 * @param[in] BASENAME The basename of the block 3820 * @param[in] SCALE The scale factor 3821 * @{ 3822 */ 3823#define SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) SCALE_ROW_##N(DATA_TYPE, BASENAME, SCALE) 3824#define SCALE_BLOCK(N, DATA_TYPE, BASENAME, SCALE) SCALE_BLOCK_STR(N, DATA_TYPE, BASENAME, SCALE) 3825/** @} */ // end of group SCALE_BLOCK 3826 3827/** Create a new vector containing the values at the given index for a set of given vectors 3828 * @name COLUMN_VECTORn 3829 * 3830 * @param[in] IDX_COL The index value 3831 * @param[in] BASENAME The basename of the destination vectors 3832 * @param[in] X The basename of the source vectors 3833 * @param[in] TYPE The data type of the destination vectors 3834 * @{ 3835 */ 3836#define COLUMN_VECTOR1(IDX_COL, BASENAME, X, TYPE) \ 3837 TYPE BASENAME##IDX_COL = (TYPE)((X##0).s##IDX_COL); 3838#define COLUMN_VECTOR2(IDX_COL, BASENAME, X, TYPE) \ 3839 VEC_DATA_TYPE(TYPE, 2) \ 3840 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 2))((X##0).s##IDX_COL, (X##1).s##IDX_COL); 3841#define COLUMN_VECTOR3(IDX_COL, BASENAME, X, TYPE) \ 3842 VEC_DATA_TYPE(TYPE, 3) \ 3843 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 3))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL); 3844#define COLUMN_VECTOR4(IDX_COL, BASENAME, X, TYPE) \ 3845 VEC_DATA_TYPE(TYPE, 4) \ 3846 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 4))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL); 3847#define COLUMN_VECTOR8(IDX_COL, BASENAME, X, TYPE) \ 3848 VEC_DATA_TYPE(TYPE, 8) \ 3849 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 8))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL); 3850#define COLUMN_VECTOR16(IDX_COL, BASENAME, X, TYPE) \ 3851 VEC_DATA_TYPE(TYPE, 16) \ 3852 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 16))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL, (X##8).s##IDX_COL, (X##9).s##IDX_COL, (X##A).s##IDX_COL, (X##B).s##IDX_COL, (X##C).s##IDX_COL, (X##D).s##IDX_COL, (X##E).s##IDX_COL, (X##F).s##IDX_COL); 3853/** @} */ // end of group COLUMN_VECTORn 3854 3855/** Create a new vector containing the values at the given index. Utility macros for transposing a colum-vector 3856 * @name COLUMN_VECTOR_SCALARn 3857 * 3858 * @param[in] IDX_COL The index value 3859 * @param[in] BASENAME The basename of the destination vectors 3860 * @param[in] X The basename of the source vectors 3861 * @param[in] TYPE The data type of the destination vectors 3862 * @{ 3863 */ 3864#define COLUMN_VECTOR_SCALAR1(IDX_COL, BASENAME, X, TYPE) \ 3865 TYPE BASENAME##IDX_COL = (TYPE)((X##0)); 3866#define COLUMN_VECTOR_SCALAR2(IDX_COL, BASENAME, X, TYPE) \ 3867 VEC_DATA_TYPE(TYPE, 2) \ 3868 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 2))((X##0), (X##1)); 3869#define COLUMN_VECTOR_SCALAR3(IDX_COL, BASENAME, X, TYPE) \ 3870 VEC_DATA_TYPE(TYPE, 3) \ 3871 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 3))((X##0), (X##1), (X##2)); 3872#define COLUMN_VECTOR_SCALAR4(IDX_COL, BASENAME, X, TYPE) \ 3873 VEC_DATA_TYPE(TYPE, 4) \ 3874 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 4))((X##0), (X##1), (X##2), (X##3)); 3875#define COLUMN_VECTOR_SCALAR8(IDX_COL, BASENAME, X, TYPE) \ 3876 VEC_DATA_TYPE(TYPE, 8) \ 3877 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 8))((X##0), (X##1), (X##2), (X##3), (X##4), (X##5), (X##6), (X##7)); 3878#define COLUMN_VECTOR_SCALAR16(IDX_COL, BASENAME, X, TYPE) \ 3879 VEC_DATA_TYPE(TYPE, 16) \ 3880 BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 16))((X##0), (X##1), (X##2), (X##3), (X##4), (X##5), (X##6), (X##7), (X##8), (X##9), (X##A), (X##B), (X##C), (X##D), (X##E), (X##F)); 3881/** @} */ // end of group COLUMN_VECTORn 3882 3883/** Create transposed vectors of the given vectors 3884 * @name TRANSPOSE_K0Xn 3885 * 3886 * @param[in] K0 The size of the source vectors 3887 * @param[in] BASENAME The basename of transposed vectors 3888 * @param[in] B The basename of source vectors for transposition 3889 * @param[in] TYPE The data type of the transposed vectors 3890 * @{ 3891 */ 3892#define TRANSPOSE_K0X1(K0, BASENAME, B, TYPE) \ 3893 COLUMN_VECTOR_SCALAR(K0, 0, BASENAME, B, TYPE); 3894#define TRANSPOSE_K0X2(K0, BASENAME, B, TYPE) \ 3895 COLUMN_VECTOR(K0, 0, BASENAME, B, TYPE); \ 3896 COLUMN_VECTOR(K0, 1, BASENAME, B, TYPE); 3897#define TRANSPOSE_K0X3(K0, BASENAME, B, TYPE) \ 3898 TRANSPOSE_K0X2(K0, BASENAME, B, TYPE); \ 3899 COLUMN_VECTOR(K0, 2, BASENAME, B, TYPE); 3900#define TRANSPOSE_K0X4(K0, BASENAME, B, TYPE) \ 3901 TRANSPOSE_K0X3(K0, BASENAME, B, TYPE); \ 3902 COLUMN_VECTOR(K0, 3, BASENAME, B, TYPE); 3903#define TRANSPOSE_K0X8(K0, BASENAME, B, TYPE) \ 3904 TRANSPOSE_K0X4(K0, BASENAME, B, TYPE); \ 3905 COLUMN_VECTOR(K0, 4, BASENAME, B, TYPE); \ 3906 COLUMN_VECTOR(K0, 5, BASENAME, B, TYPE); \ 3907 COLUMN_VECTOR(K0, 6, BASENAME, B, TYPE); \ 3908 COLUMN_VECTOR(K0, 7, BASENAME, B, TYPE); 3909#define TRANSPOSE_K0X16(K0, BASENAME, B, TYPE) \ 3910 TRANSPOSE_K0X8(K0, BASENAME, B, TYPE); \ 3911 COLUMN_VECTOR(K0, 8, BASENAME, B, TYPE); \ 3912 COLUMN_VECTOR(K0, 9, BASENAME, B, TYPE); \ 3913 COLUMN_VECTOR(K0, A, BASENAME, B, TYPE); \ 3914 COLUMN_VECTOR(K0, B, BASENAME, B, TYPE); \ 3915 COLUMN_VECTOR(K0, C, BASENAME, B, TYPE); \ 3916 COLUMN_VECTOR(K0, D, BASENAME, B, TYPE); \ 3917 COLUMN_VECTOR(K0, E, BASENAME, B, TYPE); \ 3918 COLUMN_VECTOR(K0, F, BASENAME, B, TYPE); 3919 3920/** @} */ // end of group TRANSPOSE_K0Xn 3921 3922/** Create column vectors to contain the values at the given index for a set of given vectors 3923 * 3924 * @param[in] K0 The number of source vectors 3925 * @param[in] IDX_COL The index value 3926 * @param[in] BASENAME The basename of the destination vectors 3927 * @param[in] B The basename of the source vectors 3928 * @param[in] TYPE The data type of the destination vectors 3929 */ 3930#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B, TYPE) \ 3931 CONCAT(COLUMN_VECTOR, K0) \ 3932 (IDX_COL, BASENAME, B, TYPE); 3933 3934/** Create column vectors to contain the values at the given index. Utility macro for transposing a column-vector 3935 * 3936 * @param[in] K0 The number of source vectors 3937 * @param[in] IDX_COL The index value 3938 * @param[in] BASENAME The basename of the destination vectors 3939 * @param[in] B The basename of the source vectors 3940 * @param[in] TYPE The data type of the destination vectors 3941 */ 3942#define COLUMN_VECTOR_SCALAR(K0, IDX_COL, BASENAME, B, TYPE) \ 3943 CONCAT(COLUMN_VECTOR_SCALAR, K0) \ 3944 (IDX_COL, BASENAME, B, TYPE); 3945 3946/** Create transposed vectors form the given source vectors 3947 * 3948 * @param[in] K0 The size of source vectors 3949 * @param[in] N0 The number of source vectors 3950 * @param[in] BASENAME The basename of transposed vectors 3951 * @param[in] B The basename of source vectors for transposition 3952 * @param[in] TYPE The data type of the transposed vectors 3953 * 3954 */ 3955#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B, TYPE) \ 3956 CONCAT(TRANSPOSE_K0X, N0) \ 3957 (K0, BASENAME, B, TYPE); 3958 3959/** Add the variables (BIAS0 to BIASn-1) to the others (BASENAME0 to BASENAMEn-1) 3960 * @name ADD_ROW_n 3961 * 3962 * @param[in] BASENAME The basename of the destination variables 3963 * @param[in] BIAS The basename of the added variables 3964 * @{ 3965 */ 3966#define ADD_ROW_1(BASENAME, BIAS) \ 3967 BASENAME##0 += BIAS##0; 3968 3969#define ADD_ROW_2(BASENAME, BIAS) \ 3970 ADD_ROW_1(BASENAME, BIAS) \ 3971 BASENAME##1 += BIAS##1; 3972 3973#define ADD_ROW_3(BASENAME, BIAS) \ 3974 ADD_ROW_2(BASENAME, BIAS) \ 3975 BASENAME##2 += BIAS##2; 3976 3977#define ADD_ROW_4(BASENAME, BIAS) \ 3978 ADD_ROW_3(BASENAME, BIAS) \ 3979 BASENAME##3 += BIAS##3; 3980 3981#define ADD_ROW_5(BASENAME, BIAS) \ 3982 ADD_ROW_4(BASENAME, BIAS) \ 3983 BASENAME##4 += BIAS##4; 3984 3985#define ADD_ROW_6(BASENAME, BIAS) \ 3986 ADD_ROW_5(BASENAME, BIAS) \ 3987 BASENAME##5 += BIAS##5; 3988 3989#define ADD_ROW_7(BASENAME, BIAS) \ 3990 ADD_ROW_6(BASENAME, BIAS) \ 3991 BASENAME##6 += BIAS##6; 3992 3993#define ADD_ROW_8(BASENAME, BIAS) \ 3994 ADD_ROW_7(BASENAME, BIAS) \ 3995 BASENAME##7 += BIAS##7; 3996 3997#define ADD_ROW_9(BASENAME, BIAS) \ 3998 ADD_ROW_8(BASENAME, BIAS) \ 3999 BASENAME##8 += BIAS##8; 4000 4001#define ADD_ROW_10(BASENAME, BIAS) \ 4002 ADD_ROW_9(BASENAME, BIAS) \ 4003 BASENAME##9 += BIAS##9; 4004 4005#define ADD_ROW_11(BASENAME, BIAS) \ 4006 ADD_ROW_10(BASENAME, BIAS) \ 4007 BASENAME##A += BIAS##A; 4008 4009#define ADD_ROW_12(BASENAME, BIAS) \ 4010 ADD_ROW_11(BASENAME, BIAS) \ 4011 BASENAME##B += BIAS##B; 4012 4013#define ADD_ROW_13(BASENAME, BIAS) \ 4014 ADD_ROW_12(BASENAME, BIAS) \ 4015 BASENAME##C += BIAS##C; 4016 4017#define ADD_ROW_14(BASENAME, BIAS) \ 4018 ADD_ROW_13(BASENAME, BIAS) \ 4019 BASENAME##D += BIAS##D; 4020 4021#define ADD_ROW_15(BASENAME, BIAS) \ 4022 ADD_ROW_14(BASENAME, BIAS) \ 4023 BASENAME##E += BIAS##E; 4024 4025#define ADD_ROW_16(BASENAME, BIAS) \ 4026 ADD_ROW_15(BASENAME, BIAS) \ 4027 BASENAME##F += BIAS##F; 4028 4029/** @} */ // end of group ADD_ROW_n 4030 4031/** Add the block (BIAS) to another block (BASENAME) 4032 * @name ADD_BLOCK 4033 * 4034 * Supported cases are N=1,2,3,...,16 4035 * 4036 * @param[in] N The number of vectors in the block 4037 * @param[in] BASENAME The basename of the destination variables 4038 * @param[in] BIAS The basename of the added variables 4039 * @{ 4040 */ 4041#define ADD_BLOCK_STR(N, BASENAME, BIAS) ADD_ROW_##N(BASENAME, BIAS) 4042#define ADD_BLOCK(N, BASENAME, BIAS) ADD_BLOCK_STR(N, BASENAME, BIAS) 4043/** @} */ // end of group ADD_BLOCK 4044 4045/** Broadcast (add single value) to the each element of the destination variables 4046 * @name ADD_ROW_BROADCAST_n 4047 * 4048 * @param[in] BASENAME The basename of the destination variables 4049 * @param[in] BIAS The variable containing the value to add 4050 * @{ 4051 */ 4052#define ADD_ROW_BROADCAST_1(BASENAME, BIAS) \ 4053 BASENAME##0 += BIAS; 4054 4055#define ADD_ROW_BROADCAST_2(BASENAME, BIAS) \ 4056 ADD_ROW_BROADCAST_1(BASENAME, BIAS) \ 4057 BASENAME##1 += BIAS; 4058 4059#define ADD_ROW_BROADCAST_3(BASENAME, BIAS) \ 4060 ADD_ROW_BROADCAST_2(BASENAME, BIAS) \ 4061 BASENAME##2 += BIAS; 4062 4063#define ADD_ROW_BROADCAST_4(BASENAME, BIAS) \ 4064 ADD_ROW_BROADCAST_3(BASENAME, BIAS) \ 4065 BASENAME##3 += BIAS; 4066 4067#define ADD_ROW_BROADCAST_5(BASENAME, BIAS) \ 4068 ADD_ROW_BROADCAST_4(BASENAME, BIAS) \ 4069 BASENAME##4 += BIAS; 4070 4071#define ADD_ROW_BROADCAST_6(BASENAME, BIAS) \ 4072 ADD_ROW_BROADCAST_5(BASENAME, BIAS) \ 4073 BASENAME##5 += BIAS; 4074 4075#define ADD_ROW_BROADCAST_7(BASENAME, BIAS) \ 4076 ADD_ROW_BROADCAST_6(BASENAME, BIAS) \ 4077 BASENAME##6 += BIAS; 4078 4079#define ADD_ROW_BROADCAST_8(BASENAME, BIAS) \ 4080 ADD_ROW_BROADCAST_7(BASENAME, BIAS) \ 4081 BASENAME##7 += BIAS; 4082 4083#define ADD_ROW_BROADCAST_9(BASENAME, BIAS) \ 4084 ADD_ROW_BROADCAST_8(BASENAME, BIAS) \ 4085 BASENAME##8 += BIAS; 4086 4087#define ADD_ROW_BROADCAST_10(BASENAME, BIAS) \ 4088 ADD_ROW_BROADCAST_9(BASENAME, BIAS) \ 4089 BASENAME##9 += BIAS; 4090 4091#define ADD_ROW_BROADCAST_11(BASENAME, BIAS) \ 4092 ADD_ROW_BROADCAST_10(BASENAME, BIAS) \ 4093 BASENAME##A += BIAS; 4094 4095#define ADD_ROW_BROADCAST_12(BASENAME, BIAS) \ 4096 ADD_ROW_BROADCAST_11(BASENAME, BIAS) \ 4097 BASENAME##B += BIAS; 4098 4099#define ADD_ROW_BROADCAST_13(BASENAME, BIAS) \ 4100 ADD_ROW_BROADCAST_12(BASENAME, BIAS) \ 4101 BASENAME##C += BIAS; 4102 4103#define ADD_ROW_BROADCAST_14(BASENAME, BIAS) \ 4104 ADD_ROW_BROADCAST_13(BASENAME, BIAS) \ 4105 BASENAME##D += BIAS; 4106 4107#define ADD_ROW_BROADCAST_15(BASENAME, BIAS) \ 4108 ADD_ROW_BROADCAST_14(BASENAME, BIAS) \ 4109 BASENAME##E += BIAS; 4110 4111#define ADD_ROW_BROADCAST_16(BASENAME, BIAS) \ 4112 ADD_ROW_BROADCAST_15(BASENAME, BIAS) \ 4113 BASENAME##F += BIAS; 4114 4115/** Broadcast (add a value) to the each element of the destination block (BASENAME) 4116 * @name ADD_BLOCK_BROADCAST 4117 * 4118 * Supported cases are N=1,2,3,...,16. 4119 * 4120 * @param[in] N The number of vectors in the block 4121 * @param[in] BASENAME The basename of the destination variables 4122 * @param[in] BIAS The variable containing the value to add 4123 * @{ 4124 */ 4125#define ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) ADD_ROW_BROADCAST_##N(BASENAME, BIAS) 4126#define ADD_BLOCK_BROADCAST(N, BASENAME, BIAS) ADD_BLOCK_BROADCAST_STR(N, BASENAME, BIAS) 4127/** @} */ // end of group ADD_BLOCK_BROADCAST 4128 4129/** Apply activation to the given variables 4130 * @name ACTIVATION_ROW_n 4131 * 4132 * @param[in] ACTIVATION_TYPE The type of the activation 4133 * @param[in] DATA_TYPE The data type of the vectors 4134 * @param[in] BASENAME The basename of the variables 4135 * @param[in] A_VAL Additional value required by the activation 4136 * @param[in] B_VAL Additional value required by the activation 4137 * @{ 4138 */ 4139#define ACTIVATION_ROW_1(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4140 BASENAME##0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##0, A_VAL, B_VAL); 4141 4142#define ACTIVATION_ROW_2(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4143 ACTIVATION_ROW_1(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4144 BASENAME##1 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##1, A_VAL, B_VAL); 4145 4146#define ACTIVATION_ROW_3(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4147 ACTIVATION_ROW_2(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4148 BASENAME##2 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##2, A_VAL, B_VAL); 4149 4150#define ACTIVATION_ROW_4(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4151 ACTIVATION_ROW_3(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4152 BASENAME##3 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##3, A_VAL, B_VAL); 4153 4154#define ACTIVATION_ROW_5(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4155 ACTIVATION_ROW_4(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4156 BASENAME##4 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##4, A_VAL, B_VAL); 4157 4158#define ACTIVATION_ROW_6(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4159 ACTIVATION_ROW_5(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4160 BASENAME##5 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##5, A_VAL, B_VAL); 4161 4162#define ACTIVATION_ROW_7(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4163 ACTIVATION_ROW_6(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4164 BASENAME##6 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##6, A_VAL, B_VAL); 4165 4166#define ACTIVATION_ROW_8(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4167 ACTIVATION_ROW_7(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4168 BASENAME##7 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##7, A_VAL, B_VAL); 4169 4170#define ACTIVATION_ROW_9(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4171 ACTIVATION_ROW_8(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4172 BASENAME##8 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##8, A_VAL, B_VAL); 4173 4174#define ACTIVATION_ROW_10(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4175 ACTIVATION_ROW_9(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4176 BASENAME##9 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##9, A_VAL, B_VAL); 4177 4178#define ACTIVATION_ROW_11(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4179 ACTIVATION_ROW_10(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4180 BASENAME##A = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##A, A_VAL, B_VAL); 4181 4182#define ACTIVATION_ROW_12(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4183 ACTIVATION_ROW_11(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4184 BASENAME##B = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##B, A_VAL, B_VAL); 4185 4186#define ACTIVATION_ROW_13(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4187 ACTIVATION_ROW_12(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4188 BASENAME##C = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##C, A_VAL, B_VAL); 4189 4190#define ACTIVATION_ROW_14(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4191 ACTIVATION_ROW_13(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4192 BASENAME##D = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##D, A_VAL, B_VAL); 4193 4194#define ACTIVATION_ROW_15(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4195 ACTIVATION_ROW_14(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4196 BASENAME##E = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##E, A_VAL, B_VAL); 4197 4198#define ACTIVATION_ROW_16(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4199 ACTIVATION_ROW_15(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) \ 4200 BASENAME##F = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME##F, A_VAL, B_VAL); 4201/** @} */ // end of group ACTIVATION_ROW_n 4202 4203/** Apply activation to a block (BASENAME) 4204 * @name ACTIVATION_BLOCK 4205 * 4206 * Supported cases are N=1,2,3,...,16. 4207 * 4208 * @param[in] N The number of vectors in the block 4209 * @param[in] ACTIVATION_TYPE The type of the activation 4210 * @param[in] DATA_TYPE The data type of the vectors 4211 * @param[in] BASENAME The basename of the variables 4212 * @param[in] A_VAL Additional value required by the activation 4213 * @param[in] B_VAL Additional value required by the activation 4214 * @{ 4215 */ 4216#define ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) ACTIVATION_ROW_##N(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) 4217#define ACTIVATION_BLOCK(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) ACTIVATION_BLOCK_STR(N, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, BASENAME, A_VAL, B_VAL) 4218/** @} */ // end of group ACTIVATION_BLOCK 4219 4220/** Apply convert_<data_type> to the given variables 4221 * @name CONVERT_ROW_n 4222 * 4223 * @param[in] N The size of the vectors 4224 * @param[in] DATA_TYPE The data type of the vectors 4225 * @param[in] BASENAME_SRC The basename of the source variables 4226 * @param[in] BASENAME_DST The basename of the destination variables 4227 */ 4228#define CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4229 VEC_DATA_TYPE(DATA_TYPE, N) \ 4230 BASENAME_DST##0 = CONVERT(BASENAME_SRC##0, VEC_DATA_TYPE(DATA_TYPE, N)); 4231 4232#define CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4233 CONVERT_ROW_1(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4234 VEC_DATA_TYPE(DATA_TYPE, N) \ 4235 BASENAME_DST##1 = CONVERT(BASENAME_SRC##1, VEC_DATA_TYPE(DATA_TYPE, N)); 4236 4237#define CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4238 CONVERT_ROW_2(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4239 VEC_DATA_TYPE(DATA_TYPE, N) \ 4240 BASENAME_DST##2 = CONVERT(BASENAME_SRC##2, VEC_DATA_TYPE(DATA_TYPE, N)); 4241 4242#define CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4243 CONVERT_ROW_3(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4244 VEC_DATA_TYPE(DATA_TYPE, N) \ 4245 BASENAME_DST##3 = CONVERT(BASENAME_SRC##3, VEC_DATA_TYPE(DATA_TYPE, N)); 4246 4247#define CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4248 CONVERT_ROW_4(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4249 VEC_DATA_TYPE(DATA_TYPE, N) \ 4250 BASENAME_DST##4 = CONVERT(BASENAME_SRC##4, VEC_DATA_TYPE(DATA_TYPE, N)); 4251 4252#define CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4253 CONVERT_ROW_5(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4254 VEC_DATA_TYPE(DATA_TYPE, N) \ 4255 BASENAME_DST##5 = CONVERT(BASENAME_SRC##5, VEC_DATA_TYPE(DATA_TYPE, N)); 4256 4257#define CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4258 CONVERT_ROW_6(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4259 VEC_DATA_TYPE(DATA_TYPE, N) \ 4260 BASENAME_DST##6 = CONVERT(BASENAME_SRC##6, VEC_DATA_TYPE(DATA_TYPE, N)); 4261 4262#define CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4263 CONVERT_ROW_7(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4264 VEC_DATA_TYPE(DATA_TYPE, N) \ 4265 BASENAME_DST##7 = CONVERT(BASENAME_SRC##7, VEC_DATA_TYPE(DATA_TYPE, N)); 4266 4267#define CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4268 CONVERT_ROW_8(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4269 VEC_DATA_TYPE(DATA_TYPE, N) \ 4270 BASENAME_DST##8 = CONVERT(BASENAME_SRC##8, VEC_DATA_TYPE(DATA_TYPE, N)); 4271 4272#define CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4273 CONVERT_ROW_9(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4274 VEC_DATA_TYPE(DATA_TYPE, N) \ 4275 BASENAME_DST##9 = CONVERT(BASENAME_SRC##9, VEC_DATA_TYPE(DATA_TYPE, N)); 4276 4277#define CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4278 CONVERT_ROW_10(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4279 VEC_DATA_TYPE(DATA_TYPE, N) \ 4280 BASENAME_DST##A = CONVERT(BASENAME_SRC##A, VEC_DATA_TYPE(DATA_TYPE, N)); 4281 4282#define CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4283 CONVERT_ROW_11(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4284 VEC_DATA_TYPE(DATA_TYPE, N) \ 4285 BASENAME_DST##B = CONVERT(BASENAME_SRC##B, VEC_DATA_TYPE(DATA_TYPE, N)); 4286 4287#define CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4288 CONVERT_ROW_12(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4289 VEC_DATA_TYPE(DATA_TYPE, N) \ 4290 BASENAME_DST##C = CONVERT(BASENAME_SRC##C, VEC_DATA_TYPE(DATA_TYPE, N)); 4291 4292#define CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4293 CONVERT_ROW_13(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4294 VEC_DATA_TYPE(DATA_TYPE, N) \ 4295 BASENAME_DST##D = CONVERT(BASENAME_SRC##D, VEC_DATA_TYPE(DATA_TYPE, N)); 4296 4297#define CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4298 CONVERT_ROW_14(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4299 VEC_DATA_TYPE(DATA_TYPE, N) \ 4300 BASENAME_DST##E = CONVERT(BASENAME_SRC##E, VEC_DATA_TYPE(DATA_TYPE, N)); 4301 4302#define CONVERT_ROW_16(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4303 CONVERT_ROW_15(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) \ 4304 VEC_DATA_TYPE(DATA_TYPE, N) \ 4305 BASENAME_DST##F = CONVERT(BASENAME_SRC##F, VEC_DATA_TYPE(DATA_TYPE, N)); 4306/** @} */ // end of group CONVERT_ROW_n 4307 4308/** Apply convert_<data_type> to a block (BASENAME_SRC) and save to another block (BASENAME_DST) 4309 * @name CONVERT_BLOCK 4310 * 4311 * Supported cases N=1,2,3,...,16. 4312 * 4313 * @param[in] M The number of vectors to convert 4314 * @param[in] N The size of the vectors 4315 * @param[in] DATA_TYPE The data type of the vectors 4316 * @param[in] BASENAME_SRC The basename of the source variables 4317 * @param[in] BASENAME_DST The basename of the destination variables 4318 */ 4319#define CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_ROW_##M(N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) 4320#define CONVERT_BLOCK(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) CONVERT_BLOCK_STR(M, N, DATA_TYPE, BASENAME_SRC, BASENAME_DST) 4321/** @} */ // end of group CONVERT_BLOCK 4322/* 4323 * Copyright (c) 2019-2020 Arm Limited. 4324 * 4325 * SPDX-License-Identifier: MIT 4326 * 4327 * Permission is hereby granted, free of charge, to any person obtaining a copy 4328 * of this software and associated documentation files (the "Software"), to 4329 * deal in the Software without restriction, including without limitation the 4330 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 4331 * sell copies of the Software, and to permit persons to whom the Software is 4332 * furnished to do so, subject to the following conditions: 4333 * 4334 * The above copyright notice and this permission notice shall be included in all 4335 * copies or substantial portions of the Software. 4336 * 4337 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 4338 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 4339 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 4340 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 4341 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 4342 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 4343 * SOFTWARE. 4344 */ 4345#ifndef ARM_COMPUTE_REPEAT_H 4346#define ARM_COMPUTE_REPEAT_H 4347 4348/* 4349 * Copyright (c) 2016-2020 Arm Limited. 4350 * 4351 * SPDX-License-Identifier: MIT 4352 * 4353 * Permission is hereby granted, free of charge, to any person obtaining a copy 4354 * of this software and associated documentation files (the "Software"), to 4355 * deal in the Software without restriction, including without limitation the 4356 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 4357 * sell copies of the Software, and to permit persons to whom the Software is 4358 * furnished to do so, subject to the following conditions: 4359 * 4360 * The above copyright notice and this permission notice shall be included in all 4361 * copies or substantial portions of the Software. 4362 * 4363 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 4364 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 4365 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 4366 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 4367 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 4368 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 4369 * SOFTWARE. 4370 */ 4371#ifndef ARM_COMPUTE_HELPER_H 4372#define ARM_COMPUTE_HELPER_H 4373 4374/* 4375 * Copyright (c) 2020 Arm Limited. 4376 * 4377 * SPDX-License-Identifier: MIT 4378 * 4379 * Permission is hereby granted, free of charge, to any person obtaining a copy 4380 * of this software and associated documentation files (the "Software"), to 4381 * deal in the Software without restriction, including without limitation the 4382 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 4383 * sell copies of the Software, and to permit persons to whom the Software is 4384 * furnished to do so, subject to the following conditions: 4385 * 4386 * The above copyright notice and this permission notice shall be included in all 4387 * copies or substantial portions of the Software. 4388 * 4389 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 4390 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 4391 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 4392 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 4393 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 4394 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 4395 * SOFTWARE. 4396 */ 4397 4398/** Store the 0 to (n-1)th rows of the given variables 4399 * @name STORE_ROW_n 4400 * 4401 * @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16 4402 * @param[in] DATA_TYPE The data type of the vectors 4403 * @param[in] BASENAME The basename of the variables 4404 * @param[in] PTR The base pointer 4405 * @param[in] STRIDE_Y The stride value in y-axis direction 4406 * @param[in] Z The offset in z-axis direction 4407 * @{ 4408 */ 4409#define STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4410 VSTORE(N0) \ 4411 (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 4412 4413#define STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4414 STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4415 VSTORE(N0) \ 4416 (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 4417 4418#define STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4419 STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4420 VSTORE(N0) \ 4421 (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 4422 4423#define STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4424 STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4425 VSTORE(N0) \ 4426 (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 4427 4428#define STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4429 STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4430 VSTORE(N0) \ 4431 (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 4432 4433#define STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4434 STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4435 VSTORE(N0) \ 4436 (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 4437 4438#define STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4439 STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4440 VSTORE(N0) \ 4441 (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 4442 4443#define STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4444 STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4445 VSTORE(N0) \ 4446 (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 4447 4448#define STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4449 STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4450 VSTORE(N0) \ 4451 (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 4452 4453#define STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4454 STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4455 VSTORE(N0) \ 4456 (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 4457 4458#define STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4459 STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4460 VSTORE(N0) \ 4461 (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 4462 4463#define STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4464 STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4465 VSTORE(N0) \ 4466 (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 4467 4468#define STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4469 STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4470 VSTORE(N0) \ 4471 (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 4472 4473#define STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4474 STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4475 VSTORE(N0) \ 4476 (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 4477 4478#define STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4479 STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4480 VSTORE(N0) \ 4481 (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 4482 4483#define STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4484 STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4485 VSTORE(N0) \ 4486 (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 4487/** @} */ // end of groupd STORE_ROW_n 4488 4489/** Convert and store the 0th to (n-1)th rows of the given variables 4490 * @name CONVERT_STORE_ROW_n 4491 * 4492 * @param[in] N0 The size of the vectors 4493 * @param[in] DATA_TYPE The data type of the vectors 4494 * @param[in] BASENAME The basename of the variables 4495 * @param[in] PTR The base pointer 4496 * @param[in] STRIDE_Y The stride value in y-axis direction 4497 * @param[in] Z The offset in z-axis direction 4498 * @{ 4499 */ 4500#define CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4501 VSTORE(N0) \ 4502 (CONVERT_SAT((BASENAME##0), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 4503 4504#define CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4505 CONVERT_STORE_ROW_1(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4506 VSTORE(N0) \ 4507 (CONVERT_SAT((BASENAME##1), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 4508 4509#define CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4510 CONVERT_STORE_ROW_2(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4511 VSTORE(N0) \ 4512 (CONVERT_SAT((BASENAME##2), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 4513 4514#define CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4515 CONVERT_STORE_ROW_3(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4516 VSTORE(N0) \ 4517 (CONVERT_SAT((BASENAME##3), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 4518 4519#define CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4520 CONVERT_STORE_ROW_4(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4521 VSTORE(N0) \ 4522 (CONVERT_SAT((BASENAME##4), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 4523 4524#define CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4525 CONVERT_STORE_ROW_5(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4526 VSTORE(N0) \ 4527 (CONVERT_SAT((BASENAME##5), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 4528 4529#define CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4530 CONVERT_STORE_ROW_6(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4531 VSTORE(N0) \ 4532 (CONVERT_SAT((BASENAME##6), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 4533 4534#define CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4535 CONVERT_STORE_ROW_7(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4536 VSTORE(N0) \ 4537 (CONVERT_SAT((BASENAME##7), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 4538 4539#define CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4540 CONVERT_STORE_ROW_8(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4541 VSTORE(N0) \ 4542 (CONVERT_SAT((BASENAME##8), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 4543 4544#define CONVERT_STORE_ROW_10(N0, DATA, BASENAME, PTR, STRIDE_Y, Z) \ 4545 CONVERT_STORE_ROW_9(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4546 VSTORE(N0) \ 4547 (CONVERT_SAT((BASENAME##9), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 4548 4549#define CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4550 CONVERT_STORE_ROW_10(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4551 VSTORE(N0) \ 4552 (CONVERT_SAT((BASENAME##A), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 4553 4554#define CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4555 CONVERT_STORE_ROW_11(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4556 VSTORE(N0) \ 4557 (CONVERT_SAT((BASENAME##B), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 4558 4559#define CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4560 CONVERT_STORE_ROW_12(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4561 VSTORE(N0) \ 4562 (CONVERT_SAT((BASENAME##C), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 4563 4564#define CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4565 CONVERT_STORE_ROW_13(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4566 VSTORE(N0) \ 4567 (CONVERT_SAT((BASENAME##D), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 4568 4569#define CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4570 CONVERT_STORE_ROW_14(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4571 VSTORE(N0) \ 4572 (CONVERT_SAT((BASENAME##E), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 4573 4574#define CONVERT_STORE_ROW_16(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4575 CONVERT_STORE_ROW_15(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4576 VSTORE(N0) \ 4577 (CONVERT_SAT((BASENAME##F), VEC_DATA_TYPE(DATA_TYPE, N0)), 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 4578 4579/** @} */ // end of groupd CONVERT_STORE_ROW_n 4580 4581/** Store a block of the given size M0xN0 4582 * @name STORE_BLOCK 4583 * 4584 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16. 4585 * The data to store is expected to have consecutive names for each row. 4586 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 4587 * The Z offset is expected to have consecutive names. 4588 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4589 * 4590 * @param[in] M0 The number of rows to store 4591 * @param[in] N0 The size of each vector 4592 * @param[in] DATA_TYPE The data type of the vectors 4593 * @param[in] BASENAME The basename of the variables 4594 * @param[in] PTR The base pointer 4595 * @param[in] STRIDE_Y The stride value in y-axis direction 4596 * @param[in] Z The offset in z-axis direction 4597 * @{ 4598 */ 4599#define STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4600#define STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4601/** @} */ // end of group STORE_BLOCK 4602 4603/** Convert and store a block of the given size M0xN0 4604 * @name CONVERT_STORE_BLOCK 4605 * 4606 * Supported cases are M0=1,2,3,...,16 and N0=2,3,4,8,16. 4607 * The data to store is expected to have consecutive names for each row. 4608 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 4609 * The Z offset is expected to have consecutive names. 4610 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4611 * 4612 * @param[in] M0 The number of rows to store 4613 * @param[in] N0 The size of each vector 4614 * @param[in] DATA_TYPE The data type of the vectors 4615 * @param[in] BASENAME The basename of the variables 4616 * @param[in] PTR The base pointer 4617 * @param[in] STRIDE_Y The stride value in y-axis direction 4618 * @param[in] Z The offset in z-axis direction 4619 * @{ 4620 */ 4621#define CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_ROW_##M0(N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4622#define CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) CONVERT_STORE_BLOCK_STR(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4623/** @} */ // end of group CONVERT_STORE_BLOCK 4624 4625/** Partially store the 0 to (n-1)th rows of the given variables 4626 * @name STORE_ROW_PARTIAL_n 4627 * Within each row, store the lower @p STORE_N0 elements of vectors of width @p N0 4628 * 4629 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4630 * 4631 * @param[in] N0 The width of the passed in vector. Supported: 1, 2, 3, 4, 8, 16 4632 * @param[in] STORE_N0 The **lower** size of the vectors to store. Supported: [1-16 and <= @p N0 4633 * @param[in] DATA_TYPE The data type of the vectors 4634 * @param[in] BASENAME The basename of the variables 4635 * @param[in] PTR The base pointer 4636 * @param[in] STRIDE_Y The stride value in y-axis direction 4637 * @param[in] Z The offset in z-axis direction 4638 * @{ 4639 */ 4640#define STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4641 VSTORE_PARTIAL(N0, STORE_N0) \ 4642 (BASENAME##0, 0, (__global DATA_TYPE *)(PTR + 0 * STRIDE_Y + Z##0)); 4643 4644#define STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4645 STORE_ROW_PARTIAL_1(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4646 VSTORE_PARTIAL(N0, STORE_N0) \ 4647 (BASENAME##1, 0, (__global DATA_TYPE *)(PTR + 1 * STRIDE_Y + Z##1)); 4648 4649#define STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4650 STORE_ROW_PARTIAL_2(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4651 VSTORE_PARTIAL(N0, STORE_N0) \ 4652 (BASENAME##2, 0, (__global DATA_TYPE *)(PTR + 2 * STRIDE_Y + Z##2)); 4653 4654#define STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4655 STORE_ROW_PARTIAL_3(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4656 VSTORE_PARTIAL(N0, STORE_N0) \ 4657 (BASENAME##3, 0, (__global DATA_TYPE *)(PTR + 3 * STRIDE_Y + Z##3)); 4658 4659#define STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4660 STORE_ROW_PARTIAL_4(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4661 VSTORE_PARTIAL(N0, STORE_N0) \ 4662 (BASENAME##4, 0, (__global DATA_TYPE *)(PTR + 4 * STRIDE_Y + Z##4)); 4663 4664#define STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4665 STORE_ROW_PARTIAL_5(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4666 VSTORE_PARTIAL(N0, STORE_N0) \ 4667 (BASENAME##5, 0, (__global DATA_TYPE *)(PTR + 5 * STRIDE_Y + Z##5)); 4668 4669#define STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4670 STORE_ROW_PARTIAL_6(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4671 VSTORE_PARTIAL(N0, STORE_N0) \ 4672 (BASENAME##6, 0, (__global DATA_TYPE *)(PTR + 6 * STRIDE_Y + Z##6)); 4673 4674#define STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4675 STORE_ROW_PARTIAL_7(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4676 VSTORE_PARTIAL(N0, STORE_N0) \ 4677 (BASENAME##7, 0, (__global DATA_TYPE *)(PTR + 7 * STRIDE_Y + Z##7)); 4678 4679#define STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4680 STORE_ROW_PARTIAL_8(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4681 VSTORE_PARTIAL(N0, STORE_N0) \ 4682 (BASENAME##8, 0, (__global DATA_TYPE *)(PTR + 8 * STRIDE_Y + Z##8)); 4683 4684#define STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4685 STORE_ROW_PARTIAL_9(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4686 VSTORE_PARTIAL(N0, STORE_N0) \ 4687 (BASENAME##9, 0, (__global DATA_TYPE *)(PTR + 9 * STRIDE_Y + Z##9)); 4688 4689#define STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4690 STORE_ROW_PARTIAL_10(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4691 VSTORE_PARTIAL(N0, STORE_N0) \ 4692 (BASENAME##A, 0, (__global DATA_TYPE *)(PTR + 10 * STRIDE_Y + Z##A)); 4693 4694#define STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4695 STORE_ROW_PARTIAL_11(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4696 VSTORE_PARTIAL(N0, STORE_N0) \ 4697 (BASENAME##B, 0, (__global DATA_TYPE *)(PTR + 11 * STRIDE_Y + Z##B)); 4698 4699#define STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4700 STORE_ROW_PARTIAL_12(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4701 VSTORE_PARTIAL(N0, STORE_N0) \ 4702 (BASENAME##C, 0, (__global DATA_TYPE *)(PTR + 12 * STRIDE_Y + Z##C)); 4703 4704#define STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4705 STORE_ROW_PARTIAL_13(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4706 VSTORE_PARTIAL(N0, STORE_N0) \ 4707 (BASENAME##D, 0, (__global DATA_TYPE *)(PTR + 13 * STRIDE_Y + Z##D)); 4708 4709#define STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4710 STORE_ROW_PARTIAL_14(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4711 VSTORE_PARTIAL(N0, STORE_N0) \ 4712 (BASENAME##E, 0, (__global DATA_TYPE *)(PTR + 14 * STRIDE_Y + Z##E)); 4713 4714#define STORE_ROW_PARTIAL_16(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4715 STORE_ROW_PARTIAL_15(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) \ 4716 VSTORE_PARTIAL(N0, STORE_N0) \ 4717 (BASENAME##F, 0, (__global DATA_TYPE *)(PTR + 15 * STRIDE_Y + Z##F)); 4718/** @} */ // end of groupd STORE_ROW_PARTIAL_n 4719 4720/** Partially store a block of the given size STORE_M0xSTORE_N0 4721 * @name STORE_BLOCK_PARTIAL 4722 * 4723 * @note The vector width @p N0 is also required for correct partial storing behaviour. 4724 * @note in case @p STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4725 * 4726 * The data to store is expected to have consecutive names for each row. 4727 * E.g., for STORE_M0=3 and basename=c, the expected names are c0, c1 and c2. 4728 * The Z offset is expected to have consecutive names. 4729 * E.g., for STORE_M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4730 * 4731 * @param[in] STORE_M0 The number of rows to store. Supported: 1-16 4732 * @param[in] STORE_N0 The lower number of elements of vectors to store. Supported: 1-16 and <= @p N0 4733 * @param[in] N0 The size of each vector. Supported: 1, 2, 3, 4, 8, 16 4734 * @param[in] DATA_TYPE The data type of the vectors 4735 * @param[in] BASENAME The basename of the variables 4736 * @param[in] PTR The base pointer 4737 * @param[in] STRIDE_Y The stride value in y-axis direction 4738 * @param[in] Z The offset in z-axis direction 4739 * @{ 4740 */ 4741#define STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_ROW_PARTIAL_##STORE_M0(N0, STORE_N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4742#define STORE_BLOCK_PARTIAL(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) STORE_BLOCK_PARTIAL_STR(STORE_M0, STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4743/** Store a block that can be partial in both x and y dimensions 4744 * 4745 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4746 * 4747 * The data to store is expected to have consecutive names for each row. 4748 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 4749 * The Z offset is expected to have consecutive names. 4750 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4751 * 4752 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 4753 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 4754 * @param[in] DATA_TYPE The data type of the vectors 4755 * @param[in] BASENAME The basename of the variables 4756 * @param[in] PTR The base pointer 4757 * @param[in] STRIDE_Y The stride value in y-axis direction 4758 * @param[in] Z The offset in z-axis direction 4759 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0) 4760 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0) 4761 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 4762 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 4763 */ 4764#define STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 4765 if(!(PARTIAL_COND_X) && !(PARTIAL_COND_Y)) \ 4766 { \ 4767 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4768 } \ 4769 else if((PARTIAL_COND_Y) && !(PARTIAL_COND_X)) \ 4770 { \ 4771 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4772 } \ 4773 else if(!(PARTIAL_COND_Y) && (PARTIAL_COND_X)) \ 4774 { \ 4775 STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4776 } \ 4777 else \ 4778 { \ 4779 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4780 } 4781/** Store a block that can only be partial in x but not y. 4782 * 4783 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4784 * 4785 * The data to store is expected to have consecutive names for each row. 4786 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 4787 * The Z offset is expected to have consecutive names. 4788 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4789 * 4790 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 4791 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 4792 * @param[in] DATA_TYPE The data type of the vectors 4793 * @param[in] BASENAME The basename of the variables 4794 * @param[in] PTR The base pointer 4795 * @param[in] STRIDE_Y The stride value in y-axis direction 4796 * @param[in] Z The offset in z-axis direction 4797 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported range: [1, @p N0) 4798 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 4799 */ 4800#define STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) \ 4801 if(!(PARTIAL_COND_X)) \ 4802 { \ 4803 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4804 } \ 4805 else \ 4806 { \ 4807 STORE_BLOCK_PARTIAL(M0, PARTIAL_STORE_N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4808 } 4809/** Store a block that can only be partial in y but not x. 4810 * 4811 * @note in case @p N0 or @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4812 * 4813 * The data to store is expected to have consecutive names for each row. 4814 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 4815 * The Z offset is expected to have consecutive names. 4816 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4817 * 4818 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 4819 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 4820 * @param[in] DATA_TYPE The data type of the vectors 4821 * @param[in] BASENAME The basename of the variables 4822 * @param[in] PTR The base pointer 4823 * @param[in] STRIDE_Y The stride value in y-axis direction 4824 * @param[in] Z The offset in z-axis direction 4825 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported range: [1, @p M0) 4826 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 4827 */ 4828#define STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) \ 4829 if(!(PARTIAL_COND_Y)) \ 4830 { \ 4831 STORE_BLOCK_PARTIAL(M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4832 } \ 4833 else \ 4834 { \ 4835 STORE_BLOCK_PARTIAL(PARTIAL_STORE_M0, N0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z); \ 4836 } 4837/** @} */ // end of group STORE_BLOCK_PARTIAL 4838 4839#if defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 4840 4841/** Boundary-aware GEMM block store 4842 * @name STORE_BLOCK_BOUNDARY_AWARE 4843 * This macro assumes the following schemes to achieve boundary-awareness: 4844 * - Overlapping load in Y axis from lhs tensor. This implies lhs has no padding along y dim. 4845 * - Non-Overlapping(normal) load from rhs tensor. This imples rhs can have paddings. 4846 * - Overlapping load in Y axis from bias tensor. This implies rhs has no padding along y dim. 4847 * The macro then ensures that the dst tensor can be stored without any paddings in both x and y dim. 4848 * 4849 * In the y dimension, we place the partial blocks **at the beginning** while in the x dimension, we place the partial 4850 * blocks **at the end**. 4851 * Say, the dst tensor is of shape MxN and we have M0 and N0 as the block size, this is how we define "partial blocks"/ 4852 * "boundary block" (we use the 2 terms "partial blocks" and "boundary blocks" interchangeably) and its various parameters: 4853 * 4854 * *--x--> x == 0 x == 1 4855 * | |<------------------------------N-------------------------->| 4856 * y |<--------------N0------------->|<----PARTIAL_STORE_N0----->| 4857 * | -------------############################################################# 4858 * * | | |...............................|...........................| 4859 * y == 0 | PAR_..._M0 |......Boundary block in y......|.Boundary block in x and y.| 4860 * | | |...............................|...........................| 4861 * M --############################################################# 4862 * | | | |...........................| 4863 * y == 1 | M0 | Non-boundary block |....Boundary block in x....| 4864 * | | | |...........................| 4865 * |------------############################################################# 4866 * 4867 * Then @p PARTIAL_STORE_M0 = M % M0 and @p PARTIAL_STORE_N0 = N % N0 4868 * 4869 * @note in cases @p PARTIAL_STORE_N0 != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4870 * 4871 * It automatically detects if a giving M,N,M0,N0 combination can yield partial blocks in either X and Y dimension, 4872 * and select corresponding store methods such that the boundary detection logic is only added when needed. 4873 * 4874 * The data to store is expected to have consecutive names for each row. 4875 * E.g., for M0=3 and basename=c, the expected names are c0, c1 and c2. 4876 * The Z offset is expected to have consecutive names. 4877 * E.g., for M0=3 and Z=zin, the expected z offset names are zin0, zin1 and zin2. 4878 * 4879 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 4880 * @param[in] N0 The size of each vector, for non-partial blocks. Supported: 1, 2, 3, 4, 8, 16 4881 * @param[in] DATA_TYPE The data type of the vectors 4882 * @param[in] BASENAME The basename of the variables 4883 * @param[in] PTR The base pointer 4884 * @param[in] STRIDE_Y The stride value in y-axis direction 4885 * @param[in] Z The offset in z-axis direction 4886 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0) 4887 * @param[in] PARTIAL_STORE_N0 The partial size in x, for partial blocks. Supported: [0, @p N0) 4888 * @param[in] PARTIAL_COND_Y Condition on the y axis to perform the partial store Y. True to use PARTIAL_STORE_M0 rather than M0. 4889 * @param[in] PARTIAL_COND_X Condition on the x axis to perform the partial store X. True to use PARTIAL_STORE_N0 rather than N0. 4890 * @{ 4891 */ 4892#if PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 4893// Case1: No partial blocks in either x or y 4894#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 4895 STORE_BLOCK(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z) 4896 4897#elif PARTIAL_STORE_M0 > 0 && PARTIAL_STORE_N0 == 0 4898// Case2: Partial blocks in y 4899#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 4900 STORE_BLOCK_PARTIAL_IN_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_COND_Y) 4901 4902#elif PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 > 0 4903// Case3: Partial blocks in x 4904#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 4905 STORE_BLOCK_PARTIAL_IN_X(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_N0, PARTIAL_COND_X) 4906 4907#else // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 4908// Case4: Partial blocks in both x and y 4909#define STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) \ 4910 STORE_BLOCK_PARTIAL_IN_X_AND_Y(M0, N0, DATA_TYPE, BASENAME, PTR, STRIDE_Y, Z, PARTIAL_STORE_M0, PARTIAL_STORE_N0, PARTIAL_COND_Y, PARTIAL_COND_X) 4911 4912#endif // PARTIAL_STORE_M0 == 0 && PARTIAL_STORE_N0 == 0 4913 4914#endif // defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 4915/** @} */ // end of group STORE_BLOCK_BOUNDARY_AWARE 4916 4917#if defined(PARTIAL_STORE_M0) 4918/** Compute the start m0 row (LHS, BIAS and DST) in a boundary-aware way so as to avoid padding 4919 * @name COMPUTE_M0_START_ROW 4920 * If there're any partial blocks in y dimension, they are placed at the beginning of the rows. 4921 * This shift amount is added to all rows such that the partial block (at the beginning) overlaps with the subsequent 4922 * blocks in the y dimension to avoid any padding. 4923 * EG: M0=4, PARTIAL_STORE_M0=1: 4924 * | Non-overlapping | +M0_ROW_SHIFT (Overlapping) 4925 * block 0 (partial)| start row = 0 | start row = 0 4926 * block 1 (full) | start row = 4 | start row = 1 4927 * block 2 (full) | start row = 8 | start row = 5 4928 * 4929 * @param[in] y Global id of current block in y. 4930 * @param[in] M0 The number of rows to store, for non-partial blocks. Supported: 1-16 4931 * @param[in] PARTIAL_STORE_M0 The partial size in y, for partial blocks. Supported: [0, @p M0) 4932 * @{ 4933 */ 4934#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \ 4935 ((uint)(max(0, (int)(y * M0) - (int)((M0 - PARTIAL_STORE_M0) % M0)))) 4936#else // defined(PARTIAL_STORE_M0) 4937#define COMPUTE_M0_START_ROW(y, M0, PARTIAL_STORE_M0) \ 4938 ((uint)(y * M0)) 4939#endif // defined(PARTIAL_STORE_M0) 4940/** @} */ // end of group COMPUTE_M0_START_ROW 4941 4942/** Store a vector that can only be partial in x. 4943 * 4944 * @note in case @p vec_size or @p leftover != 1, 2, 3, 4, 8, 16, extra vstore(s) will be invoked, thus incurring small performance penalty. 4945 * 4946 * The data to store is expected to end in a 0. 4947 * E.g., for basename=c, the expected name is c0. 4948 * 4949 * @param[in] basename The name of the variable without trailing 0 4950 * @param[in] data_type The data type of the vector 4951 * @param[in] ptr The base pointer 4952 * @param[in] vec_size The vector size if cond = false. Supported: 1, 2, 3, 4, 8, 16 4953 * @param[in] leftover The vector size if cond = true. Supported range: [1, @p vec_size0) 4954 * @param[in] cond Condition to select either vec_size0 or vec_size1 4955 * @{ 4956 */ 4957#define STORE_VECTOR_SELECT(basename, data_type, ptr, vec_size, leftover, cond) \ 4958 STORE_BLOCK_PARTIAL_IN_X(1, vec_size, data_type, basename, ptr, 0, 0, leftover, cond) 4959/** @} */ // end of group STORE_VECTOR_SELECT 4960 4961#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 4962#pragma OPENCL EXTENSION cl_khr_fp16 : enable 4963#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 4964 4965#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) 4966#pragma OPENCL EXTENSION cl_arm_integer_dot_product_int8 : enable 4967#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) 4968 4969#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) 4970#pragma OPENCL EXTENSION cl_arm_integer_dot_product_accumulate_int8 : enable 4971#endif // defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) 4972 4973#if defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf) 4974#pragma OPENCL EXTENSION cl_arm_printf : enable 4975#endif // defined(ARM_COMPUTE_DEBUG_ENABLED) && defined(cl_arm_printf) 4976 4977#define GPU_ARCH_MIDGARD 0x100 4978#define GPU_ARCH_BIFROST 0x200 4979 4980/** Concatenate two inputs. 4981 * 4982 * @param[in] a The first input to be concatenated 4983 * @param[in] b The second input to be concatenated 4984 * 4985 * @return The concatenated output 4986 */ 4987#define CONCAT(a, b) a##b 4988 4989/** Expand the given vector 4990 * 4991 * @param[in] x The vector to be expanded 4992 * 4993 * @return The expanded output 4994 */ 4995#define EXPAND(x) x 4996 4997/** Clamp the given value between an upper and lower bound. 4998 * 4999 * @param[in] x The value to be clamped 5000 * @param[in] min_val The lower bound 5001 * @param[in] max_val The upper bound 5002 * 5003 * @return The clamped value. 5004 */ 5005#define CLAMP(x, min_val, max_val) min(max(x, min_val), max_val) 5006 5007/** REVn reverses the given vector whose size is n. 5008 * @name REVn 5009 * 5010 * @param[in] x The vector to be reversed 5011 * 5012 * @return The reversed vector 5013 * @{ 5014 */ 5015#define REV1(x) ((x)) 5016#define REV2(x) ((x).s10) 5017#define REV3(x) ((x).s210) 5018#define REV4(x) ((x).s3210) 5019#define REV8(x) ((x).s76543210) 5020#define REV16(x) ((x).sFEDCBA9876543210) 5021/** @} */ // end of group REVn 5022 5023/** Reverse the given vector. 5024 * @name REVERSE 5025 * 5026 * @param[in] x The vector to be reversed 5027 * @param[in] s The size of the vector 5028 * 5029 * @return The reversed vector 5030 * @{ 5031 */ 5032#define REVERSE_STR(x, s) REV##s((x)) 5033#define REVERSE(x, s) REVERSE_STR(x, s) 5034/** @} */ // end of group REVERSE 5035 5036/** Circular-right-shift (rotate-right) the vector of size s by the amount of n. 5037 * @name ROTs_n 5038 * 5039 * @param[in] x The vector to be shifted 5040 * 5041 * @return The shifted vector 5042 * @{ 5043 */ 5044#define ROT1_0(x) ((x)) 5045 5046#define ROT2_0(x) ((x)) 5047#define ROT2_1(x) ((x).s10) 5048 5049#define ROT3_0(x) ((x)) 5050#define ROT3_1(x) ((x).s201) 5051#define ROT3_2(x) ((x).s120) 5052 5053#define ROT4_0(x) ((x)) 5054#define ROT4_1(x) ((x).s3012) 5055#define ROT4_2(x) ((x).s2301) 5056#define ROT4_3(x) ((x).s1230) 5057 5058#define ROT8_0(x) ((x)) 5059#define ROT8_1(x) ((x).s70123456) 5060#define ROT8_2(x) ((x).s67012345) 5061#define ROT8_3(x) ((x).s56701234) 5062#define ROT8_4(x) ((x).s45670123) 5063#define ROT8_5(x) ((x).s34567012) 5064#define ROT8_6(x) ((x).s23456701) 5065#define ROT8_7(x) ((x).s12345670) 5066 5067#define ROT16_0(x) ((x)) 5068#define ROT16_1(x) ((x).sF0123456789ABCDE) 5069#define ROT16_2(x) ((x).sEF0123456789ABCD) 5070#define ROT16_3(x) ((x).sDEF0123456789ABC) 5071#define ROT16_4(x) ((x).sCDEF0123456789AB) 5072#define ROT16_5(x) ((x).sBCDEF0123456789A) 5073#define ROT16_6(x) ((x).sABCDEF0123456789) 5074#define ROT16_7(x) ((x).s9ABCDEF012345678) 5075#define ROT16_8(x) ((x).s89ABCDEF01234567) 5076#define ROT16_9(x) ((x).s789ABCDEF0123456) 5077#define ROT16_10(x) ((x).s6789ABCDEF012345) 5078#define ROT16_11(x) ((x).s56789ABCDEF01234) 5079#define ROT16_12(x) ((x).s456789ABCDEF0123) 5080#define ROT16_13(x) ((x).s3456789ABCDEF012) 5081#define ROT16_14(x) ((x).s23456789ABCDEF01) 5082#define ROT16_15(x) ((x).s123456789ABCDEF0) 5083/** @} */ // end of group ROTs_n 5084 5085/** Circular-right-shift (rotate-right) the given vector by the given amount. 5086 * @name ROTATE 5087 * 5088 * @param[in] x The vector to be shifted 5089 * @param[in] s The size of the vector 5090 * @param[in] n The amount to be shifted 5091 * 5092 * @return The shifted vector 5093 * @{ 5094 */ 5095#define ROTATE_STR(x, s, n) ROT##s##_##n(x) 5096#define ROTATE(x, s, n) ROTATE_STR(x, s, n) 5097/** @} */ // end of group ROTATE 5098 5099/** Creates a vector of size n filled with offset values corresponding to the location of each element. 5100 * @name V_OFFSn 5101 * 5102 * @param[in] dt The data type of the output vector 5103 * 5104 * @return The vector filled with offset values 5105 * @{ 5106 */ 5107#define V_OFFS1(dt) (dt##1)(0) 5108#define V_OFFS2(dt) (dt##2)(0, 1) 5109#define V_OFFS3(dt) (dt##3)(0, 1, 2) 5110#define V_OFFS4(dt) (dt##4)(0, 1, 2, 3) 5111#define V_OFFS8(dt) (dt##8)(0, 1, 2, 3, 4, 5, 6, 7) 5112#define V_OFFS16(dt) (dt##16)(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) 5113/** @} */ // end of group V_OFFSn 5114 5115/** Create a vector filled with offset values corresponding to the location of each element. 5116 * @name VEC_OFFS 5117 * 5118 * @param[in] dt The data type of the output vector 5119 * @param[in] s The size of the output vector 5120 * 5121 * @return The vector filled with offset values 5122 * @{ 5123 */ 5124#define VEC_OFFS_STR(dt, s) V_OFFS##s(dt) 5125#define VEC_OFFS(dt, s) VEC_OFFS_STR(dt, s) 5126/** @} */ // end of group VEC_OFFS 5127 5128#define VLOAD_STR(size) vload##size 5129#define VLOAD(size) VLOAD_STR(size) 5130 5131#define PIXEL_UNIT4 1 5132#define PIXEL_UNIT8 2 5133#define PIXEL_UNIT16 4 5134 5135/** Utility macro to convert a vector size in pixel unit. 5136 * 5137 * @name CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT 5138 * 5139 * @param[in] vec_size Vector size. Only 4,8 and 16 is supported 5140 * 5141 * @return The pixel unit (number of pixels) 5142 * @{ 5143 */ 5144#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) PIXEL_UNIT##vec_size 5145#define CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(vec_size) CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT_STR(vec_size) 5146/** @} */ // end of group CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT 5147 5148#define read_image2d_floatx1(img, x_coord, y_coord) (float4)(read_imagef(img, (int2)(x_coord, y_coord))); 5149#define read_image2d_floatx2(img, x_coord, y_coord) (float8)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord))); 5150#define read_image2d_floatx4(img, x_coord, y_coord) (float16)(read_imagef(img, (int2)(x_coord, y_coord)), read_imagef(img, (int2)(x_coord + 1, y_coord)), read_imagef(img, (int2)(x_coord + 2, y_coord)), read_imagef(img, (int2)(x_coord + 3, y_coord))); 5151 5152#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 5153#define read_image2d_halfx1(img, x_coord, y_coord) (half4)(read_imageh(img, (int2)(x_coord, y_coord))); 5154#define read_image2d_halfx2(img, x_coord, y_coord) (half8)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord))); 5155#define read_image2d_halfx4(img, x_coord, y_coord) (half16)(read_imageh(img, (int2)(x_coord, y_coord)), read_imageh(img, (int2)(x_coord + 1, y_coord)), read_imageh(img, (int2)(x_coord + 2, y_coord)), read_imageh(img, (int2)(x_coord + 3, y_coord))); 5156#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) && defined(cl_khr_fp16) 5157 5158/** Utility macro to read a 2D OpenCL image object. 5159 * 5160 * @note Coordinates are not normalized 5161 * 5162 * @param[in] data_type Data type 5163 * @param[in] n0 Number of pixel to read. Only 1,2 and 4 is supported 5164 * @param[in] img OpenCL image object 5165 * @param[in] x_coord The x coordinate for the top-left pixel 5166 * @param[in] y_coord The y coordinate for the top-left pixel 5167 * 5168 * @return Pixels from the 2D OpenCL image object 5169 * @{ 5170 */ 5171#define READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) read_image2d_##data_type##x##n0(img, x_coord, y_coord) 5172#define READ_IMAGE2D(data_type, n0, img, x_coord, y_coord) READ_IMAGE2D_STR(data_type, n0, img, x_coord, y_coord) 5173 5174#define VSTORE_STR(size) vstore##size 5175#define VSTORE(size) VSTORE_STR(size) 5176 5177#define float1 float 5178#define half1 half 5179#define char1 char 5180#define uchar1 uchar 5181#define short1 short 5182#define ushort1 ushort 5183#define int1 int 5184#define uint1 uint 5185#define long1 long 5186#define ulong1 ulong 5187#define double1 double 5188 5189#define vload1(OFFSET, PTR) *(OFFSET + PTR) 5190#define vstore1(DATA, OFFSET, PTR) *(OFFSET + PTR) = DATA 5191 5192/** Extended partial vstore that correctly handles scalar values as well. 5193 * Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops 5194 * @name VSTORE_PARTIAL 5195 * 5196 * @note With this macro, the passed data can be both a vector and a scalar 5197 * @note @p store_size needs to be <= @p size 5198 * eg 1: Valid 5199 * VSTORE_PARTIAL(16, 15) ...; 5200 * eg 2: Invalid 5201 * VSTORE_PARTIAL(4, 7) ...; 5202 * 5203 * @param[in] size The width of @p DATA. Supported values: 1(scalar), 2, 3, 4, 8, 16 5204 * @param[in] store_size The number of lower elements to store. Supported values: 1-16, but has to be <= @p size 5205 * @{ 5206 */ 5207#define VSTORE_PARTIAL_STR(size, store_size) vstore_partial_##size##_##store_size 5208#define VSTORE_PARTIAL(size, store_size) VSTORE_PARTIAL_STR(size, store_size) 5209 5210#define NO_STORE(data, offs, ptr) \ 5211 { \ 5212 } 5213 5214// Size == 1 (scalar) 5215#define vstore_partial_1_0 NO_STORE 5216#define vstore_partial_1_1 vstore1 5217#define vstore_partial_1_2 NO_STORE 5218#define vstore_partial_1_3 NO_STORE 5219#define vstore_partial_1_4 NO_STORE 5220#define vstore_partial_1_5 NO_STORE 5221#define vstore_partial_1_6 NO_STORE 5222#define vstore_partial_1_7 NO_STORE 5223#define vstore_partial_1_8 NO_STORE 5224#define vstore_partial_1_9 NO_STORE 5225#define vstore_partial_1_10 NO_STORE 5226#define vstore_partial_1_11 NO_STORE 5227#define vstore_partial_1_12 NO_STORE 5228#define vstore_partial_1_13 NO_STORE 5229#define vstore_partial_1_14 NO_STORE 5230#define vstore_partial_1_15 NO_STORE 5231#define vstore_partial_1_16 NO_STORE 5232// Size == 2 5233#define vstore_partial_2_0 NO_STORE 5234#define vstore_partial_2_1 vstore_partial_1 5235#define vstore_partial_2_2 vstore_partial_2 5236#define vstore_partial_2_3 NO_STORE 5237#define vstore_partial_2_4 NO_STORE 5238#define vstore_partial_2_5 NO_STORE 5239#define vstore_partial_2_6 NO_STORE 5240#define vstore_partial_2_7 NO_STORE 5241#define vstore_partial_2_8 NO_STORE 5242#define vstore_partial_2_9 NO_STORE 5243#define vstore_partial_2_10 NO_STORE 5244#define vstore_partial_2_11 NO_STORE 5245#define vstore_partial_2_12 NO_STORE 5246#define vstore_partial_2_13 NO_STORE 5247#define vstore_partial_2_14 NO_STORE 5248#define vstore_partial_2_15 NO_STORE 5249#define vstore_partial_2_16 NO_STORE 5250// Size == 3 5251#define vstore_partial_3_0 NO_STORE 5252#define vstore_partial_3_1 vstore_partial_1 5253#define vstore_partial_3_2 vstore_partial_2 5254#define vstore_partial_3_3 vstore_partial_3 5255#define vstore_partial_3_4 NO_STORE 5256#define vstore_partial_3_5 NO_STORE 5257#define vstore_partial_3_6 NO_STORE 5258#define vstore_partial_3_7 NO_STORE 5259#define vstore_partial_3_8 NO_STORE 5260#define vstore_partial_3_9 NO_STORE 5261#define vstore_partial_3_10 NO_STORE 5262#define vstore_partial_3_11 NO_STORE 5263#define vstore_partial_3_12 NO_STORE 5264#define vstore_partial_3_13 NO_STORE 5265#define vstore_partial_3_14 NO_STORE 5266#define vstore_partial_3_15 NO_STORE 5267#define vstore_partial_3_16 NO_STORE 5268// Size == 4 5269#define vstore_partial_4_0 NO_STORE 5270#define vstore_partial_4_1 vstore_partial_1 5271#define vstore_partial_4_2 vstore_partial_2 5272#define vstore_partial_4_3 vstore_partial_3 5273#define vstore_partial_4_4 vstore_partial_4 5274#define vstore_partial_4_5 NO_STORE 5275#define vstore_partial_4_6 NO_STORE 5276#define vstore_partial_4_7 NO_STORE 5277#define vstore_partial_4_8 NO_STORE 5278#define vstore_partial_4_9 NO_STORE 5279#define vstore_partial_4_10 NO_STORE 5280#define vstore_partial_4_11 NO_STORE 5281#define vstore_partial_4_12 NO_STORE 5282#define vstore_partial_4_13 NO_STORE 5283#define vstore_partial_4_14 NO_STORE 5284#define vstore_partial_4_15 NO_STORE 5285#define vstore_partial_4_16 NO_STORE 5286// Size == 8 5287#define vstore_partial_8_0 NO_STORE 5288#define vstore_partial_8_1 vstore_partial_1 5289#define vstore_partial_8_2 vstore_partial_2 5290#define vstore_partial_8_3 vstore_partial_3 5291#define vstore_partial_8_4 vstore_partial_4 5292#define vstore_partial_8_5 vstore_partial_5 5293#define vstore_partial_8_6 vstore_partial_6 5294#define vstore_partial_8_7 vstore_partial_7 5295#define vstore_partial_8_8 vstore_partial_8 5296#define vstore_partial_8_9 NO_STORE 5297#define vstore_partial_8_10 NO_STORE 5298#define vstore_partial_8_11 NO_STORE 5299#define vstore_partial_8_12 NO_STORE 5300#define vstore_partial_8_13 NO_STORE 5301#define vstore_partial_8_14 NO_STORE 5302#define vstore_partial_8_15 NO_STORE 5303#define vstore_partial_8_16 NO_STORE 5304// Size == 16 5305#define vstore_partial_16_0 NO_STORE 5306#define vstore_partial_16_1 vstore_partial_1 5307#define vstore_partial_16_2 vstore_partial_2 5308#define vstore_partial_16_3 vstore_partial_3 5309#define vstore_partial_16_4 vstore_partial_4 5310#define vstore_partial_16_5 vstore_partial_5 5311#define vstore_partial_16_6 vstore_partial_6 5312#define vstore_partial_16_7 vstore_partial_7 5313#define vstore_partial_16_8 vstore_partial_8 5314#define vstore_partial_16_9 vstore_partial_9 5315#define vstore_partial_16_10 vstore_partial_10 5316#define vstore_partial_16_11 vstore_partial_11 5317#define vstore_partial_16_12 vstore_partial_12 5318#define vstore_partial_16_13 vstore_partial_13 5319#define vstore_partial_16_14 vstore_partial_14 5320#define vstore_partial_16_15 vstore_partial_15 5321#define vstore_partial_16_16 vstore_partial_16 5322 5323/** Partial vstore. Store the **lower** 0 to (n-1)th elements of the given vector while minimising the amount of vstore ops 5324 * @name vstore_partial_n 5325 * 5326 * @note @p DATA needs to be a vector not a scalar 5327 * @note n needs to be <= the vector width of the input variable @p DATA 5328 * eg 1: Valid 5329 * vstore_partial_15(var:float16, 0, 0xabcd); 5330 * eg 2: Invalid 5331 * vstore_partial_7(var:float4, 0, 0xabcd); 5332 * 5333 * @note in cases n == 1, 2, 3, 4, 8, 16, no extra vstore is invoked, thus there's no performance penalty. 5334 * 5335 * @param[in] DATA The name of the variable 5336 * @param[in] OFFSET Offset in n 5337 * @param[in] PTR The base pointer 5338 * @{ 5339 */ 5340#define vstore_partial_1(DATA, OFFSET, PTR) \ 5341 vstore1(DATA.s0, OFFSET, PTR); 5342 5343#define vstore_partial_2(DATA, OFFSET, PTR) \ 5344 vstore2(DATA.s01, OFFSET, PTR); 5345 5346#define vstore_partial_3(DATA, OFFSET, PTR) \ 5347 vstore3(DATA.s012, OFFSET, PTR); 5348 5349#define vstore_partial_4(DATA, OFFSET, PTR) \ 5350 vstore4(DATA.s0123, OFFSET, PTR); 5351 5352#define vstore_partial_5(DATA, OFFSET, PTR) \ 5353 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 5354 vstore1(DATA.s4, OFFSET, PTR + 4); 5355 5356#define vstore_partial_6(DATA, OFFSET, PTR) \ 5357 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 5358 vstore_partial_2(DATA.s45, OFFSET, PTR + 4); 5359 5360#define vstore_partial_7(DATA, OFFSET, PTR) \ 5361 vstore_partial_4(DATA.s0123, OFFSET, PTR); \ 5362 vstore_partial_3(DATA.s456, OFFSET, PTR + 4); 5363 5364#define vstore_partial_8(DATA, OFFSET, PTR) \ 5365 vstore8(DATA.s01234567, OFFSET, PTR); 5366 5367#define vstore_partial_9(DATA, OFFSET, PTR) \ 5368 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5369 vstore1(DATA.s8, OFFSET, PTR + 8); 5370 5371#define vstore_partial_10(DATA, OFFSET, PTR) \ 5372 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5373 vstore_partial_2(DATA.s89, OFFSET, PTR + 8); 5374 5375#define vstore_partial_11(DATA, OFFSET, PTR) \ 5376 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5377 vstore_partial_3(DATA.s89a, OFFSET, PTR + 8); 5378 5379#define vstore_partial_12(DATA, OFFSET, PTR) \ 5380 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5381 vstore_partial_4(DATA.s89ab, OFFSET, PTR + 8); 5382 5383#define vstore_partial_13(DATA, OFFSET, PTR) \ 5384 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5385 vstore_partial_5(DATA.s89abcdef, OFFSET, PTR + 8); 5386 5387#define vstore_partial_14(DATA, OFFSET, PTR) \ 5388 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5389 vstore_partial_6(DATA.s89abcdef, OFFSET, PTR + 8); 5390 5391#define vstore_partial_15(DATA, OFFSET, PTR) \ 5392 vstore_partial_8(DATA.s01234567, OFFSET, PTR); \ 5393 vstore_partial_7(DATA.s89abcdef, OFFSET, PTR + 8); 5394 5395#define vstore_partial_16(DATA, OFFSET, PTR) \ 5396 vstore16(DATA, OFFSET, PTR); 5397/** @} */ // end of groupd vstore_partial_n 5398/** @} */ // end of groupd VSTORE_PARTIAL 5399 5400// Convert built-in functions with _sat modifier are not supported in floating point so we create defines 5401// without _sat to overcome this issue 5402#define convert_float_sat convert_float 5403#define convert_float1_sat convert_float 5404#define convert_float2_sat convert_float2 5405#define convert_float3_sat convert_float3 5406#define convert_float4_sat convert_float4 5407#define convert_float8_sat convert_float8 5408#define convert_float16_sat convert_float16 5409#define convert_half_sat convert_float 5410#define convert_half1_sat convert_half 5411#define convert_half2_sat convert_half2 5412#define convert_half3_sat convert_half3 5413#define convert_half4_sat convert_half4 5414#define convert_half8_sat convert_half8 5415#define convert_half16_sat convert_half16 5416 5417#define convert_float1 convert_float 5418#define convert_half1 convert_half 5419#define convert_char1 convert_char 5420#define convert_uchar1 convert_uchar 5421#define convert_short1 convert_short 5422#define convert_ushort1 convert_ushort 5423#define convert_int1 convert_int 5424#define convert_uint1 convert_uint 5425#define convert_long1 convert_long 5426#define convert_ulong1 convert_ulong 5427#define convert_double1 convert_double 5428 5429#define convert_char1_sat convert_char_sat 5430#define convert_uchar1_sat convert_uchar_sat 5431#define convert_short1_sat convert_short_sat 5432#define convert_ushort1_sat convert_ushort_sat 5433#define convert_int1_sat convert_int_sat 5434#define convert_uint1_sat convert_uint_sat 5435#define convert_long1_sat convert_long_sat 5436#define convert_ulong1_sat convert_ulong_sat 5437#define convert_double1_sat convert_double_sat 5438 5439#define VEC_DATA_TYPE_STR(type, size) type##size 5440#define VEC_DATA_TYPE(type, size) VEC_DATA_TYPE_STR(type, size) 5441 5442#define CONVERT_STR(x, type) (convert_##type((x))) 5443#define CONVERT(x, type) CONVERT_STR(x, type) 5444 5445#define CONVERT_SAT_STR(x, type) (convert_##type##_sat((x))) 5446#define CONVERT_SAT(x, type) CONVERT_SAT_STR(x, type) 5447 5448#define CONVERT_SAT_ROUND_STR(x, type, round) (convert_##type##_sat_##round((x))) 5449#define CONVERT_SAT_ROUND(x, type, round) CONVERT_SAT_ROUND_STR(x, type, round) 5450 5451#define select_vec_dt_uchar(size) uchar##size 5452#define select_vec_dt_char(size) char##size 5453#define select_vec_dt_ushort(size) ushort##size 5454#define select_vec_dt_short(size) short##size 5455#define select_vec_dt_half(size) short##size 5456#define select_vec_dt_uint(size) uint##size 5457#define select_vec_dt_int(size) int##size 5458#define select_vec_dt_float(size) int##size 5459#define select_vec_dt_ulong(size) ulong##size 5460#define select_vec_dt_long(size) long##size 5461 5462#define SELECT_VEC_DATA_TYPE_STR(type, size) select_vec_dt_##type(size) 5463#define SELECT_VEC_DATA_TYPE(type, size) SELECT_VEC_DATA_TYPE_STR(type, size) 5464#define SELECT_DATA_TYPE(type) SELECT_VEC_DATA_TYPE_STR(type, 1) 5465 5466#define sum_reduce_1(x) (x) 5467#define sum_reduce_2(x) ((x).s0) + ((x).s1) 5468#define sum_reduce_3(x) sum_reduce_2((x).s01) + ((x).s2) 5469#define sum_reduce_4(x) sum_reduce_2((x).s01) + sum_reduce_2((x).s23) 5470#define sum_reduce_8(x) sum_reduce_4((x).s0123) + sum_reduce_4((x).s4567) 5471#define sum_reduce_16(x) sum_reduce_8((x).s01234567) + sum_reduce_8((x).s89ABCDEF) 5472 5473#define SUM_REDUCE_STR(x, size) sum_reduce_##size(x) 5474#define SUM_REDUCE(x, size) SUM_REDUCE_STR(x, size) 5475 5476#define max_reduce_1(x) (x) 5477#define max_reduce_2(x) max(((x).s0), ((x).s1)) 5478#define max_reduce_3(x) max(max_reduce_2((x).s01), ((x).s2)) 5479#define max_reduce_4(x) max(max_reduce_2((x).s01), max_reduce_2((x).s23)) 5480#define max_reduce_8(x) max(max_reduce_4((x).s0123), max_reduce_4((x).s4567)) 5481#define max_reduce_16(x) max(max_reduce_8((x).s01234567), max_reduce_8((x).s89ABCDEF)) 5482 5483#define MAX_REDUCE_STR(x, size) max_reduce_##size(x) 5484#define MAX_REDUCE(x, size) MAX_REDUCE_STR(x, size) 5485 5486#define VECTOR_DECLARATION(name) \ 5487 __global uchar *name##_ptr, \ 5488 uint name##_stride_x, \ 5489 uint name##_step_x, \ 5490 uint name##_offset_first_element_in_bytes 5491 5492#define IMAGE_DECLARATION(name) \ 5493 __global uchar *name##_ptr, \ 5494 uint name##_stride_x, \ 5495 uint name##_step_x, \ 5496 uint name##_stride_y, \ 5497 uint name##_step_y, \ 5498 uint name##_offset_first_element_in_bytes 5499 5500#define TENSOR3D_DECLARATION(name) \ 5501 __global uchar *name##_ptr, \ 5502 uint name##_stride_x, \ 5503 uint name##_step_x, \ 5504 uint name##_stride_y, \ 5505 uint name##_step_y, \ 5506 uint name##_stride_z, \ 5507 uint name##_step_z, \ 5508 uint name##_offset_first_element_in_bytes 5509 5510#define TENSOR4D_DECLARATION(name) \ 5511 __global uchar *name##_ptr, \ 5512 uint name##_stride_x, \ 5513 uint name##_step_x, \ 5514 uint name##_stride_y, \ 5515 uint name##_step_y, \ 5516 uint name##_stride_z, \ 5517 uint name##_step_z, \ 5518 uint name##_stride_w, \ 5519 uint name##_step_w, \ 5520 uint name##_offset_first_element_in_bytes 5521 5522#define CONVERT_TO_VECTOR_STRUCT(name) \ 5523 update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x) 5524 5525#define CONVERT_TO_VECTOR_STRUCT_NO_STEP(name) \ 5526 update_vector_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0) 5527 5528#define CONVERT_TO_IMAGE_STRUCT(name) \ 5529 update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y) 5530 5531#define CONVERT_TO_IMAGE_STRUCT_NO_STEP(name) \ 5532 update_image_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0) 5533 5534#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \ 5535 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z) 5536 5537#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT_NO_STEP(name) \ 5538 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, name##_step_z) 5539 5540#define CONVERT_TENSOR3D_TO_IMAGE_STRUCT(name) \ 5541 update_image_from_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, name##_stride_z, name##_step_z) 5542 5543#define CONVERT_TO_TENSOR3D_STRUCT(name) \ 5544 update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 5545 name##_stride_z, name##_step_z) 5546 5547#define CONVERT_TO_TENSOR3D_STRUCT_NO_STEP(name) \ 5548 update_tensor3D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0) 5549 5550#define CONVERT_TO_TENSOR4D_STRUCT(name, mod_size) \ 5551 update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 5552 name##_stride_z, name##_step_z, name##_stride_w, name##_step_w, mod_size) 5553 5554#define CONVERT_TO_TENSOR4D_STRUCT_NO_STEP(name, mod_size) \ 5555 update_tensor4D_workitem_ptr(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, 0, name##_stride_y, 0, name##_stride_z, 0, name##_stride_w, 0, mod_size) 5556 5557#define CONVERT_TO_TENSOR3D_STRUCT_NO_UPDATE_PTR(name) \ 5558 tensor3D_ptr_no_update(name##_ptr, name##_offset_first_element_in_bytes, name##_stride_x, name##_step_x, name##_stride_y, name##_step_y, \ 5559 name##_stride_z, name##_step_z) 5560 5561/** Structure to hold Vector information */ 5562typedef struct Vector 5563{ 5564 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 5565 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 5566 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 5567} Vector; 5568 5569/** Structure to hold Image information */ 5570typedef struct Image 5571{ 5572 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 5573 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 5574 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 5575 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 5576} Image; 5577 5578/** Structure to hold 3D tensor information */ 5579typedef struct Tensor3D 5580{ 5581 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 5582 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 5583 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 5584 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 5585 int stride_z; /**< Stride of the image in Z dimension (in bytes) */ 5586} Tensor3D; 5587 5588/** Structure to hold 4D tensor information */ 5589typedef struct Tensor4D 5590{ 5591 __global uchar *ptr; /**< Pointer to the starting postion of the buffer */ 5592 int offset_first_element_in_bytes; /**< The offset of the first element in the source image */ 5593 int stride_x; /**< Stride of the image in X dimension (in bytes) */ 5594 int stride_y; /**< Stride of the image in Y dimension (in bytes) */ 5595 int stride_z; /**< Stride of the image in Z dimension (in bytes) */ 5596 int stride_w; /**< Stride of the image in W dimension (in bytes) */ 5597} Tensor4D; 5598 5599/** Wrap vector information into an Vector structure, and make the pointer point at this workitem's data. 5600 * 5601 * @param[in] ptr Pointer to the starting postion of the buffer 5602 * @param[in] offset_first_element_in_bytes The offset of the first element in the source vector 5603 * @param[in] stride_x Stride of the vector in X dimension (in bytes) 5604 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 5605 * 5606 * @return An image object 5607 */ 5608inline Vector update_vector_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x) 5609{ 5610 Vector vector = 5611 { 5612 .ptr = ptr, 5613 .offset_first_element_in_bytes = offset_first_element_in_bytes, 5614 .stride_x = stride_x, 5615 }; 5616 vector.ptr += vector.offset_first_element_in_bytes + get_global_id(0) * step_x; 5617 return vector; 5618} 5619 5620/** Wrap image information into an Image structure, and make the pointer point at this workitem's data. 5621 * 5622 * @param[in] ptr Pointer to the starting postion of the buffer 5623 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 5624 * @param[in] stride_x Stride of the image in X dimension (in bytes) 5625 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 5626 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 5627 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 5628 * 5629 * @return An image object 5630 */ 5631inline Image update_image_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y) 5632{ 5633 Image img = 5634 { 5635 .ptr = ptr, 5636 .offset_first_element_in_bytes = offset_first_element_in_bytes, 5637 .stride_x = stride_x, 5638 .stride_y = stride_y 5639 }; 5640 img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y; 5641 return img; 5642} 5643 5644/** Wrap 3D tensor information into an image structure, and make the pointer point at this workitem's data. 5645 * 5646 * @param[in] ptr Pointer to the starting postion of the buffer 5647 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 5648 * @param[in] stride_x Stride of the image in X dimension (in bytes) 5649 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 5650 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 5651 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 5652 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 5653 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 5654 * 5655 * @return A 3D tensor object 5656 */ 5657inline Image update_image_from_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 5658{ 5659 Image img = 5660 { 5661 .ptr = ptr, 5662 .offset_first_element_in_bytes = offset_first_element_in_bytes, 5663 .stride_x = stride_x, 5664 .stride_y = stride_y 5665 }; 5666 img.ptr += img.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z; 5667 return img; 5668} 5669 5670/** Wrap 3D tensor information into an tensor structure, and make the pointer point at this workitem's data. 5671 * 5672 * @param[in] ptr Pointer to the starting postion of the buffer 5673 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 5674 * @param[in] stride_x Stride of the image in X dimension (in bytes) 5675 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 5676 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 5677 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 5678 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 5679 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 5680 * 5681 * @return A 3D tensor object 5682 */ 5683inline Tensor3D update_tensor3D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 5684{ 5685 Tensor3D tensor = 5686 { 5687 .ptr = ptr, 5688 .offset_first_element_in_bytes = offset_first_element_in_bytes, 5689 .stride_x = stride_x, 5690 .stride_y = stride_y, 5691 .stride_z = stride_z 5692 }; 5693 tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + get_global_id(2) * step_z; 5694 return tensor; 5695} 5696 5697/** Wrap 3D tensor information into an tensor structure. 5698 * 5699 * @param[in] ptr Pointer to the starting postion of the buffer 5700 * @param[in] offset_first_element_in_bytes The offset of the first element in the source image 5701 * @param[in] stride_x Stride of the image in X dimension (in bytes) 5702 * @param[in] step_x stride_x * number of elements along X processed per workitem(in bytes) 5703 * @param[in] stride_y Stride of the image in Y dimension (in bytes) 5704 * @param[in] step_y stride_y * number of elements along Y processed per workitem(in bytes) 5705 * @param[in] stride_z Stride of the image in Z dimension (in bytes) 5706 * @param[in] step_z stride_z * number of elements along Z processed per workitem(in bytes) 5707 * 5708 * @return A 3D tensor object 5709 */ 5710inline Tensor3D tensor3D_ptr_no_update(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z) 5711{ 5712 Tensor3D tensor = 5713 { 5714 .ptr = ptr, 5715 .offset_first_element_in_bytes = offset_first_element_in_bytes, 5716 .stride_x = stride_x, 5717 .stride_y = stride_y, 5718 .stride_z = stride_z 5719 }; 5720 return tensor; 5721} 5722 5723inline Tensor4D update_tensor4D_workitem_ptr(__global uchar *ptr, uint offset_first_element_in_bytes, uint stride_x, uint step_x, uint stride_y, uint step_y, uint stride_z, uint step_z, uint stride_w, 5724 uint step_w, 5725 uint mod_size) 5726{ 5727 Tensor4D tensor = 5728 { 5729 .ptr = ptr, 5730 .offset_first_element_in_bytes = offset_first_element_in_bytes, 5731 .stride_x = stride_x, 5732 .stride_y = stride_y, 5733 .stride_z = stride_z, 5734 .stride_w = stride_w 5735 }; 5736 5737 tensor.ptr += tensor.offset_first_element_in_bytes + get_global_id(0) * step_x + get_global_id(1) * step_y + (get_global_id(2) % mod_size) * step_z + (get_global_id(2) / mod_size) * step_w; 5738 return tensor; 5739} 5740 5741/** Get the pointer position of a Vector 5742 * 5743 * @param[in] vec Pointer to the starting position of the buffer 5744 * @param[in] x Relative X position 5745 */ 5746inline __global const uchar *vector_offset(const Vector *vec, int x) 5747{ 5748 return vec->ptr + x * vec->stride_x; 5749} 5750 5751/** Get the pointer position of a Image 5752 * 5753 * @param[in] img Pointer to the starting position of the buffer 5754 * @param[in] x Relative X position 5755 * @param[in] y Relative Y position 5756 */ 5757inline __global uchar *offset(const Image *img, int x, int y) 5758{ 5759 return img->ptr + x * img->stride_x + y * img->stride_y; 5760} 5761 5762/** Get the pointer position of a Tensor3D 5763 * 5764 * @param[in] tensor Pointer to the starting position of the buffer 5765 * @param[in] x Relative X position 5766 * @param[in] y Relative Y position 5767 * @param[in] z Relative Z position 5768 */ 5769inline __global const uchar *tensor3D_offset(const Tensor3D *tensor, int x, int y, int z) 5770{ 5771 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z; 5772} 5773 5774/** Get the pointer position of a Tensor4D 5775 * 5776 * @param[in] tensor Pointer to the starting position of the buffer 5777 * @param[in] x Relative X position 5778 * @param[in] y Relative Y position 5779 * @param[in] z Relative Z position 5780 * @param[in] w Relative W position 5781 */ 5782inline __global const uchar *tensor4D_offset(const Tensor4D *tensor, int x, int y, int z, int w) 5783{ 5784 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + w * tensor->stride_w; 5785} 5786 5787/** Get the offset for a given linear index of a Tensor3D 5788 * 5789 * @param[in] tensor Pointer to the starting position of the buffer 5790 * @param[in] width Width of the input tensor 5791 * @param[in] height Height of the input tensor 5792 * @param[in] depth Depth of the input tensor 5793 * @param[in] index Linear index 5794 */ 5795inline __global const uchar *tensor3D_index2ptr(const Tensor3D *tensor, uint width, uint height, uint depth, uint index) 5796{ 5797 uint num_elements = width * height; 5798 5799 const uint z = index / num_elements; 5800 5801 index %= num_elements; 5802 5803 const uint y = index / width; 5804 5805 index %= width; 5806 5807 const uint x = index; 5808 5809 return tensor->ptr + x * tensor->stride_x + y * tensor->stride_y + z * tensor->stride_z + tensor->offset_first_element_in_bytes; 5810} 5811 5812#endif // _HELPER_H 5813 5814/** Macros that help in loop unrolling */ 5815//Repeat macros with 3 param, excluding the implicit ID param 5816#define REPEAT_3_1(P_X, P_A, P_B, P_C) P_X##_DEF(0, P_A, P_B, P_C) 5817#define REPEAT_3_2(P_X, P_A, P_B, P_C) \ 5818 P_X##_DEF(1, P_A, P_B, P_C); \ 5819 REPEAT_3_1(P_X, P_A, P_B, P_C) 5820#define REPEAT_3_3(P_X, P_A, P_B, P_C) \ 5821 P_X##_DEF(2, P_A, P_B, P_C); \ 5822 REPEAT_3_2(P_X, P_A, P_B, P_C) 5823#define REPEAT_3_4(P_X, P_A, P_B, P_C) \ 5824 P_X##_DEF(3, P_A, P_B, P_C); \ 5825 REPEAT_3_3(P_X, P_A, P_B, P_C) 5826#define REPEAT_3_5(P_X, P_A, P_B, P_C) \ 5827 P_X##_DEF(4, P_A, P_B, P_C); \ 5828 REPEAT_3_4(P_X, P_A, P_B, P_C) 5829#define REPEAT_3_6(P_X, P_A, P_B, P_C) \ 5830 P_X##_DEF(5, P_A, P_B, P_C); \ 5831 REPEAT_3_5(P_X, P_A, P_B, P_C) 5832#define REPEAT_3_7(P_X, P_A, P_B, P_C) \ 5833 P_X##_DEF(6, P_A, P_B, P_C); \ 5834 REPEAT_3_6(P_X, P_A, P_B, P_C) 5835#define REPEAT_3_8(P_X, P_A, P_B, P_C) \ 5836 P_X##_DEF(7, P_A, P_B, P_C); \ 5837 REPEAT_3_7(P_X, P_A, P_B, P_C) 5838#define REPEAT_3_9(P_X, P_A, P_B, P_C) \ 5839 P_X##_DEF(8, P_A, P_B, P_C); \ 5840 REPEAT_3_8(P_X, P_A, P_B, P_C) 5841#define REPEAT_3_10(P_X, P_A, P_B, P_C) \ 5842 P_X##_DEF(9, P_A, P_B, P_C); \ 5843 REPEAT_3_9(P_X, P_A, P_B, P_C) 5844#define REPEAT_3_11(P_X, P_A, P_B, P_C) \ 5845 P_X##_DEF(A, P_A, P_B, P_C); \ 5846 REPEAT_3_10(P_X, P_A, P_B, P_C) 5847#define REPEAT_3_12(P_X, P_A, P_B, P_C) \ 5848 P_X##_DEF(B, P_A, P_B, P_C); \ 5849 REPEAT_3_11(P_X, P_A, P_B, P_C) 5850#define REPEAT_3_13(P_X, P_A, P_B, P_C) \ 5851 P_X##_DEF(C, P_A, P_B, P_C); \ 5852 REPEAT_3_12(P_X, P_A, P_B, P_C) 5853#define REPEAT_3_14(P_X, P_A, P_B, P_C) \ 5854 P_X##_DEF(D, P_A, P_B, P_C); \ 5855 REPEAT_3_13(P_X, P_A, P_B, P_C) 5856#define REPEAT_3_15(P_X, P_A, P_B, P_C) \ 5857 P_X##_DEF(E, P_A, P_B, P_C); \ 5858 REPEAT_3_14(P_X, P_A, P_B, P_C) 5859#define REPEAT_3_16(P_X, P_A, P_B, P_C) \ 5860 P_X##_DEF(F, P_A, P_B, P_C); \ 5861 REPEAT_3_15(P_X, P_A, P_B, P_C) 5862 5863#define REPEAT_DEF_3_N(P_NUM, P_OP, P_A, P_B, P_C) REPEAT_3_##P_NUM(P_OP, P_A, P_B, P_C) //One level of indirection to ensure order of expansion does not affect preprocessing P_NUM 5864#define REPEAT_3_N(P_NUM, P_OP, P_A, P_B, P_C) REPEAT_DEF_3_N(P_NUM, P_OP, P_A, P_B, P_C) 5865 5866// Repeat macros with 4 param, excluding the implicit ID param 5867#define REPEAT_4_1(P_X, P_A, P_B, P_C, P_D) P_X##_DEF(0, P_A, P_B, P_C, P_D) 5868#define REPEAT_4_2(P_X, P_A, P_B, P_C, P_D) \ 5869 P_X##_DEF(1, P_A, P_B, P_C, P_D); \ 5870 REPEAT_4_1(P_X, P_A, P_B, P_C, P_D) 5871#define REPEAT_4_3(P_X, P_A, P_B, P_C, P_D) \ 5872 P_X##_DEF(2, P_A, P_B, P_C, P_D); \ 5873 REPEAT_4_2(P_X, P_A, P_B, P_C, P_D) 5874#define REPEAT_4_4(P_X, P_A, P_B, P_C, P_D) \ 5875 P_X##_DEF(3, P_A, P_B, P_C, P_D); \ 5876 REPEAT_4_3(P_X, P_A, P_B, P_C, P_D) 5877#define REPEAT_4_5(P_X, P_A, P_B, P_C, P_D) \ 5878 P_X##_DEF(4, P_A, P_B, P_C, P_D); \ 5879 REPEAT_4_4(P_X, P_A, P_B, P_C, P_D) 5880#define REPEAT_4_6(P_X, P_A, P_B, P_C, P_D) \ 5881 P_X##_DEF(5, P_A, P_B, P_C, P_D); \ 5882 REPEAT_4_5(P_X, P_A, P_B, P_C, P_D) 5883#define REPEAT_4_7(P_X, P_A, P_B, P_C, P_D) \ 5884 P_X##_DEF(6, P_A, P_B, P_C, P_D); \ 5885 REPEAT_4_6(P_X, P_A, P_B, P_C, P_D) 5886#define REPEAT_4_8(P_X, P_A, P_B, P_C, P_D) \ 5887 P_X##_DEF(7, P_A, P_B, P_C, P_D); \ 5888 REPEAT_4_7(P_X, P_A, P_B, P_C, P_D) 5889#define REPEAT_4_9(P_X, P_A, P_B, P_C, P_D) \ 5890 P_X##_DEF(8, P_A, P_B, P_C, P_D); \ 5891 REPEAT_4_8(P_X, P_A, P_B, P_C, P_D) 5892#define REPEAT_4_10(P_X, P_A, P_B, P_C, P_D) \ 5893 P_X##_DEF(9, P_A, P_B, P_C, P_D); \ 5894 REPEAT_4_9(P_X, P_A, P_B, P_C, P_D) 5895#define REPEAT_4_11(P_X, P_A, P_B, P_C, P_D) \ 5896 P_X##_DEF(A, P_A, P_B, P_C, P_D); \ 5897 REPEAT_4_10(P_X, P_A, P_B, P_C, P_D) 5898#define REPEAT_4_12(P_X, P_A, P_B, P_C, P_D) \ 5899 P_X##_DEF(B, P_A, P_B, P_C, P_D); \ 5900 REPEAT_4_11(P_X, P_A, P_B, P_C, P_D) 5901#define REPEAT_4_13(P_X, P_A, P_B, P_C, P_D) \ 5902 P_X##_DEF(C, P_A, P_B, P_C, P_D); \ 5903 REPEAT_4_12(P_X, P_A, P_B, P_C, P_D) 5904#define REPEAT_4_14(P_X, P_A, P_B, P_C, P_D) \ 5905 P_X##_DEF(D, P_A, P_B, P_C, P_D); \ 5906 REPEAT_4_13(P_X, P_A, P_B, P_C, P_D) 5907#define REPEAT_4_15(P_X, P_A, P_B, P_C, P_D) \ 5908 P_X##_DEF(E, P_A, P_B, P_C, P_D); \ 5909 REPEAT_4_14(P_X, P_A, P_B, P_C, P_D) 5910#define REPEAT_4_16(P_X, P_A, P_B, P_C, P_D) \ 5911 P_X##_DEF(F, P_A, P_B, P_C, P_D); \ 5912 REPEAT_4_15(P_X, P_A, P_B, P_C, P_D) 5913 5914#define REPEAT_DEF_4_N(P_NUM, P_OP, P_A, P_B, P_C, P_D) REPEAT_4_##P_NUM(P_OP, P_A, P_B, P_C, P_D) //One level of indirection to ensure order of expansion does not affect preprocessing P_NUM 5915#define REPEAT_4_N(P_NUM, P_OP, P_A, P_B, P_C, P_D) REPEAT_DEF_4_N(P_NUM, P_OP, P_A, P_B, P_C, P_D) 5916 5917// Macro for initializing N variables. Generates N statements that defines VAR##N = RHS_ACCESSOR_DEF(...) 5918#define VAR_INIT_TO_CONST_DEF(ID, TYPE, VAR, VAL) TYPE VAR##ID = VAL 5919#define REPEAT_VAR_INIT_TO_CONST(N, TYPE, VAR, VAL) REPEAT_3_N(N, VAR_INIT_TO_CONST, TYPE, VAR, VAL) 5920 5921// Macro for initializing N variables by converting the data type. Generates N statements that defines VAR##N = RHS_ACCESSOR_DEF(...) 5922#define VAR_INIT_CONVERT_DEF(ID, TYPE_OUT, VAR_IN, VAR_OUT) TYPE_OUT VAR_OUT##ID = CONVERT(VAR_IN##ID, TYPE_OUT) 5923#define REPEAT_VAR_INIT_CONVERT(N, TYPE_OUT, VAR_IN, VAR_OUT) REPEAT_3_N(N, VAR_INIT_CONVERT, TYPE_OUT, VAR_IN, VAR_OUT) 5924 5925// Macro for initializing N variables by converting the data type with saturation. Generates N statements that defines VAR##N = RHS_ACCESSOR_DEF(...) 5926#define VAR_INIT_CONVERT_SAT_DEF(ID, TYPE_OUT, VAR_IN, VAR_OUT) TYPE_OUT VAR_OUT##ID = CONVERT_SAT(VAR_IN##ID, TYPE_OUT) 5927#define REPEAT_VAR_INIT_CONVERT_SAT(N, TYPE_OUT, VAR_IN, VAR_OUT) REPEAT_3_N(N, VAR_INIT_CONVERT_SAT, TYPE_OUT, VAR_IN, VAR_OUT) 5928 5929// Macro for adding a constant to N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5930#define ADD_CONST_TO_VAR_DEF(ID, TYPE, VAR, VAL) VAR##ID += (TYPE)VAL 5931#define REPEAT_ADD_CONST_TO_VAR(N, TYPE, VAR, VAL) REPEAT_3_N(N, ADD_CONST_TO_VAR, TYPE, VAR, VAL) 5932 5933// Macro for multiplying N variables (VAR_B) by a constant (VAL) and adding to other N variables (VAR_A). Generates N statements that defines VAR_A##N =RHS_ACCESSOR_DEF(...) 5934#define MLA_VAR_WITH_CONST_VEC_DEF(ID, VAR_A, VAR_B, VAL) VAR_A##ID += VAR_B##ID * VAL 5935#define REPEAT_MLA_VAR_WITH_CONST_VEC(N, VAR_A, VAR_B, VAL) REPEAT_3_N(N, MLA_VAR_WITH_CONST_VEC, VAR_A, VAR_B, VAL) 5936 5937// Macro for adding a vector to N-variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5938#define ADD_VECTOR_TO_VAR_DEF(ID, TYPE, VAR, VEC) VAR##ID += VEC 5939#define REPEAT_ADD_VECTOR_TO_VAR(N, VAR, VEC) REPEAT_3_N(N, ADD_VECTOR_TO_VAR, "", VAR, VEC) 5940 5941// Macro for adding a two N-variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5942#define ADD_TWO_VARS_DEF(ID, TYPE, VAR_A, VAR_B) VAR_A##ID += VAR_B##ID 5943#define REPEAT_ADD_TWO_VARS(N, VAR_A, VAR_B) REPEAT_3_N(N, ADD_TWO_VARS, "", VAR_A, VAR_B) 5944 5945// Macro for performing Max between a constant and N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5946#define MAX_CONST_VAR_DEF(ID, TYPE, VAR, VAL) VAR##ID = max(VAR##ID, (TYPE)VAL) 5947#define REPEAT_MAX_CONST_VAR(N, TYPE, VAR, VAL) REPEAT_3_N(N, MAX_CONST_VAR, TYPE, VAR, VAL) 5948 5949// Macro for performing Min between a constant and N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5950#define MIN_CONST_VAR_DEF(ID, TYPE, VAR, VAL) VAR##ID = min(VAR##ID, (TYPE)VAL) 5951#define REPEAT_MIN_CONST_VAR(N, TYPE, VAR, VAL) REPEAT_3_N(N, MIN_CONST_VAR, TYPE, VAR, VAL) 5952 5953// Macro for performing ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE to N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5954#define ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE_DEF(ID, SIZE, VAR, RES_MUL, RES_SHIFT) VAR##ID = ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, SIZE) 5955#define REPEAT_ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE(N, SIZE, VAR, RES_MUL, RES_SHIFT) REPEAT_4_N(N, ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE, SIZE, VAR, RES_MUL, RES_SHIFT) 5956 5957// Macro for performing ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE to N variables. Generates N statements that defines VAR##N =RHS_ACCESSOR_DEF(...) 5958#define ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE_DEF(ID, SIZE, VAR, RES_MUL, RES_SHIFT) VAR##ID = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, SIZE) 5959#define REPEAT_ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(N, SIZE, VAR, RES_MUL, RES_SHIFT) REPEAT_4_N(N, ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE, SIZE, VAR, RES_MUL, RES_SHIFT) 5960 5961// Macro for performing per-channel ASYMM_MULT_BY_QUANT_MULTIPLIER to N variables. 5962#define ASYMM_MULT_BY_QUANT_MULTIPLIER_PER_CHANNEL_DEF(ID, SIZE, VAR, RES_MUL, RES_SHIFT) \ 5963 ({ \ 5964 VEC_DATA_TYPE(int, N0) \ 5965 VAR##ID_shift_lt0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_GREATER_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, N0); \ 5966 VEC_DATA_TYPE(int, N0) \ 5967 VAR##ID_shift_gt0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(VAR##ID, RES_MUL, RES_SHIFT, N0); \ 5968 VAR##ID = select(VAR##ID_shift_lt0, VAR##ID_shift_gt0, RES_SHIFT >= 0); \ 5969 }) 5970#define REPEAT_ASYMM_MULT_BY_QUANT_MULTIPLIER_PER_CHANNEL(N, SIZE, VAR, RES_MUL, RES_SHIFT) REPEAT_4_N(N, ASYMM_MULT_BY_QUANT_MULTIPLIER_PER_CHANNEL, SIZE, VAR, RES_MUL, RES_SHIFT) 5971 5972#endif // ARM_COMPUTE_REPEAT_H 5973 5974#if defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 5975/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) 5976 * 5977 * @note The number of rows of destination matrix must be passed at compile time using -DM 5978 * @note The number of columns of the destination matrix must be passed at compile time using -DN 5979 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK 5980 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 5981 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 5982 * @note The optional alpha's value need to be passed at compile time using -DALPHA 5983 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2) 5984 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2) 5985 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 5986 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 5987 * 5988 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 5989 * The activation function is performed after the bias addition 5990 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: 5991 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 5992 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 5993 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 5994 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 5995 * 5996 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 5997 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 5998 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 5999 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 6000 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6001 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 6002 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 6003 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 6004 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6005 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 6006 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6007 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 6008 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 6009 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 6010 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 6011 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 6012 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 6013 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 6014 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 6015 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 6016 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) 6017 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 6018 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) 6019 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 6020 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 6021 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 6022 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 6023 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 6024 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 6025 */ 6026__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0), 6027 IMAGE_DECLARATION(src1), 6028#if defined(BETA) 6029 IMAGE_DECLARATION(src2), 6030#endif // defined(BETA) 6031 IMAGE_DECLARATION(dst), 6032 uint src0_stride_z, 6033 uint src1_stride_z, 6034#if defined(BETA) 6035 uint src2_stride_z, 6036#endif //defined(BETA) 6037 uint dst_stride_z 6038#if defined(REINTERPRET_OUTPUT_AS_3D) 6039 , 6040 uint cross_plane_pad 6041#endif // REINTERPRET_OUTPUT_AS_3D 6042 ) 6043{ 6044 int x = get_global_id(0) / H0; 6045 int y = get_global_id(1) / V0; 6046 int z = get_global_id(2); 6047 6048 // Offset 6049 const int offset_row_a = (get_global_id(1) % V0) * 4; 6050 const int offset_row_b = (get_global_id(0) % H0) * 4; 6051 6052 // src_addr_a = address of matrix A 6053 // src_addr_b = address of matrix B 6054 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes; 6055 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes; 6056 6057#if defined(MATRIX_B_DEPTH) 6058 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 6059 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z; 6060#else // defined(MATRIX_B_DEPTH) 6061 src1_addr_in_bytes += z * src1_stride_z; 6062#endif // defined(MATRIX_B_DEPTH) 6063 6064 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes); 6065 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes); 6066 6067 // Compute end row address for matrix B 6068 __global float *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(float)); 6069 6070 src_addr_a += offset_row_a; 6071 src_addr_b += offset_row_b; 6072 6073 // Reset accumulators 6074 float4 c0 = 0.0f; 6075 float4 c1 = 0.0f; 6076 float4 c2 = 0.0f; 6077 float4 c3 = 0.0f; 6078 6079 for(; src_addr_b <= (src_end_addr_b - (int)(8 * H0)); src_addr_a += 8 * V0, src_addr_b += 8 * H0) 6080 { 6081 // Load values from matrix A (interleaved) and matrix B (transposed) 6082 float4 a0 = vload4(0, src_addr_a); 6083 float4 b0 = vload4(0, src_addr_b); 6084 6085 c0 += (float4)a0.s0 * b0; 6086 c1 += (float4)a0.s1 * b0; 6087 c2 += (float4)a0.s2 * b0; 6088 c3 += (float4)a0.s3 * b0; 6089 6090 // Load values from matrix A (interleaved) and matrix B (transposed) 6091 a0 = vload4(0, src_addr_a + 4 * V0); 6092 b0 = vload4(0, src_addr_b + 4 * H0); 6093 6094 c0 += (float4)a0.s0 * b0; 6095 c1 += (float4)a0.s1 * b0; 6096 c2 += (float4)a0.s2 * b0; 6097 c3 += (float4)a0.s3 * b0; 6098 } 6099 6100 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 4 * H0) 6101 { 6102 // Load values from matrix A (interleaved) and matrix B (transposed) 6103 float4 a0 = vload4(0, src_addr_a); 6104 float4 b0 = vload4(0, src_addr_b); 6105 6106 c0 += (float4)a0.s0 * b0; 6107 c1 += (float4)a0.s1 * b0; 6108 c2 += (float4)a0.s2 * b0; 6109 c3 += (float4)a0.s3 * b0; 6110 } 6111 6112 // Compute destination address 6113 Image dst = CONVERT_TO_IMAGE_STRUCT(dst); 6114 6115 // Compute dst address 6116 __global uchar *dst_addr = offset(&dst, 0, 0); 6117 6118 uint4 zout = 0; 6119 6120#if defined(REINTERPRET_OUTPUT_AS_3D) 6121 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 6122 // in order to take into account the presence of possible cross plane paddings 6123 // 6124 // | | 6125 // | plane0 | 6126 // | | 6127 // |__________________| 6128 // |******************| 6129 // | cross_plane_pad | 6130 // |******************| 6131 // | | 6132 // | plane1 | 6133 // | | 6134 // |__________________| 6135 6136 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D 6137 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; 6138 zout = min(DEPTH_GEMM3D - 1, zout); 6139 6140 // Add offset due to the cross plane paddings 6141 zout *= (cross_plane_pad * dst_stride_y); 6142 6143 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 6144 // multiply dst_stride_z by DEPTH_GEMM3D 6145 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 6146#else // defined(REINTERPRET_OUTPUT_AS_3D) 6147 // Add offset for batched GEMM 6148 dst_addr += z * dst_stride_z; 6149#endif // defined(REINTERPRET_OUTPUT_AS_3D) 6150 6151 // Multiply by the weight of matrix-matrix product and store the result 6152#if defined(ALPHA) 6153 SCALE_BLOCK(4, float, c, ALPHA); 6154#endif // defined(ALPHA) 6155 6156 // Add beta*bias 6157#if defined(BETA) 6158 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); 6159 6160#if defined(BROADCAST_BIAS) 6161 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)); 6162 6163 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero); 6164 6165#ifndef UNIT_BETA 6166 SCALE_BLOCK(1, float, bias, BETA); 6167#endif // UNIT_BIAS 6168 6169 // c = c + bias[broadcasted] 6170 ADD_BLOCK_BROADCAST(4, c, bias0); 6171 6172#else // defined(BROADCAST_BIAS) 6173 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( 6174 2) * src2_stride_z; 6175 6176 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero); 6177 6178#ifndef UNIT_BETA 6179 SCALE_BLOCK(4, float, bias, BETA); 6180#endif // UNIT_BIAS 6181 6182 // c = c + bias 6183 ADD_BLOCK(4, c, bias); 6184 6185#endif // defined(BROADCAST_BIAS) 6186#endif // defined(BETA) 6187 6188#if defined(ACTIVATION_TYPE) 6189 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL); 6190#endif // defined(ACTIVATION_TYPE) 6191 6192 // Store 4x4 block 6193 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M); 6194 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N); 6195 STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 6196} 6197 6198/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) 6199 * 6200 * @note The number of rows of destination matrix must be passed at compile time using -DM 6201 * @note The number of columns of the destination matrix must be passed at compile time using -DN 6202 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK 6203 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 6204 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 6205 * @note The optional alpha's value need to be passed at compile time using -DALPHA 6206 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2) 6207 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2) 6208 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 6209 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 6210 * 6211 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 6212 * The activation function is performed after the bias addition 6213 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: 6214 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 6215 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 6216 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 6217 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 6218 * 6219 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 6220 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 6221 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6222 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 6223 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6224 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 6225 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 6226 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 6227 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6228 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 6229 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6230 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 6231 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 6232 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 6233 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 6234 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 6235 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 6236 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 6237 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 6238 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 6239 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) 6240 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 6241 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) 6242 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 6243 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 6244 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 6245 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 6246 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 6247 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 6248 */ 6249__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0), 6250 IMAGE_DECLARATION(src1), 6251#if defined(BETA) 6252 IMAGE_DECLARATION(src2), 6253#endif // defined(BETA) 6254 IMAGE_DECLARATION(dst), 6255 uint src0_stride_z, 6256 uint src1_stride_z, 6257#if defined(BETA) 6258 uint src2_stride_z, 6259#endif //defined(BETA) 6260 uint dst_stride_z 6261#if defined(REINTERPRET_OUTPUT_AS_3D) 6262 , 6263 uint cross_plane_pad 6264#endif // REINTERPRET_OUTPUT_AS_3D 6265 ) 6266{ 6267 int x = get_global_id(0) / H0; 6268 int y = get_global_id(1) / V0; 6269 int z = get_global_id(2); 6270 6271 // Offset 6272 const int offset_row_a = (get_global_id(1) % V0) * 4; 6273 const int offset_row_b = (get_global_id(0) % H0) * 4; 6274 6275 // src_addr_a = address of matrix A 6276 // src_addr_b = address of matrix B 6277 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes; 6278 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes; 6279 6280#if defined(MATRIX_B_DEPTH) 6281 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 6282 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z; 6283#else // defined(MATRIX_B_DEPTH) 6284 src1_addr_in_bytes += z * src1_stride_z; 6285#endif // defined(MATRIX_B_DEPTH) 6286 6287 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes); 6288 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes); 6289 6290 src_addr_a += offset_row_a; 6291 src_addr_b += offset_row_b; 6292 6293 // Reset accumulators 6294 float4 c0 = 0.0f; 6295 float4 c1 = 0.0f; 6296 float4 c2 = 0.0f; 6297 float4 c3 = 0.0f; 6298 6299 int i = 0; 6300 for(; i <= (int)(K - 4); i += 4) 6301 { 6302 // Load values from matrix A (interleaved) and matrix B (transposed) 6303 float4 a0 = vload4(0, src_addr_a); 6304 float4 b0 = vload4(0, src_addr_b); 6305 6306 src_addr_a += 4 * V0; 6307 src_addr_b += 4 * H0; 6308 6309 c0.s0 = fma(a0.s0, b0.s0, c0.s0); 6310 c0.s1 = fma(a0.s0, b0.s1, c0.s1); 6311 c0.s2 = fma(a0.s0, b0.s2, c0.s2); 6312 c0.s3 = fma(a0.s0, b0.s3, c0.s3); 6313 6314 c1.s0 = fma(a0.s1, b0.s0, c1.s0); 6315 c1.s1 = fma(a0.s1, b0.s1, c1.s1); 6316 c1.s2 = fma(a0.s1, b0.s2, c1.s2); 6317 c1.s3 = fma(a0.s1, b0.s3, c1.s3); 6318 6319 c2.s0 = fma(a0.s2, b0.s0, c2.s0); 6320 c2.s1 = fma(a0.s2, b0.s1, c2.s1); 6321 c2.s2 = fma(a0.s2, b0.s2, c2.s2); 6322 c2.s3 = fma(a0.s2, b0.s3, c2.s3); 6323 6324 c3.s0 = fma(a0.s3, b0.s0, c3.s0); 6325 c3.s1 = fma(a0.s3, b0.s1, c3.s1); 6326 c3.s2 = fma(a0.s3, b0.s2, c3.s2); 6327 c3.s3 = fma(a0.s3, b0.s3, c3.s3); 6328 6329 // Load values from matrix A (interleaved) and matrix B (transposed) 6330 a0 = vload4(0, src_addr_a); 6331 b0 = vload4(0, src_addr_b); 6332 6333 src_addr_a += 4 * V0; 6334 src_addr_b += 4 * H0; 6335 6336 c0.s0 = fma(a0.s0, b0.s0, c0.s0); 6337 c0.s1 = fma(a0.s0, b0.s1, c0.s1); 6338 c0.s2 = fma(a0.s0, b0.s2, c0.s2); 6339 c0.s3 = fma(a0.s0, b0.s3, c0.s3); 6340 6341 c1.s0 = fma(a0.s1, b0.s0, c1.s0); 6342 c1.s1 = fma(a0.s1, b0.s1, c1.s1); 6343 c1.s2 = fma(a0.s1, b0.s2, c1.s2); 6344 c1.s3 = fma(a0.s1, b0.s3, c1.s3); 6345 6346 c2.s0 = fma(a0.s2, b0.s0, c2.s0); 6347 c2.s1 = fma(a0.s2, b0.s1, c2.s1); 6348 c2.s2 = fma(a0.s2, b0.s2, c2.s2); 6349 c2.s3 = fma(a0.s2, b0.s3, c2.s3); 6350 6351 c3.s0 = fma(a0.s3, b0.s0, c3.s0); 6352 c3.s1 = fma(a0.s3, b0.s1, c3.s1); 6353 c3.s2 = fma(a0.s3, b0.s2, c3.s2); 6354 c3.s3 = fma(a0.s3, b0.s3, c3.s3); 6355 6356 // Load values from matrix A (interleaved) and matrix B (transposed) 6357 a0 = vload4(0, src_addr_a); 6358 b0 = vload4(0, src_addr_b); 6359 6360 src_addr_a += 4 * V0; 6361 src_addr_b += 4 * H0; 6362 6363 c0.s0 = fma(a0.s0, b0.s0, c0.s0); 6364 c0.s1 = fma(a0.s0, b0.s1, c0.s1); 6365 c0.s2 = fma(a0.s0, b0.s2, c0.s2); 6366 c0.s3 = fma(a0.s0, b0.s3, c0.s3); 6367 6368 c1.s0 = fma(a0.s1, b0.s0, c1.s0); 6369 c1.s1 = fma(a0.s1, b0.s1, c1.s1); 6370 c1.s2 = fma(a0.s1, b0.s2, c1.s2); 6371 c1.s3 = fma(a0.s1, b0.s3, c1.s3); 6372 6373 c2.s0 = fma(a0.s2, b0.s0, c2.s0); 6374 c2.s1 = fma(a0.s2, b0.s1, c2.s1); 6375 c2.s2 = fma(a0.s2, b0.s2, c2.s2); 6376 c2.s3 = fma(a0.s2, b0.s3, c2.s3); 6377 6378 c3.s0 = fma(a0.s3, b0.s0, c3.s0); 6379 c3.s1 = fma(a0.s3, b0.s1, c3.s1); 6380 c3.s2 = fma(a0.s3, b0.s2, c3.s2); 6381 c3.s3 = fma(a0.s3, b0.s3, c3.s3); 6382 6383 // Load values from matrix A (interleaved) and matrix B (transposed) 6384 a0 = vload4(0, src_addr_a); 6385 b0 = vload4(0, src_addr_b); 6386 6387 src_addr_a += 4 * V0; 6388 src_addr_b += 4 * H0; 6389 6390 c0.s0 = fma(a0.s0, b0.s0, c0.s0); 6391 c0.s1 = fma(a0.s0, b0.s1, c0.s1); 6392 c0.s2 = fma(a0.s0, b0.s2, c0.s2); 6393 c0.s3 = fma(a0.s0, b0.s3, c0.s3); 6394 6395 c1.s0 = fma(a0.s1, b0.s0, c1.s0); 6396 c1.s1 = fma(a0.s1, b0.s1, c1.s1); 6397 c1.s2 = fma(a0.s1, b0.s2, c1.s2); 6398 c1.s3 = fma(a0.s1, b0.s3, c1.s3); 6399 6400 c2.s0 = fma(a0.s2, b0.s0, c2.s0); 6401 c2.s1 = fma(a0.s2, b0.s1, c2.s1); 6402 c2.s2 = fma(a0.s2, b0.s2, c2.s2); 6403 c2.s3 = fma(a0.s2, b0.s3, c2.s3); 6404 6405 c3.s0 = fma(a0.s3, b0.s0, c3.s0); 6406 c3.s1 = fma(a0.s3, b0.s1, c3.s1); 6407 c3.s2 = fma(a0.s3, b0.s2, c3.s2); 6408 c3.s3 = fma(a0.s3, b0.s3, c3.s3); 6409 } 6410 6411 for(; i < (int)K; ++i) 6412 { 6413 // Load values from matrix A (interleaved) and matrix B (transposed) 6414 float4 a0 = vload4(0, src_addr_a); 6415 float4 b0 = vload4(0, src_addr_b); 6416 6417 src_addr_a += 4 * V0; 6418 src_addr_b += 4 * H0; 6419 6420 c0.s0 = fma(a0.s0, b0.s0, c0.s0); 6421 c0.s1 = fma(a0.s0, b0.s1, c0.s1); 6422 c0.s2 = fma(a0.s0, b0.s2, c0.s2); 6423 c0.s3 = fma(a0.s0, b0.s3, c0.s3); 6424 6425 c1.s0 = fma(a0.s1, b0.s0, c1.s0); 6426 c1.s1 = fma(a0.s1, b0.s1, c1.s1); 6427 c1.s2 = fma(a0.s1, b0.s2, c1.s2); 6428 c1.s3 = fma(a0.s1, b0.s3, c1.s3); 6429 6430 c2.s0 = fma(a0.s2, b0.s0, c2.s0); 6431 c2.s1 = fma(a0.s2, b0.s1, c2.s1); 6432 c2.s2 = fma(a0.s2, b0.s2, c2.s2); 6433 c2.s3 = fma(a0.s2, b0.s3, c2.s3); 6434 6435 c3.s0 = fma(a0.s3, b0.s0, c3.s0); 6436 c3.s1 = fma(a0.s3, b0.s1, c3.s1); 6437 c3.s2 = fma(a0.s3, b0.s2, c3.s2); 6438 c3.s3 = fma(a0.s3, b0.s3, c3.s3); 6439 } 6440 6441 // Compute destination address 6442 Image dst = CONVERT_TO_IMAGE_STRUCT(dst); 6443 6444 // Compute dst address 6445 __global uchar *dst_addr = offset(&dst, 0, 0); 6446 6447 uint4 zout = 0; 6448 6449#if defined(REINTERPRET_OUTPUT_AS_3D) 6450 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 6451 // in order to take into account the presence of possible cross plane paddings 6452 // 6453 // | | 6454 // | plane0 | 6455 // | | 6456 // |__________________| 6457 // |******************| 6458 // | cross_plane_pad | 6459 // |******************| 6460 // | | 6461 // | plane1 | 6462 // | | 6463 // |__________________| 6464 6465 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D 6466 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; 6467 zout = min(DEPTH_GEMM3D - 1, zout); 6468 6469 // Add offset due to the cross plane paddings 6470 zout *= (cross_plane_pad * dst_stride_y); 6471 6472 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 6473 // multiply dst_stride_z by DEPTH_GEMM3D 6474 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 6475#else // defined(REINTERPRET_OUTPUT_AS_3D) 6476 // Add offset for batched GEMM 6477 dst_addr += z * dst_stride_z; 6478#endif // defined(REINTERPRET_OUTPUT_AS_3D) 6479 6480 // Multiply by the weight of matrix-matrix product and store the result 6481#if defined(ALPHA) 6482 SCALE_BLOCK(4, float, c, ALPHA); 6483#endif // defined(ALPHA) 6484 6485 // Add beta*bias 6486#if defined(BETA) 6487 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); 6488 6489#if defined(BROADCAST_BIAS) 6490 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)); 6491 6492 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero); 6493 6494#ifndef UNIT_BETA 6495 SCALE_BLOCK(1, float, bias, BETA); 6496#endif // UNIT_BIAS 6497 6498 // c = c + bias[broadcasted] 6499 ADD_BLOCK_BROADCAST(4, c, bias0); 6500 6501#else // defined(BROADCAST_BIAS) 6502 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( 6503 2) * src2_stride_z; 6504 6505 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero); 6506 6507#ifndef UNIT_BETA 6508 SCALE_BLOCK(4, float, bias, BETA); 6509#endif // UNIT_BIAS 6510 6511 // c = c + bias 6512 ADD_BLOCK(4, c, bias); 6513 6514#endif // defined(BROADCAST_BIAS) 6515#endif // defined(BETA) 6516 6517#if defined(ACTIVATION_TYPE) 6518 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL); 6519#endif // defined(ACTIVATION_TYPE) 6520 6521 // Store 4x4 block 6522 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M); 6523 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N); 6524 STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 6525} 6526 6527#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 6528/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) 6529 * 6530 * @note The number of rows of destination matrix must be passed at compile time using -DM 6531 * @note The number of columns of the destination matrix must be passed at compile time using -DN 6532 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK 6533 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 6534 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 6535 * @note The optional alpha's value need to be passed at compile time using -DALPHA 6536 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2) 6537 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2) 6538 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 6539 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 6540 * 6541 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 6542 * The activation function is performed after the bias addition 6543 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: 6544 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 6545 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 6546 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 6547 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 6548 * 6549 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 6550 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 6551 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6552 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 6553 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6554 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 6555 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 6556 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 6557 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6558 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 6559 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6560 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 6561 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 6562 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 6563 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 6564 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 6565 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 6566 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 6567 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 6568 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 6569 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) 6570 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 6571 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) 6572 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 6573 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 6574 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 6575 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 6576 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 6577 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 6578 */ 6579__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0), 6580 IMAGE_DECLARATION(src1), 6581#if defined(BETA) 6582 IMAGE_DECLARATION(src2), 6583#endif // defined(BETA) 6584 IMAGE_DECLARATION(dst), 6585 uint src0_stride_z, 6586 uint src1_stride_z, 6587#if defined(BETA) 6588 uint src2_stride_z, 6589#endif //defined(BETA) 6590 uint dst_stride_z 6591#if defined(REINTERPRET_OUTPUT_AS_3D) 6592 , 6593 uint cross_plane_pad 6594#endif // REINTERPRET_OUTPUT_AS_3D 6595 ) 6596{ 6597 int x = get_global_id(0) / H0; 6598 int y = get_global_id(1) / V0; 6599 int z = get_global_id(2); 6600 6601 // Offset 6602 const int offset_row_a = (get_global_id(1) % V0) * 4; 6603 const int offset_row_b = (get_global_id(0) % H0) * 8; 6604 6605 // src_addr_a = address of matrix A 6606 // src_addr_b = address of matrix B 6607 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes; 6608 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes; 6609 6610#if defined(MATRIX_B_DEPTH) 6611 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 6612 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z; 6613#else // defined(MATRIX_B_DEPTH) 6614 src1_addr_in_bytes += z * src1_stride_z; 6615#endif // defined(MATRIX_B_DEPTH) 6616 6617 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes); 6618 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes); 6619 6620 // Compute end row address for matrix B 6621 __global half *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(half)); 6622 6623 src_addr_a += offset_row_a; 6624 src_addr_b += offset_row_b; 6625 6626 // Reset accumulators 6627 half8 c0 = 0.0f; 6628 half8 c1 = 0.0f; 6629 half8 c2 = 0.0f; 6630 half8 c3 = 0.0f; 6631 6632 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0) 6633 { 6634 // Load values from matrix A (interleaved) and matrix B (transposed) 6635 half4 a0 = vload4(0, src_addr_a); 6636 half8 b0 = vload8(0, src_addr_b); 6637 6638 c0 += (half8)a0.s0 * b0; 6639 c1 += (half8)a0.s1 * b0; 6640 c2 += (half8)a0.s2 * b0; 6641 c3 += (half8)a0.s3 * b0; 6642 6643 // Load values from matrix A (interleaved) and matrix B (transposed) 6644 a0 = vload4(0, src_addr_a + 4 * V0); 6645 b0 = vload8(0, src_addr_b + 8 * H0); 6646 6647 c0 += (half8)a0.s0 * b0; 6648 c1 += (half8)a0.s1 * b0; 6649 c2 += (half8)a0.s2 * b0; 6650 c3 += (half8)a0.s3 * b0; 6651 } 6652 6653 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0) 6654 { 6655 // Load values from matrix A (interleaved) and matrix B (transposed) 6656 half4 a0 = vload4(0, src_addr_a); 6657 half8 b0 = vload8(0, src_addr_b); 6658 6659 c0 += (half8)a0.s0 * b0; 6660 c1 += (half8)a0.s1 * b0; 6661 c2 += (half8)a0.s2 * b0; 6662 c3 += (half8)a0.s3 * b0; 6663 } 6664 6665 // Compute destination address 6666 Image dst = CONVERT_TO_IMAGE_STRUCT(dst); 6667 6668 // Compute dst address 6669 __global uchar *dst_addr = offset(&dst, 0, 0); 6670 6671 uint4 zout = 0; 6672 6673#if defined(REINTERPRET_OUTPUT_AS_3D) 6674 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 6675 // in order to take into account the presence of possible cross plane paddings 6676 // 6677 // | | 6678 // | plane0 | 6679 // | | 6680 // |__________________| 6681 // |******************| 6682 // | cross_plane_pad | 6683 // |******************| 6684 // | | 6685 // | plane1 | 6686 // | | 6687 // |__________________| 6688 6689 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D 6690 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; 6691 zout = min(DEPTH_GEMM3D - 1, zout); 6692 6693 // Add offset due to the cross plane paddings 6694 zout *= (cross_plane_pad * dst_stride_y); 6695 6696 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 6697 // multiply dst_stride_z by DEPTH_GEMM3D 6698 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 6699#else // defined(REINTERPRET_OUTPUT_AS_3D) 6700 // Add offset for batched GEMM 6701 dst_addr += z * dst_stride_z; 6702#endif // defined(REINTERPRET_OUTPUT_AS_3D) 6703 6704 // Multiply by the weight of matrix-matrix product and store the result 6705#if defined(ALPHA) 6706 SCALE_BLOCK(4, half, c, ALPHA); 6707#endif // defined(ALPHA) 6708 6709 // Add beta*bias 6710#if defined(BETA) 6711 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); 6712 6713#if defined(BROADCAST_BIAS) 6714 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); 6715 6716 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 6717 6718#ifndef UNIT_BETA 6719 SCALE_BLOCK(1, half, bias, BETA); 6720#endif // UNIT_BIAS 6721 6722 // c = c + bias[broadcasted] 6723 ADD_BLOCK_BROADCAST(4, c, bias0); 6724 6725#else // defined(BROADCAST_BIAS) 6726 6727 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( 6728 2) * src2_stride_z; 6729 6730 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 6731 6732#ifndef UNIT_BETA 6733 SCALE_BLOCK(4, half, bias, BETA); 6734#endif // UNIT_BIAS 6735 6736 // c = c + bias 6737 ADD_BLOCK(4, c, bias); 6738 6739#endif // defined(BROADCAST_BIAS) 6740#endif // defined(BETA) 6741 6742#if defined(ACTIVATION_TYPE) 6743 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL); 6744#endif // defined(ACTIVATION_TYPE) 6745 6746 // Store 4x8 block 6747 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M); 6748 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N); 6749 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 6750} 6751 6752/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable. 6753 * 6754 * @note The number of rows of destination matrix must be passed at compile time using -DM 6755 * @note The number of columns of the destination matrix must be passed at compile time using -DN 6756 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK 6757 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 6758 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 6759 * @note The optional alpha's value need to be passed at compile time using -DALPHA 6760 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2) 6761 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2) 6762 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 6763 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 6764 * 6765 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 6766 * The activation function is performed after the bias addition 6767 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: 6768 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 6769 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 6770 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 6771 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 6772 * 6773 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 6774 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 6775 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6776 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 6777 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6778 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 6779 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 6780 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 6781 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 6782 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 6783 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 6784 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 6785 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 6786 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 6787 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 6788 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 6789 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 6790 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 6791 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 6792 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 6793 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) 6794 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 6795 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) 6796 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 6797 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 6798 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 6799 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 6800 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 6801 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 6802 */ 6803__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0), 6804 IMAGE_DECLARATION(src1), 6805#if defined(BETA) 6806 IMAGE_DECLARATION(src2), 6807#endif // defined(BETA) 6808 IMAGE_DECLARATION(dst), 6809 uint src0_stride_z, 6810 uint src1_stride_z, 6811#if defined(BETA) 6812 uint src2_stride_z, 6813#endif //defined(BETA) 6814 uint dst_stride_z 6815#if defined(REINTERPRET_OUTPUT_AS_3D) 6816 , 6817 uint cross_plane_pad 6818#endif // REINTERPRET_OUTPUT_AS_3D 6819 ) 6820{ 6821 int x = get_global_id(0) / H0; 6822 int y = get_global_id(1) / V0; 6823 int z = get_global_id(2); 6824 6825 // Offset 6826 const int offset_row_a = (get_global_id(1) % V0) * 4; 6827 const int offset_row_b = (get_global_id(0) % H0) * 8; 6828 6829 // src_addr_a = address of matrix A 6830 // src_addr_b = address of matrix B 6831 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes; 6832 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes; 6833 6834#if defined(MATRIX_B_DEPTH) 6835 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 6836 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z; 6837#else // defined(MATRIX_B_DEPTH) 6838 src1_addr_in_bytes += z * src1_stride_z; 6839#endif // defined(MATRIX_B_DEPTH) 6840 6841 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes); 6842 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes); 6843 6844 // Compute end row address for matrix B 6845 __global half *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(half)); 6846 6847 src_addr_a += offset_row_a; 6848 src_addr_b += offset_row_b; 6849 6850 // Reset accumulators 6851 float8 c0 = 0.0f; 6852 float8 c1 = 0.0f; 6853 float8 c2 = 0.0f; 6854 float8 c3 = 0.0f; 6855 6856 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0) 6857 { 6858 // Load values from matrix A (interleaved) and matrix B (transposed) 6859 float4 a0 = convert_float4(vload4(0, src_addr_a)); 6860 float8 b0 = convert_float8(vload8(0, src_addr_b)); 6861 6862 c0 += (float8)a0.s0 * b0; 6863 c1 += (float8)a0.s1 * b0; 6864 c2 += (float8)a0.s2 * b0; 6865 c3 += (float8)a0.s3 * b0; 6866 6867 // Load values from matrix A (interleaved) and matrix B (transposed) 6868 a0 = convert_float4(vload4(0, src_addr_a + 4 * V0)); 6869 b0 = convert_float8(vload8(0, src_addr_b + 8 * H0)); 6870 6871 c0 += (float8)a0.s0 * b0; 6872 c1 += (float8)a0.s1 * b0; 6873 c2 += (float8)a0.s2 * b0; 6874 c3 += (float8)a0.s3 * b0; 6875 } 6876 6877 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0) 6878 { 6879 // Load values from matrix A (interleaved) and matrix B (transposed) 6880 float4 a0 = convert_float4(vload4(0, src_addr_a)); 6881 float8 b0 = convert_float8(vload8(0, src_addr_b)); 6882 6883 c0 += (float8)a0.s0 * b0; 6884 c1 += (float8)a0.s1 * b0; 6885 c2 += (float8)a0.s2 * b0; 6886 c3 += (float8)a0.s3 * b0; 6887 } 6888 6889 // Compute destination address 6890 Image dst = CONVERT_TO_IMAGE_STRUCT(dst); 6891 6892 // Compute dst address 6893 __global uchar *dst_addr = offset(&dst, 0, 0); 6894 6895 uint4 zout = 0; 6896 6897#if defined(REINTERPRET_OUTPUT_AS_3D) 6898 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 6899 // in order to take into account the presence of possible cross plane paddings 6900 // 6901 // | | 6902 // | plane0 | 6903 // | | 6904 // |__________________| 6905 // |******************| 6906 // | cross_plane_pad | 6907 // |******************| 6908 // | | 6909 // | plane1 | 6910 // | | 6911 // |__________________| 6912 6913 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D 6914 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; 6915 zout = min(DEPTH_GEMM3D - 1, zout); 6916 6917 // Add offset due to the cross plane paddings 6918 zout *= (cross_plane_pad * dst_stride_y); 6919 6920 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 6921 // multiply dst_stride_z by DEPTH_GEMM3D 6922 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 6923#else // defined(REINTERPRET_OUTPUT_AS_3D) 6924 // Add offset for batched GEMM 6925 dst_addr += z * dst_stride_z; 6926#endif // defined(REINTERPRET_OUTPUT_AS_3D) 6927 6928 // Multiply by the weight of matrix-matrix product and store the result 6929#if defined(ALPHA) 6930 SCALE_BLOCK(4, float, c, ALPHA); 6931#endif // defined(ALPHA) 6932 6933#if defined(BETA) 6934 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); 6935 6936#if defined(BROADCAST_BIAS) 6937 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); 6938 6939 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 6940 6941 float8 bias_f0 = convert_float8(bias0); 6942 6943#ifndef UNIT_BETA 6944 SCALE_BLOCK(1, float, bias_f, BETA); 6945#endif // UNIT_BIAS 6946 6947 // c = c + bias[broadcasted] 6948 ADD_BLOCK_BROADCAST(4, c, bias_f0); 6949 6950#else // defined(BROADCAST_BIAS) 6951 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( 6952 2) * src2_stride_z; 6953 6954 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 6955 6956 float8 bias_f0 = convert_float8(bias0); 6957 float8 bias_f1 = convert_float8(bias1); 6958 float8 bias_f2 = convert_float8(bias2); 6959 float8 bias_f3 = convert_float8(bias3); 6960 6961#ifndef UNIT_BETA 6962 SCALE_BLOCK(4, float, bias_f, BETA); 6963#endif // UNIT_BIAS 6964 6965 // c = c + bias 6966 ADD_BLOCK(4, c, bias_f); 6967 6968#endif // defined(BROADCAST_BIAS) 6969#endif // defined(BETA) 6970 6971 half8 c_h0 = convert_half8(c0); 6972 half8 c_h1 = convert_half8(c1); 6973 half8 c_h2 = convert_half8(c2); 6974 half8 c_h3 = convert_half8(c3); 6975 6976#if defined(ACTIVATION_TYPE) 6977 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c_h, A_VAL, B_VAL); 6978#endif // defined(ACTIVATION_TYPE) 6979 6980 // Store 4x8 block 6981 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M); 6982 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N); 6983 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 6984} 6985 6986/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) 6987 * 6988 * @note The number of rows of destination matrix must be passed at compile time using -DM 6989 * @note The number of columns of the destination matrix must be passed at compile time using -DN 6990 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK 6991 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 6992 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 6993 * @note The optional alpha's value need to be passed at compile time using -DALPHA 6994 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2) 6995 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2) 6996 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 6997 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 6998 * 6999 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 7000 * The activation function is performed after the bias addition 7001 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time: 7002 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 7003 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 7004 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 7005 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 7006 * 7007 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 7008 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 7009 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 7010 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 7011 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 7012 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 7013 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 7014 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 7015 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 7016 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 7017 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 7018 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 7019 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 7020 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 7021 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 7022 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 7023 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 7024 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 7025 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 7026 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 7027 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) 7028 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 7029 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) 7030 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 7031 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 7032 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 7033 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 7034 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 7035 */ 7036__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0), 7037 IMAGE_DECLARATION(src1), 7038#if defined(BETA) 7039 IMAGE_DECLARATION(src2), 7040#endif // defined(BETA) 7041 IMAGE_DECLARATION(dst), 7042 uint src0_stride_z, 7043 uint src1_stride_z, 7044#if defined(BETA) 7045 uint src2_stride_z, 7046#endif //defined(BETA) 7047 uint dst_stride_z 7048#if defined(REINTERPRET_OUTPUT_AS_3D) 7049 , 7050 uint cross_plane_pad 7051#endif // REINTERPRET_OUTPUT_AS_3D 7052 ) 7053{ 7054 int x = get_global_id(0) / H0; 7055 int y = get_global_id(1) / V0; 7056 int z = get_global_id(2); 7057 7058 // Offset 7059 const int offset_row_a = (get_global_id(1) % V0) * 4; 7060 const int offset_row_b = (get_global_id(0) % H0) * 8; 7061 7062 // src_addr_a = address of matrix A 7063 // src_addr_b = address of matrix B 7064 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes; 7065 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes; 7066 7067#if defined(MATRIX_B_DEPTH) 7068 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 7069 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z; 7070#else // defined(MATRIX_B_DEPTH) 7071 src1_addr_in_bytes += z * src1_stride_z; 7072#endif // defined(MATRIX_B_DEPTH) 7073 7074 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes); 7075 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes); 7076 7077 src_addr_a += offset_row_a; 7078 src_addr_b += offset_row_b; 7079 7080 // Reset accumulators 7081 half8 c0 = 0.0f; 7082 half8 c1 = 0.0f; 7083 half8 c2 = 0.0f; 7084 half8 c3 = 0.0f; 7085 7086 int i = 0; 7087 for(; i <= (int)(K - 4); i += 4) 7088 { 7089#if V0 == 1 7090 // Load values from matrix A (interleaved) and matrix B (transposed) 7091 half8 a0 = vload8(0, src_addr_a); 7092 half8 b0 = vload8(0, src_addr_b); 7093 7094 src_addr_a += 8 * V0; 7095 src_addr_b += 8 * H0; 7096 7097 c0 = fma((half8)a0.s0, b0, c0); 7098 c1 = fma((half8)a0.s1, b0, c1); 7099 c2 = fma((half8)a0.s2, b0, c2); 7100 c3 = fma((half8)a0.s3, b0, c3); 7101 7102 // Load values from matrix B (transposed) 7103 b0 = vload8(0, src_addr_b); 7104 7105 src_addr_b += 8 * H0; 7106 7107 c0 = fma((half8)a0.s4, b0, c0); 7108 c1 = fma((half8)a0.s5, b0, c1); 7109 c2 = fma((half8)a0.s6, b0, c2); 7110 c3 = fma((half8)a0.s7, b0, c3); 7111 7112 // Load values from matrix A (interleaved) and matrix B (transposed) 7113 a0 = vload8(0, src_addr_a); 7114 b0 = vload8(0, src_addr_b); 7115 7116 src_addr_a += 8 * V0; 7117 src_addr_b += 8 * H0; 7118 7119 c0 = fma((half8)a0.s0, b0, c0); 7120 c1 = fma((half8)a0.s1, b0, c1); 7121 c2 = fma((half8)a0.s2, b0, c2); 7122 c3 = fma((half8)a0.s3, b0, c3); 7123 7124 // Load values from matrix B (transposed) 7125 b0 = vload8(0, src_addr_b); 7126 7127 src_addr_b += 8 * H0; 7128 7129 c0 = fma((half8)a0.s4, b0, c0); 7130 c1 = fma((half8)a0.s5, b0, c1); 7131 c2 = fma((half8)a0.s6, b0, c2); 7132 c3 = fma((half8)a0.s7, b0, c3); 7133#else // V0 == 1 7134 // Load values from matrix A (interleaved) and matrix B (transposed) 7135 half4 a0 = vload4(0, src_addr_a); 7136 half8 b0 = vload8(0, src_addr_b); 7137 7138 src_addr_a += 4 * V0; 7139 src_addr_b += 8 * H0; 7140 7141 c0 = fma((half8)a0.s0, b0, c0); 7142 c1 = fma((half8)a0.s1, b0, c1); 7143 c2 = fma((half8)a0.s2, b0, c2); 7144 c3 = fma((half8)a0.s3, b0, c3); 7145 7146 // Load values from matrix A (interleaved) and matrix B (transposed) 7147 a0 = vload4(0, src_addr_a); 7148 b0 = vload8(0, src_addr_b); 7149 7150 src_addr_a += 4 * V0; 7151 src_addr_b += 8 * H0; 7152 7153 c0 = fma((half8)a0.s0, b0, c0); 7154 c1 = fma((half8)a0.s1, b0, c1); 7155 c2 = fma((half8)a0.s2, b0, c2); 7156 c3 = fma((half8)a0.s3, b0, c3); 7157 7158 // Load values from matrix A (interleaved) and matrix B (transposed) 7159 a0 = vload4(0, src_addr_a); 7160 b0 = vload8(0, src_addr_b); 7161 7162 src_addr_a += 4 * V0; 7163 src_addr_b += 8 * H0; 7164 7165 c0 = fma((half8)a0.s0, b0, c0); 7166 c1 = fma((half8)a0.s1, b0, c1); 7167 c2 = fma((half8)a0.s2, b0, c2); 7168 c3 = fma((half8)a0.s3, b0, c3); 7169 7170 // Load values from matrix A (interleaved) and matrix B (transposed) 7171 a0 = vload4(0, src_addr_a); 7172 b0 = vload8(0, src_addr_b); 7173 7174 src_addr_a += 4 * V0; 7175 src_addr_b += 8 * H0; 7176 7177 c0 = fma((half8)a0.s0, b0, c0); 7178 c1 = fma((half8)a0.s1, b0, c1); 7179 c2 = fma((half8)a0.s2, b0, c2); 7180 c3 = fma((half8)a0.s3, b0, c3); 7181#endif // V0 == 1 7182 } 7183 7184 for(; i < (int)K; ++i) 7185 { 7186 // Load values from matrix A (interleaved) and matrix B (transposed) 7187 half4 a0 = vload4(0, src_addr_a); 7188 half8 b0 = vload8(0, src_addr_b); 7189 7190 src_addr_a += 4 * V0; 7191 src_addr_b += 8 * H0; 7192 7193 c0 = fma((half8)a0.s0, b0, c0); 7194 c1 = fma((half8)a0.s1, b0, c1); 7195 c2 = fma((half8)a0.s2, b0, c2); 7196 c3 = fma((half8)a0.s3, b0, c3); 7197 } 7198 7199 // Compute destination address 7200 Image dst = CONVERT_TO_IMAGE_STRUCT(dst); 7201 7202 // Compute dst address 7203 __global uchar *dst_addr = offset(&dst, 0, 0); 7204 7205 uint4 zout = 0; 7206 7207#if defined(REINTERPRET_OUTPUT_AS_3D) 7208 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 7209 // in order to take into account the presence of possible cross plane paddings 7210 // 7211 // | | 7212 // | plane0 | 7213 // | | 7214 // |__________________| 7215 // |******************| 7216 // | cross_plane_pad | 7217 // |******************| 7218 // | | 7219 // | plane1 | 7220 // | | 7221 // |__________________| 7222 7223 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D 7224 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D; 7225 zout = min(DEPTH_GEMM3D - 1, zout); 7226 7227 // Add offset due to the cross plane paddings 7228 zout *= (cross_plane_pad * dst_stride_y); 7229 7230 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 7231 // multiply dst_stride_z by DEPTH_GEMM3D 7232 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 7233#else // defined(REINTERPRET_OUTPUT_AS_3D) 7234 // Add offset for batched GEMM 7235 dst_addr += z * dst_stride_z; 7236#endif // defined(REINTERPRET_OUTPUT_AS_3D) 7237 7238 // Multiply by the weight of matrix-matrix product and store the result 7239#if defined(ALPHA) 7240 SCALE_BLOCK(4, half, c, ALPHA); 7241#endif // defined(ALPHA) 7242 7243 // Add beta*bias 7244#if defined(BETA) 7245 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0); 7246 7247#if defined(BROADCAST_BIAS) 7248 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); 7249 7250 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 7251 7252#ifndef UNIT_BETA 7253 SCALE_BLOCK(1, half, bias, BETA); 7254#endif // UNIT_BIAS 7255 7256 // c = c + bias[broadcasted] 7257 ADD_BLOCK_BROADCAST(4, c, bias0); 7258 7259#else // defined(BROADCAST_BIAS) 7260 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id( 7261 2) * src2_stride_z; 7262 7263 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 7264 7265#ifndef UNIT_BETA 7266 SCALE_BLOCK(4, half, bias, BETA); 7267#endif // UNIT_BIAS 7268 7269 // c = c + bias 7270 ADD_BLOCK(4, c, bias); 7271 7272#endif // defined(BROADCAST_BIAS) 7273#endif // defined(BETA) 7274 7275#if defined(ACTIVATION_TYPE) 7276 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL); 7277#endif // defined(ACTIVATION_TYPE) 7278 7279 // Store 4x8 block 7280 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M); 7281 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N); 7282 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 7283} 7284 7285#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 7286 7287#endif // defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 7288 7289#if defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 7290#if defined(DATA_TYPE) 7291#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, N0) 7292/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped. 7293 * 7294 * @note This OpenCL kernel works with floating point data types (F16/F32) 7295 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float) 7296 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0 7297 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN 7298 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 7299 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 7300 * @note The optional alpha's value need to be passed at compile time using -DALPHA 7301 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 7302 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 7303 * 7304 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 7305 * The activation function is performed after the bias addition 7306 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: 7307 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D 7308 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 7309 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 7310 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 7311 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 7312 * 7313 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32 7314 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 7315 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 7316 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 7317 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 7318 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 7319 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 7320 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 7321 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 7322 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 7323 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 7324 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 7325 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 7326 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 7327 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 7328 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 7329 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 7330 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 7331 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 7332 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 7333 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) 7334 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 7335 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) 7336 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 7337 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 7338 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 7339 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 7340 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 7341 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) 7342 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D) 7343 */ 7344__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0), 7345 IMAGE_DECLARATION(src1), 7346#if defined(BETA) 7347 IMAGE_DECLARATION(src2), 7348#endif // defined(BETA) 7349 IMAGE_DECLARATION(dst), 7350 uint src0_stride_z, 7351 uint src1_stride_z, 7352#if defined(BETA) 7353 uint src2_stride_z, 7354#endif //defined(BETA) 7355 uint dst_stride_z 7356#if defined(REINTERPRET_INPUT_AS_3D) 7357 , 7358 uint src_cross_plane_pad 7359#endif // REINTERPRET_INPUT_AS_3D 7360#if defined(REINTERPRET_OUTPUT_AS_3D) 7361 , 7362 uint dst_cross_plane_pad 7363#endif // REINTERPRET_OUTPUT_AS_3D 7364 ) 7365{ 7366 int idx = get_global_id(0) * N0; 7367 7368 // Compute starting address for matrix A and Matrix B 7369 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); 7370 7371 // Update address for the matrix A 7372 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y; 7373 7374 // Update address for the matrix B 7375 src_addr.s1 += idx * sizeof(DATA_TYPE); 7376 7377#if defined(REINTERPRET_INPUT_AS_3D) 7378 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension 7379 // in order to take into account the presence of possible cross plane paddings 7380 // 7381 // | | 7382 // | plane0 | 7383 // | | 7384 // |__________________| 7385 // |******************| 7386 // | cross_plane_pad | 7387 // |******************| 7388 // | | 7389 // | plane1 | 7390 // | | 7391 // |__________________| 7392 7393 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D 7394 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 7395 zin = min(DEPTH_GEMM3D - 1, zin); 7396 7397 // Add offset due to the cross plane paddings 7398 zin *= (src_cross_plane_pad * src0_stride_y); 7399 7400 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 7401 // multiply src0_stride_z by DEPTH_GEMM3D 7402 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; 7403 7404#else // defined(REINTERPRET_INPUT_AS_3D) 7405 7406 // Add offset for batched GEMM 7407 src_addr.s0 += get_global_id(2) * src0_stride_z; 7408 7409#endif // defined(REINTERPRET_INPUT_AS_3D) 7410 7411#if defined(MATRIX_B_DEPTH) 7412 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 7413 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; 7414#else // defined(MATRIX_B_DEPTH) 7415 src_addr.s1 += get_global_id(2) * src1_stride_z; 7416#endif // defined(MATRIX_B_DEPTH) 7417 7418 int end_row_vec_a = src_addr.s0 + (K * sizeof(DATA_TYPE)); 7419 7420 VECTOR_TYPE acc0 = 0.0f; 7421#if M0 > 1 7422 VECTOR_TYPE acc1 = 0.0f; 7423#endif // M0 > 1 7424#if M0 > 2 7425 VECTOR_TYPE acc2 = 0.0f; 7426#endif // M0 > 2 7427#if M0 > 3 7428 VECTOR_TYPE acc3 = 0.0f; 7429#endif // M0 > 3 7430 7431 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y)) 7432 { 7433#if defined(REINTERPRET_INPUT_AS_3D) 7434 // Load values from matrix A 7435 LOAD_BLOCK(M0, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); 7436#else // defined(REINTERPRET_INPUT_AS_3D) 7437 // Load values from matrix A 7438 VEC_DATA_TYPE(DATA_TYPE, 2) 7439 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 7440#if M0 > 1 7441 VEC_DATA_TYPE(DATA_TYPE, 2) 7442 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 7443#endif // M0 > 1 7444#if M0 > 2 7445 VEC_DATA_TYPE(DATA_TYPE, 2) 7446 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 7447#endif // M0 > 2 7448#if M0 > 3 7449 VEC_DATA_TYPE(DATA_TYPE, 2) 7450 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 7451#endif // M0 > 3 7452#endif // defined(REINTERPRET_INPUT_AS_3D) 7453 7454 // Load values from matrix B 7455 VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); 7456 VECTOR_TYPE b1 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y)); 7457 7458 // Accumulate 7459 acc0 += b0 * (VECTOR_TYPE)a0.s0; 7460 acc0 += b1 * (VECTOR_TYPE)a0.s1; 7461#if M0 > 1 7462 acc1 += b0 * (VECTOR_TYPE)a1.s0; 7463 acc1 += b1 * (VECTOR_TYPE)a1.s1; 7464#endif // M0 > 1 7465#if M0 > 2 7466 acc2 += b0 * (VECTOR_TYPE)a2.s0; 7467 acc2 += b1 * (VECTOR_TYPE)a2.s1; 7468#endif // M0 > 2 7469#if M0 > 3 7470 acc3 += b0 * (VECTOR_TYPE)a3.s0; 7471 acc3 += b1 * (VECTOR_TYPE)a3.s1; 7472#endif // M0 > 3 7473 } 7474 7475 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y)) 7476 { 7477#if defined(REINTERPRET_INPUT_AS_3D) 7478 // Load values from matrix A 7479 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); 7480#if M0 > 1 7481 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); 7482#endif // M0 > 1 7483#if M0 > 2 7484 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); 7485#endif // M0 > 2 7486#if M0 > 3 7487 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); 7488#endif // M0 > 3 7489#else // defined(REINTERPRET_INPUT_AS_3D) 7490 // Load values from matrix A 7491 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 7492#if M0 > 1 7493 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 7494#endif // M0 > 1 7495#if M0 > 2 7496 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 7497#endif // M0 > 2 7498#if M0 > 3 7499 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 7500#endif // M0 > 3 7501#endif // defined(REINTERPRET_INPUT_AS_3D) 7502 7503 // Load values from matrix B 7504 VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); 7505 7506 // Accumulate 7507 acc0 += b0 * (VECTOR_TYPE)a0; 7508#if M0 > 1 7509 acc1 += b0 * (VECTOR_TYPE)a1; 7510#endif // M0 > 1 7511#if M0 > 2 7512 acc2 += b0 * (VECTOR_TYPE)a2; 7513#endif // M0 > 2 7514#if M0 > 3 7515 acc3 += b0 * (VECTOR_TYPE)a3; 7516#endif // M0 > 3 7517 } 7518 7519 int z = get_global_id(2); 7520 7521 // Compute dst address 7522 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 7523 PARTIAL_STORE_M0) 7524 * dst_stride_y); 7525 7526 uint4 zout = 0; 7527 7528#if defined(REINTERPRET_OUTPUT_AS_3D) 7529 7530 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 7531 // in order to take into account the presence of possible cross plane paddings 7532 // 7533 // | | 7534 // | plane0 | 7535 // | | 7536 // |__________________| 7537 // |******************| 7538 // | cross_plane_pad | 7539 // |******************| 7540 // | | 7541 // | plane1 | 7542 // | | 7543 // |__________________| 7544 7545 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D 7546 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 7547 zout = min(DEPTH_GEMM3D - 1, zout); 7548 7549 // Add offset due to the cross plane paddings 7550 zout *= (dst_cross_plane_pad * dst_stride_y); 7551 7552 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 7553 // multiply dst_stride_z by DEPTH_GEMM3D 7554 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 7555#else // defined(REINTERPRET_OUTPUT_AS_3D) 7556 // Add offset for batched GEMM 7557 dst_addr += z * dst_stride_z; 7558#endif // defined(REINTERPRET_OUTPUT_AS_3D) 7559 7560 // Multiply by the weight of matrix-matrix product and store the result 7561#if defined(ALPHA) 7562 SCALE_BLOCK(M0, DATA_TYPE, acc, ALPHA); 7563#endif // defined(ALPHA) 7564 7565 // Add beta*bias 7566#if defined(BETA) 7567 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); 7568 7569#if defined(BROADCAST_BIAS) 7570 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)); 7571 7572 LOAD_BLOCK(1, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero); 7573 7574#ifndef UNIT_BETA 7575 SCALE_BLOCK(1, DATA_TYPE, bias, BETA); 7576#endif // UNIT_BIAS 7577 7578 // c = c + bias[broadcasted] 7579 ADD_BLOCK_BROADCAST(M0, acc, bias0); 7580 7581#else // defined(BROADCAST_BIAS) 7582 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 7583 PARTIAL_STORE_M0) 7584 * src2_stride_y) 7585 + z * src2_stride_z; 7586 7587 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero); 7588 7589#ifndef UNIT_BETA 7590 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA); 7591#endif // UNIT_BIAS 7592 7593 // c = c + bias 7594 ADD_BLOCK(M0, acc, bias); 7595 7596#endif // defined(BROADCAST_BIAS) 7597#endif // defined(BETA) 7598 7599#if defined(ACTIVATION_TYPE) 7600 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, acc, A_VAL, B_VAL); 7601#endif // defined(ACTIVATION_TYPE) 7602 7603 // Store output block 7604 const bool cond_y = get_global_id(1) == 0; 7605 const bool cond_x = ((get_global_id(0) + 1) * N0 >= N); 7606 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 7607} 7608#endif // defined(DATA_TYPE) 7609 7610/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped 7611 * 7612 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units. 7613 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0. 7614 * @note This kernel processed a fixed number of elements along x: -DN0=4. 7615 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN 7616 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 7617 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 7618 * @note The optional alpha's value need to be passed at compile time using -DALPHA 7619 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 7620 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 7621 * 7622 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 7623 * The activation function is performed after the bias addition 7624 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: 7625 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D 7626 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 7627 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 7628 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 7629 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 7630 * 7631 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 7632 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 7633 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 7634 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 7635 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 7636 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 7637 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 7638 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 7639 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 7640 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 7641 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 7642 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 7643 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 7644 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 7645 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 7646 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 7647 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 7648 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 7649 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 7650 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 7651 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) 7652 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 7653 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) 7654 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 7655 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 7656 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 7657 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 7658 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 7659 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) 7660 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 7661 */ 7662__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0), 7663 IMAGE_DECLARATION(src1), 7664#if defined(BETA) 7665 IMAGE_DECLARATION(src2), 7666#endif // defined(BETA) 7667 IMAGE_DECLARATION(dst), 7668 uint src0_stride_z, 7669 uint src1_stride_z, 7670#if defined(BETA) 7671 uint src2_stride_z, 7672#endif //defined(BETA) 7673 uint dst_stride_z 7674#if defined(REINTERPRET_INPUT_AS_3D) 7675 , 7676 uint src_cross_plane_pad 7677#endif // REINTERPRET_INPUT_AS_3D 7678#if defined(REINTERPRET_OUTPUT_AS_3D) 7679 , 7680 uint dst_cross_plane_pad 7681#endif // REINTERPRET_OUTPUT_AS_3D 7682 ) 7683{ 7684 int idx = get_global_id(0) * N0; 7685 7686 // Compute starting address for matrix A and matrix B 7687 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); 7688 7689 // Update address for matrix A 7690 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y; 7691 7692 // Update address for matrix B 7693 src_addr.s1 += idx * sizeof(float); 7694 7695#if defined(REINTERPRET_INPUT_AS_3D) 7696 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension 7697 // in order to take into account the presence of possible cross plane paddings 7698 // 7699 // | | 7700 // | plane0 | 7701 // | | 7702 // |__________________| 7703 // |******************| 7704 // | cross_plane_pad | 7705 // |******************| 7706 // | | 7707 // | plane1 | 7708 // | | 7709 // |__________________| 7710 7711 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D 7712 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 7713 zin = min(DEPTH_GEMM3D - 1, zin); 7714 7715 // Add offset due to the cross plane paddings 7716 zin *= (src_cross_plane_pad * src0_stride_y); 7717 7718 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 7719 // multiply src0_stride_z by DEPTH_GEMM3D 7720 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; 7721 7722#else // defined(REINTERPRET_INPUT_AS_3D) 7723 7724 // Add offset for batched GEMM 7725 src_addr.s0 += get_global_id(2) * src0_stride_z; 7726 7727#endif // defined(REINTERPRET_INPUT_AS_3D) 7728 7729#if defined(MATRIX_B_DEPTH) 7730 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 7731 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; 7732#else // defined(MATRIX_B_DEPTH) 7733 src_addr.s1 += get_global_id(2) * src1_stride_z; 7734#endif // defined(MATRIX_B_DEPTH) 7735 7736 // Initialize accumulators 7737 float4 acc0 = 0.0f; 7738 7739#if M0 > 1 7740 float4 acc1 = 0.0f; 7741#endif // M0 > 1 7742 7743#if M0 > 2 7744 float4 acc2 = 0.0f; 7745#endif // M0 > 2 7746 7747#if M0 > 3 7748 float4 acc3 = 0.0f; 7749#endif // M0 > 3 7750 7751 // A and B src indices get incremented at the same time. 7752 int i = 0; 7753 for(; i <= ((int)K - 4); i += 4) 7754 { 7755#if defined(REINTERPRET_INPUT_AS_3D) 7756 // Load values from matrix A and matrix B 7757 LOAD_BLOCK(M0, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); 7758#else // defined(REINTERPRET_INPUT_AS_3D) 7759 // Load values from matrix A and matrix B 7760 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 7761#if M0 > 1 7762 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 7763#endif // M0 > 1 7764#if M0 > 2 7765 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 7766#endif // M0 > 2 7767#if M0 > 3 7768 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 7769#endif // M0 > 3 7770#endif // defined(REINTERPRET_INPUT_AS_3D) 7771 7772 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); 7773 src_addr.s1 += src1_stride_y; 7774 7775 // Multiply and accumulate 7776 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0); 7777 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1); 7778 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2); 7779 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3); 7780 7781#if M0 > 1 7782 7783 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0); 7784 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1); 7785 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2); 7786 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3); 7787 7788#endif // M0 > 1 7789#if M0 > 2 7790 7791 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0); 7792 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1); 7793 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2); 7794 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3); 7795 7796#endif // M0 > 2 7797#if M0 > 3 7798 7799 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0); 7800 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1); 7801 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2); 7802 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3); 7803#endif // M0 > 3 7804 7805 // Load values from matrix A and matrix B 7806 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); 7807 src_addr.s1 += src1_stride_y; 7808 7809 // Multiply and accumulate 7810 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0); 7811 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1); 7812 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2); 7813 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3); 7814 7815#if M0 > 1 7816 7817 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0); 7818 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1); 7819 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2); 7820 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3); 7821 7822#endif // M0 > 1 7823#if M0 > 2 7824 7825 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0); 7826 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1); 7827 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2); 7828 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3); 7829 7830#endif // M0 > 2 7831#if M0 > 3 7832 7833 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0); 7834 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1); 7835 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2); 7836 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3); 7837#endif // M0 > 3 7838 7839 // Load values from matrix A and matrix B 7840 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); 7841 src_addr.s1 += src1_stride_y; 7842 7843 // Multiply and accumulate 7844 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0); 7845 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1); 7846 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2); 7847 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3); 7848 7849#if M0 > 1 7850 7851 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0); 7852 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1); 7853 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2); 7854 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3); 7855 7856#endif // M0 > 1 7857#if M0 > 2 7858 7859 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0); 7860 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1); 7861 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2); 7862 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3); 7863 7864#endif // M0 > 2 7865#if M0 > 3 7866 7867 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0); 7868 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1); 7869 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2); 7870 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3); 7871#endif // M0 > 3 7872 7873 // Load values from matrix A and matrix B 7874 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); 7875 src_addr.s1 += src1_stride_y; 7876 7877 // Multiply and accumulate 7878 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0); 7879 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1); 7880 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2); 7881 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3); 7882 7883#if M0 > 1 7884 7885 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0); 7886 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1); 7887 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2); 7888 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3); 7889 7890#endif // M0 > 1 7891#if M0 > 2 7892 7893 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0); 7894 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1); 7895 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2); 7896 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3); 7897 7898#endif // M0 > 2 7899#if M0 > 3 7900 7901 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0); 7902 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1); 7903 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2); 7904 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3); 7905#endif // M0 > 3 7906 7907 src_addr.s0 += 4 * sizeof(float); 7908 } 7909 7910 for(; i < (int)K; ++i) 7911 { 7912#if defined(REINTERPRET_INPUT_AS_3D) 7913 // Load values from matrix A 7914 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); 7915#if M0 > 1 7916 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); 7917#endif // M0 > 1 7918#if M0 > 2 7919 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); 7920#endif // M0 > 2 7921#if M0 > 3 7922 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); 7923#endif // M0 > 3 7924#else // defined(REINTERPRET_INPUT_AS_3D) 7925 // Load values from matrix A 7926 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 7927#if M0 > 1 7928 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 7929#endif // M0 > 1 7930#if M0 > 2 7931 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 7932#endif // M0 > 2 7933#if M0 > 3 7934 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 7935#endif // M0 > 3 7936#endif // defined(REINTERPRET_INPUT_AS_3D) 7937 7938 // Load values from matrix B 7939 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1)); 7940 src_addr.s1 += src1_stride_y; 7941 7942 // Multiply and accumulate 7943 acc0.s0 = fma(a0, b0.s0, acc0.s0); 7944 acc0.s1 = fma(a0, b0.s1, acc0.s1); 7945 acc0.s2 = fma(a0, b0.s2, acc0.s2); 7946 acc0.s3 = fma(a0, b0.s3, acc0.s3); 7947#if M0 > 1 7948 acc1.s0 = fma(a1, b0.s0, acc1.s0); 7949 acc1.s1 = fma(a1, b0.s1, acc1.s1); 7950 acc1.s2 = fma(a1, b0.s2, acc1.s2); 7951 acc1.s3 = fma(a1, b0.s3, acc1.s3); 7952#endif // M0 > 1 7953#if M0 > 2 7954 acc2.s0 = fma(a2, b0.s0, acc2.s0); 7955 acc2.s1 = fma(a2, b0.s1, acc2.s1); 7956 acc2.s2 = fma(a2, b0.s2, acc2.s2); 7957 acc2.s3 = fma(a2, b0.s3, acc2.s3); 7958#endif // M0 > 2 7959#if M0 > 3 7960 acc3.s0 = fma(a3, b0.s0, acc3.s0); 7961 acc3.s1 = fma(a3, b0.s1, acc3.s1); 7962 acc3.s2 = fma(a3, b0.s2, acc3.s2); 7963 acc3.s3 = fma(a3, b0.s3, acc3.s3); 7964#endif // M0 > 3 7965 7966 src_addr.s0 += sizeof(float); 7967 } 7968 7969 int z = get_global_id(2); 7970 7971 // Compute dst address 7972 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 7973 PARTIAL_STORE_M0) * dst_stride_y); 7974 7975 uint4 zout = 0; 7976 7977#if defined(REINTERPRET_OUTPUT_AS_3D) 7978 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 7979 // in order to take into account the presence of possible cross plane paddings 7980 // 7981 // | | 7982 // | plane0 | 7983 // | | 7984 // |__________________| 7985 // |******************| 7986 // | cross_plane_pad | 7987 // |******************| 7988 // | | 7989 // | plane1 | 7990 // | | 7991 // |__________________| 7992 7993 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D 7994 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 7995 zout = min(DEPTH_GEMM3D - 1, zout); 7996 7997 // Add offset due to the cross plane paddings 7998 zout *= (dst_cross_plane_pad * dst_stride_y); 7999 8000 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 8001 // multiply dst_stride_z by DEPTH_GEMM3D 8002 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 8003#else // defined(REINTERPRET_OUTPUT_AS_3D) 8004 // Add offset for batched GEMM 8005 dst_addr += z * dst_stride_z; 8006#endif // defined(REINTERPRET_OUTPUT_AS_3D) 8007 8008 // Multiply by the weight of matrix-matrix product and store the result 8009#if defined(ALPHA) 8010 SCALE_BLOCK(M0, float, acc, ALPHA); 8011#endif // defined(ALPHA) 8012 8013 // Add beta*bias 8014#if defined(BETA) 8015 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); 8016 8017#if defined(BROADCAST_BIAS) 8018 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)); 8019 8020 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero); 8021 8022#ifndef UNIT_BETA 8023 SCALE_BLOCK(1, float, bias, BETA); 8024#endif // UNIT_BIAS 8025 8026 // acc = acc + bias[broadcasted] 8027 ADD_BLOCK_BROADCAST(M0, acc, bias0); 8028 8029#else // defined(BROADCAST_BIAS) 8030 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 8031 PARTIAL_STORE_M0) 8032 * src2_stride_y) 8033 + z * src2_stride_z; 8034 8035 LOAD_BLOCK(M0, 4, float, bias, src2_addr, 0, src2_stride_y, zero); 8036 8037#ifndef UNIT_BETA 8038 SCALE_BLOCK(M0, float, bias, BETA); 8039#endif // UNIT_BIAS 8040 8041 // acc = acc + bias 8042 ADD_BLOCK(M0, acc, bias); 8043 8044#endif // defined(BROADCAST_BIAS) 8045#endif // defined(BETA) 8046 8047#if defined(ACTIVATION_TYPE) 8048 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL); 8049#endif // defined(ACTIVATION_TYPE) 8050 8051 // Store the output block 8052 const bool cond_y = get_global_id(1) == 0; 8053 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N); 8054 STORE_BLOCK_BOUNDARY_AWARE(M0, 4, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 8055} 8056 8057/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped 8058 * 8059 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units. 8060 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000. 8061 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0. 8062 * @note This kernel processed a fixed number of elements along x: -DN0=2. 8063 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN 8064 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 8065 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 8066 * @note The optional alpha's value need to be passed at compile time using -DALPHA 8067 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 8068 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 8069 * 8070 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 8071 * The activation function is performed after the bias addition 8072 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: 8073 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D 8074 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 8075 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 8076 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 8077 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 8078 * 8079 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32 8080 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 8081 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 8082 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 8083 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 8084 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 8085 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 8086 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 8087 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 8088 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 8089 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 8090 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 8091 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 8092 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 8093 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 8094 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 8095 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 8096 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 8097 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 8098 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 8099 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) 8100 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 8101 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) 8102 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 8103 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 8104 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 8105 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 8106 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 8107 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) 8108 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 8109 */ 8110__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0), 8111 IMAGE_DECLARATION(src1), 8112#if defined(BETA) 8113 IMAGE_DECLARATION(src2), 8114#endif // defined(BETA) 8115 IMAGE_DECLARATION(dst), 8116 uint src0_stride_z, 8117 uint src1_stride_z, 8118#if defined(BETA) 8119 uint src2_stride_z, 8120#endif //defined(BETA) 8121 uint dst_stride_z 8122#if defined(REINTERPRET_INPUT_AS_3D) 8123 , 8124 uint src_cross_plane_pad 8125#endif // REINTERPRET_INPUT_AS_3D 8126#if defined(REINTERPRET_OUTPUT_AS_3D) 8127 , 8128 uint dst_cross_plane_pad 8129#endif // REINTERPRET_OUTPUT_AS_3D 8130 ) 8131{ 8132 // Requires 2 N0, C vect2, A vect4, B (2 vload2) // to fix for M0 > 1 8133 int idx = get_global_id(0) * N0; 8134 8135 // Compute starting address for matrix A and Matrix B 8136 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); 8137 8138 // Update address for the matrix A 8139 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y; 8140 8141 // Update address for the matrix B 8142 src_addr.s1 += idx * sizeof(float); 8143 8144#if defined(REINTERPRET_INPUT_AS_3D) 8145 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension 8146 // in order to take into account the presence of possible cross plane paddings 8147 // 8148 // | | 8149 // | plane0 | 8150 // | | 8151 // |__________________| 8152 // |******************| 8153 // | cross_plane_pad | 8154 // |******************| 8155 // | | 8156 // | plane1 | 8157 // | | 8158 // |__________________| 8159 8160 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D 8161 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 8162 zin = min(DEPTH_GEMM3D - 1, zin); 8163 8164 // Add offset due to the cross plane paddings 8165 zin *= (src_cross_plane_pad * src0_stride_y); 8166 8167 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 8168 // multiply src0_stride_z by DEPTH_GEMM3D 8169 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; 8170 8171#else // defined(REINTERPRET_INPUT_AS_3D) 8172 8173 // Add offset for batched GEMM 8174 src_addr.s0 += get_global_id(2) * src0_stride_z; 8175 8176#endif // defined(REINTERPRET_INPUT_AS_3D) 8177 8178#if defined(MATRIX_B_DEPTH) 8179 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 8180 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; 8181#else // defined(MATRIX_B_DEPTH) 8182 src_addr.s1 += get_global_id(2) * src1_stride_z; 8183#endif // defined(MATRIX_B_DEPTH) 8184 8185 // Initialize accumulators 8186 float2 acc0 = 0.0f; 8187#if M0 > 1 8188 float2 acc1 = 0.0f; 8189#endif // M0 > 1 8190#if M0 > 2 8191 float2 acc2 = 0.0f; 8192#endif // M0 > 2 8193#if M0 > 3 8194 float2 acc3 = 0.0f; 8195#endif // M0 > 3 8196 8197 // A and B src indices get incremented at the same time. 8198 int i = 0; 8199 for(; i <= ((int)K - 8); i += 8) 8200 { 8201#if defined(REINTERPRET_INPUT_AS_3D) 8202 // Load values from matrix A 8203 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0)); 8204#else // defined(REINTERPRET_INPUT_AS_3D) 8205 // Load values from matrix A 8206 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0)); 8207#endif // defined(REINTERPRET_INPUT_AS_3D) 8208 8209 // Load values from matrix B 8210 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8211 src_addr.s1 += src1_stride_y; 8212 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8213 src_addr.s1 += src1_stride_y; 8214 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8215 src_addr.s1 += src1_stride_y; 8216 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8217 src_addr.s1 += src1_stride_y; 8218 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8219 src_addr.s1 += src1_stride_y; 8220 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8221 src_addr.s1 += src1_stride_y; 8222 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8223 src_addr.s1 += src1_stride_y; 8224 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8225 src_addr.s1 += src1_stride_y; 8226 8227 // Multiply and accumulate 8228 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0); 8229 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0); 8230 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0); 8231 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0); 8232 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0); 8233 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0); 8234 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0); 8235 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0); 8236 8237 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1); 8238 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1); 8239 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1); 8240 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1); 8241 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1); 8242 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1); 8243 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1); 8244 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1); 8245 8246#if M0 > 1 8247#if defined(REINTERPRET_INPUT_AS_3D) 8248 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); 8249#else // defined(REINTERPRET_INPUT_AS_3D) 8250 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 8251#endif // defined(REINTERPRET_INPUT_AS_3D) 8252 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0); 8253 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0); 8254 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0); 8255 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0); 8256 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0); 8257 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0); 8258 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0); 8259 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0); 8260 8261 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1); 8262 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1); 8263 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1); 8264 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1); 8265 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1); 8266 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1); 8267 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1); 8268 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1); 8269#endif // M0 > 1 8270#if M0 > 2 8271#if defined(REINTERPRET_INPUT_AS_3D) 8272 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); 8273#else // defined(REINTERPRET_INPUT_AS_3D) 8274 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 8275#endif // defined(REINTERPRET_INPUT_AS_3D) 8276 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0); 8277 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0); 8278 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0); 8279 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0); 8280 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0); 8281 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0); 8282 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0); 8283 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0); 8284 8285 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1); 8286 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1); 8287 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1); 8288 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1); 8289 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1); 8290 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1); 8291 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1); 8292 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1); 8293#endif // M0 > 2 8294#if M0 > 3 8295#if defined(REINTERPRET_INPUT_AS_3D) 8296 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); 8297#else // defined(REINTERPRET_INPUT_AS_3D) 8298 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 8299#endif // defined(REINTERPRET_INPUT_AS_3D) 8300 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0); 8301 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0); 8302 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0); 8303 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0); 8304 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0); 8305 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0); 8306 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0); 8307 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0); 8308 8309 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1); 8310 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1); 8311 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1); 8312 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1); 8313 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1); 8314 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1); 8315 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1); 8316 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1); 8317#endif // M0 > 3 8318 8319 src_addr.s0 += sizeof(float) * 8; 8320 } 8321 // float size increment 8322 for(; i < (int)K; ++i) 8323 { 8324#if defined(REINTERPRET_INPUT_AS_3D) 8325 // Load values from matrix A 8326 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); 8327#if M0 > 1 8328 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); 8329#endif // M0 > 1 8330#if M0 > 2 8331 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); 8332#endif // M0 > 2 8333#if M0 > 3 8334 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); 8335#endif // M0 > 3 8336#else // defined(REINTERPRET_INPUT_AS_3D) 8337 // Load values from matrix A 8338 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 8339#if M0 > 1 8340 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 8341#endif // M0 > 1 8342#if M0 > 2 8343 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 8344#endif // M0 > 2 8345#if M0 > 3 8346 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 8347#endif // M0 > 3 8348#endif // defined(REINTERPRET_INPUT_AS_3D) 8349 8350 // Load values from matrix B 8351 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1)); 8352 src_addr.s1 += src1_stride_y; 8353 8354 // Multiply and accumulate 8355 acc0.s0 = fma(a0, b0.s0, acc0.s0); 8356 acc0.s1 = fma(a0, b0.s1, acc0.s1); 8357#if M0 > 1 8358 acc1.s0 = fma(a1, b0.s0, acc1.s0); 8359 acc1.s1 = fma(a1, b0.s1, acc1.s1); 8360#endif // M0 > 1 8361#if M0 > 2 8362 acc2.s0 = fma(a2, b0.s0, acc2.s0); 8363 acc2.s1 = fma(a2, b0.s1, acc2.s1); 8364#endif // M0 > 2 8365#if M0 > 3 8366 acc3.s0 = fma(a3, b0.s0, acc3.s0); 8367 acc3.s1 = fma(a3, b0.s1, acc3.s1); 8368#endif // M0 > 3 8369 8370 src_addr.s0 += sizeof(float); 8371 } 8372 8373 int z = get_global_id(2); 8374 8375 // Compute dst address 8376 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 8377 PARTIAL_STORE_M0) * dst_stride_y); 8378 8379 uint4 zout = 0; 8380 8381#if defined(REINTERPRET_OUTPUT_AS_3D) 8382 8383 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 8384 // in order to take into account the presence of possible cross plane paddings 8385 // 8386 // | | 8387 // | plane0 | 8388 // | | 8389 // |__________________| 8390 // |******************| 8391 // | cross_plane_pad | 8392 // |******************| 8393 // | | 8394 // | plane1 | 8395 // | | 8396 // |__________________| 8397 8398 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D 8399 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 8400 zout = min(DEPTH_GEMM3D - 1, zout); 8401 8402 // Add offset due to the cross plane paddings 8403 zout *= (dst_cross_plane_pad * dst_stride_y); 8404 8405 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 8406 // multiply dst_stride_z by DEPTH_GEMM3D 8407 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 8408#else // defined(REINTERPRET_OUTPUT_AS_3D) 8409 // Add offset for batched GEMM 8410 dst_addr += z * dst_stride_z; 8411#endif // defined(REINTERPRET_OUTPUT_AS_3D) 8412 8413 // Multiply by the weight of matrix-matrix product and store the result 8414#if defined(ALPHA) 8415 SCALE_BLOCK(M0, float, acc, ALPHA); 8416#endif // defined(ALPHA) 8417 8418 // Add beta*bias 8419#if defined(BETA) 8420 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); 8421 8422#if defined(BROADCAST_BIAS) 8423 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)); 8424 8425 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero); 8426 8427#ifndef UNIT_BETA 8428 SCALE_BLOCK(1, float, bias, BETA); 8429#endif // UNIT_BIAS 8430 8431 // acc = acc + bias[broadcasted] 8432 ADD_BLOCK_BROADCAST(M0, acc, bias0); 8433 8434#else // defined(BROADCAST_BIAS) 8435 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 8436 PARTIAL_STORE_M0) 8437 * src2_stride_y) 8438 + z * src2_stride_z; 8439 8440 LOAD_BLOCK(M0, 2, float, bias, src2_addr, 0, src2_stride_y, zero); 8441 8442#ifndef UNIT_BETA 8443 SCALE_BLOCK(M0, float, bias, BETA); 8444#endif // UNIT_BIAS 8445 8446 // acc = acc + bias 8447 ADD_BLOCK(M0, acc, bias); 8448 8449#endif // defined(BROADCAST_BIAS) 8450#endif // defined(BETA) 8451 8452#if defined(ACTIVATION_TYPE) 8453 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL); 8454#endif // defined(ACTIVATION_TYPE) 8455 8456 // Store the output block 8457 const bool cond_y = get_global_id(1) == 0; 8458 const bool cond_x = ((get_global_id(0) + 1) * 2 >= N); 8459 STORE_BLOCK_BOUNDARY_AWARE(M0, 2, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 8460} 8461 8462#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 8463/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped 8464 * 8465 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable. 8466 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0. 8467 * @note This kernel processed a fixed number of elements along x: -DN0=8. 8468 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN 8469 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 8470 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 8471 * @note The optional alpha's value need to be passed at compile time using -DALPHA 8472 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 8473 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 8474 * 8475 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 8476 * The activation function is performed after the bias addition 8477 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: 8478 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D 8479 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 8480 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 8481 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 8482 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 8483 * 8484 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 8485 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 8486 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 8487 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 8488 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 8489 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 8490 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 8491 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 8492 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 8493 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 8494 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 8495 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 8496 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 8497 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 8498 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 8499 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 8500 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 8501 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 8502 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 8503 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 8504 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) 8505 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 8506 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) 8507 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 8508 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 8509 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 8510 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 8511 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 8512 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) 8513 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 8514 */ 8515__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0), 8516 IMAGE_DECLARATION(src1), 8517#if defined(BETA) 8518 IMAGE_DECLARATION(src2), 8519#endif // defined(BETA) 8520 IMAGE_DECLARATION(dst), 8521 uint src0_stride_z, 8522 uint src1_stride_z, 8523#if defined(BETA) 8524 uint src2_stride_z, 8525#endif //defined(BETA) 8526 uint dst_stride_z 8527#if defined(REINTERPRET_INPUT_AS_3D) 8528 , 8529 uint src_cross_plane_pad 8530#endif // REINTERPRET_INPUT_AS_3D 8531#if defined(REINTERPRET_OUTPUT_AS_3D) 8532 , 8533 uint dst_cross_plane_pad 8534#endif // REINTERPRET_OUTPUT_AS_3D 8535 ) 8536{ 8537 int idx = get_global_id(0) * N0; 8538 8539 // Compute starting address for matrix A and Matrix B 8540 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); 8541 8542 // Update address for the matrix A 8543 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y; 8544 8545 // Update address for the matrix B 8546 src_addr.s1 += idx * sizeof(half); 8547 8548#if defined(REINTERPRET_INPUT_AS_3D) 8549 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension 8550 // in order to take into account the presence of possible cross plane paddings 8551 // 8552 // | | 8553 // | plane0 | 8554 // | | 8555 // |__________________| 8556 // |******************| 8557 // | cross_plane_pad | 8558 // |******************| 8559 // | | 8560 // | plane1 | 8561 // | | 8562 // |__________________| 8563 8564 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D 8565 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 8566 zin = min(DEPTH_GEMM3D - 1, zin); 8567 8568 // Add offset due to the cross plane paddings 8569 zin *= (src_cross_plane_pad * src0_stride_y); 8570 8571 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 8572 // multiply src0_stride_z by DEPTH_GEMM3D 8573 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; 8574 8575#else // defined(REINTERPRET_INPUT_AS_3D) 8576 8577 // Add offset for batched GEMM 8578 src_addr.s0 += get_global_id(2) * src0_stride_z; 8579 8580#endif // defined(REINTERPRET_INPUT_AS_3D) 8581 8582#if defined(MATRIX_B_DEPTH) 8583 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 8584 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; 8585#else // defined(MATRIX_B_DEPTH) 8586 src_addr.s1 += get_global_id(2) * src1_stride_z; 8587#endif // defined(MATRIX_B_DEPTH) 8588 8589 float8 acc0 = 0.0h; 8590#if M0 > 1 8591 float8 acc1 = 0.0h; 8592#endif // M0 > 1 8593#if M0 > 2 8594 float8 acc2 = 0.0h; 8595#endif // M0 > 2 8596#if M0 > 3 8597 float8 acc3 = 0.0h; 8598#endif // M0 > 3 8599 8600 int i = 0; 8601 for(; i <= ((int)K - 4); i += 4) 8602 { 8603#if defined(REINTERPRET_INPUT_AS_3D) 8604 // Load values from matrix A 8605 LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); 8606#else // defined(REINTERPRET_INPUT_AS_3D) 8607 // Load values from matrix A 8608 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 8609#if M0 > 1 8610 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 8611#endif // M0 > 1 8612#if M0 > 2 8613 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 8614#endif // M0 > 2 8615#if M0 > 3 8616 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 8617#endif // M0 > 3 8618#endif // defined(REINTERPRET_INPUT_AS_3D) 8619 8620 // Load values from matrix B 8621 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); 8622 src_addr.s1 += src1_stride_y; 8623 8624 // Accumulate 8625 acc0 = fma(b0, (float8)a0.s0, acc0); 8626#if M0 > 1 8627 acc1 = fma(b0, (float8)a1.s0, acc1); 8628#endif // M0 > 1 8629#if M0 > 2 8630 acc2 = fma(b0, (float8)a2.s0, acc2); 8631#endif // M0 > 2 8632#if M0 > 3 8633 acc3 = fma(b0, (float8)a3.s0, acc3); 8634#endif // M0 > 3 8635 8636 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); 8637 src_addr.s1 += src1_stride_y; 8638 acc0 = fma(b0, (float8)a0.s1, acc0); 8639#if M0 > 1 8640 acc1 = fma(b0, (float8)a1.s1, acc1); 8641#endif // M0 > 1 8642#if M0 > 2 8643 acc2 = fma(b0, (float8)a2.s1, acc2); 8644#endif // M0 > 2 8645#if M0 > 3 8646 acc3 = fma(b0, (float8)a3.s1, acc3); 8647#endif // M0 > 3 8648 8649 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); 8650 src_addr.s1 += src1_stride_y; 8651 acc0 = fma(b0, (float8)a0.s2, acc0); 8652#if M0 > 1 8653 acc1 = fma(b0, (float8)a1.s2, acc1); 8654#endif // M0 > 1 8655#if M0 > 2 8656 acc2 = fma(b0, (float8)a2.s2, acc2); 8657#endif // M0 > 2 8658#if M0 > 3 8659 acc3 = fma(b0, (float8)a3.s2, acc3); 8660#endif // M0 > 3 8661 8662 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); 8663 src_addr.s1 += src1_stride_y; 8664 acc0 = fma(b0, (float8)a0.s3, acc0); 8665#if M0 > 1 8666 acc1 = fma(b0, (float8)a1.s3, acc1); 8667#endif // M0 > 1 8668#if M0 > 2 8669 acc2 = fma(b0, (float8)a2.s3, acc2); 8670#endif // M0 > 2 8671#if M0 > 3 8672 acc3 = fma(b0, (float8)a3.s3, acc3); 8673#endif // M0 > 3 8674 8675 src_addr.s0 += 4 * sizeof(half); 8676 } 8677 8678 for(; i < (int)K; ++i) 8679 { 8680#if defined(REINTERPRET_INPUT_AS_3D) 8681 // Load values from matrix A 8682 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); 8683#if M0 > 1 8684 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); 8685#endif // M0 > 1 8686#if M0 > 2 8687 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); 8688#endif // M0 > 2 8689#if M0 > 3 8690 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); 8691#endif // M0 > 3 8692#else // defined(REINTERPRET_INPUT_AS_3D) 8693 // Load values from matrix A 8694 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 8695#if M0 > 1 8696 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 8697#endif // M0 > 1 8698#if M0 > 2 8699 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 8700#endif // M0 > 2 8701#if M0 > 3 8702 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 8703#endif // M0 > 3 8704#endif // defined(REINTERPRET_INPUT_AS_3D) 8705 8706 // Load values from matrix B 8707 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1))); 8708 8709 src_addr += (int2)(sizeof(half), src1_stride_y); 8710 8711 // Accumulate 8712 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0; 8713#if M0 > 1 8714 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1; 8715#endif // M0 > 1 8716#if M0 > 2 8717 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2; 8718#endif // M0 > 2 8719#if M0 > 3 8720 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3; 8721#endif // M0 > 3 8722 } 8723 8724 int z = get_global_id(2); 8725 8726 // Compute dst address 8727 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y); 8728 8729 uint4 zout = 0; 8730 8731#if defined(REINTERPRET_OUTPUT_AS_3D) 8732 8733 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 8734 // in order to take into account the presence of possible cross plane paddings 8735 // 8736 // | | 8737 // | plane0 | 8738 // | | 8739 // |__________________| 8740 // |******************| 8741 // | cross_plane_pad | 8742 // |******************| 8743 // | | 8744 // | plane1 | 8745 // | | 8746 // |__________________| 8747 8748 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D 8749 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 8750 zout = min(DEPTH_GEMM3D - 1, zout); 8751 8752 // Add offset due to the cross plane paddings 8753 zout *= (dst_cross_plane_pad * dst_stride_y); 8754 8755 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 8756 // multiply dst_stride_z by DEPTH_GEMM3D 8757 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 8758#else // defined(REINTERPRET_OUTPUT_AS_3D) 8759 // Add offset for batched GEMM 8760 dst_addr += z * dst_stride_z; 8761#endif // defined(REINTERPRET_OUTPUT_AS_3D) 8762 8763 // Multiply by the weight of matrix-matrix product and store the result 8764#if defined(ALPHA) 8765 SCALE_BLOCK(M0, float, acc, ALPHA); 8766#endif // defined(ALPHA) 8767 8768#if defined(BETA) 8769 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); 8770 8771#if defined(BROADCAST_BIAS) 8772 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); 8773 8774 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 8775 8776 float8 bias_f0 = convert_float8(bias0); 8777 8778#ifndef UNIT_BETA 8779 SCALE_BLOCK(1, float, bias_f, BETA); 8780#endif // UNIT_BIAS 8781 8782 // acc = acc + bias[broadcasted] 8783 ADD_BLOCK_BROADCAST(M0, acc, bias_f0); 8784 8785#else // defined(BROADCAST_BIAS) 8786 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 8787 PARTIAL_STORE_M0) 8788 * src2_stride_y) 8789 + z * src2_stride_z; 8790 8791 LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 8792 8793 float8 bias_f0 = convert_float8(bias0); 8794#if M0 > 1 8795 float8 bias_f1 = convert_float8(bias1); 8796#endif // M0 > 1 8797#if M0 > 2 8798 float8 bias_f2 = convert_float8(bias2); 8799#endif // M0 > 2 8800#if M0 > 3 8801 float8 bias_f3 = convert_float8(bias3); 8802#endif // M0 > 3 8803 8804#ifndef UNIT_BETA 8805 SCALE_BLOCK(M0, float, bias_f, BETA); 8806#endif // UNIT_BIAS 8807 8808 // acc = acc + bias 8809 ADD_BLOCK(M0, acc, bias_f); 8810 8811#endif // defined(BROADCAST_BIAS) 8812#endif // defined(BETA) 8813 8814 half8 acc_h0 = convert_half8(acc0); 8815#if M0 > 1 8816 half8 acc_h1 = convert_half8(acc1); 8817#endif // M0 > 1 8818#if M0 > 2 8819 half8 acc_h2 = convert_half8(acc2); 8820#endif // M0 > 2 8821#if M0 > 3 8822 half8 acc_h3 = convert_half8(acc3); 8823#endif // M0 > 3 8824 8825#if defined(ACTIVATION_TYPE) 8826 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc_h, A_VAL, B_VAL); 8827#endif // defined(ACTIVATION_TYPE) 8828 8829 // Store the output block 8830 const bool cond_y = get_global_id(1) == 0; 8831 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N); 8832 STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 8833} 8834 8835/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped 8836 * 8837 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units. 8838 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0. 8839 * @note This kernel processed a fixed number of elements along x: -DN0=8. 8840 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN 8841 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1) 8842 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1) 8843 * @note The optional alpha's value need to be passed at compile time using -DALPHA 8844 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16) 8845 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16]) 8846 * 8847 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively. 8848 * The activation function is performed after the bias addition 8849 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: 8850 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D 8851 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D 8852 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor. 8853 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor 8854 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped 8855 * 8856 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16 8857 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes) 8858 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 8859 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes) 8860 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 8861 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix 8862 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr 8863 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes) 8864 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes) 8865 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes) 8866 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) 8867 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix 8868 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr 8869 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes) 8870 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes) 8871 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes) 8872 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes) 8873 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix 8874 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr 8875 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) 8876 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes) 8877 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) 8878 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes) 8879 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix 8880 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes) 8881 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes) 8882 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes) 8883 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) 8884 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D) 8885 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D) 8886 */ 8887__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0), 8888 IMAGE_DECLARATION(src1), 8889#if defined(BETA) 8890 IMAGE_DECLARATION(src2), 8891#endif // defined(BETA) 8892 IMAGE_DECLARATION(dst), 8893 uint src0_stride_z, 8894 uint src1_stride_z, 8895#if defined(BETA) 8896 uint src2_stride_z, 8897#endif //defined(BETA) 8898 uint dst_stride_z 8899#if defined(REINTERPRET_INPUT_AS_3D) 8900 , 8901 uint src_cross_plane_pad 8902#endif // REINTERPRET_INPUT_AS_3D 8903#if defined(REINTERPRET_OUTPUT_AS_3D) 8904 , 8905 uint dst_cross_plane_pad 8906#endif // REINTERPRET_OUTPUT_AS_3D 8907 ) 8908{ 8909 int idx = get_global_id(0) * N0; 8910 8911 // Compute starting address for matrix A and Matrix B 8912 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes)); 8913 8914 // Update address for the matrix A 8915 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y; 8916 8917 // Update address for the matrix B 8918 src_addr.s1 += idx * sizeof(half); 8919 8920#if defined(REINTERPRET_INPUT_AS_3D) 8921 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension 8922 // in order to take into account the presence of possible cross plane paddings 8923 // 8924 // | | 8925 // | plane0 | 8926 // | | 8927 // |__________________| 8928 // |******************| 8929 // | cross_plane_pad | 8930 // |******************| 8931 // | | 8932 // | plane1 | 8933 // | | 8934 // |__________________| 8935 8936 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D 8937 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 8938 zin = min(DEPTH_GEMM3D - 1, zin); 8939 8940 // Add offset due to the cross plane paddings 8941 zin *= (src_cross_plane_pad * src0_stride_y); 8942 8943 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 8944 // multiply src0_stride_z by DEPTH_GEMM3D 8945 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D; 8946 8947#else // defined(REINTERPRET_INPUT_AS_3D) 8948 8949 // Add offset for batched GEMM 8950 src_addr.s0 += get_global_id(2) * src0_stride_z; 8951 8952#endif // defined(REINTERPRET_INPUT_AS_3D) 8953 8954#if defined(MATRIX_B_DEPTH) 8955 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3 8956 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z; 8957#else // defined(MATRIX_B_DEPTH) 8958 src_addr.s1 += get_global_id(2) * src1_stride_z; 8959#endif // defined(MATRIX_B_DEPTH) 8960 8961 half8 acc0 = 0.0h; 8962#if M0 > 1 8963 half8 acc1 = 0.0h; 8964#endif // M0 > 1 8965#if M0 > 2 8966 half8 acc2 = 0.0h; 8967#endif // M0 > 2 8968#if M0 > 3 8969 half8 acc3 = 0.0h; 8970#endif // M0 > 3 8971 8972 int i = 0; 8973 for(; i <= ((int)K - 4); i += 4) 8974 { 8975#if defined(REINTERPRET_INPUT_AS_3D) 8976 // Load values from matrix A 8977 LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s); 8978#else // defined(REINTERPRET_INPUT_AS_3D) 8979 // Load values from matrix A 8980 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 8981#if M0 > 1 8982 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 8983#endif // M0 > 1 8984#if M0 > 2 8985 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 8986#endif // M0 > 2 8987#if M0 > 3 8988 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 8989#endif // M0 > 3 8990#endif // defined(REINTERPRET_INPUT_AS_3D) 8991 8992 // Load values from matrix B 8993 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); 8994 src_addr.s1 += src1_stride_y; 8995 8996 // Accumulate 8997 acc0 = fma(b0, (half8)a0.s0, acc0); 8998#if M0 > 1 8999 acc1 = fma(b0, (half8)a1.s0, acc1); 9000#endif // M0 > 1 9001#if M0 > 2 9002 acc2 = fma(b0, (half8)a2.s0, acc2); 9003#endif // M0 > 2 9004#if M0 > 3 9005 acc3 = fma(b0, (half8)a3.s0, acc3); 9006#endif // M0 > 3 9007 9008 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); 9009 src_addr.s1 += src1_stride_y; 9010 acc0 = fma(b0, (half8)a0.s1, acc0); 9011#if M0 > 1 9012 acc1 = fma(b0, (half8)a1.s1, acc1); 9013#endif // M0 > 1 9014#if M0 > 2 9015 acc2 = fma(b0, (half8)a2.s1, acc2); 9016#endif // M0 > 2 9017#if M0 > 3 9018 acc3 = fma(b0, (half8)a3.s1, acc3); 9019#endif // M0 > 3 9020 9021 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); 9022 src_addr.s1 += src1_stride_y; 9023 acc0 = fma(b0, (half8)a0.s2, acc0); 9024#if M0 > 1 9025 acc1 = fma(b0, (half8)a1.s2, acc1); 9026#endif // M0 > 1 9027#if M0 > 2 9028 acc2 = fma(b0, (half8)a2.s2, acc2); 9029#endif // M0 > 2 9030#if M0 > 3 9031 acc3 = fma(b0, (half8)a3.s2, acc3); 9032#endif // M0 > 3 9033 9034 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); 9035 src_addr.s1 += src1_stride_y; 9036 acc0 = fma(b0, (half8)a0.s3, acc0); 9037#if M0 > 1 9038 acc1 = fma(b0, (half8)a1.s3, acc1); 9039#endif // M0 > 1 9040#if M0 > 2 9041 acc2 = fma(b0, (half8)a2.s3, acc2); 9042#endif // M0 > 2 9043#if M0 > 3 9044 acc3 = fma(b0, (half8)a3.s3, acc3); 9045#endif // M0 > 3 9046 9047 src_addr.s0 += 4 * sizeof(half); 9048 } 9049 9050 for(; i < (int)K; ++i) 9051 { 9052#if defined(REINTERPRET_INPUT_AS_3D) 9053 // Load values from matrix A 9054 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0)); 9055#if M0 > 1 9056 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1)); 9057#endif // M0 > 1 9058#if M0 > 2 9059 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2)); 9060#endif // M0 > 2 9061#if M0 > 3 9062 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3)); 9063#endif // M0 > 3 9064#else // defined(REINTERPRET_INPUT_AS_3D) 9065 // Load values from matrix A 9066 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); 9067#if M0 > 1 9068 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); 9069#endif // M0 > 1 9070#if M0 > 2 9071 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); 9072#endif // M0 > 2 9073#if M0 > 3 9074 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); 9075#endif // M0 > 3 9076#endif // defined(REINTERPRET_INPUT_AS_3D) 9077 9078 // Load values from matrix B 9079 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1)); 9080 9081 src_addr += (int2)(sizeof(half), src1_stride_y); 9082 9083 // Accumulate 9084 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0; 9085#if M0 > 1 9086 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1; 9087#endif // M0 > 1 9088#if M0 > 2 9089 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2; 9090#endif // M0 > 2 9091#if M0 > 3 9092 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3; 9093#endif // M0 > 3 9094 } 9095 9096 int z = get_global_id(2); 9097 9098 // Compute dst address 9099 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y); 9100 9101 uint4 zout = 0; 9102 9103#if defined(REINTERPRET_OUTPUT_AS_3D) 9104 9105 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension 9106 // in order to take into account the presence of possible cross plane paddings 9107 // 9108 // | | 9109 // | plane0 | 9110 // | | 9111 // |__________________| 9112 // |******************| 9113 // | cross_plane_pad | 9114 // |******************| 9115 // | | 9116 // | plane1 | 9117 // | | 9118 // |__________________| 9119 9120 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D 9121 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D; 9122 zout = min(DEPTH_GEMM3D - 1, zout); 9123 9124 // Add offset due to the cross plane paddings 9125 zout *= (dst_cross_plane_pad * dst_stride_y); 9126 9127 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we 9128 // multiply dst_stride_z by DEPTH_GEMM3D 9129 dst_addr += z * dst_stride_z * DEPTH_GEMM3D; 9130#else // defined(REINTERPRET_OUTPUT_AS_3D) 9131 // Add offset for batched GEMM 9132 dst_addr += z * dst_stride_z; 9133#endif // defined(REINTERPRET_OUTPUT_AS_3D) 9134 9135 // Multiply by the weight of matrix-matrix product and store the result 9136#if defined(ALPHA) 9137 SCALE_BLOCK(M0, half, acc, ALPHA); 9138#endif // defined(ALPHA) 9139 9140 // Add beta*bias 9141#if defined(BETA) 9142 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0); 9143 9144#if defined(BROADCAST_BIAS) 9145 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)); 9146 9147 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 9148 9149#ifndef UNIT_BETA 9150 SCALE_BLOCK(1, half, bias, BETA); 9151#endif // UNIT_BIAS 9152 9153 // acc = acc + bias[broadcasted] 9154 ADD_BLOCK_BROADCAST(M0, acc, bias0); 9155 9156#else // defined(BROADCAST_BIAS) 9157 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, 9158 PARTIAL_STORE_M0) 9159 * src2_stride_y) 9160 + z * src2_stride_z; 9161 9162 LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero); 9163 9164#ifndef UNIT_BETA 9165 SCALE_BLOCK(M0, half, bias, BETA); 9166#endif // UNIT_BIAS 9167 9168 // acc = acc + bias 9169 ADD_BLOCK(M0, acc, bias); 9170 9171#endif // defined(BROADCAST_BIAS) 9172#endif // defined(BETA) 9173 9174#if defined(ACTIVATION_TYPE) 9175 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc, A_VAL, B_VAL); 9176#endif // defined(ACTIVATION_TYPE) 9177 9178 // Store the output block 9179 const bool cond_y = get_global_id(1) == 0; 9180 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N); 9181 STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x); 9182} 9183#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED) 9184 9185#endif // defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) 9186 9187)"