• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::Integer;
2 use core::mem;
3 use num_traits::{checked_pow, PrimInt};
4 
5 /// Provides methods to compute an integer's square root, cube root,
6 /// and arbitrary `n`th root.
7 pub trait Roots: Integer {
8     /// Returns the truncated principal `n`th root of an integer
9     /// -- `if x >= 0 { ⌊ⁿ√x⌋ } else { ⌈ⁿ√x⌉ }`
10     ///
11     /// This is solving for `r` in `rⁿ = x`, rounding toward zero.
12     /// If `x` is positive, the result will satisfy `rⁿ ≤ x < (r+1)ⁿ`.
13     /// If `x` is negative and `n` is odd, then `(r-1)ⁿ < x ≤ rⁿ`.
14     ///
15     /// # Panics
16     ///
17     /// Panics if `n` is zero:
18     ///
19     /// ```should_panic
20     /// # use num_integer::Roots;
21     /// println!("can't compute ⁰√x : {}", 123.nth_root(0));
22     /// ```
23     ///
24     /// or if `n` is even and `self` is negative:
25     ///
26     /// ```should_panic
27     /// # use num_integer::Roots;
28     /// println!("no imaginary numbers... {}", (-1).nth_root(10));
29     /// ```
30     ///
31     /// # Examples
32     ///
33     /// ```
34     /// use num_integer::Roots;
35     ///
36     /// let x: i32 = 12345;
37     /// assert_eq!(x.nth_root(1), x);
38     /// assert_eq!(x.nth_root(2), x.sqrt());
39     /// assert_eq!(x.nth_root(3), x.cbrt());
40     /// assert_eq!(x.nth_root(4), 10);
41     /// assert_eq!(x.nth_root(13), 2);
42     /// assert_eq!(x.nth_root(14), 1);
43     /// assert_eq!(x.nth_root(std::u32::MAX), 1);
44     ///
45     /// assert_eq!(std::i32::MAX.nth_root(30), 2);
46     /// assert_eq!(std::i32::MAX.nth_root(31), 1);
47     /// assert_eq!(std::i32::MIN.nth_root(31), -2);
48     /// assert_eq!((std::i32::MIN + 1).nth_root(31), -1);
49     ///
50     /// assert_eq!(std::u32::MAX.nth_root(31), 2);
51     /// assert_eq!(std::u32::MAX.nth_root(32), 1);
52     /// ```
nth_root(&self, n: u32) -> Self53     fn nth_root(&self, n: u32) -> Self;
54 
55     /// Returns the truncated principal square root of an integer -- `⌊√x⌋`
56     ///
57     /// This is solving for `r` in `r² = x`, rounding toward zero.
58     /// The result will satisfy `r² ≤ x < (r+1)²`.
59     ///
60     /// # Panics
61     ///
62     /// Panics if `self` is less than zero:
63     ///
64     /// ```should_panic
65     /// # use num_integer::Roots;
66     /// println!("no imaginary numbers... {}", (-1).sqrt());
67     /// ```
68     ///
69     /// # Examples
70     ///
71     /// ```
72     /// use num_integer::Roots;
73     ///
74     /// let x: i32 = 12345;
75     /// assert_eq!((x * x).sqrt(), x);
76     /// assert_eq!((x * x + 1).sqrt(), x);
77     /// assert_eq!((x * x - 1).sqrt(), x - 1);
78     /// ```
79     #[inline]
sqrt(&self) -> Self80     fn sqrt(&self) -> Self {
81         self.nth_root(2)
82     }
83 
84     /// Returns the truncated principal cube root of an integer --
85     /// `if x >= 0 { ⌊∛x⌋ } else { ⌈∛x⌉ }`
86     ///
87     /// This is solving for `r` in `r³ = x`, rounding toward zero.
88     /// If `x` is positive, the result will satisfy `r³ ≤ x < (r+1)³`.
89     /// If `x` is negative, then `(r-1)³ < x ≤ r³`.
90     ///
91     /// # Examples
92     ///
93     /// ```
94     /// use num_integer::Roots;
95     ///
96     /// let x: i32 = 1234;
97     /// assert_eq!((x * x * x).cbrt(), x);
98     /// assert_eq!((x * x * x + 1).cbrt(), x);
99     /// assert_eq!((x * x * x - 1).cbrt(), x - 1);
100     ///
101     /// assert_eq!((-(x * x * x)).cbrt(), -x);
102     /// assert_eq!((-(x * x * x + 1)).cbrt(), -x);
103     /// assert_eq!((-(x * x * x - 1)).cbrt(), -(x - 1));
104     /// ```
105     #[inline]
cbrt(&self) -> Self106     fn cbrt(&self) -> Self {
107         self.nth_root(3)
108     }
109 }
110 
111 /// Returns the truncated principal square root of an integer --
112 /// see [Roots::sqrt](trait.Roots.html#method.sqrt).
113 #[inline]
sqrt<T: Roots>(x: T) -> T114 pub fn sqrt<T: Roots>(x: T) -> T {
115     x.sqrt()
116 }
117 
118 /// Returns the truncated principal cube root of an integer --
119 /// see [Roots::cbrt](trait.Roots.html#method.cbrt).
120 #[inline]
cbrt<T: Roots>(x: T) -> T121 pub fn cbrt<T: Roots>(x: T) -> T {
122     x.cbrt()
123 }
124 
125 /// Returns the truncated principal `n`th root of an integer --
126 /// see [Roots::nth_root](trait.Roots.html#tymethod.nth_root).
127 #[inline]
nth_root<T: Roots>(x: T, n: u32) -> T128 pub fn nth_root<T: Roots>(x: T, n: u32) -> T {
129     x.nth_root(n)
130 }
131 
132 macro_rules! signed_roots {
133     ($T:ty, $U:ty) => {
134         impl Roots for $T {
135             #[inline]
136             fn nth_root(&self, n: u32) -> Self {
137                 if *self >= 0 {
138                     (*self as $U).nth_root(n) as Self
139                 } else {
140                     assert!(n.is_odd(), "even roots of a negative are imaginary");
141                     -((self.wrapping_neg() as $U).nth_root(n) as Self)
142                 }
143             }
144 
145             #[inline]
146             fn sqrt(&self) -> Self {
147                 assert!(*self >= 0, "the square root of a negative is imaginary");
148                 (*self as $U).sqrt() as Self
149             }
150 
151             #[inline]
152             fn cbrt(&self) -> Self {
153                 if *self >= 0 {
154                     (*self as $U).cbrt() as Self
155                 } else {
156                     -((self.wrapping_neg() as $U).cbrt() as Self)
157                 }
158             }
159         }
160     };
161 }
162 
163 signed_roots!(i8, u8);
164 signed_roots!(i16, u16);
165 signed_roots!(i32, u32);
166 signed_roots!(i64, u64);
167 signed_roots!(i128, u128);
168 signed_roots!(isize, usize);
169 
170 #[inline]
fixpoint<T, F>(mut x: T, f: F) -> T where T: Integer + Copy, F: Fn(T) -> T,171 fn fixpoint<T, F>(mut x: T, f: F) -> T
172 where
173     T: Integer + Copy,
174     F: Fn(T) -> T,
175 {
176     let mut xn = f(x);
177     while x < xn {
178         x = xn;
179         xn = f(x);
180     }
181     while x > xn {
182         x = xn;
183         xn = f(x);
184     }
185     x
186 }
187 
188 #[inline]
bits<T>() -> u32189 fn bits<T>() -> u32 {
190     8 * mem::size_of::<T>() as u32
191 }
192 
193 #[inline]
log2<T: PrimInt>(x: T) -> u32194 fn log2<T: PrimInt>(x: T) -> u32 {
195     debug_assert!(x > T::zero());
196     bits::<T>() - 1 - x.leading_zeros()
197 }
198 
199 macro_rules! unsigned_roots {
200     ($T:ident) => {
201         impl Roots for $T {
202             #[inline]
203             fn nth_root(&self, n: u32) -> Self {
204                 fn go(a: $T, n: u32) -> $T {
205                     // Specialize small roots
206                     match n {
207                         0 => panic!("can't find a root of degree 0!"),
208                         1 => return a,
209                         2 => return a.sqrt(),
210                         3 => return a.cbrt(),
211                         _ => (),
212                     }
213 
214                     // The root of values less than 2ⁿ can only be 0 or 1.
215                     if bits::<$T>() <= n || a < (1 << n) {
216                         return (a > 0) as $T;
217                     }
218 
219                     if bits::<$T>() > 64 {
220                         // 128-bit division is slow, so do a bitwise `nth_root` until it's small enough.
221                         return if a <= core::u64::MAX as $T {
222                             (a as u64).nth_root(n) as $T
223                         } else {
224                             let lo = (a >> n).nth_root(n) << 1;
225                             let hi = lo + 1;
226                             // 128-bit `checked_mul` also involves division, but we can't always
227                             // compute `hiⁿ` without risking overflow.  Try to avoid it though...
228                             if hi.next_power_of_two().trailing_zeros() * n >= bits::<$T>() {
229                                 match checked_pow(hi, n as usize) {
230                                     Some(x) if x <= a => hi,
231                                     _ => lo,
232                                 }
233                             } else {
234                                 if hi.pow(n) <= a {
235                                     hi
236                                 } else {
237                                     lo
238                                 }
239                             }
240                         };
241                     }
242 
243                     #[cfg(feature = "std")]
244                     #[inline]
245                     fn guess(x: $T, n: u32) -> $T {
246                         // for smaller inputs, `f64` doesn't justify its cost.
247                         if bits::<$T>() <= 32 || x <= core::u32::MAX as $T {
248                             1 << ((log2(x) + n - 1) / n)
249                         } else {
250                             ((x as f64).ln() / f64::from(n)).exp() as $T
251                         }
252                     }
253 
254                     #[cfg(not(feature = "std"))]
255                     #[inline]
256                     fn guess(x: $T, n: u32) -> $T {
257                         1 << ((log2(x) + n - 1) / n)
258                     }
259 
260                     // https://en.wikipedia.org/wiki/Nth_root_algorithm
261                     let n1 = n - 1;
262                     let next = |x: $T| {
263                         let y = match checked_pow(x, n1 as usize) {
264                             Some(ax) => a / ax,
265                             None => 0,
266                         };
267                         (y + x * n1 as $T) / n as $T
268                     };
269                     fixpoint(guess(a, n), next)
270                 }
271                 go(*self, n)
272             }
273 
274             #[inline]
275             fn sqrt(&self) -> Self {
276                 fn go(a: $T) -> $T {
277                     if bits::<$T>() > 64 {
278                         // 128-bit division is slow, so do a bitwise `sqrt` until it's small enough.
279                         return if a <= core::u64::MAX as $T {
280                             (a as u64).sqrt() as $T
281                         } else {
282                             let lo = (a >> 2u32).sqrt() << 1;
283                             let hi = lo + 1;
284                             if hi * hi <= a {
285                                 hi
286                             } else {
287                                 lo
288                             }
289                         };
290                     }
291 
292                     if a < 4 {
293                         return (a > 0) as $T;
294                     }
295 
296                     #[cfg(feature = "std")]
297                     #[inline]
298                     fn guess(x: $T) -> $T {
299                         (x as f64).sqrt() as $T
300                     }
301 
302                     #[cfg(not(feature = "std"))]
303                     #[inline]
304                     fn guess(x: $T) -> $T {
305                         1 << ((log2(x) + 1) / 2)
306                     }
307 
308                     // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Babylonian_method
309                     let next = |x: $T| (a / x + x) >> 1;
310                     fixpoint(guess(a), next)
311                 }
312                 go(*self)
313             }
314 
315             #[inline]
316             fn cbrt(&self) -> Self {
317                 fn go(a: $T) -> $T {
318                     if bits::<$T>() > 64 {
319                         // 128-bit division is slow, so do a bitwise `cbrt` until it's small enough.
320                         return if a <= core::u64::MAX as $T {
321                             (a as u64).cbrt() as $T
322                         } else {
323                             let lo = (a >> 3u32).cbrt() << 1;
324                             let hi = lo + 1;
325                             if hi * hi * hi <= a {
326                                 hi
327                             } else {
328                                 lo
329                             }
330                         };
331                     }
332 
333                     if bits::<$T>() <= 32 {
334                         // Implementation based on Hacker's Delight `icbrt2`
335                         let mut x = a;
336                         let mut y2 = 0;
337                         let mut y = 0;
338                         let smax = bits::<$T>() / 3;
339                         for s in (0..smax + 1).rev() {
340                             let s = s * 3;
341                             y2 *= 4;
342                             y *= 2;
343                             let b = 3 * (y2 + y) + 1;
344                             if x >> s >= b {
345                                 x -= b << s;
346                                 y2 += 2 * y + 1;
347                                 y += 1;
348                             }
349                         }
350                         return y;
351                     }
352 
353                     if a < 8 {
354                         return (a > 0) as $T;
355                     }
356                     if a <= core::u32::MAX as $T {
357                         return (a as u32).cbrt() as $T;
358                     }
359 
360                     #[cfg(feature = "std")]
361                     #[inline]
362                     fn guess(x: $T) -> $T {
363                         (x as f64).cbrt() as $T
364                     }
365 
366                     #[cfg(not(feature = "std"))]
367                     #[inline]
368                     fn guess(x: $T) -> $T {
369                         1 << ((log2(x) + 2) / 3)
370                     }
371 
372                     // https://en.wikipedia.org/wiki/Cube_root#Numerical_methods
373                     let next = |x: $T| (a / (x * x) + x * 2) / 3;
374                     fixpoint(guess(a), next)
375                 }
376                 go(*self)
377             }
378         }
379     };
380 }
381 
382 unsigned_roots!(u8);
383 unsigned_roots!(u16);
384 unsigned_roots!(u32);
385 unsigned_roots!(u64);
386 unsigned_roots!(u128);
387 unsigned_roots!(usize);
388