• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2015-2016 Brian Smith.
2 //
3 // Permission to use, copy, modify, and/or distribute this software for any
4 // purpose with or without fee is hereby granted, provided that the above
5 // copyright notice and this permission notice appear in all copies.
6 //
7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 
15 //! Multi-precision integers.
16 //!
17 //! # Modular Arithmetic.
18 //!
19 //! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20 //! modulus *m*. We work in finite commutative rings instead of finite fields
21 //! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22 //! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23 //! finite field.
24 //!
25 //! In some calculations we need to deal with multiple rings at once. For
26 //! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27 //! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28 //! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29 //! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30 //! the "unit" pattern described in [Static checking of units in Servo].
31 //!
32 //! `Elem` also uses the static unit checking pattern to statically track the
33 //! Montgomery factors that need to be canceled out in each value using it's
34 //! `E` parameter.
35 //!
36 //! [Static checking of units in Servo]:
37 //!     https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38 
39 use crate::{
40     arithmetic::montgomery::*,
41     bits, bssl, c, debug, error,
42     limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
43 };
44 use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
45 use core::{
46     marker::PhantomData,
47     ops::{Deref, DerefMut},
48 };
49 
50 mod bn_mul_mont_fallback;
51 
52 pub unsafe trait Prime {}
53 
54 struct Width<M> {
55     num_limbs: usize,
56 
57     /// The modulus *m* that the width originated from.
58     m: PhantomData<M>,
59 }
60 
61 /// All `BoxedLimbs<M>` are stored in the same number of limbs.
62 struct BoxedLimbs<M> {
63     limbs: Box<[Limb]>,
64 
65     /// The modulus *m* that determines the size of `limbx`.
66     m: PhantomData<M>,
67 }
68 
69 impl<M> Deref for BoxedLimbs<M> {
70     type Target = [Limb];
71     #[inline]
deref(&self) -> &Self::Target72     fn deref(&self) -> &Self::Target {
73         &self.limbs
74     }
75 }
76 
77 impl<M> DerefMut for BoxedLimbs<M> {
78     #[inline]
deref_mut(&mut self) -> &mut Self::Target79     fn deref_mut(&mut self) -> &mut Self::Target {
80         &mut self.limbs
81     }
82 }
83 
84 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
85 // is resolved or restrict `M: Clone`.
86 impl<M> Clone for BoxedLimbs<M> {
clone(&self) -> Self87     fn clone(&self) -> Self {
88         Self {
89             limbs: self.limbs.clone(),
90             m: self.m,
91         }
92     }
93 }
94 
95 impl<M> BoxedLimbs<M> {
positive_minimal_width_from_be_bytes( input: untrusted::Input, ) -> Result<Self, error::KeyRejected>96     fn positive_minimal_width_from_be_bytes(
97         input: untrusted::Input,
98     ) -> Result<Self, error::KeyRejected> {
99         // Reject leading zeros. Also reject the value zero ([0]) because zero
100         // isn't positive.
101         if untrusted::Reader::new(input).peek(0) {
102             return Err(error::KeyRejected::invalid_encoding());
103         }
104         let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
105         let mut r = Self::zero(Width {
106             num_limbs,
107             m: PhantomData,
108         });
109         limb::parse_big_endian_and_pad_consttime(input, &mut r)
110             .map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
111         Ok(r)
112     }
113 
minimal_width_from_unpadded(limbs: &[Limb]) -> Self114     fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
115         debug_assert_ne!(limbs.last(), Some(&0));
116         Self {
117             limbs: limbs.to_owned().into_boxed_slice(),
118             m: PhantomData,
119         }
120     }
121 
from_be_bytes_padded_less_than( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>122     fn from_be_bytes_padded_less_than(
123         input: untrusted::Input,
124         m: &Modulus<M>,
125     ) -> Result<Self, error::Unspecified> {
126         let mut r = Self::zero(m.width());
127         limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
128         if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
129             return Err(error::Unspecified);
130         }
131         Ok(r)
132     }
133 
134     #[inline]
is_zero(&self) -> bool135     fn is_zero(&self) -> bool {
136         limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
137     }
138 
zero(width: Width<M>) -> Self139     fn zero(width: Width<M>) -> Self {
140         Self {
141             limbs: vec![0; width.num_limbs].into_boxed_slice(),
142             m: PhantomData,
143         }
144     }
145 
width(&self) -> Width<M>146     fn width(&self) -> Width<M> {
147         Width {
148             num_limbs: self.limbs.len(),
149             m: PhantomData,
150         }
151     }
152 }
153 
154 /// A modulus *s* that is smaller than another modulus *l* so every element of
155 /// ℤ/sℤ is also an element of ℤ/lℤ.
156 pub unsafe trait SmallerModulus<L> {}
157 
158 /// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
159 /// the precondition for reduction by conditional subtraction,
160 /// `elem_reduce_once()`.
161 pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
162 
163 /// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
164 /// the precondition for the more general Montgomery reduction from ℤ/lℤ to
165 /// ℤ/sℤ.
166 pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
167 
168 pub unsafe trait PublicModulus {}
169 
170 /// The x86 implementation of `bn_mul_mont`, at least, requires at least 4
171 /// limbs. For a long time we have required 4 limbs for all targets, though
172 /// this may be unnecessary. TODO: Replace this with
173 /// `n.len() < 256 / LIMB_BITS` so that 32-bit and 64-bit platforms behave the
174 /// same.
175 pub const MODULUS_MIN_LIMBS: usize = 4;
176 
177 pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
178 
179 /// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
180 /// for efficient Montgomery multiplication modulo *m*. The value must be odd
181 /// and larger than 2. The larger-than-1 requirement is imposed, at least, by
182 /// the modular inversion code.
183 pub struct Modulus<M> {
184     limbs: BoxedLimbs<M>, // Also `value >= 3`.
185 
186     // n0 * N == -1 (mod r).
187     //
188     // r == 2**(N0_LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
189     // ensures that we can do integer division by |r| by simply ignoring
190     // `N0_LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
191     // just looking at the lowest `N0_LIMBS_USED` limbs. This is what makes
192     // Montgomery multiplication efficient.
193     //
194     // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
195     // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
196     // multi-limb Montgomery multiplication of a * b (mod n), given the
197     // unreduced product t == a * b, we repeatedly calculate:
198     //
199     //    t1 := t % r         |t1| is |t|'s lowest limb (see previous paragraph).
200     //    t2 := t1*n0*n
201     //    t3 := t + t2
202     //    t := t3 / r         copy all limbs of |t3| except the lowest to |t|.
203     //
204     // In the last step, it would only make sense to ignore the lowest limb of
205     // |t3| if it were zero. The middle steps ensure that this is the case:
206     //
207     //                            t3 ==  0 (mod r)
208     //                        t + t2 ==  0 (mod r)
209     //                   t + t1*n0*n ==  0 (mod r)
210     //                       t1*n0*n == -t (mod r)
211     //                        t*n0*n == -t (mod r)
212     //                          n0*n == -1 (mod r)
213     //                            n0 == -1/n (mod r)
214     //
215     // Thus, in each iteration of the loop, we multiply by the constant factor
216     // n0, the negative inverse of n (mod r).
217     //
218     // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
219     // ones that don't, we could use a shorter `R` value and use faster `Limb`
220     // calculations instead of double-precision `u64` calculations.
221     n0: N0,
222 
223     oneRR: One<M, RR>,
224 }
225 
226 impl<M: PublicModulus> Modulus<M> {
to_be_bytes(&self) -> Box<[u8]>227     pub fn to_be_bytes(&self) -> Box<[u8]> {
228         let mut padded = vec![0u8; self.limbs.len() * LIMB_BYTES];
229         // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
230         limb::big_endian_from_limbs(&self.limbs, &mut padded);
231         strip_leading_zeros(&padded)
232     }
233 }
234 
235 impl<M: PublicModulus> Clone for Modulus<M> {
clone(&self) -> Self236     fn clone(&self) -> Self {
237         Self {
238             limbs: self.limbs.clone(),
239             n0: self.n0.clone(),
240             oneRR: self.oneRR.clone(),
241         }
242     }
243 }
244 
245 impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error>246     fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
247         let mut state = fmt.debug_tuple("Modulus");
248 
249         #[cfg(feature = "alloc")]
250         let state = {
251             let value = self.to_be_bytes(); // XXX: Allocates
252             state.field(&debug::HexStr(&value))
253         };
254 
255         state.finish()
256     }
257 }
258 
259 impl<M> Modulus<M> {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::KeyRejected>260     pub fn from_be_bytes_with_bit_length(
261         input: untrusted::Input,
262     ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
263         let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
264         Self::from_boxed_limbs(limbs)
265     }
266 
from_nonnegative_with_bit_length( n: Nonnegative, ) -> Result<(Self, bits::BitLength), error::KeyRejected>267     pub fn from_nonnegative_with_bit_length(
268         n: Nonnegative,
269     ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
270         let limbs = BoxedLimbs {
271             limbs: n.limbs.into_boxed_slice(),
272             m: PhantomData,
273         };
274         Self::from_boxed_limbs(limbs)
275     }
276 
from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected>277     fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
278         if n.len() > MODULUS_MAX_LIMBS {
279             return Err(error::KeyRejected::too_large());
280         }
281         if n.len() < MODULUS_MIN_LIMBS {
282             return Err(error::KeyRejected::unexpected_error());
283         }
284         if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
285             return Err(error::KeyRejected::invalid_component());
286         }
287         if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
288             return Err(error::KeyRejected::unexpected_error());
289         }
290 
291         // n_mod_r = n % r. As explained in the documentation for `n0`, this is
292         // done by taking the lowest `N0_LIMBS_USED` limbs of `n`.
293         #[allow(clippy::useless_conversion)]
294         let n0 = {
295             prefixed_extern! {
296                 fn bn_neg_inv_mod_r_u64(n: u64) -> u64;
297             }
298 
299             // XXX: u64::from isn't guaranteed to be constant time.
300             let mut n_mod_r: u64 = u64::from(n[0]);
301 
302             if N0_LIMBS_USED == 2 {
303                 // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
304                 // fail to compile because of `deny(exceeding_bitshifts)`.
305                 debug_assert_eq!(LIMB_BITS, 32);
306                 n_mod_r |= u64::from(n[1]) << 32;
307             }
308             N0::from(unsafe { bn_neg_inv_mod_r_u64(n_mod_r) })
309         };
310 
311         let bits = limb::limbs_minimal_bits(&n.limbs);
312         let oneRR = {
313             let partial = PartialModulus {
314                 limbs: &n.limbs,
315                 n0: n0.clone(),
316                 m: PhantomData,
317             };
318 
319             One::newRR(&partial, bits)
320         };
321 
322         Ok((
323             Self {
324                 limbs: n,
325                 n0,
326                 oneRR,
327             },
328             bits,
329         ))
330     }
331 
332     #[inline]
width(&self) -> Width<M>333     fn width(&self) -> Width<M> {
334         self.limbs.width()
335     }
336 
zero<E>(&self) -> Elem<M, E>337     fn zero<E>(&self) -> Elem<M, E> {
338         Elem {
339             limbs: BoxedLimbs::zero(self.width()),
340             encoding: PhantomData,
341         }
342     }
343 
344     // TODO: Get rid of this
one(&self) -> Elem<M, Unencoded>345     fn one(&self) -> Elem<M, Unencoded> {
346         let mut r = self.zero();
347         r.limbs[0] = 1;
348         r
349     }
350 
oneRR(&self) -> &One<M, RR>351     pub fn oneRR(&self) -> &One<M, RR> {
352         &self.oneRR
353     }
354 
to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded> where M: SmallerModulus<L>,355     pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
356     where
357         M: SmallerModulus<L>,
358     {
359         // TODO: Encode this assertion into the `where` above.
360         assert_eq!(self.width().num_limbs, l.width().num_limbs);
361         let limbs = self.limbs.clone();
362         Elem {
363             limbs: BoxedLimbs {
364                 limbs: limbs.limbs,
365                 m: PhantomData,
366             },
367             encoding: PhantomData,
368         }
369     }
370 
as_partial(&self) -> PartialModulus<M>371     fn as_partial(&self) -> PartialModulus<M> {
372         PartialModulus {
373             limbs: &self.limbs,
374             n0: self.n0.clone(),
375             m: PhantomData,
376         }
377     }
378 }
379 
380 struct PartialModulus<'a, M> {
381     limbs: &'a [Limb],
382     n0: N0,
383     m: PhantomData<M>,
384 }
385 
386 impl<M> PartialModulus<'_, M> {
387     // TODO: XXX Avoid duplication with `Modulus`.
zero(&self) -> Elem<M, R>388     fn zero(&self) -> Elem<M, R> {
389         let width = Width {
390             num_limbs: self.limbs.len(),
391             m: PhantomData,
392         };
393         Elem {
394             limbs: BoxedLimbs::zero(width),
395             encoding: PhantomData,
396         }
397     }
398 }
399 
400 /// Elements of ℤ/mℤ for some modulus *m*.
401 //
402 // Defaulting `E` to `Unencoded` is a convenience for callers from outside this
403 // submodule. However, for maximum clarity, we always explicitly use
404 // `Unencoded` within the `bigint` submodule.
405 pub struct Elem<M, E = Unencoded> {
406     limbs: BoxedLimbs<M>,
407 
408     /// The number of Montgomery factors that need to be canceled out from
409     /// `value` to get the actual value.
410     encoding: PhantomData<E>,
411 }
412 
413 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
414 // is resolved or restrict `M: Clone` and `E: Clone`.
415 impl<M, E> Clone for Elem<M, E> {
clone(&self) -> Self416     fn clone(&self) -> Self {
417         Self {
418             limbs: self.limbs.clone(),
419             encoding: self.encoding,
420         }
421     }
422 }
423 
424 impl<M, E> Elem<M, E> {
425     #[inline]
is_zero(&self) -> bool426     pub fn is_zero(&self) -> bool {
427         self.limbs.is_zero()
428     }
429 }
430 
431 impl<M, E: ReductionEncoding> Elem<M, E> {
decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output>432     fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
433         // A multiplication isn't required since we're multiplying by the
434         // unencoded value one (1); only a Montgomery reduction is needed.
435         // However the only non-multiplication Montgomery reduction function we
436         // have requires the input to be large, so we avoid using it here.
437         let mut limbs = self.limbs;
438         let num_limbs = m.width().num_limbs;
439         let mut one = [0; MODULUS_MAX_LIMBS];
440         one[0] = 1;
441         let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
442         limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
443         Elem {
444             limbs,
445             encoding: PhantomData,
446         }
447     }
448 }
449 
450 impl<M> Elem<M, R> {
451     #[inline]
into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded>452     pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
453         self.decode_once(m)
454     }
455 }
456 
457 impl<M> Elem<M, Unencoded> {
from_be_bytes_padded( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>458     pub fn from_be_bytes_padded(
459         input: untrusted::Input,
460         m: &Modulus<M>,
461     ) -> Result<Self, error::Unspecified> {
462         Ok(Self {
463             limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
464             encoding: PhantomData,
465         })
466     }
467 
468     #[inline]
fill_be_bytes(&self, out: &mut [u8])469     pub fn fill_be_bytes(&self, out: &mut [u8]) {
470         // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
471         limb::big_endian_from_limbs(&self.limbs, out)
472     }
473 
into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected>474     pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
475         let (m, _bits) =
476             Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
477         Ok(m)
478     }
479 
is_one(&self) -> bool480     fn is_one(&self) -> bool {
481         limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
482     }
483 }
484 
elem_mul<M, AF, BF>( a: &Elem<M, AF>, b: Elem<M, BF>, m: &Modulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,485 pub fn elem_mul<M, AF, BF>(
486     a: &Elem<M, AF>,
487     b: Elem<M, BF>,
488     m: &Modulus<M>,
489 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
490 where
491     (AF, BF): ProductEncoding,
492 {
493     elem_mul_(a, b, &m.as_partial())
494 }
495 
elem_mul_<M, AF, BF>( a: &Elem<M, AF>, mut b: Elem<M, BF>, m: &PartialModulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,496 fn elem_mul_<M, AF, BF>(
497     a: &Elem<M, AF>,
498     mut b: Elem<M, BF>,
499     m: &PartialModulus<M>,
500 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
501 where
502     (AF, BF): ProductEncoding,
503 {
504     limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
505     Elem {
506         limbs: b.limbs,
507         encoding: PhantomData,
508     }
509 }
510 
elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>)511 fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
512     prefixed_extern! {
513         fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
514     }
515     unsafe {
516         LIMBS_shl_mod(
517             a.limbs.as_mut_ptr(),
518             a.limbs.as_ptr(),
519             m.limbs.as_ptr(),
520             m.limbs.len(),
521         );
522     }
523 }
524 
elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, Unencoded>525 pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
526     a: &Elem<Larger, Unencoded>,
527     m: &Modulus<Smaller>,
528 ) -> Elem<Smaller, Unencoded> {
529     let mut r = a.limbs.clone();
530     assert!(r.len() <= m.limbs.len());
531     limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
532     Elem {
533         limbs: BoxedLimbs {
534             limbs: r.limbs,
535             m: PhantomData,
536         },
537         encoding: PhantomData,
538     }
539 }
540 
541 #[inline]
elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, RInverse>542 pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
543     a: &Elem<Larger, Unencoded>,
544     m: &Modulus<Smaller>,
545 ) -> Elem<Smaller, RInverse> {
546     let mut tmp = [0; MODULUS_MAX_LIMBS];
547     let tmp = &mut tmp[..a.limbs.len()];
548     tmp.copy_from_slice(&a.limbs);
549 
550     let mut r = m.zero();
551     limbs_from_mont_in_place(&mut r.limbs, tmp, &m.limbs, &m.n0);
552     r
553 }
554 
elem_squared<M, E>( mut a: Elem<M, E>, m: &PartialModulus<M>, ) -> Elem<M, <(E, E) as ProductEncoding>::Output> where (E, E): ProductEncoding,555 fn elem_squared<M, E>(
556     mut a: Elem<M, E>,
557     m: &PartialModulus<M>,
558 ) -> Elem<M, <(E, E) as ProductEncoding>::Output>
559 where
560     (E, E): ProductEncoding,
561 {
562     limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
563     Elem {
564         limbs: a.limbs,
565         encoding: PhantomData,
566     }
567 }
568 
elem_widen<Larger, Smaller: SmallerModulus<Larger>>( a: Elem<Smaller, Unencoded>, m: &Modulus<Larger>, ) -> Elem<Larger, Unencoded>569 pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
570     a: Elem<Smaller, Unencoded>,
571     m: &Modulus<Larger>,
572 ) -> Elem<Larger, Unencoded> {
573     let mut r = m.zero();
574     r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
575     r
576 }
577 
578 // TODO: Document why this works for all Montgomery factors.
elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>579 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
580     limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, &m.limbs);
581     a
582 }
583 
584 // TODO: Document why this works for all Montgomery factors.
elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>585 pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
586     prefixed_extern! {
587         // `r` and `a` may alias.
588         fn LIMBS_sub_mod(
589             r: *mut Limb,
590             a: *const Limb,
591             b: *const Limb,
592             m: *const Limb,
593             num_limbs: c::size_t,
594         );
595     }
596     unsafe {
597         LIMBS_sub_mod(
598             a.limbs.as_mut_ptr(),
599             a.limbs.as_ptr(),
600             b.limbs.as_ptr(),
601             m.limbs.as_ptr(),
602             m.limbs.len(),
603         );
604     }
605     a
606 }
607 
608 // The value 1, Montgomery-encoded some number of times.
609 pub struct One<M, E>(Elem<M, E>);
610 
611 impl<M> One<M, RR> {
612     // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
613     // 2**LIMB_BITS such that R > m.
614     //
615     // Even though the assembly on some 32-bit platforms works with 64-bit
616     // values, using `LIMB_BITS` here, rather than `N0_LIMBS_USED * LIMB_BITS`,
617     // is correct because R**2 will still be a multiple of the latter as
618     // `N0_LIMBS_USED` is either one or two.
newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self619     fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
620         let m_bits = m_bits.as_usize_bits();
621         let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
622 
623         // base = 2**(lg m - 1).
624         let bit = m_bits - 1;
625         let mut base = m.zero();
626         base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
627 
628         // Double `base` so that base == R == 2**r (mod m). For normal moduli
629         // that have the high bit of the highest limb set, this requires one
630         // doubling. Unusual moduli require more doublings but we are less
631         // concerned about the performance of those.
632         //
633         // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
634         // Montgomery form (`elem_exp_vartime_()` requires the base to be in
635         // Montgomery form). Then compute
636         // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
637         //
638         // Take advantage of the fact that `elem_mul_by_2` is faster than
639         // `elem_squared` by replacing some of the early squarings with shifts.
640         // TODO: Benchmark shift vs. squaring performance to determine the
641         // optimal value of `lg_base`.
642         let lg_base = 2usize; // Shifts vs. squaring trade-off.
643         debug_assert_eq!(lg_base.count_ones(), 1); // Must 2**n for n >= 0.
644         let shifts = r - bit + lg_base;
645         let exponent = (r / lg_base) as u64;
646         for _ in 0..shifts {
647             elem_mul_by_2(&mut base, m)
648         }
649         let RR = elem_exp_vartime_(base, exponent, m);
650 
651         Self(Elem {
652             limbs: RR.limbs,
653             encoding: PhantomData, // PhantomData<RR>
654         })
655     }
656 }
657 
658 impl<M: PublicModulus, E> Clone for One<M, E> {
clone(&self) -> Self659     fn clone(&self) -> Self {
660         Self(self.0.clone())
661     }
662 }
663 
664 impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
as_ref(&self) -> &Elem<M, E>665     fn as_ref(&self) -> &Elem<M, E> {
666         &self.0
667     }
668 }
669 
670 /// A non-secret odd positive value in the range
671 /// [3, PUBLIC_EXPONENT_MAX_VALUE].
672 #[derive(Clone, Copy)]
673 pub struct PublicExponent(u64);
674 
675 impl core::fmt::Debug for PublicExponent {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error>676     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
677         write!(f, "{}", self.0)
678     }
679 }
680 
681 impl PublicExponent {
from_be_bytes( input: untrusted::Input, min_value: u64, ) -> Result<Self, error::KeyRejected>682     pub fn from_be_bytes(
683         input: untrusted::Input,
684         min_value: u64,
685     ) -> Result<Self, error::KeyRejected> {
686         if input.len() > 5 {
687             return Err(error::KeyRejected::too_large());
688         }
689         let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
690             // The exponent can't be zero and it can't be prefixed with
691             // zero-valued bytes.
692             if input.peek(0) {
693                 return Err(error::KeyRejected::invalid_encoding());
694             }
695             let mut value = 0u64;
696             loop {
697                 let byte = input
698                     .read_byte()
699                     .map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
700                 value = (value << 8) | u64::from(byte);
701                 if input.at_end() {
702                     return Ok(value);
703                 }
704             }
705         })?;
706 
707         // Step 2 / Step b. NIST SP800-89 defers to FIPS 186-3, which requires
708         // `e >= 65537`. We enforce this when signing, but are more flexible in
709         // verification, for compatibility. Only small public exponents are
710         // supported.
711         if value & 1 != 1 {
712             return Err(error::KeyRejected::invalid_component());
713         }
714         debug_assert!(min_value & 1 == 1);
715         debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
716         if min_value < 3 {
717             return Err(error::KeyRejected::invalid_component());
718         }
719         if value < min_value {
720             return Err(error::KeyRejected::too_small());
721         }
722         if value > PUBLIC_EXPONENT_MAX_VALUE {
723             return Err(error::KeyRejected::too_large());
724         }
725 
726         Ok(Self(value))
727     }
728 
729     #[inline]
to_be_bytes(&self) -> Box<[u8]>730     pub fn to_be_bytes(&self) -> Box<[u8]> {
731         strip_leading_zeros(&u64::to_be_bytes(self.0))
732     }
733 }
734 
735 // This limit was chosen to bound the performance of the simple
736 // exponentiation-by-squaring implementation in `elem_exp_vartime`. In
737 // particular, it helps mitigate theoretical resource exhaustion attacks. 33
738 // bits was chosen as the limit based on the recommendations in [1] and
739 // [2]. Windows CryptoAPI (at least older versions) doesn't support values
740 // larger than 32 bits [3], so it is unlikely that exponents larger than 32
741 // bits are being used for anything Windows commonly does.
742 //
743 // [1] https://www.imperialviolet.org/2012/03/16/rsae.html
744 // [2] https://www.imperialviolet.org/2012/03/17/rsados.html
745 // [3] https://msdn.microsoft.com/en-us/library/aa387685(VS.85).aspx
746 const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
747 
748 /// Calculates base**exponent (mod m).
749 // TODO: The test coverage needs to be expanded, e.g. test with the largest
750 // accepted exponent and with the most common values of 65537 and 3.
elem_exp_vartime<M>( base: Elem<M, Unencoded>, PublicExponent(exponent): PublicExponent, m: &Modulus<M>, ) -> Elem<M, R>751 pub fn elem_exp_vartime<M>(
752     base: Elem<M, Unencoded>,
753     PublicExponent(exponent): PublicExponent,
754     m: &Modulus<M>,
755 ) -> Elem<M, R> {
756     let base = elem_mul(m.oneRR().as_ref(), base, &m);
757     elem_exp_vartime_(base, exponent, &m.as_partial())
758 }
759 
760 /// Calculates base**exponent (mod m).
elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R>761 fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
762     // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
763     // square-and-multiply that scans the exponent from the most significant
764     // bit to the least significant bit (left-to-right). Left-to-right requires
765     // less storage compared to right-to-left scanning, at the cost of needing
766     // to compute `exponent.leading_zeros()`, which we assume to be cheap.
767     //
768     // During RSA public key operations the exponent is almost always either 65537
769     // (0b10000000000000001) or 3 (0b11), both of which have a Hamming weight
770     // of 2. During Montgomery setup the exponent is almost always a power of two,
771     // with Hamming weight 1. As explained in [Knuth], exponentiation by squaring
772     // is the most efficient algorithm when the Hamming weight is 2 or less. It
773     // isn't the most efficient for all other, uncommon, exponent values but any
774     // suboptimality is bounded by `PUBLIC_EXPONENT_MAX_VALUE`.
775     //
776     // This implementation is slightly simplified by taking advantage of the
777     // fact that we require the exponent to be a positive integer.
778     //
779     // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
780     //          Algorithms (3rd Edition), Section 4.6.3.
781     assert!(exponent >= 1);
782     assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
783     let mut acc = base.clone();
784     let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
785     debug_assert!((exponent & bit) != 0);
786     while bit > 1 {
787         bit >>= 1;
788         acc = elem_squared(acc, m);
789         if (exponent & bit) != 0 {
790             acc = elem_mul_(&base, acc, m);
791         }
792     }
793     acc
794 }
795 
796 // `M` represents the prime modulus for which the exponent is in the interval
797 // [1, `m` - 1).
798 pub struct PrivateExponent<M> {
799     limbs: BoxedLimbs<M>,
800 }
801 
802 impl<M> PrivateExponent<M> {
from_be_bytes_padded( input: untrusted::Input, p: &Modulus<M>, ) -> Result<Self, error::Unspecified>803     pub fn from_be_bytes_padded(
804         input: untrusted::Input,
805         p: &Modulus<M>,
806     ) -> Result<Self, error::Unspecified> {
807         let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
808 
809         // Proof that `dP < p - 1`:
810         //
811         // If `dP < p` then either `dP == p - 1` or `dP < p - 1`. Since `p` is
812         // odd, `p - 1` is even. `d` is odd, and an odd number modulo an even
813         // number is odd. Therefore `dP` must be odd. But then it cannot be
814         // `p - 1` and so we know `dP < p - 1`.
815         //
816         // Further we know `dP != 0` because `dP` is not even.
817         if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
818             return Err(error::Unspecified);
819         }
820 
821         Ok(Self { limbs: dP })
822     }
823 }
824 
825 impl<M: Prime> PrivateExponent<M> {
826     // Returns `p - 2`.
for_flt(p: &Modulus<M>) -> Self827     fn for_flt(p: &Modulus<M>) -> Self {
828         let two = elem_add(p.one(), p.one(), p);
829         let p_minus_2 = elem_sub(p.zero(), &two, p);
830         Self {
831             limbs: p_minus_2.limbs,
832         }
833     }
834 }
835 
836 #[cfg(not(target_arch = "x86_64"))]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>837 pub fn elem_exp_consttime<M>(
838     base: Elem<M, R>,
839     exponent: &PrivateExponent<M>,
840     m: &Modulus<M>,
841 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
842     use crate::limb::Window;
843 
844     const WINDOW_BITS: usize = 5;
845     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
846 
847     let num_limbs = m.limbs.len();
848 
849     let mut table = vec![0; TABLE_ENTRIES * num_limbs];
850 
851     fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
852         prefixed_extern! {
853             fn LIMBS_select_512_32(
854                 r: *mut Limb,
855                 table: *const Limb,
856                 num_limbs: c::size_t,
857                 i: Window,
858             ) -> bssl::Result;
859         }
860         Result::from(unsafe {
861             LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
862         })
863         .unwrap();
864     }
865 
866     fn power<M>(
867         table: &[Limb],
868         i: Window,
869         mut acc: Elem<M, R>,
870         mut tmp: Elem<M, R>,
871         m: &Modulus<M>,
872     ) -> (Elem<M, R>, Elem<M, R>) {
873         for _ in 0..WINDOW_BITS {
874             acc = elem_squared(acc, &m.as_partial());
875         }
876         gather(table, i, &mut tmp);
877         let acc = elem_mul(&tmp, acc, m);
878         (acc, tmp)
879     }
880 
881     let tmp = m.one();
882     let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
883 
884     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
885         &table[(i * num_limbs)..][..num_limbs]
886     }
887     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
888         &mut table[(i * num_limbs)..][..num_limbs]
889     }
890     let num_limbs = m.limbs.len();
891     entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
892     entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
893     for i in 2..TABLE_ENTRIES {
894         let (src1, src2) = if i % 2 == 0 {
895             (i / 2, i / 2)
896         } else {
897             (i - 1, 1)
898         };
899         let (previous, rest) = table.split_at_mut(num_limbs * i);
900         let src1 = entry(previous, src1, num_limbs);
901         let src2 = entry(previous, src2, num_limbs);
902         let dst = entry_mut(rest, 0, num_limbs);
903         limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
904     }
905 
906     let (r, _) = limb::fold_5_bit_windows(
907         &exponent.limbs,
908         |initial_window| {
909             let mut r = Elem {
910                 limbs: base.limbs,
911                 encoding: PhantomData,
912             };
913             gather(&table, initial_window, &mut r);
914             (r, tmp)
915         },
916         |(acc, tmp), window| power(&table, window, acc, tmp, m),
917     );
918 
919     let r = r.into_unencoded(m);
920 
921     Ok(r)
922 }
923 
924 /// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
elem_inverse_consttime<M: Prime>( a: Elem<M, R>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>925 pub fn elem_inverse_consttime<M: Prime>(
926     a: Elem<M, R>,
927     m: &Modulus<M>,
928 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
929     elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
930 }
931 
932 #[cfg(target_arch = "x86_64")]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>933 pub fn elem_exp_consttime<M>(
934     base: Elem<M, R>,
935     exponent: &PrivateExponent<M>,
936     m: &Modulus<M>,
937 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
938     // The x86_64 assembly was written under the assumption that the input data
939     // is aligned to `MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH` bytes, which was/is
940     // 64 in OpenSSL. Similarly, OpenSSL uses the x86_64 assembly functions by
941     // giving it only inputs `tmp`, `am`, and `np` that immediately follow the
942     // table. The code seems to "work" even when the inputs aren't exactly
943     // like that but the side channel defenses might not be as effective. All
944     // the awkwardness here stems from trying to use the assembly code like
945     // OpenSSL does.
946 
947     use crate::limb::Window;
948 
949     const WINDOW_BITS: usize = 5;
950     const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
951 
952     let num_limbs = m.limbs.len();
953 
954     const ALIGNMENT: usize = 64;
955     assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
956     let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
957     let (table, state) = {
958         let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
959         let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
960         assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
961         table.split_at_mut(TABLE_ENTRIES * num_limbs)
962     };
963 
964     fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
965         &table[(i * num_limbs)..][..num_limbs]
966     }
967     fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
968         &mut table[(i * num_limbs)..][..num_limbs]
969     }
970 
971     const ACC: usize = 0; // `tmp` in OpenSSL
972     const BASE: usize = ACC + 1; // `am` in OpenSSL
973     const M: usize = BASE + 1; // `np` in OpenSSL
974 
975     entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
976     entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
977 
978     fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
979         prefixed_extern! {
980             fn bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
981         }
982         unsafe {
983             bn_scatter5(
984                 entry(state, ACC, num_limbs).as_ptr(),
985                 num_limbs,
986                 table.as_mut_ptr(),
987                 i,
988             )
989         }
990     }
991 
992     fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
993         prefixed_extern! {
994             fn bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
995         }
996         unsafe {
997             bn_gather5(
998                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
999                 num_limbs,
1000                 table.as_ptr(),
1001                 i,
1002             )
1003         }
1004     }
1005 
1006     fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1007         gather(table, state, i, num_limbs);
1008         assert_eq!(ACC, 0);
1009         let (acc, rest) = state.split_at_mut(num_limbs);
1010         let m = entry(rest, M - 1, num_limbs);
1011         limbs_mont_square(acc, m, n0);
1012     }
1013 
1014     fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1015         prefixed_extern! {
1016             fn bn_mul_mont_gather5(
1017                 rp: *mut Limb,
1018                 ap: *const Limb,
1019                 table: *const Limb,
1020                 np: *const Limb,
1021                 n0: &N0,
1022                 num: c::size_t,
1023                 power: Window,
1024             );
1025         }
1026         unsafe {
1027             bn_mul_mont_gather5(
1028                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1029                 entry(state, BASE, num_limbs).as_ptr(),
1030                 table.as_ptr(),
1031                 entry(state, M, num_limbs).as_ptr(),
1032                 n0,
1033                 num_limbs,
1034                 i,
1035             );
1036         }
1037     }
1038 
1039     fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1040         prefixed_extern! {
1041             fn bn_power5(
1042                 r: *mut Limb,
1043                 a: *const Limb,
1044                 table: *const Limb,
1045                 n: *const Limb,
1046                 n0: &N0,
1047                 num: c::size_t,
1048                 i: Window,
1049             );
1050         }
1051         unsafe {
1052             bn_power5(
1053                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1054                 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1055                 table.as_ptr(),
1056                 entry(state, M, num_limbs).as_ptr(),
1057                 n0,
1058                 num_limbs,
1059                 i,
1060             );
1061         }
1062     }
1063 
1064     // table[0] = base**0.
1065     {
1066         let acc = entry_mut(state, ACC, num_limbs);
1067         acc[0] = 1;
1068         limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
1069     }
1070     scatter(table, state, 0, num_limbs);
1071 
1072     // table[1] = base**1.
1073     entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
1074     scatter(table, state, 1, num_limbs);
1075 
1076     for i in 2..(TABLE_ENTRIES as Window) {
1077         if i % 2 == 0 {
1078             // TODO: Optimize this to avoid gathering
1079             gather_square(table, state, &m.n0, i / 2, num_limbs);
1080         } else {
1081             gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
1082         };
1083         scatter(table, state, i, num_limbs);
1084     }
1085 
1086     let state = limb::fold_5_bit_windows(
1087         &exponent.limbs,
1088         |initial_window| {
1089             gather(table, state, initial_window, num_limbs);
1090             state
1091         },
1092         |state, window| {
1093             power(table, state, &m.n0, window, num_limbs);
1094             state
1095         },
1096     );
1097 
1098     prefixed_extern! {
1099         fn bn_from_montgomery(
1100             r: *mut Limb,
1101             a: *const Limb,
1102             not_used: *const Limb,
1103             n: *const Limb,
1104             n0: &N0,
1105             num: c::size_t,
1106         ) -> bssl::Result;
1107     }
1108     Result::from(unsafe {
1109         bn_from_montgomery(
1110             entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1111             entry(state, ACC, num_limbs).as_ptr(),
1112             core::ptr::null(),
1113             entry(state, M, num_limbs).as_ptr(),
1114             &m.n0,
1115             num_limbs,
1116         )
1117     })?;
1118     let mut r = Elem {
1119         limbs: base.limbs,
1120         encoding: PhantomData,
1121     };
1122     r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
1123     Ok(r)
1124 }
1125 
1126 /// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
verify_inverses_consttime<M>( a: &Elem<M, R>, b: Elem<M, Unencoded>, m: &Modulus<M>, ) -> Result<(), error::Unspecified>1127 pub fn verify_inverses_consttime<M>(
1128     a: &Elem<M, R>,
1129     b: Elem<M, Unencoded>,
1130     m: &Modulus<M>,
1131 ) -> Result<(), error::Unspecified> {
1132     if elem_mul(a, b, m).is_one() {
1133         Ok(())
1134     } else {
1135         Err(error::Unspecified)
1136     }
1137 }
1138 
1139 #[inline]
elem_verify_equal_consttime<M, E>( a: &Elem<M, E>, b: &Elem<M, E>, ) -> Result<(), error::Unspecified>1140 pub fn elem_verify_equal_consttime<M, E>(
1141     a: &Elem<M, E>,
1142     b: &Elem<M, E>,
1143 ) -> Result<(), error::Unspecified> {
1144     if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
1145         Ok(())
1146     } else {
1147         Err(error::Unspecified)
1148     }
1149 }
1150 
1151 /// Nonnegative integers.
1152 pub struct Nonnegative {
1153     limbs: Vec<Limb>,
1154 }
1155 
1156 impl Nonnegative {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::Unspecified>1157     pub fn from_be_bytes_with_bit_length(
1158         input: untrusted::Input,
1159     ) -> Result<(Self, bits::BitLength), error::Unspecified> {
1160         let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
1161         // Rejects empty inputs.
1162         limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
1163         while limbs.last() == Some(&0) {
1164             let _ = limbs.pop();
1165         }
1166         let r_bits = limb::limbs_minimal_bits(&limbs);
1167         Ok((Self { limbs }, r_bits))
1168     }
1169 
1170     #[inline]
is_odd(&self) -> bool1171     pub fn is_odd(&self) -> bool {
1172         limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
1173     }
1174 
verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified>1175     pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
1176         if !greater_than(other, self) {
1177             return Err(error::Unspecified);
1178         }
1179         Ok(())
1180     }
1181 
to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified>1182     pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
1183         self.verify_less_than_modulus(&m)?;
1184         let mut r = m.zero();
1185         r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
1186         Ok(r)
1187     }
1188 
verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified>1189     pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
1190         if self.limbs.len() > m.limbs.len() {
1191             return Err(error::Unspecified);
1192         }
1193         if self.limbs.len() == m.limbs.len() {
1194             if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
1195                 return Err(error::Unspecified);
1196             }
1197         }
1198         Ok(())
1199     }
1200 }
1201 
1202 // Returns a > b.
greater_than(a: &Nonnegative, b: &Nonnegative) -> bool1203 fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
1204     if a.limbs.len() == b.limbs.len() {
1205         limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
1206     } else {
1207         a.limbs.len() > b.limbs.len()
1208     }
1209 }
1210 
1211 #[derive(Clone)]
1212 #[repr(transparent)]
1213 struct N0([Limb; 2]);
1214 
1215 const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
1216 
1217 impl From<u64> for N0 {
1218     #[inline]
from(n0: u64) -> Self1219     fn from(n0: u64) -> Self {
1220         #[cfg(target_pointer_width = "64")]
1221         {
1222             Self([n0, 0])
1223         }
1224 
1225         #[cfg(target_pointer_width = "32")]
1226         {
1227             Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
1228         }
1229     }
1230 }
1231 
1232 /// r *= a
limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0)1233 fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
1234     debug_assert_eq!(r.len(), m.len());
1235     debug_assert_eq!(a.len(), m.len());
1236     unsafe {
1237         bn_mul_mont(
1238             r.as_mut_ptr(),
1239             r.as_ptr(),
1240             a.as_ptr(),
1241             m.as_ptr(),
1242             n0,
1243             r.len(),
1244         )
1245     }
1246 }
1247 
limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0)1248 fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
1249     prefixed_extern! {
1250         fn bn_from_montgomery_in_place(
1251             r: *mut Limb,
1252             num_r: c::size_t,
1253             a: *mut Limb,
1254             num_a: c::size_t,
1255             n: *const Limb,
1256             num_n: c::size_t,
1257             n0: &N0,
1258         ) -> bssl::Result;
1259     }
1260     Result::from(unsafe {
1261         bn_from_montgomery_in_place(
1262             r.as_mut_ptr(),
1263             r.len(),
1264             tmp.as_mut_ptr(),
1265             tmp.len(),
1266             m.as_ptr(),
1267             m.len(),
1268             &n0,
1269         )
1270     })
1271     .unwrap()
1272 }
1273 
1274 #[cfg(not(any(
1275     target_arch = "aarch64",
1276     target_arch = "arm",
1277     target_arch = "x86",
1278     target_arch = "x86_64"
1279 )))]
limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb])1280 fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
1281     debug_assert_eq!(r.len(), 2 * a.len());
1282     debug_assert_eq!(a.len(), b.len());
1283     let ab_len = a.len();
1284 
1285     crate::polyfill::slice::fill(&mut r[..ab_len], 0);
1286     for (i, &b_limb) in b.iter().enumerate() {
1287         r[ab_len + i] = unsafe {
1288             limbs_mul_add_limb(
1289                 (&mut r[i..][..ab_len]).as_mut_ptr(),
1290                 a.as_ptr(),
1291                 b_limb,
1292                 ab_len,
1293             )
1294         };
1295     }
1296 }
1297 
1298 /// r = a * b
1299 #[cfg(not(target_arch = "x86_64"))]
limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0)1300 fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
1301     debug_assert_eq!(r.len(), m.len());
1302     debug_assert_eq!(a.len(), m.len());
1303     debug_assert_eq!(b.len(), m.len());
1304 
1305     unsafe {
1306         bn_mul_mont(
1307             r.as_mut_ptr(),
1308             a.as_ptr(),
1309             b.as_ptr(),
1310             m.as_ptr(),
1311             n0,
1312             r.len(),
1313         )
1314     }
1315 }
1316 
1317 /// r = r**2
limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0)1318 fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
1319     debug_assert_eq!(r.len(), m.len());
1320     unsafe {
1321         bn_mul_mont(
1322             r.as_mut_ptr(),
1323             r.as_ptr(),
1324             r.as_ptr(),
1325             m.as_ptr(),
1326             n0,
1327             r.len(),
1328         )
1329     }
1330 }
1331 
1332 prefixed_extern! {
1333     // `r` and/or 'a' and/or 'b' may alias.
1334     fn bn_mul_mont(
1335         r: *mut Limb,
1336         a: *const Limb,
1337         b: *const Limb,
1338         n: *const Limb,
1339         n0: &N0,
1340         num_limbs: c::size_t,
1341     );
1342 }
1343 
1344 #[cfg(any(
1345     test,
1346     not(any(
1347         target_arch = "aarch64",
1348         target_arch = "arm",
1349         target_arch = "x86_64",
1350         target_arch = "x86"
1351     ))
1352 ))]
1353 prefixed_extern! {
1354     // `r` must not alias `a`
1355     #[must_use]
1356     fn limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
1357 }
1358 
strip_leading_zeros(value: &[u8]) -> Box<[u8]>1359 fn strip_leading_zeros(value: &[u8]) -> Box<[u8]> {
1360     fn index_after_zeros(bytes: &[u8]) -> usize {
1361         for (i, &value) in bytes.iter().enumerate() {
1362             if value != 0 {
1363                 return i;
1364             }
1365         }
1366         bytes.len()
1367     }
1368     (&value[index_after_zeros(value)..]).into()
1369 }
1370 
1371 #[cfg(test)]
1372 mod tests {
1373     use super::*;
1374     use crate::test;
1375     use alloc::format;
1376 
1377     // Type-level representation of an arbitrary modulus.
1378     struct M {}
1379 
1380     unsafe impl PublicModulus for M {}
1381 
1382     #[test]
test_elem_exp_consttime()1383     fn test_elem_exp_consttime() {
1384         test::run(
1385             test_file!("bigint_elem_exp_consttime_tests.txt"),
1386             |section, test_case| {
1387                 assert_eq!(section, "");
1388 
1389                 let m = consume_modulus::<M>(test_case, "M");
1390                 let expected_result = consume_elem(test_case, "ModExp", &m);
1391                 let base = consume_elem(test_case, "A", &m);
1392                 let e = {
1393                     let bytes = test_case.consume_bytes("E");
1394                     PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
1395                         .expect("valid exponent")
1396                 };
1397                 let base = into_encoded(base, &m);
1398                 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
1399                 assert_elem_eq(&actual_result, &expected_result);
1400 
1401                 Ok(())
1402             },
1403         )
1404     }
1405 
1406     // TODO: fn test_elem_exp_vartime() using
1407     // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
1408     // In the meantime, the function is tested indirectly via the RSA
1409     // verification and signing tests.
1410     #[test]
test_elem_mul()1411     fn test_elem_mul() {
1412         test::run(
1413             test_file!("bigint_elem_mul_tests.txt"),
1414             |section, test_case| {
1415                 assert_eq!(section, "");
1416 
1417                 let m = consume_modulus::<M>(test_case, "M");
1418                 let expected_result = consume_elem(test_case, "ModMul", &m);
1419                 let a = consume_elem(test_case, "A", &m);
1420                 let b = consume_elem(test_case, "B", &m);
1421 
1422                 let b = into_encoded(b, &m);
1423                 let a = into_encoded(a, &m);
1424                 let actual_result = elem_mul(&a, b, &m);
1425                 let actual_result = actual_result.into_unencoded(&m);
1426                 assert_elem_eq(&actual_result, &expected_result);
1427 
1428                 Ok(())
1429             },
1430         )
1431     }
1432 
1433     #[test]
test_elem_squared()1434     fn test_elem_squared() {
1435         test::run(
1436             test_file!("bigint_elem_squared_tests.txt"),
1437             |section, test_case| {
1438                 assert_eq!(section, "");
1439 
1440                 let m = consume_modulus::<M>(test_case, "M");
1441                 let expected_result = consume_elem(test_case, "ModSquare", &m);
1442                 let a = consume_elem(test_case, "A", &m);
1443 
1444                 let a = into_encoded(a, &m);
1445                 let actual_result = elem_squared(a, &m.as_partial());
1446                 let actual_result = actual_result.into_unencoded(&m);
1447                 assert_elem_eq(&actual_result, &expected_result);
1448 
1449                 Ok(())
1450             },
1451         )
1452     }
1453 
1454     #[test]
test_elem_reduced()1455     fn test_elem_reduced() {
1456         test::run(
1457             test_file!("bigint_elem_reduced_tests.txt"),
1458             |section, test_case| {
1459                 assert_eq!(section, "");
1460 
1461                 struct MM {}
1462                 unsafe impl SmallerModulus<MM> for M {}
1463                 unsafe impl NotMuchSmallerModulus<MM> for M {}
1464 
1465                 let m = consume_modulus::<M>(test_case, "M");
1466                 let expected_result = consume_elem(test_case, "R", &m);
1467                 let a =
1468                     consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
1469 
1470                 let actual_result = elem_reduced(&a, &m);
1471                 let oneRR = m.oneRR();
1472                 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
1473                 assert_elem_eq(&actual_result, &expected_result);
1474 
1475                 Ok(())
1476             },
1477         )
1478     }
1479 
1480     #[test]
test_elem_reduced_once()1481     fn test_elem_reduced_once() {
1482         test::run(
1483             test_file!("bigint_elem_reduced_once_tests.txt"),
1484             |section, test_case| {
1485                 assert_eq!(section, "");
1486 
1487                 struct N {}
1488                 struct QQ {}
1489                 unsafe impl SmallerModulus<N> for QQ {}
1490                 unsafe impl SlightlySmallerModulus<N> for QQ {}
1491 
1492                 let qq = consume_modulus::<QQ>(test_case, "QQ");
1493                 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
1494                 let n = consume_modulus::<N>(test_case, "N");
1495                 let a = consume_elem::<N>(test_case, "A", &n);
1496 
1497                 let actual_result = elem_reduced_once(&a, &qq);
1498                 assert_elem_eq(&actual_result, &expected_result);
1499 
1500                 Ok(())
1501             },
1502         )
1503     }
1504 
1505     #[test]
test_modulus_debug()1506     fn test_modulus_debug() {
1507         let (modulus, _) =
1508             Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(&[0xff; 1024 / 8]))
1509                 .unwrap();
1510         assert_eq!(
1511             "Modulus(\"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff\")",
1512             format!("{:?}", modulus)
1513         );
1514     }
1515 
1516     #[test]
test_public_exponent_debug()1517     fn test_public_exponent_debug() {
1518         let exponent =
1519             PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
1520                 .unwrap();
1521         assert_eq!("65537", format!("{:?}", exponent));
1522     }
1523 
consume_elem<M>( test_case: &mut test::TestCase, name: &str, m: &Modulus<M>, ) -> Elem<M, Unencoded>1524     fn consume_elem<M>(
1525         test_case: &mut test::TestCase,
1526         name: &str,
1527         m: &Modulus<M>,
1528     ) -> Elem<M, Unencoded> {
1529         let value = test_case.consume_bytes(name);
1530         Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
1531     }
1532 
consume_elem_unchecked<M>( test_case: &mut test::TestCase, name: &str, num_limbs: usize, ) -> Elem<M, Unencoded>1533     fn consume_elem_unchecked<M>(
1534         test_case: &mut test::TestCase,
1535         name: &str,
1536         num_limbs: usize,
1537     ) -> Elem<M, Unencoded> {
1538         let value = consume_nonnegative(test_case, name);
1539         let mut limbs = BoxedLimbs::zero(Width {
1540             num_limbs,
1541             m: PhantomData,
1542         });
1543         limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
1544         Elem {
1545             limbs,
1546             encoding: PhantomData,
1547         }
1548     }
1549 
consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M>1550     fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
1551         let value = test_case.consume_bytes(name);
1552         let (value, _) =
1553             Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
1554         value
1555     }
1556 
consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative1557     fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1558         let bytes = test_case.consume_bytes(name);
1559         let (r, _r_bits) =
1560             Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1561         r
1562     }
1563 
assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>)1564     fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1565         if elem_verify_equal_consttime(&a, b).is_err() {
1566             panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1567         }
1568     }
1569 
into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R>1570     fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1571         elem_mul(m.oneRR().as_ref(), a, m)
1572     }
1573 
1574     #[test]
1575     // TODO: wasm
test_mul_add_words()1576     fn test_mul_add_words() {
1577         const ZERO: Limb = 0;
1578         const MAX: Limb = ZERO.wrapping_sub(1);
1579         static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
1580             (&[0], &[0], 0, 0, &[0]),
1581             (&[MAX], &[0], MAX, 0, &[MAX]),
1582             (&[0], &[MAX], MAX, MAX - 1, &[1]),
1583             (&[MAX], &[MAX], MAX, MAX, &[0]),
1584             (&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
1585             (&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
1586             (&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
1587             (&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
1588             (&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
1589         ];
1590 
1591         for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
1592             extern crate std;
1593             let mut r = std::vec::Vec::from(*r_input);
1594             assert_eq!(r.len(), a.len()); // Sanity check
1595             let actual_retval =
1596                 unsafe { limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
1597             assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, &r[..], expected_r);
1598             assert_eq!(
1599                 actual_retval, *expected_retval,
1600                 "{}: {:x?} != {:x?}",
1601                 i, actual_retval, *expected_retval
1602             );
1603         }
1604     }
1605 }
1606