• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2014 The BoringSSL Authors
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15#include <openssl/base.h>
16
17#include <memory>
18
19#include <assert.h>
20#include <stdlib.h>
21
22#include <openssl/bytestring.h>
23#include <openssl/mem.h>
24#include <openssl/rand.h>
25
26#include "../../internal.h"
27#include "../bcm_interface.h"
28#include "../keccak/internal.h"
29
30namespace mldsa {
31namespace {
32
33namespace fips {
34void ensure_keygen_self_test();
35void ensure_sign_self_test();
36void ensure_verify_self_test();
37}  // namespace fips
38
39constexpr int kDegree = 256;
40constexpr int kRhoBytes = 32;
41constexpr int kSigmaBytes = 64;
42constexpr int kKBytes = 32;
43constexpr int kTrBytes = 64;
44constexpr int kMuBytes = 64;
45constexpr int kRhoPrimeBytes = 64;
46
47// 2^23 - 2^13 + 1
48constexpr uint32_t kPrime = 8380417;
49// Inverse of -kPrime modulo 2^32
50constexpr uint32_t kPrimeNegInverse = 4236238847;
51constexpr int kDroppedBits = 13;
52constexpr uint32_t kHalfPrime = (kPrime - 1) / 2;
53constexpr uint32_t kGamma2 = (kPrime - 1) / 32;
54// 256^-1 mod kPrime, in Montgomery form.
55constexpr uint32_t kInverseDegreeMontgomery = 41978;
56
57// Constants that vary depending on ML-DSA size.
58//
59// These are implemented as templates which take the K parameter to distinguish
60// the ML-DSA sizes.
61
62template <int K>
63constexpr size_t public_key_bytes() {
64  if constexpr (K == 6) {
65    return BCM_MLDSA65_PUBLIC_KEY_BYTES;
66  } else if constexpr (K == 8) {
67    return BCM_MLDSA87_PUBLIC_KEY_BYTES;
68  }
69}
70
71template <int K>
72constexpr size_t signature_bytes() {
73  if constexpr (K == 6) {
74    return BCM_MLDSA65_SIGNATURE_BYTES;
75  } else if constexpr (K == 8) {
76    return BCM_MLDSA87_SIGNATURE_BYTES;
77  }
78}
79
80template <int K>
81constexpr int tau() {
82  if constexpr (K == 6) {
83    return 49;
84  } else if constexpr (K == 8) {
85    return 60;
86  }
87}
88
89template <int K>
90constexpr int lambda_bytes() {
91  if constexpr (K == 6) {
92    return 192 / 8;
93  } else if constexpr (K == 8) {
94    return 256 / 8;
95  }
96}
97
98template <int K>
99constexpr int gamma1() {
100  if constexpr (K == 6 || K == 8) {
101    return 1 << 19;
102  }
103}
104
105template <int K>
106constexpr int beta() {
107  if constexpr (K == 6) {
108    return 196;
109  } else if constexpr (K == 8) {
110    return 120;
111  }
112}
113
114template <int K>
115constexpr int omega() {
116  if constexpr (K == 6) {
117    return 55;
118  } else if constexpr (K == 8) {
119    return 75;
120  }
121}
122
123template <int K>
124constexpr int eta() {
125  if constexpr (K == 6) {
126    return 4;
127  } else if constexpr (K == 8) {
128    return 2;
129  }
130}
131
132template <int K>
133constexpr int plus_minus_eta_bitlen() {
134  if constexpr (K == 6) {
135    return 4;
136  } else if constexpr (K == 8) {
137    return 3;
138  }
139}
140
141// Fundamental types.
142
143typedef struct scalar {
144  uint32_t c[kDegree];
145} scalar;
146
147template <int K>
148struct vector {
149  scalar v[K];
150};
151
152template <int K, int L>
153struct matrix {
154  scalar v[K][L];
155};
156
157/* Arithmetic */
158
159// This bit of Python will be referenced in some of the following comments:
160//
161// q = 8380417
162// # Inverse of -q modulo 2^32
163// q_neg_inverse = 4236238847
164// # 2^64 modulo q
165// montgomery_square = 2365951
166//
167// def bitreverse(i):
168//     ret = 0
169//     for n in range(8):
170//         bit = i & 1
171//         ret <<= 1
172//         ret |= bit
173//         i >>= 1
174//     return ret
175//
176// def montgomery_reduce(x):
177//     a = (x * q_neg_inverse) % 2**32
178//     b = x + a * q
179//     assert b & 0xFFFF_FFFF == 0
180//     c = b >> 32
181//     assert c < q
182//     return c
183//
184// def montgomery_transform(x):
185//     return montgomery_reduce(x * montgomery_square)
186
187// kNTTRootsMontgomery = [
188//   montgomery_transform(pow(1753, bitreverse(i), q)) for i in range(256)
189// ]
190static const uint32_t kNTTRootsMontgomery[256] = {
191    4193792, 25847,   5771523, 7861508, 237124,  7602457, 7504169, 466468,
192    1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
193    2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868,
194    6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005,
195    2706023, 95776,   3077325, 3530437, 6718724, 4788269, 5842901, 3915439,
196    4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118,
197    6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596,
198    811944,  531354,  954230,  3881043, 3900724, 5823537, 2071892, 5582638,
199    4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196,
200    7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922,
201    3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370,
202    7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987,
203    5037034, 264944,  508951,  3097992, 44288,   7280319, 904516,  3958618,
204    4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969,  7063561,
205    189548,  4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330,
206    1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961,
207    2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955,
208    266997,  2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039,
209    900702,  1859098, 909542,  819034,  495491,  6767243, 8337157, 7857917,
210    7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579,
211    342297,  286988,  5942594, 4108315, 3437287, 5038140, 1735879, 203044,
212    2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974,
213    4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447,
214    7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775,
215    7100756, 1917081, 5834105, 7005614, 1500165, 777191,  2235880, 3406031,
216    7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136,  4603424,
217    6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531,  7173032,
218    5196991, 162844,  1616392, 3014001, 810149,  1652634, 4686184, 6581310,
219    5341501, 3523897, 3866901, 269760,  2213111, 7404533, 1717735, 472078,
220    7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524,
221    5441381, 6144432, 7959518, 6094090, 183443,  7403526, 1612842, 4834730,
222    7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782};
223
224// Reduces x mod kPrime in constant time, where 0 <= x < 2*kPrime.
225uint32_t reduce_once(uint32_t x) {
226  declassify_assert(x < 2 * kPrime);
227  // return x < kPrime ? x : x - kPrime;
228  return constant_time_select_int(constant_time_lt_w(x, kPrime), x, x - kPrime);
229}
230
231// Returns the absolute value in constant time.
232uint32_t abs_signed(uint32_t x) {
233  // return is_positive(x) ? x : -x;
234  // Note: MSVC doesn't like applying the unary minus operator to unsigned types
235  // (warning C4146), so we write the negation as a bitwise not plus one
236  // (assuming two's complement representation).
237  return constant_time_select_int(constant_time_lt_w(x, 0x80000000), x, 0u - x);
238}
239
240// Returns the absolute value modulo kPrime.
241uint32_t abs_mod_prime(uint32_t x) {
242  declassify_assert(x < kPrime);
243  // return x > kHalfPrime ? kPrime - x : x;
244  return constant_time_select_int(constant_time_lt_w(kHalfPrime, x), kPrime - x,
245                                  x);
246}
247
248// Returns the maximum of two values in constant time.
249uint32_t maximum(uint32_t x, uint32_t y) {
250  // return x < y ? y : x;
251  return constant_time_select_int(constant_time_lt_w(x, y), y, x);
252}
253
254uint32_t mod_sub(uint32_t a, uint32_t b) {
255  declassify_assert(a < kPrime);
256  declassify_assert(b < kPrime);
257  return reduce_once(kPrime + a - b);
258}
259
260void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {
261  for (int i = 0; i < kDegree; i++) {
262    out->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
263  }
264}
265
266void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {
267  for (int i = 0; i < kDegree; i++) {
268    out->c[i] = mod_sub(lhs->c[i], rhs->c[i]);
269  }
270}
271
272uint32_t reduce_montgomery(uint64_t x) {
273  declassify_assert(x <= ((uint64_t)kPrime << 32));
274  uint64_t a = (uint32_t)x * kPrimeNegInverse;
275  uint64_t b = x + a * kPrime;
276  declassify_assert((b & 0xffffffff) == 0);
277  uint32_t c = b >> 32;
278  return reduce_once(c);
279}
280
281// Multiply two scalars in the number theoretically transformed state.
282void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
283  for (int i = 0; i < kDegree; i++) {
284    out->c[i] = reduce_montgomery((uint64_t)lhs->c[i] * (uint64_t)rhs->c[i]);
285  }
286}
287
288// In place number theoretic transform of a given scalar.
289//
290// FIPS 204, Algorithm 41 (`NTT`).
291static void scalar_ntt(scalar *s) {
292  // Step: 1, 2, 4, 8, ..., 128
293  // Offset: 128, 64, 32, 16, ..., 1
294  int offset = kDegree;
295  for (int step = 1; step < kDegree; step <<= 1) {
296    offset >>= 1;
297    int k = 0;
298    for (int i = 0; i < step; i++) {
299      assert(k == 2 * offset * i);
300      const uint32_t step_root = kNTTRootsMontgomery[step + i];
301      for (int j = k; j < k + offset; j++) {
302        uint32_t even = s->c[j];
303        // |reduce_montgomery| works on values up to kPrime*R and R > 2*kPrime.
304        // |step_root| < kPrime because it's static data. |s->c[...]| is <
305        // kPrime by the invariants of that struct.
306        uint32_t odd =
307            reduce_montgomery((uint64_t)step_root * (uint64_t)s->c[j + offset]);
308        s->c[j] = reduce_once(odd + even);
309        s->c[j + offset] = mod_sub(even, odd);
310      }
311      k += 2 * offset;
312    }
313  }
314}
315
316// In place inverse number theoretic transform of a given scalar.
317//
318// FIPS 204, Algorithm 42 (`NTT^-1`).
319void scalar_inverse_ntt(scalar *s) {
320  // Step: 128, 64, 32, 16, ..., 1
321  // Offset: 1, 2, 4, 8, ..., 128
322  int step = kDegree;
323  for (int offset = 1; offset < kDegree; offset <<= 1) {
324    step >>= 1;
325    int k = 0;
326    for (int i = 0; i < step; i++) {
327      assert(k == 2 * offset * i);
328      const uint32_t step_root =
329          kPrime - kNTTRootsMontgomery[step + (step - 1 - i)];
330      for (int j = k; j < k + offset; j++) {
331        uint32_t even = s->c[j];
332        uint32_t odd = s->c[j + offset];
333        s->c[j] = reduce_once(odd + even);
334
335        // |reduce_montgomery| works on values up to kPrime*R and R > 2*kPrime.
336        // kPrime + even < 2*kPrime because |even| < kPrime, by the invariants
337        // of that structure. Thus kPrime + even - odd < 2*kPrime because odd >=
338        // 0, because it's unsigned and less than kPrime. Lastly step_root <
339        // kPrime, because |kNTTRootsMontgomery| is static data.
340        s->c[j + offset] = reduce_montgomery((uint64_t)step_root *
341                                             (uint64_t)(kPrime + even - odd));
342      }
343      k += 2 * offset;
344    }
345  }
346  for (int i = 0; i < kDegree; i++) {
347    s->c[i] = reduce_montgomery((uint64_t)s->c[i] *
348                                (uint64_t)kInverseDegreeMontgomery);
349  }
350}
351
352template <int X>
353void vector_zero(vector<X> *out) {
354  OPENSSL_memset(out, 0, sizeof(*out));
355}
356
357template <int X>
358void vector_add(vector<X> *out, const vector<X> *lhs, const vector<X> *rhs) {
359  for (int i = 0; i < X; i++) {
360    scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
361  }
362}
363
364template <int X>
365void vector_sub(vector<X> *out, const vector<X> *lhs, const vector<X> *rhs) {
366  for (int i = 0; i < X; i++) {
367    scalar_sub(&out->v[i], &lhs->v[i], &rhs->v[i]);
368  }
369}
370
371template <int X>
372void vector_mult_scalar(vector<X> *out, const vector<X> *lhs,
373                        const scalar *rhs) {
374  for (int i = 0; i < X; i++) {
375    scalar_mult(&out->v[i], &lhs->v[i], rhs);
376  }
377}
378
379template <int X>
380void vector_ntt(vector<X> *a) {
381  for (int i = 0; i < X; i++) {
382    scalar_ntt(&a->v[i]);
383  }
384}
385
386template <int X>
387void vector_inverse_ntt(vector<X> *a) {
388  for (int i = 0; i < X; i++) {
389    scalar_inverse_ntt(&a->v[i]);
390  }
391}
392
393template <int K, int L>
394void matrix_mult(vector<K> *out, const matrix<K, L> *m, const vector<L> *a) {
395  vector_zero(out);
396  for (int i = 0; i < K; i++) {
397    for (int j = 0; j < L; j++) {
398      scalar product;
399      scalar_mult(&product, &m->v[i][j], &a->v[j]);
400      scalar_add(&out->v[i], &out->v[i], &product);
401    }
402  }
403}
404
405/* Rounding & hints */
406
407// FIPS 204, Algorithm 35 (`Power2Round`).
408void power2_round(uint32_t *r1, uint32_t *r0, uint32_t r) {
409  *r1 = r >> kDroppedBits;
410  *r0 = r - (*r1 << kDroppedBits);
411
412  uint32_t r0_adjusted = mod_sub(*r0, 1 << kDroppedBits);
413  uint32_t r1_adjusted = *r1 + 1;
414
415  // Mask is set iff r0 > 2^(dropped_bits - 1).
416  crypto_word_t mask =
417      constant_time_lt_w((uint32_t)(1 << (kDroppedBits - 1)), *r0);
418  // r0 = mask ? r0_adjusted : r0
419  *r0 = constant_time_select_int(mask, r0_adjusted, *r0);
420  // r1 = mask ? r1_adjusted : r1
421  *r1 = constant_time_select_int(mask, r1_adjusted, *r1);
422}
423
424// Scale back previously rounded value.
425void scale_power2_round(uint32_t *out, uint32_t r1) {
426  // Pre-condition: 0 <= r1 <= 2^10 - 1
427  assert(r1 < (1u << 10));
428
429  *out = r1 << kDroppedBits;
430
431  // Post-condition: 0 <= out <= 2^23 - 2^13 = kPrime - 1
432  assert(*out < kPrime);
433}
434
435// FIPS 204, Algorithm 37 (`HighBits`).
436uint32_t high_bits(uint32_t x) {
437  // Reference description (given 0 <= x < q):
438  //
439  // ```
440  // int32_t r0 = x mod+- (2 * kGamma2);
441  // if (x - r0 == q - 1) {
442  //   return 0;
443  // } else {
444  //   return (x - r0) / (2 * kGamma2);
445  // }
446  // ```
447  //
448  // Below is the formula taken from the reference implementation.
449  //
450  // Here, kGamma2 == 2^18 - 2^8
451  // This returns ((ceil(x / 2^7) * (2^10 + 1) + 2^21) / 2^22) mod 2^4
452  uint32_t r1 = (x + 127) >> 7;
453  r1 = (r1 * 1025 + (1 << 21)) >> 22;
454  r1 &= 15;
455  return r1;
456}
457
458// FIPS 204, Algorithm 36 (`Decompose`).
459void decompose(uint32_t *r1, int32_t *r0, uint32_t r) {
460  *r1 = high_bits(r);
461
462  *r0 = r;
463  *r0 -= *r1 * 2 * (int32_t)kGamma2;
464  *r0 -= (((int32_t)kHalfPrime - *r0) >> 31) & (int32_t)kPrime;
465}
466
467// FIPS 204, Algorithm 38 (`LowBits`).
468int32_t low_bits(uint32_t x) {
469  uint32_t r1;
470  int32_t r0;
471  decompose(&r1, &r0, x);
472  return r0;
473}
474
475// FIPS 204, Algorithm 39 (`MakeHint`).
476//
477// In the spec this takes two arguments, z and r, and is called with
478//   z = -ct0
479//   r = w - cs2 + ct0
480//
481// It then computes HighBits (algorithm 37) of z and z+r. But z+r is just w -
482// cs2, so this takes three arguments and saves an addition.
483int32_t make_hint(uint32_t ct0, uint32_t cs2, uint32_t w) {
484  uint32_t r_plus_z = mod_sub(w, cs2);
485  uint32_t r = reduce_once(r_plus_z + ct0);
486  return high_bits(r) != high_bits(r_plus_z);
487}
488
489// FIPS 204, Algorithm 40 (`UseHint`).
490uint32_t use_hint_vartime(uint32_t h, uint32_t r) {
491  uint32_t r1;
492  int32_t r0;
493  decompose(&r1, &r0, r);
494
495  if (h) {
496    if (r0 > 0) {
497      // m = 16, thus |mod m| in the spec turns into |& 15|.
498      return (r1 + 1) & 15;
499    } else {
500      return (r1 - 1) & 15;
501    }
502  }
503  return r1;
504}
505
506void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {
507  for (int i = 0; i < kDegree; i++) {
508    power2_round(&s1->c[i], &s0->c[i], s->c[i]);
509  }
510}
511
512void scalar_scale_power2_round(scalar *out, const scalar *in) {
513  for (int i = 0; i < kDegree; i++) {
514    scale_power2_round(&out->c[i], in->c[i]);
515  }
516}
517
518void scalar_high_bits(scalar *out, const scalar *in) {
519  for (int i = 0; i < kDegree; i++) {
520    out->c[i] = high_bits(in->c[i]);
521  }
522}
523
524void scalar_low_bits(scalar *out, const scalar *in) {
525  for (int i = 0; i < kDegree; i++) {
526    out->c[i] = low_bits(in->c[i]);
527  }
528}
529
530void scalar_max(uint32_t *max, const scalar *s) {
531  for (int i = 0; i < kDegree; i++) {
532    uint32_t abs = abs_mod_prime(s->c[i]);
533    *max = maximum(*max, abs);
534  }
535}
536
537void scalar_max_signed(uint32_t *max, const scalar *s) {
538  for (int i = 0; i < kDegree; i++) {
539    uint32_t abs = abs_signed(s->c[i]);
540    *max = maximum(*max, abs);
541  }
542}
543
544void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
545                      const scalar *w) {
546  for (int i = 0; i < kDegree; i++) {
547    out->c[i] = make_hint(ct0->c[i], cs2->c[i], w->c[i]);
548  }
549}
550
551void scalar_use_hint_vartime(scalar *out, const scalar *h, const scalar *r) {
552  for (int i = 0; i < kDegree; i++) {
553    out->c[i] = use_hint_vartime(h->c[i], r->c[i]);
554  }
555}
556
557template <int X>
558void vector_power2_round(vector<X> *t1, vector<X> *t0, const vector<X> *t) {
559  for (int i = 0; i < X; i++) {
560    scalar_power2_round(&t1->v[i], &t0->v[i], &t->v[i]);
561  }
562}
563
564template <int X>
565void vector_scale_power2_round(vector<X> *out, const vector<X> *in) {
566  for (int i = 0; i < X; i++) {
567    scalar_scale_power2_round(&out->v[i], &in->v[i]);
568  }
569}
570
571template <int X>
572void vector_high_bits(vector<X> *out, const vector<X> *in) {
573  for (int i = 0; i < X; i++) {
574    scalar_high_bits(&out->v[i], &in->v[i]);
575  }
576}
577
578template <int X>
579void vector_low_bits(vector<X> *out, const vector<X> *in) {
580  for (int i = 0; i < X; i++) {
581    scalar_low_bits(&out->v[i], &in->v[i]);
582  }
583}
584
585template <int X>
586uint32_t vector_max(const vector<X> *a) {
587  uint32_t max = 0;
588  for (int i = 0; i < X; i++) {
589    scalar_max(&max, &a->v[i]);
590  }
591  return max;
592}
593
594template <int X>
595uint32_t vector_max_signed(const vector<X> *a) {
596  uint32_t max = 0;
597  for (int i = 0; i < X; i++) {
598    scalar_max_signed(&max, &a->v[i]);
599  }
600  return max;
601}
602
603// The input vector contains only zeroes and ones.
604template <int X>
605size_t vector_count_ones(const vector<X> *a) {
606  size_t count = 0;
607  for (int i = 0; i < X; i++) {
608    for (int j = 0; j < kDegree; j++) {
609      count += a->v[i].c[j];
610    }
611  }
612  return count;
613}
614
615template <int X>
616void vector_make_hint(vector<X> *out, const vector<X> *ct0,
617                      const vector<X> *cs2, const vector<X> *w) {
618  for (int i = 0; i < X; i++) {
619    scalar_make_hint(&out->v[i], &ct0->v[i], &cs2->v[i], &w->v[i]);
620  }
621}
622
623template <int X>
624void vector_use_hint_vartime(vector<X> *out, const vector<X> *h,
625                             const vector<X> *r) {
626  for (int i = 0; i < X; i++) {
627    scalar_use_hint_vartime(&out->v[i], &h->v[i], &r->v[i]);
628  }
629}
630
631/* Bit packing */
632
633// FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 4.
634static void scalar_encode_4(uint8_t out[128], const scalar *s) {
635  // Every two elements lands on a byte boundary.
636  static_assert(kDegree % 2 == 0, "kDegree must be a multiple of 2");
637  for (int i = 0; i < kDegree / 2; i++) {
638    uint32_t a = s->c[2 * i];
639    uint32_t b = s->c[2 * i + 1];
640    declassify_assert(a < 16);
641    declassify_assert(b < 16);
642    out[i] = a | (b << 4);
643  }
644}
645
646// FIPS 204, Algorithm 16 (`SimpleBitPack`). Specialized to bitlen(b) = 10.
647void scalar_encode_10(uint8_t out[320], const scalar *s) {
648  // Every four elements lands on a byte boundary.
649  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
650  for (int i = 0; i < kDegree / 4; i++) {
651    uint32_t a = s->c[4 * i];
652    uint32_t b = s->c[4 * i + 1];
653    uint32_t c = s->c[4 * i + 2];
654    uint32_t d = s->c[4 * i + 3];
655    declassify_assert(a < 1024);
656    declassify_assert(b < 1024);
657    declassify_assert(c < 1024);
658    declassify_assert(d < 1024);
659    out[5 * i] = (uint8_t)a;
660    out[5 * i + 1] = (uint8_t)((a >> 8) | (b << 2));
661    out[5 * i + 2] = (uint8_t)((b >> 6) | (c << 4));
662    out[5 * i + 3] = (uint8_t)((c >> 4) | (d << 6));
663    out[5 * i + 4] = (uint8_t)(d >> 2);
664  }
665}
666
667// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(a+b) = 4 and b = 4.
668void scalar_encode_signed_4_4(uint8_t out[128], const scalar *s) {
669  // Every two elements lands on a byte boundary.
670  static_assert(kDegree % 2 == 0, "kDegree must be a multiple of 2");
671  for (int i = 0; i < kDegree / 2; i++) {
672    uint32_t a = mod_sub(4, s->c[2 * i]);
673    uint32_t b = mod_sub(4, s->c[2 * i + 1]);
674    declassify_assert(a < 16);
675    declassify_assert(b < 16);
676    out[i] = a | (b << 4);
677  }
678}
679
680// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(a+b) = 3 and b = 2.
681static void scalar_encode_signed_3_2(uint8_t out[96], const scalar *s) {
682  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
683  for (int i = 0; i < kDegree / 8; i++) {
684    uint32_t a = mod_sub(2, s->c[8 * i]);
685    uint32_t b = mod_sub(2, s->c[8 * i + 1]);
686    uint32_t c = mod_sub(2, s->c[8 * i + 2]);
687    uint32_t d = mod_sub(2, s->c[8 * i + 3]);
688    uint32_t e = mod_sub(2, s->c[8 * i + 4]);
689    uint32_t f = mod_sub(2, s->c[8 * i + 5]);
690    uint32_t g = mod_sub(2, s->c[8 * i + 6]);
691    uint32_t h = mod_sub(2, s->c[8 * i + 7]);
692    uint32_t v = (h << 21) | (g << 18) | (f << 15) | (e << 12) | (d << 9) |
693                 (c << 6) | (b << 3) | a;
694    uint8_t v_bytes[sizeof(v)];
695    CRYPTO_store_u32_le(v_bytes, v);
696    OPENSSL_memcpy(&out[i * 3], v_bytes, 3);
697  }
698}
699
700// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(a+b) = 13 and b =
701// 2^12.
702void scalar_encode_signed_13_12(uint8_t out[416], const scalar *s) {
703  static const uint32_t kMax = 1u << 12;
704  // Every two elements lands on a byte boundary.
705  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
706  for (int i = 0; i < kDegree / 8; i++) {
707    uint32_t a = mod_sub(kMax, s->c[8 * i]);
708    uint32_t b = mod_sub(kMax, s->c[8 * i + 1]);
709    uint32_t c = mod_sub(kMax, s->c[8 * i + 2]);
710    uint32_t d = mod_sub(kMax, s->c[8 * i + 3]);
711    uint32_t e = mod_sub(kMax, s->c[8 * i + 4]);
712    uint32_t f = mod_sub(kMax, s->c[8 * i + 5]);
713    uint32_t g = mod_sub(kMax, s->c[8 * i + 6]);
714    uint32_t h = mod_sub(kMax, s->c[8 * i + 7]);
715    declassify_assert(a < (1u << 13));
716    declassify_assert(b < (1u << 13));
717    declassify_assert(c < (1u << 13));
718    declassify_assert(d < (1u << 13));
719    declassify_assert(e < (1u << 13));
720    declassify_assert(f < (1u << 13));
721    declassify_assert(g < (1u << 13));
722    declassify_assert(h < (1u << 13));
723    a |= b << 13;
724    a |= c << 26;
725    c >>= 6;
726    c |= d << 7;
727    c |= e << 20;
728    e >>= 12;
729    e |= f << 1;
730    e |= g << 14;
731    e |= h << 27;
732    h >>= 5;
733    OPENSSL_memcpy(&out[13 * i], &a, sizeof(a));
734    OPENSSL_memcpy(&out[13 * i + 4], &c, sizeof(c));
735    OPENSSL_memcpy(&out[13 * i + 8], &e, sizeof(e));
736    OPENSSL_memcpy(&out[13 * i + 12], &h, 1);
737  }
738}
739
740// FIPS 204, Algorithm 17 (`BitPack`). Specialized to bitlen(a+b) = 20 and b =
741// 2^19.
742void scalar_encode_signed_20_19(uint8_t out[640], const scalar *s) {
743  static const uint32_t kMax = 1u << 19;
744  // Every two elements lands on a byte boundary.
745  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
746  for (int i = 0; i < kDegree / 4; i++) {
747    uint32_t a = mod_sub(kMax, s->c[4 * i]);
748    uint32_t b = mod_sub(kMax, s->c[4 * i + 1]);
749    uint32_t c = mod_sub(kMax, s->c[4 * i + 2]);
750    uint32_t d = mod_sub(kMax, s->c[4 * i + 3]);
751    declassify_assert(a < (1u << 20));
752    declassify_assert(b < (1u << 20));
753    declassify_assert(c < (1u << 20));
754    declassify_assert(d < (1u << 20));
755    a |= b << 20;
756    b >>= 12;
757    b |= c << 8;
758    b |= d << 28;
759    d >>= 4;
760    OPENSSL_memcpy(&out[10 * i], &a, sizeof(a));
761    OPENSSL_memcpy(&out[10 * i + 4], &b, sizeof(b));
762    OPENSSL_memcpy(&out[10 * i + 8], &d, 2);
763  }
764}
765
766// FIPS 204, Algorithm 17 (`BitPack`).
767void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
768                          uint32_t max) {
769  if (bits == 3) {
770    assert(max == 2);
771    scalar_encode_signed_3_2(out, s);
772  } else if (bits == 4) {
773    assert(max == 4);
774    scalar_encode_signed_4_4(out, s);
775  } else if (bits == 20) {
776    assert(max == 1u << 19);
777    scalar_encode_signed_20_19(out, s);
778  } else {
779    assert(bits == 13);
780    assert(max == 1u << 12);
781    scalar_encode_signed_13_12(out, s);
782  }
783}
784
785// FIPS 204, Algorithm 18 (`SimpleBitUnpack`). Specialized for bitlen(b) == 10.
786void scalar_decode_10(scalar *out, const uint8_t in[320]) {
787  uint32_t v;
788  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
789  for (int i = 0; i < kDegree / 4; i++) {
790    OPENSSL_memcpy(&v, &in[5 * i], sizeof(v));
791    out->c[4 * i] = v & 0x3ff;
792    out->c[4 * i + 1] = (v >> 10) & 0x3ff;
793    out->c[4 * i + 2] = (v >> 20) & 0x3ff;
794    out->c[4 * i + 3] = (v >> 30) | (((uint32_t)in[5 * i + 4]) << 2);
795  }
796}
797
798// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 4 and b =
799// 4.
800int scalar_decode_signed_4_4(scalar *out, const uint8_t in[128]) {
801  uint32_t v;
802  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
803  for (int i = 0; i < kDegree / 8; i++) {
804    OPENSSL_memcpy(&v, &in[4 * i], sizeof(v));
805    // None of the nibbles may be >= 9. So if the MSB of any nibble is set, none
806    // of the other bits may be set. First, select all the MSBs.
807    const uint32_t msbs = v & 0x88888888u;
808    // For each nibble where the MSB is set, form a mask of all the other bits.
809    const uint32_t mask = (msbs >> 1) | (msbs >> 2) | (msbs >> 3);
810    // A nibble is only out of range in the case of invalid input, in which case
811    // it is okay to leak the value.
812    if (constant_time_declassify_int((mask & v) != 0)) {
813      return 0;
814    }
815
816    out->c[i * 8] = mod_sub(4, v & 15);
817    out->c[i * 8 + 1] = mod_sub(4, (v >> 4) & 15);
818    out->c[i * 8 + 2] = mod_sub(4, (v >> 8) & 15);
819    out->c[i * 8 + 3] = mod_sub(4, (v >> 12) & 15);
820    out->c[i * 8 + 4] = mod_sub(4, (v >> 16) & 15);
821    out->c[i * 8 + 5] = mod_sub(4, (v >> 20) & 15);
822    out->c[i * 8 + 6] = mod_sub(4, (v >> 24) & 15);
823    out->c[i * 8 + 7] = mod_sub(4, v >> 28);
824  }
825  return 1;
826}
827
828// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 3 and b =
829// 2.
830static int scalar_decode_signed_3_2(scalar *out, const uint8_t in[96]) {
831  uint32_t v;
832  uint8_t v_bytes[sizeof(v)] = {0};
833  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
834  for (int i = 0; i < kDegree / 8; i++) {
835    OPENSSL_memcpy(v_bytes, &in[3 * i], 3);
836    v = CRYPTO_load_u32_le(v_bytes);
837    // v contains 8, 3-bit values in the lower 24 bits. None of the values may
838    // be >= 5. So if the MSB of any triple is set, none of the other bits may
839    // be set. First, select all the MSBs.
840    const uint32_t msbs = v & 000044444444u;
841    // For each triple where the MSB is set, form a mask of all the other bits.
842    const uint32_t mask = (msbs >> 1) | (msbs >> 2);
843    // A triple is only out of range in the case of invalid input, in which case
844    // it is okay to leak the value.
845    if (constant_time_declassify_int((mask & v) != 0)) {
846      return 0;
847    }
848
849    out->c[i * 8 + 0] = mod_sub(2, (v >> 0) & 7);
850    out->c[i * 8 + 1] = mod_sub(2, (v >> 3) & 7);
851    out->c[i * 8 + 2] = mod_sub(2, (v >> 6) & 7);
852    out->c[i * 8 + 3] = mod_sub(2, (v >> 9) & 7);
853    out->c[i * 8 + 4] = mod_sub(2, (v >> 12) & 7);
854    out->c[i * 8 + 5] = mod_sub(2, (v >> 15) & 7);
855    out->c[i * 8 + 6] = mod_sub(2, (v >> 18) & 7);
856    out->c[i * 8 + 7] = mod_sub(2, v >> 21);
857  }
858  return 1;
859}
860
861// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 13 and b =
862// 2^12.
863void scalar_decode_signed_13_12(scalar *out, const uint8_t in[416]) {
864  static const uint32_t kMax = 1u << 12;
865  static const uint32_t k13Bits = (1u << 13) - 1;
866  static const uint32_t k7Bits = (1u << 7) - 1;
867
868  uint32_t a, b, c;
869  uint8_t d;
870  static_assert(kDegree % 8 == 0, "kDegree must be a multiple of 8");
871  for (int i = 0; i < kDegree / 8; i++) {
872    OPENSSL_memcpy(&a, &in[13 * i], sizeof(a));
873    OPENSSL_memcpy(&b, &in[13 * i + 4], sizeof(b));
874    OPENSSL_memcpy(&c, &in[13 * i + 8], sizeof(c));
875    d = in[13 * i + 12];
876
877    // It's not possible for a 13-bit number to be out of range when the max is
878    // 2^12.
879    out->c[i * 8] = mod_sub(kMax, a & k13Bits);
880    out->c[i * 8 + 1] = mod_sub(kMax, (a >> 13) & k13Bits);
881    out->c[i * 8 + 2] = mod_sub(kMax, (a >> 26) | ((b & k7Bits) << 6));
882    out->c[i * 8 + 3] = mod_sub(kMax, (b >> 7) & k13Bits);
883    out->c[i * 8 + 4] = mod_sub(kMax, (b >> 20) | ((c & 1) << 12));
884    out->c[i * 8 + 5] = mod_sub(kMax, (c >> 1) & k13Bits);
885    out->c[i * 8 + 6] = mod_sub(kMax, (c >> 14) & k13Bits);
886    out->c[i * 8 + 7] = mod_sub(kMax, (c >> 27) | ((uint32_t)d) << 5);
887  }
888}
889
890// FIPS 204, Algorithm 19 (`BitUnpack`). Specialized to bitlen(a+b) = 20 and b =
891// 2^19.
892void scalar_decode_signed_20_19(scalar *out, const uint8_t in[640]) {
893  static const uint32_t kMax = 1u << 19;
894  static const uint32_t k20Bits = (1u << 20) - 1;
895
896  uint32_t a, b;
897  uint16_t c;
898  static_assert(kDegree % 4 == 0, "kDegree must be a multiple of 4");
899  for (int i = 0; i < kDegree / 4; i++) {
900    OPENSSL_memcpy(&a, &in[10 * i], sizeof(a));
901    OPENSSL_memcpy(&b, &in[10 * i + 4], sizeof(b));
902    OPENSSL_memcpy(&c, &in[10 * i + 8], sizeof(c));
903
904    // It's not possible for a 20-bit number to be out of range when the max is
905    // 2^19.
906    out->c[i * 4] = mod_sub(kMax, a & k20Bits);
907    out->c[i * 4 + 1] = mod_sub(kMax, (a >> 20) | ((b & 0xff) << 12));
908    out->c[i * 4 + 2] = mod_sub(kMax, (b >> 8) & k20Bits);
909    out->c[i * 4 + 3] = mod_sub(kMax, (b >> 28) | ((uint32_t)c) << 4);
910  }
911}
912
913// FIPS 204, Algorithm 19 (`BitUnpack`).
914int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
915                         uint32_t max) {
916  if (bits == 3) {
917    assert(max == 2);
918    return scalar_decode_signed_3_2(out, in);
919  } else if (bits == 4) {
920    assert(max == 4);
921    return scalar_decode_signed_4_4(out, in);
922  } else if (bits == 13) {
923    assert(max == (1u << 12));
924    scalar_decode_signed_13_12(out, in);
925    return 1;
926  } else if (bits == 20) {
927    assert(max == (1u << 19));
928    scalar_decode_signed_20_19(out, in);
929    return 1;
930  } else {
931    abort();
932  }
933}
934
935/* Expansion functions */
936
937// FIPS 204, Algorithm 30 (`RejNTTPoly`).
938//
939// Rejection samples a Keccak stream to get uniformly distributed elements. This
940// is used for matrix expansion and only operates on public inputs.
941void scalar_from_keccak_vartime(scalar *out,
942                                const uint8_t derived_seed[kRhoBytes + 2]) {
943  struct BORINGSSL_keccak_st keccak_ctx;
944  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
945  BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, kRhoBytes + 2);
946  assert(keccak_ctx.squeeze_offset == 0);
947  assert(keccak_ctx.rate_bytes == 168);
948  static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
949
950  int done = 0;
951  while (done < kDegree) {
952    uint8_t block[168];
953    BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
954    for (size_t i = 0; i < sizeof(block) && done < kDegree; i += 3) {
955      // FIPS 204, Algorithm 14 (`CoeffFromThreeBytes`).
956      uint32_t value = (uint32_t)block[i] | ((uint32_t)block[i + 1] << 8) |
957                       (((uint32_t)block[i + 2] & 0x7f) << 16);
958      if (value < kPrime) {
959        out->c[done++] = value;
960      }
961    }
962  }
963}
964
965template <int ETA>
966static bool coefficient_from_nibble(uint32_t nibble, uint32_t *result);
967
968template <>
969bool coefficient_from_nibble<4>(uint32_t nibble, uint32_t *result) {
970  if (constant_time_declassify_int(nibble < 9)) {
971    *result = mod_sub(4, nibble);
972    return true;
973  }
974  return false;
975}
976
977template <>
978bool coefficient_from_nibble<2>(uint32_t nibble, uint32_t *result) {
979  if (constant_time_declassify_int(nibble < 15)) {
980    *result = mod_sub(2, nibble % 5);
981    return true;
982  }
983  return false;
984}
985
986// FIPS 204, Algorithm 31 (`RejBoundedPoly`).
987template <int ETA>
988void scalar_uniform(scalar *out, const uint8_t derived_seed[kSigmaBytes + 2]) {
989  struct BORINGSSL_keccak_st keccak_ctx;
990  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
991  BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, kSigmaBytes + 2);
992  assert(keccak_ctx.squeeze_offset == 0);
993  assert(keccak_ctx.rate_bytes == 136);
994
995  int done = 0;
996  while (done < kDegree) {
997    uint8_t block[136];
998    BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
999    for (size_t i = 0; i < sizeof(block) && done < kDegree; ++i) {
1000      uint32_t t0 = block[i] & 0x0F;
1001      uint32_t t1 = block[i] >> 4;
1002      // FIPS 204, Algorithm 15 (`CoefFromHalfByte`). Although both the input
1003      // and output here are secret, it is OK to leak when we rejected a byte.
1004      // Individual bytes of the SHAKE-256 stream are (indistiguishable from)
1005      // independent of each other and the original seed, so leaking information
1006      // about the rejected bytes does not reveal the input or output.
1007      uint32_t v;
1008      if (coefficient_from_nibble<ETA>(t0, &v)) {
1009        out->c[done++] = v;
1010      }
1011      if (done < kDegree && coefficient_from_nibble<ETA>(t1, &v)) {
1012        out->c[done++] = v;
1013      }
1014    }
1015  }
1016}
1017
1018// FIPS 204, Algorithm 34 (`ExpandMask`), but just a single step.
1019void scalar_sample_mask(scalar *out,
1020                        const uint8_t derived_seed[kRhoPrimeBytes + 2]) {
1021  uint8_t buf[640];
1022  BORINGSSL_keccak(buf, sizeof(buf), derived_seed, kRhoPrimeBytes + 2,
1023                   boringssl_shake256);
1024
1025  scalar_decode_signed_20_19(out, buf);
1026}
1027
1028// FIPS 204, Algorithm 29 (`SampleInBall`).
1029void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed, int len,
1030                                   int tau) {
1031  struct BORINGSSL_keccak_st keccak_ctx;
1032  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1033  BORINGSSL_keccak_absorb(&keccak_ctx, seed, len);
1034  assert(keccak_ctx.squeeze_offset == 0);
1035  assert(keccak_ctx.rate_bytes == 136);
1036
1037  uint8_t block[136];
1038  BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
1039
1040  uint64_t signs = CRYPTO_load_u64_le(block);
1041  int offset = 8;
1042  // SampleInBall implements a Fisher–Yates shuffle, which unavoidably leaks
1043  // where the zeros are by memory access pattern. Although this leak happens
1044  // before bad signatures are rejected, this is safe. See
1045  // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/8d8f01ac_70af3f21/
1046  CONSTTIME_DECLASSIFY(block + offset, sizeof(block) - offset);
1047
1048  OPENSSL_memset(out, 0, sizeof(*out));
1049  for (size_t i = kDegree - tau; i < kDegree; i++) {
1050    size_t byte;
1051    for (;;) {
1052      if (offset == 136) {
1053        BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
1054        // See above.
1055        CONSTTIME_DECLASSIFY(block, sizeof(block));
1056        offset = 0;
1057      }
1058
1059      byte = block[offset++];
1060      if (byte <= i) {
1061        break;
1062      }
1063    }
1064
1065    out->c[i] = out->c[byte];
1066    out->c[byte] = mod_sub(1, 2 * (signs & 1));
1067    signs >>= 1;
1068  }
1069}
1070
1071// FIPS 204, Algorithm 32 (`ExpandA`).
1072template <int K, int L>
1073void matrix_expand(matrix<K, L> *out, const uint8_t rho[kRhoBytes]) {
1074  static_assert(K <= 0x100, "K must fit in 8 bits");
1075  static_assert(L <= 0x100, "L must fit in 8 bits");
1076
1077  uint8_t derived_seed[kRhoBytes + 2];
1078  OPENSSL_memcpy(derived_seed, rho, kRhoBytes);
1079  for (int i = 0; i < K; i++) {
1080    for (int j = 0; j < L; j++) {
1081      derived_seed[kRhoBytes + 1] = (uint8_t)i;
1082      derived_seed[kRhoBytes] = (uint8_t)j;
1083      scalar_from_keccak_vartime(&out->v[i][j], derived_seed);
1084    }
1085  }
1086}
1087
1088// FIPS 204, Algorithm 33 (`ExpandS`).
1089template <int K, int L>
1090void vector_expand_short(vector<L> *s1, vector<K> *s2,
1091                         const uint8_t sigma[kSigmaBytes]) {
1092  static_assert(K <= 0x100, "K must fit in 8 bits");
1093  static_assert(L <= 0x100, "L must fit in 8 bits");
1094  static_assert(K + L <= 0x100, "K+L must fit in 8 bits");
1095
1096  uint8_t derived_seed[kSigmaBytes + 2];
1097  OPENSSL_memcpy(derived_seed, sigma, kSigmaBytes);
1098  derived_seed[kSigmaBytes] = 0;
1099  derived_seed[kSigmaBytes + 1] = 0;
1100  for (int i = 0; i < L; i++) {
1101    scalar_uniform<eta<K>()>(&s1->v[i], derived_seed);
1102    ++derived_seed[kSigmaBytes];
1103  }
1104  for (int i = 0; i < K; i++) {
1105    scalar_uniform<eta<K>()>(&s2->v[i], derived_seed);
1106    ++derived_seed[kSigmaBytes];
1107  }
1108}
1109
1110// FIPS 204, Algorithm 34 (`ExpandMask`).
1111template <int L>
1112void vector_expand_mask(vector<L> *out, const uint8_t seed[kRhoPrimeBytes],
1113                        size_t kappa) {
1114  assert(kappa + L <= 0x10000);
1115
1116  uint8_t derived_seed[kRhoPrimeBytes + 2];
1117  OPENSSL_memcpy(derived_seed, seed, kRhoPrimeBytes);
1118  for (int i = 0; i < L; i++) {
1119    size_t index = kappa + i;
1120    derived_seed[kRhoPrimeBytes] = index & 0xFF;
1121    derived_seed[kRhoPrimeBytes + 1] = (index >> 8) & 0xFF;
1122    scalar_sample_mask(&out->v[i], derived_seed);
1123  }
1124}
1125
1126/* Encoding */
1127
1128// FIPS 204, Algorithm 16 (`SimpleBitPack`).
1129//
1130// Encodes an entire vector into 32*K*|bits| bytes. Note that since 256
1131// (kDegree) is divisible by 8, the individual vector entries will always fill a
1132// whole number of bytes, so we do not need to worry about bit packing here.
1133template <int K>
1134void vector_encode(uint8_t *out, const vector<K> *a, int bits) {
1135  if (bits == 4) {
1136    for (int i = 0; i < K; i++) {
1137      scalar_encode_4(out + i * bits * kDegree / 8, &a->v[i]);
1138    }
1139  } else {
1140    assert(bits == 10);
1141    for (int i = 0; i < K; i++) {
1142      scalar_encode_10(out + i * bits * kDegree / 8, &a->v[i]);
1143    }
1144  }
1145}
1146
1147// FIPS 204, Algorithm 18 (`SimpleBitUnpack`).
1148template <int K>
1149void vector_decode_10(vector<K> *out, const uint8_t *in) {
1150  for (int i = 0; i < K; i++) {
1151    scalar_decode_10(&out->v[i], in + i * 10 * kDegree / 8);
1152  }
1153}
1154
1155// FIPS 204, Algorithm 17 (`BitPack`).
1156//
1157// Encodes an entire vector into 32*L*|bits| bytes. Note that since 256
1158// (kDegree) is divisible by 8, the individual vector entries will always fill a
1159// whole number of bytes, so we do not need to worry about bit packing here.
1160template <int X>
1161void vector_encode_signed(uint8_t *out, const vector<X> *a, int bits,
1162                          uint32_t max) {
1163  for (int i = 0; i < X; i++) {
1164    scalar_encode_signed(out + i * bits * kDegree / 8, &a->v[i], bits, max);
1165  }
1166}
1167
1168template <int X>
1169int vector_decode_signed(vector<X> *out, const uint8_t *in, int bits,
1170                         uint32_t max) {
1171  for (int i = 0; i < X; i++) {
1172    if (!scalar_decode_signed(&out->v[i], in + i * bits * kDegree / 8, bits,
1173                              max)) {
1174      return 0;
1175    }
1176  }
1177  return 1;
1178}
1179
1180// FIPS 204, Algorithm 28 (`w1Encode`).
1181template <int K>
1182void w1_encode(uint8_t out[128 * K], const vector<K> *w1) {
1183  vector_encode(out, w1, 4);
1184}
1185
1186// FIPS 204, Algorithm 20 (`HintBitPack`).
1187template <int K>
1188void hint_bit_pack(uint8_t out[omega<K>() + K], const vector<K> *h) {
1189  OPENSSL_memset(out, 0, omega<K>() + K);
1190  int index = 0;
1191  for (int i = 0; i < K; i++) {
1192    for (int j = 0; j < kDegree; j++) {
1193      if (h->v[i].c[j]) {
1194        // h must have at most omega<K>() non-zero coefficients.
1195        BSSL_CHECK(index < omega<K>());
1196        out[index++] = j;
1197      }
1198    }
1199    out[omega<K>() + i] = index;
1200  }
1201}
1202
1203// FIPS 204, Algorithm 21 (`HintBitUnpack`).
1204template <int K>
1205int hint_bit_unpack(vector<K> *h, const uint8_t in[omega<K>() + K]) {
1206  vector_zero(h);
1207  int index = 0;
1208  for (int i = 0; i < K; i++) {
1209    const int limit = in[omega<K>() + i];
1210    if (limit < index || limit > omega<K>()) {
1211      return 0;
1212    }
1213
1214    int last = -1;
1215    while (index < limit) {
1216      int byte = in[index++];
1217      if (last >= 0 && byte <= last) {
1218        return 0;
1219      }
1220      last = byte;
1221      static_assert(kDegree == 256,
1222                    "kDegree must be 256 for this write to be in bounds");
1223      h->v[i].c[byte] = 1;
1224    }
1225  }
1226  for (; index < omega<K>(); index++) {
1227    if (in[index] != 0) {
1228      return 0;
1229    }
1230  }
1231  return 1;
1232}
1233
1234template <int K>
1235struct public_key {
1236  uint8_t rho[kRhoBytes];
1237  vector<K> t1;
1238  // Pre-cached value(s).
1239  uint8_t public_key_hash[kTrBytes];
1240};
1241
1242template <int K, int L>
1243struct private_key {
1244  uint8_t rho[kRhoBytes];
1245  uint8_t k[kKBytes];
1246  uint8_t public_key_hash[kTrBytes];
1247  vector<L> s1;
1248  vector<K> s2;
1249  vector<K> t0;
1250};
1251
1252template <int K, int L>
1253struct signature {
1254  uint8_t c_tilde[2 * lambda_bytes<K>()];
1255  vector<L> z;
1256  vector<K> h;
1257};
1258
1259// FIPS 204, Algorithm 22 (`pkEncode`).
1260template <int K>
1261int mldsa_marshal_public_key(CBB *out, const struct public_key<K> *pub) {
1262  if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
1263    return 0;
1264  }
1265
1266  uint8_t *vectork_output;
1267  if (!CBB_add_space(out, &vectork_output, 320 * K)) {
1268    return 0;
1269  }
1270  vector_encode(vectork_output, &pub->t1, 10);
1271
1272  return 1;
1273}
1274
1275// FIPS 204, Algorithm 23 (`pkDecode`).
1276template <int K>
1277int mldsa_parse_public_key(struct public_key<K> *pub, CBS *in) {
1278  const CBS orig_in = *in;
1279
1280  if (!CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
1281    return 0;
1282  }
1283
1284  CBS t1_bytes;
1285  if (!CBS_get_bytes(in, &t1_bytes, 320 * K) || CBS_len(in) != 0) {
1286    return 0;
1287  }
1288  vector_decode_10(&pub->t1, CBS_data(&t1_bytes));
1289
1290  // Compute pre-cached values.
1291  BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
1292                   CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
1293
1294  return 1;
1295}
1296
1297// FIPS 204, Algorithm 24 (`skEncode`).
1298template <int K, int L>
1299int mldsa_marshal_private_key(CBB *out, const struct private_key<K, L> *priv) {
1300  if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
1301      !CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
1302      !CBB_add_bytes(out, priv->public_key_hash,
1303                     sizeof(priv->public_key_hash))) {
1304    return 0;
1305  }
1306
1307  constexpr size_t scalar_bytes =
1308      (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
1309  uint8_t *vectorl_output;
1310  if (!CBB_add_space(out, &vectorl_output, scalar_bytes * L)) {
1311    return 0;
1312  }
1313  vector_encode_signed(vectorl_output, &priv->s1, plus_minus_eta_bitlen<K>(),
1314                       eta<K>());
1315
1316  uint8_t *s2_output;
1317  if (!CBB_add_space(out, &s2_output, scalar_bytes * K)) {
1318    return 0;
1319  }
1320  vector_encode_signed(s2_output, &priv->s2, plus_minus_eta_bitlen<K>(),
1321                       eta<K>());
1322
1323  uint8_t *t0_output;
1324  if (!CBB_add_space(out, &t0_output, 416 * K)) {
1325    return 0;
1326  }
1327  vector_encode_signed(t0_output, &priv->t0, 13, 1 << 12);
1328
1329  return 1;
1330}
1331
1332// FIPS 204, Algorithm 25 (`skDecode`).
1333template <int K, int L>
1334int mldsa_parse_private_key(struct private_key<K, L> *priv, CBS *in) {
1335  CBS s1_bytes;
1336  CBS s2_bytes;
1337  CBS t0_bytes;
1338  constexpr size_t scalar_bytes =
1339      (kDegree * plus_minus_eta_bitlen<K>() + 7) / 8;
1340  if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
1341      !CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
1342      !CBS_copy_bytes(in, priv->public_key_hash,
1343                      sizeof(priv->public_key_hash)) ||
1344      !CBS_get_bytes(in, &s1_bytes, scalar_bytes * L) ||
1345      !vector_decode_signed(&priv->s1, CBS_data(&s1_bytes),
1346                            plus_minus_eta_bitlen<K>(), eta<K>()) ||
1347      !CBS_get_bytes(in, &s2_bytes, scalar_bytes * K) ||
1348      !vector_decode_signed(&priv->s2, CBS_data(&s2_bytes),
1349                            plus_minus_eta_bitlen<K>(), eta<K>()) ||
1350      !CBS_get_bytes(in, &t0_bytes, 416 * K) ||
1351      // Note: Decoding 13 bits into (-2^12, 2^12] cannot fail.
1352      !vector_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
1353    return 0;
1354  }
1355
1356  return 1;
1357}
1358
1359// FIPS 204, Algorithm 26 (`sigEncode`).
1360template <int K, int L>
1361int mldsa_marshal_signature(CBB *out, const struct signature<K, L> *sign) {
1362  if (!CBB_add_bytes(out, sign->c_tilde, sizeof(sign->c_tilde))) {
1363    return 0;
1364  }
1365
1366  uint8_t *vectorl_output;
1367  if (!CBB_add_space(out, &vectorl_output, 640 * L)) {
1368    return 0;
1369  }
1370  vector_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
1371
1372  uint8_t *hint_output;
1373  if (!CBB_add_space(out, &hint_output, omega<K>() + K)) {
1374    return 0;
1375  }
1376  hint_bit_pack(hint_output, &sign->h);
1377
1378  return 1;
1379}
1380
1381// FIPS 204, Algorithm 27 (`sigDecode`).
1382template <int K, int L>
1383int mldsa_parse_signature(struct signature<K, L> *sign, CBS *in) {
1384  CBS z_bytes;
1385  CBS hint_bytes;
1386  if (!CBS_copy_bytes(in, sign->c_tilde, sizeof(sign->c_tilde)) ||
1387      !CBS_get_bytes(in, &z_bytes, 640 * L) ||
1388      // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
1389      !vector_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
1390      !CBS_get_bytes(in, &hint_bytes, omega<K>() + K) ||
1391      !hint_bit_unpack(&sign->h, CBS_data(&hint_bytes))) {
1392    return 0;
1393  };
1394
1395  return 1;
1396}
1397
1398template <typename T>
1399struct DeleterFree {
1400  void operator()(T *ptr) { OPENSSL_free(ptr); }
1401};
1402
1403// FIPS 204, Algorithm 6 (`ML-DSA.KeyGen_internal`). Returns 1 on success and 0
1404// on failure.
1405template <int K, int L>
1406int mldsa_generate_key_external_entropy_no_self_test(
1407    uint8_t out_encoded_public_key[public_key_bytes<K>()],
1408    struct private_key<K, L> *priv,
1409    const uint8_t entropy[BCM_MLDSA_SEED_BYTES]) {
1410  // Intermediate values, allocated on the heap to allow use when there is a
1411  // limited amount of stack.
1412  struct values_st {
1413    struct public_key<K> pub;
1414    matrix<K, L> a_ntt;
1415    vector<L> s1_ntt;
1416    vector<K> t;
1417  };
1418  std::unique_ptr<values_st, DeleterFree<values_st>> values(
1419      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
1420  if (values == NULL) {
1421    return 0;
1422  }
1423
1424  uint8_t augmented_entropy[BCM_MLDSA_SEED_BYTES + 2];
1425  OPENSSL_memcpy(augmented_entropy, entropy, BCM_MLDSA_SEED_BYTES);
1426  // The k and l parameters are appended to the seed.
1427  augmented_entropy[BCM_MLDSA_SEED_BYTES] = K;
1428  augmented_entropy[BCM_MLDSA_SEED_BYTES + 1] = L;
1429  uint8_t expanded_seed[kRhoBytes + kSigmaBytes + kKBytes];
1430  BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), augmented_entropy,
1431                   sizeof(augmented_entropy), boringssl_shake256);
1432  const uint8_t *const rho = expanded_seed;
1433  const uint8_t *const sigma = expanded_seed + kRhoBytes;
1434  const uint8_t *const k = expanded_seed + kRhoBytes + kSigmaBytes;
1435  // rho is public.
1436  CONSTTIME_DECLASSIFY(rho, kRhoBytes);
1437  OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
1438  OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
1439  OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
1440
1441  matrix_expand(&values->a_ntt, rho);
1442  vector_expand_short(&priv->s1, &priv->s2, sigma);
1443
1444  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1445  vector_ntt(&values->s1_ntt);
1446
1447  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1448  vector_inverse_ntt(&values->t);
1449  vector_add(&values->t, &values->t, &priv->s2);
1450
1451  vector_power2_round(&values->pub.t1, &priv->t0, &values->t);
1452  // t1 is public.
1453  CONSTTIME_DECLASSIFY(&values->pub.t1, sizeof(values->pub.t1));
1454
1455  CBB cbb;
1456  CBB_init_fixed(&cbb, out_encoded_public_key, public_key_bytes<K>());
1457  if (!mldsa_marshal_public_key(&cbb, &values->pub)) {
1458    return 0;
1459  }
1460  assert(CBB_len(&cbb) == public_key_bytes<K>());
1461
1462  BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
1463                   out_encoded_public_key, public_key_bytes<K>(),
1464                   boringssl_shake256);
1465
1466  return 1;
1467}
1468
1469template <int K, int L>
1470int mldsa_generate_key_external_entropy(
1471    uint8_t out_encoded_public_key[public_key_bytes<K>()],
1472    struct private_key<K, L> *priv,
1473    const uint8_t entropy[BCM_MLDSA_SEED_BYTES]) {
1474  fips::ensure_keygen_self_test();
1475  return mldsa_generate_key_external_entropy_no_self_test(
1476      out_encoded_public_key, priv, entropy);
1477}
1478
1479template <int K, int L>
1480int mldsa_public_from_private(struct public_key<K> *pub,
1481                              const struct private_key<K, L> *priv) {
1482  // Intermediate values, allocated on the heap to allow use when there is a
1483  // limited amount of stack.
1484  struct values_st {
1485    matrix<K, L> a_ntt;
1486    vector<L> s1_ntt;
1487    vector<K> t;
1488    vector<K> t0;
1489  };
1490  std::unique_ptr<values_st, DeleterFree<values_st>> values(
1491      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
1492  if (values == NULL) {
1493    return 0;
1494  }
1495
1496  OPENSSL_memcpy(pub->rho, priv->rho, sizeof(pub->rho));
1497  OPENSSL_memcpy(pub->public_key_hash, priv->public_key_hash,
1498                 sizeof(pub->public_key_hash));
1499
1500  matrix_expand(&values->a_ntt, priv->rho);
1501
1502  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1503  vector_ntt(&values->s1_ntt);
1504
1505  matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1506  vector_inverse_ntt(&values->t);
1507  vector_add(&values->t, &values->t, &priv->s2);
1508
1509  vector_power2_round(&pub->t1, &values->t0, &values->t);
1510  // t1 is part of the public key and thus is public.
1511  CONSTTIME_DECLASSIFY(&pub->t1, sizeof(pub->t1));
1512  return 1;
1513}
1514
1515// FIPS 204, Algorithm 7 (`ML-DSA.Sign_internal`). Returns 1 on success and 0
1516// on failure.
1517template <int K, int L>
1518int mldsa_sign_internal_no_self_test(
1519    uint8_t out_encoded_signature[signature_bytes<K>()],
1520    const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
1521    const uint8_t *context_prefix, size_t context_prefix_len,
1522    const uint8_t *context, size_t context_len,
1523    const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
1524  uint8_t mu[kMuBytes];
1525  struct BORINGSSL_keccak_st keccak_ctx;
1526  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1527  BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
1528                          sizeof(priv->public_key_hash));
1529  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
1530  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
1531  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1532  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
1533
1534  uint8_t rho_prime[kRhoPrimeBytes];
1535  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1536  BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
1537  BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
1538                          BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES);
1539  BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
1540  BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, kRhoPrimeBytes);
1541
1542  // Intermediate values, allocated on the heap to allow use when there is a
1543  // limited amount of stack.
1544  struct values_st {
1545    struct signature<K, L> sign;
1546    vector<L> s1_ntt;
1547    vector<K> s2_ntt;
1548    vector<K> t0_ntt;
1549    matrix<K, L> a_ntt;
1550    vector<L> y;
1551    vector<K> w;
1552    vector<K> w1;
1553    vector<L> cs1;
1554    vector<K> cs2;
1555  };
1556  std::unique_ptr<values_st, DeleterFree<values_st>> values(
1557      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
1558  if (values == NULL) {
1559    return 0;
1560  }
1561  OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1562  vector_ntt(&values->s1_ntt);
1563
1564  OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
1565  vector_ntt(&values->s2_ntt);
1566
1567  OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
1568  vector_ntt(&values->t0_ntt);
1569
1570  matrix_expand(&values->a_ntt, priv->rho);
1571
1572  // kappa must not exceed 2**16/L = 13107. But the probability of it
1573  // exceeding even 1000 iterations is vanishingly small.
1574  for (size_t kappa = 0;; kappa += L) {
1575    vector_expand_mask(&values->y, rho_prime, kappa);
1576
1577    vector<L> *y_ntt = &values->cs1;
1578    OPENSSL_memcpy(y_ntt, &values->y, sizeof(*y_ntt));
1579    vector_ntt(y_ntt);
1580
1581    matrix_mult(&values->w, &values->a_ntt, y_ntt);
1582    vector_inverse_ntt(&values->w);
1583
1584    vector_high_bits(&values->w1, &values->w);
1585    uint8_t w1_encoded[128 * K];
1586    w1_encode(w1_encoded, &values->w1);
1587
1588    BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1589    BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
1590    BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1591    BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
1592                             2 * lambda_bytes<K>());
1593
1594    scalar c_ntt;
1595    scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
1596                                  sizeof(values->sign.c_tilde), tau<K>());
1597    scalar_ntt(&c_ntt);
1598
1599    vector_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
1600    vector_inverse_ntt(&values->cs1);
1601    vector_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
1602    vector_inverse_ntt(&values->cs2);
1603
1604    vector_add(&values->sign.z, &values->y, &values->cs1);
1605
1606    vector<K> *r0 = &values->w1;
1607    vector_sub(r0, &values->w, &values->cs2);
1608    vector_low_bits(r0, r0);
1609
1610    // Leaking the fact that a signature was rejected is fine as the next
1611    // attempt at a signature will be (indistinguishable from) independent of
1612    // this one. Note, however, that we additionally leak which of the two
1613    // branches rejected the signature. Section 5.5 of
1614    // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
1615    // describes this leak as OK. Note we leak less than what is described by
1616    // the paper; we do not reveal which coefficient violated the bound, and
1617    // we hide which of the |z_max| or |r0_max| bound failed. See also
1618    // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
1619    uint32_t z_max = vector_max(&values->sign.z);
1620    uint32_t r0_max = vector_max_signed(r0);
1621    if (constant_time_declassify_w(
1622            constant_time_ge_w(z_max, gamma1<K>() - beta<K>()) |
1623            constant_time_ge_w(r0_max, kGamma2 - beta<K>()))) {
1624#if defined(BORINGSSL_FIPS_BREAK_TESTS)
1625      // In order to show that our self-tests trigger both restart cases in
1626      // this loop, printf-logging is added when built in break-test mode.
1627      printf("MLDSA signature restart case 1.\n");
1628#endif
1629      continue;
1630    }
1631
1632    vector<K> *ct0 = &values->w1;
1633    vector_mult_scalar(ct0, &values->t0_ntt, &c_ntt);
1634    vector_inverse_ntt(ct0);
1635    vector_make_hint(&values->sign.h, ct0, &values->cs2, &values->w);
1636
1637    // See above.
1638    uint32_t ct0_max = vector_max(ct0);
1639    size_t h_ones = vector_count_ones(&values->sign.h);
1640    if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
1641                                   constant_time_lt_w(omega<K>(), h_ones))) {
1642#if defined(BORINGSSL_FIPS_BREAK_TESTS)
1643      // In order to show that our self-tests trigger both restart cases in
1644      // this loop, printf-logging is added when built in break-test mode.
1645      printf("MLDSA signature restart case 2.\n");
1646#endif
1647      continue;
1648    }
1649
1650    // Although computed with the private key, the signature is public.
1651    CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
1652    CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
1653    CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
1654
1655    CBB cbb;
1656    CBB_init_fixed(&cbb, out_encoded_signature, signature_bytes<K>());
1657    if (!mldsa_marshal_signature(&cbb, &values->sign)) {
1658      return 0;
1659    }
1660
1661    BSSL_CHECK(CBB_len(&cbb) == signature_bytes<K>());
1662    return 1;
1663  }
1664}
1665
1666template <int K, int L>
1667int mldsa_sign_internal(
1668    uint8_t out_encoded_signature[signature_bytes<K>()],
1669    const struct private_key<K, L> *priv, const uint8_t *msg, size_t msg_len,
1670    const uint8_t *context_prefix, size_t context_prefix_len,
1671    const uint8_t *context, size_t context_len,
1672    const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
1673  fips::ensure_sign_self_test();
1674  return mldsa_sign_internal_no_self_test(
1675      out_encoded_signature, priv, msg, msg_len, context_prefix,
1676      context_prefix_len, context, context_len, randomizer);
1677}
1678
1679// FIPS 204, Algorithm 8 (`ML-DSA.Verify_internal`).
1680template <int K, int L>
1681int mldsa_verify_internal_no_self_test(
1682    const struct public_key<K> *pub,
1683    const uint8_t encoded_signature[signature_bytes<K>()], const uint8_t *msg,
1684    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
1685    const uint8_t *context, size_t context_len) {
1686  // Intermediate values, allocated on the heap to allow use when there is a
1687  // limited amount of stack.
1688  struct values_st {
1689    struct signature<K, L> sign;
1690    matrix<K, L> a_ntt;
1691    vector<L> z_ntt;
1692    vector<K> az_ntt;
1693    vector<K> ct1_ntt;
1694  };
1695  std::unique_ptr<values_st, DeleterFree<values_st>> values(
1696      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
1697  if (values == NULL) {
1698    return 0;
1699  }
1700
1701  CBS cbs;
1702  CBS_init(&cbs, encoded_signature, signature_bytes<K>());
1703  if (!mldsa_parse_signature(&values->sign, &cbs)) {
1704    return 0;
1705  }
1706
1707  matrix_expand(&values->a_ntt, pub->rho);
1708
1709  uint8_t mu[kMuBytes];
1710  struct BORINGSSL_keccak_st keccak_ctx;
1711  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1712  BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
1713                          sizeof(pub->public_key_hash));
1714  BORINGSSL_keccak_absorb(&keccak_ctx, context_prefix, context_prefix_len);
1715  BORINGSSL_keccak_absorb(&keccak_ctx, context, context_len);
1716  BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1717  BORINGSSL_keccak_squeeze(&keccak_ctx, mu, kMuBytes);
1718
1719  scalar c_ntt;
1720  scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde,
1721                                sizeof(values->sign.c_tilde), tau<K>());
1722  scalar_ntt(&c_ntt);
1723
1724  OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
1725  vector_ntt(&values->z_ntt);
1726
1727  matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
1728
1729  vector_scale_power2_round(&values->ct1_ntt, &pub->t1);
1730  vector_ntt(&values->ct1_ntt);
1731
1732  vector_mult_scalar(&values->ct1_ntt, &values->ct1_ntt, &c_ntt);
1733
1734  vector<K> *const w1 = &values->az_ntt;
1735  vector_sub(w1, &values->az_ntt, &values->ct1_ntt);
1736  vector_inverse_ntt(w1);
1737
1738  vector_use_hint_vartime(w1, &values->sign.h, w1);
1739  uint8_t w1_encoded[128 * K];
1740  w1_encode(w1_encoded, w1);
1741
1742  uint8_t c_tilde[2 * lambda_bytes<K>()];
1743  BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1744  BORINGSSL_keccak_absorb(&keccak_ctx, mu, kMuBytes);
1745  BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1746  BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * lambda_bytes<K>());
1747
1748  uint32_t z_max = vector_max(&values->sign.z);
1749  return z_max < static_cast<uint32_t>(gamma1<K>() - beta<K>()) &&
1750         OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * lambda_bytes<K>()) ==
1751             0;
1752}
1753template <int K, int L>
1754int mldsa_verify_internal(const struct public_key<K> *pub,
1755                          const uint8_t encoded_signature[signature_bytes<K>()],
1756                          const uint8_t *msg, size_t msg_len,
1757                          const uint8_t *context_prefix,
1758                          size_t context_prefix_len, const uint8_t *context,
1759                          size_t context_len) {
1760  fips::ensure_verify_self_test();
1761  return mldsa_verify_internal_no_self_test<K, L>(
1762      pub, encoded_signature, msg, msg_len, context_prefix, context_prefix_len,
1763      context, context_len);
1764}
1765
1766struct private_key<6, 5> *private_key_from_external_65(
1767    const struct BCM_mldsa65_private_key *external) {
1768  static_assert(sizeof(struct BCM_mldsa65_private_key) ==
1769                    sizeof(struct private_key<6, 5>),
1770                "MLDSA65 private key size incorrect");
1771  static_assert(alignof(struct BCM_mldsa65_private_key) ==
1772                    alignof(struct private_key<6, 5>),
1773                "MLDSA65 private key align incorrect");
1774  return (struct private_key<6, 5> *)external;
1775}
1776
1777struct public_key<6> *public_key_from_external_65(
1778    const struct BCM_mldsa65_public_key *external) {
1779  static_assert(
1780      sizeof(struct BCM_mldsa65_public_key) == sizeof(struct public_key<6>),
1781      "MLDSA65 public key size incorrect");
1782  static_assert(
1783      alignof(struct BCM_mldsa65_public_key) == alignof(struct public_key<6>),
1784      "MLDSA65 public key align incorrect");
1785  return (struct public_key<6> *)external;
1786}
1787
1788struct private_key<8, 7> *private_key_from_external_87(
1789    const struct BCM_mldsa87_private_key *external) {
1790  static_assert(sizeof(struct BCM_mldsa87_private_key) ==
1791                    sizeof(struct private_key<8, 7>),
1792                "MLDSA87 private key size incorrect");
1793  static_assert(alignof(struct BCM_mldsa87_private_key) ==
1794                    alignof(struct private_key<8, 7>),
1795                "MLDSA87 private key align incorrect");
1796  return (struct private_key<8, 7> *)external;
1797}
1798
1799struct public_key<8> *public_key_from_external_87(
1800    const struct BCM_mldsa87_public_key *external) {
1801  static_assert(
1802      sizeof(struct BCM_mldsa87_public_key) == sizeof(struct public_key<8>),
1803      "MLDSA87 public key size incorrect");
1804  static_assert(
1805      alignof(struct BCM_mldsa87_public_key) == alignof(struct public_key<8>),
1806      "MLDSA87 public key align incorrect");
1807  return (struct public_key<8> *)external;
1808}
1809
1810namespace fips {
1811
1812#include "fips_known_values.inc"
1813
1814static int keygen_self_test() {
1815  private_key<6, 5> priv;
1816  uint8_t pub_bytes[BCM_MLDSA65_PUBLIC_KEY_BYTES];
1817  if (!mldsa_generate_key_external_entropy_no_self_test(pub_bytes, &priv,
1818                                                        kGenerateKeyEntropy)) {
1819    return 0;
1820  }
1821
1822  uint8_t priv_bytes[BCM_MLDSA65_PRIVATE_KEY_BYTES];
1823  CBB cbb;
1824  CBB_init_fixed(&cbb, priv_bytes, sizeof(priv_bytes));
1825  if (!mldsa_marshal_private_key(&cbb, &priv)) {
1826    return 0;
1827  }
1828
1829  static_assert(sizeof(pub_bytes) == sizeof(kExpectedPublicKey));
1830  static_assert(sizeof(priv_bytes) == sizeof(kExpectedPrivateKey));
1831  if (!BORINGSSL_check_test(kExpectedPublicKey, pub_bytes, sizeof(pub_bytes),
1832                            "ML-DSA keygen public key") ||
1833      !BORINGSSL_check_test(kExpectedPrivateKey, priv_bytes, sizeof(priv_bytes),
1834                            "ML-DSA keygen private key")) {
1835    return 0;
1836  }
1837
1838  return 1;
1839}
1840
1841static int sign_self_test() {
1842  private_key<6, 5> priv;
1843  uint8_t pub_bytes[BCM_MLDSA65_PUBLIC_KEY_BYTES];
1844  if (!mldsa_generate_key_external_entropy(pub_bytes, &priv, kSignEntropy)) {
1845    return 0;
1846  }
1847
1848  const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {};
1849  uint8_t sig[BCM_MLDSA65_SIGNATURE_BYTES];
1850
1851  // This message triggers the first restart case for signing.
1852  uint8_t message[4] = {0};
1853  if (!mldsa_sign_internal_no_self_test(sig, &priv, message, sizeof(message),
1854                                        nullptr, 0, nullptr, 0, randomizer)) {
1855    return 0;
1856  }
1857  static_assert(sizeof(kExpectedCase1Signature) == sizeof(sig));
1858  if (!BORINGSSL_check_test(kExpectedCase1Signature, sig, sizeof(sig),
1859                            "ML-DSA sign case 1")) {
1860    return 0;
1861  }
1862
1863  // This message triggers the second restart case for signing.
1864  message[0] = 123;
1865  if (!mldsa_sign_internal_no_self_test(sig, &priv, message, sizeof(message),
1866                                        nullptr, 0, nullptr, 0, randomizer)) {
1867    return 0;
1868  }
1869  static_assert(sizeof(kExpectedCase2Signature) == sizeof(sig));
1870  if (!BORINGSSL_check_test(kExpectedCase2Signature, sig, sizeof(sig),
1871                            "ML-DSA sign case 2")) {
1872    return 0;
1873  }
1874
1875  return 1;
1876}
1877
1878static int verify_self_test() {
1879  struct values_st {
1880    private_key<6, 5> priv;
1881    public_key<6> pub;
1882    uint8_t pub_bytes[BCM_MLDSA65_PUBLIC_KEY_BYTES];
1883  };
1884  std::unique_ptr<values_st, DeleterFree<values_st>> values(
1885      reinterpret_cast<struct values_st *>(OPENSSL_malloc(sizeof(values_st))));
1886  if (!values) {
1887    return 0;
1888  }
1889
1890  if (!mldsa_generate_key_external_entropy(values->pub_bytes, &values->priv,
1891                                           kSignEntropy)) {
1892    return 0;
1893  }
1894
1895  const uint8_t message[4] = {1, 0};
1896  if (!mldsa_public_from_private(&values->pub, &values->priv) ||
1897      !mldsa_verify_internal_no_self_test<6, 5>(
1898          &values->pub, kExpectedVerifySignature, message, sizeof(message),
1899          nullptr, 0, nullptr, 0)) {
1900    return 0;
1901  }
1902
1903  return 1;
1904}
1905
1906template <int K, int L>
1907int check_key(private_key<K, L> *priv) {
1908  uint8_t sig[signature_bytes<K>()];
1909  uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES] = {};
1910  mldsa::public_key<K> pub;
1911  if (!mldsa_public_from_private(&pub, priv) ||
1912      !mldsa_sign_internal_no_self_test(sig, priv, nullptr, 0, nullptr, 0,
1913                                        nullptr, 0, randomizer)) {
1914    return 0;
1915  }
1916
1917  if (boringssl_fips_break_test("MLDSA_PWCT")) {
1918    sig[0] ^= 1;
1919  }
1920
1921  if (!mldsa_verify_internal_no_self_test<K, L>(&pub, sig, nullptr, 0, nullptr,
1922                                                0, nullptr, 0)) {
1923    return 0;
1924  }
1925  return 1;
1926}
1927
1928#if defined(BORINGSSL_FIPS)
1929
1930DEFINE_STATIC_ONCE(g_mldsa_keygen_self_test_once)
1931
1932void ensure_keygen_self_test(void) {
1933  CRYPTO_once(g_mldsa_keygen_self_test_once_bss_get(), []() {
1934    if (!keygen_self_test()) {
1935      BORINGSSL_FIPS_abort();
1936    }
1937  });
1938}
1939
1940DEFINE_STATIC_ONCE(g_mldsa_sign_self_test_once)
1941
1942void ensure_sign_self_test(void) {
1943  CRYPTO_once(g_mldsa_sign_self_test_once_bss_get(), []() {
1944    if (!sign_self_test()) {
1945      BORINGSSL_FIPS_abort();
1946    }
1947  });
1948}
1949
1950DEFINE_STATIC_ONCE(g_mldsa_verify_self_test_once)
1951
1952void ensure_verify_self_test(void) {
1953  CRYPTO_once(g_mldsa_verify_self_test_once_bss_get(), []() {
1954    if (!verify_self_test()) {
1955      BORINGSSL_FIPS_abort();
1956    }
1957  });
1958}
1959
1960#else
1961
1962void ensure_keygen_self_test(void) {}
1963void ensure_sign_self_test(void) {}
1964void ensure_verify_self_test(void) {}
1965
1966#endif
1967
1968}  // namespace fips
1969
1970}  // namespace
1971}  // namespace mldsa
1972
1973
1974// ML-DSA-65 specific wrappers.
1975
1976bcm_status BCM_mldsa65_parse_public_key(
1977    struct BCM_mldsa65_public_key *public_key, CBS *in) {
1978  return bcm_as_approved_status(mldsa_parse_public_key(
1979      mldsa::public_key_from_external_65(public_key), in));
1980}
1981
1982bcm_status BCM_mldsa65_marshal_private_key(
1983    CBB *out, const struct BCM_mldsa65_private_key *private_key) {
1984  return bcm_as_approved_status(mldsa_marshal_private_key(
1985      out, mldsa::private_key_from_external_65(private_key)));
1986}
1987
1988bcm_status BCM_mldsa65_parse_private_key(
1989    struct BCM_mldsa65_private_key *private_key, CBS *in) {
1990  return bcm_as_approved_status(
1991      mldsa_parse_private_key(mldsa::private_key_from_external_65(private_key),
1992                              in) &&
1993      CBS_len(in) == 0);
1994}
1995
1996bcm_status BCM_mldsa65_check_key_fips(
1997    struct BCM_mldsa65_private_key *private_key) {
1998  return bcm_as_approved_status(
1999      mldsa::fips::check_key(mldsa::private_key_from_external_65(private_key)));
2000}
2001
2002// Calls |MLDSA_generate_key_external_entropy| with random bytes from
2003// |BCM_rand_bytes|.
2004bcm_status BCM_mldsa65_generate_key(
2005    uint8_t out_encoded_public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES],
2006    uint8_t out_seed[BCM_MLDSA_SEED_BYTES],
2007    struct BCM_mldsa65_private_key *out_private_key) {
2008  BCM_rand_bytes(out_seed, BCM_MLDSA_SEED_BYTES);
2009  CONSTTIME_SECRET(out_seed, BCM_MLDSA_SEED_BYTES);
2010  return BCM_mldsa65_generate_key_external_entropy(out_encoded_public_key,
2011                                                   out_private_key, out_seed);
2012}
2013
2014bcm_status BCM_mldsa65_private_key_from_seed(
2015    struct BCM_mldsa65_private_key *out_private_key,
2016    const uint8_t seed[BCM_MLDSA_SEED_BYTES]) {
2017  uint8_t public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES];
2018  return BCM_mldsa65_generate_key_external_entropy(public_key, out_private_key,
2019                                                   seed);
2020}
2021
2022bcm_status BCM_mldsa65_generate_key_external_entropy(
2023    uint8_t out_encoded_public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES],
2024    struct BCM_mldsa65_private_key *out_private_key,
2025    const uint8_t entropy[BCM_MLDSA_SEED_BYTES]) {
2026  return bcm_as_not_approved_status(mldsa_generate_key_external_entropy(
2027      out_encoded_public_key,
2028      mldsa::private_key_from_external_65(out_private_key), entropy));
2029}
2030
2031bcm_status BCM_mldsa65_generate_key_fips(
2032    uint8_t out_encoded_public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES],
2033    uint8_t out_seed[BCM_MLDSA_SEED_BYTES],
2034    struct BCM_mldsa65_private_key *out_private_key) {
2035  if (BCM_mldsa65_generate_key(out_encoded_public_key, out_seed,
2036                               out_private_key) == bcm_status::failure) {
2037    return bcm_status::failure;
2038  }
2039  return BCM_mldsa65_check_key_fips(out_private_key);
2040}
2041
2042bcm_status BCM_mldsa65_generate_key_external_entropy_fips(
2043    uint8_t out_encoded_public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES],
2044    struct BCM_mldsa65_private_key *out_private_key,
2045    const uint8_t entropy[BCM_MLDSA_SEED_BYTES]) {
2046  if (BCM_mldsa65_generate_key_external_entropy(out_encoded_public_key,
2047                                                out_private_key, entropy) ==
2048      bcm_status::failure) {
2049    return bcm_status::failure;
2050  }
2051  return BCM_mldsa65_check_key_fips(out_private_key);
2052}
2053
2054bcm_status BCM_mldsa65_private_key_from_seed_fips(
2055    struct BCM_mldsa65_private_key *out_private_key,
2056    const uint8_t seed[BCM_MLDSA_SEED_BYTES]) {
2057  uint8_t public_key[BCM_MLDSA65_PUBLIC_KEY_BYTES];
2058  if (BCM_mldsa65_generate_key_external_entropy(public_key, out_private_key,
2059                                                seed) == bcm_status::failure) {
2060    return bcm_status::failure;
2061  }
2062  return BCM_mldsa65_check_key_fips(out_private_key);
2063}
2064
2065bcm_status BCM_mldsa65_public_from_private(
2066    struct BCM_mldsa65_public_key *out_public_key,
2067    const struct BCM_mldsa65_private_key *private_key) {
2068  return bcm_as_approved_status(mldsa_public_from_private(
2069      mldsa::public_key_from_external_65(out_public_key),
2070      mldsa::private_key_from_external_65(private_key)));
2071}
2072
2073bcm_status BCM_mldsa65_sign_internal(
2074    uint8_t out_encoded_signature[BCM_MLDSA65_SIGNATURE_BYTES],
2075    const struct BCM_mldsa65_private_key *private_key, const uint8_t *msg,
2076    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
2077    const uint8_t *context, size_t context_len,
2078    const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
2079  return bcm_as_approved_status(mldsa_sign_internal(
2080      out_encoded_signature, mldsa::private_key_from_external_65(private_key),
2081      msg, msg_len, context_prefix, context_prefix_len, context, context_len,
2082      randomizer));
2083}
2084
2085// ML-DSA signature in randomized mode, filling the random bytes with
2086// |BCM_rand_bytes|.
2087bcm_status BCM_mldsa65_sign(
2088    uint8_t out_encoded_signature[BCM_MLDSA65_SIGNATURE_BYTES],
2089    const struct BCM_mldsa65_private_key *private_key, const uint8_t *msg,
2090    size_t msg_len, const uint8_t *context, size_t context_len) {
2091  BSSL_CHECK(context_len <= 255);
2092  uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES];
2093  BCM_rand_bytes(randomizer, sizeof(randomizer));
2094  CONSTTIME_SECRET(randomizer, sizeof(randomizer));
2095
2096  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context_len)};
2097  return BCM_mldsa65_sign_internal(
2098      out_encoded_signature, private_key, msg, msg_len, context_prefix,
2099      sizeof(context_prefix), context, context_len, randomizer);
2100}
2101
2102// FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
2103bcm_status BCM_mldsa65_verify(
2104    const struct BCM_mldsa65_public_key *public_key,
2105    const uint8_t signature[BCM_MLDSA65_SIGNATURE_BYTES], const uint8_t *msg,
2106    size_t msg_len, const uint8_t *context, size_t context_len) {
2107  BSSL_CHECK(context_len <= 255);
2108  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context_len)};
2109  return BCM_mldsa65_verify_internal(public_key, signature, msg, msg_len,
2110                                     context_prefix, sizeof(context_prefix),
2111                                     context, context_len);
2112}
2113
2114bcm_status BCM_mldsa65_verify_internal(
2115    const struct BCM_mldsa65_public_key *public_key,
2116    const uint8_t encoded_signature[BCM_MLDSA65_SIGNATURE_BYTES],
2117    const uint8_t *msg, size_t msg_len, const uint8_t *context_prefix,
2118    size_t context_prefix_len, const uint8_t *context, size_t context_len) {
2119  return bcm_as_approved_status(mldsa::mldsa_verify_internal<6, 5>(
2120      mldsa::public_key_from_external_65(public_key), encoded_signature, msg,
2121      msg_len, context_prefix, context_prefix_len, context, context_len));
2122}
2123
2124bcm_status BCM_mldsa65_marshal_public_key(
2125    CBB *out, const struct BCM_mldsa65_public_key *public_key) {
2126  return bcm_as_approved_status(mldsa_marshal_public_key(
2127      out, mldsa::public_key_from_external_65(public_key)));
2128}
2129
2130
2131// ML-DSA-87 specific wrappers.
2132
2133bcm_status BCM_mldsa87_parse_public_key(
2134    struct BCM_mldsa87_public_key *public_key, CBS *in) {
2135  return bcm_as_approved_status(mldsa_parse_public_key(
2136      mldsa::public_key_from_external_87(public_key), in));
2137}
2138
2139bcm_status BCM_mldsa87_marshal_private_key(
2140    CBB *out, const struct BCM_mldsa87_private_key *private_key) {
2141  return bcm_as_approved_status(mldsa_marshal_private_key(
2142      out, mldsa::private_key_from_external_87(private_key)));
2143}
2144
2145bcm_status BCM_mldsa87_parse_private_key(
2146    struct BCM_mldsa87_private_key *private_key, CBS *in) {
2147  return bcm_as_approved_status(
2148      mldsa_parse_private_key(mldsa::private_key_from_external_87(private_key),
2149                              in) &&
2150      CBS_len(in) == 0);
2151}
2152
2153bcm_status BCM_mldsa87_check_key_fips(
2154    struct BCM_mldsa87_private_key *private_key) {
2155  return bcm_as_approved_status(
2156      mldsa::fips::check_key(mldsa::private_key_from_external_87(private_key)));
2157}
2158
2159// Calls |MLDSA_generate_key_external_entropy| with random bytes from
2160// |BCM_rand_bytes|.
2161bcm_status BCM_mldsa87_generate_key(
2162    uint8_t out_encoded_public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES],
2163    uint8_t out_seed[BCM_MLDSA_SEED_BYTES],
2164    struct BCM_mldsa87_private_key *out_private_key) {
2165  BCM_rand_bytes(out_seed, BCM_MLDSA_SEED_BYTES);
2166  return BCM_mldsa87_generate_key_external_entropy(out_encoded_public_key,
2167                                                   out_private_key, out_seed);
2168}
2169
2170bcm_status BCM_mldsa87_private_key_from_seed(
2171    struct BCM_mldsa87_private_key *out_private_key,
2172    const uint8_t seed[BCM_MLDSA_SEED_BYTES]) {
2173  uint8_t public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES];
2174  return BCM_mldsa87_generate_key_external_entropy(public_key, out_private_key,
2175                                                   seed);
2176}
2177
2178bcm_status BCM_mldsa87_generate_key_external_entropy(
2179    uint8_t out_encoded_public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES],
2180    struct BCM_mldsa87_private_key *out_private_key,
2181    const uint8_t entropy[BCM_MLDSA_SEED_BYTES]) {
2182  return bcm_as_not_approved_status(mldsa_generate_key_external_entropy(
2183      out_encoded_public_key,
2184      mldsa::private_key_from_external_87(out_private_key), entropy));
2185}
2186
2187bcm_status BCM_mldsa87_generate_key_fips(
2188    uint8_t out_encoded_public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES],
2189    uint8_t out_seed[BCM_MLDSA_SEED_BYTES],
2190    struct BCM_mldsa87_private_key *out_private_key) {
2191  if (BCM_mldsa87_generate_key(out_encoded_public_key, out_seed,
2192                               out_private_key) == bcm_status::failure) {
2193    return bcm_status::failure;
2194  }
2195  return BCM_mldsa87_check_key_fips(out_private_key);
2196}
2197
2198bcm_status BCM_mldsa87_generate_key_external_entropy_fips(
2199    uint8_t out_encoded_public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES],
2200    struct BCM_mldsa87_private_key *out_private_key,
2201    const uint8_t entropy[BCM_MLDSA_SEED_BYTES]) {
2202  if (BCM_mldsa87_generate_key_external_entropy(out_encoded_public_key,
2203                                                out_private_key, entropy) ==
2204      bcm_status::failure) {
2205    return bcm_status::failure;
2206  }
2207  return BCM_mldsa87_check_key_fips(out_private_key);
2208}
2209
2210bcm_status BCM_mldsa87_private_key_from_seed_fips(
2211    struct BCM_mldsa87_private_key *out_private_key,
2212    const uint8_t seed[BCM_MLDSA_SEED_BYTES]) {
2213  uint8_t public_key[BCM_MLDSA87_PUBLIC_KEY_BYTES];
2214  if (BCM_mldsa87_generate_key_external_entropy(public_key, out_private_key,
2215                                                seed) == bcm_status::failure) {
2216    return bcm_status::failure;
2217  }
2218  return BCM_mldsa87_check_key_fips(out_private_key);
2219}
2220
2221bcm_status BCM_mldsa87_public_from_private(
2222    struct BCM_mldsa87_public_key *out_public_key,
2223    const struct BCM_mldsa87_private_key *private_key) {
2224  return bcm_as_approved_status(mldsa_public_from_private(
2225      mldsa::public_key_from_external_87(out_public_key),
2226      mldsa::private_key_from_external_87(private_key)));
2227}
2228
2229bcm_status BCM_mldsa87_sign_internal(
2230    uint8_t out_encoded_signature[BCM_MLDSA87_SIGNATURE_BYTES],
2231    const struct BCM_mldsa87_private_key *private_key, const uint8_t *msg,
2232    size_t msg_len, const uint8_t *context_prefix, size_t context_prefix_len,
2233    const uint8_t *context, size_t context_len,
2234    const uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES]) {
2235  return bcm_as_approved_status(mldsa_sign_internal(
2236      out_encoded_signature, mldsa::private_key_from_external_87(private_key),
2237      msg, msg_len, context_prefix, context_prefix_len, context, context_len,
2238      randomizer));
2239}
2240
2241// ML-DSA signature in randomized mode, filling the random bytes with
2242// |BCM_rand_bytes|.
2243bcm_status BCM_mldsa87_sign(
2244    uint8_t out_encoded_signature[BCM_MLDSA87_SIGNATURE_BYTES],
2245    const struct BCM_mldsa87_private_key *private_key, const uint8_t *msg,
2246    size_t msg_len, const uint8_t *context, size_t context_len) {
2247  BSSL_CHECK(context_len <= 255);
2248  uint8_t randomizer[BCM_MLDSA_SIGNATURE_RANDOMIZER_BYTES];
2249  BCM_rand_bytes(randomizer, sizeof(randomizer));
2250
2251  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context_len)};
2252  return BCM_mldsa87_sign_internal(
2253      out_encoded_signature, private_key, msg, msg_len, context_prefix,
2254      sizeof(context_prefix), context, context_len, randomizer);
2255}
2256
2257// FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
2258bcm_status BCM_mldsa87_verify(const struct BCM_mldsa87_public_key *public_key,
2259                              const uint8_t *signature, const uint8_t *msg,
2260                              size_t msg_len, const uint8_t *context,
2261                              size_t context_len) {
2262  BSSL_CHECK(context_len <= 255);
2263  const uint8_t context_prefix[2] = {0, static_cast<uint8_t>(context_len)};
2264  return BCM_mldsa87_verify_internal(public_key, signature, msg, msg_len,
2265                                     context_prefix, sizeof(context_prefix),
2266                                     context, context_len);
2267}
2268
2269bcm_status BCM_mldsa87_verify_internal(
2270    const struct BCM_mldsa87_public_key *public_key,
2271    const uint8_t encoded_signature[BCM_MLDSA87_SIGNATURE_BYTES],
2272    const uint8_t *msg, size_t msg_len, const uint8_t *context_prefix,
2273    size_t context_prefix_len, const uint8_t *context, size_t context_len) {
2274  return bcm_as_approved_status(mldsa::mldsa_verify_internal<8, 7>(
2275      mldsa::public_key_from_external_87(public_key), encoded_signature, msg,
2276      msg_len, context_prefix, context_prefix_len, context, context_len));
2277}
2278
2279bcm_status BCM_mldsa87_marshal_public_key(
2280    CBB *out, const struct BCM_mldsa87_public_key *public_key) {
2281  return bcm_as_approved_status(mldsa_marshal_public_key(
2282      out, mldsa::public_key_from_external_87(public_key)));
2283}
2284
2285int boringssl_self_test_mldsa() {
2286  return mldsa::fips::keygen_self_test() && mldsa::fips::sign_self_test() &&
2287         mldsa::fips::verify_self_test();
2288}
2289