1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2020, Arm Limited and Contributors 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_PACKET_MATH_SVE_H 11 #define EIGEN_PACKET_MATH_SVE_H 12 13 namespace Eigen 14 { 15 namespace internal 16 { 17 #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 18 #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8 19 #endif 20 21 #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD 22 #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD 23 #endif 24 25 #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32 26 27 template <typename Scalar, int SVEVectorLength> 28 struct sve_packet_size_selector { 29 enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) }; 30 }; 31 32 /********************************* int32 **************************************/ 33 typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); 34 35 template <> 36 struct packet_traits<numext::int32_t> : default_packet_traits { 37 typedef PacketXi type; 38 typedef PacketXi half; // Half not implemented yet 39 enum { 40 Vectorizable = 1, 41 AlignedOnScalar = 1, 42 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size, 43 HasHalfPacket = 0, 44 45 HasAdd = 1, 46 HasSub = 1, 47 HasShift = 1, 48 HasMul = 1, 49 HasNegate = 1, 50 HasAbs = 1, 51 HasArg = 0, 52 HasAbs2 = 1, 53 HasMin = 1, 54 HasMax = 1, 55 HasConj = 1, 56 HasSetLinear = 0, 57 HasBlend = 0, 58 HasReduxp = 0 // Not implemented in SVE 59 }; 60 }; 61 62 template <> 63 struct unpacket_traits<PacketXi> { 64 typedef numext::int32_t type; 65 typedef PacketXi half; // Half not yet implemented 66 enum { 67 size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size, 68 alignment = Aligned64, 69 vectorizable = true, 70 masked_load_available = false, 71 masked_store_available = false 72 }; 73 }; 74 75 template <> 76 EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr) 77 { 78 svprfw(svptrue_b32(), addr, SV_PLDL1KEEP); 79 } 80 81 template <> 82 EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from) 83 { 84 return svdup_n_s32(from); 85 } 86 87 template <> 88 EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a) 89 { 90 numext::int32_t c[packet_traits<numext::int32_t>::size]; 91 for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i; 92 return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c)); 93 } 94 95 template <> 96 EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b) 97 { 98 return svadd_s32_z(svptrue_b32(), a, b); 99 } 100 101 template <> 102 EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b) 103 { 104 return svsub_s32_z(svptrue_b32(), a, b); 105 } 106 107 template <> 108 EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a) 109 { 110 return svneg_s32_z(svptrue_b32(), a); 111 } 112 113 template <> 114 EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a) 115 { 116 return a; 117 } 118 119 template <> 120 EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b) 121 { 122 return svmul_s32_z(svptrue_b32(), a, b); 123 } 124 125 template <> 126 EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b) 127 { 128 return svdiv_s32_z(svptrue_b32(), a, b); 129 } 130 131 template <> 132 EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c) 133 { 134 return svmla_s32_z(svptrue_b32(), c, a, b); 135 } 136 137 template <> 138 EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b) 139 { 140 return svmin_s32_z(svptrue_b32(), a, b); 141 } 142 143 template <> 144 EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b) 145 { 146 return svmax_s32_z(svptrue_b32(), a, b); 147 } 148 149 template <> 150 EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b) 151 { 152 return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); 153 } 154 155 template <> 156 EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b) 157 { 158 return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu); 159 } 160 161 template <> 162 EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b) 163 { 164 return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu); 165 } 166 167 template <> 168 EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/) 169 { 170 return svdup_n_s32_z(svptrue_b32(), 0xffffffffu); 171 } 172 173 template <> 174 EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/) 175 { 176 return svdup_n_s32_z(svptrue_b32(), 0); 177 } 178 179 template <> 180 EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b) 181 { 182 return svand_s32_z(svptrue_b32(), a, b); 183 } 184 185 template <> 186 EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b) 187 { 188 return svorr_s32_z(svptrue_b32(), a, b); 189 } 190 191 template <> 192 EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b) 193 { 194 return sveor_s32_z(svptrue_b32(), a, b); 195 } 196 197 template <> 198 EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b) 199 { 200 return svbic_s32_z(svptrue_b32(), a, b); 201 } 202 203 template <int N> 204 EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a) 205 { 206 return svasrd_n_s32_z(svptrue_b32(), a, N); 207 } 208 209 template <int N> 210 EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a) 211 { 212 return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N))); 213 } 214 215 template <int N> 216 EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a) 217 { 218 return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N)); 219 } 220 221 template <> 222 EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from) 223 { 224 EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from); 225 } 226 227 template <> 228 EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from) 229 { 230 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from); 231 } 232 233 template <> 234 EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from) 235 { 236 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} 237 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} 238 return svld1_gather_u32index_s32(svptrue_b32(), from, indices); 239 } 240 241 template <> 242 EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from) 243 { 244 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} 245 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} 246 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} 247 return svld1_gather_u32index_s32(svptrue_b32(), from, indices); 248 } 249 250 template <> 251 EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from) 252 { 253 EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from); 254 } 255 256 template <> 257 EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from) 258 { 259 EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from); 260 } 261 262 template <> 263 EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride) 264 { 265 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} 266 svint32_t indices = svindex_s32(0, stride); 267 return svld1_gather_s32index_s32(svptrue_b32(), from, indices); 268 } 269 270 template <> 271 EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from, Index stride) 272 { 273 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} 274 svint32_t indices = svindex_s32(0, stride); 275 svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from); 276 } 277 278 template <> 279 EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a) 280 { 281 // svlasta returns the first element if all predicate bits are 0 282 return svlasta_s32(svpfalse_b(), a); 283 } 284 285 template <> 286 EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a) 287 { 288 return svrev_s32(a); 289 } 290 291 template <> 292 EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a) 293 { 294 return svabs_s32_z(svptrue_b32(), a); 295 } 296 297 template <> 298 EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a) 299 { 300 return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a)); 301 } 302 303 template <> 304 EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a) 305 { 306 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), 307 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); 308 309 // Multiply the vector by its reverse 310 svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a)); 311 svint32_t half_prod; 312 313 // Extract the high half of the vector. Depending on the VL more reductions need to be done 314 if (EIGEN_ARM64_SVE_VL >= 2048) { 315 half_prod = svtbl_s32(prod, svindex_u32(32, 1)); 316 prod = svmul_s32_z(svptrue_b32(), prod, half_prod); 317 } 318 if (EIGEN_ARM64_SVE_VL >= 1024) { 319 half_prod = svtbl_s32(prod, svindex_u32(16, 1)); 320 prod = svmul_s32_z(svptrue_b32(), prod, half_prod); 321 } 322 if (EIGEN_ARM64_SVE_VL >= 512) { 323 half_prod = svtbl_s32(prod, svindex_u32(8, 1)); 324 prod = svmul_s32_z(svptrue_b32(), prod, half_prod); 325 } 326 if (EIGEN_ARM64_SVE_VL >= 256) { 327 half_prod = svtbl_s32(prod, svindex_u32(4, 1)); 328 prod = svmul_s32_z(svptrue_b32(), prod, half_prod); 329 } 330 // Last reduction 331 half_prod = svtbl_s32(prod, svindex_u32(2, 1)); 332 prod = svmul_s32_z(svptrue_b32(), prod, half_prod); 333 334 // The reduction is done to the first element. 335 return pfirst<PacketXi>(prod); 336 } 337 338 template <> 339 EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a) 340 { 341 return svminv_s32(svptrue_b32(), a); 342 } 343 344 template <> 345 EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a) 346 { 347 return svmaxv_s32(svptrue_b32(), a); 348 } 349 350 template <int N> 351 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) { 352 int buffer[packet_traits<numext::int32_t>::size * N] = {0}; 353 int i = 0; 354 355 PacketXi stride_index = svindex_s32(0, N); 356 357 for (i = 0; i < N; i++) { 358 svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); 359 } 360 for (i = 0; i < N; i++) { 361 kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size); 362 } 363 } 364 365 /********************************* float32 ************************************/ 366 367 typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL))); 368 369 template <> 370 struct packet_traits<float> : default_packet_traits { 371 typedef PacketXf type; 372 typedef PacketXf half; 373 374 enum { 375 Vectorizable = 1, 376 AlignedOnScalar = 1, 377 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size, 378 HasHalfPacket = 0, 379 380 HasAdd = 1, 381 HasSub = 1, 382 HasShift = 1, 383 HasMul = 1, 384 HasNegate = 1, 385 HasAbs = 1, 386 HasArg = 0, 387 HasAbs2 = 1, 388 HasMin = 1, 389 HasMax = 1, 390 HasConj = 1, 391 HasSetLinear = 0, 392 HasBlend = 0, 393 HasReduxp = 0, // Not implemented in SVE 394 395 HasDiv = 1, 396 HasFloor = 1, 397 398 HasSin = EIGEN_FAST_MATH, 399 HasCos = EIGEN_FAST_MATH, 400 HasLog = 1, 401 HasExp = 1, 402 HasSqrt = 0, 403 HasTanh = EIGEN_FAST_MATH, 404 HasErf = EIGEN_FAST_MATH 405 }; 406 }; 407 408 template <> 409 struct unpacket_traits<PacketXf> { 410 typedef float type; 411 typedef PacketXf half; // Half not yet implemented 412 typedef PacketXi integer_packet; 413 414 enum { 415 size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size, 416 alignment = Aligned64, 417 vectorizable = true, 418 masked_load_available = false, 419 masked_store_available = false 420 }; 421 }; 422 423 template <> 424 EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from) 425 { 426 return svdup_n_f32(from); 427 } 428 429 template <> 430 EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from) 431 { 432 return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from)); 433 } 434 435 template <> 436 EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a) 437 { 438 float c[packet_traits<float>::size]; 439 for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i; 440 return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c)); 441 } 442 443 template <> 444 EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b) 445 { 446 return svadd_f32_z(svptrue_b32(), a, b); 447 } 448 449 template <> 450 EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b) 451 { 452 return svsub_f32_z(svptrue_b32(), a, b); 453 } 454 455 template <> 456 EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a) 457 { 458 return svneg_f32_z(svptrue_b32(), a); 459 } 460 461 template <> 462 EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a) 463 { 464 return a; 465 } 466 467 template <> 468 EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b) 469 { 470 return svmul_f32_z(svptrue_b32(), a, b); 471 } 472 473 template <> 474 EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b) 475 { 476 return svdiv_f32_z(svptrue_b32(), a, b); 477 } 478 479 template <> 480 EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c) 481 { 482 return svmla_f32_z(svptrue_b32(), c, a, b); 483 } 484 485 template <> 486 EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b) 487 { 488 return svmin_f32_z(svptrue_b32(), a, b); 489 } 490 491 template <> 492 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) 493 { 494 return pmin<PacketXf>(a, b); 495 } 496 497 template <> 498 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) 499 { 500 return svminnm_f32_z(svptrue_b32(), a, b); 501 } 502 503 template <> 504 EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b) 505 { 506 return svmax_f32_z(svptrue_b32(), a, b); 507 } 508 509 template <> 510 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b) 511 { 512 return pmax<PacketXf>(a, b); 513 } 514 515 template <> 516 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b) 517 { 518 return svmaxnm_f32_z(svptrue_b32(), a, b); 519 } 520 521 // Float comparisons in SVE return svbool (predicate). Use svdup to set active 522 // lanes to 1 (0xffffffffu) and inactive lanes to 0. 523 template <> 524 EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b) 525 { 526 return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); 527 } 528 529 template <> 530 EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b) 531 { 532 return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu)); 533 } 534 535 template <> 536 EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b) 537 { 538 return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu)); 539 } 540 541 // Do a predicate inverse (svnot_b_z) on the predicate resulted from the 542 // greater/equal comparison (svcmpge_f32). Then fill a float vector with the 543 // active elements. 544 template <> 545 EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b) 546 { 547 return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu)); 548 } 549 550 template <> 551 EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a) 552 { 553 return svrintm_f32_z(svptrue_b32(), a); 554 } 555 556 template <> 557 EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/) 558 { 559 return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu)); 560 } 561 562 // Logical Operations are not supported for float, so reinterpret casts 563 template <> 564 EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b) 565 { 566 return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); 567 } 568 569 template <> 570 EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b) 571 { 572 return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); 573 } 574 575 template <> 576 EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b) 577 { 578 return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); 579 } 580 581 template <> 582 EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b) 583 { 584 return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b))); 585 } 586 587 template <> 588 EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from) 589 { 590 EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from); 591 } 592 593 template <> 594 EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from) 595 { 596 EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from); 597 } 598 599 template <> 600 EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from) 601 { 602 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} 603 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} 604 return svld1_gather_u32index_f32(svptrue_b32(), from, indices); 605 } 606 607 template <> 608 EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from) 609 { 610 svuint32_t indices = svindex_u32(0, 1); // index {base=0, base+step=1, base+step*2, ...} 611 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a1, a1, a2, a2, ...} 612 indices = svzip1_u32(indices, indices); // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...} 613 return svld1_gather_u32index_f32(svptrue_b32(), from, indices); 614 } 615 616 template <> 617 EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from) 618 { 619 EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from); 620 } 621 622 template <> 623 EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from) 624 { 625 EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from); 626 } 627 628 template <> 629 EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride) 630 { 631 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} 632 svint32_t indices = svindex_s32(0, stride); 633 return svld1_gather_s32index_f32(svptrue_b32(), from, indices); 634 } 635 636 template <> 637 EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride) 638 { 639 // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...} 640 svint32_t indices = svindex_s32(0, stride); 641 svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from); 642 } 643 644 template <> 645 EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a) 646 { 647 // svlasta returns the first element if all predicate bits are 0 648 return svlasta_f32(svpfalse_b(), a); 649 } 650 651 template <> 652 EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a) 653 { 654 return svrev_f32(a); 655 } 656 657 template <> 658 EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a) 659 { 660 return svabs_f32_z(svptrue_b32(), a); 661 } 662 663 // TODO(tellenbach): Should this go into MathFunctions.h? If so, change for 664 // all vector extensions and the generic version. 665 template <> 666 EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent) 667 { 668 return pfrexp_generic(a, exponent); 669 } 670 671 template <> 672 EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a) 673 { 674 return svaddv_f32(svptrue_b32(), a); 675 } 676 677 // Other reduction functions: 678 // mul 679 // Only works for SVE Vls multiple of 128 680 template <> 681 EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a) 682 { 683 EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0), 684 EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); 685 // Multiply the vector by its reverse 686 svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a)); 687 svfloat32_t half_prod; 688 689 // Extract the high half of the vector. Depending on the VL more reductions need to be done 690 if (EIGEN_ARM64_SVE_VL >= 2048) { 691 half_prod = svtbl_f32(prod, svindex_u32(32, 1)); 692 prod = svmul_f32_z(svptrue_b32(), prod, half_prod); 693 } 694 if (EIGEN_ARM64_SVE_VL >= 1024) { 695 half_prod = svtbl_f32(prod, svindex_u32(16, 1)); 696 prod = svmul_f32_z(svptrue_b32(), prod, half_prod); 697 } 698 if (EIGEN_ARM64_SVE_VL >= 512) { 699 half_prod = svtbl_f32(prod, svindex_u32(8, 1)); 700 prod = svmul_f32_z(svptrue_b32(), prod, half_prod); 701 } 702 if (EIGEN_ARM64_SVE_VL >= 256) { 703 half_prod = svtbl_f32(prod, svindex_u32(4, 1)); 704 prod = svmul_f32_z(svptrue_b32(), prod, half_prod); 705 } 706 // Last reduction 707 half_prod = svtbl_f32(prod, svindex_u32(2, 1)); 708 prod = svmul_f32_z(svptrue_b32(), prod, half_prod); 709 710 // The reduction is done to the first element. 711 return pfirst<PacketXf>(prod); 712 } 713 714 template <> 715 EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a) 716 { 717 return svminv_f32(svptrue_b32(), a); 718 } 719 720 template <> 721 EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a) 722 { 723 return svmaxv_f32(svptrue_b32(), a); 724 } 725 726 template<int N> 727 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel) 728 { 729 float buffer[packet_traits<float>::size * N] = {0}; 730 int i = 0; 731 732 PacketXi stride_index = svindex_s32(0, N); 733 734 for (i = 0; i < N; i++) { 735 svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]); 736 } 737 738 for (i = 0; i < N; i++) { 739 kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size); 740 } 741 } 742 743 template<> 744 EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent) 745 { 746 return pldexp_generic(a, exponent); 747 } 748 749 } // namespace internal 750 } // namespace Eigen 751 752 #endif // EIGEN_PACKET_MATH_SVE_H 753