1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <stdbool.h>
12 #include <stddef.h>
13 #include <stdint.h>
14 #include <assert.h>
15
16 #ifdef _MSC_VER
17 #include <intrin.h>
18 #include <stdlib.h> // For _rotl.
19 #endif
20
21 #include <xnnpack/common.h>
22
23
24 // stdlib.h from Windows 10 SDK defines min & max macros.
25 // Undefine them before defining the corresponding functions.
26 #ifdef min
27 #undef min
28 #endif
29 #ifdef max
30 #undef max
31 #endif
32
33
min(size_t a,size_t b)34 XNN_INLINE static size_t min(size_t a, size_t b) {
35 return XNN_UNPREDICTABLE(b < a) ? b : a;
36 }
37
max(size_t a,size_t b)38 XNN_INLINE static size_t max(size_t a, size_t b) {
39 return XNN_UNPREDICTABLE(b < a) ? a : b;
40 }
41
doz(size_t a,size_t b)42 XNN_INLINE static size_t doz(size_t a, size_t b) {
43 return XNN_UNPREDICTABLE(b < a) ? a - b : 0;
44 }
45
divide_round_up(size_t n,size_t q)46 XNN_INLINE static size_t divide_round_up(size_t n, size_t q) {
47 return XNN_UNPREDICTABLE(n % q == 0) ? n / q : n / q + 1;
48 }
49
round_up(size_t n,size_t q)50 XNN_INLINE static size_t round_up(size_t n, size_t q) {
51 return divide_round_up(n, q) * q;
52 }
53
is_po2(size_t n)54 XNN_INLINE static bool is_po2(size_t n) {
55 return (n != 0) && ((n & (n - 1)) == 0);
56 }
round_down_po2(size_t n,size_t q)57 XNN_INLINE static size_t round_down_po2(size_t n, size_t q) {
58 assert(is_po2(q));
59 return n & -q;
60 }
61
round_up_po2(size_t n,size_t q)62 XNN_INLINE static size_t round_up_po2(size_t n, size_t q) {
63 return round_down_po2(n + q - 1, q);
64 }
65
subtract_modulo(size_t a,size_t b,size_t m)66 XNN_INLINE static size_t subtract_modulo(size_t a, size_t b, size_t m) {
67 assert(a < m);
68 assert(b < m);
69 return XNN_UNPREDICTABLE(a >= b) ? a - b : a - b + m;
70 }
71
uint32_as_float(uint32_t i)72 XNN_INLINE static float uint32_as_float(uint32_t i) {
73 union {
74 uint32_t as_uint32;
75 float as_float;
76 } bits = { i };
77 return bits.as_float;
78 }
79
float_as_uint32(float f)80 XNN_INLINE static uint32_t float_as_uint32(float f) {
81 union {
82 float as_float;
83 uint32_t as_uint32;
84 } bits = { f };
85 return bits.as_uint32;
86 }
87
uint64_as_double(uint64_t i)88 XNN_INLINE static double uint64_as_double(uint64_t i) {
89 union {
90 uint64_t as_uint64;
91 double as_double;
92 } bits = { i };
93 return bits.as_double;
94 }
95
double_as_uint64(double f)96 XNN_INLINE static uint64_t double_as_uint64(double f) {
97 union {
98 double as_double;
99 uint64_t as_uint64;
100 } bits = { f };
101 return bits.as_uint64;
102 }
103
math_abs_s32(int32_t n)104 XNN_INLINE static uint32_t math_abs_s32(int32_t n) {
105 #if defined(_MSC_VER)
106 return (uint32_t) abs((int) n);
107 #else
108 return XNN_UNPREDICTABLE(n >= 0) ? (uint32_t) n : -(uint32_t) n;
109 #endif
110 }
111
math_min_s32(int32_t a,int32_t b)112 XNN_INLINE static int32_t math_min_s32(int32_t a, int32_t b) {
113 return XNN_UNPREDICTABLE(a < b) ? a : b;
114 }
115
math_max_s32(int32_t a,int32_t b)116 XNN_INLINE static int32_t math_max_s32(int32_t a, int32_t b) {
117 return XNN_UNPREDICTABLE(a > b) ? a : b;
118 }
119
math_min_u32(uint32_t a,uint32_t b)120 XNN_INLINE static uint32_t math_min_u32(uint32_t a, uint32_t b) {
121 return XNN_UNPREDICTABLE(a < b) ? a : b;
122 }
123
math_max_u32(uint32_t a,uint32_t b)124 XNN_INLINE static uint32_t math_max_u32(uint32_t a, uint32_t b) {
125 return XNN_UNPREDICTABLE(a > b) ? a : b;
126 }
127
math_doz_u32(uint32_t a,uint32_t b)128 XNN_INLINE static uint32_t math_doz_u32(uint32_t a, uint32_t b) {
129 return XNN_UNPREDICTABLE(a > b) ? a - b : 0;
130 }
131
math_mulext_s32(int32_t a,int32_t b)132 XNN_INLINE static int64_t math_mulext_s32(int32_t a, int32_t b) {
133 #if defined(_MSC_VER) && defined(_M_IX86)
134 return (int64_t) __emul((int) a, (int) b);
135 #else
136 return (int64_t) a * (int64_t) b;
137 #endif
138 }
139
math_mulext_u32(uint32_t a,uint32_t b)140 XNN_INLINE static uint64_t math_mulext_u32(uint32_t a, uint32_t b) {
141 #if defined(_MSC_VER) && defined(_M_IX86)
142 return (uint64_t) __emulu((unsigned int) a, (unsigned int) b);
143 #else
144 return (uint64_t) a * (uint64_t) b;
145 #endif
146 }
147
math_muladd_f32(float x,float y,float acc)148 XNN_INLINE static float math_muladd_f32(float x, float y, float acc) {
149 #if defined(__GNUC__) && defined(__FP_FAST_FMAF)
150 return __builtin_fmaf(x, y, acc);
151 #elif defined(__clang__) && defined(__riscv)
152 return __builtin_fmaf(x, y, acc);
153 #else
154 return x * y + acc;
155 #endif
156 }
157
math_min_f32(float a,float b)158 XNN_INLINE static float math_min_f32(float a, float b) {
159 #if defined(__GNUC__) && defined(__ARM_ARCH) && (__ARM_ARCH >= 8)
160 return __builtin_fminf(a, b);
161 #elif defined(__clang__) && defined(__riscv)
162 return __builtin_fminf(a, b);
163 #else
164 return XNN_UNPREDICTABLE(b < a) ? b : a;
165 #endif
166 }
167
math_max_f32(float a,float b)168 XNN_INLINE static float math_max_f32(float a, float b) {
169 #if defined(__GNUC__) && defined(__ARM_ARCH) && (__ARM_ARCH >= 8)
170 return __builtin_fmaxf(a, b);
171 #elif defined(__clang__) && defined(__riscv)
172 return __builtin_fmaxf(a, b);
173 #else
174 return XNN_UNPREDICTABLE(b < a) ? a : b;
175 #endif
176 }
177
math_min_f64(double a,double b)178 XNN_INLINE static double math_min_f64(double a, double b) {
179 #if defined(__GNUC__) && defined(__ARM_ARCH) && (__ARM_ARCH >= 8)
180 return __builtin_fmin(a, b);
181 #elif defined(__clang__) && defined(__riscv)
182 return __builtin_fmin(a, b);
183 #else
184 return XNN_UNPREDICTABLE(b < a) ? b : a;
185 #endif
186 }
187
math_max_f64(double a,double b)188 XNN_INLINE static double math_max_f64(double a, double b) {
189 #if defined(__GNUC__) && defined(__ARM_ARCH) && (__ARM_ARCH >= 8)
190 return __builtin_fmax(a, b);
191 #elif defined(__clang__) && defined(__riscv)
192 return __builtin_fmax(a, b);
193 #else
194 return XNN_UNPREDICTABLE(b < a) ? a : b;
195 #endif
196 }
197
math_nonsign_mask_f32()198 XNN_INLINE static float math_nonsign_mask_f32() {
199 #if defined(__INTEL_COMPILER)
200 // Surprisingly, Intel compiler ignores __builtin_nanf payload
201 return _castu32_f32(0x7FFFFFFF);
202 #elif defined(__GNUC__)
203 return __builtin_nanf("0x7FFFFF");
204 #else
205 union {
206 uint32_t as_word;
207 float as_float;
208 } f;
209 f.as_word = 0x7FFFFFFF;
210 return f.as_float;
211 #endif
212 }
213
214
215 #if defined(__clang__)
216 #if __clang_major__ == 3 && __clang_minor__ >= 7 || __clang_major__ > 3
217 #define XNN_IGNORE_SHIFT_BASE_UB __attribute__((__no_sanitize__("shift-base")))
218 #else
219 #define XNN_IGNORE_SHIFT_BASE_UB
220 #endif
221 #elif defined(__GNUC__)
222 #if __GNUC__ >= 8
223 #define XNN_IGNORE_SHIFT_BASE_UB __attribute__((__no_sanitize__("shift-base")))
224 #elif __GNUC__ == 4 && __GNUC_MINOR__ >= 9 || __GNUC__ > 4
225 // 4.9 <= gcc < 8 support ubsan, but doesn't support no_sanitize attribute
226 #define XNN_IGNORE_SHIFT_BASE_UB
227 #ifndef XNN_USE_SHIFT_BASE_UB_WORKAROUND
228 #define XNN_USE_SHIFT_BASE_UB_WORKAROUND 1
229 #endif
230 #else
231 #define XNN_IGNORE_SHIFT_BASE_UB
232 #endif
233 #else
234 #define XNN_IGNORE_SHIFT_BASE_UB
235 #endif
236
237 XNN_IGNORE_SHIFT_BASE_UB
math_asr_s32(int32_t x,uint32_t n)238 XNN_INLINE static int32_t math_asr_s32(int32_t x, uint32_t n) {
239 #ifdef XNN_USE_SHIFT_BASE_UB_WORKAROUND
240 #if XNN_ARCH_X86_64 || XNN_ARCH_ARM64
241 return (int32_t) ((uint64_t) (int64_t) x >> n);
242 #else
243 return x >= 0 ? x >> n : ~(~x >> n);
244 #endif
245 #else
246 return x >> n;
247 #endif
248 }
249
250 XNN_IGNORE_SHIFT_BASE_UB
math_asr_s64(int64_t x,uint32_t n)251 XNN_INLINE static int64_t math_asr_s64(int64_t x, uint32_t n) {
252 #ifdef XNN_USE_SHIFT_BASE_UB_WORKAROUND
253 return x >= 0 ? x >> n : ~(~x >> n);
254 #else
255 return x >> n;
256 #endif
257 }
258
math_clz_u32(uint32_t x)259 XNN_INLINE static uint32_t math_clz_u32(uint32_t x) {
260 #if defined(_MSC_VER) && !defined(__clang__)
261 unsigned long index;
262 if XNN_UNPREDICTABLE(_BitScanReverse(&index, (unsigned long) x) != 0) {
263 return (uint32_t) index ^ 31;
264 } else {
265 return 32;
266 }
267 #else
268 if XNN_UNPREDICTABLE(x == 0) {
269 return 32;
270 } else {
271 return (uint32_t) __builtin_clz((unsigned int) x);
272 }
273 #endif
274 }
275
math_clz_nonzero_u32(uint32_t x)276 XNN_INLINE static uint32_t math_clz_nonzero_u32(uint32_t x) {
277 assert(x != 0);
278 #if defined(_MSC_VER) && !defined(__clang__)
279 unsigned long index;
280 _BitScanReverse(&index, (unsigned long) x);
281 return (uint32_t) index ^ 31;
282 #else
283 return (uint32_t) __builtin_clz((unsigned int) x);
284 #endif
285 }
286
math_ctz_u32(uint32_t x)287 XNN_INLINE static uint32_t math_ctz_u32(uint32_t x) {
288 #if defined(_MSC_VER) && !defined(__clang__)
289 unsigned long index;
290 _BitScanForward(&index, (unsigned long) x);
291 return (uint32_t) index;
292 #else
293 return (uint32_t) __builtin_ctz((unsigned int) x);
294 #endif
295 }
296
math_rotl_u32(uint32_t x,int8_t r)297 XNN_INLINE static uint32_t math_rotl_u32(uint32_t x, int8_t r)
298 {
299 #if XNN_COMPILER_MSVC
300 return _rotl((unsigned int) x, (int) r);
301 #else
302 return (x << r) | (x >> (32 - r));
303 #endif
304 }
305
306 #ifndef __cplusplus
math_cvt_sat_u32_f64(double x)307 XNN_INLINE static uint32_t math_cvt_sat_u32_f64(double x) {
308 #if defined(__GNUC__) && defined(__arm__)
309 float i; // float instead of uint32_t because vcvt.u32.f64 writes to an S register
310 __asm__ ("vcvt.u32.f64 %[i], %P[x]"
311 : [i] "=w" (i)
312 : [x] "w" (x));
313 return float_as_uint32(i);
314 #elif defined(__GNUC__) && defined(__aarch64__)
315 uint32_t i;
316 __asm__ ("fcvtnu %w[i], %d[x]"
317 : [i] "=r" (i)
318 : [x] "w" (x));
319 return i;
320 #elif defined(__GNUC__) && defined(__riscv)
321 uint32_t i;
322 __asm__ ("fcvt.wu.d %[i], %[x], rne"
323 : [i] "=r" (i)
324 : [x] "f" (x));
325 return i;
326 #elif defined(__clang__) && defined(__wasm__) && defined(__wasm_nontrapping_fptoint__)
327 return __builtin_wasm_trunc_saturate_u_i32_f64(rint(x));
328 #else
329 x = math_max_f64(x, 0.0);
330 x = math_min_f64(x, 4294967295.0);
331 return (uint32_t) double_as_uint64(x + 0x1.0p+52);
332 #endif
333 }
334 #endif
335