1 use core;
2 use core::mem;
3 use traits::checked_pow;
4 use traits::PrimInt;
5 use Integer;
6
7 /// Provides methods to compute an integer's square root, cube root,
8 /// and arbitrary `n`th root.
9 pub trait Roots: Integer {
10 /// Returns the truncated principal `n`th root of an integer
11 /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }`
12 ///
13 /// This is solving for `r` in `rⁿ = x`, rounding toward zero.
14 /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`.
15 /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`.
16 ///
17 /// # Panics
18 ///
19 /// Panics if `n` is zero:
20 ///
21 /// ```should_panic
22 /// # use num_integer::Roots;
23 /// println!("can't compute ⁰√x : {}", 123.nth_root(0));
24 /// ```
25 ///
26 /// or if `n` is even and `self` is negative:
27 ///
28 /// ```should_panic
29 /// # use num_integer::Roots;
30 /// println!("no imaginary numbers... {}", (-1).nth_root(10));
31 /// ```
32 ///
33 /// # Examples
34 ///
35 /// ```
36 /// use num_integer::Roots;
37 ///
38 /// let x: i32 = 12345;
39 /// assert_eq!(x.nth_root(1), x);
40 /// assert_eq!(x.nth_root(2), x.sqrt());
41 /// assert_eq!(x.nth_root(3), x.cbrt());
42 /// assert_eq!(x.nth_root(4), 10);
43 /// assert_eq!(x.nth_root(13), 2);
44 /// assert_eq!(x.nth_root(14), 1);
45 /// assert_eq!(x.nth_root(std::u32::MAX), 1);
46 ///
47 /// assert_eq!(std::i32::MAX.nth_root(30), 2);
48 /// assert_eq!(std::i32::MAX.nth_root(31), 1);
49 /// assert_eq!(std::i32::MIN.nth_root(31), -2);
50 /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1);
51 ///
52 /// assert_eq!(std::u32::MAX.nth_root(31), 2);
53 /// assert_eq!(std::u32::MAX.nth_root(32), 1);
54 /// ```
nth_root(&self, n: u32) -> Self55 fn nth_root(&self, n: u32) -> Self;
56
57 /// Returns the truncated principal square root of an integer -- `⌊√x⌋`
58 ///
59 /// This is solving for `r` in `r² = x`, rounding toward zero.
60 /// The result will satisfy `r² ≤ x < (r+1)²`.
61 ///
62 /// # Panics
63 ///
64 /// Panics if `self` is less than zero:
65 ///
66 /// ```should_panic
67 /// # use num_integer::Roots;
68 /// println!("no imaginary numbers... {}", (-1).sqrt());
69 /// ```
70 ///
71 /// # Examples
72 ///
73 /// ```
74 /// use num_integer::Roots;
75 ///
76 /// let x: i32 = 12345;
77 /// assert_eq!((x * x).sqrt(), x);
78 /// assert_eq!((x * x + 1).sqrt(), x);
79 /// assert_eq!((x * x - 1).sqrt(), x - 1);
80 /// ```
81 #[inline]
sqrt(&self) -> Self82 fn sqrt(&self) -> Self {
83 self.nth_root(2)
84 }
85
86 /// Returns the truncated principal cube root of an integer --
87 /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }`
88 ///
89 /// This is solving for `r` in `r³ = x`, rounding toward zero.
90 /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`.
91 /// If `x` is negative, then `(r-1)³ < x ≤ r³`.
92 ///
93 /// # Examples
94 ///
95 /// ```
96 /// use num_integer::Roots;
97 ///
98 /// let x: i32 = 1234;
99 /// assert_eq!((x * x * x).cbrt(), x);
100 /// assert_eq!((x * x * x + 1).cbrt(), x);
101 /// assert_eq!((x * x * x - 1).cbrt(), x - 1);
102 ///
103 /// assert_eq!((-(x * x * x)).cbrt(), -x);
104 /// assert_eq!((-(x * x * x + 1)).cbrt(), -x);
105 /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1));
106 /// ```
107 #[inline]
cbrt(&self) -> Self108 fn cbrt(&self) -> Self {
109 self.nth_root(3)
110 }
111 }
112
113 /// Returns the truncated principal square root of an integer --
114 /// see [Roots::sqrt](trait.Roots.html#method.sqrt).
115 #[inline]
sqrt<T: Roots>(x: T) -> T116 pub fn sqrt<T: Roots>(x: T) -> T {
117 x.sqrt()
118 }
119
120 /// Returns the truncated principal cube root of an integer --
121 /// see [Roots::cbrt](trait.Roots.html#method.cbrt).
122 #[inline]
cbrt<T: Roots>(x: T) -> T123 pub fn cbrt<T: Roots>(x: T) -> T {
124 x.cbrt()
125 }
126
127 /// Returns the truncated principal `n`th root of an integer --
128 /// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root).
129 #[inline]
nth_root<T: Roots>(x: T, n: u32) -> T130 pub fn nth_root<T: Roots>(x: T, n: u32) -> T {
131 x.nth_root(n)
132 }
133
134 macro_rules! signed_roots {
135 ($T:ty, $U:ty) => {
136 impl Roots for $T {
137 #[inline]
138 fn nth_root(&self, n: u32) -> Self {
139 if *self >= 0 {
140 (*self as $U).nth_root(n) as Self
141 } else {
142 assert!(n.is_odd(), "even roots of a negative are imaginary");
143 -((self.wrapping_neg() as $U).nth_root(n) as Self)
144 }
145 }
146
147 #[inline]
148 fn sqrt(&self) -> Self {
149 assert!(*self >= 0, "the square root of a negative is imaginary");
150 (*self as $U).sqrt() as Self
151 }
152
153 #[inline]
154 fn cbrt(&self) -> Self {
155 if *self >= 0 {
156 (*self as $U).cbrt() as Self
157 } else {
158 -((self.wrapping_neg() as $U).cbrt() as Self)
159 }
160 }
161 }
162 };
163 }
164
165 signed_roots!(i8, u8);
166 signed_roots!(i16, u16);
167 signed_roots!(i32, u32);
168 signed_roots!(i64, u64);
169 #[cfg(has_i128)]
170 signed_roots!(i128, u128);
171 signed_roots!(isize, usize);
172
173 #[inline]
fixpoint<T, F>(mut x: T, f: F) -> T where T: Integer + Copy, F: Fn(T) -> T,174 fn fixpoint<T, F>(mut x: T, f: F) -> T
175 where
176 T: Integer + Copy,
177 F: Fn(T) -> T,
178 {
179 let mut xn = f(x);
180 while x < xn {
181 x = xn;
182 xn = f(x);
183 }
184 while x > xn {
185 x = xn;
186 xn = f(x);
187 }
188 x
189 }
190
191 #[inline]
bits<T>() -> u32192 fn bits<T>() -> u32 {
193 8 * mem::size_of::<T>() as u32
194 }
195
196 #[inline]
log2<T: PrimInt>(x: T) -> u32197 fn log2<T: PrimInt>(x: T) -> u32 {
198 debug_assert!(x > T::zero());
199 bits::<T>() - 1 - x.leading_zeros()
200 }
201
202 macro_rules! unsigned_roots {
203 ($T:ident) => {
204 impl Roots for $T {
205 #[inline]
206 fn nth_root(&self, n: u32) -> Self {
207 fn go(a: $T, n: u32) -> $T {
208 // Specialize small roots
209 match n {
210 0 => panic!("can't find a root of degree 0!"),
211 1 => return a,
212 2 => return a.sqrt(),
213 3 => return a.cbrt(),
214 _ => (),
215 }
216
217 // The root of values less than 2ⁿ can only be 0 or 1.
218 if bits::<$T>() <= n || a < (1 << n) {
219 return (a > 0) as $T;
220 }
221
222 if bits::<$T>() > 64 {
223 // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
224 return if a <= core::u64::MAX as $T {
225 (a as u64).nth_root(n) as $T
226 } else {
227 let lo = (a >> n).nth_root(n) << 1;
228 let hi = lo + 1;
229 // 128-bit `checked_mul` also involves division, but we can't always
230 // compute `hiⁿ` without risking overflow. Try to avoid it though...
231 if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
232 match checked_pow(hi, n as usize) {
233 Some(x) if x <= a => hi,
234 _ => lo,
235 }
236 } else {
237 if hi.pow(n) <= a {
238 hi
239 } else {
240 lo
241 }
242 }
243 };
244 }
245
246 #[cfg(feature = "std")]
247 #[inline]
248 fn guess(x: $T, n: u32) -> $T {
249 // for smaller inputs, `f64` doesn't justify its cost.
250 if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
251 1 << ((log2(x) + n - 1) / n)
252 } else {
253 ((x as f64).ln() / f64::from(n)).exp() as $T
254 }
255 }
256
257 #[cfg(not(feature = "std"))]
258 #[inline]
259 fn guess(x: $T, n: u32) -> $T {
260 1 << ((log2(x) + n - 1) / n)
261 }
262
263 // https://en.wikipedia.org/wiki/Nth_root_algorithm
264 let n1 = n - 1;
265 let next = |x: $T| {
266 let y = match checked_pow(x, n1 as usize) {
267 Some(ax) => a / ax,
268 None => 0,
269 };
270 (y + x * n1 as $T) / n as $T
271 };
272 fixpoint(guess(a, n), next)
273 }
274 go(*self, n)
275 }
276
277 #[inline]
278 fn sqrt(&self) -> Self {
279 fn go(a: $T) -> $T {
280 if bits::<$T>() > 64 {
281 // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
282 return if a <= core::u64::MAX as $T {
283 (a as u64).sqrt() as $T
284 } else {
285 let lo = (a >> 2u32).sqrt() << 1;
286 let hi = lo + 1;
287 if hi * hi <= a {
288 hi
289 } else {
290 lo
291 }
292 };
293 }
294
295 if a < 4 {
296 return (a > 0) as $T;
297 }
298
299 #[cfg(feature = "std")]
300 #[inline]
301 fn guess(x: $T) -> $T {
302 (x as f64).sqrt() as $T
303 }
304
305 #[cfg(not(feature = "std"))]
306 #[inline]
307 fn guess(x: $T) -> $T {
308 1 << ((log2(x) + 1) / 2)
309 }
310
311 // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
312 let next = |x: $T| (a / x + x) >> 1;
313 fixpoint(guess(a), next)
314 }
315 go(*self)
316 }
317
318 #[inline]
319 fn cbrt(&self) -> Self {
320 fn go(a: $T) -> $T {
321 if bits::<$T>() > 64 {
322 // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
323 return if a <= core::u64::MAX as $T {
324 (a as u64).cbrt() as $T
325 } else {
326 let lo = (a >> 3u32).cbrt() << 1;
327 let hi = lo + 1;
328 if hi * hi * hi <= a {
329 hi
330 } else {
331 lo
332 }
333 };
334 }
335
336 if bits::<$T>() <= 32 {
337 // Implementation based on Hacker's Delight `icbrt2`
338 let mut x = a;
339 let mut y2 = 0;
340 let mut y = 0;
341 let smax = bits::<$T>() / 3;
342 for s in (0..smax + 1).rev() {
343 let s = s * 3;
344 y2 *= 4;
345 y *= 2;
346 let b = 3 * (y2 + y) + 1;
347 if x >> s >= b {
348 x -= b << s;
349 y2 += 2 * y + 1;
350 y += 1;
351 }
352 }
353 return y;
354 }
355
356 if a < 8 {
357 return (a > 0) as $T;
358 }
359 if a <= core::u32::MAX as $T {
360 return (a as u32).cbrt() as $T;
361 }
362
363 #[cfg(feature = "std")]
364 #[inline]
365 fn guess(x: $T) -> $T {
366 (x as f64).cbrt() as $T
367 }
368
369 #[cfg(not(feature = "std"))]
370 #[inline]
371 fn guess(x: $T) -> $T {
372 1 << ((log2(x) + 2) / 3)
373 }
374
375 // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
376 let next = |x: $T| (a / (x * x) + x * 2) / 3;
377 fixpoint(guess(a), next)
378 }
379 go(*self)
380 }
381 }
382 };
383 }
384
385 unsigned_roots!(u8);
386 unsigned_roots!(u16);
387 unsigned_roots!(u32);
388 unsigned_roots!(u64);
389 #[cfg(has_i128)]
390 unsigned_roots!(u128);
391 unsigned_roots!(usize);
392