1 // Copyright 2015-2016 Brian Smith.
2 //
3 // Permission to use, copy, modify, and/or distribute this software for any
4 // purpose with or without fee is hereby granted, provided that the above
5 // copyright notice and this permission notice appear in all copies.
6 //
7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15 //! Multi-precision integers.
16 //!
17 //! # Modular Arithmetic.
18 //!
19 //! Modular arithmetic is done in finite commutative rings ℤ/mℤ for some
20 //! modulus *m*. We work in finite commutative rings instead of finite fields
21 //! because the RSA public modulus *n* is not prime, which means ℤ/nℤ contains
22 //! nonzero elements that have no multiplicative inverse, so ℤ/nℤ is not a
23 //! finite field.
24 //!
25 //! In some calculations we need to deal with multiple rings at once. For
26 //! example, RSA private key operations operate in the rings ℤ/nℤ, ℤ/pℤ, and
27 //! ℤ/qℤ. Types and functions dealing with such rings are all parameterized
28 //! over a type `M` to ensure that we don't wrongly mix up the math, e.g. by
29 //! multiplying an element of ℤ/pℤ by an element of ℤ/qℤ modulo q. This follows
30 //! the "unit" pattern described in [Static checking of units in Servo].
31 //!
32 //! `Elem` also uses the static unit checking pattern to statically track the
33 //! Montgomery factors that need to be canceled out in each value using it's
34 //! `E` parameter.
35 //!
36 //! [Static checking of units in Servo]:
37 //! https://blog.mozilla.org/research/2014/06/23/static-checking-of-units-in-servo/
38
39 use crate::{
40 arithmetic::montgomery::*,
41 bits, bssl, c, debug, error,
42 limb::{self, Limb, LimbMask, LIMB_BITS, LIMB_BYTES},
43 };
44 use alloc::{borrow::ToOwned as _, boxed::Box, vec, vec::Vec};
45 use core::{
46 marker::PhantomData,
47 ops::{Deref, DerefMut},
48 };
49
50 mod bn_mul_mont_fallback;
51
52 pub unsafe trait Prime {}
53
54 struct Width<M> {
55 num_limbs: usize,
56
57 /// The modulus *m* that the width originated from.
58 m: PhantomData<M>,
59 }
60
61 /// All `BoxedLimbs<M>` are stored in the same number of limbs.
62 struct BoxedLimbs<M> {
63 limbs: Box<[Limb]>,
64
65 /// The modulus *m* that determines the size of `limbx`.
66 m: PhantomData<M>,
67 }
68
69 impl<M> Deref for BoxedLimbs<M> {
70 type Target = [Limb];
71 #[inline]
deref(&self) -> &Self::Target72 fn deref(&self) -> &Self::Target {
73 &self.limbs
74 }
75 }
76
77 impl<M> DerefMut for BoxedLimbs<M> {
78 #[inline]
deref_mut(&mut self) -> &mut Self::Target79 fn deref_mut(&mut self) -> &mut Self::Target {
80 &mut self.limbs
81 }
82 }
83
84 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
85 // is resolved or restrict `M: Clone`.
86 impl<M> Clone for BoxedLimbs<M> {
clone(&self) -> Self87 fn clone(&self) -> Self {
88 Self {
89 limbs: self.limbs.clone(),
90 m: self.m,
91 }
92 }
93 }
94
95 impl<M> BoxedLimbs<M> {
positive_minimal_width_from_be_bytes( input: untrusted::Input, ) -> Result<Self, error::KeyRejected>96 fn positive_minimal_width_from_be_bytes(
97 input: untrusted::Input,
98 ) -> Result<Self, error::KeyRejected> {
99 // Reject leading zeros. Also reject the value zero ([0]) because zero
100 // isn't positive.
101 if untrusted::Reader::new(input).peek(0) {
102 return Err(error::KeyRejected::invalid_encoding());
103 }
104 let num_limbs = (input.len() + LIMB_BYTES - 1) / LIMB_BYTES;
105 let mut r = Self::zero(Width {
106 num_limbs,
107 m: PhantomData,
108 });
109 limb::parse_big_endian_and_pad_consttime(input, &mut r)
110 .map_err(|error::Unspecified| error::KeyRejected::unexpected_error())?;
111 Ok(r)
112 }
113
minimal_width_from_unpadded(limbs: &[Limb]) -> Self114 fn minimal_width_from_unpadded(limbs: &[Limb]) -> Self {
115 debug_assert_ne!(limbs.last(), Some(&0));
116 Self {
117 limbs: limbs.to_owned().into_boxed_slice(),
118 m: PhantomData,
119 }
120 }
121
from_be_bytes_padded_less_than( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>122 fn from_be_bytes_padded_less_than(
123 input: untrusted::Input,
124 m: &Modulus<M>,
125 ) -> Result<Self, error::Unspecified> {
126 let mut r = Self::zero(m.width());
127 limb::parse_big_endian_and_pad_consttime(input, &mut r)?;
128 if limb::limbs_less_than_limbs_consttime(&r, &m.limbs) != LimbMask::True {
129 return Err(error::Unspecified);
130 }
131 Ok(r)
132 }
133
134 #[inline]
is_zero(&self) -> bool135 fn is_zero(&self) -> bool {
136 limb::limbs_are_zero_constant_time(&self.limbs) == LimbMask::True
137 }
138
zero(width: Width<M>) -> Self139 fn zero(width: Width<M>) -> Self {
140 Self {
141 limbs: vec![0; width.num_limbs].into_boxed_slice(),
142 m: PhantomData,
143 }
144 }
145
width(&self) -> Width<M>146 fn width(&self) -> Width<M> {
147 Width {
148 num_limbs: self.limbs.len(),
149 m: PhantomData,
150 }
151 }
152 }
153
154 /// A modulus *s* that is smaller than another modulus *l* so every element of
155 /// ℤ/sℤ is also an element of ℤ/lℤ.
156 pub unsafe trait SmallerModulus<L> {}
157
158 /// A modulus *s* where s < l < 2*s for the given larger modulus *l*. This is
159 /// the precondition for reduction by conditional subtraction,
160 /// `elem_reduce_once()`.
161 pub unsafe trait SlightlySmallerModulus<L>: SmallerModulus<L> {}
162
163 /// A modulus *s* where √l <= s < l for the given larger modulus *l*. This is
164 /// the precondition for the more general Montgomery reduction from ℤ/lℤ to
165 /// ℤ/sℤ.
166 pub unsafe trait NotMuchSmallerModulus<L>: SmallerModulus<L> {}
167
168 pub unsafe trait PublicModulus {}
169
170 /// The x86 implementation of `bn_mul_mont`, at least, requires at least 4
171 /// limbs. For a long time we have required 4 limbs for all targets, though
172 /// this may be unnecessary. TODO: Replace this with
173 /// `n.len() < 256 / LIMB_BITS` so that 32-bit and 64-bit platforms behave the
174 /// same.
175 pub const MODULUS_MIN_LIMBS: usize = 4;
176
177 pub const MODULUS_MAX_LIMBS: usize = 8192 / LIMB_BITS;
178
179 /// The modulus *m* for a ring ℤ/mℤ, along with the precomputed values needed
180 /// for efficient Montgomery multiplication modulo *m*. The value must be odd
181 /// and larger than 2. The larger-than-1 requirement is imposed, at least, by
182 /// the modular inversion code.
183 pub struct Modulus<M> {
184 limbs: BoxedLimbs<M>, // Also `value >= 3`.
185
186 // n0 * N == -1 (mod r).
187 //
188 // r == 2**(N0_LIMBS_USED * LIMB_BITS) and LG_LITTLE_R == lg(r). This
189 // ensures that we can do integer division by |r| by simply ignoring
190 // `N0_LIMBS_USED` limbs. Similarly, we can calculate values modulo `r` by
191 // just looking at the lowest `N0_LIMBS_USED` limbs. This is what makes
192 // Montgomery multiplication efficient.
193 //
194 // As shown in Algorithm 1 of "Fast Prime Field Elliptic Curve Cryptography
195 // with 256 Bit Primes" by Shay Gueron and Vlad Krasnov, in the loop of a
196 // multi-limb Montgomery multiplication of a * b (mod n), given the
197 // unreduced product t == a * b, we repeatedly calculate:
198 //
199 // t1 := t % r |t1| is |t|'s lowest limb (see previous paragraph).
200 // t2 := t1*n0*n
201 // t3 := t + t2
202 // t := t3 / r copy all limbs of |t3| except the lowest to |t|.
203 //
204 // In the last step, it would only make sense to ignore the lowest limb of
205 // |t3| if it were zero. The middle steps ensure that this is the case:
206 //
207 // t3 == 0 (mod r)
208 // t + t2 == 0 (mod r)
209 // t + t1*n0*n == 0 (mod r)
210 // t1*n0*n == -t (mod r)
211 // t*n0*n == -t (mod r)
212 // n0*n == -1 (mod r)
213 // n0 == -1/n (mod r)
214 //
215 // Thus, in each iteration of the loop, we multiply by the constant factor
216 // n0, the negative inverse of n (mod r).
217 //
218 // TODO(perf): Not all 32-bit platforms actually make use of n0[1]. For the
219 // ones that don't, we could use a shorter `R` value and use faster `Limb`
220 // calculations instead of double-precision `u64` calculations.
221 n0: N0,
222
223 oneRR: One<M, RR>,
224 }
225
226 impl<M: PublicModulus> Modulus<M> {
to_be_bytes(&self) -> Box<[u8]>227 pub fn to_be_bytes(&self) -> Box<[u8]> {
228 let mut padded = vec![0u8; self.limbs.len() * LIMB_BYTES];
229 // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
230 limb::big_endian_from_limbs(&self.limbs, &mut padded);
231 strip_leading_zeros(&padded)
232 }
233 }
234
235 impl<M: PublicModulus> Clone for Modulus<M> {
clone(&self) -> Self236 fn clone(&self) -> Self {
237 Self {
238 limbs: self.limbs.clone(),
239 n0: self.n0.clone(),
240 oneRR: self.oneRR.clone(),
241 }
242 }
243 }
244
245 impl<M: PublicModulus> core::fmt::Debug for Modulus<M> {
fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error>246 fn fmt(&self, fmt: &mut ::core::fmt::Formatter) -> Result<(), ::core::fmt::Error> {
247 let mut state = fmt.debug_tuple("Modulus");
248
249 #[cfg(feature = "alloc")]
250 let state = {
251 let value = self.to_be_bytes(); // XXX: Allocates
252 state.field(&debug::HexStr(&value))
253 };
254
255 state.finish()
256 }
257 }
258
259 impl<M> Modulus<M> {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::KeyRejected>260 pub fn from_be_bytes_with_bit_length(
261 input: untrusted::Input,
262 ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
263 let limbs = BoxedLimbs::positive_minimal_width_from_be_bytes(input)?;
264 Self::from_boxed_limbs(limbs)
265 }
266
from_nonnegative_with_bit_length( n: Nonnegative, ) -> Result<(Self, bits::BitLength), error::KeyRejected>267 pub fn from_nonnegative_with_bit_length(
268 n: Nonnegative,
269 ) -> Result<(Self, bits::BitLength), error::KeyRejected> {
270 let limbs = BoxedLimbs {
271 limbs: n.limbs.into_boxed_slice(),
272 m: PhantomData,
273 };
274 Self::from_boxed_limbs(limbs)
275 }
276
from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected>277 fn from_boxed_limbs(n: BoxedLimbs<M>) -> Result<(Self, bits::BitLength), error::KeyRejected> {
278 if n.len() > MODULUS_MAX_LIMBS {
279 return Err(error::KeyRejected::too_large());
280 }
281 if n.len() < MODULUS_MIN_LIMBS {
282 return Err(error::KeyRejected::unexpected_error());
283 }
284 if limb::limbs_are_even_constant_time(&n) != LimbMask::False {
285 return Err(error::KeyRejected::invalid_component());
286 }
287 if limb::limbs_less_than_limb_constant_time(&n, 3) != LimbMask::False {
288 return Err(error::KeyRejected::unexpected_error());
289 }
290
291 // n_mod_r = n % r. As explained in the documentation for `n0`, this is
292 // done by taking the lowest `N0_LIMBS_USED` limbs of `n`.
293 #[allow(clippy::useless_conversion)]
294 let n0 = {
295 prefixed_extern! {
296 fn bn_neg_inv_mod_r_u64(n: u64) -> u64;
297 }
298
299 // XXX: u64::from isn't guaranteed to be constant time.
300 let mut n_mod_r: u64 = u64::from(n[0]);
301
302 if N0_LIMBS_USED == 2 {
303 // XXX: If we use `<< LIMB_BITS` here then 64-bit builds
304 // fail to compile because of `deny(exceeding_bitshifts)`.
305 debug_assert_eq!(LIMB_BITS, 32);
306 n_mod_r |= u64::from(n[1]) << 32;
307 }
308 N0::from(unsafe { bn_neg_inv_mod_r_u64(n_mod_r) })
309 };
310
311 let bits = limb::limbs_minimal_bits(&n.limbs);
312 let oneRR = {
313 let partial = PartialModulus {
314 limbs: &n.limbs,
315 n0: n0.clone(),
316 m: PhantomData,
317 };
318
319 One::newRR(&partial, bits)
320 };
321
322 Ok((
323 Self {
324 limbs: n,
325 n0,
326 oneRR,
327 },
328 bits,
329 ))
330 }
331
332 #[inline]
width(&self) -> Width<M>333 fn width(&self) -> Width<M> {
334 self.limbs.width()
335 }
336
zero<E>(&self) -> Elem<M, E>337 fn zero<E>(&self) -> Elem<M, E> {
338 Elem {
339 limbs: BoxedLimbs::zero(self.width()),
340 encoding: PhantomData,
341 }
342 }
343
344 // TODO: Get rid of this
one(&self) -> Elem<M, Unencoded>345 fn one(&self) -> Elem<M, Unencoded> {
346 let mut r = self.zero();
347 r.limbs[0] = 1;
348 r
349 }
350
oneRR(&self) -> &One<M, RR>351 pub fn oneRR(&self) -> &One<M, RR> {
352 &self.oneRR
353 }
354
to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded> where M: SmallerModulus<L>,355 pub fn to_elem<L>(&self, l: &Modulus<L>) -> Elem<L, Unencoded>
356 where
357 M: SmallerModulus<L>,
358 {
359 // TODO: Encode this assertion into the `where` above.
360 assert_eq!(self.width().num_limbs, l.width().num_limbs);
361 let limbs = self.limbs.clone();
362 Elem {
363 limbs: BoxedLimbs {
364 limbs: limbs.limbs,
365 m: PhantomData,
366 },
367 encoding: PhantomData,
368 }
369 }
370
as_partial(&self) -> PartialModulus<M>371 fn as_partial(&self) -> PartialModulus<M> {
372 PartialModulus {
373 limbs: &self.limbs,
374 n0: self.n0.clone(),
375 m: PhantomData,
376 }
377 }
378 }
379
380 struct PartialModulus<'a, M> {
381 limbs: &'a [Limb],
382 n0: N0,
383 m: PhantomData<M>,
384 }
385
386 impl<M> PartialModulus<'_, M> {
387 // TODO: XXX Avoid duplication with `Modulus`.
zero(&self) -> Elem<M, R>388 fn zero(&self) -> Elem<M, R> {
389 let width = Width {
390 num_limbs: self.limbs.len(),
391 m: PhantomData,
392 };
393 Elem {
394 limbs: BoxedLimbs::zero(width),
395 encoding: PhantomData,
396 }
397 }
398 }
399
400 /// Elements of ℤ/mℤ for some modulus *m*.
401 //
402 // Defaulting `E` to `Unencoded` is a convenience for callers from outside this
403 // submodule. However, for maximum clarity, we always explicitly use
404 // `Unencoded` within the `bigint` submodule.
405 pub struct Elem<M, E = Unencoded> {
406 limbs: BoxedLimbs<M>,
407
408 /// The number of Montgomery factors that need to be canceled out from
409 /// `value` to get the actual value.
410 encoding: PhantomData<E>,
411 }
412
413 // TODO: `derive(Clone)` after https://github.com/rust-lang/rust/issues/26925
414 // is resolved or restrict `M: Clone` and `E: Clone`.
415 impl<M, E> Clone for Elem<M, E> {
clone(&self) -> Self416 fn clone(&self) -> Self {
417 Self {
418 limbs: self.limbs.clone(),
419 encoding: self.encoding,
420 }
421 }
422 }
423
424 impl<M, E> Elem<M, E> {
425 #[inline]
is_zero(&self) -> bool426 pub fn is_zero(&self) -> bool {
427 self.limbs.is_zero()
428 }
429 }
430
431 impl<M, E: ReductionEncoding> Elem<M, E> {
decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output>432 fn decode_once(self, m: &Modulus<M>) -> Elem<M, <E as ReductionEncoding>::Output> {
433 // A multiplication isn't required since we're multiplying by the
434 // unencoded value one (1); only a Montgomery reduction is needed.
435 // However the only non-multiplication Montgomery reduction function we
436 // have requires the input to be large, so we avoid using it here.
437 let mut limbs = self.limbs;
438 let num_limbs = m.width().num_limbs;
439 let mut one = [0; MODULUS_MAX_LIMBS];
440 one[0] = 1;
441 let one = &one[..num_limbs]; // assert!(num_limbs <= MODULUS_MAX_LIMBS);
442 limbs_mont_mul(&mut limbs, &one, &m.limbs, &m.n0);
443 Elem {
444 limbs,
445 encoding: PhantomData,
446 }
447 }
448 }
449
450 impl<M> Elem<M, R> {
451 #[inline]
into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded>452 pub fn into_unencoded(self, m: &Modulus<M>) -> Elem<M, Unencoded> {
453 self.decode_once(m)
454 }
455 }
456
457 impl<M> Elem<M, Unencoded> {
from_be_bytes_padded( input: untrusted::Input, m: &Modulus<M>, ) -> Result<Self, error::Unspecified>458 pub fn from_be_bytes_padded(
459 input: untrusted::Input,
460 m: &Modulus<M>,
461 ) -> Result<Self, error::Unspecified> {
462 Ok(Self {
463 limbs: BoxedLimbs::from_be_bytes_padded_less_than(input, m)?,
464 encoding: PhantomData,
465 })
466 }
467
468 #[inline]
fill_be_bytes(&self, out: &mut [u8])469 pub fn fill_be_bytes(&self, out: &mut [u8]) {
470 // See Falko Strenzke, "Manger's Attack revisited", ICICS 2010.
471 limb::big_endian_from_limbs(&self.limbs, out)
472 }
473
into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected>474 pub fn into_modulus<MM>(self) -> Result<Modulus<MM>, error::KeyRejected> {
475 let (m, _bits) =
476 Modulus::from_boxed_limbs(BoxedLimbs::minimal_width_from_unpadded(&self.limbs))?;
477 Ok(m)
478 }
479
is_one(&self) -> bool480 fn is_one(&self) -> bool {
481 limb::limbs_equal_limb_constant_time(&self.limbs, 1) == LimbMask::True
482 }
483 }
484
elem_mul<M, AF, BF>( a: &Elem<M, AF>, b: Elem<M, BF>, m: &Modulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,485 pub fn elem_mul<M, AF, BF>(
486 a: &Elem<M, AF>,
487 b: Elem<M, BF>,
488 m: &Modulus<M>,
489 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
490 where
491 (AF, BF): ProductEncoding,
492 {
493 elem_mul_(a, b, &m.as_partial())
494 }
495
elem_mul_<M, AF, BF>( a: &Elem<M, AF>, mut b: Elem<M, BF>, m: &PartialModulus<M>, ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output> where (AF, BF): ProductEncoding,496 fn elem_mul_<M, AF, BF>(
497 a: &Elem<M, AF>,
498 mut b: Elem<M, BF>,
499 m: &PartialModulus<M>,
500 ) -> Elem<M, <(AF, BF) as ProductEncoding>::Output>
501 where
502 (AF, BF): ProductEncoding,
503 {
504 limbs_mont_mul(&mut b.limbs, &a.limbs, &m.limbs, &m.n0);
505 Elem {
506 limbs: b.limbs,
507 encoding: PhantomData,
508 }
509 }
510
elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>)511 fn elem_mul_by_2<M, AF>(a: &mut Elem<M, AF>, m: &PartialModulus<M>) {
512 prefixed_extern! {
513 fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
514 }
515 unsafe {
516 LIMBS_shl_mod(
517 a.limbs.as_mut_ptr(),
518 a.limbs.as_ptr(),
519 m.limbs.as_ptr(),
520 m.limbs.len(),
521 );
522 }
523 }
524
elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, Unencoded>525 pub fn elem_reduced_once<Larger, Smaller: SlightlySmallerModulus<Larger>>(
526 a: &Elem<Larger, Unencoded>,
527 m: &Modulus<Smaller>,
528 ) -> Elem<Smaller, Unencoded> {
529 let mut r = a.limbs.clone();
530 assert!(r.len() <= m.limbs.len());
531 limb::limbs_reduce_once_constant_time(&mut r, &m.limbs);
532 Elem {
533 limbs: BoxedLimbs {
534 limbs: r.limbs,
535 m: PhantomData,
536 },
537 encoding: PhantomData,
538 }
539 }
540
541 #[inline]
elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>( a: &Elem<Larger, Unencoded>, m: &Modulus<Smaller>, ) -> Elem<Smaller, RInverse>542 pub fn elem_reduced<Larger, Smaller: NotMuchSmallerModulus<Larger>>(
543 a: &Elem<Larger, Unencoded>,
544 m: &Modulus<Smaller>,
545 ) -> Elem<Smaller, RInverse> {
546 let mut tmp = [0; MODULUS_MAX_LIMBS];
547 let tmp = &mut tmp[..a.limbs.len()];
548 tmp.copy_from_slice(&a.limbs);
549
550 let mut r = m.zero();
551 limbs_from_mont_in_place(&mut r.limbs, tmp, &m.limbs, &m.n0);
552 r
553 }
554
elem_squared<M, E>( mut a: Elem<M, E>, m: &PartialModulus<M>, ) -> Elem<M, <(E, E) as ProductEncoding>::Output> where (E, E): ProductEncoding,555 fn elem_squared<M, E>(
556 mut a: Elem<M, E>,
557 m: &PartialModulus<M>,
558 ) -> Elem<M, <(E, E) as ProductEncoding>::Output>
559 where
560 (E, E): ProductEncoding,
561 {
562 limbs_mont_square(&mut a.limbs, &m.limbs, &m.n0);
563 Elem {
564 limbs: a.limbs,
565 encoding: PhantomData,
566 }
567 }
568
elem_widen<Larger, Smaller: SmallerModulus<Larger>>( a: Elem<Smaller, Unencoded>, m: &Modulus<Larger>, ) -> Elem<Larger, Unencoded>569 pub fn elem_widen<Larger, Smaller: SmallerModulus<Larger>>(
570 a: Elem<Smaller, Unencoded>,
571 m: &Modulus<Larger>,
572 ) -> Elem<Larger, Unencoded> {
573 let mut r = m.zero();
574 r.limbs[..a.limbs.len()].copy_from_slice(&a.limbs);
575 r
576 }
577
578 // TODO: Document why this works for all Montgomery factors.
elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>579 pub fn elem_add<M, E>(mut a: Elem<M, E>, b: Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
580 limb::limbs_add_assign_mod(&mut a.limbs, &b.limbs, &m.limbs);
581 a
582 }
583
584 // TODO: Document why this works for all Montgomery factors.
elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E>585 pub fn elem_sub<M, E>(mut a: Elem<M, E>, b: &Elem<M, E>, m: &Modulus<M>) -> Elem<M, E> {
586 prefixed_extern! {
587 // `r` and `a` may alias.
588 fn LIMBS_sub_mod(
589 r: *mut Limb,
590 a: *const Limb,
591 b: *const Limb,
592 m: *const Limb,
593 num_limbs: c::size_t,
594 );
595 }
596 unsafe {
597 LIMBS_sub_mod(
598 a.limbs.as_mut_ptr(),
599 a.limbs.as_ptr(),
600 b.limbs.as_ptr(),
601 m.limbs.as_ptr(),
602 m.limbs.len(),
603 );
604 }
605 a
606 }
607
608 // The value 1, Montgomery-encoded some number of times.
609 pub struct One<M, E>(Elem<M, E>);
610
611 impl<M> One<M, RR> {
612 // Returns RR = = R**2 (mod n) where R = 2**r is the smallest power of
613 // 2**LIMB_BITS such that R > m.
614 //
615 // Even though the assembly on some 32-bit platforms works with 64-bit
616 // values, using `LIMB_BITS` here, rather than `N0_LIMBS_USED * LIMB_BITS`,
617 // is correct because R**2 will still be a multiple of the latter as
618 // `N0_LIMBS_USED` is either one or two.
newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self619 fn newRR(m: &PartialModulus<M>, m_bits: bits::BitLength) -> Self {
620 let m_bits = m_bits.as_usize_bits();
621 let r = (m_bits + (LIMB_BITS - 1)) / LIMB_BITS * LIMB_BITS;
622
623 // base = 2**(lg m - 1).
624 let bit = m_bits - 1;
625 let mut base = m.zero();
626 base.limbs[bit / LIMB_BITS] = 1 << (bit % LIMB_BITS);
627
628 // Double `base` so that base == R == 2**r (mod m). For normal moduli
629 // that have the high bit of the highest limb set, this requires one
630 // doubling. Unusual moduli require more doublings but we are less
631 // concerned about the performance of those.
632 //
633 // Then double `base` again so that base == 2*R (mod n), i.e. `2` in
634 // Montgomery form (`elem_exp_vartime_()` requires the base to be in
635 // Montgomery form). Then compute
636 // RR = R**2 == base**r == R**r == (2**r)**r (mod n).
637 //
638 // Take advantage of the fact that `elem_mul_by_2` is faster than
639 // `elem_squared` by replacing some of the early squarings with shifts.
640 // TODO: Benchmark shift vs. squaring performance to determine the
641 // optimal value of `lg_base`.
642 let lg_base = 2usize; // Shifts vs. squaring trade-off.
643 debug_assert_eq!(lg_base.count_ones(), 1); // Must 2**n for n >= 0.
644 let shifts = r - bit + lg_base;
645 let exponent = (r / lg_base) as u64;
646 for _ in 0..shifts {
647 elem_mul_by_2(&mut base, m)
648 }
649 let RR = elem_exp_vartime_(base, exponent, m);
650
651 Self(Elem {
652 limbs: RR.limbs,
653 encoding: PhantomData, // PhantomData<RR>
654 })
655 }
656 }
657
658 impl<M: PublicModulus, E> Clone for One<M, E> {
clone(&self) -> Self659 fn clone(&self) -> Self {
660 Self(self.0.clone())
661 }
662 }
663
664 impl<M, E> AsRef<Elem<M, E>> for One<M, E> {
as_ref(&self) -> &Elem<M, E>665 fn as_ref(&self) -> &Elem<M, E> {
666 &self.0
667 }
668 }
669
670 /// A non-secret odd positive value in the range
671 /// [3, PUBLIC_EXPONENT_MAX_VALUE].
672 #[derive(Clone, Copy)]
673 pub struct PublicExponent(u64);
674
675 impl core::fmt::Debug for PublicExponent {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error>676 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
677 write!(f, "{}", self.0)
678 }
679 }
680
681 impl PublicExponent {
from_be_bytes( input: untrusted::Input, min_value: u64, ) -> Result<Self, error::KeyRejected>682 pub fn from_be_bytes(
683 input: untrusted::Input,
684 min_value: u64,
685 ) -> Result<Self, error::KeyRejected> {
686 if input.len() > 5 {
687 return Err(error::KeyRejected::too_large());
688 }
689 let value = input.read_all(error::KeyRejected::invalid_encoding(), |input| {
690 // The exponent can't be zero and it can't be prefixed with
691 // zero-valued bytes.
692 if input.peek(0) {
693 return Err(error::KeyRejected::invalid_encoding());
694 }
695 let mut value = 0u64;
696 loop {
697 let byte = input
698 .read_byte()
699 .map_err(|untrusted::EndOfInput| error::KeyRejected::invalid_encoding())?;
700 value = (value << 8) | u64::from(byte);
701 if input.at_end() {
702 return Ok(value);
703 }
704 }
705 })?;
706
707 // Step 2 / Step b. NIST SP800-89 defers to FIPS 186-3, which requires
708 // `e >= 65537`. We enforce this when signing, but are more flexible in
709 // verification, for compatibility. Only small public exponents are
710 // supported.
711 if value & 1 != 1 {
712 return Err(error::KeyRejected::invalid_component());
713 }
714 debug_assert!(min_value & 1 == 1);
715 debug_assert!(min_value <= PUBLIC_EXPONENT_MAX_VALUE);
716 if min_value < 3 {
717 return Err(error::KeyRejected::invalid_component());
718 }
719 if value < min_value {
720 return Err(error::KeyRejected::too_small());
721 }
722 if value > PUBLIC_EXPONENT_MAX_VALUE {
723 return Err(error::KeyRejected::too_large());
724 }
725
726 Ok(Self(value))
727 }
728
729 #[inline]
to_be_bytes(&self) -> Box<[u8]>730 pub fn to_be_bytes(&self) -> Box<[u8]> {
731 strip_leading_zeros(&u64::to_be_bytes(self.0))
732 }
733 }
734
735 // This limit was chosen to bound the performance of the simple
736 // exponentiation-by-squaring implementation in `elem_exp_vartime`. In
737 // particular, it helps mitigate theoretical resource exhaustion attacks. 33
738 // bits was chosen as the limit based on the recommendations in [1] and
739 // [2]. Windows CryptoAPI (at least older versions) doesn't support values
740 // larger than 32 bits [3], so it is unlikely that exponents larger than 32
741 // bits are being used for anything Windows commonly does.
742 //
743 // [1] https://www.imperialviolet.org/2012/03/16/rsae.html
744 // [2] https://www.imperialviolet.org/2012/03/17/rsados.html
745 // [3] https://msdn.microsoft.com/en-us/library/aa387685(VS.85).aspx
746 const PUBLIC_EXPONENT_MAX_VALUE: u64 = (1u64 << 33) - 1;
747
748 /// Calculates base**exponent (mod m).
749 // TODO: The test coverage needs to be expanded, e.g. test with the largest
750 // accepted exponent and with the most common values of 65537 and 3.
elem_exp_vartime<M>( base: Elem<M, Unencoded>, PublicExponent(exponent): PublicExponent, m: &Modulus<M>, ) -> Elem<M, R>751 pub fn elem_exp_vartime<M>(
752 base: Elem<M, Unencoded>,
753 PublicExponent(exponent): PublicExponent,
754 m: &Modulus<M>,
755 ) -> Elem<M, R> {
756 let base = elem_mul(m.oneRR().as_ref(), base, &m);
757 elem_exp_vartime_(base, exponent, &m.as_partial())
758 }
759
760 /// Calculates base**exponent (mod m).
elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R>761 fn elem_exp_vartime_<M>(base: Elem<M, R>, exponent: u64, m: &PartialModulus<M>) -> Elem<M, R> {
762 // Use what [Knuth] calls the "S-and-X binary method", i.e. variable-time
763 // square-and-multiply that scans the exponent from the most significant
764 // bit to the least significant bit (left-to-right). Left-to-right requires
765 // less storage compared to right-to-left scanning, at the cost of needing
766 // to compute `exponent.leading_zeros()`, which we assume to be cheap.
767 //
768 // During RSA public key operations the exponent is almost always either 65537
769 // (0b10000000000000001) or 3 (0b11), both of which have a Hamming weight
770 // of 2. During Montgomery setup the exponent is almost always a power of two,
771 // with Hamming weight 1. As explained in [Knuth], exponentiation by squaring
772 // is the most efficient algorithm when the Hamming weight is 2 or less. It
773 // isn't the most efficient for all other, uncommon, exponent values but any
774 // suboptimality is bounded by `PUBLIC_EXPONENT_MAX_VALUE`.
775 //
776 // This implementation is slightly simplified by taking advantage of the
777 // fact that we require the exponent to be a positive integer.
778 //
779 // [Knuth]: The Art of Computer Programming, Volume 2: Seminumerical
780 // Algorithms (3rd Edition), Section 4.6.3.
781 assert!(exponent >= 1);
782 assert!(exponent <= PUBLIC_EXPONENT_MAX_VALUE);
783 let mut acc = base.clone();
784 let mut bit = 1 << (64 - 1 - exponent.leading_zeros());
785 debug_assert!((exponent & bit) != 0);
786 while bit > 1 {
787 bit >>= 1;
788 acc = elem_squared(acc, m);
789 if (exponent & bit) != 0 {
790 acc = elem_mul_(&base, acc, m);
791 }
792 }
793 acc
794 }
795
796 // `M` represents the prime modulus for which the exponent is in the interval
797 // [1, `m` - 1).
798 pub struct PrivateExponent<M> {
799 limbs: BoxedLimbs<M>,
800 }
801
802 impl<M> PrivateExponent<M> {
from_be_bytes_padded( input: untrusted::Input, p: &Modulus<M>, ) -> Result<Self, error::Unspecified>803 pub fn from_be_bytes_padded(
804 input: untrusted::Input,
805 p: &Modulus<M>,
806 ) -> Result<Self, error::Unspecified> {
807 let dP = BoxedLimbs::from_be_bytes_padded_less_than(input, p)?;
808
809 // Proof that `dP < p - 1`:
810 //
811 // If `dP < p` then either `dP == p - 1` or `dP < p - 1`. Since `p` is
812 // odd, `p - 1` is even. `d` is odd, and an odd number modulo an even
813 // number is odd. Therefore `dP` must be odd. But then it cannot be
814 // `p - 1` and so we know `dP < p - 1`.
815 //
816 // Further we know `dP != 0` because `dP` is not even.
817 if limb::limbs_are_even_constant_time(&dP) != LimbMask::False {
818 return Err(error::Unspecified);
819 }
820
821 Ok(Self { limbs: dP })
822 }
823 }
824
825 impl<M: Prime> PrivateExponent<M> {
826 // Returns `p - 2`.
for_flt(p: &Modulus<M>) -> Self827 fn for_flt(p: &Modulus<M>) -> Self {
828 let two = elem_add(p.one(), p.one(), p);
829 let p_minus_2 = elem_sub(p.zero(), &two, p);
830 Self {
831 limbs: p_minus_2.limbs,
832 }
833 }
834 }
835
836 #[cfg(not(target_arch = "x86_64"))]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>837 pub fn elem_exp_consttime<M>(
838 base: Elem<M, R>,
839 exponent: &PrivateExponent<M>,
840 m: &Modulus<M>,
841 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
842 use crate::limb::Window;
843
844 const WINDOW_BITS: usize = 5;
845 const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
846
847 let num_limbs = m.limbs.len();
848
849 let mut table = vec![0; TABLE_ENTRIES * num_limbs];
850
851 fn gather<M>(table: &[Limb], i: Window, r: &mut Elem<M, R>) {
852 prefixed_extern! {
853 fn LIMBS_select_512_32(
854 r: *mut Limb,
855 table: *const Limb,
856 num_limbs: c::size_t,
857 i: Window,
858 ) -> bssl::Result;
859 }
860 Result::from(unsafe {
861 LIMBS_select_512_32(r.limbs.as_mut_ptr(), table.as_ptr(), r.limbs.len(), i)
862 })
863 .unwrap();
864 }
865
866 fn power<M>(
867 table: &[Limb],
868 i: Window,
869 mut acc: Elem<M, R>,
870 mut tmp: Elem<M, R>,
871 m: &Modulus<M>,
872 ) -> (Elem<M, R>, Elem<M, R>) {
873 for _ in 0..WINDOW_BITS {
874 acc = elem_squared(acc, &m.as_partial());
875 }
876 gather(table, i, &mut tmp);
877 let acc = elem_mul(&tmp, acc, m);
878 (acc, tmp)
879 }
880
881 let tmp = m.one();
882 let tmp = elem_mul(m.oneRR().as_ref(), tmp, m);
883
884 fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
885 &table[(i * num_limbs)..][..num_limbs]
886 }
887 fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
888 &mut table[(i * num_limbs)..][..num_limbs]
889 }
890 let num_limbs = m.limbs.len();
891 entry_mut(&mut table, 0, num_limbs).copy_from_slice(&tmp.limbs);
892 entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
893 for i in 2..TABLE_ENTRIES {
894 let (src1, src2) = if i % 2 == 0 {
895 (i / 2, i / 2)
896 } else {
897 (i - 1, 1)
898 };
899 let (previous, rest) = table.split_at_mut(num_limbs * i);
900 let src1 = entry(previous, src1, num_limbs);
901 let src2 = entry(previous, src2, num_limbs);
902 let dst = entry_mut(rest, 0, num_limbs);
903 limbs_mont_product(dst, src1, src2, &m.limbs, &m.n0);
904 }
905
906 let (r, _) = limb::fold_5_bit_windows(
907 &exponent.limbs,
908 |initial_window| {
909 let mut r = Elem {
910 limbs: base.limbs,
911 encoding: PhantomData,
912 };
913 gather(&table, initial_window, &mut r);
914 (r, tmp)
915 },
916 |(acc, tmp), window| power(&table, window, acc, tmp, m),
917 );
918
919 let r = r.into_unencoded(m);
920
921 Ok(r)
922 }
923
924 /// Uses Fermat's Little Theorem to calculate modular inverse in constant time.
elem_inverse_consttime<M: Prime>( a: Elem<M, R>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>925 pub fn elem_inverse_consttime<M: Prime>(
926 a: Elem<M, R>,
927 m: &Modulus<M>,
928 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
929 elem_exp_consttime(a, &PrivateExponent::for_flt(&m), m)
930 }
931
932 #[cfg(target_arch = "x86_64")]
elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent<M>, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, error::Unspecified>933 pub fn elem_exp_consttime<M>(
934 base: Elem<M, R>,
935 exponent: &PrivateExponent<M>,
936 m: &Modulus<M>,
937 ) -> Result<Elem<M, Unencoded>, error::Unspecified> {
938 // The x86_64 assembly was written under the assumption that the input data
939 // is aligned to `MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH` bytes, which was/is
940 // 64 in OpenSSL. Similarly, OpenSSL uses the x86_64 assembly functions by
941 // giving it only inputs `tmp`, `am`, and `np` that immediately follow the
942 // table. The code seems to "work" even when the inputs aren't exactly
943 // like that but the side channel defenses might not be as effective. All
944 // the awkwardness here stems from trying to use the assembly code like
945 // OpenSSL does.
946
947 use crate::limb::Window;
948
949 const WINDOW_BITS: usize = 5;
950 const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
951
952 let num_limbs = m.limbs.len();
953
954 const ALIGNMENT: usize = 64;
955 assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
956 let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
957 let (table, state) = {
958 let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
959 let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
960 assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
961 table.split_at_mut(TABLE_ENTRIES * num_limbs)
962 };
963
964 fn entry(table: &[Limb], i: usize, num_limbs: usize) -> &[Limb] {
965 &table[(i * num_limbs)..][..num_limbs]
966 }
967 fn entry_mut(table: &mut [Limb], i: usize, num_limbs: usize) -> &mut [Limb] {
968 &mut table[(i * num_limbs)..][..num_limbs]
969 }
970
971 const ACC: usize = 0; // `tmp` in OpenSSL
972 const BASE: usize = ACC + 1; // `am` in OpenSSL
973 const M: usize = BASE + 1; // `np` in OpenSSL
974
975 entry_mut(state, BASE, num_limbs).copy_from_slice(&base.limbs);
976 entry_mut(state, M, num_limbs).copy_from_slice(&m.limbs);
977
978 fn scatter(table: &mut [Limb], state: &[Limb], i: Window, num_limbs: usize) {
979 prefixed_extern! {
980 fn bn_scatter5(a: *const Limb, a_len: c::size_t, table: *mut Limb, i: Window);
981 }
982 unsafe {
983 bn_scatter5(
984 entry(state, ACC, num_limbs).as_ptr(),
985 num_limbs,
986 table.as_mut_ptr(),
987 i,
988 )
989 }
990 }
991
992 fn gather(table: &[Limb], state: &mut [Limb], i: Window, num_limbs: usize) {
993 prefixed_extern! {
994 fn bn_gather5(r: *mut Limb, a_len: c::size_t, table: *const Limb, i: Window);
995 }
996 unsafe {
997 bn_gather5(
998 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
999 num_limbs,
1000 table.as_ptr(),
1001 i,
1002 )
1003 }
1004 }
1005
1006 fn gather_square(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1007 gather(table, state, i, num_limbs);
1008 assert_eq!(ACC, 0);
1009 let (acc, rest) = state.split_at_mut(num_limbs);
1010 let m = entry(rest, M - 1, num_limbs);
1011 limbs_mont_square(acc, m, n0);
1012 }
1013
1014 fn gather_mul_base(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1015 prefixed_extern! {
1016 fn bn_mul_mont_gather5(
1017 rp: *mut Limb,
1018 ap: *const Limb,
1019 table: *const Limb,
1020 np: *const Limb,
1021 n0: &N0,
1022 num: c::size_t,
1023 power: Window,
1024 );
1025 }
1026 unsafe {
1027 bn_mul_mont_gather5(
1028 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1029 entry(state, BASE, num_limbs).as_ptr(),
1030 table.as_ptr(),
1031 entry(state, M, num_limbs).as_ptr(),
1032 n0,
1033 num_limbs,
1034 i,
1035 );
1036 }
1037 }
1038
1039 fn power(table: &[Limb], state: &mut [Limb], n0: &N0, i: Window, num_limbs: usize) {
1040 prefixed_extern! {
1041 fn bn_power5(
1042 r: *mut Limb,
1043 a: *const Limb,
1044 table: *const Limb,
1045 n: *const Limb,
1046 n0: &N0,
1047 num: c::size_t,
1048 i: Window,
1049 );
1050 }
1051 unsafe {
1052 bn_power5(
1053 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1054 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1055 table.as_ptr(),
1056 entry(state, M, num_limbs).as_ptr(),
1057 n0,
1058 num_limbs,
1059 i,
1060 );
1061 }
1062 }
1063
1064 // table[0] = base**0.
1065 {
1066 let acc = entry_mut(state, ACC, num_limbs);
1067 acc[0] = 1;
1068 limbs_mont_mul(acc, &m.oneRR.0.limbs, &m.limbs, &m.n0);
1069 }
1070 scatter(table, state, 0, num_limbs);
1071
1072 // table[1] = base**1.
1073 entry_mut(state, ACC, num_limbs).copy_from_slice(&base.limbs);
1074 scatter(table, state, 1, num_limbs);
1075
1076 for i in 2..(TABLE_ENTRIES as Window) {
1077 if i % 2 == 0 {
1078 // TODO: Optimize this to avoid gathering
1079 gather_square(table, state, &m.n0, i / 2, num_limbs);
1080 } else {
1081 gather_mul_base(table, state, &m.n0, i - 1, num_limbs)
1082 };
1083 scatter(table, state, i, num_limbs);
1084 }
1085
1086 let state = limb::fold_5_bit_windows(
1087 &exponent.limbs,
1088 |initial_window| {
1089 gather(table, state, initial_window, num_limbs);
1090 state
1091 },
1092 |state, window| {
1093 power(table, state, &m.n0, window, num_limbs);
1094 state
1095 },
1096 );
1097
1098 prefixed_extern! {
1099 fn bn_from_montgomery(
1100 r: *mut Limb,
1101 a: *const Limb,
1102 not_used: *const Limb,
1103 n: *const Limb,
1104 n0: &N0,
1105 num: c::size_t,
1106 ) -> bssl::Result;
1107 }
1108 Result::from(unsafe {
1109 bn_from_montgomery(
1110 entry_mut(state, ACC, num_limbs).as_mut_ptr(),
1111 entry(state, ACC, num_limbs).as_ptr(),
1112 core::ptr::null(),
1113 entry(state, M, num_limbs).as_ptr(),
1114 &m.n0,
1115 num_limbs,
1116 )
1117 })?;
1118 let mut r = Elem {
1119 limbs: base.limbs,
1120 encoding: PhantomData,
1121 };
1122 r.limbs.copy_from_slice(entry(state, ACC, num_limbs));
1123 Ok(r)
1124 }
1125
1126 /// Verified a == b**-1 (mod m), i.e. a**-1 == b (mod m).
verify_inverses_consttime<M>( a: &Elem<M, R>, b: Elem<M, Unencoded>, m: &Modulus<M>, ) -> Result<(), error::Unspecified>1127 pub fn verify_inverses_consttime<M>(
1128 a: &Elem<M, R>,
1129 b: Elem<M, Unencoded>,
1130 m: &Modulus<M>,
1131 ) -> Result<(), error::Unspecified> {
1132 if elem_mul(a, b, m).is_one() {
1133 Ok(())
1134 } else {
1135 Err(error::Unspecified)
1136 }
1137 }
1138
1139 #[inline]
elem_verify_equal_consttime<M, E>( a: &Elem<M, E>, b: &Elem<M, E>, ) -> Result<(), error::Unspecified>1140 pub fn elem_verify_equal_consttime<M, E>(
1141 a: &Elem<M, E>,
1142 b: &Elem<M, E>,
1143 ) -> Result<(), error::Unspecified> {
1144 if limb::limbs_equal_limbs_consttime(&a.limbs, &b.limbs) == LimbMask::True {
1145 Ok(())
1146 } else {
1147 Err(error::Unspecified)
1148 }
1149 }
1150
1151 /// Nonnegative integers.
1152 pub struct Nonnegative {
1153 limbs: Vec<Limb>,
1154 }
1155
1156 impl Nonnegative {
from_be_bytes_with_bit_length( input: untrusted::Input, ) -> Result<(Self, bits::BitLength), error::Unspecified>1157 pub fn from_be_bytes_with_bit_length(
1158 input: untrusted::Input,
1159 ) -> Result<(Self, bits::BitLength), error::Unspecified> {
1160 let mut limbs = vec![0; (input.len() + LIMB_BYTES - 1) / LIMB_BYTES];
1161 // Rejects empty inputs.
1162 limb::parse_big_endian_and_pad_consttime(input, &mut limbs)?;
1163 while limbs.last() == Some(&0) {
1164 let _ = limbs.pop();
1165 }
1166 let r_bits = limb::limbs_minimal_bits(&limbs);
1167 Ok((Self { limbs }, r_bits))
1168 }
1169
1170 #[inline]
is_odd(&self) -> bool1171 pub fn is_odd(&self) -> bool {
1172 limb::limbs_are_even_constant_time(&self.limbs) != LimbMask::True
1173 }
1174
verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified>1175 pub fn verify_less_than(&self, other: &Self) -> Result<(), error::Unspecified> {
1176 if !greater_than(other, self) {
1177 return Err(error::Unspecified);
1178 }
1179 Ok(())
1180 }
1181
to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified>1182 pub fn to_elem<M>(&self, m: &Modulus<M>) -> Result<Elem<M, Unencoded>, error::Unspecified> {
1183 self.verify_less_than_modulus(&m)?;
1184 let mut r = m.zero();
1185 r.limbs[0..self.limbs.len()].copy_from_slice(&self.limbs);
1186 Ok(r)
1187 }
1188
verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified>1189 pub fn verify_less_than_modulus<M>(&self, m: &Modulus<M>) -> Result<(), error::Unspecified> {
1190 if self.limbs.len() > m.limbs.len() {
1191 return Err(error::Unspecified);
1192 }
1193 if self.limbs.len() == m.limbs.len() {
1194 if limb::limbs_less_than_limbs_consttime(&self.limbs, &m.limbs) != LimbMask::True {
1195 return Err(error::Unspecified);
1196 }
1197 }
1198 Ok(())
1199 }
1200 }
1201
1202 // Returns a > b.
greater_than(a: &Nonnegative, b: &Nonnegative) -> bool1203 fn greater_than(a: &Nonnegative, b: &Nonnegative) -> bool {
1204 if a.limbs.len() == b.limbs.len() {
1205 limb::limbs_less_than_limbs_vartime(&b.limbs, &a.limbs)
1206 } else {
1207 a.limbs.len() > b.limbs.len()
1208 }
1209 }
1210
1211 #[derive(Clone)]
1212 #[repr(transparent)]
1213 struct N0([Limb; 2]);
1214
1215 const N0_LIMBS_USED: usize = 64 / LIMB_BITS;
1216
1217 impl From<u64> for N0 {
1218 #[inline]
from(n0: u64) -> Self1219 fn from(n0: u64) -> Self {
1220 #[cfg(target_pointer_width = "64")]
1221 {
1222 Self([n0, 0])
1223 }
1224
1225 #[cfg(target_pointer_width = "32")]
1226 {
1227 Self([n0 as Limb, (n0 >> LIMB_BITS) as Limb])
1228 }
1229 }
1230 }
1231
1232 /// r *= a
limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0)1233 fn limbs_mont_mul(r: &mut [Limb], a: &[Limb], m: &[Limb], n0: &N0) {
1234 debug_assert_eq!(r.len(), m.len());
1235 debug_assert_eq!(a.len(), m.len());
1236 unsafe {
1237 bn_mul_mont(
1238 r.as_mut_ptr(),
1239 r.as_ptr(),
1240 a.as_ptr(),
1241 m.as_ptr(),
1242 n0,
1243 r.len(),
1244 )
1245 }
1246 }
1247
limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0)1248 fn limbs_from_mont_in_place(r: &mut [Limb], tmp: &mut [Limb], m: &[Limb], n0: &N0) {
1249 prefixed_extern! {
1250 fn bn_from_montgomery_in_place(
1251 r: *mut Limb,
1252 num_r: c::size_t,
1253 a: *mut Limb,
1254 num_a: c::size_t,
1255 n: *const Limb,
1256 num_n: c::size_t,
1257 n0: &N0,
1258 ) -> bssl::Result;
1259 }
1260 Result::from(unsafe {
1261 bn_from_montgomery_in_place(
1262 r.as_mut_ptr(),
1263 r.len(),
1264 tmp.as_mut_ptr(),
1265 tmp.len(),
1266 m.as_ptr(),
1267 m.len(),
1268 &n0,
1269 )
1270 })
1271 .unwrap()
1272 }
1273
1274 #[cfg(not(any(
1275 target_arch = "aarch64",
1276 target_arch = "arm",
1277 target_arch = "x86",
1278 target_arch = "x86_64"
1279 )))]
limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb])1280 fn limbs_mul(r: &mut [Limb], a: &[Limb], b: &[Limb]) {
1281 debug_assert_eq!(r.len(), 2 * a.len());
1282 debug_assert_eq!(a.len(), b.len());
1283 let ab_len = a.len();
1284
1285 crate::polyfill::slice::fill(&mut r[..ab_len], 0);
1286 for (i, &b_limb) in b.iter().enumerate() {
1287 r[ab_len + i] = unsafe {
1288 limbs_mul_add_limb(
1289 (&mut r[i..][..ab_len]).as_mut_ptr(),
1290 a.as_ptr(),
1291 b_limb,
1292 ab_len,
1293 )
1294 };
1295 }
1296 }
1297
1298 /// r = a * b
1299 #[cfg(not(target_arch = "x86_64"))]
limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0)1300 fn limbs_mont_product(r: &mut [Limb], a: &[Limb], b: &[Limb], m: &[Limb], n0: &N0) {
1301 debug_assert_eq!(r.len(), m.len());
1302 debug_assert_eq!(a.len(), m.len());
1303 debug_assert_eq!(b.len(), m.len());
1304
1305 unsafe {
1306 bn_mul_mont(
1307 r.as_mut_ptr(),
1308 a.as_ptr(),
1309 b.as_ptr(),
1310 m.as_ptr(),
1311 n0,
1312 r.len(),
1313 )
1314 }
1315 }
1316
1317 /// r = r**2
limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0)1318 fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0) {
1319 debug_assert_eq!(r.len(), m.len());
1320 unsafe {
1321 bn_mul_mont(
1322 r.as_mut_ptr(),
1323 r.as_ptr(),
1324 r.as_ptr(),
1325 m.as_ptr(),
1326 n0,
1327 r.len(),
1328 )
1329 }
1330 }
1331
1332 prefixed_extern! {
1333 // `r` and/or 'a' and/or 'b' may alias.
1334 fn bn_mul_mont(
1335 r: *mut Limb,
1336 a: *const Limb,
1337 b: *const Limb,
1338 n: *const Limb,
1339 n0: &N0,
1340 num_limbs: c::size_t,
1341 );
1342 }
1343
1344 #[cfg(any(
1345 test,
1346 not(any(
1347 target_arch = "aarch64",
1348 target_arch = "arm",
1349 target_arch = "x86_64",
1350 target_arch = "x86"
1351 ))
1352 ))]
1353 prefixed_extern! {
1354 // `r` must not alias `a`
1355 #[must_use]
1356 fn limbs_mul_add_limb(r: *mut Limb, a: *const Limb, b: Limb, num_limbs: c::size_t) -> Limb;
1357 }
1358
strip_leading_zeros(value: &[u8]) -> Box<[u8]>1359 fn strip_leading_zeros(value: &[u8]) -> Box<[u8]> {
1360 fn index_after_zeros(bytes: &[u8]) -> usize {
1361 for (i, &value) in bytes.iter().enumerate() {
1362 if value != 0 {
1363 return i;
1364 }
1365 }
1366 bytes.len()
1367 }
1368 (&value[index_after_zeros(value)..]).into()
1369 }
1370
1371 #[cfg(test)]
1372 mod tests {
1373 use super::*;
1374 use crate::test;
1375 use alloc::format;
1376
1377 // Type-level representation of an arbitrary modulus.
1378 struct M {}
1379
1380 unsafe impl PublicModulus for M {}
1381
1382 #[test]
test_elem_exp_consttime()1383 fn test_elem_exp_consttime() {
1384 test::run(
1385 test_file!("bigint_elem_exp_consttime_tests.txt"),
1386 |section, test_case| {
1387 assert_eq!(section, "");
1388
1389 let m = consume_modulus::<M>(test_case, "M");
1390 let expected_result = consume_elem(test_case, "ModExp", &m);
1391 let base = consume_elem(test_case, "A", &m);
1392 let e = {
1393 let bytes = test_case.consume_bytes("E");
1394 PrivateExponent::from_be_bytes_padded(untrusted::Input::from(&bytes), &m)
1395 .expect("valid exponent")
1396 };
1397 let base = into_encoded(base, &m);
1398 let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
1399 assert_elem_eq(&actual_result, &expected_result);
1400
1401 Ok(())
1402 },
1403 )
1404 }
1405
1406 // TODO: fn test_elem_exp_vartime() using
1407 // "src/rsa/bigint_elem_exp_vartime_tests.txt". See that file for details.
1408 // In the meantime, the function is tested indirectly via the RSA
1409 // verification and signing tests.
1410 #[test]
test_elem_mul()1411 fn test_elem_mul() {
1412 test::run(
1413 test_file!("bigint_elem_mul_tests.txt"),
1414 |section, test_case| {
1415 assert_eq!(section, "");
1416
1417 let m = consume_modulus::<M>(test_case, "M");
1418 let expected_result = consume_elem(test_case, "ModMul", &m);
1419 let a = consume_elem(test_case, "A", &m);
1420 let b = consume_elem(test_case, "B", &m);
1421
1422 let b = into_encoded(b, &m);
1423 let a = into_encoded(a, &m);
1424 let actual_result = elem_mul(&a, b, &m);
1425 let actual_result = actual_result.into_unencoded(&m);
1426 assert_elem_eq(&actual_result, &expected_result);
1427
1428 Ok(())
1429 },
1430 )
1431 }
1432
1433 #[test]
test_elem_squared()1434 fn test_elem_squared() {
1435 test::run(
1436 test_file!("bigint_elem_squared_tests.txt"),
1437 |section, test_case| {
1438 assert_eq!(section, "");
1439
1440 let m = consume_modulus::<M>(test_case, "M");
1441 let expected_result = consume_elem(test_case, "ModSquare", &m);
1442 let a = consume_elem(test_case, "A", &m);
1443
1444 let a = into_encoded(a, &m);
1445 let actual_result = elem_squared(a, &m.as_partial());
1446 let actual_result = actual_result.into_unencoded(&m);
1447 assert_elem_eq(&actual_result, &expected_result);
1448
1449 Ok(())
1450 },
1451 )
1452 }
1453
1454 #[test]
test_elem_reduced()1455 fn test_elem_reduced() {
1456 test::run(
1457 test_file!("bigint_elem_reduced_tests.txt"),
1458 |section, test_case| {
1459 assert_eq!(section, "");
1460
1461 struct MM {}
1462 unsafe impl SmallerModulus<MM> for M {}
1463 unsafe impl NotMuchSmallerModulus<MM> for M {}
1464
1465 let m = consume_modulus::<M>(test_case, "M");
1466 let expected_result = consume_elem(test_case, "R", &m);
1467 let a =
1468 consume_elem_unchecked::<MM>(test_case, "A", expected_result.limbs.len() * 2);
1469
1470 let actual_result = elem_reduced(&a, &m);
1471 let oneRR = m.oneRR();
1472 let actual_result = elem_mul(oneRR.as_ref(), actual_result, &m);
1473 assert_elem_eq(&actual_result, &expected_result);
1474
1475 Ok(())
1476 },
1477 )
1478 }
1479
1480 #[test]
test_elem_reduced_once()1481 fn test_elem_reduced_once() {
1482 test::run(
1483 test_file!("bigint_elem_reduced_once_tests.txt"),
1484 |section, test_case| {
1485 assert_eq!(section, "");
1486
1487 struct N {}
1488 struct QQ {}
1489 unsafe impl SmallerModulus<N> for QQ {}
1490 unsafe impl SlightlySmallerModulus<N> for QQ {}
1491
1492 let qq = consume_modulus::<QQ>(test_case, "QQ");
1493 let expected_result = consume_elem::<QQ>(test_case, "R", &qq);
1494 let n = consume_modulus::<N>(test_case, "N");
1495 let a = consume_elem::<N>(test_case, "A", &n);
1496
1497 let actual_result = elem_reduced_once(&a, &qq);
1498 assert_elem_eq(&actual_result, &expected_result);
1499
1500 Ok(())
1501 },
1502 )
1503 }
1504
1505 #[test]
test_modulus_debug()1506 fn test_modulus_debug() {
1507 let (modulus, _) =
1508 Modulus::<M>::from_be_bytes_with_bit_length(untrusted::Input::from(&[0xff; 1024 / 8]))
1509 .unwrap();
1510 assert_eq!(
1511 "Modulus(\"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff\")",
1512 format!("{:?}", modulus)
1513 );
1514 }
1515
1516 #[test]
test_public_exponent_debug()1517 fn test_public_exponent_debug() {
1518 let exponent =
1519 PublicExponent::from_be_bytes(untrusted::Input::from(&[0x1, 0x00, 0x01]), 65537)
1520 .unwrap();
1521 assert_eq!("65537", format!("{:?}", exponent));
1522 }
1523
consume_elem<M>( test_case: &mut test::TestCase, name: &str, m: &Modulus<M>, ) -> Elem<M, Unencoded>1524 fn consume_elem<M>(
1525 test_case: &mut test::TestCase,
1526 name: &str,
1527 m: &Modulus<M>,
1528 ) -> Elem<M, Unencoded> {
1529 let value = test_case.consume_bytes(name);
1530 Elem::from_be_bytes_padded(untrusted::Input::from(&value), m).unwrap()
1531 }
1532
consume_elem_unchecked<M>( test_case: &mut test::TestCase, name: &str, num_limbs: usize, ) -> Elem<M, Unencoded>1533 fn consume_elem_unchecked<M>(
1534 test_case: &mut test::TestCase,
1535 name: &str,
1536 num_limbs: usize,
1537 ) -> Elem<M, Unencoded> {
1538 let value = consume_nonnegative(test_case, name);
1539 let mut limbs = BoxedLimbs::zero(Width {
1540 num_limbs,
1541 m: PhantomData,
1542 });
1543 limbs[0..value.limbs.len()].copy_from_slice(&value.limbs);
1544 Elem {
1545 limbs,
1546 encoding: PhantomData,
1547 }
1548 }
1549
consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M>1550 fn consume_modulus<M>(test_case: &mut test::TestCase, name: &str) -> Modulus<M> {
1551 let value = test_case.consume_bytes(name);
1552 let (value, _) =
1553 Modulus::from_be_bytes_with_bit_length(untrusted::Input::from(&value)).unwrap();
1554 value
1555 }
1556
consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative1557 fn consume_nonnegative(test_case: &mut test::TestCase, name: &str) -> Nonnegative {
1558 let bytes = test_case.consume_bytes(name);
1559 let (r, _r_bits) =
1560 Nonnegative::from_be_bytes_with_bit_length(untrusted::Input::from(&bytes)).unwrap();
1561 r
1562 }
1563
assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>)1564 fn assert_elem_eq<M, E>(a: &Elem<M, E>, b: &Elem<M, E>) {
1565 if elem_verify_equal_consttime(&a, b).is_err() {
1566 panic!("{:x?} != {:x?}", &*a.limbs, &*b.limbs);
1567 }
1568 }
1569
into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R>1570 fn into_encoded<M>(a: Elem<M, Unencoded>, m: &Modulus<M>) -> Elem<M, R> {
1571 elem_mul(m.oneRR().as_ref(), a, m)
1572 }
1573
1574 #[test]
1575 // TODO: wasm
test_mul_add_words()1576 fn test_mul_add_words() {
1577 const ZERO: Limb = 0;
1578 const MAX: Limb = ZERO.wrapping_sub(1);
1579 static TEST_CASES: &[(&[Limb], &[Limb], Limb, Limb, &[Limb])] = &[
1580 (&[0], &[0], 0, 0, &[0]),
1581 (&[MAX], &[0], MAX, 0, &[MAX]),
1582 (&[0], &[MAX], MAX, MAX - 1, &[1]),
1583 (&[MAX], &[MAX], MAX, MAX, &[0]),
1584 (&[0, 0], &[MAX, MAX], MAX, MAX - 1, &[1, MAX]),
1585 (&[1, 0], &[MAX, MAX], MAX, MAX - 1, &[2, MAX]),
1586 (&[MAX, 0], &[MAX, MAX], MAX, MAX, &[0, 0]),
1587 (&[0, 1], &[MAX, MAX], MAX, MAX, &[1, 0]),
1588 (&[MAX, MAX], &[MAX, MAX], MAX, MAX, &[0, MAX]),
1589 ];
1590
1591 for (i, (r_input, a, w, expected_retval, expected_r)) in TEST_CASES.iter().enumerate() {
1592 extern crate std;
1593 let mut r = std::vec::Vec::from(*r_input);
1594 assert_eq!(r.len(), a.len()); // Sanity check
1595 let actual_retval =
1596 unsafe { limbs_mul_add_limb(r.as_mut_ptr(), a.as_ptr(), *w, a.len()) };
1597 assert_eq!(&r, expected_r, "{}: {:x?} != {:x?}", i, &r[..], expected_r);
1598 assert_eq!(
1599 actual_retval, *expected_retval,
1600 "{}: {:x?} != {:x?}",
1601 i, actual_retval, *expected_retval
1602 );
1603 }
1604 }
1605 }
1606