1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2020 Everton Constantino (everton.constantino@ibm.com) 5 // Copyright (C) 2021 Chip Kerchner (chip.kerchner@ibm.com) 6 // 7 // This Source Code Form is subject to the terms of the Mozilla 8 // Public License v. 2.0. If a copy of the MPL was not distributed 9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 10 11 #ifndef EIGEN_MATRIX_PRODUCT_ALTIVEC_H 12 #define EIGEN_MATRIX_PRODUCT_ALTIVEC_H 13 14 #ifndef EIGEN_ALTIVEC_USE_CUSTOM_PACK 15 #define EIGEN_ALTIVEC_USE_CUSTOM_PACK 1 16 #endif 17 18 #include "MatrixProductCommon.h" 19 20 // Since LLVM doesn't support dynamic dispatching, force either always MMA or VSX 21 #if EIGEN_COMP_LLVM 22 #if !defined(EIGEN_ALTIVEC_DISABLE_MMA) && !defined(EIGEN_ALTIVEC_MMA_ONLY) 23 #ifdef __MMA__ 24 #define EIGEN_ALTIVEC_MMA_ONLY 25 #else 26 #define EIGEN_ALTIVEC_DISABLE_MMA 27 #endif 28 #endif 29 #endif 30 31 #ifdef __has_builtin 32 #if __has_builtin(__builtin_mma_assemble_acc) 33 #define ALTIVEC_MMA_SUPPORT 34 #endif 35 #endif 36 37 #if defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 38 #include "MatrixProductMMA.h" 39 #endif 40 41 /************************************************************************************************** 42 * TODO * 43 * - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). * 44 * - Check the possibility of transposing as GETREAL and GETIMAG when needed. * 45 **************************************************************************************************/ 46 namespace Eigen { 47 48 namespace internal { 49 50 /************************** 51 * Constants and typedefs * 52 **************************/ 53 template<typename Scalar> 54 struct quad_traits 55 { 56 typedef typename packet_traits<Scalar>::type vectortype; 57 typedef PacketBlock<vectortype,4> type; 58 typedef vectortype rhstype; 59 enum 60 { 61 vectorsize = packet_traits<Scalar>::size, 62 size = 4, 63 rows = 4 64 }; 65 }; 66 67 template<> 68 struct quad_traits<double> 69 { 70 typedef Packet2d vectortype; 71 typedef PacketBlock<vectortype,4> type; 72 typedef PacketBlock<Packet2d,2> rhstype; 73 enum 74 { 75 vectorsize = packet_traits<double>::size, 76 size = 2, 77 rows = 4 78 }; 79 }; 80 81 // MatrixProduct decomposes real/imaginary vectors into a real vector and an imaginary vector, this turned out 82 // to be faster than Eigen's usual approach of having real/imaginary pairs on a single vector. This constants then 83 // are responsible to extract from convert between Eigen's and MatrixProduct approach. 84 85 const static Packet16uc p16uc_GETREAL32 = { 0, 1, 2, 3, 86 8, 9, 10, 11, 87 16, 17, 18, 19, 88 24, 25, 26, 27}; 89 90 const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7, 91 12, 13, 14, 15, 92 20, 21, 22, 23, 93 28, 29, 30, 31}; 94 const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7, 95 16, 17, 18, 19, 20, 21, 22, 23}; 96 97 //[a,ai],[b,bi] = [ai,bi] 98 const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15, 99 24, 25, 26, 27, 28, 29, 30, 31}; 100 101 /********************************************* 102 * Single precision real and complex packing * 103 * *******************************************/ 104 105 /** 106 * Symm packing is related to packing of symmetric adjoint blocks, as expected the packing leaves 107 * the diagonal real, whatever is below it is copied from the respective upper diagonal element and 108 * conjugated. There's no PanelMode available for symm packing. 109 * 110 * Packing in general is supposed to leave the lhs block and the rhs block easy to be read by gemm using 111 * its respective rank-update instructions. The float32/64 versions are different because at this moment 112 * the size of the accumulator is fixed at 512-bits so you can't have a 4x4 accumulator of 64-bit elements. 113 * 114 * As mentioned earlier MatrixProduct breaks complex numbers into a real vector and a complex vector so packing has 115 * to take that into account, at the moment, we run pack the real part and then the imaginary part, this is the main 116 * reason why packing for complex is broken down into several different parts, also the reason why we endup having a 117 * float32/64 and complex float32/64 version. 118 **/ 119 template<typename Scalar, typename Index, int StorageOrder> 120 EIGEN_ALWAYS_INLINE std::complex<Scalar> getAdjointVal(Index i, Index j, const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder>& dt) 121 { 122 std::complex<Scalar> v; 123 if(i < j) 124 { 125 v.real( dt(j,i).real()); 126 v.imag(-dt(j,i).imag()); 127 } else if(i > j) 128 { 129 v.real( dt(i,j).real()); 130 v.imag( dt(i,j).imag()); 131 } else { 132 v.real( dt(i,j).real()); 133 v.imag((Scalar)0.0); 134 } 135 return v; 136 } 137 138 template<typename Scalar, typename Index, int StorageOrder, int N> 139 EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* blockB, const std::complex<Scalar>* _rhs, Index rhsStride, Index rows, Index cols, Index k2) 140 { 141 const Index depth = k2 + rows; 142 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> rhs(_rhs, rhsStride); 143 const Index vectorSize = N*quad_traits<Scalar>::vectorsize; 144 const Index vectorDelta = vectorSize * rows; 145 Scalar* blockBf = reinterpret_cast<Scalar *>(blockB); 146 147 Index rir = 0, rii, j = 0; 148 for(; j + vectorSize <= cols; j+=vectorSize) 149 { 150 rii = rir + vectorDelta; 151 152 for(Index i = k2; i < depth; i++) 153 { 154 for(Index k = 0; k < vectorSize; k++) 155 { 156 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j + k, rhs); 157 158 blockBf[rir + k] = v.real(); 159 blockBf[rii + k] = v.imag(); 160 } 161 rir += vectorSize; 162 rii += vectorSize; 163 } 164 165 rir += vectorDelta; 166 } 167 if (j < cols) 168 { 169 rii = rir + ((cols - j) * rows); 170 171 for(Index i = k2; i < depth; i++) 172 { 173 Index k = j; 174 for(; k < cols; k++) 175 { 176 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs); 177 178 blockBf[rir] = v.real(); 179 blockBf[rii] = v.imag(); 180 181 rir += 1; 182 rii += 1; 183 } 184 } 185 } 186 } 187 188 template<typename Scalar, typename Index, int StorageOrder> 189 EIGEN_STRONG_INLINE void symm_pack_complex_lhs_helper(std::complex<Scalar>* blockA, const std::complex<Scalar>* _lhs, Index lhsStride, Index cols, Index rows) 190 { 191 const Index depth = cols; 192 const_blas_data_mapper<std::complex<Scalar>, Index, StorageOrder> lhs(_lhs, lhsStride); 193 const Index vectorSize = quad_traits<Scalar>::vectorsize; 194 const Index vectorDelta = vectorSize * depth; 195 Scalar* blockAf = (Scalar *)(blockA); 196 197 Index rir = 0, rii, j = 0; 198 for(; j + vectorSize <= rows; j+=vectorSize) 199 { 200 rii = rir + vectorDelta; 201 202 for(Index i = 0; i < depth; i++) 203 { 204 for(Index k = 0; k < vectorSize; k++) 205 { 206 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(j+k, i, lhs); 207 208 blockAf[rir + k] = v.real(); 209 blockAf[rii + k] = v.imag(); 210 } 211 rir += vectorSize; 212 rii += vectorSize; 213 } 214 215 rir += vectorDelta; 216 } 217 218 if (j < rows) 219 { 220 rii = rir + ((rows - j) * depth); 221 222 for(Index i = 0; i < depth; i++) 223 { 224 Index k = j; 225 for(; k < rows; k++) 226 { 227 std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(k, i, lhs); 228 229 blockAf[rir] = v.real(); 230 blockAf[rii] = v.imag(); 231 232 rir += 1; 233 rii += 1; 234 } 235 } 236 } 237 } 238 239 template<typename Scalar, typename Index, int StorageOrder, int N> 240 EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2) 241 { 242 const Index depth = k2 + rows; 243 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(_rhs, rhsStride); 244 const Index vectorSize = quad_traits<Scalar>::vectorsize; 245 246 Index ri = 0, j = 0; 247 for(; j + N*vectorSize <= cols; j+=N*vectorSize) 248 { 249 Index i = k2; 250 for(; i < depth; i++) 251 { 252 for(Index k = 0; k < N*vectorSize; k++) 253 { 254 if(i <= j+k) 255 blockB[ri + k] = rhs(j+k, i); 256 else 257 blockB[ri + k] = rhs(i, j+k); 258 } 259 ri += N*vectorSize; 260 } 261 } 262 263 if (j < cols) 264 { 265 for(Index i = k2; i < depth; i++) 266 { 267 Index k = j; 268 for(; k < cols; k++) 269 { 270 if(k <= i) 271 blockB[ri] = rhs(i, k); 272 else 273 blockB[ri] = rhs(k, i); 274 ri += 1; 275 } 276 } 277 } 278 } 279 280 template<typename Scalar, typename Index, int StorageOrder> 281 EIGEN_STRONG_INLINE void symm_pack_lhs_helper(Scalar* blockA, const Scalar* _lhs, Index lhsStride, Index cols, Index rows) 282 { 283 const Index depth = cols; 284 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs, lhsStride); 285 const Index vectorSize = quad_traits<Scalar>::vectorsize; 286 287 Index ri = 0, j = 0; 288 for(; j + vectorSize <= rows; j+=vectorSize) 289 { 290 Index i = 0; 291 292 for(; i < depth; i++) 293 { 294 for(Index k = 0; k < vectorSize; k++) 295 { 296 if(i <= j+k) 297 blockA[ri + k] = lhs(j+k, i); 298 else 299 blockA[ri + k] = lhs(i, j+k); 300 } 301 ri += vectorSize; 302 } 303 } 304 305 if (j < rows) 306 { 307 for(Index i = 0; i < depth; i++) 308 { 309 Index k = j; 310 for(; k < rows; k++) 311 { 312 if(i <= k) 313 blockA[ri] = lhs(k, i); 314 else 315 blockA[ri] = lhs(i, k); 316 ri += 1; 317 } 318 } 319 } 320 } 321 322 template<typename Index, int nr, int StorageOrder> 323 struct symm_pack_rhs<std::complex<float>, Index, nr, StorageOrder> 324 { 325 void operator()(std::complex<float>* blockB, const std::complex<float>* _rhs, Index rhsStride, Index rows, Index cols, Index k2) 326 { 327 symm_pack_complex_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2); 328 } 329 }; 330 331 template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder> 332 struct symm_pack_lhs<std::complex<float>, Index, Pack1, Pack2_dummy, StorageOrder> 333 { 334 void operator()(std::complex<float>* blockA, const std::complex<float>* _lhs, Index lhsStride, Index cols, Index rows) 335 { 336 symm_pack_complex_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows); 337 } 338 }; 339 340 // *********** symm_pack std::complex<float64> *********** 341 342 template<typename Index, int nr, int StorageOrder> 343 struct symm_pack_rhs<std::complex<double>, Index, nr, StorageOrder> 344 { 345 void operator()(std::complex<double>* blockB, const std::complex<double>* _rhs, Index rhsStride, Index rows, Index cols, Index k2) 346 { 347 symm_pack_complex_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2); 348 } 349 }; 350 351 template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder> 352 struct symm_pack_lhs<std::complex<double>, Index, Pack1, Pack2_dummy, StorageOrder> 353 { 354 void operator()(std::complex<double>* blockA, const std::complex<double>* _lhs, Index lhsStride, Index cols, Index rows) 355 { 356 symm_pack_complex_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows); 357 } 358 }; 359 360 // *********** symm_pack float32 *********** 361 template<typename Index, int nr, int StorageOrder> 362 struct symm_pack_rhs<float, Index, nr, StorageOrder> 363 { 364 void operator()(float* blockB, const float* _rhs, Index rhsStride, Index rows, Index cols, Index k2) 365 { 366 symm_pack_rhs_helper<float, Index, StorageOrder, 1>(blockB, _rhs, rhsStride, rows, cols, k2); 367 } 368 }; 369 370 template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder> 371 struct symm_pack_lhs<float, Index, Pack1, Pack2_dummy, StorageOrder> 372 { 373 void operator()(float* blockA, const float* _lhs, Index lhsStride, Index cols, Index rows) 374 { 375 symm_pack_lhs_helper<float, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows); 376 } 377 }; 378 379 // *********** symm_pack float64 *********** 380 template<typename Index, int nr, int StorageOrder> 381 struct symm_pack_rhs<double, Index, nr, StorageOrder> 382 { 383 void operator()(double* blockB, const double* _rhs, Index rhsStride, Index rows, Index cols, Index k2) 384 { 385 symm_pack_rhs_helper<double, Index, StorageOrder, 2>(blockB, _rhs, rhsStride, rows, cols, k2); 386 } 387 }; 388 389 template<typename Index, int Pack1, int Pack2_dummy, int StorageOrder> 390 struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder> 391 { 392 void operator()(double* blockA, const double* _lhs, Index lhsStride, Index cols, Index rows) 393 { 394 symm_pack_lhs_helper<double, Index, StorageOrder>(blockA, _lhs, lhsStride, cols, rows); 395 } 396 }; 397 398 /** 399 * PanelMode 400 * Packing might be called several times before being multiplied by gebp_kernel, this happens because 401 * on special occasions it fills part of block with other parts of the matrix. Two variables control 402 * how PanelMode should behave: offset and stride. The idea is that those variables represent whatever 403 * is going to be the real offset and stride in the future and this is what you should obey. The process 404 * is to behave as you would with normal packing but leave the start of each part with the correct offset 405 * and the end as well respecting the real stride the block will have. Gebp is aware of both blocks stride 406 * and offset and behaves accordingly. 407 **/ 408 409 template<typename Scalar, typename Packet, typename Index> 410 EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block) 411 { 412 const Index size = 16 / sizeof(Scalar); 413 pstore<Scalar>(to + (0 * size), block.packet[0]); 414 pstore<Scalar>(to + (1 * size), block.packet[1]); 415 pstore<Scalar>(to + (2 * size), block.packet[2]); 416 pstore<Scalar>(to + (3 * size), block.packet[3]); 417 } 418 419 template<typename Scalar, typename Packet, typename Index> 420 EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block) 421 { 422 const Index size = 16 / sizeof(Scalar); 423 pstore<Scalar>(to + (0 * size), block.packet[0]); 424 pstore<Scalar>(to + (1 * size), block.packet[1]); 425 } 426 427 // General template for lhs & rhs complex packing. 428 template<typename Scalar, typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs> 429 struct dhs_cpack { 430 EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 431 { 432 const Index vectorSize = quad_traits<Scalar>::vectorsize; 433 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth); 434 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii; 435 Scalar* blockAt = reinterpret_cast<Scalar *>(blockA); 436 Index j = 0; 437 438 for(; j + vectorSize <= rows; j+=vectorSize) 439 { 440 Index i = 0; 441 442 rii = rir + vectorDelta; 443 444 for(; i + vectorSize <= depth; i+=vectorSize) 445 { 446 PacketBlock<Packet,4> blockr, blocki; 447 PacketBlock<PacketC,8> cblock; 448 449 if (UseLhs) { 450 bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i); 451 } else { 452 bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j); 453 } 454 455 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); 456 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32); 457 blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32); 458 blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32); 459 460 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32); 461 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32); 462 blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32); 463 blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32); 464 465 if(Conjugate) 466 { 467 blocki.packet[0] = -blocki.packet[0]; 468 blocki.packet[1] = -blocki.packet[1]; 469 blocki.packet[2] = -blocki.packet[2]; 470 blocki.packet[3] = -blocki.packet[3]; 471 } 472 473 if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs))) 474 { 475 ptranspose(blockr); 476 ptranspose(blocki); 477 } 478 479 storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr); 480 storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki); 481 482 rir += 4*vectorSize; 483 rii += 4*vectorSize; 484 } 485 for(; i < depth; i++) 486 { 487 PacketBlock<Packet,1> blockr, blocki; 488 PacketBlock<PacketC,2> cblock; 489 490 if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs))) 491 { 492 if (UseLhs) { 493 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); 494 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i); 495 } else { 496 cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0); 497 cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2); 498 } 499 } else { 500 std::complex<Scalar> lhs0, lhs1; 501 if (UseLhs) { 502 lhs0 = lhs(j + 0, i); 503 lhs1 = lhs(j + 1, i); 504 cblock.packet[0] = pload2(&lhs0, &lhs1); 505 lhs0 = lhs(j + 2, i); 506 lhs1 = lhs(j + 3, i); 507 cblock.packet[1] = pload2(&lhs0, &lhs1); 508 } else { 509 lhs0 = lhs(i, j + 0); 510 lhs1 = lhs(i, j + 1); 511 cblock.packet[0] = pload2(&lhs0, &lhs1); 512 lhs0 = lhs(i, j + 2); 513 lhs1 = lhs(i, j + 3); 514 cblock.packet[1] = pload2(&lhs0, &lhs1); 515 } 516 } 517 518 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL32); 519 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG32); 520 521 if(Conjugate) 522 { 523 blocki.packet[0] = -blocki.packet[0]; 524 } 525 526 pstore<Scalar>(blockAt + rir, blockr.packet[0]); 527 pstore<Scalar>(blockAt + rii, blocki.packet[0]); 528 529 rir += vectorSize; 530 rii += vectorSize; 531 } 532 533 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); 534 } 535 536 if (j < rows) 537 { 538 if(PanelMode) rir += (offset*(rows - j - vectorSize)); 539 rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); 540 541 for(Index i = 0; i < depth; i++) 542 { 543 Index k = j; 544 for(; k < rows; k++) 545 { 546 if (UseLhs) { 547 blockAt[rir] = lhs(k, i).real(); 548 549 if(Conjugate) 550 blockAt[rii] = -lhs(k, i).imag(); 551 else 552 blockAt[rii] = lhs(k, i).imag(); 553 } else { 554 blockAt[rir] = lhs(i, k).real(); 555 556 if(Conjugate) 557 blockAt[rii] = -lhs(i, k).imag(); 558 else 559 blockAt[rii] = lhs(i, k).imag(); 560 } 561 562 rir += 1; 563 rii += 1; 564 } 565 } 566 } 567 } 568 }; 569 570 // General template for lhs & rhs packing. 571 template<typename Scalar, typename Index, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs> 572 struct dhs_pack{ 573 EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 574 { 575 const Index vectorSize = quad_traits<Scalar>::vectorsize; 576 Index ri = 0, j = 0; 577 578 for(; j + vectorSize <= rows; j+=vectorSize) 579 { 580 Index i = 0; 581 582 if(PanelMode) ri += vectorSize*offset; 583 584 for(; i + vectorSize <= depth; i+=vectorSize) 585 { 586 PacketBlock<Packet,4> block; 587 588 if (UseLhs) { 589 bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i); 590 } else { 591 bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j); 592 } 593 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) 594 { 595 ptranspose(block); 596 } 597 598 storeBlock<Scalar, Packet, Index>(blockA + ri, block); 599 600 ri += 4*vectorSize; 601 } 602 for(; i < depth; i++) 603 { 604 if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) 605 { 606 if (UseLhs) { 607 blockA[ri+0] = lhs(j+0, i); 608 blockA[ri+1] = lhs(j+1, i); 609 blockA[ri+2] = lhs(j+2, i); 610 blockA[ri+3] = lhs(j+3, i); 611 } else { 612 blockA[ri+0] = lhs(i, j+0); 613 blockA[ri+1] = lhs(i, j+1); 614 blockA[ri+2] = lhs(i, j+2); 615 blockA[ri+3] = lhs(i, j+3); 616 } 617 } else { 618 Packet lhsV; 619 if (UseLhs) { 620 lhsV = lhs.template loadPacket<Packet>(j, i); 621 } else { 622 lhsV = lhs.template loadPacket<Packet>(i, j); 623 } 624 pstore<Scalar>(blockA + ri, lhsV); 625 } 626 627 ri += vectorSize; 628 } 629 630 if(PanelMode) ri += vectorSize*(stride - offset - depth); 631 } 632 633 if (j < rows) 634 { 635 if(PanelMode) ri += offset*(rows - j); 636 637 for(Index i = 0; i < depth; i++) 638 { 639 Index k = j; 640 for(; k < rows; k++) 641 { 642 if (UseLhs) { 643 blockA[ri] = lhs(k, i); 644 } else { 645 blockA[ri] = lhs(i, k); 646 } 647 ri += 1; 648 } 649 } 650 } 651 } 652 }; 653 654 // General template for lhs packing, float64 specialization. 655 template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode> 656 struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, true> 657 { 658 EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 659 { 660 const Index vectorSize = quad_traits<double>::vectorsize; 661 Index ri = 0, j = 0; 662 663 for(; j + vectorSize <= rows; j+=vectorSize) 664 { 665 Index i = 0; 666 667 if(PanelMode) ri += vectorSize*offset; 668 669 for(; i + vectorSize <= depth; i+=vectorSize) 670 { 671 PacketBlock<Packet2d,2> block; 672 if(StorageOrder == RowMajor) 673 { 674 block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i); 675 block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i); 676 677 ptranspose(block); 678 } else { 679 block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0); 680 block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1); 681 } 682 683 storeBlock<double, Packet2d, Index>(blockA + ri, block); 684 685 ri += 2*vectorSize; 686 } 687 for(; i < depth; i++) 688 { 689 if(StorageOrder == RowMajor) 690 { 691 blockA[ri+0] = lhs(j+0, i); 692 blockA[ri+1] = lhs(j+1, i); 693 } else { 694 Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i); 695 pstore<double>(blockA + ri, lhsV); 696 } 697 698 ri += vectorSize; 699 } 700 701 if(PanelMode) ri += vectorSize*(stride - offset - depth); 702 } 703 704 if (j < rows) 705 { 706 if(PanelMode) ri += offset*(rows - j); 707 708 for(Index i = 0; i < depth; i++) 709 { 710 Index k = j; 711 for(; k < rows; k++) 712 { 713 blockA[ri] = lhs(k, i); 714 ri += 1; 715 } 716 } 717 } 718 } 719 }; 720 721 // General template for rhs packing, float64 specialization. 722 template<typename Index, typename DataMapper, int StorageOrder, bool PanelMode> 723 struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, false> 724 { 725 EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 726 { 727 const Index vectorSize = quad_traits<double>::vectorsize; 728 Index ri = 0, j = 0; 729 730 for(; j + 2*vectorSize <= cols; j+=2*vectorSize) 731 { 732 Index i = 0; 733 734 if(PanelMode) ri += offset*(2*vectorSize); 735 736 for(; i + vectorSize <= depth; i+=vectorSize) 737 { 738 PacketBlock<Packet2d,4> block; 739 if(StorageOrder == ColMajor) 740 { 741 PacketBlock<Packet2d,2> block1, block2; 742 block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0); 743 block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1); 744 block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2); 745 block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3); 746 747 ptranspose(block1); 748 ptranspose(block2); 749 750 pstore<double>(blockB + ri , block1.packet[0]); 751 pstore<double>(blockB + ri + 2, block2.packet[0]); 752 pstore<double>(blockB + ri + 4, block1.packet[1]); 753 pstore<double>(blockB + ri + 6, block2.packet[1]); 754 } else { 755 block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2] 756 block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4] 757 block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2] 758 block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4] 759 760 storeBlock<double, Packet2d, Index>(blockB + ri, block); 761 } 762 763 ri += 4*vectorSize; 764 } 765 for(; i < depth; i++) 766 { 767 if(StorageOrder == ColMajor) 768 { 769 blockB[ri+0] = rhs(i, j+0); 770 blockB[ri+1] = rhs(i, j+1); 771 772 ri += vectorSize; 773 774 blockB[ri+0] = rhs(i, j+2); 775 blockB[ri+1] = rhs(i, j+3); 776 } else { 777 Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j); 778 pstore<double>(blockB + ri, rhsV); 779 780 ri += vectorSize; 781 782 rhsV = rhs.template loadPacket<Packet2d>(i, j + 2); 783 pstore<double>(blockB + ri, rhsV); 784 } 785 ri += vectorSize; 786 } 787 788 if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth); 789 } 790 791 if (j < cols) 792 { 793 if(PanelMode) ri += offset*(cols - j); 794 795 for(Index i = 0; i < depth; i++) 796 { 797 Index k = j; 798 for(; k < cols; k++) 799 { 800 blockB[ri] = rhs(i, k); 801 ri += 1; 802 } 803 } 804 } 805 } 806 }; 807 808 // General template for lhs complex packing, float64 specialization. 809 template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode> 810 struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true> 811 { 812 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 813 { 814 const Index vectorSize = quad_traits<double>::vectorsize; 815 const Index vectorDelta = vectorSize * ((PanelMode) ? stride : depth); 816 Index rir = ((PanelMode) ? (vectorSize*offset) : 0), rii; 817 double* blockAt = reinterpret_cast<double *>(blockA); 818 Index j = 0; 819 820 for(; j + vectorSize <= rows; j+=vectorSize) 821 { 822 Index i = 0; 823 824 rii = rir + vectorDelta; 825 826 for(; i + vectorSize <= depth; i+=vectorSize) 827 { 828 PacketBlock<Packet,2> blockr, blocki; 829 PacketBlock<PacketC,4> cblock; 830 831 if(StorageOrder == ColMajor) 832 { 833 cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i] 834 cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i] 835 836 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i] 837 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i] 838 839 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2] 840 blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] 841 842 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64); 843 blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64); 844 } else { 845 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i] 846 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i] 847 848 cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i] 849 cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i 850 851 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2] 852 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2] 853 854 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); 855 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); 856 } 857 858 if(Conjugate) 859 { 860 blocki.packet[0] = -blocki.packet[0]; 861 blocki.packet[1] = -blocki.packet[1]; 862 } 863 864 storeBlock<double, Packet, Index>(blockAt + rir, blockr); 865 storeBlock<double, Packet, Index>(blockAt + rii, blocki); 866 867 rir += 2*vectorSize; 868 rii += 2*vectorSize; 869 } 870 for(; i < depth; i++) 871 { 872 PacketBlock<Packet,1> blockr, blocki; 873 PacketBlock<PacketC,2> cblock; 874 875 cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); 876 cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); 877 878 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); 879 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); 880 881 if(Conjugate) 882 { 883 blocki.packet[0] = -blocki.packet[0]; 884 } 885 886 pstore<double>(blockAt + rir, blockr.packet[0]); 887 pstore<double>(blockAt + rii, blocki.packet[0]); 888 889 rir += vectorSize; 890 rii += vectorSize; 891 } 892 893 rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); 894 } 895 896 if (j < rows) 897 { 898 if(PanelMode) rir += (offset*(rows - j - vectorSize)); 899 rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); 900 901 for(Index i = 0; i < depth; i++) 902 { 903 Index k = j; 904 for(; k < rows; k++) 905 { 906 blockAt[rir] = lhs(k, i).real(); 907 908 if(Conjugate) 909 blockAt[rii] = -lhs(k, i).imag(); 910 else 911 blockAt[rii] = lhs(k, i).imag(); 912 913 rir += 1; 914 rii += 1; 915 } 916 } 917 } 918 } 919 }; 920 921 // General template for rhs complex packing, float64 specialization. 922 template<typename Index, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode> 923 struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false> 924 { 925 EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 926 { 927 const Index vectorSize = quad_traits<double>::vectorsize; 928 const Index vectorDelta = 2*vectorSize * ((PanelMode) ? stride : depth); 929 Index rir = ((PanelMode) ? (2*vectorSize*offset) : 0), rii; 930 double* blockBt = reinterpret_cast<double *>(blockB); 931 Index j = 0; 932 933 for(; j + 2*vectorSize <= cols; j+=2*vectorSize) 934 { 935 Index i = 0; 936 937 rii = rir + vectorDelta; 938 939 for(; i < depth; i++) 940 { 941 PacketBlock<PacketC,4> cblock; 942 PacketBlock<Packet,2> blockr, blocki; 943 944 bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j); 945 946 blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); 947 blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); 948 949 blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64); 950 blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64); 951 952 if(Conjugate) 953 { 954 blocki.packet[0] = -blocki.packet[0]; 955 blocki.packet[1] = -blocki.packet[1]; 956 } 957 958 storeBlock<double, Packet, Index>(blockBt + rir, blockr); 959 storeBlock<double, Packet, Index>(blockBt + rii, blocki); 960 961 rir += 2*vectorSize; 962 rii += 2*vectorSize; 963 } 964 965 rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta); 966 } 967 968 if (j < cols) 969 { 970 if(PanelMode) rir += (offset*(cols - j - 2*vectorSize)); 971 rii = rir + (((PanelMode) ? stride : depth) * (cols - j)); 972 973 for(Index i = 0; i < depth; i++) 974 { 975 Index k = j; 976 for(; k < cols; k++) 977 { 978 blockBt[rir] = rhs(i, k).real(); 979 980 if(Conjugate) 981 blockBt[rii] = -rhs(i, k).imag(); 982 else 983 blockBt[rii] = rhs(i, k).imag(); 984 985 rir += 1; 986 rii += 1; 987 } 988 } 989 } 990 } 991 }; 992 993 /************** 994 * GEMM utils * 995 **************/ 996 997 // 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm). 998 template<typename Packet, bool NegativeAccumulate> 999 EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV) 1000 { 1001 if(NegativeAccumulate) 1002 { 1003 acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); 1004 acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); 1005 acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); 1006 acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); 1007 } else { 1008 acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); 1009 acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); 1010 acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); 1011 acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); 1012 } 1013 } 1014 1015 template<typename Packet, bool NegativeAccumulate> 1016 EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV) 1017 { 1018 if(NegativeAccumulate) 1019 { 1020 acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); 1021 } else { 1022 acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); 1023 } 1024 } 1025 1026 template<int N, typename Scalar, typename Packet, bool NegativeAccumulate> 1027 EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV) 1028 { 1029 Packet lhsV = pload<Packet>(lhs); 1030 1031 pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV); 1032 } 1033 1034 template<typename Scalar, typename Packet, typename Index> 1035 EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows) 1036 { 1037 #ifdef _ARCH_PWR9 1038 lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); 1039 #else 1040 Index i = 0; 1041 do { 1042 lhsV[i] = lhs[i]; 1043 } while (++i < remaining_rows); 1044 #endif 1045 } 1046 1047 template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate> 1048 EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) 1049 { 1050 Packet lhsV; 1051 loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows); 1052 1053 pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV); 1054 } 1055 1056 // 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. 1057 template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1058 EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi) 1059 { 1060 pger_common<Packet, false>(accReal, lhsV, rhsV); 1061 if(LhsIsReal) 1062 { 1063 pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi); 1064 EIGEN_UNUSED_VARIABLE(lhsVi); 1065 } else { 1066 if (!RhsIsReal) { 1067 pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi); 1068 pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi); 1069 } else { 1070 EIGEN_UNUSED_VARIABLE(rhsVi); 1071 } 1072 pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV); 1073 } 1074 } 1075 1076 template<int N, typename Scalar, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1077 EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) 1078 { 1079 Packet lhsV = ploadLhs<Scalar, Packet>(lhs_ptr); 1080 Packet lhsVi; 1081 if(!LhsIsReal) lhsVi = ploadLhs<Scalar, Packet>(lhs_ptr_imag); 1082 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1083 1084 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); 1085 } 1086 1087 template<typename Scalar, typename Packet, typename Index, bool LhsIsReal> 1088 EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows) 1089 { 1090 #ifdef _ARCH_PWR9 1091 lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar)); 1092 if(!LhsIsReal) lhsVi = vec_xl_len((Scalar *)lhs_ptr_imag, remaining_rows * sizeof(Scalar)); 1093 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1094 #else 1095 Index i = 0; 1096 do { 1097 lhsV[i] = lhs_ptr[i]; 1098 if(!LhsIsReal) lhsVi[i] = lhs_ptr_imag[i]; 1099 } while (++i < remaining_rows); 1100 if(LhsIsReal) EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1101 #endif 1102 } 1103 1104 template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1105 EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows) 1106 { 1107 Packet lhsV, lhsVi; 1108 loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows); 1109 1110 pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); 1111 } 1112 1113 template<typename Scalar, typename Packet> 1114 EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs) 1115 { 1116 return ploadu<Packet>(lhs); 1117 } 1118 1119 // Zero the accumulator on PacketBlock. 1120 template<typename Scalar, typename Packet> 1121 EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc) 1122 { 1123 acc.packet[0] = pset1<Packet>((Scalar)0); 1124 acc.packet[1] = pset1<Packet>((Scalar)0); 1125 acc.packet[2] = pset1<Packet>((Scalar)0); 1126 acc.packet[3] = pset1<Packet>((Scalar)0); 1127 } 1128 1129 template<typename Scalar, typename Packet> 1130 EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc) 1131 { 1132 acc.packet[0] = pset1<Packet>((Scalar)0); 1133 } 1134 1135 // Scale the PacketBlock vectors by alpha. 1136 template<typename Packet> 1137 EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha) 1138 { 1139 acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); 1140 acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); 1141 acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); 1142 acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); 1143 } 1144 1145 template<typename Packet> 1146 EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha) 1147 { 1148 acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); 1149 } 1150 1151 template<typename Packet> 1152 EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha) 1153 { 1154 acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha); 1155 acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha); 1156 acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha); 1157 acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha); 1158 } 1159 1160 template<typename Packet> 1161 EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha) 1162 { 1163 acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha); 1164 } 1165 1166 // Complex version of PacketBlock scaling. 1167 template<typename Packet, int N> 1168 EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag) 1169 { 1170 bscalec_common<Packet>(cReal, aReal, bReal); 1171 1172 bscalec_common<Packet>(cImag, aImag, bReal); 1173 1174 pger_common<Packet, true>(&cReal, bImag, aImag.packet); 1175 1176 pger_common<Packet, false>(&cImag, bImag, aReal.packet); 1177 } 1178 1179 template<typename Packet> 1180 EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask) 1181 { 1182 acc.packet[0] = pand(acc.packet[0], pMask); 1183 acc.packet[1] = pand(acc.packet[1], pMask); 1184 acc.packet[2] = pand(acc.packet[2], pMask); 1185 acc.packet[3] = pand(acc.packet[3], pMask); 1186 } 1187 1188 template<typename Packet> 1189 EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask) 1190 { 1191 band<Packet>(aReal, pMask); 1192 band<Packet>(aImag, pMask); 1193 1194 bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag); 1195 } 1196 1197 // Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. 1198 template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder> 1199 EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col) 1200 { 1201 if (StorageOrder == RowMajor) { 1202 acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols); 1203 acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols); 1204 acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols); 1205 acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols); 1206 } else { 1207 acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0); 1208 acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1); 1209 acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2); 1210 acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3); 1211 } 1212 } 1213 1214 // An overload of bload when you have a PacketBLock with 8 vectors. 1215 template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder> 1216 EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col) 1217 { 1218 if (StorageOrder == RowMajor) { 1219 acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols); 1220 acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols); 1221 acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols); 1222 acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols); 1223 acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols); 1224 acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols); 1225 acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols); 1226 acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols); 1227 } else { 1228 acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0); 1229 acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1); 1230 acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2); 1231 acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3); 1232 acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0); 1233 acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1); 1234 acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2); 1235 acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3); 1236 } 1237 } 1238 1239 template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder> 1240 EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col) 1241 { 1242 acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0); 1243 acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0); 1244 } 1245 1246 const static Packet4i mask41 = { -1, 0, 0, 0 }; 1247 const static Packet4i mask42 = { -1, -1, 0, 0 }; 1248 const static Packet4i mask43 = { -1, -1, -1, 0 }; 1249 1250 const static Packet2l mask21 = { -1, 0 }; 1251 1252 template<typename Packet> 1253 EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows) 1254 { 1255 if (remaining_rows == 0) { 1256 return pset1<Packet>(float(0.0)); // Not used 1257 } else { 1258 switch (remaining_rows) { 1259 case 1: return Packet(mask41); 1260 case 2: return Packet(mask42); 1261 default: return Packet(mask43); 1262 } 1263 } 1264 } 1265 1266 template<> 1267 EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows) 1268 { 1269 if (remaining_rows == 0) { 1270 return pset1<Packet2d>(double(0.0)); // Not used 1271 } else { 1272 return Packet2d(mask21); 1273 } 1274 } 1275 1276 template<typename Packet> 1277 EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask) 1278 { 1279 band<Packet>(accZ, pMask); 1280 1281 bscale<Packet>(acc, accZ, pAlpha); 1282 } 1283 1284 template<typename Packet> 1285 EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3) 1286 { 1287 pbroadcast4<Packet>(a, a0, a1, a2, a3); 1288 } 1289 1290 template<> 1291 EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) 1292 { 1293 a1 = pload<Packet2d>(a); 1294 a3 = pload<Packet2d>(a + 2); 1295 a0 = vec_splat(a1, 0); 1296 a1 = vec_splat(a1, 1); 1297 a2 = vec_splat(a3, 0); 1298 a3 = vec_splat(a3, 1); 1299 } 1300 1301 // PEEL loop factor. 1302 #define PEEL 7 1303 1304 template<typename Scalar, typename Packet, typename Index> 1305 EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL( 1306 const Scalar* &lhs_ptr, 1307 const Scalar* &rhs_ptr, 1308 PacketBlock<Packet,1> &accZero, 1309 Index remaining_rows, 1310 Index remaining_cols) 1311 { 1312 Packet rhsV[1]; 1313 rhsV[0] = pset1<Packet>(rhs_ptr[0]); 1314 pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); 1315 lhs_ptr += remaining_rows; 1316 rhs_ptr += remaining_cols; 1317 } 1318 1319 template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows> 1320 EIGEN_STRONG_INLINE void gemm_extra_col( 1321 const DataMapper& res, 1322 const Scalar* lhs_base, 1323 const Scalar* rhs_base, 1324 Index depth, 1325 Index strideA, 1326 Index offsetA, 1327 Index row, 1328 Index col, 1329 Index remaining_rows, 1330 Index remaining_cols, 1331 const Packet& pAlpha) 1332 { 1333 const Scalar* rhs_ptr = rhs_base; 1334 const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; 1335 PacketBlock<Packet,1> accZero; 1336 1337 bsetzero<Scalar, Packet>(accZero); 1338 1339 Index remaining_depth = (depth & -accRows); 1340 Index k = 0; 1341 for(; k + PEEL <= remaining_depth; k+= PEEL) 1342 { 1343 EIGEN_POWER_PREFETCH(rhs_ptr); 1344 EIGEN_POWER_PREFETCH(lhs_ptr); 1345 for (int l = 0; l < PEEL; l++) { 1346 MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); 1347 } 1348 } 1349 for(; k < remaining_depth; k++) 1350 { 1351 MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); 1352 } 1353 for(; k < depth; k++) 1354 { 1355 Packet rhsV[1]; 1356 rhsV[0] = pset1<Packet>(rhs_ptr[0]); 1357 pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); 1358 lhs_ptr += remaining_rows; 1359 rhs_ptr += remaining_cols; 1360 } 1361 1362 accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]); 1363 for(Index i = 0; i < remaining_rows; i++) { 1364 res(row + i, col) += accZero.packet[0][i]; 1365 } 1366 } 1367 1368 template<typename Scalar, typename Packet, typename Index, const Index accRows> 1369 EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW( 1370 const Scalar* &lhs_ptr, 1371 const Scalar* &rhs_ptr, 1372 PacketBlock<Packet,4> &accZero, 1373 Index remaining_rows) 1374 { 1375 Packet rhsV[4]; 1376 pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); 1377 pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); 1378 lhs_ptr += remaining_rows; 1379 rhs_ptr += accRows; 1380 } 1381 1382 template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols> 1383 EIGEN_STRONG_INLINE void gemm_extra_row( 1384 const DataMapper& res, 1385 const Scalar* lhs_base, 1386 const Scalar* rhs_base, 1387 Index depth, 1388 Index strideA, 1389 Index offsetA, 1390 Index row, 1391 Index col, 1392 Index rows, 1393 Index cols, 1394 Index remaining_rows, 1395 const Packet& pAlpha, 1396 const Packet& pMask) 1397 { 1398 const Scalar* rhs_ptr = rhs_base; 1399 const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; 1400 PacketBlock<Packet,4> accZero, acc; 1401 1402 bsetzero<Scalar, Packet>(accZero); 1403 1404 Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); 1405 Index k = 0; 1406 for(; k + PEEL <= remaining_depth; k+= PEEL) 1407 { 1408 EIGEN_POWER_PREFETCH(rhs_ptr); 1409 EIGEN_POWER_PREFETCH(lhs_ptr); 1410 for (int l = 0; l < PEEL; l++) { 1411 MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows); 1412 } 1413 } 1414 for(; k < remaining_depth; k++) 1415 { 1416 MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows); 1417 } 1418 1419 if ((remaining_depth == depth) && (rows >= accCols)) 1420 { 1421 for(Index j = 0; j < 4; j++) { 1422 acc.packet[j] = res.template loadPacket<Packet>(row, col + j); 1423 } 1424 bscale<Packet>(acc, accZero, pAlpha, pMask); 1425 res.template storePacketBlock<Packet,4>(row, col, acc); 1426 } else { 1427 for(; k < depth; k++) 1428 { 1429 Packet rhsV[4]; 1430 pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); 1431 pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); 1432 lhs_ptr += remaining_rows; 1433 rhs_ptr += accRows; 1434 } 1435 1436 for(Index j = 0; j < 4; j++) { 1437 accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]); 1438 } 1439 for(Index j = 0; j < 4; j++) { 1440 for(Index i = 0; i < remaining_rows; i++) { 1441 res(row + i, col + j) += accZero.packet[j][i]; 1442 } 1443 } 1444 } 1445 } 1446 1447 #define MICRO_UNROLL(func) \ 1448 func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) 1449 1450 #define MICRO_UNROLL_WORK(func, func2, peel) \ 1451 MICRO_UNROLL(func2); \ 1452 func(0,peel) func(1,peel) func(2,peel) func(3,peel) \ 1453 func(4,peel) func(5,peel) func(6,peel) func(7,peel) 1454 1455 #define MICRO_LOAD_ONE(iter) \ 1456 if (unroll_factor > iter) { \ 1457 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr##iter); \ 1458 lhs_ptr##iter += accCols; \ 1459 } else { \ 1460 EIGEN_UNUSED_VARIABLE(lhsV##iter); \ 1461 } 1462 1463 #define MICRO_WORK_ONE(iter, peel) \ 1464 if (unroll_factor > iter) { \ 1465 pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \ 1466 } 1467 1468 #define MICRO_TYPE_PEEL4(func, func2, peel) \ 1469 if (PEEL > peel) { \ 1470 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ 1471 pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ 1472 MICRO_UNROLL_WORK(func, func2, peel) \ 1473 } else { \ 1474 EIGEN_UNUSED_VARIABLE(rhsV##peel); \ 1475 } 1476 1477 #define MICRO_TYPE_PEEL1(func, func2, peel) \ 1478 if (PEEL > peel) { \ 1479 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ 1480 rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \ 1481 MICRO_UNROLL_WORK(func, func2, peel) \ 1482 } else { \ 1483 EIGEN_UNUSED_VARIABLE(rhsV##peel); \ 1484 } 1485 1486 #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \ 1487 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ 1488 func(func1,func2,0); func(func1,func2,1); \ 1489 func(func1,func2,2); func(func1,func2,3); \ 1490 func(func1,func2,4); func(func1,func2,5); \ 1491 func(func1,func2,6); func(func1,func2,7); \ 1492 func(func1,func2,8); func(func1,func2,9); 1493 1494 #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \ 1495 Packet rhsV0[M]; \ 1496 func(func1,func2,0); 1497 1498 #define MICRO_ONE_PEEL4 \ 1499 MICRO_UNROLL_TYPE_PEEL(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ 1500 rhs_ptr += (accRows * PEEL); 1501 1502 #define MICRO_ONE4 \ 1503 MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ 1504 rhs_ptr += accRows; 1505 1506 #define MICRO_ONE_PEEL1 \ 1507 MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ 1508 rhs_ptr += (remaining_cols * PEEL); 1509 1510 #define MICRO_ONE1 \ 1511 MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ 1512 rhs_ptr += remaining_cols; 1513 1514 #define MICRO_DST_PTR_ONE(iter) \ 1515 if (unroll_factor > iter) { \ 1516 bsetzero<Scalar, Packet>(accZero##iter); \ 1517 } else { \ 1518 EIGEN_UNUSED_VARIABLE(accZero##iter); \ 1519 } 1520 1521 #define MICRO_DST_PTR MICRO_UNROLL(MICRO_DST_PTR_ONE) 1522 1523 #define MICRO_SRC_PTR_ONE(iter) \ 1524 if (unroll_factor > iter) { \ 1525 lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ 1526 } else { \ 1527 EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ 1528 } 1529 1530 #define MICRO_SRC_PTR MICRO_UNROLL(MICRO_SRC_PTR_ONE) 1531 1532 #define MICRO_PREFETCH_ONE(iter) \ 1533 if (unroll_factor > iter) { \ 1534 EIGEN_POWER_PREFETCH(lhs_ptr##iter); \ 1535 } 1536 1537 #define MICRO_PREFETCH MICRO_UNROLL(MICRO_PREFETCH_ONE) 1538 1539 #define MICRO_STORE_ONE(iter) \ 1540 if (unroll_factor > iter) { \ 1541 acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \ 1542 acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \ 1543 acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \ 1544 acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \ 1545 bscale<Packet>(acc, accZero##iter, pAlpha); \ 1546 res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \ 1547 } 1548 1549 #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) 1550 1551 #define MICRO_COL_STORE_ONE(iter) \ 1552 if (unroll_factor > iter) { \ 1553 acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \ 1554 bscale<Packet>(acc, accZero##iter, pAlpha); \ 1555 res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \ 1556 } 1557 1558 #define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE) 1559 1560 template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols> 1561 EIGEN_STRONG_INLINE void gemm_unrolled_iteration( 1562 const DataMapper& res, 1563 const Scalar* lhs_base, 1564 const Scalar* rhs_base, 1565 Index depth, 1566 Index strideA, 1567 Index offsetA, 1568 Index& row, 1569 Index col, 1570 const Packet& pAlpha) 1571 { 1572 const Scalar* rhs_ptr = rhs_base; 1573 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; 1574 PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; 1575 PacketBlock<Packet,4> acc; 1576 1577 MICRO_SRC_PTR 1578 MICRO_DST_PTR 1579 1580 Index k = 0; 1581 for(; k + PEEL <= depth; k+= PEEL) 1582 { 1583 EIGEN_POWER_PREFETCH(rhs_ptr); 1584 MICRO_PREFETCH 1585 MICRO_ONE_PEEL4 1586 } 1587 for(; k < depth; k++) 1588 { 1589 MICRO_ONE4 1590 } 1591 MICRO_STORE 1592 1593 row += unroll_factor*accCols; 1594 } 1595 1596 template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols> 1597 EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( 1598 const DataMapper& res, 1599 const Scalar* lhs_base, 1600 const Scalar* rhs_base, 1601 Index depth, 1602 Index strideA, 1603 Index offsetA, 1604 Index& row, 1605 Index col, 1606 Index remaining_cols, 1607 const Packet& pAlpha) 1608 { 1609 const Scalar* rhs_ptr = rhs_base; 1610 const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL; 1611 PacketBlock<Packet,1> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; 1612 PacketBlock<Packet,1> acc; 1613 1614 MICRO_SRC_PTR 1615 MICRO_DST_PTR 1616 1617 Index k = 0; 1618 for(; k + PEEL <= depth; k+= PEEL) 1619 { 1620 EIGEN_POWER_PREFETCH(rhs_ptr); 1621 MICRO_PREFETCH 1622 MICRO_ONE_PEEL1 1623 } 1624 for(; k < depth; k++) 1625 { 1626 MICRO_ONE1 1627 } 1628 MICRO_COL_STORE 1629 1630 row += unroll_factor*accCols; 1631 } 1632 1633 template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols> 1634 EIGEN_STRONG_INLINE void gemm_unrolled_col( 1635 const DataMapper& res, 1636 const Scalar* lhs_base, 1637 const Scalar* rhs_base, 1638 Index depth, 1639 Index strideA, 1640 Index offsetA, 1641 Index& row, 1642 Index rows, 1643 Index col, 1644 Index remaining_cols, 1645 const Packet& pAlpha) 1646 { 1647 #define MAX_UNROLL 6 1648 while(row + MAX_UNROLL*accCols <= rows) { 1649 gemm_unrolled_col_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1650 } 1651 switch( (rows-row)/accCols ) { 1652 #if MAX_UNROLL > 7 1653 case 7: 1654 gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1655 break; 1656 #endif 1657 #if MAX_UNROLL > 6 1658 case 6: 1659 gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1660 break; 1661 #endif 1662 #if MAX_UNROLL > 5 1663 case 5: 1664 gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1665 break; 1666 #endif 1667 #if MAX_UNROLL > 4 1668 case 4: 1669 gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1670 break; 1671 #endif 1672 #if MAX_UNROLL > 3 1673 case 3: 1674 gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1675 break; 1676 #endif 1677 #if MAX_UNROLL > 2 1678 case 2: 1679 gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1680 break; 1681 #endif 1682 #if MAX_UNROLL > 1 1683 case 1: 1684 gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); 1685 break; 1686 #endif 1687 default: 1688 break; 1689 } 1690 #undef MAX_UNROLL 1691 } 1692 1693 /**************** 1694 * GEMM kernels * 1695 * **************/ 1696 template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols> 1697 EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) 1698 { 1699 const Index remaining_rows = rows % accCols; 1700 const Index remaining_cols = cols % accRows; 1701 1702 if( strideA == -1 ) strideA = depth; 1703 if( strideB == -1 ) strideB = depth; 1704 1705 const Packet pAlpha = pset1<Packet>(alpha); 1706 const Packet pMask = bmask<Packet>((const int)(remaining_rows)); 1707 1708 Index col = 0; 1709 for(; col + accRows <= cols; col += accRows) 1710 { 1711 const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; 1712 const Scalar* lhs_base = blockA; 1713 Index row = 0; 1714 1715 #define MAX_UNROLL 6 1716 while(row + MAX_UNROLL*accCols <= rows) { 1717 gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1718 } 1719 switch( (rows-row)/accCols ) { 1720 #if MAX_UNROLL > 7 1721 case 7: 1722 gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1723 break; 1724 #endif 1725 #if MAX_UNROLL > 6 1726 case 6: 1727 gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1728 break; 1729 #endif 1730 #if MAX_UNROLL > 5 1731 case 5: 1732 gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1733 break; 1734 #endif 1735 #if MAX_UNROLL > 4 1736 case 4: 1737 gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1738 break; 1739 #endif 1740 #if MAX_UNROLL > 3 1741 case 3: 1742 gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1743 break; 1744 #endif 1745 #if MAX_UNROLL > 2 1746 case 2: 1747 gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1748 break; 1749 #endif 1750 #if MAX_UNROLL > 1 1751 case 1: 1752 gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); 1753 break; 1754 #endif 1755 default: 1756 break; 1757 } 1758 #undef MAX_UNROLL 1759 1760 if(remaining_rows > 0) 1761 { 1762 gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); 1763 } 1764 } 1765 1766 if(remaining_cols > 0) 1767 { 1768 const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; 1769 const Scalar* lhs_base = blockA; 1770 1771 for(; col < cols; col++) 1772 { 1773 Index row = 0; 1774 1775 gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); 1776 1777 if (remaining_rows > 0) 1778 { 1779 gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); 1780 } 1781 rhs_base++; 1782 } 1783 } 1784 } 1785 1786 #define accColsC (accCols / 2) 1787 #define advanceRows ((LhsIsReal) ? 1 : 2) 1788 #define advanceCols ((RhsIsReal) ? 1 : 2) 1789 1790 // PEEL_COMPLEX loop factor. 1791 #define PEEL_COMPLEX 3 1792 1793 template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1794 EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL( 1795 const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, 1796 const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, 1797 PacketBlock<Packet,1> &accReal, PacketBlock<Packet,1> &accImag, 1798 Index remaining_rows, 1799 Index remaining_cols) 1800 { 1801 Packet rhsV[1], rhsVi[1]; 1802 rhsV[0] = pset1<Packet>(rhs_ptr_real[0]); 1803 if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]); 1804 pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); 1805 lhs_ptr_real += remaining_rows; 1806 if(!LhsIsReal) lhs_ptr_imag += remaining_rows; 1807 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1808 rhs_ptr_real += remaining_cols; 1809 if(!RhsIsReal) rhs_ptr_imag += remaining_cols; 1810 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); 1811 } 1812 1813 template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1814 EIGEN_STRONG_INLINE void gemm_complex_extra_col( 1815 const DataMapper& res, 1816 const Scalar* lhs_base, 1817 const Scalar* rhs_base, 1818 Index depth, 1819 Index strideA, 1820 Index offsetA, 1821 Index strideB, 1822 Index row, 1823 Index col, 1824 Index remaining_rows, 1825 Index remaining_cols, 1826 const Packet& pAlphaReal, 1827 const Packet& pAlphaImag) 1828 { 1829 const Scalar* rhs_ptr_real = rhs_base; 1830 const Scalar* rhs_ptr_imag; 1831 if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB; 1832 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); 1833 const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; 1834 const Scalar* lhs_ptr_imag; 1835 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; 1836 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1837 PacketBlock<Packet,1> accReal, accImag; 1838 PacketBlock<Packet,1> taccReal, taccImag; 1839 PacketBlock<Packetc,1> acc0, acc1; 1840 1841 bsetzero<Scalar, Packet>(accReal); 1842 bsetzero<Scalar, Packet>(accImag); 1843 1844 Index remaining_depth = (depth & -accRows); 1845 Index k = 0; 1846 for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) 1847 { 1848 EIGEN_POWER_PREFETCH(rhs_ptr_real); 1849 if(!RhsIsReal) { 1850 EIGEN_POWER_PREFETCH(rhs_ptr_imag); 1851 } 1852 EIGEN_POWER_PREFETCH(lhs_ptr_real); 1853 if(!LhsIsReal) { 1854 EIGEN_POWER_PREFETCH(lhs_ptr_imag); 1855 } 1856 for (int l = 0; l < PEEL_COMPLEX; l++) { 1857 MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); 1858 } 1859 } 1860 for(; k < remaining_depth; k++) 1861 { 1862 MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); 1863 } 1864 1865 for(; k < depth; k++) 1866 { 1867 Packet rhsV[1], rhsVi[1]; 1868 rhsV[0] = pset1<Packet>(rhs_ptr_real[0]); 1869 if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]); 1870 pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); 1871 lhs_ptr_real += remaining_rows; 1872 if(!LhsIsReal) lhs_ptr_imag += remaining_rows; 1873 rhs_ptr_real += remaining_cols; 1874 if(!RhsIsReal) rhs_ptr_imag += remaining_cols; 1875 } 1876 1877 bscalec<Packet,1>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); 1878 bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1); 1879 1880 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) 1881 { 1882 res(row + 0, col + 0) += pfirst<Packetc>(acc0.packet[0]); 1883 } else { 1884 acc0.packet[0] += res.template loadPacket<Packetc>(row + 0, col + 0); 1885 res.template storePacketBlock<Packetc,1>(row + 0, col + 0, acc0); 1886 if(remaining_rows > accColsC) { 1887 res(row + accColsC, col + 0) += pfirst<Packetc>(acc1.packet[0]); 1888 } 1889 } 1890 } 1891 1892 template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1893 EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW( 1894 const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, 1895 const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, 1896 PacketBlock<Packet,4> &accReal, PacketBlock<Packet,4> &accImag, 1897 Index remaining_rows) 1898 { 1899 Packet rhsV[4], rhsVi[4]; 1900 pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); 1901 if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); 1902 pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); 1903 lhs_ptr_real += remaining_rows; 1904 if(!LhsIsReal) lhs_ptr_imag += remaining_rows; 1905 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1906 rhs_ptr_real += accRows; 1907 if(!RhsIsReal) rhs_ptr_imag += accRows; 1908 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); 1909 } 1910 1911 template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 1912 EIGEN_STRONG_INLINE void gemm_complex_extra_row( 1913 const DataMapper& res, 1914 const Scalar* lhs_base, 1915 const Scalar* rhs_base, 1916 Index depth, 1917 Index strideA, 1918 Index offsetA, 1919 Index strideB, 1920 Index row, 1921 Index col, 1922 Index rows, 1923 Index cols, 1924 Index remaining_rows, 1925 const Packet& pAlphaReal, 1926 const Packet& pAlphaImag, 1927 const Packet& pMask) 1928 { 1929 const Scalar* rhs_ptr_real = rhs_base; 1930 const Scalar* rhs_ptr_imag; 1931 if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB; 1932 else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); 1933 const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; 1934 const Scalar* lhs_ptr_imag; 1935 if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; 1936 else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); 1937 PacketBlock<Packet,4> accReal, accImag; 1938 PacketBlock<Packet,4> taccReal, taccImag; 1939 PacketBlock<Packetc,4> acc0, acc1; 1940 PacketBlock<Packetc,8> tRes; 1941 1942 bsetzero<Scalar, Packet>(accReal); 1943 bsetzero<Scalar, Packet>(accImag); 1944 1945 Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); 1946 Index k = 0; 1947 for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) 1948 { 1949 EIGEN_POWER_PREFETCH(rhs_ptr_real); 1950 if(!RhsIsReal) { 1951 EIGEN_POWER_PREFETCH(rhs_ptr_imag); 1952 } 1953 EIGEN_POWER_PREFETCH(lhs_ptr_real); 1954 if(!LhsIsReal) { 1955 EIGEN_POWER_PREFETCH(lhs_ptr_imag); 1956 } 1957 for (int l = 0; l < PEEL_COMPLEX; l++) { 1958 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); 1959 } 1960 } 1961 for(; k < remaining_depth; k++) 1962 { 1963 MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); 1964 } 1965 1966 if ((remaining_depth == depth) && (rows >= accCols)) 1967 { 1968 bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row, col); 1969 bscalec<Packet>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); 1970 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); 1971 res.template storePacketBlock<Packetc,4>(row + 0, col, acc0); 1972 res.template storePacketBlock<Packetc,4>(row + accColsC, col, acc1); 1973 } else { 1974 for(; k < depth; k++) 1975 { 1976 Packet rhsV[4], rhsVi[4]; 1977 pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); 1978 if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); 1979 pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); 1980 lhs_ptr_real += remaining_rows; 1981 if(!LhsIsReal) lhs_ptr_imag += remaining_rows; 1982 rhs_ptr_real += accRows; 1983 if(!RhsIsReal) rhs_ptr_imag += accRows; 1984 } 1985 1986 bscalec<Packet,4>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); 1987 bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1); 1988 1989 if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) 1990 { 1991 for(Index j = 0; j < 4; j++) { 1992 res(row + 0, col + j) += pfirst<Packetc>(acc0.packet[j]); 1993 } 1994 } else { 1995 for(Index j = 0; j < 4; j++) { 1996 PacketBlock<Packetc,1> acc2; 1997 acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, col + j) + acc0.packet[j]; 1998 res.template storePacketBlock<Packetc,1>(row + 0, col + j, acc2); 1999 if(remaining_rows > accColsC) { 2000 res(row + accColsC, col + j) += pfirst<Packetc>(acc1.packet[j]); 2001 } 2002 } 2003 } 2004 } 2005 } 2006 2007 #define MICRO_COMPLEX_UNROLL(func) \ 2008 func(0) func(1) func(2) func(3) func(4) 2009 2010 #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ 2011 MICRO_COMPLEX_UNROLL(func2); \ 2012 func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel) 2013 2014 #define MICRO_COMPLEX_LOAD_ONE(iter) \ 2015 if (unroll_factor > iter) { \ 2016 lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \ 2017 lhs_ptr_real##iter += accCols; \ 2018 if(!LhsIsReal) { \ 2019 lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \ 2020 lhs_ptr_imag##iter += accCols; \ 2021 } else { \ 2022 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ 2023 } \ 2024 } else { \ 2025 EIGEN_UNUSED_VARIABLE(lhsV##iter); \ 2026 EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ 2027 } 2028 2029 #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \ 2030 if (unroll_factor > iter) { \ 2031 pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ 2032 } 2033 2034 #define MICRO_COMPLEX_WORK_ONE1(iter, peel) \ 2035 if (unroll_factor > iter) { \ 2036 pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ 2037 } 2038 2039 #define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \ 2040 if (PEEL_COMPLEX > peel) { \ 2041 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ 2042 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ 2043 pbroadcast4_old<Packet>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ 2044 if(!RhsIsReal) { \ 2045 pbroadcast4_old<Packet>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ 2046 } else { \ 2047 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ 2048 } \ 2049 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ 2050 } else { \ 2051 EIGEN_UNUSED_VARIABLE(rhsV##peel); \ 2052 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ 2053 } 2054 2055 #define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \ 2056 if (PEEL_COMPLEX > peel) { \ 2057 Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ 2058 Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ 2059 rhsV##peel[0] = pset1<Packet>(rhs_ptr_real[remaining_cols * peel]); \ 2060 if(!RhsIsReal) { \ 2061 rhsVi##peel[0] = pset1<Packet>(rhs_ptr_imag[remaining_cols * peel]); \ 2062 } else { \ 2063 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ 2064 } \ 2065 MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ 2066 } else { \ 2067 EIGEN_UNUSED_VARIABLE(rhsV##peel); \ 2068 EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ 2069 } 2070 2071 #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \ 2072 Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ 2073 Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \ 2074 func(func1,func2,0); func(func1,func2,1); \ 2075 func(func1,func2,2); func(func1,func2,3); \ 2076 func(func1,func2,4); func(func1,func2,5); \ 2077 func(func1,func2,6); func(func1,func2,7); \ 2078 func(func1,func2,8); func(func1,func2,9); 2079 2080 #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \ 2081 Packet rhsV0[M], rhsVi0[M];\ 2082 func(func1,func2,0); 2083 2084 #define MICRO_COMPLEX_ONE_PEEL4 \ 2085 MICRO_COMPLEX_UNROLL_TYPE_PEEL(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ 2086 rhs_ptr_real += (accRows * PEEL_COMPLEX); \ 2087 if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX); 2088 2089 #define MICRO_COMPLEX_ONE4 \ 2090 MICRO_COMPLEX_UNROLL_TYPE_ONE(4, MICRO_COMPLEX_TYPE_PEEL4, MICRO_COMPLEX_WORK_ONE4, MICRO_COMPLEX_LOAD_ONE); \ 2091 rhs_ptr_real += accRows; \ 2092 if(!RhsIsReal) rhs_ptr_imag += accRows; 2093 2094 #define MICRO_COMPLEX_ONE_PEEL1 \ 2095 MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ 2096 rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \ 2097 if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX); 2098 2099 #define MICRO_COMPLEX_ONE1 \ 2100 MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ 2101 rhs_ptr_real += remaining_cols; \ 2102 if(!RhsIsReal) rhs_ptr_imag += remaining_cols; 2103 2104 #define MICRO_COMPLEX_DST_PTR_ONE(iter) \ 2105 if (unroll_factor > iter) { \ 2106 bsetzero<Scalar, Packet>(accReal##iter); \ 2107 bsetzero<Scalar, Packet>(accImag##iter); \ 2108 } else { \ 2109 EIGEN_UNUSED_VARIABLE(accReal##iter); \ 2110 EIGEN_UNUSED_VARIABLE(accImag##iter); \ 2111 } 2112 2113 #define MICRO_COMPLEX_DST_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_DST_PTR_ONE) 2114 2115 #define MICRO_COMPLEX_SRC_PTR_ONE(iter) \ 2116 if (unroll_factor > iter) { \ 2117 lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ 2118 if(!LhsIsReal) { \ 2119 lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ 2120 } else { \ 2121 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ 2122 } \ 2123 } else { \ 2124 EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ 2125 EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ 2126 } 2127 2128 #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE) 2129 2130 #define MICRO_COMPLEX_PREFETCH_ONE(iter) \ 2131 if (unroll_factor > iter) { \ 2132 EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ 2133 if(!LhsIsReal) { \ 2134 EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ 2135 } \ 2136 } 2137 2138 #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) 2139 2140 #define MICRO_COMPLEX_STORE_ONE(iter) \ 2141 if (unroll_factor > iter) { \ 2142 bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \ 2143 bscalec<Packet,4>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ 2144 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \ 2145 res.template storePacketBlock<Packetc,4>(row + iter*accCols + 0, col, acc0); \ 2146 res.template storePacketBlock<Packetc,4>(row + iter*accCols + accColsC, col, acc1); \ 2147 } 2148 2149 #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE) 2150 2151 #define MICRO_COMPLEX_COL_STORE_ONE(iter) \ 2152 if (unroll_factor > iter) { \ 2153 bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \ 2154 bscalec<Packet,1>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ 2155 bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \ 2156 res.template storePacketBlock<Packetc,1>(row + iter*accCols + 0, col, acc0); \ 2157 res.template storePacketBlock<Packetc,1>(row + iter*accCols + accColsC, col, acc1); \ 2158 } 2159 2160 #define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE) 2161 2162 template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 2163 EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( 2164 const DataMapper& res, 2165 const Scalar* lhs_base, 2166 const Scalar* rhs_base, 2167 Index depth, 2168 Index strideA, 2169 Index offsetA, 2170 Index strideB, 2171 Index& row, 2172 Index col, 2173 const Packet& pAlphaReal, 2174 const Packet& pAlphaImag) 2175 { 2176 const Scalar* rhs_ptr_real = rhs_base; 2177 const Scalar* rhs_ptr_imag; 2178 if(!RhsIsReal) { 2179 rhs_ptr_imag = rhs_base + accRows*strideB; 2180 } else { 2181 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); 2182 } 2183 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; 2184 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; 2185 const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; 2186 PacketBlock<Packet,4> accReal0, accImag0, accReal1, accImag1; 2187 PacketBlock<Packet,4> accReal2, accImag2, accReal3, accImag3; 2188 PacketBlock<Packet,4> accReal4, accImag4; 2189 PacketBlock<Packet,4> taccReal, taccImag; 2190 PacketBlock<Packetc,4> acc0, acc1; 2191 PacketBlock<Packetc,8> tRes; 2192 2193 MICRO_COMPLEX_SRC_PTR 2194 MICRO_COMPLEX_DST_PTR 2195 2196 Index k = 0; 2197 for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) 2198 { 2199 EIGEN_POWER_PREFETCH(rhs_ptr_real); 2200 if(!RhsIsReal) { 2201 EIGEN_POWER_PREFETCH(rhs_ptr_imag); 2202 } 2203 MICRO_COMPLEX_PREFETCH 2204 MICRO_COMPLEX_ONE_PEEL4 2205 } 2206 for(; k < depth; k++) 2207 { 2208 MICRO_COMPLEX_ONE4 2209 } 2210 MICRO_COMPLEX_STORE 2211 2212 row += unroll_factor*accCols; 2213 } 2214 2215 template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 2216 EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration( 2217 const DataMapper& res, 2218 const Scalar* lhs_base, 2219 const Scalar* rhs_base, 2220 Index depth, 2221 Index strideA, 2222 Index offsetA, 2223 Index strideB, 2224 Index& row, 2225 Index col, 2226 Index remaining_cols, 2227 const Packet& pAlphaReal, 2228 const Packet& pAlphaImag) 2229 { 2230 const Scalar* rhs_ptr_real = rhs_base; 2231 const Scalar* rhs_ptr_imag; 2232 if(!RhsIsReal) { 2233 rhs_ptr_imag = rhs_base + remaining_cols*strideB; 2234 } else { 2235 EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); 2236 } 2237 const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; 2238 const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; 2239 const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; 2240 PacketBlock<Packet,1> accReal0, accImag0, accReal1, accImag1; 2241 PacketBlock<Packet,1> accReal2, accImag2, accReal3, accImag3; 2242 PacketBlock<Packet,1> accReal4, accImag4; 2243 PacketBlock<Packet,1> taccReal, taccImag; 2244 PacketBlock<Packetc,1> acc0, acc1; 2245 PacketBlock<Packetc,2> tRes; 2246 2247 MICRO_COMPLEX_SRC_PTR 2248 MICRO_COMPLEX_DST_PTR 2249 2250 Index k = 0; 2251 for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) 2252 { 2253 EIGEN_POWER_PREFETCH(rhs_ptr_real); 2254 if(!RhsIsReal) { 2255 EIGEN_POWER_PREFETCH(rhs_ptr_imag); 2256 } 2257 MICRO_COMPLEX_PREFETCH 2258 MICRO_COMPLEX_ONE_PEEL1 2259 } 2260 for(; k < depth; k++) 2261 { 2262 MICRO_COMPLEX_ONE1 2263 } 2264 MICRO_COMPLEX_COL_STORE 2265 2266 row += unroll_factor*accCols; 2267 } 2268 2269 template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 2270 EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( 2271 const DataMapper& res, 2272 const Scalar* lhs_base, 2273 const Scalar* rhs_base, 2274 Index depth, 2275 Index strideA, 2276 Index offsetA, 2277 Index strideB, 2278 Index& row, 2279 Index rows, 2280 Index col, 2281 Index remaining_cols, 2282 const Packet& pAlphaReal, 2283 const Packet& pAlphaImag) 2284 { 2285 #define MAX_COMPLEX_UNROLL 3 2286 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { 2287 gemm_complex_unrolled_col_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); 2288 } 2289 switch( (rows-row)/accCols ) { 2290 #if MAX_COMPLEX_UNROLL > 4 2291 case 4: 2292 gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); 2293 break; 2294 #endif 2295 #if MAX_COMPLEX_UNROLL > 3 2296 case 3: 2297 gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); 2298 break; 2299 #endif 2300 #if MAX_COMPLEX_UNROLL > 2 2301 case 2: 2302 gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); 2303 break; 2304 #endif 2305 #if MAX_COMPLEX_UNROLL > 1 2306 case 1: 2307 gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); 2308 break; 2309 #endif 2310 default: 2311 break; 2312 } 2313 #undef MAX_COMPLEX_UNROLL 2314 } 2315 2316 template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal> 2317 EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) 2318 { 2319 const Index remaining_rows = rows % accCols; 2320 const Index remaining_cols = cols % accRows; 2321 2322 if( strideA == -1 ) strideA = depth; 2323 if( strideB == -1 ) strideB = depth; 2324 2325 const Packet pAlphaReal = pset1<Packet>(alpha.real()); 2326 const Packet pAlphaImag = pset1<Packet>(alpha.imag()); 2327 const Packet pMask = bmask<Packet>((const int)(remaining_rows)); 2328 2329 const Scalar* blockA = (Scalar *) blockAc; 2330 const Scalar* blockB = (Scalar *) blockBc; 2331 2332 Index col = 0; 2333 for(; col + accRows <= cols; col += accRows) 2334 { 2335 const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; 2336 const Scalar* lhs_base = blockA; 2337 Index row = 0; 2338 2339 #define MAX_COMPLEX_UNROLL 3 2340 while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { 2341 gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); 2342 } 2343 switch( (rows-row)/accCols ) { 2344 #if MAX_COMPLEX_UNROLL > 4 2345 case 4: 2346 gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); 2347 break; 2348 #endif 2349 #if MAX_COMPLEX_UNROLL > 3 2350 case 3: 2351 gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); 2352 break; 2353 #endif 2354 #if MAX_COMPLEX_UNROLL > 2 2355 case 2: 2356 gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); 2357 break; 2358 #endif 2359 #if MAX_COMPLEX_UNROLL > 1 2360 case 1: 2361 gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); 2362 break; 2363 #endif 2364 default: 2365 break; 2366 } 2367 #undef MAX_COMPLEX_UNROLL 2368 2369 if(remaining_rows > 0) 2370 { 2371 gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); 2372 } 2373 } 2374 2375 if(remaining_cols > 0) 2376 { 2377 const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; 2378 const Scalar* lhs_base = blockA; 2379 2380 for(; col < cols; col++) 2381 { 2382 Index row = 0; 2383 2384 gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); 2385 2386 if (remaining_rows > 0) 2387 { 2388 gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); 2389 } 2390 rhs_base++; 2391 } 2392 } 2393 } 2394 2395 #undef accColsC 2396 #undef advanceCols 2397 #undef advanceRows 2398 2399 /************************************ 2400 * ppc64le template specializations * 2401 * **********************************/ 2402 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2403 struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2404 { 2405 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2406 }; 2407 2408 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2409 void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2410 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2411 { 2412 dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, true> pack; 2413 pack(blockA, lhs, depth, rows, stride, offset); 2414 } 2415 2416 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2417 struct gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2418 { 2419 void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2420 }; 2421 2422 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2423 void gemm_pack_lhs<double, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2424 ::operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2425 { 2426 dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, true> pack; 2427 pack(blockA, lhs, depth, rows, stride, offset); 2428 } 2429 2430 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK 2431 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2432 struct gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2433 { 2434 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2435 }; 2436 2437 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2438 void gemm_pack_rhs<double, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2439 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2440 { 2441 dhs_pack<double, Index, DataMapper, Packet2d, ColMajor, PanelMode, false> pack; 2442 pack(blockB, rhs, depth, cols, stride, offset); 2443 } 2444 2445 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2446 struct gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2447 { 2448 void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2449 }; 2450 2451 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2452 void gemm_pack_rhs<double, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2453 ::operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2454 { 2455 dhs_pack<double, Index, DataMapper, Packet2d, RowMajor, PanelMode, false> pack; 2456 pack(blockB, rhs, depth, cols, stride, offset); 2457 } 2458 #endif 2459 2460 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2461 struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2462 { 2463 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2464 }; 2465 2466 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2467 void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2468 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2469 { 2470 dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, true> pack; 2471 pack(blockA, lhs, depth, rows, stride, offset); 2472 } 2473 2474 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2475 struct gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2476 { 2477 void operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2478 }; 2479 2480 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2481 void gemm_pack_lhs<float, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2482 ::operator()(float* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2483 { 2484 dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, true> pack; 2485 pack(blockA, lhs, depth, rows, stride, offset); 2486 } 2487 2488 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2489 struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2490 { 2491 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2492 }; 2493 2494 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2495 void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2496 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2497 { 2498 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, true> pack; 2499 pack(blockA, lhs, depth, rows, stride, offset); 2500 } 2501 2502 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2503 struct gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2504 { 2505 void operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2506 }; 2507 2508 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2509 void gemm_pack_lhs<std::complex<float>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2510 ::operator()(std::complex<float>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2511 { 2512 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, true> pack; 2513 pack(blockA, lhs, depth, rows, stride, offset); 2514 } 2515 2516 #if EIGEN_ALTIVEC_USE_CUSTOM_PACK 2517 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2518 struct gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2519 { 2520 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2521 }; 2522 2523 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2524 void gemm_pack_rhs<float, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2525 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2526 { 2527 dhs_pack<float, Index, DataMapper, Packet4f, ColMajor, PanelMode, false> pack; 2528 pack(blockB, rhs, depth, cols, stride, offset); 2529 } 2530 2531 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2532 struct gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2533 { 2534 void operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2535 }; 2536 2537 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2538 void gemm_pack_rhs<float, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2539 ::operator()(float* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2540 { 2541 dhs_pack<float, Index, DataMapper, Packet4f, RowMajor, PanelMode, false> pack; 2542 pack(blockB, rhs, depth, cols, stride, offset); 2543 } 2544 #endif 2545 2546 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2547 struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2548 { 2549 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2550 }; 2551 2552 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2553 void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2554 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2555 { 2556 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, ColMajor, Conjugate, PanelMode, false> pack; 2557 pack(blockB, rhs, depth, cols, stride, offset); 2558 } 2559 2560 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2561 struct gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2562 { 2563 void operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2564 }; 2565 2566 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2567 void gemm_pack_rhs<std::complex<float>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2568 ::operator()(std::complex<float>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2569 { 2570 dhs_cpack<float, Index, DataMapper, Packet4f, Packet2cf, RowMajor, Conjugate, PanelMode, false> pack; 2571 pack(blockB, rhs, depth, cols, stride, offset); 2572 } 2573 2574 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2575 struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2576 { 2577 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2578 }; 2579 2580 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2581 void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, RowMajor, Conjugate, PanelMode> 2582 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2583 { 2584 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, true> pack; 2585 pack(blockA, lhs, depth, rows, stride, offset); 2586 } 2587 2588 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2589 struct gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2590 { 2591 void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride=0, Index offset=0); 2592 }; 2593 2594 template<typename Index, typename DataMapper, int Pack1, int Pack2, typename Packet, bool Conjugate, bool PanelMode> 2595 void gemm_pack_lhs<std::complex<double>, Index, DataMapper, Pack1, Pack2, Packet, ColMajor, Conjugate, PanelMode> 2596 ::operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset) 2597 { 2598 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, true> pack; 2599 pack(blockA, lhs, depth, rows, stride, offset); 2600 } 2601 2602 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2603 struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2604 { 2605 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2606 }; 2607 2608 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2609 void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, ColMajor, Conjugate, PanelMode> 2610 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2611 { 2612 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, ColMajor, Conjugate, PanelMode, false> pack; 2613 pack(blockB, rhs, depth, cols, stride, offset); 2614 } 2615 2616 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2617 struct gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2618 { 2619 void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0); 2620 }; 2621 2622 template<typename Index, typename DataMapper, int nr, bool Conjugate, bool PanelMode> 2623 void gemm_pack_rhs<std::complex<double>, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode> 2624 ::operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset) 2625 { 2626 dhs_cpack<double, Index, DataMapper, Packet2d, Packet1cd, RowMajor, Conjugate, PanelMode, false> pack; 2627 pack(blockB, rhs, depth, cols, stride, offset); 2628 } 2629 2630 // ********* gebp specializations ********* 2631 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2632 struct gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2633 { 2634 typedef typename quad_traits<float>::vectortype Packet; 2635 typedef typename quad_traits<float>::rhstype RhsPacket; 2636 2637 void operator()(const DataMapper& res, const float* blockA, const float* blockB, 2638 Index rows, Index depth, Index cols, float alpha, 2639 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2640 }; 2641 2642 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2643 void gebp_kernel<float, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2644 ::operator()(const DataMapper& res, const float* blockA, const float* blockB, 2645 Index rows, Index depth, Index cols, float alpha, 2646 Index strideA, Index strideB, Index offsetA, Index offsetB) 2647 { 2648 const Index accRows = quad_traits<float>::rows; 2649 const Index accCols = quad_traits<float>::size; 2650 void (*gemm_function)(const DataMapper&, const float*, const float*, Index, Index, Index, float, Index, Index, Index, Index); 2651 2652 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2653 //generate with MMA only 2654 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2655 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2656 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2657 gemm_function = &Eigen::internal::gemmMMA<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2658 } 2659 else{ 2660 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2661 } 2662 #else 2663 gemm_function = &Eigen::internal::gemm<float, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2664 #endif 2665 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2666 } 2667 2668 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2669 struct gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2670 { 2671 typedef Packet4f Packet; 2672 typedef Packet2cf Packetc; 2673 typedef Packet4f RhsPacket; 2674 2675 void operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB, 2676 Index rows, Index depth, Index cols, std::complex<float> alpha, 2677 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2678 }; 2679 2680 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2681 void gebp_kernel<std::complex<float>, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2682 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const std::complex<float>* blockB, 2683 Index rows, Index depth, Index cols, std::complex<float> alpha, 2684 Index strideA, Index strideB, Index offsetA, Index offsetB) 2685 { 2686 const Index accRows = quad_traits<float>::rows; 2687 const Index accCols = quad_traits<float>::size; 2688 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const std::complex<float>*, 2689 Index, Index, Index, std::complex<float>, Index, Index, Index, Index); 2690 2691 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2692 //generate with MMA only 2693 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2694 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2695 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2696 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2697 } 2698 else{ 2699 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2700 } 2701 #else 2702 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2703 #endif 2704 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2705 } 2706 2707 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2708 struct gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2709 { 2710 typedef Packet4f Packet; 2711 typedef Packet2cf Packetc; 2712 typedef Packet4f RhsPacket; 2713 2714 void operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB, 2715 Index rows, Index depth, Index cols, std::complex<float> alpha, 2716 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2717 }; 2718 2719 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2720 void gebp_kernel<float, std::complex<float>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2721 ::operator()(const DataMapper& res, const float* blockA, const std::complex<float>* blockB, 2722 Index rows, Index depth, Index cols, std::complex<float> alpha, 2723 Index strideA, Index strideB, Index offsetA, Index offsetB) 2724 { 2725 const Index accRows = quad_traits<float>::rows; 2726 const Index accCols = quad_traits<float>::size; 2727 void (*gemm_function)(const DataMapper&, const float*, const std::complex<float>*, 2728 Index, Index, Index, std::complex<float>, Index, Index, Index, Index); 2729 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2730 //generate with MMA only 2731 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2732 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2733 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2734 gemm_function = &Eigen::internal::gemm_complexMMA<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2735 } 2736 else{ 2737 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2738 } 2739 #else 2740 gemm_function = &Eigen::internal::gemm_complex<float, std::complex<float>, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2741 #endif 2742 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2743 } 2744 2745 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2746 struct gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2747 { 2748 typedef Packet4f Packet; 2749 typedef Packet2cf Packetc; 2750 typedef Packet4f RhsPacket; 2751 2752 void operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB, 2753 Index rows, Index depth, Index cols, std::complex<float> alpha, 2754 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2755 }; 2756 2757 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2758 void gebp_kernel<std::complex<float>, float, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2759 ::operator()(const DataMapper& res, const std::complex<float>* blockA, const float* blockB, 2760 Index rows, Index depth, Index cols, std::complex<float> alpha, 2761 Index strideA, Index strideB, Index offsetA, Index offsetB) 2762 { 2763 const Index accRows = quad_traits<float>::rows; 2764 const Index accCols = quad_traits<float>::size; 2765 void (*gemm_function)(const DataMapper&, const std::complex<float>*, const float*, 2766 Index, Index, Index, std::complex<float>, Index, Index, Index, Index); 2767 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2768 //generate with MMA only 2769 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2770 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2771 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2772 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2773 } 2774 else{ 2775 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2776 } 2777 #else 2778 gemm_function = &Eigen::internal::gemm_complex<std::complex<float>, float, std::complex<float>, float, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2779 #endif 2780 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2781 } 2782 2783 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2784 struct gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2785 { 2786 typedef typename quad_traits<double>::vectortype Packet; 2787 typedef typename quad_traits<double>::rhstype RhsPacket; 2788 2789 void operator()(const DataMapper& res, const double* blockA, const double* blockB, 2790 Index rows, Index depth, Index cols, double alpha, 2791 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2792 }; 2793 2794 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2795 void gebp_kernel<double, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2796 ::operator()(const DataMapper& res, const double* blockA, const double* blockB, 2797 Index rows, Index depth, Index cols, double alpha, 2798 Index strideA, Index strideB, Index offsetA, Index offsetB) 2799 { 2800 const Index accRows = quad_traits<double>::rows; 2801 const Index accCols = quad_traits<double>::size; 2802 void (*gemm_function)(const DataMapper&, const double*, const double*, Index, Index, Index, double, Index, Index, Index, Index); 2803 2804 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2805 //generate with MMA only 2806 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2807 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2808 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2809 gemm_function = &Eigen::internal::gemmMMA<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2810 } 2811 else{ 2812 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2813 } 2814 #else 2815 gemm_function = &Eigen::internal::gemm<double, Index, Packet, RhsPacket, DataMapper, accRows, accCols>; 2816 #endif 2817 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2818 } 2819 2820 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2821 struct gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2822 { 2823 typedef quad_traits<double>::vectortype Packet; 2824 typedef Packet1cd Packetc; 2825 typedef quad_traits<double>::rhstype RhsPacket; 2826 2827 void operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB, 2828 Index rows, Index depth, Index cols, std::complex<double> alpha, 2829 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2830 }; 2831 2832 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2833 void gebp_kernel<std::complex<double>, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2834 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const std::complex<double>* blockB, 2835 Index rows, Index depth, Index cols, std::complex<double> alpha, 2836 Index strideA, Index strideB, Index offsetA, Index offsetB) 2837 { 2838 const Index accRows = quad_traits<double>::rows; 2839 const Index accCols = quad_traits<double>::size; 2840 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const std::complex<double>*, 2841 Index, Index, Index, std::complex<double>, Index, Index, Index, Index); 2842 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2843 //generate with MMA only 2844 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2845 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2846 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2847 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2848 } 2849 else{ 2850 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2851 } 2852 #else 2853 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, false>; 2854 #endif 2855 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2856 } 2857 2858 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2859 struct gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2860 { 2861 typedef quad_traits<double>::vectortype Packet; 2862 typedef Packet1cd Packetc; 2863 typedef quad_traits<double>::rhstype RhsPacket; 2864 2865 void operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB, 2866 Index rows, Index depth, Index cols, std::complex<double> alpha, 2867 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2868 }; 2869 2870 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2871 void gebp_kernel<std::complex<double>, double, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2872 ::operator()(const DataMapper& res, const std::complex<double>* blockA, const double* blockB, 2873 Index rows, Index depth, Index cols, std::complex<double> alpha, 2874 Index strideA, Index strideB, Index offsetA, Index offsetB) 2875 { 2876 const Index accRows = quad_traits<double>::rows; 2877 const Index accCols = quad_traits<double>::size; 2878 void (*gemm_function)(const DataMapper&, const std::complex<double>*, const double*, 2879 Index, Index, Index, std::complex<double>, Index, Index, Index, Index); 2880 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2881 //generate with MMA only 2882 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2883 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2884 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2885 gemm_function = &Eigen::internal::gemm_complexMMA<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2886 } 2887 else{ 2888 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2889 } 2890 #else 2891 gemm_function = &Eigen::internal::gemm_complex<std::complex<double>, double, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, false, true>; 2892 #endif 2893 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2894 } 2895 2896 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2897 struct gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2898 { 2899 typedef quad_traits<double>::vectortype Packet; 2900 typedef Packet1cd Packetc; 2901 typedef quad_traits<double>::rhstype RhsPacket; 2902 2903 void operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB, 2904 Index rows, Index depth, Index cols, std::complex<double> alpha, 2905 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); 2906 }; 2907 2908 template<typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs, bool ConjugateRhs> 2909 void gebp_kernel<double, std::complex<double>, Index, DataMapper, mr, nr, ConjugateLhs, ConjugateRhs> 2910 ::operator()(const DataMapper& res, const double* blockA, const std::complex<double>* blockB, 2911 Index rows, Index depth, Index cols, std::complex<double> alpha, 2912 Index strideA, Index strideB, Index offsetA, Index offsetB) 2913 { 2914 const Index accRows = quad_traits<double>::rows; 2915 const Index accCols = quad_traits<double>::size; 2916 void (*gemm_function)(const DataMapper&, const double*, const std::complex<double>*, 2917 Index, Index, Index, std::complex<double>, Index, Index, Index, Index); 2918 #ifdef EIGEN_ALTIVEC_MMA_ONLY 2919 //generate with MMA only 2920 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2921 #elif defined(ALTIVEC_MMA_SUPPORT) && !defined(EIGEN_ALTIVEC_DISABLE_MMA) 2922 if (__builtin_cpu_supports ("arch_3_1") && __builtin_cpu_supports ("mma")){ 2923 gemm_function = &Eigen::internal::gemm_complexMMA<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2924 } 2925 else{ 2926 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2927 } 2928 #else 2929 gemm_function = &Eigen::internal::gemm_complex<double, std::complex<double>, std::complex<double>, double, Index, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, true, false>; 2930 #endif 2931 gemm_function(res, blockA, blockB, rows, depth, cols, alpha, strideA, strideB, offsetA, offsetB); 2932 } 2933 } // end namespace internal 2934 2935 } // end namespace Eigen 2936 2937 #endif // EIGEN_MATRIX_PRODUCT_ALTIVEC_H 2938