• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use super::addition::{__add2, add2};
2 use super::subtraction::sub2;
3 use super::{biguint_from_vec, cmp_slice, BigUint, IntDigits};
4 
5 use crate::big_digit::{self, BigDigit, DoubleBigDigit};
6 use crate::Sign::{self, Minus, NoSign, Plus};
7 use crate::{BigInt, UsizePromotion};
8 
9 use core::cmp::Ordering;
10 use core::iter::Product;
11 use core::ops::{Mul, MulAssign};
12 use num_traits::{CheckedMul, FromPrimitive, One, Zero};
13 
14 #[inline]
mac_with_carry( a: BigDigit, b: BigDigit, c: BigDigit, acc: &mut DoubleBigDigit, ) -> BigDigit15 pub(super) fn mac_with_carry(
16     a: BigDigit,
17     b: BigDigit,
18     c: BigDigit,
19     acc: &mut DoubleBigDigit,
20 ) -> BigDigit {
21     *acc += DoubleBigDigit::from(a);
22     *acc += DoubleBigDigit::from(b) * DoubleBigDigit::from(c);
23     let lo = *acc as BigDigit;
24     *acc >>= big_digit::BITS;
25     lo
26 }
27 
28 #[inline]
mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit29 fn mul_with_carry(a: BigDigit, b: BigDigit, acc: &mut DoubleBigDigit) -> BigDigit {
30     *acc += DoubleBigDigit::from(a) * DoubleBigDigit::from(b);
31     let lo = *acc as BigDigit;
32     *acc >>= big_digit::BITS;
33     lo
34 }
35 
36 /// Three argument multiply accumulate:
37 /// acc += b * c
mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit)38 fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
39     if c == 0 {
40         return;
41     }
42 
43     let mut carry = 0;
44     let (a_lo, a_hi) = acc.split_at_mut(b.len());
45 
46     for (a, &b) in a_lo.iter_mut().zip(b) {
47         *a = mac_with_carry(*a, b, c, &mut carry);
48     }
49 
50     let (carry_hi, carry_lo) = big_digit::from_doublebigdigit(carry);
51 
52     let final_carry = if carry_hi == 0 {
53         __add2(a_hi, &[carry_lo])
54     } else {
55         __add2(a_hi, &[carry_hi, carry_lo])
56     };
57     assert_eq!(final_carry, 0, "carry overflow during multiplication!");
58 }
59 
bigint_from_slice(slice: &[BigDigit]) -> BigInt60 fn bigint_from_slice(slice: &[BigDigit]) -> BigInt {
61     BigInt::from(biguint_from_vec(slice.to_vec()))
62 }
63 
64 /// Three argument multiply accumulate:
65 /// acc += b * c
66 #[allow(clippy::many_single_char_names)]
mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit])67 fn mac3(mut acc: &mut [BigDigit], mut b: &[BigDigit], mut c: &[BigDigit]) {
68     // Least-significant zeros have no effect on the output.
69     if let Some(&0) = b.first() {
70         if let Some(nz) = b.iter().position(|&d| d != 0) {
71             b = &b[nz..];
72             acc = &mut acc[nz..];
73         } else {
74             return;
75         }
76     }
77     if let Some(&0) = c.first() {
78         if let Some(nz) = c.iter().position(|&d| d != 0) {
79             c = &c[nz..];
80             acc = &mut acc[nz..];
81         } else {
82             return;
83         }
84     }
85 
86     let acc = acc;
87     let (x, y) = if b.len() < c.len() { (b, c) } else { (c, b) };
88 
89     // We use four algorithms for different input sizes.
90     //
91     // - For small inputs, long multiplication is fastest.
92     // - If y is at least least twice as long as x, split using Half-Karatsuba.
93     // - Next we use Karatsuba multiplication (Toom-2), which we have optimized
94     //   to avoid unnecessary allocations for intermediate values.
95     // - For the largest inputs we use Toom-3, which better optimizes the
96     //   number of operations, but uses more temporary allocations.
97     //
98     // The thresholds are somewhat arbitrary, chosen by evaluating the results
99     // of `cargo bench --bench bigint multiply`.
100 
101     if x.len() <= 32 {
102         // Long multiplication:
103         for (i, xi) in x.iter().enumerate() {
104             mac_digit(&mut acc[i..], y, *xi);
105         }
106     } else if x.len() * 2 <= y.len() {
107         // Karatsuba Multiplication for factors with significant length disparity.
108         //
109         // The Half-Karatsuba Multiplication Algorithm is a specialized case of
110         // the normal Karatsuba multiplication algorithm, designed for the scenario
111         // where y has at least twice as many base digits as x.
112         //
113         // In this case y (the longer input) is split into high2 and low2,
114         // at m2 (half the length of y) and x (the shorter input),
115         // is used directly without splitting.
116         //
117         // The algorithm then proceeds as follows:
118         //
119         // 1. Compute the product z0 = x * low2.
120         // 2. Compute the product temp = x * high2.
121         // 3. Adjust the weight of temp by adding m2 (* NBASE ^ m2)
122         // 4. Add temp and z0 to obtain the final result.
123         //
124         // Proof:
125         //
126         // The algorithm can be derived from the original Karatsuba algorithm by
127         // simplifying the formula when the shorter factor x is not split into
128         // high and low parts, as shown below.
129         //
130         // Original Karatsuba formula:
131         //
132         //     result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
133         //
134         // Substitutions:
135         //
136         //     low1 = x
137         //     high1 = 0
138         //
139         // Applying substitutions:
140         //
141         //     z0 = (low1 * low2)
142         //        = (x * low2)
143         //
144         //     z1 = ((low1 + high1) * (low2 + high2))
145         //        = ((x + 0) * (low2 + high2))
146         //        = (x * low2) + (x * high2)
147         //
148         //     z2 = (high1 * high2)
149         //        = (0 * high2)
150         //        = 0
151         //
152         // Simplified using the above substitutions:
153         //
154         //     result = (z2 * NBASE ^ (m2 × 2)) + ((z1 - z2 - z0) * NBASE ^ m2) + z0
155         //            = (0 * NBASE ^ (m2 × 2)) + ((z1 - 0 - z0) * NBASE ^ m2) + z0
156         //            = ((z1 - z0) * NBASE ^ m2) + z0
157         //            = ((z1 - z0) * NBASE ^ m2) + z0
158         //            = (x * high2) * NBASE ^ m2 + z0
159         let m2 = y.len() / 2;
160         let (low2, high2) = y.split_at(m2);
161 
162         // (x * high2) * NBASE ^ m2 + z0
163         mac3(acc, x, low2);
164         mac3(&mut acc[m2..], x, high2);
165     } else if x.len() <= 256 {
166         // Karatsuba multiplication:
167         //
168         // The idea is that we break x and y up into two smaller numbers that each have about half
169         // as many digits, like so (note that multiplying by b is just a shift):
170         //
171         // x = x0 + x1 * b
172         // y = y0 + y1 * b
173         //
174         // With some algebra, we can compute x * y with three smaller products, where the inputs to
175         // each of the smaller products have only about half as many digits as x and y:
176         //
177         // x * y = (x0 + x1 * b) * (y0 + y1 * b)
178         //
179         // x * y = x0 * y0
180         //       + x0 * y1 * b
181         //       + x1 * y0 * b
182         //       + x1 * y1 * b^2
183         //
184         // Let p0 = x0 * y0 and p2 = x1 * y1:
185         //
186         // x * y = p0
187         //       + (x0 * y1 + x1 * y0) * b
188         //       + p2 * b^2
189         //
190         // The real trick is that middle term:
191         //
192         //         x0 * y1 + x1 * y0
193         //
194         //       = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
195         //
196         //       = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
197         //
198         // Now we complete the square:
199         //
200         //       = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
201         //
202         //       = -((x1 - x0) * (y1 - y0)) + p0 + p2
203         //
204         // Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
205         //
206         // x * y = p0
207         //       + (p0 + p2 - p1) * b
208         //       + p2 * b^2
209         //
210         // Where the three intermediate products are:
211         //
212         // p0 = x0 * y0
213         // p1 = (x1 - x0) * (y1 - y0)
214         // p2 = x1 * y1
215         //
216         // In doing the computation, we take great care to avoid unnecessary temporary variables
217         // (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
218         // bit so we can use the same temporary variable for all the intermediate products:
219         //
220         // x * y = p2 * b^2 + p2 * b
221         //       + p0 * b + p0
222         //       - p1 * b
223         //
224         // The other trick we use is instead of doing explicit shifts, we slice acc at the
225         // appropriate offset when doing the add.
226 
227         // When x is smaller than y, it's significantly faster to pick b such that x is split in
228         // half, not y:
229         let b = x.len() / 2;
230         let (x0, x1) = x.split_at(b);
231         let (y0, y1) = y.split_at(b);
232 
233         // We reuse the same BigUint for all the intermediate multiplies and have to size p
234         // appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
235         let len = x1.len() + y1.len() + 1;
236         let mut p = BigUint { data: vec![0; len] };
237 
238         // p2 = x1 * y1
239         mac3(&mut p.data, x1, y1);
240 
241         // Not required, but the adds go faster if we drop any unneeded 0s from the end:
242         p.normalize();
243 
244         add2(&mut acc[b..], &p.data);
245         add2(&mut acc[b * 2..], &p.data);
246 
247         // Zero out p before the next multiply:
248         p.data.truncate(0);
249         p.data.resize(len, 0);
250 
251         // p0 = x0 * y0
252         mac3(&mut p.data, x0, y0);
253         p.normalize();
254 
255         add2(acc, &p.data);
256         add2(&mut acc[b..], &p.data);
257 
258         // p1 = (x1 - x0) * (y1 - y0)
259         // We do this one last, since it may be negative and acc can't ever be negative:
260         let (j0_sign, j0) = sub_sign(x1, x0);
261         let (j1_sign, j1) = sub_sign(y1, y0);
262 
263         match j0_sign * j1_sign {
264             Plus => {
265                 p.data.truncate(0);
266                 p.data.resize(len, 0);
267 
268                 mac3(&mut p.data, &j0.data, &j1.data);
269                 p.normalize();
270 
271                 sub2(&mut acc[b..], &p.data);
272             }
273             Minus => {
274                 mac3(&mut acc[b..], &j0.data, &j1.data);
275             }
276             NoSign => (),
277         }
278     } else {
279         // Toom-3 multiplication:
280         //
281         // Toom-3 is like Karatsuba above, but dividing the inputs into three parts.
282         // Both are instances of Toom-Cook, using `k=3` and `k=2` respectively.
283         //
284         // The general idea is to treat the large integers digits as
285         // polynomials of a certain degree and determine the coefficients/digits
286         // of the product of the two via interpolation of the polynomial product.
287         let i = y.len() / 3 + 1;
288 
289         let x0_len = Ord::min(x.len(), i);
290         let x1_len = Ord::min(x.len() - x0_len, i);
291 
292         let y0_len = i;
293         let y1_len = Ord::min(y.len() - y0_len, i);
294 
295         // Break x and y into three parts, representating an order two polynomial.
296         // t is chosen to be the size of a digit so we can use faster shifts
297         // in place of multiplications.
298         //
299         // x(t) = x2*t^2 + x1*t + x0
300         let x0 = bigint_from_slice(&x[..x0_len]);
301         let x1 = bigint_from_slice(&x[x0_len..x0_len + x1_len]);
302         let x2 = bigint_from_slice(&x[x0_len + x1_len..]);
303 
304         // y(t) = y2*t^2 + y1*t + y0
305         let y0 = bigint_from_slice(&y[..y0_len]);
306         let y1 = bigint_from_slice(&y[y0_len..y0_len + y1_len]);
307         let y2 = bigint_from_slice(&y[y0_len + y1_len..]);
308 
309         // Let w(t) = x(t) * y(t)
310         //
311         // This gives us the following order-4 polynomial.
312         //
313         // w(t) = w4*t^4 + w3*t^3 + w2*t^2 + w1*t + w0
314         //
315         // We need to find the coefficients w4, w3, w2, w1 and w0. Instead
316         // of simply multiplying the x and y in total, we can evaluate w
317         // at 5 points. An n-degree polynomial is uniquely identified by (n + 1)
318         // points.
319         //
320         // It is arbitrary as to what points we evaluate w at but we use the
321         // following.
322         //
323         // w(t) at t = 0, 1, -1, -2 and inf
324         //
325         // The values for w(t) in terms of x(t)*y(t) at these points are:
326         //
327         // let a = w(0)   = x0 * y0
328         // let b = w(1)   = (x2 + x1 + x0) * (y2 + y1 + y0)
329         // let c = w(-1)  = (x2 - x1 + x0) * (y2 - y1 + y0)
330         // let d = w(-2)  = (4*x2 - 2*x1 + x0) * (4*y2 - 2*y1 + y0)
331         // let e = w(inf) = x2 * y2 as t -> inf
332 
333         // x0 + x2, avoiding temporaries
334         let p = &x0 + &x2;
335 
336         // y0 + y2, avoiding temporaries
337         let q = &y0 + &y2;
338 
339         // x2 - x1 + x0, avoiding temporaries
340         let p2 = &p - &x1;
341 
342         // y2 - y1 + y0, avoiding temporaries
343         let q2 = &q - &y1;
344 
345         // w(0)
346         let r0 = &x0 * &y0;
347 
348         // w(inf)
349         let r4 = &x2 * &y2;
350 
351         // w(1)
352         let r1 = (p + x1) * (q + y1);
353 
354         // w(-1)
355         let r2 = &p2 * &q2;
356 
357         // w(-2)
358         let r3 = ((p2 + x2) * 2 - x0) * ((q2 + y2) * 2 - y0);
359 
360         // Evaluating these points gives us the following system of linear equations.
361         //
362         //  0  0  0  0  1 | a
363         //  1  1  1  1  1 | b
364         //  1 -1  1 -1  1 | c
365         // 16 -8  4 -2  1 | d
366         //  1  0  0  0  0 | e
367         //
368         // The solved equation (after gaussian elimination or similar)
369         // in terms of its coefficients:
370         //
371         // w0 = w(0)
372         // w1 = w(0)/2 + w(1)/3 - w(-1) + w(-2)/6 - 2*w(inf)
373         // w2 = -w(0) + w(1)/2 + w(-1)/2 - w(inf)
374         // w3 = -w(0)/2 + w(1)/6 + w(-1)/2 - w(-2)/6 + 2*w(inf)
375         // w4 = w(inf)
376         //
377         // This particular sequence is given by Bodrato and is an interpolation
378         // of the above equations.
379         let mut comp3: BigInt = (r3 - &r1) / 3u32;
380         let mut comp1: BigInt = (r1 - &r2) >> 1;
381         let mut comp2: BigInt = r2 - &r0;
382         comp3 = ((&comp2 - comp3) >> 1) + (&r4 << 1);
383         comp2 += &comp1 - &r4;
384         comp1 -= &comp3;
385 
386         // Recomposition. The coefficients of the polynomial are now known.
387         //
388         // Evaluate at w(t) where t is our given base to get the result.
389         //
390         //     let bits = u64::from(big_digit::BITS) * i as u64;
391         //     let result = r0
392         //         + (comp1 << bits)
393         //         + (comp2 << (2 * bits))
394         //         + (comp3 << (3 * bits))
395         //         + (r4 << (4 * bits));
396         //     let result_pos = result.to_biguint().unwrap();
397         //     add2(&mut acc[..], &result_pos.data);
398         //
399         // But with less intermediate copying:
400         for (j, result) in [&r0, &comp1, &comp2, &comp3, &r4].iter().enumerate().rev() {
401             match result.sign() {
402                 Plus => add2(&mut acc[i * j..], result.digits()),
403                 Minus => sub2(&mut acc[i * j..], result.digits()),
404                 NoSign => {}
405             }
406         }
407     }
408 }
409 
mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint410 fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
411     let len = x.len() + y.len() + 1;
412     let mut prod = BigUint { data: vec![0; len] };
413 
414     mac3(&mut prod.data, x, y);
415     prod.normalized()
416 }
417 
scalar_mul(a: &mut BigUint, b: BigDigit)418 fn scalar_mul(a: &mut BigUint, b: BigDigit) {
419     match b {
420         0 => a.set_zero(),
421         1 => {}
422         _ => {
423             if b.is_power_of_two() {
424                 *a <<= b.trailing_zeros();
425             } else {
426                 let mut carry = 0;
427                 for a in a.data.iter_mut() {
428                     *a = mul_with_carry(*a, b, &mut carry);
429                 }
430                 if carry != 0 {
431                     a.data.push(carry as BigDigit);
432                 }
433             }
434         }
435     }
436 }
437 
sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint)438 fn sub_sign(mut a: &[BigDigit], mut b: &[BigDigit]) -> (Sign, BigUint) {
439     // Normalize:
440     if let Some(&0) = a.last() {
441         a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
442     }
443     if let Some(&0) = b.last() {
444         b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
445     }
446 
447     match cmp_slice(a, b) {
448         Ordering::Greater => {
449             let mut a = a.to_vec();
450             sub2(&mut a, b);
451             (Plus, biguint_from_vec(a))
452         }
453         Ordering::Less => {
454             let mut b = b.to_vec();
455             sub2(&mut b, a);
456             (Minus, biguint_from_vec(b))
457         }
458         Ordering::Equal => (NoSign, BigUint::ZERO),
459     }
460 }
461 
462 macro_rules! impl_mul {
463     ($(impl Mul<$Other:ty> for $Self:ty;)*) => {$(
464         impl Mul<$Other> for $Self {
465             type Output = BigUint;
466 
467             #[inline]
468             fn mul(self, other: $Other) -> BigUint {
469                 match (&*self.data, &*other.data) {
470                     // multiply by zero
471                     (&[], _) | (_, &[]) => BigUint::ZERO,
472                     // multiply by a scalar
473                     (_, &[digit]) => self * digit,
474                     (&[digit], _) => other * digit,
475                     // full multiplication
476                     (x, y) => mul3(x, y),
477                 }
478             }
479         }
480     )*}
481 }
482 impl_mul! {
483     impl Mul<BigUint> for BigUint;
484     impl Mul<BigUint> for &BigUint;
485     impl Mul<&BigUint> for BigUint;
486     impl Mul<&BigUint> for &BigUint;
487 }
488 
489 macro_rules! impl_mul_assign {
490     ($(impl MulAssign<$Other:ty> for BigUint;)*) => {$(
491         impl MulAssign<$Other> for BigUint {
492             #[inline]
493             fn mul_assign(&mut self, other: $Other) {
494                 match (&*self.data, &*other.data) {
495                     // multiply by zero
496                     (&[], _) => {},
497                     (_, &[]) => self.set_zero(),
498                     // multiply by a scalar
499                     (_, &[digit]) => *self *= digit,
500                     (&[digit], _) => *self = other * digit,
501                     // full multiplication
502                     (x, y) => *self = mul3(x, y),
503                 }
504             }
505         }
506     )*}
507 }
508 impl_mul_assign! {
509     impl MulAssign<BigUint> for BigUint;
510     impl MulAssign<&BigUint> for BigUint;
511 }
512 
513 promote_unsigned_scalars!(impl Mul for BigUint, mul);
514 promote_unsigned_scalars_assign!(impl MulAssign for BigUint, mul_assign);
515 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u32> for BigUint, mul);
516 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u64> for BigUint, mul);
517 forward_all_scalar_binop_to_val_val_commutative!(impl Mul<u128> for BigUint, mul);
518 
519 impl Mul<u32> for BigUint {
520     type Output = BigUint;
521 
522     #[inline]
mul(mut self, other: u32) -> BigUint523     fn mul(mut self, other: u32) -> BigUint {
524         self *= other;
525         self
526     }
527 }
528 impl MulAssign<u32> for BigUint {
529     #[inline]
mul_assign(&mut self, other: u32)530     fn mul_assign(&mut self, other: u32) {
531         scalar_mul(self, other as BigDigit);
532     }
533 }
534 
535 impl Mul<u64> for BigUint {
536     type Output = BigUint;
537 
538     #[inline]
mul(mut self, other: u64) -> BigUint539     fn mul(mut self, other: u64) -> BigUint {
540         self *= other;
541         self
542     }
543 }
544 impl MulAssign<u64> for BigUint {
545     cfg_digit!(
546         #[inline]
547         fn mul_assign(&mut self, other: u64) {
548             if let Some(other) = BigDigit::from_u64(other) {
549                 scalar_mul(self, other);
550             } else {
551                 let (hi, lo) = big_digit::from_doublebigdigit(other);
552                 *self = mul3(&self.data, &[lo, hi]);
553             }
554         }
555 
556         #[inline]
557         fn mul_assign(&mut self, other: u64) {
558             scalar_mul(self, other);
559         }
560     );
561 }
562 
563 impl Mul<u128> for BigUint {
564     type Output = BigUint;
565 
566     #[inline]
mul(mut self, other: u128) -> BigUint567     fn mul(mut self, other: u128) -> BigUint {
568         self *= other;
569         self
570     }
571 }
572 
573 impl MulAssign<u128> for BigUint {
574     cfg_digit!(
575         #[inline]
576         fn mul_assign(&mut self, other: u128) {
577             if let Some(other) = BigDigit::from_u128(other) {
578                 scalar_mul(self, other);
579             } else {
580                 *self = match super::u32_from_u128(other) {
581                     (0, 0, c, d) => mul3(&self.data, &[d, c]),
582                     (0, b, c, d) => mul3(&self.data, &[d, c, b]),
583                     (a, b, c, d) => mul3(&self.data, &[d, c, b, a]),
584                 };
585             }
586         }
587 
588         #[inline]
589         fn mul_assign(&mut self, other: u128) {
590             if let Some(other) = BigDigit::from_u128(other) {
591                 scalar_mul(self, other);
592             } else {
593                 let (hi, lo) = big_digit::from_doublebigdigit(other);
594                 *self = mul3(&self.data, &[lo, hi]);
595             }
596         }
597     );
598 }
599 
600 impl CheckedMul for BigUint {
601     #[inline]
checked_mul(&self, v: &BigUint) -> Option<BigUint>602     fn checked_mul(&self, v: &BigUint) -> Option<BigUint> {
603         Some(self.mul(v))
604     }
605 }
606 
607 impl_product_iter_type!(BigUint);
608 
609 #[test]
test_sub_sign()610 fn test_sub_sign() {
611     use crate::BigInt;
612     use num_traits::Num;
613 
614     fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
615         let (sign, val) = sub_sign(a, b);
616         BigInt::from_biguint(sign, val)
617     }
618 
619     let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
620     let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
621     let a_i = BigInt::from(a.clone());
622     let b_i = BigInt::from(b.clone());
623 
624     assert_eq!(sub_sign_i(&a.data, &b.data), &a_i - &b_i);
625     assert_eq!(sub_sign_i(&b.data, &a.data), &b_i - &a_i);
626 }
627