• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2014 The BoringSSL Authors
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15#include <openssl/base.h>
16
17#include <assert.h>
18#include <stdint.h>
19#include <stdlib.h>
20#include <string.h>
21
22#include <openssl/base.h>
23#include <openssl/bytestring.h>
24#include <openssl/mem.h>
25#include <openssl/rand.h>
26
27#include "../../internal.h"
28#include "../bcm_interface.h"
29#include "../delocate.h"
30#include "../keccak/internal.h"
31
32
33namespace mlkem {
34namespace {
35
36namespace fips {
37void ensure_keygen_self_test();
38void ensure_encap_self_test();
39void ensure_decap_self_test();
40}  // namespace fips
41
42// See
43// https://csrc.nist.gov/pubs/fips/203/final
44
45static void prf(uint8_t *out, size_t out_len, const uint8_t in[33]) {
46  BORINGSSL_keccak(out, out_len, in, 33, boringssl_shake256);
47}
48
49// Section 4.1
50void hash_h(uint8_t out[32], const uint8_t *in, size_t len) {
51  BORINGSSL_keccak(out, 32, in, len, boringssl_sha3_256);
52}
53
54void hash_g(uint8_t out[64], const uint8_t *in, size_t len) {
55  BORINGSSL_keccak(out, 64, in, len, boringssl_sha3_512);
56}
57
58// This is called `J` in the spec.
59void kdf(uint8_t out[BCM_MLKEM_SHARED_SECRET_BYTES],
60         const uint8_t failure_secret[32], const uint8_t *ciphertext,
61         size_t ciphertext_len) {
62  struct BORINGSSL_keccak_st st;
63  BORINGSSL_keccak_init(&st, boringssl_shake256);
64  BORINGSSL_keccak_absorb(&st, failure_secret, 32);
65  BORINGSSL_keccak_absorb(&st, ciphertext, ciphertext_len);
66  BORINGSSL_keccak_squeeze(&st, out, BCM_MLKEM_SHARED_SECRET_BYTES);
67}
68
69// Constants that are common across all sizes.
70#define DEGREE 256
71const size_t kBarrettMultiplier = 5039;
72const unsigned kBarrettShift = 24;
73static const uint16_t kPrime = 3329;
74const int kLog2Prime = 12;
75const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
76// kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
77// root of unity.
78const uint16_t kInverseDegree = 3303;
79
80// Rank-specific constants.
81#define RANK768 3
82static const int kDU768 = 10;
83const int kDV768 = 4;
84#define RANK1024 4
85static const int kDU1024 = 11;
86const int kDV1024 = 5;
87
88constexpr size_t encoded_vector_size(int rank) {
89  return (kLog2Prime * DEGREE / 8) * static_cast<size_t>(rank);
90}
91
92constexpr size_t encoded_public_key_size(int rank) {
93  return encoded_vector_size(rank) + /*sizeof(rho)=*/32;
94}
95
96static_assert(encoded_public_key_size(RANK768) == BCM_MLKEM768_PUBLIC_KEY_BYTES,
97              "");
98static_assert(encoded_public_key_size(RANK1024) ==
99                  BCM_MLKEM1024_PUBLIC_KEY_BYTES,
100              "");
101
102constexpr size_t compressed_vector_size(int rank) {
103  // `if constexpr` isn't available in C++17.
104  return (rank == RANK768 ? kDU768 : kDU1024) * static_cast<size_t>(rank) *
105         DEGREE / 8;
106}
107
108constexpr size_t ciphertext_size(int rank) {
109  return compressed_vector_size(rank) +
110         (rank == RANK768 ? kDV768 : kDV1024) * DEGREE / 8;
111}
112
113static_assert(ciphertext_size(RANK768) == BCM_MLKEM768_CIPHERTEXT_BYTES, "");
114static_assert(ciphertext_size(RANK1024) == BCM_MLKEM1024_CIPHERTEXT_BYTES, "");
115
116typedef struct scalar {
117  // On every function entry and exit, 0 <= c < kPrime.
118  uint16_t c[DEGREE];
119} scalar;
120
121template <int RANK>
122struct vector {
123  scalar v[RANK];
124};
125
126template <int RANK>
127struct matrix {
128  scalar v[RANK][RANK];
129};
130
131// This bit of Python will be referenced in some of the following comments:
132//
133// p = 3329
134//
135// def bitreverse(i):
136//     ret = 0
137//     for n in range(7):
138//         bit = i & 1
139//         ret <<= 1
140//         ret |= bit
141//         i >>= 1
142//     return ret
143
144// kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
145const uint16_t kNTTRoots[128] = {
146    1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
147    2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
148    1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
149    1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
150    2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
151    2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
152    1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
153    1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
154    1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
155    2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
156    1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
157};
158
159// kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
160const uint16_t kInverseNTTRoots[128] = {
161    1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
162    2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
163    1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
164    2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
165    1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
166    1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
167    2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
168    2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
169    2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
170    1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
171    2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
172};
173
174// kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
175const uint16_t kModRoots[128] = {
176    17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
177    2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
178    756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
179    2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
180    939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
181    268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
182    375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
183    2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
184    2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
185    2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
186    2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
187};
188
189// reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
190uint16_t reduce_once(uint16_t x) {
191  declassify_assert(x < 2 * kPrime);
192  const uint16_t subtracted = x - kPrime;
193  uint16_t mask = 0u - (subtracted >> 15);
194  // Although this is a constant-time select, we omit a value barrier here.
195  // Value barriers impede auto-vectorization (likely because it forces the
196  // value to transit through a general-purpose register). On AArch64, this is a
197  // difference of 2x.
198  //
199  // We usually add value barriers to selects because Clang turns consecutive
200  // selects with the same condition into a branch instead of CMOV/CSEL. This
201  // condition does not occur in ML-KEM, so omitting it seems to be safe so far,
202  // but see |scalar_centered_binomial_distribution_eta_2_with_prf|.
203  return (mask & x) | (~mask & subtracted);
204}
205
206// constant time reduce x mod kPrime using Barrett reduction. x must be less
207// than kPrime + 2×kPrime².
208static uint16_t reduce(uint32_t x) {
209  declassify_assert(x < kPrime + 2u * kPrime * kPrime);
210  uint64_t product = (uint64_t)x * kBarrettMultiplier;
211  uint32_t quotient = (uint32_t)(product >> kBarrettShift);
212  uint32_t remainder = x - quotient * kPrime;
213  return reduce_once(remainder);
214}
215
216void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
217
218template <int RANK>
219void vector_zero(vector<RANK> *out) {
220  OPENSSL_memset(out->v, 0, sizeof(scalar) * RANK);
221}
222
223// In place number theoretic transform of a given scalar.
224// Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
225// transform leaves off the last iteration of the usual FFT code, with the 128
226// relevant roots of unity being stored in |kNTTRoots|. This means the output
227// should be seen as 128 elements in GF(3329^2), with the coefficients of the
228// elements being consecutive entries in |s->c|.
229static void scalar_ntt(scalar *s) {
230  int offset = DEGREE;
231  // `int` is used here because using `size_t` throughout caused a ~5% slowdown
232  // with Clang 14 on Aarch64.
233  for (int step = 1; step < DEGREE / 2; step <<= 1) {
234    offset >>= 1;
235    int k = 0;
236    for (int i = 0; i < step; i++) {
237      const uint32_t step_root = kNTTRoots[i + step];
238      for (int j = k; j < k + offset; j++) {
239        uint16_t odd = reduce(step_root * s->c[j + offset]);
240        uint16_t even = s->c[j];
241        s->c[j] = reduce_once(odd + even);
242        s->c[j + offset] = reduce_once(even - odd + kPrime);
243      }
244      k += 2 * offset;
245    }
246  }
247}
248
249template <int RANK>
250static void vector_ntt(vector<RANK> *a) {
251  for (int i = 0; i < RANK; i++) {
252    scalar_ntt(&a->v[i]);
253  }
254}
255
256// In place inverse number theoretic transform of a given scalar, with pairs of
257// entries of s->v being interpreted as elements of GF(3329^2). Just as with the
258// number theoretic transform, this leaves off the first step of the normal iFFT
259// to account for the fact that 3329 does not have a 512th root of unity, using
260// the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
261void scalar_inverse_ntt(scalar *s) {
262  int step = DEGREE / 2;
263  // `int` is used here because using `size_t` throughout caused a ~5% slowdown
264  // with Clang 14 on Aarch64.
265  for (int offset = 2; offset < DEGREE; offset <<= 1) {
266    step >>= 1;
267    int k = 0;
268    for (int i = 0; i < step; i++) {
269      uint32_t step_root = kInverseNTTRoots[i + step];
270      for (int j = k; j < k + offset; j++) {
271        uint16_t odd = s->c[j + offset];
272        uint16_t even = s->c[j];
273        s->c[j] = reduce_once(odd + even);
274        s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
275      }
276      k += 2 * offset;
277    }
278  }
279  for (int i = 0; i < DEGREE; i++) {
280    s->c[i] = reduce(s->c[i] * kInverseDegree);
281  }
282}
283
284template <int RANK>
285void vector_inverse_ntt(vector<RANK> *a) {
286  for (int i = 0; i < RANK; i++) {
287    scalar_inverse_ntt(&a->v[i]);
288  }
289}
290
291void scalar_add(scalar *lhs, const scalar *rhs) {
292  for (int i = 0; i < DEGREE; i++) {
293    lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
294  }
295}
296
297void scalar_sub(scalar *lhs, const scalar *rhs) {
298  for (int i = 0; i < DEGREE; i++) {
299    lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
300  }
301}
302
303// Multiplying two scalars in the number theoretically transformed state. Since
304// 3329 does not have a 512th root of unity, this means we have to interpret
305// the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
306// - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
307// stored in the precomputed |kModRoots| table. Note that our Barrett transform
308// only allows us to multipy two reduced numbers together, so we need some
309// intermediate reduction steps, even if an uint64_t could hold 3 multiplied
310// numbers.
311void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
312  for (int i = 0; i < DEGREE / 2; i++) {
313    uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
314    uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
315    uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
316    uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
317    out->c[2 * i] =
318        reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
319    out->c[2 * i + 1] = reduce(img_real + real_img);
320  }
321}
322
323template <int RANK>
324void vector_add(vector<RANK> *lhs, const vector<RANK> *rhs) {
325  for (int i = 0; i < RANK; i++) {
326    scalar_add(&lhs->v[i], &rhs->v[i]);
327  }
328}
329
330template <int RANK>
331static void matrix_mult(vector<RANK> *out, const matrix<RANK> *m,
332                        const vector<RANK> *a) {
333  vector_zero(out);
334  for (int i = 0; i < RANK; i++) {
335    for (int j = 0; j < RANK; j++) {
336      scalar product;
337      scalar_mult(&product, &m->v[i][j], &a->v[j]);
338      scalar_add(&out->v[i], &product);
339    }
340  }
341}
342
343template <int RANK>
344void matrix_mult_transpose(vector<RANK> *out, const matrix<RANK> *m,
345                           const vector<RANK> *a) {
346  vector_zero(out);
347  for (int i = 0; i < RANK; i++) {
348    for (int j = 0; j < RANK; j++) {
349      scalar product;
350      scalar_mult(&product, &m->v[j][i], &a->v[j]);
351      scalar_add(&out->v[i], &product);
352    }
353  }
354}
355
356template <int RANK>
357void scalar_inner_product(scalar *out, const vector<RANK> *lhs,
358                          const vector<RANK> *rhs) {
359  scalar_zero(out);
360  for (int i = 0; i < RANK; i++) {
361    scalar product;
362    scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
363    scalar_add(out, &product);
364  }
365}
366
367// Algorithm 6 from the spec. Rejection samples a Keccak stream to get
368// uniformly distributed elements. This is used for matrix expansion and only
369// operates on public inputs.
370static void scalar_from_keccak_vartime(scalar *out,
371                                       struct BORINGSSL_keccak_st *keccak_ctx) {
372  assert(keccak_ctx->squeeze_offset == 0);
373  assert(keccak_ctx->rate_bytes == 168);
374  static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
375
376  int done = 0;
377  while (done < DEGREE) {
378    uint8_t block[168];
379    BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
380    for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
381      uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
382      uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
383      if (d1 < kPrime) {
384        out->c[done++] = d1;
385      }
386      if (d2 < kPrime && done < DEGREE) {
387        out->c[done++] = d2;
388      }
389    }
390  }
391}
392
393// Algorithm 7 from the spec, with eta fixed to two and the PRF call
394// included. Creates binominally distributed elements by sampling 2*|eta| bits,
395// and setting the coefficient to the count of the first bits minus the count of
396// the second bits, resulting in a centered binomial distribution. Since eta is
397// two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
398// and 0 with probability 3/8.
399void scalar_centered_binomial_distribution_eta_2_with_prf(
400    scalar *out, const uint8_t input[33]) {
401  uint8_t entropy[128];
402  static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
403  prf(entropy, sizeof(entropy), input);
404
405  for (int i = 0; i < DEGREE; i += 2) {
406    uint8_t byte = entropy[i / 2];
407
408    uint16_t value = (byte & 1) + ((byte >> 1) & 1);
409    value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
410    // Add |kPrime| if |value| underflowed. See |reduce_once| for a discussion
411    // on why the value barrier is omitted. While this could have been written
412    // reduce_once(value + kPrime), this is one extra addition and small range
413    // of |value| tempts some versions of Clang to emit a branch.
414    uint16_t mask = 0u - (value >> 15);
415    out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
416
417    byte >>= 4;
418    value = (byte & 1) + ((byte >> 1) & 1);
419    value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
420    // See above.
421    mask = 0u - (value >> 15);
422    out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
423  }
424}
425
426// Generates a secret vector by using
427// |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
428// appending and incrementing |counter| for entry of the vector.
429template <int RANK>
430void vector_generate_secret_eta_2(vector<RANK> *out, uint8_t *counter,
431                                  const uint8_t seed[32]) {
432  uint8_t input[33];
433  OPENSSL_memcpy(input, seed, 32);
434  for (int i = 0; i < RANK; i++) {
435    input[32] = (*counter)++;
436    scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
437  }
438}
439
440// Expands the matrix of a seed for key generation and for encaps-CPA.
441template <int RANK>
442void matrix_expand(matrix<RANK> *out, const uint8_t rho[32]) {
443  uint8_t input[34];
444  OPENSSL_memcpy(input, rho, 32);
445  for (int i = 0; i < RANK; i++) {
446    for (int j = 0; j < RANK; j++) {
447      input[32] = i;
448      input[33] = j;
449      struct BORINGSSL_keccak_st keccak_ctx;
450      BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
451      BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
452      scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
453    }
454  }
455}
456
457const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff};
458
459void scalar_encode(uint8_t *out, const scalar *s, int bits) {
460  assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
461
462  uint8_t out_byte = 0;
463  int out_byte_bits = 0;
464
465  for (int i = 0; i < DEGREE; i++) {
466    uint16_t element = s->c[i];
467    int element_bits_done = 0;
468
469    while (element_bits_done < bits) {
470      int chunk_bits = bits - element_bits_done;
471      int out_bits_remaining = 8 - out_byte_bits;
472      if (chunk_bits >= out_bits_remaining) {
473        chunk_bits = out_bits_remaining;
474        out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
475        *out = out_byte;
476        out++;
477        out_byte_bits = 0;
478        out_byte = 0;
479      } else {
480        out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
481        out_byte_bits += chunk_bits;
482      }
483
484      element_bits_done += chunk_bits;
485      element >>= chunk_bits;
486    }
487  }
488
489  if (out_byte_bits > 0) {
490    *out = out_byte;
491  }
492}
493
494// scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
495void scalar_encode_1(uint8_t out[32], const scalar *s) {
496  for (int i = 0; i < DEGREE; i += 8) {
497    uint8_t out_byte = 0;
498    for (int j = 0; j < 8; j++) {
499      out_byte |= (s->c[i + j] & 1) << j;
500    }
501    *out = out_byte;
502    out++;
503  }
504}
505
506// Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
507// (DEGREE) is divisible by 8, the individual vector entries will always fill a
508// whole number of bytes, so we do not need to worry about bit packing here.
509template <int RANK>
510void vector_encode(uint8_t *out, const vector<RANK> *a, int bits) {
511  for (int i = 0; i < RANK; i++) {
512    scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
513  }
514}
515
516// scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
517// |out|. It returns one on success and zero if any parsed value is >=
518// |kPrime|.
519int scalar_decode(scalar *out, const uint8_t *in, int bits) {
520  assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
521
522  uint8_t in_byte = 0;
523  int in_byte_bits_left = 0;
524
525  for (int i = 0; i < DEGREE; i++) {
526    uint16_t element = 0;
527    int element_bits_done = 0;
528
529    while (element_bits_done < bits) {
530      if (in_byte_bits_left == 0) {
531        in_byte = *in;
532        in++;
533        in_byte_bits_left = 8;
534      }
535
536      int chunk_bits = bits - element_bits_done;
537      if (chunk_bits > in_byte_bits_left) {
538        chunk_bits = in_byte_bits_left;
539      }
540
541      element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
542      in_byte_bits_left -= chunk_bits;
543      in_byte >>= chunk_bits;
544
545      element_bits_done += chunk_bits;
546    }
547
548    // An element is only out of range in the case of invalid input, in which
549    // case it is okay to leak the comparison.
550    if (constant_time_declassify_int(element >= kPrime)) {
551      return 0;
552    }
553    out->c[i] = element;
554  }
555
556  return 1;
557}
558
559// scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
560void scalar_decode_1(scalar *out, const uint8_t in[32]) {
561  for (int i = 0; i < DEGREE; i += 8) {
562    uint8_t in_byte = *in;
563    in++;
564    for (int j = 0; j < 8; j++) {
565      out->c[i + j] = in_byte & 1;
566      in_byte >>= 1;
567    }
568  }
569}
570
571// Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
572// success or zero if any parsed value is >= |kPrime|.
573template <int RANK>
574static int vector_decode(vector<RANK> *out, const uint8_t *in, int bits) {
575  for (int i = 0; i < RANK; i++) {
576    if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
577      return 0;
578    }
579  }
580  return 1;
581}
582
583// Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
584// numbers close to each other together. The formula used is
585// round(2^|bits|/kPrime*x) mod 2^|bits|.
586// Uses Barrett reduction to achieve constant time. Since we need both the
587// remainder (for rounding) and the quotient (as the result), we cannot use
588// |reduce| here, but need to do the Barrett reduction directly.
589static uint16_t compress(uint16_t x, int bits) {
590  uint32_t shifted = (uint32_t)x << bits;
591  uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
592  uint32_t quotient = (uint32_t)(product >> kBarrettShift);
593  uint32_t remainder = shifted - quotient * kPrime;
594
595  // Adjust the quotient to round correctly:
596  //   0 <= remainder <= kHalfPrime round to 0
597  //   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
598  //   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
599  declassify_assert(remainder < 2u * kPrime);
600  quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
601  quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
602  return quotient & ((1 << bits) - 1);
603}
604
605// Decompresses |x| by using an equi-distant representative. The formula is
606// round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
607// implement this logic using only bit operations.
608uint16_t decompress(uint16_t x, int bits) {
609  uint32_t product = (uint32_t)x * kPrime;
610  uint32_t power = 1 << bits;
611  // This is |product| % power, since |power| is a power of 2.
612  uint32_t remainder = product & (power - 1);
613  // This is |product| / power, since |power| is a power of 2.
614  uint32_t lower = product >> bits;
615  // The rounding logic works since the first half of numbers mod |power| have a
616  // 0 as first bit, and the second half has a 1 as first bit, since |power| is
617  // a power of 2. As a 12 bit number, |remainder| is always positive, so we
618  // will shift in 0s for a right shift.
619  return lower + (remainder >> (bits - 1));
620}
621
622static void scalar_compress(scalar *s, int bits) {
623  for (int i = 0; i < DEGREE; i++) {
624    s->c[i] = compress(s->c[i], bits);
625  }
626}
627
628static void scalar_decompress(scalar *s, int bits) {
629  for (int i = 0; i < DEGREE; i++) {
630    s->c[i] = decompress(s->c[i], bits);
631  }
632}
633
634template <int RANK>
635void vector_compress(vector<RANK> *a, int bits) {
636  for (int i = 0; i < RANK; i++) {
637    scalar_compress(&a->v[i], bits);
638  }
639}
640
641template <int RANK>
642void vector_decompress(vector<RANK> *a, int bits) {
643  for (int i = 0; i < RANK; i++) {
644    scalar_decompress(&a->v[i], bits);
645  }
646}
647
648template <int RANK>
649struct public_key {
650  vector<RANK> t;
651  uint8_t rho[32];
652  uint8_t public_key_hash[32];
653  matrix<RANK> m;
654};
655
656template <int RANK>
657struct private_key {
658  struct public_key<RANK> pub;
659  vector<RANK> s;
660  uint8_t fo_failure_secret[32];
661};
662
663template <int RANK>
664static void decrypt_cpa(
665    uint8_t out[32], const struct private_key<RANK> *priv,
666    const uint8_t ciphertext[BCM_MLKEM768_CIPHERTEXT_BYTES]) {
667  constexpr int du = RANK == RANK768 ? kDU768 : kDU1024;
668  constexpr int dv = RANK == RANK768 ? kDV768 : kDV1024;
669
670  vector<RANK> u;
671  vector_decode(&u, ciphertext, du);
672  vector_decompress(&u, du);
673  vector_ntt(&u);
674  scalar v;
675  scalar_decode(&v, ciphertext + compressed_vector_size(RANK), dv);
676  scalar_decompress(&v, dv);
677  scalar mask;
678  scalar_inner_product(&mask, &priv->s, &u);
679  scalar_inverse_ntt(&mask);
680  scalar_sub(&v, &mask);
681  scalar_compress(&v, 1);
682  scalar_encode_1(out, &v);
683}
684
685template <int RANK>
686static bcm_status mlkem_marshal_public_key(CBB *out,
687                                           const struct public_key<RANK> *pub) {
688  uint8_t *vector_output;
689  if (!CBB_add_space(out, &vector_output, encoded_vector_size(RANK))) {
690    return bcm_status::failure;
691  }
692  vector_encode(vector_output, &pub->t, kLog2Prime);
693  if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
694    return bcm_status::failure;
695  }
696  return bcm_status::approved;
697}
698
699template <int RANK>
700void mlkem_generate_key_external_seed_no_self_test(
701    uint8_t *out_encoded_public_key, private_key<RANK> *priv,
702    const uint8_t seed[BCM_MLKEM_SEED_BYTES]) {
703  uint8_t augmented_seed[33];
704  OPENSSL_memcpy(augmented_seed, seed, 32);
705  augmented_seed[32] = RANK;
706
707  uint8_t hashed[64];
708  hash_g(hashed, augmented_seed, sizeof(augmented_seed));
709  const uint8_t *const rho = hashed;
710  const uint8_t *const sigma = hashed + 32;
711  // rho is public.
712  CONSTTIME_DECLASSIFY(rho, 32);
713  OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
714  matrix_expand(&priv->pub.m, rho);
715  uint8_t counter = 0;
716  vector_generate_secret_eta_2(&priv->s, &counter, sigma);
717  vector_ntt(&priv->s);
718  vector<RANK> error;
719  vector_generate_secret_eta_2(&error, &counter, sigma);
720  vector_ntt(&error);
721  matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
722  vector_add(&priv->pub.t, &error);
723  // t is part of the public key and thus is public.
724  CONSTTIME_DECLASSIFY(&priv->pub.t, sizeof(priv->pub.t));
725
726  CBB cbb;
727  CBB_init_fixed(&cbb, out_encoded_public_key, encoded_public_key_size(RANK));
728  if (!bcm_success(mlkem_marshal_public_key(&cbb, &priv->pub))) {
729    abort();
730  }
731
732  hash_h(priv->pub.public_key_hash, out_encoded_public_key,
733         encoded_public_key_size(RANK));
734  OPENSSL_memcpy(priv->fo_failure_secret, seed + 32, 32);
735}
736
737template <int RANK>
738void mlkem_generate_key_external_seed(
739    uint8_t *out_encoded_public_key, private_key<RANK> *priv,
740    const uint8_t seed[BCM_MLKEM_SEED_BYTES]) {
741  fips::ensure_keygen_self_test();
742  mlkem_generate_key_external_seed_no_self_test(out_encoded_public_key, priv,
743                                                seed);
744}
745
746// Encrypts a message with given randomness to
747// the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
748// would not result in a CCA secure scheme, since lattice schemes are vulnerable
749// to decryption failure oracles.
750template <int RANK>
751void encrypt_cpa(uint8_t *out, const struct mlkem::public_key<RANK> *pub,
752                 const uint8_t message[32], const uint8_t randomness[32]) {
753  constexpr int du = RANK == RANK768 ? mlkem::kDU768 : mlkem::kDU1024;
754  constexpr int dv = RANK == RANK768 ? mlkem::kDV768 : mlkem::kDV1024;
755
756  uint8_t counter = 0;
757  mlkem::vector<RANK> secret;
758  vector_generate_secret_eta_2(&secret, &counter, randomness);
759  vector_ntt(&secret);
760  mlkem::vector<RANK> error;
761  vector_generate_secret_eta_2(&error, &counter, randomness);
762  uint8_t input[33];
763  OPENSSL_memcpy(input, randomness, 32);
764  input[32] = counter;
765  mlkem::scalar scalar_error;
766  scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
767  mlkem::vector<RANK> u;
768  matrix_mult(&u, &pub->m, &secret);
769  vector_inverse_ntt(&u);
770  vector_add(&u, &error);
771  mlkem::scalar v;
772  scalar_inner_product(&v, &pub->t, &secret);
773  scalar_inverse_ntt(&v);
774  scalar_add(&v, &scalar_error);
775  mlkem::scalar expanded_message;
776  scalar_decode_1(&expanded_message, message);
777  scalar_decompress(&expanded_message, 1);
778  scalar_add(&v, &expanded_message);
779  vector_compress(&u, du);
780  vector_encode(out, &u, du);
781  scalar_compress(&v, dv);
782  scalar_encode(out + mlkem::compressed_vector_size(RANK), &v, dv);
783}
784
785// See section 6.3
786template <int RANK>
787void mlkem_decap_no_self_test(
788    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
789    const uint8_t *ciphertext, const struct private_key<RANK> *priv) {
790  uint8_t decrypted[64];
791  decrypt_cpa(decrypted, priv, ciphertext);
792  OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
793                 sizeof(decrypted) - 32);
794  uint8_t key_and_randomness[64];
795  hash_g(key_and_randomness, decrypted, sizeof(decrypted));
796  constexpr size_t ciphertext_len = ciphertext_size(RANK);
797  uint8_t expected_ciphertext[BCM_MLKEM1024_CIPHERTEXT_BYTES];
798  static_assert(ciphertext_len <= sizeof(expected_ciphertext), "");
799  encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
800              key_and_randomness + 32);
801
802  uint8_t failure_key[32];
803  kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
804
805  uint8_t mask = constant_time_eq_int_8(
806      CRYPTO_memcmp(ciphertext, expected_ciphertext, ciphertext_len), 0);
807  for (int i = 0; i < BCM_MLKEM_SHARED_SECRET_BYTES; i++) {
808    out_shared_secret[i] =
809        constant_time_select_8(mask, key_and_randomness[i], failure_key[i]);
810  }
811}
812
813template <int RANK>
814void mlkem_decap(uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
815                 const uint8_t *ciphertext,
816                 const struct private_key<RANK> *priv) {
817  fips::ensure_decap_self_test();
818  mlkem_decap_no_self_test(out_shared_secret, ciphertext, priv);
819}
820
821// mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
822// the value of |pub->public_key_hash|.
823template <int RANK>
824int mlkem_parse_public_key_no_hash(struct public_key<RANK> *pub, CBS *in) {
825  CBS t_bytes;
826  if (!CBS_get_bytes(in, &t_bytes, encoded_vector_size(RANK)) ||
827      !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
828      !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
829    return 0;
830  }
831  matrix_expand(&pub->m, pub->rho);
832  return 1;
833}
834
835template <int RANK>
836int mlkem_parse_public_key(struct public_key<RANK> *pub, CBS *in) {
837  CBS orig_in = *in;
838  if (!mlkem_parse_public_key_no_hash(pub, in) ||  //
839      CBS_len(in) != 0) {
840    return 0;
841  }
842  hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
843  return 1;
844}
845
846template <int RANK>
847int mlkem_parse_private_key(struct private_key<RANK> *priv, CBS *in) {
848  CBS s_bytes;
849  if (!CBS_get_bytes(in, &s_bytes, encoded_vector_size(RANK)) ||
850      !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
851      !mlkem_parse_public_key_no_hash(&priv->pub, in) ||
852      !CBS_copy_bytes(in, priv->pub.public_key_hash,
853                      sizeof(priv->pub.public_key_hash)) ||
854      !CBS_copy_bytes(in, priv->fo_failure_secret,
855                      sizeof(priv->fo_failure_secret)) ||
856      CBS_len(in) != 0) {
857    return 0;
858  }
859  return 1;
860}
861
862template <int RANK>
863int mlkem_marshal_private_key(CBB *out, const struct private_key<RANK> *priv) {
864  uint8_t *s_output;
865  if (!CBB_add_space(out, &s_output, encoded_vector_size(RANK))) {
866    return 0;
867  }
868  vector_encode(s_output, &priv->s, kLog2Prime);
869  if (!bcm_success(mlkem_marshal_public_key(out, &priv->pub)) ||
870      !CBB_add_bytes(out, priv->pub.public_key_hash,
871                     sizeof(priv->pub.public_key_hash)) ||
872      !CBB_add_bytes(out, priv->fo_failure_secret,
873                     sizeof(priv->fo_failure_secret))) {
874    return 0;
875  }
876  return 1;
877}
878
879struct public_key<RANK768> *public_key_768_from_external(
880    const struct BCM_mlkem768_public_key *external) {
881  static_assert(sizeof(struct BCM_mlkem768_public_key) >=
882                    sizeof(struct public_key<RANK768>),
883                "MLKEM public key is too small");
884  static_assert(alignof(struct BCM_mlkem768_public_key) >=
885                    alignof(struct public_key<RANK768>),
886                "MLKEM public key alignment incorrect");
887  return (struct public_key<RANK768> *)external;
888}
889
890static struct public_key<RANK1024> *public_key_1024_from_external(
891    const struct BCM_mlkem1024_public_key *external) {
892  static_assert(sizeof(struct BCM_mlkem1024_public_key) >=
893                    sizeof(struct public_key<RANK1024>),
894                "MLKEM1024 public key is too small");
895  static_assert(alignof(struct BCM_mlkem1024_public_key) >=
896                    alignof(struct public_key<RANK1024>),
897                "MLKEM1024 public key alignment incorrect");
898  return (struct public_key<RANK1024> *)external;
899}
900
901struct private_key<RANK768> *private_key_768_from_external(
902    const struct BCM_mlkem768_private_key *external) {
903  static_assert(sizeof(struct BCM_mlkem768_private_key) >=
904                    sizeof(struct private_key<RANK768>),
905                "MLKEM private key too small");
906  static_assert(alignof(struct BCM_mlkem768_private_key) >=
907                    alignof(struct private_key<RANK768>),
908                "MLKEM private key alignment incorrect");
909  return (struct private_key<RANK768> *)external;
910}
911
912struct private_key<RANK1024> *private_key_1024_from_external(
913    const struct BCM_mlkem1024_private_key *external) {
914  static_assert(sizeof(struct BCM_mlkem1024_private_key) >=
915                    sizeof(struct private_key<RANK1024>),
916                "MLKEM1024 private key too small");
917  static_assert(alignof(struct BCM_mlkem1024_private_key) >=
918                    alignof(struct private_key<RANK1024>),
919                "MLKEM1024 private key alignment incorrect");
920  return (struct private_key<RANK1024> *)external;
921}
922
923// See section 6.2.
924template <int RANK>
925void mlkem_encap_external_entropy_no_self_test(
926    uint8_t *out_ciphertext,
927    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
928    const struct mlkem::public_key<RANK> *pub,
929    const uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY]) {
930  uint8_t input[64];
931  OPENSSL_memcpy(input, entropy, BCM_MLKEM_ENCAP_ENTROPY);
932  OPENSSL_memcpy(input + BCM_MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
933                 sizeof(input) - BCM_MLKEM_ENCAP_ENTROPY);
934  uint8_t key_and_randomness[64];
935  mlkem::hash_g(key_and_randomness, input, sizeof(input));
936  encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
937  // The ciphertext is public.
938  CONSTTIME_DECLASSIFY(out_ciphertext, mlkem::ciphertext_size(RANK));
939  static_assert(BCM_MLKEM_SHARED_SECRET_BYTES == 32, "");
940  memcpy(out_shared_secret, key_and_randomness, 32);
941}
942
943template <int RANK>
944void mlkem_encap_external_entropy(
945    uint8_t *out_ciphertext,
946    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
947    const struct mlkem::public_key<RANK> *pub,
948    const uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY]) {
949  fips::ensure_encap_self_test();
950  mlkem_encap_external_entropy_no_self_test(out_ciphertext, out_shared_secret,
951                                            pub, entropy);
952}
953
954namespace fips {
955
956#include "fips_known_values.inc"
957
958static int keygen_self_test() {
959  uint8_t pub_key[BCM_MLKEM768_PUBLIC_KEY_BYTES];
960  private_key<RANK768> priv;
961  static_assert(sizeof(kTestEntropy) >= BCM_MLKEM_SEED_BYTES);
962  mlkem_generate_key_external_seed_no_self_test(pub_key, &priv, kTestEntropy);
963  CBB cbb;
964  constexpr size_t kMarshaledPrivateKeySize = 2400;
965  uint8_t priv_bytes[kMarshaledPrivateKeySize];
966  CBB_init_fixed(&cbb, priv_bytes, sizeof(priv_bytes));
967  static_assert(sizeof(kExpectedPrivateKeyBytes) == kMarshaledPrivateKeySize);
968  static_assert(sizeof(kExpectedPublicKeyBytes) == sizeof(pub_key));
969  if (!mlkem_marshal_private_key(&cbb, &priv) ||
970      !BORINGSSL_check_test(kExpectedPrivateKeyBytes, priv_bytes,
971                            sizeof(priv_bytes), "ML-KEM keygen private key") ||
972      !BORINGSSL_check_test(kExpectedPublicKeyBytes, pub_key, sizeof(pub_key),
973                            "ML-KEM keygen public key")) {
974    return 0;
975  }
976  return 1;
977}
978
979static int encap_self_test() {
980  CBS cbs;
981  CBS_init(&cbs, kExpectedPublicKeyBytes, sizeof(kExpectedPublicKeyBytes));
982  public_key<RANK768> pub;
983  if (!mlkem_parse_public_key(&pub, &cbs)) {
984    return 0;
985  }
986  uint8_t ciphertext[BCM_MLKEM768_CIPHERTEXT_BYTES];
987  uint8_t shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES];
988  static_assert(sizeof(kTestEntropy) >= BCM_MLKEM_ENCAP_ENTROPY);
989  mlkem_encap_external_entropy_no_self_test(ciphertext, shared_secret, &pub,
990                                            kTestEntropy);
991  if (!BORINGSSL_check_test(ciphertext, kExpectedCiphertext, sizeof(ciphertext),
992                            "ML-KEM encap ciphertext") ||
993      !BORINGSSL_check_test(kExpectedSharedSecret, shared_secret,
994                            sizeof(kExpectedSharedSecret),
995                            "ML-KEM encap shared secret")) {
996    return 0;
997  }
998  return 1;
999}
1000
1001static int decap_self_test() {
1002  CBS cbs;
1003  CBS_init(&cbs, kExpectedPrivateKeyBytes, sizeof(kExpectedPrivateKeyBytes));
1004  private_key<RANK768> priv;
1005  if (!mlkem_parse_private_key(&priv, &cbs)) {
1006    return 0;
1007  }
1008  uint8_t shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES];
1009  mlkem_decap_no_self_test(shared_secret, kExpectedCiphertext, &priv);
1010  static_assert(sizeof(kExpectedSharedSecret) == sizeof(shared_secret));
1011  if (!BORINGSSL_check_test(kExpectedSharedSecret, shared_secret,
1012                            sizeof(shared_secret),
1013                            "ML-KEM decap shared secret")) {
1014    return 0;
1015  }
1016
1017  uint8_t implicit_rejection_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES];
1018  static_assert(sizeof(kExpectedPrivateKeyBytes) >=
1019                sizeof(kExpectedCiphertext));
1020  mlkem_decap_no_self_test(implicit_rejection_shared_secret,
1021                           kExpectedPrivateKeyBytes, &priv);
1022  static_assert(sizeof(kExpectedImplicitRejectionSharedSecret) ==
1023                sizeof(implicit_rejection_shared_secret));
1024  if (!BORINGSSL_check_test(kExpectedImplicitRejectionSharedSecret,
1025                            implicit_rejection_shared_secret,
1026                            sizeof(implicit_rejection_shared_secret),
1027                            "ML-KEM decap implicit rejection shared secret")) {
1028    return 0;
1029  }
1030  return 1;
1031}
1032
1033#if defined(BORINGSSL_FIPS)
1034
1035DEFINE_STATIC_ONCE(g_mlkem_keygen_self_test_once)
1036
1037void ensure_keygen_self_test(void) {
1038  CRYPTO_once(g_mlkem_keygen_self_test_once_bss_get(), []() {
1039    if (!keygen_self_test()) {
1040      BORINGSSL_FIPS_abort();
1041    }
1042  });
1043}
1044
1045DEFINE_STATIC_ONCE(g_mlkem_encap_self_test_once)
1046
1047void ensure_encap_self_test(void) {
1048  CRYPTO_once(g_mlkem_encap_self_test_once_bss_get(), []() {
1049    if (!encap_self_test()) {
1050      BORINGSSL_FIPS_abort();
1051    }
1052  });
1053}
1054
1055DEFINE_STATIC_ONCE(g_mlkem_decap_self_test_once)
1056
1057void ensure_decap_self_test(void) {
1058  CRYPTO_once(g_mlkem_decap_self_test_once_bss_get(), []() {
1059    if (!decap_self_test()) {
1060      BORINGSSL_FIPS_abort();
1061    }
1062  });
1063}
1064
1065#else
1066
1067void ensure_keygen_self_test(void) {}
1068void ensure_encap_self_test(void) {}
1069void ensure_decap_self_test(void) {}
1070
1071#endif
1072}  // namespace fips
1073
1074}  // namespace
1075}  // namespace mlkem
1076
1077bcm_status BCM_mlkem768_check_fips(
1078    const struct BCM_mlkem768_private_key *private_key) {
1079  mlkem::private_key<RANK768> *priv =
1080      mlkem::private_key_768_from_external(private_key);
1081
1082  const uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY] = {1, 2, 3, 4};
1083  uint8_t ciphertext[BCM_MLKEM768_CIPHERTEXT_BYTES];
1084  uint8_t shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES];
1085  mlkem_encap_external_entropy_no_self_test(ciphertext, shared_secret,
1086                                            &priv->pub, entropy);
1087
1088  if (boringssl_fips_break_test("MLKEM_PWCT")) {
1089    shared_secret[0] ^= 1;
1090  }
1091
1092  uint8_t shared_secret2[BCM_MLKEM_SHARED_SECRET_BYTES];
1093  mlkem::mlkem_decap_no_self_test(shared_secret2, ciphertext, priv);
1094  if (CRYPTO_memcmp(shared_secret, shared_secret2, sizeof(shared_secret)) !=
1095      0) {
1096    return bcm_status::failure;
1097  }
1098  return bcm_status::approved;
1099}
1100
1101bcm_status BCM_mlkem768_generate_key_fips(
1102    uint8_t out_encoded_public_key[BCM_MLKEM768_PUBLIC_KEY_BYTES],
1103    uint8_t optional_out_seed[BCM_MLKEM_SEED_BYTES],
1104    struct BCM_mlkem768_private_key *out_private_key) {
1105  BCM_mlkem768_generate_key(out_encoded_public_key, optional_out_seed,
1106                            out_private_key);
1107  return BCM_mlkem768_check_fips(out_private_key);
1108}
1109
1110bcm_infallible BCM_mlkem768_generate_key(
1111    uint8_t out_encoded_public_key[BCM_MLKEM768_PUBLIC_KEY_BYTES],
1112    uint8_t optional_out_seed[BCM_MLKEM_SEED_BYTES],
1113    struct BCM_mlkem768_private_key *out_private_key) {
1114  uint8_t seed[BCM_MLKEM_SEED_BYTES];
1115  BCM_rand_bytes(seed, sizeof(seed));
1116  CONSTTIME_SECRET(seed, sizeof(seed));
1117  if (optional_out_seed) {
1118    OPENSSL_memcpy(optional_out_seed, seed, sizeof(seed));
1119  }
1120  BCM_mlkem768_generate_key_external_seed(out_encoded_public_key,
1121                                          out_private_key, seed);
1122  return bcm_infallible::not_approved;
1123}
1124
1125bcm_status BCM_mlkem768_private_key_from_seed(
1126    struct BCM_mlkem768_private_key *out_private_key, const uint8_t *seed,
1127    size_t seed_len) {
1128  if (seed_len != BCM_MLKEM_SEED_BYTES) {
1129    return bcm_status::failure;
1130  }
1131
1132  uint8_t public_key_bytes[BCM_MLKEM768_PUBLIC_KEY_BYTES];
1133  BCM_mlkem768_generate_key_external_seed(public_key_bytes, out_private_key,
1134                                          seed);
1135  return bcm_status::not_approved;
1136}
1137
1138bcm_status BCM_mlkem1024_check_fips(
1139    const struct BCM_mlkem1024_private_key *private_key) {
1140  mlkem::private_key<RANK1024> *priv =
1141      mlkem::private_key_1024_from_external(private_key);
1142
1143  const uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY] = {1, 2, 3, 4};
1144  uint8_t ciphertext[BCM_MLKEM1024_CIPHERTEXT_BYTES];
1145  uint8_t shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES];
1146  mlkem_encap_external_entropy_no_self_test(ciphertext, shared_secret,
1147                                            &priv->pub, entropy);
1148
1149  if (boringssl_fips_break_test("MLKEM_PWCT")) {
1150    shared_secret[0] ^= 1;
1151  }
1152
1153  uint8_t shared_secret2[BCM_MLKEM_SHARED_SECRET_BYTES];
1154  mlkem::mlkem_decap_no_self_test(shared_secret2, ciphertext, priv);
1155  if (CRYPTO_memcmp(shared_secret, shared_secret2, sizeof(shared_secret)) !=
1156      0) {
1157    return bcm_status::failure;
1158  }
1159  return bcm_status::approved;
1160}
1161
1162bcm_status BCM_mlkem1024_generate_key_fips(
1163    uint8_t out_encoded_public_key[BCM_MLKEM1024_PUBLIC_KEY_BYTES],
1164    uint8_t optional_out_seed[BCM_MLKEM_SEED_BYTES],
1165    struct BCM_mlkem1024_private_key *out_private_key) {
1166  BCM_mlkem1024_generate_key(out_encoded_public_key, optional_out_seed,
1167                             out_private_key);
1168  return BCM_mlkem1024_check_fips(out_private_key);
1169}
1170
1171bcm_infallible BCM_mlkem1024_generate_key(
1172    uint8_t out_encoded_public_key[BCM_MLKEM1024_PUBLIC_KEY_BYTES],
1173    uint8_t optional_out_seed[BCM_MLKEM_SEED_BYTES],
1174    struct BCM_mlkem1024_private_key *out_private_key) {
1175  uint8_t seed[BCM_MLKEM_SEED_BYTES];
1176  BCM_rand_bytes(seed, sizeof(seed));
1177  CONSTTIME_SECRET(seed, sizeof(seed));
1178  if (optional_out_seed) {
1179    OPENSSL_memcpy(optional_out_seed, seed, sizeof(seed));
1180  }
1181  BCM_mlkem1024_generate_key_external_seed(out_encoded_public_key,
1182                                           out_private_key, seed);
1183  return bcm_infallible::not_approved;
1184}
1185
1186bcm_status BCM_mlkem1024_private_key_from_seed(
1187    struct BCM_mlkem1024_private_key *out_private_key, const uint8_t *seed,
1188    size_t seed_len) {
1189  if (seed_len != BCM_MLKEM_SEED_BYTES) {
1190    return bcm_status::failure;
1191  }
1192  uint8_t public_key_bytes[BCM_MLKEM1024_PUBLIC_KEY_BYTES];
1193  BCM_mlkem1024_generate_key_external_seed(public_key_bytes, out_private_key,
1194                                           seed);
1195  return bcm_status::not_approved;
1196}
1197
1198bcm_infallible BCM_mlkem768_generate_key_external_seed(
1199    uint8_t out_encoded_public_key[BCM_MLKEM768_PUBLIC_KEY_BYTES],
1200    struct BCM_mlkem768_private_key *out_private_key,
1201    const uint8_t seed[BCM_MLKEM_SEED_BYTES]) {
1202  mlkem::private_key<RANK768> *priv =
1203      mlkem::private_key_768_from_external(out_private_key);
1204  mlkem_generate_key_external_seed(out_encoded_public_key, priv, seed);
1205  return bcm_infallible::approved;
1206}
1207
1208bcm_infallible BCM_mlkem1024_generate_key_external_seed(
1209    uint8_t out_encoded_public_key[BCM_MLKEM1024_PUBLIC_KEY_BYTES],
1210    struct BCM_mlkem1024_private_key *out_private_key,
1211    const uint8_t seed[BCM_MLKEM_SEED_BYTES]) {
1212  mlkem::private_key<RANK1024> *priv =
1213      mlkem::private_key_1024_from_external(out_private_key);
1214  mlkem_generate_key_external_seed(out_encoded_public_key, priv, seed);
1215  return bcm_infallible::approved;
1216}
1217
1218bcm_infallible BCM_mlkem768_public_from_private(
1219    struct BCM_mlkem768_public_key *out_public_key,
1220    const struct BCM_mlkem768_private_key *private_key) {
1221  struct mlkem::public_key<RANK768> *const pub =
1222      mlkem::public_key_768_from_external(out_public_key);
1223  const struct mlkem::private_key<RANK768> *const priv =
1224      mlkem::private_key_768_from_external(private_key);
1225  *pub = priv->pub;
1226  return bcm_infallible::approved;
1227}
1228
1229bcm_infallible BCM_mlkem1024_public_from_private(
1230    struct BCM_mlkem1024_public_key *out_public_key,
1231    const struct BCM_mlkem1024_private_key *private_key) {
1232  struct mlkem::public_key<RANK1024> *const pub =
1233      mlkem::public_key_1024_from_external(out_public_key);
1234  const struct mlkem::private_key<RANK1024> *const priv =
1235      mlkem::private_key_1024_from_external(private_key);
1236  *pub = priv->pub;
1237  return bcm_infallible::approved;
1238}
1239
1240// Calls |MLKEM768_encap_external_entropy| with random bytes from
1241// |BCM_rand_bytes|
1242bcm_infallible BCM_mlkem768_encap(
1243    uint8_t out_ciphertext[BCM_MLKEM768_CIPHERTEXT_BYTES],
1244    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
1245    const struct BCM_mlkem768_public_key *public_key) {
1246  uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY];
1247  BCM_rand_bytes(entropy, BCM_MLKEM_ENCAP_ENTROPY);
1248  CONSTTIME_SECRET(entropy, BCM_MLKEM_ENCAP_ENTROPY);
1249  BCM_mlkem768_encap_external_entropy(out_ciphertext, out_shared_secret,
1250                                      public_key, entropy);
1251  return bcm_infallible::approved;
1252}
1253
1254bcm_infallible BCM_mlkem1024_encap(
1255    uint8_t out_ciphertext[BCM_MLKEM1024_CIPHERTEXT_BYTES],
1256    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
1257    const struct BCM_mlkem1024_public_key *public_key) {
1258  uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY];
1259  BCM_rand_bytes(entropy, BCM_MLKEM_ENCAP_ENTROPY);
1260  CONSTTIME_SECRET(entropy, BCM_MLKEM_ENCAP_ENTROPY);
1261  BCM_mlkem1024_encap_external_entropy(out_ciphertext, out_shared_secret,
1262                                       public_key, entropy);
1263  return bcm_infallible::approved;
1264}
1265
1266bcm_infallible BCM_mlkem768_encap_external_entropy(
1267    uint8_t out_ciphertext[BCM_MLKEM768_CIPHERTEXT_BYTES],
1268    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
1269    const struct BCM_mlkem768_public_key *public_key,
1270    const uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY]) {
1271  const struct mlkem::public_key<RANK768> *pub =
1272      mlkem::public_key_768_from_external(public_key);
1273  mlkem_encap_external_entropy(out_ciphertext, out_shared_secret, pub, entropy);
1274  return bcm_infallible::approved;
1275}
1276
1277bcm_infallible BCM_mlkem1024_encap_external_entropy(
1278    uint8_t out_ciphertext[BCM_MLKEM1024_CIPHERTEXT_BYTES],
1279    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
1280    const struct BCM_mlkem1024_public_key *public_key,
1281    const uint8_t entropy[BCM_MLKEM_ENCAP_ENTROPY]) {
1282  const struct mlkem::public_key<RANK1024> *pub =
1283      mlkem::public_key_1024_from_external(public_key);
1284  mlkem_encap_external_entropy(out_ciphertext, out_shared_secret, pub, entropy);
1285  return bcm_infallible::approved;
1286}
1287
1288bcm_status BCM_mlkem768_decap(
1289    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
1290    const uint8_t *ciphertext, size_t ciphertext_len,
1291    const struct BCM_mlkem768_private_key *private_key) {
1292  if (ciphertext_len != BCM_MLKEM768_CIPHERTEXT_BYTES) {
1293    BCM_rand_bytes(out_shared_secret, BCM_MLKEM_SHARED_SECRET_BYTES);
1294    return bcm_status::failure;
1295  }
1296  const struct mlkem::private_key<RANK768> *priv =
1297      mlkem::private_key_768_from_external(private_key);
1298  mlkem_decap(out_shared_secret, ciphertext, priv);
1299  return bcm_status::approved;
1300}
1301
1302bcm_status BCM_mlkem1024_decap(
1303    uint8_t out_shared_secret[BCM_MLKEM_SHARED_SECRET_BYTES],
1304    const uint8_t *ciphertext, size_t ciphertext_len,
1305    const struct BCM_mlkem1024_private_key *private_key) {
1306  if (ciphertext_len != BCM_MLKEM1024_CIPHERTEXT_BYTES) {
1307    BCM_rand_bytes(out_shared_secret, BCM_MLKEM_SHARED_SECRET_BYTES);
1308    return bcm_status::failure;
1309  }
1310  const struct mlkem::private_key<RANK1024> *priv =
1311      mlkem::private_key_1024_from_external(private_key);
1312  mlkem_decap(out_shared_secret, ciphertext, priv);
1313  return bcm_status::approved;
1314}
1315
1316bcm_status BCM_mlkem768_marshal_public_key(
1317    CBB *out, const struct BCM_mlkem768_public_key *public_key) {
1318  return mlkem_marshal_public_key(
1319      out, mlkem::public_key_768_from_external(public_key));
1320}
1321
1322bcm_status BCM_mlkem1024_marshal_public_key(
1323    CBB *out, const struct BCM_mlkem1024_public_key *public_key) {
1324  return mlkem_marshal_public_key(
1325      out, mlkem::public_key_1024_from_external(public_key));
1326}
1327
1328bcm_status BCM_mlkem768_parse_public_key(
1329    struct BCM_mlkem768_public_key *public_key, CBS *in) {
1330  struct mlkem::public_key<RANK768> *pub =
1331      mlkem::public_key_768_from_external(public_key);
1332  if (!mlkem_parse_public_key(pub, in)) {
1333    return bcm_status::failure;
1334  }
1335  return bcm_status::approved;
1336}
1337
1338bcm_status BCM_mlkem1024_parse_public_key(
1339    struct BCM_mlkem1024_public_key *public_key, CBS *in) {
1340  struct mlkem::public_key<RANK1024> *pub =
1341      mlkem::public_key_1024_from_external(public_key);
1342  if (!mlkem_parse_public_key(pub, in)) {
1343    return bcm_status::failure;
1344  }
1345  return bcm_status::approved;
1346}
1347
1348bcm_status BCM_mlkem768_marshal_private_key(
1349    CBB *out, const struct BCM_mlkem768_private_key *private_key) {
1350  const struct mlkem::private_key<RANK768> *const priv =
1351      mlkem::private_key_768_from_external(private_key);
1352  if (!mlkem_marshal_private_key(out, priv)) {
1353    return bcm_status::failure;
1354  }
1355  return bcm_status::approved;
1356}
1357
1358bcm_status BCM_mlkem1024_marshal_private_key(
1359    CBB *out, const struct BCM_mlkem1024_private_key *private_key) {
1360  const struct mlkem::private_key<RANK1024> *const priv =
1361      mlkem::private_key_1024_from_external(private_key);
1362  if (!mlkem_marshal_private_key(out, priv)) {
1363    return bcm_status::failure;
1364  }
1365  return bcm_status::approved;
1366}
1367
1368bcm_status BCM_mlkem768_parse_private_key(
1369    struct BCM_mlkem768_private_key *out_private_key, CBS *in) {
1370  struct mlkem::private_key<RANK768> *const priv =
1371      mlkem::private_key_768_from_external(out_private_key);
1372  if (!mlkem_parse_private_key(priv, in)) {
1373    return bcm_status::failure;
1374  }
1375  return bcm_status::approved;
1376}
1377
1378bcm_status BCM_mlkem1024_parse_private_key(
1379    struct BCM_mlkem1024_private_key *out_private_key, CBS *in) {
1380  struct mlkem::private_key<RANK1024> *const priv =
1381      mlkem::private_key_1024_from_external(out_private_key);
1382  if (!mlkem_parse_private_key(priv, in)) {
1383    return bcm_status::failure;
1384  }
1385  return bcm_status::approved;
1386}
1387
1388int boringssl_self_test_mlkem() {
1389  return mlkem::fips::keygen_self_test() && mlkem::fips::encap_self_test() &&
1390         mlkem::fips::decap_self_test();
1391}
1392