• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2020, Arm Limited and Contributors
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_PACKET_MATH_SVE_H
11 #define EIGEN_PACKET_MATH_SVE_H
12 
13 namespace Eigen
14 {
15 namespace internal
16 {
17 #ifndef EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD
18 #define EIGEN_CACHEFRIENDLY_PRODUCT_THRESHOLD 8
19 #endif
20 
21 #ifndef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
22 #define EIGEN_HAS_SINGLE_INSTRUCTION_MADD
23 #endif
24 
25 #define EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS 32
26 
27 template <typename Scalar, int SVEVectorLength>
28 struct sve_packet_size_selector {
29   enum { size = SVEVectorLength / (sizeof(Scalar) * CHAR_BIT) };
30 };
31 
32 /********************************* int32 **************************************/
33 typedef svint32_t PacketXi __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
34 
35 template <>
36 struct packet_traits<numext::int32_t> : default_packet_traits {
37   typedef PacketXi type;
38   typedef PacketXi half;  // Half not implemented yet
39   enum {
40     Vectorizable = 1,
41     AlignedOnScalar = 1,
42     size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
43     HasHalfPacket = 0,
44 
45     HasAdd = 1,
46     HasSub = 1,
47     HasShift = 1,
48     HasMul = 1,
49     HasNegate = 1,
50     HasAbs = 1,
51     HasArg = 0,
52     HasAbs2 = 1,
53     HasMin = 1,
54     HasMax = 1,
55     HasConj = 1,
56     HasSetLinear = 0,
57     HasBlend = 0,
58     HasReduxp = 0  // Not implemented in SVE
59   };
60 };
61 
62 template <>
63 struct unpacket_traits<PacketXi> {
64   typedef numext::int32_t type;
65   typedef PacketXi half;  // Half not yet implemented
66   enum {
67     size = sve_packet_size_selector<numext::int32_t, EIGEN_ARM64_SVE_VL>::size,
68     alignment = Aligned64,
69     vectorizable = true,
70     masked_load_available = false,
71     masked_store_available = false
72   };
73 };
74 
75 template <>
76 EIGEN_STRONG_INLINE void prefetch<numext::int32_t>(const numext::int32_t* addr)
77 {
78   svprfw(svptrue_b32(), addr, SV_PLDL1KEEP);
79 }
80 
81 template <>
82 EIGEN_STRONG_INLINE PacketXi pset1<PacketXi>(const numext::int32_t& from)
83 {
84   return svdup_n_s32(from);
85 }
86 
87 template <>
88 EIGEN_STRONG_INLINE PacketXi plset<PacketXi>(const numext::int32_t& a)
89 {
90   numext::int32_t c[packet_traits<numext::int32_t>::size];
91   for (int i = 0; i < packet_traits<numext::int32_t>::size; i++) c[i] = i;
92   return svadd_s32_z(svptrue_b32(), pset1<PacketXi>(a), svld1_s32(svptrue_b32(), c));
93 }
94 
95 template <>
96 EIGEN_STRONG_INLINE PacketXi padd<PacketXi>(const PacketXi& a, const PacketXi& b)
97 {
98   return svadd_s32_z(svptrue_b32(), a, b);
99 }
100 
101 template <>
102 EIGEN_STRONG_INLINE PacketXi psub<PacketXi>(const PacketXi& a, const PacketXi& b)
103 {
104   return svsub_s32_z(svptrue_b32(), a, b);
105 }
106 
107 template <>
108 EIGEN_STRONG_INLINE PacketXi pnegate(const PacketXi& a)
109 {
110   return svneg_s32_z(svptrue_b32(), a);
111 }
112 
113 template <>
114 EIGEN_STRONG_INLINE PacketXi pconj(const PacketXi& a)
115 {
116   return a;
117 }
118 
119 template <>
120 EIGEN_STRONG_INLINE PacketXi pmul<PacketXi>(const PacketXi& a, const PacketXi& b)
121 {
122   return svmul_s32_z(svptrue_b32(), a, b);
123 }
124 
125 template <>
126 EIGEN_STRONG_INLINE PacketXi pdiv<PacketXi>(const PacketXi& a, const PacketXi& b)
127 {
128   return svdiv_s32_z(svptrue_b32(), a, b);
129 }
130 
131 template <>
132 EIGEN_STRONG_INLINE PacketXi pmadd(const PacketXi& a, const PacketXi& b, const PacketXi& c)
133 {
134   return svmla_s32_z(svptrue_b32(), c, a, b);
135 }
136 
137 template <>
138 EIGEN_STRONG_INLINE PacketXi pmin<PacketXi>(const PacketXi& a, const PacketXi& b)
139 {
140   return svmin_s32_z(svptrue_b32(), a, b);
141 }
142 
143 template <>
144 EIGEN_STRONG_INLINE PacketXi pmax<PacketXi>(const PacketXi& a, const PacketXi& b)
145 {
146   return svmax_s32_z(svptrue_b32(), a, b);
147 }
148 
149 template <>
150 EIGEN_STRONG_INLINE PacketXi pcmp_le<PacketXi>(const PacketXi& a, const PacketXi& b)
151 {
152   return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
153 }
154 
155 template <>
156 EIGEN_STRONG_INLINE PacketXi pcmp_lt<PacketXi>(const PacketXi& a, const PacketXi& b)
157 {
158   return svdup_n_s32_z(svcmplt_s32(svptrue_b32(), a, b), 0xffffffffu);
159 }
160 
161 template <>
162 EIGEN_STRONG_INLINE PacketXi pcmp_eq<PacketXi>(const PacketXi& a, const PacketXi& b)
163 {
164   return svdup_n_s32_z(svcmpeq_s32(svptrue_b32(), a, b), 0xffffffffu);
165 }
166 
167 template <>
168 EIGEN_STRONG_INLINE PacketXi ptrue<PacketXi>(const PacketXi& /*a*/)
169 {
170   return svdup_n_s32_z(svptrue_b32(), 0xffffffffu);
171 }
172 
173 template <>
174 EIGEN_STRONG_INLINE PacketXi pzero<PacketXi>(const PacketXi& /*a*/)
175 {
176   return svdup_n_s32_z(svptrue_b32(), 0);
177 }
178 
179 template <>
180 EIGEN_STRONG_INLINE PacketXi pand<PacketXi>(const PacketXi& a, const PacketXi& b)
181 {
182   return svand_s32_z(svptrue_b32(), a, b);
183 }
184 
185 template <>
186 EIGEN_STRONG_INLINE PacketXi por<PacketXi>(const PacketXi& a, const PacketXi& b)
187 {
188   return svorr_s32_z(svptrue_b32(), a, b);
189 }
190 
191 template <>
192 EIGEN_STRONG_INLINE PacketXi pxor<PacketXi>(const PacketXi& a, const PacketXi& b)
193 {
194   return sveor_s32_z(svptrue_b32(), a, b);
195 }
196 
197 template <>
198 EIGEN_STRONG_INLINE PacketXi pandnot<PacketXi>(const PacketXi& a, const PacketXi& b)
199 {
200   return svbic_s32_z(svptrue_b32(), a, b);
201 }
202 
203 template <int N>
204 EIGEN_STRONG_INLINE PacketXi parithmetic_shift_right(PacketXi a)
205 {
206   return svasrd_n_s32_z(svptrue_b32(), a, N);
207 }
208 
209 template <int N>
210 EIGEN_STRONG_INLINE PacketXi plogical_shift_right(PacketXi a)
211 {
212   return svreinterpret_s32_u32(svlsr_u32_z(svptrue_b32(), svreinterpret_u32_s32(a), svdup_n_u32_z(svptrue_b32(), N)));
213 }
214 
215 template <int N>
216 EIGEN_STRONG_INLINE PacketXi plogical_shift_left(PacketXi a)
217 {
218   return svlsl_s32_z(svptrue_b32(), a, svdup_n_u32_z(svptrue_b32(), N));
219 }
220 
221 template <>
222 EIGEN_STRONG_INLINE PacketXi pload<PacketXi>(const numext::int32_t* from)
223 {
224   EIGEN_DEBUG_ALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
225 }
226 
227 template <>
228 EIGEN_STRONG_INLINE PacketXi ploadu<PacketXi>(const numext::int32_t* from)
229 {
230   EIGEN_DEBUG_UNALIGNED_LOAD return svld1_s32(svptrue_b32(), from);
231 }
232 
233 template <>
234 EIGEN_STRONG_INLINE PacketXi ploaddup<PacketXi>(const numext::int32_t* from)
235 {
236   svuint32_t indices = svindex_u32(0, 1);  // index {base=0, base+step=1, base+step*2, ...}
237   indices = svzip1_u32(indices, indices);  // index in the format {a0, a0, a1, a1, a2, a2, ...}
238   return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
239 }
240 
241 template <>
242 EIGEN_STRONG_INLINE PacketXi ploadquad<PacketXi>(const numext::int32_t* from)
243 {
244   svuint32_t indices = svindex_u32(0, 1);  // index {base=0, base+step=1, base+step*2, ...}
245   indices = svzip1_u32(indices, indices);  // index in the format {a0, a0, a1, a1, a2, a2, ...}
246   indices = svzip1_u32(indices, indices);  // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
247   return svld1_gather_u32index_s32(svptrue_b32(), from, indices);
248 }
249 
250 template <>
251 EIGEN_STRONG_INLINE void pstore<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
252 {
253   EIGEN_DEBUG_ALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
254 }
255 
256 template <>
257 EIGEN_STRONG_INLINE void pstoreu<numext::int32_t>(numext::int32_t* to, const PacketXi& from)
258 {
259   EIGEN_DEBUG_UNALIGNED_STORE svst1_s32(svptrue_b32(), to, from);
260 }
261 
262 template <>
263 EIGEN_DEVICE_FUNC inline PacketXi pgather<numext::int32_t, PacketXi>(const numext::int32_t* from, Index stride)
264 {
265   // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
266   svint32_t indices = svindex_s32(0, stride);
267   return svld1_gather_s32index_s32(svptrue_b32(), from, indices);
268 }
269 
270 template <>
271 EIGEN_DEVICE_FUNC inline void pscatter<numext::int32_t, PacketXi>(numext::int32_t* to, const PacketXi& from, Index stride)
272 {
273   // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
274   svint32_t indices = svindex_s32(0, stride);
275   svst1_scatter_s32index_s32(svptrue_b32(), to, indices, from);
276 }
277 
278 template <>
279 EIGEN_STRONG_INLINE numext::int32_t pfirst<PacketXi>(const PacketXi& a)
280 {
281   // svlasta returns the first element if all predicate bits are 0
282   return svlasta_s32(svpfalse_b(), a);
283 }
284 
285 template <>
286 EIGEN_STRONG_INLINE PacketXi preverse(const PacketXi& a)
287 {
288   return svrev_s32(a);
289 }
290 
291 template <>
292 EIGEN_STRONG_INLINE PacketXi pabs(const PacketXi& a)
293 {
294   return svabs_s32_z(svptrue_b32(), a);
295 }
296 
297 template <>
298 EIGEN_STRONG_INLINE numext::int32_t predux<PacketXi>(const PacketXi& a)
299 {
300   return static_cast<numext::int32_t>(svaddv_s32(svptrue_b32(), a));
301 }
302 
303 template <>
304 EIGEN_STRONG_INLINE numext::int32_t predux_mul<PacketXi>(const PacketXi& a)
305 {
306   EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
307                       EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
308 
309   // Multiply the vector by its reverse
310   svint32_t prod = svmul_s32_z(svptrue_b32(), a, svrev_s32(a));
311   svint32_t half_prod;
312 
313   // Extract the high half of the vector. Depending on the VL more reductions need to be done
314   if (EIGEN_ARM64_SVE_VL >= 2048) {
315     half_prod = svtbl_s32(prod, svindex_u32(32, 1));
316     prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
317   }
318   if (EIGEN_ARM64_SVE_VL >= 1024) {
319     half_prod = svtbl_s32(prod, svindex_u32(16, 1));
320     prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
321   }
322   if (EIGEN_ARM64_SVE_VL >= 512) {
323     half_prod = svtbl_s32(prod, svindex_u32(8, 1));
324     prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
325   }
326   if (EIGEN_ARM64_SVE_VL >= 256) {
327     half_prod = svtbl_s32(prod, svindex_u32(4, 1));
328     prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
329   }
330   // Last reduction
331   half_prod = svtbl_s32(prod, svindex_u32(2, 1));
332   prod = svmul_s32_z(svptrue_b32(), prod, half_prod);
333 
334   // The reduction is done to the first element.
335   return pfirst<PacketXi>(prod);
336 }
337 
338 template <>
339 EIGEN_STRONG_INLINE numext::int32_t predux_min<PacketXi>(const PacketXi& a)
340 {
341   return svminv_s32(svptrue_b32(), a);
342 }
343 
344 template <>
345 EIGEN_STRONG_INLINE numext::int32_t predux_max<PacketXi>(const PacketXi& a)
346 {
347   return svmaxv_s32(svptrue_b32(), a);
348 }
349 
350 template <int N>
351 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXi, N>& kernel) {
352   int buffer[packet_traits<numext::int32_t>::size * N] = {0};
353   int i = 0;
354 
355   PacketXi stride_index = svindex_s32(0, N);
356 
357   for (i = 0; i < N; i++) {
358     svst1_scatter_s32index_s32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
359   }
360   for (i = 0; i < N; i++) {
361     kernel.packet[i] = svld1_s32(svptrue_b32(), buffer + i * packet_traits<numext::int32_t>::size);
362   }
363 }
364 
365 /********************************* float32 ************************************/
366 
367 typedef svfloat32_t PacketXf __attribute__((arm_sve_vector_bits(EIGEN_ARM64_SVE_VL)));
368 
369 template <>
370 struct packet_traits<float> : default_packet_traits {
371   typedef PacketXf type;
372   typedef PacketXf half;
373 
374   enum {
375     Vectorizable = 1,
376     AlignedOnScalar = 1,
377     size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
378     HasHalfPacket = 0,
379 
380     HasAdd = 1,
381     HasSub = 1,
382     HasShift = 1,
383     HasMul = 1,
384     HasNegate = 1,
385     HasAbs = 1,
386     HasArg = 0,
387     HasAbs2 = 1,
388     HasMin = 1,
389     HasMax = 1,
390     HasConj = 1,
391     HasSetLinear = 0,
392     HasBlend = 0,
393     HasReduxp = 0,  // Not implemented in SVE
394 
395     HasDiv = 1,
396     HasFloor = 1,
397 
398     HasSin = EIGEN_FAST_MATH,
399     HasCos = EIGEN_FAST_MATH,
400     HasLog = 1,
401     HasExp = 1,
402     HasSqrt = 0,
403     HasTanh = EIGEN_FAST_MATH,
404     HasErf = EIGEN_FAST_MATH
405   };
406 };
407 
408 template <>
409 struct unpacket_traits<PacketXf> {
410   typedef float type;
411   typedef PacketXf half;  // Half not yet implemented
412   typedef PacketXi integer_packet;
413 
414   enum {
415     size = sve_packet_size_selector<float, EIGEN_ARM64_SVE_VL>::size,
416     alignment = Aligned64,
417     vectorizable = true,
418     masked_load_available = false,
419     masked_store_available = false
420   };
421 };
422 
423 template <>
424 EIGEN_STRONG_INLINE PacketXf pset1<PacketXf>(const float& from)
425 {
426   return svdup_n_f32(from);
427 }
428 
429 template <>
430 EIGEN_STRONG_INLINE PacketXf pset1frombits<PacketXf>(numext::uint32_t from)
431 {
432   return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), from));
433 }
434 
435 template <>
436 EIGEN_STRONG_INLINE PacketXf plset<PacketXf>(const float& a)
437 {
438   float c[packet_traits<float>::size];
439   for (int i = 0; i < packet_traits<float>::size; i++) c[i] = i;
440   return svadd_f32_z(svptrue_b32(), pset1<PacketXf>(a), svld1_f32(svptrue_b32(), c));
441 }
442 
443 template <>
444 EIGEN_STRONG_INLINE PacketXf padd<PacketXf>(const PacketXf& a, const PacketXf& b)
445 {
446   return svadd_f32_z(svptrue_b32(), a, b);
447 }
448 
449 template <>
450 EIGEN_STRONG_INLINE PacketXf psub<PacketXf>(const PacketXf& a, const PacketXf& b)
451 {
452   return svsub_f32_z(svptrue_b32(), a, b);
453 }
454 
455 template <>
456 EIGEN_STRONG_INLINE PacketXf pnegate(const PacketXf& a)
457 {
458   return svneg_f32_z(svptrue_b32(), a);
459 }
460 
461 template <>
462 EIGEN_STRONG_INLINE PacketXf pconj(const PacketXf& a)
463 {
464   return a;
465 }
466 
467 template <>
468 EIGEN_STRONG_INLINE PacketXf pmul<PacketXf>(const PacketXf& a, const PacketXf& b)
469 {
470   return svmul_f32_z(svptrue_b32(), a, b);
471 }
472 
473 template <>
474 EIGEN_STRONG_INLINE PacketXf pdiv<PacketXf>(const PacketXf& a, const PacketXf& b)
475 {
476   return svdiv_f32_z(svptrue_b32(), a, b);
477 }
478 
479 template <>
480 EIGEN_STRONG_INLINE PacketXf pmadd(const PacketXf& a, const PacketXf& b, const PacketXf& c)
481 {
482   return svmla_f32_z(svptrue_b32(), c, a, b);
483 }
484 
485 template <>
486 EIGEN_STRONG_INLINE PacketXf pmin<PacketXf>(const PacketXf& a, const PacketXf& b)
487 {
488   return svmin_f32_z(svptrue_b32(), a, b);
489 }
490 
491 template <>
492 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
493 {
494   return pmin<PacketXf>(a, b);
495 }
496 
497 template <>
498 EIGEN_STRONG_INLINE PacketXf pmin<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
499 {
500   return svminnm_f32_z(svptrue_b32(), a, b);
501 }
502 
503 template <>
504 EIGEN_STRONG_INLINE PacketXf pmax<PacketXf>(const PacketXf& a, const PacketXf& b)
505 {
506   return svmax_f32_z(svptrue_b32(), a, b);
507 }
508 
509 template <>
510 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNaN, PacketXf>(const PacketXf& a, const PacketXf& b)
511 {
512   return pmax<PacketXf>(a, b);
513 }
514 
515 template <>
516 EIGEN_STRONG_INLINE PacketXf pmax<PropagateNumbers, PacketXf>(const PacketXf& a, const PacketXf& b)
517 {
518   return svmaxnm_f32_z(svptrue_b32(), a, b);
519 }
520 
521 // Float comparisons in SVE return svbool (predicate). Use svdup to set active
522 // lanes to 1 (0xffffffffu) and inactive lanes to 0.
523 template <>
524 EIGEN_STRONG_INLINE PacketXf pcmp_le<PacketXf>(const PacketXf& a, const PacketXf& b)
525 {
526   return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
527 }
528 
529 template <>
530 EIGEN_STRONG_INLINE PacketXf pcmp_lt<PacketXf>(const PacketXf& a, const PacketXf& b)
531 {
532   return svreinterpret_f32_u32(svdup_n_u32_z(svcmplt_f32(svptrue_b32(), a, b), 0xffffffffu));
533 }
534 
535 template <>
536 EIGEN_STRONG_INLINE PacketXf pcmp_eq<PacketXf>(const PacketXf& a, const PacketXf& b)
537 {
538   return svreinterpret_f32_u32(svdup_n_u32_z(svcmpeq_f32(svptrue_b32(), a, b), 0xffffffffu));
539 }
540 
541 // Do a predicate inverse (svnot_b_z) on the predicate resulted from the
542 // greater/equal comparison (svcmpge_f32). Then fill a float vector with the
543 // active elements.
544 template <>
545 EIGEN_STRONG_INLINE PacketXf pcmp_lt_or_nan<PacketXf>(const PacketXf& a, const PacketXf& b)
546 {
547   return svreinterpret_f32_u32(svdup_n_u32_z(svnot_b_z(svptrue_b32(), svcmpge_f32(svptrue_b32(), a, b)), 0xffffffffu));
548 }
549 
550 template <>
551 EIGEN_STRONG_INLINE PacketXf pfloor<PacketXf>(const PacketXf& a)
552 {
553   return svrintm_f32_z(svptrue_b32(), a);
554 }
555 
556 template <>
557 EIGEN_STRONG_INLINE PacketXf ptrue<PacketXf>(const PacketXf& /*a*/)
558 {
559   return svreinterpret_f32_u32(svdup_n_u32_z(svptrue_b32(), 0xffffffffu));
560 }
561 
562 // Logical Operations are not supported for float, so reinterpret casts
563 template <>
564 EIGEN_STRONG_INLINE PacketXf pand<PacketXf>(const PacketXf& a, const PacketXf& b)
565 {
566   return svreinterpret_f32_u32(svand_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
567 }
568 
569 template <>
570 EIGEN_STRONG_INLINE PacketXf por<PacketXf>(const PacketXf& a, const PacketXf& b)
571 {
572   return svreinterpret_f32_u32(svorr_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
573 }
574 
575 template <>
576 EIGEN_STRONG_INLINE PacketXf pxor<PacketXf>(const PacketXf& a, const PacketXf& b)
577 {
578   return svreinterpret_f32_u32(sveor_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
579 }
580 
581 template <>
582 EIGEN_STRONG_INLINE PacketXf pandnot<PacketXf>(const PacketXf& a, const PacketXf& b)
583 {
584   return svreinterpret_f32_u32(svbic_u32_z(svptrue_b32(), svreinterpret_u32_f32(a), svreinterpret_u32_f32(b)));
585 }
586 
587 template <>
588 EIGEN_STRONG_INLINE PacketXf pload<PacketXf>(const float* from)
589 {
590   EIGEN_DEBUG_ALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
591 }
592 
593 template <>
594 EIGEN_STRONG_INLINE PacketXf ploadu<PacketXf>(const float* from)
595 {
596   EIGEN_DEBUG_UNALIGNED_LOAD return svld1_f32(svptrue_b32(), from);
597 }
598 
599 template <>
600 EIGEN_STRONG_INLINE PacketXf ploaddup<PacketXf>(const float* from)
601 {
602   svuint32_t indices = svindex_u32(0, 1);  // index {base=0, base+step=1, base+step*2, ...}
603   indices = svzip1_u32(indices, indices);  // index in the format {a0, a0, a1, a1, a2, a2, ...}
604   return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
605 }
606 
607 template <>
608 EIGEN_STRONG_INLINE PacketXf ploadquad<PacketXf>(const float* from)
609 {
610   svuint32_t indices = svindex_u32(0, 1);  // index {base=0, base+step=1, base+step*2, ...}
611   indices = svzip1_u32(indices, indices);  // index in the format {a0, a0, a1, a1, a2, a2, ...}
612   indices = svzip1_u32(indices, indices);  // index in the format {a0, a0, a0, a0, a1, a1, a1, a1, ...}
613   return svld1_gather_u32index_f32(svptrue_b32(), from, indices);
614 }
615 
616 template <>
617 EIGEN_STRONG_INLINE void pstore<float>(float* to, const PacketXf& from)
618 {
619   EIGEN_DEBUG_ALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
620 }
621 
622 template <>
623 EIGEN_STRONG_INLINE void pstoreu<float>(float* to, const PacketXf& from)
624 {
625   EIGEN_DEBUG_UNALIGNED_STORE svst1_f32(svptrue_b32(), to, from);
626 }
627 
628 template <>
629 EIGEN_DEVICE_FUNC inline PacketXf pgather<float, PacketXf>(const float* from, Index stride)
630 {
631   // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
632   svint32_t indices = svindex_s32(0, stride);
633   return svld1_gather_s32index_f32(svptrue_b32(), from, indices);
634 }
635 
636 template <>
637 EIGEN_DEVICE_FUNC inline void pscatter<float, PacketXf>(float* to, const PacketXf& from, Index stride)
638 {
639   // Indice format: {base=0, base+stride, base+stride*2, base+stride*3, ...}
640   svint32_t indices = svindex_s32(0, stride);
641   svst1_scatter_s32index_f32(svptrue_b32(), to, indices, from);
642 }
643 
644 template <>
645 EIGEN_STRONG_INLINE float pfirst<PacketXf>(const PacketXf& a)
646 {
647   // svlasta returns the first element if all predicate bits are 0
648   return svlasta_f32(svpfalse_b(), a);
649 }
650 
651 template <>
652 EIGEN_STRONG_INLINE PacketXf preverse(const PacketXf& a)
653 {
654   return svrev_f32(a);
655 }
656 
657 template <>
658 EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
659 {
660   return svabs_f32_z(svptrue_b32(), a);
661 }
662 
663 // TODO(tellenbach): Should this go into MathFunctions.h? If so, change for
664 // all vector extensions and the generic version.
665 template <>
666 EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
667 {
668   return pfrexp_generic(a, exponent);
669 }
670 
671 template <>
672 EIGEN_STRONG_INLINE float predux<PacketXf>(const PacketXf& a)
673 {
674   return svaddv_f32(svptrue_b32(), a);
675 }
676 
677 // Other reduction functions:
678 // mul
679 // Only works for SVE Vls multiple of 128
680 template <>
681 EIGEN_STRONG_INLINE float predux_mul<PacketXf>(const PacketXf& a)
682 {
683   EIGEN_STATIC_ASSERT((EIGEN_ARM64_SVE_VL % 128 == 0),
684                       EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
685   // Multiply the vector by its reverse
686   svfloat32_t prod = svmul_f32_z(svptrue_b32(), a, svrev_f32(a));
687   svfloat32_t half_prod;
688 
689   // Extract the high half of the vector. Depending on the VL more reductions need to be done
690   if (EIGEN_ARM64_SVE_VL >= 2048) {
691     half_prod = svtbl_f32(prod, svindex_u32(32, 1));
692     prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
693   }
694   if (EIGEN_ARM64_SVE_VL >= 1024) {
695     half_prod = svtbl_f32(prod, svindex_u32(16, 1));
696     prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
697   }
698   if (EIGEN_ARM64_SVE_VL >= 512) {
699     half_prod = svtbl_f32(prod, svindex_u32(8, 1));
700     prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
701   }
702   if (EIGEN_ARM64_SVE_VL >= 256) {
703     half_prod = svtbl_f32(prod, svindex_u32(4, 1));
704     prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
705   }
706   // Last reduction
707   half_prod = svtbl_f32(prod, svindex_u32(2, 1));
708   prod = svmul_f32_z(svptrue_b32(), prod, half_prod);
709 
710   // The reduction is done to the first element.
711   return pfirst<PacketXf>(prod);
712 }
713 
714 template <>
715 EIGEN_STRONG_INLINE float predux_min<PacketXf>(const PacketXf& a)
716 {
717   return svminv_f32(svptrue_b32(), a);
718 }
719 
720 template <>
721 EIGEN_STRONG_INLINE float predux_max<PacketXf>(const PacketXf& a)
722 {
723   return svmaxv_f32(svptrue_b32(), a);
724 }
725 
726 template<int N>
727 EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
728 {
729   float buffer[packet_traits<float>::size * N] = {0};
730   int i = 0;
731 
732   PacketXi stride_index = svindex_s32(0, N);
733 
734   for (i = 0; i < N; i++) {
735     svst1_scatter_s32index_f32(svptrue_b32(), buffer + i, stride_index, kernel.packet[i]);
736   }
737 
738   for (i = 0; i < N; i++) {
739     kernel.packet[i] = svld1_f32(svptrue_b32(), buffer + i * packet_traits<float>::size);
740   }
741 }
742 
743 template<>
744 EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
745 {
746   return pldexp_generic(a, exponent);
747 }
748 
749 }  // namespace internal
750 }  // namespace Eigen
751 
752 #endif  // EIGEN_PACKET_MATH_SVE_H
753