1 /*
2 * Wrapper functions for SVE ACLE.
3 *
4 * Copyright (c) 2019-2023, Arm Limited.
5 * SPDX-License-Identifier: MIT OR Apache-2.0 WITH LLVM-exception
6 */
7
8 #ifndef SV_MATH_H
9 #define SV_MATH_H
10
11 #ifndef WANT_VMATH
12 /* Enable the build of vector math code. */
13 #define WANT_VMATH 1
14 #endif
15 #if WANT_VMATH
16
17 #if WANT_SVE_MATH
18 #define SV_SUPPORTED 1
19
20 #include <arm_sve.h>
21 #include <stdbool.h>
22
23 #include "math_config.h"
24
25 typedef float f32_t;
26 typedef uint32_t u32_t;
27 typedef int32_t s32_t;
28 typedef double f64_t;
29 typedef uint64_t u64_t;
30 typedef int64_t s64_t;
31
32 typedef svfloat64_t sv_f64_t;
33 typedef svuint64_t sv_u64_t;
34 typedef svint64_t sv_s64_t;
35
36 typedef svfloat32_t sv_f32_t;
37 typedef svuint32_t sv_u32_t;
38 typedef svint32_t sv_s32_t;
39
40 /* Double precision. */
41 static inline sv_s64_t
sv_s64(s64_t x)42 sv_s64 (s64_t x)
43 {
44 return svdup_n_s64 (x);
45 }
46
47 static inline sv_u64_t
sv_u64(u64_t x)48 sv_u64 (u64_t x)
49 {
50 return svdup_n_u64 (x);
51 }
52
53 static inline sv_f64_t
sv_f64(f64_t x)54 sv_f64 (f64_t x)
55 {
56 return svdup_n_f64 (x);
57 }
58
59 static inline sv_f64_t
sv_fma_f64_x(svbool_t pg,sv_f64_t x,sv_f64_t y,sv_f64_t z)60 sv_fma_f64_x (svbool_t pg, sv_f64_t x, sv_f64_t y, sv_f64_t z)
61 {
62 return svmla_f64_x (pg, z, x, y);
63 }
64
65 /* res = z + x * y with x scalar. */
66 static inline sv_f64_t
sv_fma_n_f64_x(svbool_t pg,f64_t x,sv_f64_t y,sv_f64_t z)67 sv_fma_n_f64_x (svbool_t pg, f64_t x, sv_f64_t y, sv_f64_t z)
68 {
69 return svmla_n_f64_x (pg, z, y, x);
70 }
71
72 static inline sv_s64_t
sv_as_s64_u64(sv_u64_t x)73 sv_as_s64_u64 (sv_u64_t x)
74 {
75 return svreinterpret_s64_u64 (x);
76 }
77
78 static inline sv_u64_t
sv_as_u64_f64(sv_f64_t x)79 sv_as_u64_f64 (sv_f64_t x)
80 {
81 return svreinterpret_u64_f64 (x);
82 }
83
84 static inline sv_f64_t
sv_as_f64_u64(sv_u64_t x)85 sv_as_f64_u64 (sv_u64_t x)
86 {
87 return svreinterpret_f64_u64 (x);
88 }
89
90 static inline sv_f64_t
sv_to_f64_s64_x(svbool_t pg,sv_s64_t s)91 sv_to_f64_s64_x (svbool_t pg, sv_s64_t s)
92 {
93 return svcvt_f64_x (pg, s);
94 }
95
96 static inline sv_f64_t
sv_call_f64(f64_t (* f)(f64_t),sv_f64_t x,sv_f64_t y,svbool_t cmp)97 sv_call_f64 (f64_t (*f) (f64_t), sv_f64_t x, sv_f64_t y, svbool_t cmp)
98 {
99 svbool_t p = svpfirst (cmp, svpfalse ());
100 while (svptest_any (cmp, p))
101 {
102 f64_t elem = svclastb_n_f64 (p, 0, x);
103 elem = (*f) (elem);
104 sv_f64_t y2 = svdup_n_f64 (elem);
105 y = svsel_f64 (p, y2, y);
106 p = svpnext_b64 (cmp, p);
107 }
108 return y;
109 }
110
111 static inline sv_f64_t
sv_call2_f64(f64_t (* f)(f64_t,f64_t),sv_f64_t x1,sv_f64_t x2,sv_f64_t y,svbool_t cmp)112 sv_call2_f64 (f64_t (*f) (f64_t, f64_t), sv_f64_t x1, sv_f64_t x2, sv_f64_t y,
113 svbool_t cmp)
114 {
115 svbool_t p = svpfirst (cmp, svpfalse ());
116 while (svptest_any (cmp, p))
117 {
118 f64_t elem1 = svclastb_n_f64 (p, 0, x1);
119 f64_t elem2 = svclastb_n_f64 (p, 0, x2);
120 f64_t ret = (*f) (elem1, elem2);
121 sv_f64_t y2 = svdup_n_f64 (ret);
122 y = svsel_f64 (p, y2, y);
123 p = svpnext_b64 (cmp, p);
124 }
125 return y;
126 }
127
128 /* Load array of uint64_t into svuint64_t. */
129 static inline sv_u64_t
sv_lookup_u64_x(svbool_t pg,const u64_t * tab,sv_u64_t idx)130 sv_lookup_u64_x (svbool_t pg, const u64_t *tab, sv_u64_t idx)
131 {
132 return svld1_gather_u64index_u64 (pg, tab, idx);
133 }
134
135 /* Load array of double into svfloat64_t. */
136 static inline sv_f64_t
sv_lookup_f64_x(svbool_t pg,const f64_t * tab,sv_u64_t idx)137 sv_lookup_f64_x (svbool_t pg, const f64_t *tab, sv_u64_t idx)
138 {
139 return svld1_gather_u64index_f64 (pg, tab, idx);
140 }
141
142 static inline sv_u64_t
sv_mod_n_u64_x(svbool_t pg,sv_u64_t x,u64_t y)143 sv_mod_n_u64_x (svbool_t pg, sv_u64_t x, u64_t y)
144 {
145 sv_u64_t q = svdiv_n_u64_x (pg, x, y);
146 return svmls_n_u64_x (pg, x, q, y);
147 }
148
149 /* Single precision. */
150 static inline sv_s32_t
sv_s32(s32_t x)151 sv_s32 (s32_t x)
152 {
153 return svdup_n_s32 (x);
154 }
155
156 static inline sv_u32_t
sv_u32(u32_t x)157 sv_u32 (u32_t x)
158 {
159 return svdup_n_u32 (x);
160 }
161
162 static inline sv_f32_t
sv_f32(f32_t x)163 sv_f32 (f32_t x)
164 {
165 return svdup_n_f32 (x);
166 }
167
168 static inline sv_f32_t
sv_fma_f32_x(svbool_t pg,sv_f32_t x,sv_f32_t y,sv_f32_t z)169 sv_fma_f32_x (svbool_t pg, sv_f32_t x, sv_f32_t y, sv_f32_t z)
170 {
171 return svmla_f32_x (pg, z, x, y);
172 }
173
174 /* res = z + x * y with x scalar. */
175 static inline sv_f32_t
sv_fma_n_f32_x(svbool_t pg,f32_t x,sv_f32_t y,sv_f32_t z)176 sv_fma_n_f32_x (svbool_t pg, f32_t x, sv_f32_t y, sv_f32_t z)
177 {
178 return svmla_n_f32_x (pg, z, y, x);
179 }
180
181 static inline sv_u32_t
sv_as_u32_f32(sv_f32_t x)182 sv_as_u32_f32 (sv_f32_t x)
183 {
184 return svreinterpret_u32_f32 (x);
185 }
186
187 static inline sv_f32_t
sv_as_f32_u32(sv_u32_t x)188 sv_as_f32_u32 (sv_u32_t x)
189 {
190 return svreinterpret_f32_u32 (x);
191 }
192
193 static inline sv_s32_t
sv_as_s32_u32(sv_u32_t x)194 sv_as_s32_u32 (sv_u32_t x)
195 {
196 return svreinterpret_s32_u32 (x);
197 }
198
199 static inline sv_f32_t
sv_to_f32_s32_x(svbool_t pg,sv_s32_t s)200 sv_to_f32_s32_x (svbool_t pg, sv_s32_t s)
201 {
202 return svcvt_f32_x (pg, s);
203 }
204
205 static inline sv_s32_t
sv_to_s32_f32_x(svbool_t pg,sv_f32_t x)206 sv_to_s32_f32_x (svbool_t pg, sv_f32_t x)
207 {
208 return svcvt_s32_f32_x (pg, x);
209 }
210
211 static inline sv_f32_t
sv_call_f32(f32_t (* f)(f32_t),sv_f32_t x,sv_f32_t y,svbool_t cmp)212 sv_call_f32 (f32_t (*f) (f32_t), sv_f32_t x, sv_f32_t y, svbool_t cmp)
213 {
214 svbool_t p = svpfirst (cmp, svpfalse ());
215 while (svptest_any (cmp, p))
216 {
217 f32_t elem = svclastb_n_f32 (p, 0, x);
218 elem = (*f) (elem);
219 sv_f32_t y2 = svdup_n_f32 (elem);
220 y = svsel_f32 (p, y2, y);
221 p = svpnext_b32 (cmp, p);
222 }
223 return y;
224 }
225
226 static inline sv_f32_t
sv_call2_f32(f32_t (* f)(f32_t,f32_t),sv_f32_t x1,sv_f32_t x2,sv_f32_t y,svbool_t cmp)227 sv_call2_f32 (f32_t (*f) (f32_t, f32_t), sv_f32_t x1, sv_f32_t x2, sv_f32_t y,
228 svbool_t cmp)
229 {
230 svbool_t p = svpfirst (cmp, svpfalse ());
231 while (svptest_any (cmp, p))
232 {
233 f32_t elem1 = svclastb_n_f32 (p, 0, x1);
234 f32_t elem2 = svclastb_n_f32 (p, 0, x2);
235 f32_t ret = (*f) (elem1, elem2);
236 sv_f32_t y2 = svdup_n_f32 (ret);
237 y = svsel_f32 (p, y2, y);
238 p = svpnext_b32 (cmp, p);
239 }
240 return y;
241 }
242
243 #endif
244 #endif
245 #endif
246