1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner (benoit.steiner.goog@gmail.com) 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_COMPLEX_AVX_H 11 #define EIGEN_COMPLEX_AVX_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 //---------- float ---------- 18 struct Packet4cf 19 { Packet4cfPacket4cf20 EIGEN_STRONG_INLINE Packet4cf() {} Packet4cfPacket4cf21 EIGEN_STRONG_INLINE explicit Packet4cf(const __m256& a) : v(a) {} 22 __m256 v; 23 }; 24 25 #ifndef EIGEN_VECTORIZE_AVX512 26 template<> struct packet_traits<std::complex<float> > : default_packet_traits 27 { 28 typedef Packet4cf type; 29 typedef Packet2cf half; 30 enum { 31 Vectorizable = 1, 32 AlignedOnScalar = 1, 33 size = 4, 34 HasHalfPacket = 1, 35 36 HasAdd = 1, 37 HasSub = 1, 38 HasMul = 1, 39 HasDiv = 1, 40 HasNegate = 1, 41 HasSqrt = 1, 42 HasAbs = 0, 43 HasAbs2 = 0, 44 HasMin = 0, 45 HasMax = 0, 46 HasSetLinear = 0 47 }; 48 }; 49 #endif 50 51 template<> struct unpacket_traits<Packet4cf> { 52 typedef std::complex<float> type; 53 typedef Packet2cf half; 54 typedef Packet8f as_real; 55 enum { 56 size=4, 57 alignment=Aligned32, 58 vectorizable=true, 59 masked_load_available=false, 60 masked_store_available=false 61 }; 62 }; 63 64 template<> EIGEN_STRONG_INLINE Packet4cf padd<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_add_ps(a.v,b.v)); } 65 template<> EIGEN_STRONG_INLINE Packet4cf psub<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_sub_ps(a.v,b.v)); } 66 template<> EIGEN_STRONG_INLINE Packet4cf pnegate(const Packet4cf& a) 67 { 68 return Packet4cf(pnegate(a.v)); 69 } 70 template<> EIGEN_STRONG_INLINE Packet4cf pconj(const Packet4cf& a) 71 { 72 const __m256 mask = _mm256_castsi256_ps(_mm256_setr_epi32(0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000)); 73 return Packet4cf(_mm256_xor_ps(a.v,mask)); 74 } 75 76 template<> EIGEN_STRONG_INLINE Packet4cf pmul<Packet4cf>(const Packet4cf& a, const Packet4cf& b) 77 { 78 __m256 tmp1 = _mm256_mul_ps(_mm256_moveldup_ps(a.v), b.v); 79 __m256 tmp2 = _mm256_mul_ps(_mm256_movehdup_ps(a.v), _mm256_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1))); 80 __m256 result = _mm256_addsub_ps(tmp1, tmp2); 81 return Packet4cf(result); 82 } 83 84 template <> 85 EIGEN_STRONG_INLINE Packet4cf pcmp_eq(const Packet4cf& a, const Packet4cf& b) { 86 __m256 eq = _mm256_cmp_ps(a.v, b.v, _CMP_EQ_OQ); 87 return Packet4cf(_mm256_and_ps(eq, _mm256_permute_ps(eq, 0xb1))); 88 } 89 90 template<> EIGEN_STRONG_INLINE Packet4cf ptrue<Packet4cf>(const Packet4cf& a) { return Packet4cf(ptrue(Packet8f(a.v))); } 91 template<> EIGEN_STRONG_INLINE Packet4cf pand <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_and_ps(a.v,b.v)); } 92 template<> EIGEN_STRONG_INLINE Packet4cf por <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_or_ps(a.v,b.v)); } 93 template<> EIGEN_STRONG_INLINE Packet4cf pxor <Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_xor_ps(a.v,b.v)); } 94 template<> EIGEN_STRONG_INLINE Packet4cf pandnot<Packet4cf>(const Packet4cf& a, const Packet4cf& b) { return Packet4cf(_mm256_andnot_ps(b.v,a.v)); } 95 96 template<> EIGEN_STRONG_INLINE Packet4cf pload <Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cf(pload<Packet8f>(&numext::real_ref(*from))); } 97 template<> EIGEN_STRONG_INLINE Packet4cf ploadu<Packet4cf>(const std::complex<float>* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cf(ploadu<Packet8f>(&numext::real_ref(*from))); } 98 99 100 template<> EIGEN_STRONG_INLINE Packet4cf pset1<Packet4cf>(const std::complex<float>& from) 101 { 102 return Packet4cf(_mm256_castpd_ps(_mm256_broadcast_sd((const double*)(const void*)&from))); 103 } 104 105 template<> EIGEN_STRONG_INLINE Packet4cf ploaddup<Packet4cf>(const std::complex<float>* from) 106 { 107 // FIXME The following might be optimized using _mm256_movedup_pd 108 Packet2cf a = ploaddup<Packet2cf>(from); 109 Packet2cf b = ploaddup<Packet2cf>(from+1); 110 return Packet4cf(_mm256_insertf128_ps(_mm256_castps128_ps256(a.v), b.v, 1)); 111 } 112 113 template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float>* to, const Packet4cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); } 114 template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float>* to, const Packet4cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); } 115 116 template<> EIGEN_DEVICE_FUNC inline Packet4cf pgather<std::complex<float>, Packet4cf>(const std::complex<float>* from, Index stride) 117 { 118 return Packet4cf(_mm256_set_ps(std::imag(from[3*stride]), std::real(from[3*stride]), 119 std::imag(from[2*stride]), std::real(from[2*stride]), 120 std::imag(from[1*stride]), std::real(from[1*stride]), 121 std::imag(from[0*stride]), std::real(from[0*stride]))); 122 } 123 124 template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<float>, Packet4cf>(std::complex<float>* to, const Packet4cf& from, Index stride) 125 { 126 __m128 low = _mm256_extractf128_ps(from.v, 0); 127 to[stride*0] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 0)), 128 _mm_cvtss_f32(_mm_shuffle_ps(low, low, 1))); 129 to[stride*1] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(low, low, 2)), 130 _mm_cvtss_f32(_mm_shuffle_ps(low, low, 3))); 131 132 __m128 high = _mm256_extractf128_ps(from.v, 1); 133 to[stride*2] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 0)), 134 _mm_cvtss_f32(_mm_shuffle_ps(high, high, 1))); 135 to[stride*3] = std::complex<float>(_mm_cvtss_f32(_mm_shuffle_ps(high, high, 2)), 136 _mm_cvtss_f32(_mm_shuffle_ps(high, high, 3))); 137 138 } 139 140 template<> EIGEN_STRONG_INLINE std::complex<float> pfirst<Packet4cf>(const Packet4cf& a) 141 { 142 return pfirst(Packet2cf(_mm256_castps256_ps128(a.v))); 143 } 144 145 template<> EIGEN_STRONG_INLINE Packet4cf preverse(const Packet4cf& a) { 146 __m128 low = _mm256_extractf128_ps(a.v, 0); 147 __m128 high = _mm256_extractf128_ps(a.v, 1); 148 __m128d lowd = _mm_castps_pd(low); 149 __m128d highd = _mm_castps_pd(high); 150 low = _mm_castpd_ps(_mm_shuffle_pd(lowd,lowd,0x1)); 151 high = _mm_castpd_ps(_mm_shuffle_pd(highd,highd,0x1)); 152 __m256 result = _mm256_setzero_ps(); 153 result = _mm256_insertf128_ps(result, low, 1); 154 result = _mm256_insertf128_ps(result, high, 0); 155 return Packet4cf(result); 156 } 157 158 template<> EIGEN_STRONG_INLINE std::complex<float> predux<Packet4cf>(const Packet4cf& a) 159 { 160 return predux(padd(Packet2cf(_mm256_extractf128_ps(a.v,0)), 161 Packet2cf(_mm256_extractf128_ps(a.v,1)))); 162 } 163 164 template<> EIGEN_STRONG_INLINE std::complex<float> predux_mul<Packet4cf>(const Packet4cf& a) 165 { 166 return predux_mul(pmul(Packet2cf(_mm256_extractf128_ps(a.v, 0)), 167 Packet2cf(_mm256_extractf128_ps(a.v, 1)))); 168 } 169 170 EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cf,Packet8f) 171 172 template<> EIGEN_STRONG_INLINE Packet4cf pdiv<Packet4cf>(const Packet4cf& a, const Packet4cf& b) 173 { 174 Packet4cf num = pmul(a, pconj(b)); 175 __m256 tmp = _mm256_mul_ps(b.v, b.v); 176 __m256 tmp2 = _mm256_shuffle_ps(tmp,tmp,0xB1); 177 __m256 denom = _mm256_add_ps(tmp, tmp2); 178 return Packet4cf(_mm256_div_ps(num.v, denom)); 179 } 180 181 template<> EIGEN_STRONG_INLINE Packet4cf pcplxflip<Packet4cf>(const Packet4cf& x) 182 { 183 return Packet4cf(_mm256_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1))); 184 } 185 186 //---------- double ---------- 187 struct Packet2cd 188 { 189 EIGEN_STRONG_INLINE Packet2cd() {} 190 EIGEN_STRONG_INLINE explicit Packet2cd(const __m256d& a) : v(a) {} 191 __m256d v; 192 }; 193 194 #ifndef EIGEN_VECTORIZE_AVX512 195 template<> struct packet_traits<std::complex<double> > : default_packet_traits 196 { 197 typedef Packet2cd type; 198 typedef Packet1cd half; 199 enum { 200 Vectorizable = 1, 201 AlignedOnScalar = 0, 202 size = 2, 203 HasHalfPacket = 1, 204 205 HasAdd = 1, 206 HasSub = 1, 207 HasMul = 1, 208 HasDiv = 1, 209 HasNegate = 1, 210 HasSqrt = 1, 211 HasAbs = 0, 212 HasAbs2 = 0, 213 HasMin = 0, 214 HasMax = 0, 215 HasSetLinear = 0 216 }; 217 }; 218 #endif 219 220 template<> struct unpacket_traits<Packet2cd> { 221 typedef std::complex<double> type; 222 typedef Packet1cd half; 223 typedef Packet4d as_real; 224 enum { 225 size=2, 226 alignment=Aligned32, 227 vectorizable=true, 228 masked_load_available=false, 229 masked_store_available=false 230 }; 231 }; 232 233 template<> EIGEN_STRONG_INLINE Packet2cd padd<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_add_pd(a.v,b.v)); } 234 template<> EIGEN_STRONG_INLINE Packet2cd psub<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_sub_pd(a.v,b.v)); } 235 template<> EIGEN_STRONG_INLINE Packet2cd pnegate(const Packet2cd& a) { return Packet2cd(pnegate(a.v)); } 236 template<> EIGEN_STRONG_INLINE Packet2cd pconj(const Packet2cd& a) 237 { 238 const __m256d mask = _mm256_castsi256_pd(_mm256_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0)); 239 return Packet2cd(_mm256_xor_pd(a.v,mask)); 240 } 241 242 template<> EIGEN_STRONG_INLINE Packet2cd pmul<Packet2cd>(const Packet2cd& a, const Packet2cd& b) 243 { 244 __m256d tmp1 = _mm256_shuffle_pd(a.v,a.v,0x0); 245 __m256d even = _mm256_mul_pd(tmp1, b.v); 246 __m256d tmp2 = _mm256_shuffle_pd(a.v,a.v,0xF); 247 __m256d tmp3 = _mm256_shuffle_pd(b.v,b.v,0x5); 248 __m256d odd = _mm256_mul_pd(tmp2, tmp3); 249 return Packet2cd(_mm256_addsub_pd(even, odd)); 250 } 251 252 template <> 253 EIGEN_STRONG_INLINE Packet2cd pcmp_eq(const Packet2cd& a, const Packet2cd& b) { 254 __m256d eq = _mm256_cmp_pd(a.v, b.v, _CMP_EQ_OQ); 255 return Packet2cd(pand(eq, _mm256_permute_pd(eq, 0x5))); 256 } 257 258 template<> EIGEN_STRONG_INLINE Packet2cd ptrue<Packet2cd>(const Packet2cd& a) { return Packet2cd(ptrue(Packet4d(a.v))); } 259 template<> EIGEN_STRONG_INLINE Packet2cd pand <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_and_pd(a.v,b.v)); } 260 template<> EIGEN_STRONG_INLINE Packet2cd por <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_or_pd(a.v,b.v)); } 261 template<> EIGEN_STRONG_INLINE Packet2cd pxor <Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_xor_pd(a.v,b.v)); } 262 template<> EIGEN_STRONG_INLINE Packet2cd pandnot<Packet2cd>(const Packet2cd& a, const Packet2cd& b) { return Packet2cd(_mm256_andnot_pd(b.v,a.v)); } 263 264 template<> EIGEN_STRONG_INLINE Packet2cd pload <Packet2cd>(const std::complex<double>* from) 265 { EIGEN_DEBUG_ALIGNED_LOAD return Packet2cd(pload<Packet4d>((const double*)from)); } 266 template<> EIGEN_STRONG_INLINE Packet2cd ploadu<Packet2cd>(const std::complex<double>* from) 267 { EIGEN_DEBUG_UNALIGNED_LOAD return Packet2cd(ploadu<Packet4d>((const double*)from)); } 268 269 template<> EIGEN_STRONG_INLINE Packet2cd pset1<Packet2cd>(const std::complex<double>& from) 270 { 271 // in case casting to a __m128d* is really not safe, then we can still fallback to this version: (much slower though) 272 // return Packet2cd(_mm256_loadu2_m128d((const double*)&from,(const double*)&from)); 273 return Packet2cd(_mm256_broadcast_pd((const __m128d*)(const void*)&from)); 274 } 275 276 template<> EIGEN_STRONG_INLINE Packet2cd ploaddup<Packet2cd>(const std::complex<double>* from) { return pset1<Packet2cd>(*from); } 277 278 template<> EIGEN_STRONG_INLINE void pstore <std::complex<double> >(std::complex<double> * to, const Packet2cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); } 279 template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<double> >(std::complex<double> * to, const Packet2cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); } 280 281 template<> EIGEN_DEVICE_FUNC inline Packet2cd pgather<std::complex<double>, Packet2cd>(const std::complex<double>* from, Index stride) 282 { 283 return Packet2cd(_mm256_set_pd(std::imag(from[1*stride]), std::real(from[1*stride]), 284 std::imag(from[0*stride]), std::real(from[0*stride]))); 285 } 286 287 template<> EIGEN_DEVICE_FUNC inline void pscatter<std::complex<double>, Packet2cd>(std::complex<double>* to, const Packet2cd& from, Index stride) 288 { 289 __m128d low = _mm256_extractf128_pd(from.v, 0); 290 to[stride*0] = std::complex<double>(_mm_cvtsd_f64(low), _mm_cvtsd_f64(_mm_shuffle_pd(low, low, 1))); 291 __m128d high = _mm256_extractf128_pd(from.v, 1); 292 to[stride*1] = std::complex<double>(_mm_cvtsd_f64(high), _mm_cvtsd_f64(_mm_shuffle_pd(high, high, 1))); 293 } 294 295 template<> EIGEN_STRONG_INLINE std::complex<double> pfirst<Packet2cd>(const Packet2cd& a) 296 { 297 __m128d low = _mm256_extractf128_pd(a.v, 0); 298 EIGEN_ALIGN16 double res[2]; 299 _mm_store_pd(res, low); 300 return std::complex<double>(res[0],res[1]); 301 } 302 303 template<> EIGEN_STRONG_INLINE Packet2cd preverse(const Packet2cd& a) { 304 __m256d result = _mm256_permute2f128_pd(a.v, a.v, 1); 305 return Packet2cd(result); 306 } 307 308 template<> EIGEN_STRONG_INLINE std::complex<double> predux<Packet2cd>(const Packet2cd& a) 309 { 310 return predux(padd(Packet1cd(_mm256_extractf128_pd(a.v,0)), 311 Packet1cd(_mm256_extractf128_pd(a.v,1)))); 312 } 313 314 template<> EIGEN_STRONG_INLINE std::complex<double> predux_mul<Packet2cd>(const Packet2cd& a) 315 { 316 return predux(pmul(Packet1cd(_mm256_extractf128_pd(a.v,0)), 317 Packet1cd(_mm256_extractf128_pd(a.v,1)))); 318 } 319 320 EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet2cd,Packet4d) 321 322 template<> EIGEN_STRONG_INLINE Packet2cd pdiv<Packet2cd>(const Packet2cd& a, const Packet2cd& b) 323 { 324 Packet2cd num = pmul(a, pconj(b)); 325 __m256d tmp = _mm256_mul_pd(b.v, b.v); 326 __m256d denom = _mm256_hadd_pd(tmp, tmp); 327 return Packet2cd(_mm256_div_pd(num.v, denom)); 328 } 329 330 template<> EIGEN_STRONG_INLINE Packet2cd pcplxflip<Packet2cd>(const Packet2cd& x) 331 { 332 return Packet2cd(_mm256_shuffle_pd(x.v, x.v, 0x5)); 333 } 334 335 EIGEN_DEVICE_FUNC inline void 336 ptranspose(PacketBlock<Packet4cf,4>& kernel) { 337 __m256d P0 = _mm256_castps_pd(kernel.packet[0].v); 338 __m256d P1 = _mm256_castps_pd(kernel.packet[1].v); 339 __m256d P2 = _mm256_castps_pd(kernel.packet[2].v); 340 __m256d P3 = _mm256_castps_pd(kernel.packet[3].v); 341 342 __m256d T0 = _mm256_shuffle_pd(P0, P1, 15); 343 __m256d T1 = _mm256_shuffle_pd(P0, P1, 0); 344 __m256d T2 = _mm256_shuffle_pd(P2, P3, 15); 345 __m256d T3 = _mm256_shuffle_pd(P2, P3, 0); 346 347 kernel.packet[1].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T0, T2, 32)); 348 kernel.packet[3].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T0, T2, 49)); 349 kernel.packet[0].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 32)); 350 kernel.packet[2].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 49)); 351 } 352 353 EIGEN_DEVICE_FUNC inline void 354 ptranspose(PacketBlock<Packet2cd,2>& kernel) { 355 __m256d tmp = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 0+(2<<4)); 356 kernel.packet[1].v = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 1+(3<<4)); 357 kernel.packet[0].v = tmp; 358 } 359 360 template<> EIGEN_STRONG_INLINE Packet2cd psqrt<Packet2cd>(const Packet2cd& a) { 361 return psqrt_complex<Packet2cd>(a); 362 } 363 364 template<> EIGEN_STRONG_INLINE Packet4cf psqrt<Packet4cf>(const Packet4cf& a) { 365 return psqrt_complex<Packet4cf>(a); 366 } 367 368 } // end namespace internal 369 370 } // end namespace Eigen 371 372 #endif // EIGEN_COMPLEX_AVX_H 373