• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright (c) 2023, Google Inc.
2  *
3  * Permission to use, copy, modify, and/or distribute this software for any
4  * purpose with or without fee is hereby granted, provided that the above
5  * copyright notice and this permission notice appear in all copies.
6  *
7  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14 
15 #include <openssl/kyber.h>
16 
17 #include <assert.h>
18 #include <stdlib.h>
19 
20 #include <openssl/bytestring.h>
21 #include <openssl/rand.h>
22 
23 #include "../internal.h"
24 #include "../keccak/internal.h"
25 #include "./internal.h"
26 
27 
28 // See
29 // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
30 
31 #define DEGREE 256
32 #define RANK 3
33 
34 static const size_t kBarrettMultiplier = 5039;
35 static const unsigned kBarrettShift = 24;
36 static const uint16_t kPrime = 3329;
37 static const int kLog2Prime = 12;
38 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
39 static const int kDU = 10;
40 static const int kDV = 4;
41 // kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
42 // root of unity.
43 static const uint16_t kInverseDegree = 3303;
44 static const size_t kEncodedVectorSize =
45     (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
46 static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
47 
48 typedef struct scalar {
49   // On every function entry and exit, 0 <= c < kPrime.
50   uint16_t c[DEGREE];
51 } scalar;
52 
53 typedef struct vector {
54   scalar v[RANK];
55 } vector;
56 
57 typedef struct matrix {
58   scalar v[RANK][RANK];
59 } matrix;
60 
61 // This bit of Python will be referenced in some of the following comments:
62 //
63 // p = 3329
64 //
65 // def bitreverse(i):
66 //     ret = 0
67 //     for n in range(7):
68 //         bit = i & 1
69 //         ret <<= 1
70 //         ret |= bit
71 //         i >>= 1
72 //     return ret
73 
74 // kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
75 static const uint16_t kNTTRoots[128] = {
76     1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
77     2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
78     1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
79     1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
80     2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
81     2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
82     1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
83     1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
84     1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
85     2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
86     1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
87 };
88 
89 // kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
90 static const uint16_t kInverseNTTRoots[128] = {
91     1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
92     2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
93     1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
94     2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
95     1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
96     1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
97     2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
98     2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
99     2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
100     1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
101     2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
102 };
103 
104 // kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
105 static const uint16_t kModRoots[128] = {
106     17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
107     2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
108     756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
109     2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
110     939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
111     268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
112     375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
113     2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
114     2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
115     2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
116     2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
117 };
118 
119 // reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
reduce_once(uint16_t x)120 static uint16_t reduce_once(uint16_t x) {
121   assert(x < 2 * kPrime);
122   const uint16_t subtracted = x - kPrime;
123   uint16_t mask = 0u - (subtracted >> 15);
124   // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
125   // overall and Clang still produces constant-time code using `csel`. On other
126   // platforms & compilers on godbolt that we care about, this code also
127   // produces constant-time output.
128   return (mask & x) | (~mask & subtracted);
129 }
130 
131 // constant time reduce x mod kPrime using Barrett reduction. x must be less
132 // than kPrime + 2×kPrime².
reduce(uint32_t x)133 static uint16_t reduce(uint32_t x) {
134   assert(x < kPrime + 2u * kPrime * kPrime);
135   uint64_t product = (uint64_t)x * kBarrettMultiplier;
136   uint32_t quotient = (uint32_t)(product >> kBarrettShift);
137   uint32_t remainder = x - quotient * kPrime;
138   return reduce_once(remainder);
139 }
140 
scalar_zero(scalar * out)141 static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
142 
vector_zero(vector * out)143 static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
144 
145 // In place number theoretic transform of a given scalar.
146 // Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
147 // transform leaves off the last iteration of the usual FFT code, with the 128
148 // relevant roots of unity being stored in |kNTTRoots|. This means the output
149 // should be seen as 128 elements in GF(3329^2), with the coefficients of the
150 // elements being consecutive entries in |s->c|.
scalar_ntt(scalar * s)151 static void scalar_ntt(scalar *s) {
152   int offset = DEGREE;
153   // `int` is used here because using `size_t` throughout caused a ~5% slowdown
154   // with Clang 14 on Aarch64.
155   for (int step = 1; step < DEGREE / 2; step <<= 1) {
156     offset >>= 1;
157     int k = 0;
158     for (int i = 0; i < step; i++) {
159       const uint32_t step_root = kNTTRoots[i + step];
160       for (int j = k; j < k + offset; j++) {
161         uint16_t odd = reduce(step_root * s->c[j + offset]);
162         uint16_t even = s->c[j];
163         s->c[j] = reduce_once(odd + even);
164         s->c[j + offset] = reduce_once(even - odd + kPrime);
165       }
166       k += 2 * offset;
167     }
168   }
169 }
170 
vector_ntt(vector * a)171 static void vector_ntt(vector *a) {
172   for (int i = 0; i < RANK; i++) {
173     scalar_ntt(&a->v[i]);
174   }
175 }
176 
177 // In place inverse number theoretic transform of a given scalar, with pairs of
178 // entries of s->v being interpreted as elements of GF(3329^2). Just as with the
179 // number theoretic transform, this leaves off the first step of the normal iFFT
180 // to account for the fact that 3329 does not have a 512th root of unity, using
181 // the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
scalar_inverse_ntt(scalar * s)182 static void scalar_inverse_ntt(scalar *s) {
183   int step = DEGREE / 2;
184   // `int` is used here because using `size_t` throughout caused a ~5% slowdown
185   // with Clang 14 on Aarch64.
186   for (int offset = 2; offset < DEGREE; offset <<= 1) {
187     step >>= 1;
188     int k = 0;
189     for (int i = 0; i < step; i++) {
190       uint32_t step_root = kInverseNTTRoots[i + step];
191       for (int j = k; j < k + offset; j++) {
192         uint16_t odd = s->c[j + offset];
193         uint16_t even = s->c[j];
194         s->c[j] = reduce_once(odd + even);
195         s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
196       }
197       k += 2 * offset;
198     }
199   }
200   for (int i = 0; i < DEGREE; i++) {
201     s->c[i] = reduce(s->c[i] * kInverseDegree);
202   }
203 }
204 
vector_inverse_ntt(vector * a)205 static void vector_inverse_ntt(vector *a) {
206   for (int i = 0; i < RANK; i++) {
207     scalar_inverse_ntt(&a->v[i]);
208   }
209 }
210 
scalar_add(scalar * lhs,const scalar * rhs)211 static void scalar_add(scalar *lhs, const scalar *rhs) {
212   for (int i = 0; i < DEGREE; i++) {
213     lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
214   }
215 }
216 
scalar_sub(scalar * lhs,const scalar * rhs)217 static void scalar_sub(scalar *lhs, const scalar *rhs) {
218   for (int i = 0; i < DEGREE; i++) {
219     lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
220   }
221 }
222 
223 // Multiplying two scalars in the number theoretically transformed state. Since
224 // 3329 does not have a 512th root of unity, this means we have to interpret
225 // the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
226 // - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
227 // stored in the precomputed |kModRoots| table. Note that our Barrett transform
228 // only allows us to multipy two reduced numbers together, so we need some
229 // intermediate reduction steps, even if an uint64_t could hold 3 multiplied
230 // numbers.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)231 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
232   for (int i = 0; i < DEGREE / 2; i++) {
233     uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
234     uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
235     uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
236     uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
237     out->c[2 * i] =
238         reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
239     out->c[2 * i + 1] = reduce(img_real + real_img);
240   }
241 }
242 
vector_add(vector * lhs,const vector * rhs)243 static void vector_add(vector *lhs, const vector *rhs) {
244   for (int i = 0; i < RANK; i++) {
245     scalar_add(&lhs->v[i], &rhs->v[i]);
246   }
247 }
248 
matrix_mult(vector * out,const matrix * m,const vector * a)249 static void matrix_mult(vector *out, const matrix *m, const vector *a) {
250   vector_zero(out);
251   for (int i = 0; i < RANK; i++) {
252     for (int j = 0; j < RANK; j++) {
253       scalar product;
254       scalar_mult(&product, &m->v[i][j], &a->v[j]);
255       scalar_add(&out->v[i], &product);
256     }
257   }
258 }
259 
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)260 static void matrix_mult_transpose(vector *out, const matrix *m,
261                                   const vector *a) {
262   vector_zero(out);
263   for (int i = 0; i < RANK; i++) {
264     for (int j = 0; j < RANK; j++) {
265       scalar product;
266       scalar_mult(&product, &m->v[j][i], &a->v[j]);
267       scalar_add(&out->v[i], &product);
268     }
269   }
270 }
271 
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)272 static void scalar_inner_product(scalar *out, const vector *lhs,
273                                  const vector *rhs) {
274   scalar_zero(out);
275   for (int i = 0; i < RANK; i++) {
276     scalar product;
277     scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
278     scalar_add(out, &product);
279   }
280 }
281 
282 // Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
283 // uniformly distributed elements. This is used for matrix expansion and only
284 // operates on public inputs.
scalar_from_keccak_vartime(scalar * out,struct BORINGSSL_keccak_st * keccak_ctx)285 static void scalar_from_keccak_vartime(scalar *out,
286                                        struct BORINGSSL_keccak_st *keccak_ctx) {
287   assert(keccak_ctx->squeeze_offset == 0);
288   assert(keccak_ctx->rate_bytes == 168);
289   static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
290 
291   int done = 0;
292   while (done < DEGREE) {
293     uint8_t block[168];
294     BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
295     for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
296       uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
297       uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
298       if (d1 < kPrime) {
299         out->c[done++] = d1;
300       }
301       if (d2 < kPrime && done < DEGREE) {
302         out->c[done++] = d2;
303       }
304     }
305   }
306 }
307 
308 // Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
309 // included. Creates binominally distributed elements by sampling 2*|eta| bits,
310 // and setting the coefficient to the count of the first bits minus the count of
311 // the second bits, resulting in a centered binomial distribution. Since eta is
312 // two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
313 // and 0 with probability 3/8.
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])314 static void scalar_centered_binomial_distribution_eta_2_with_prf(
315     scalar *out, const uint8_t input[33]) {
316   uint8_t entropy[128];
317   static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
318   BORINGSSL_keccak(entropy, sizeof(entropy), input, 33, boringssl_shake256);
319 
320   for (int i = 0; i < DEGREE; i += 2) {
321     uint8_t byte = entropy[i / 2];
322 
323     uint16_t value = kPrime;
324     value += (byte & 1) + ((byte >> 1) & 1);
325     value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
326     out->c[i] = reduce_once(value);
327 
328     byte >>= 4;
329     value = kPrime;
330     value += (byte & 1) + ((byte >> 1) & 1);
331     value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
332     out->c[i + 1] = reduce_once(value);
333   }
334 }
335 
336 // Generates a secret vector by using
337 // |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
338 // appending and incrementing |counter| for entry of the vector.
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])339 static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
340                                          const uint8_t seed[32]) {
341   uint8_t input[33];
342   OPENSSL_memcpy(input, seed, 32);
343   for (int i = 0; i < RANK; i++) {
344     input[32] = (*counter)++;
345     scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
346   }
347 }
348 
349 // Expands the matrix of a seed for key generation and for encaps-CPA.
matrix_expand(matrix * out,const uint8_t rho[32])350 static void matrix_expand(matrix *out, const uint8_t rho[32]) {
351   uint8_t input[34];
352   OPENSSL_memcpy(input, rho, 32);
353   for (int i = 0; i < RANK; i++) {
354     for (int j = 0; j < RANK; j++) {
355       input[32] = i;
356       input[33] = j;
357       struct BORINGSSL_keccak_st keccak_ctx;
358       BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
359       BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
360       scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
361     }
362   }
363 }
364 
365 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
366                                   0x1f, 0x3f, 0x7f, 0xff};
367 
scalar_encode(uint8_t * out,const scalar * s,int bits)368 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
369   assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
370 
371   uint8_t out_byte = 0;
372   int out_byte_bits = 0;
373 
374   for (int i = 0; i < DEGREE; i++) {
375     uint16_t element = s->c[i];
376     int element_bits_done = 0;
377 
378     while (element_bits_done < bits) {
379       int chunk_bits = bits - element_bits_done;
380       int out_bits_remaining = 8 - out_byte_bits;
381       if (chunk_bits >= out_bits_remaining) {
382         chunk_bits = out_bits_remaining;
383         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
384         *out = out_byte;
385         out++;
386         out_byte_bits = 0;
387         out_byte = 0;
388       } else {
389         out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
390         out_byte_bits += chunk_bits;
391       }
392 
393       element_bits_done += chunk_bits;
394       element >>= chunk_bits;
395     }
396   }
397 
398   if (out_byte_bits > 0) {
399     *out = out_byte;
400   }
401 }
402 
403 // scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
scalar_encode_1(uint8_t out[32],const scalar * s)404 static void scalar_encode_1(uint8_t out[32], const scalar *s) {
405   for (int i = 0; i < DEGREE; i += 8) {
406     uint8_t out_byte = 0;
407     for (int j = 0; j < 8; j++) {
408       out_byte |= (s->c[i + j] & 1) << j;
409     }
410     *out = out_byte;
411     out++;
412   }
413 }
414 
415 // Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
416 // (DEGREE) is divisible by 8, the individual vector entries will always fill a
417 // whole number of bytes, so we do not need to worry about bit packing here.
vector_encode(uint8_t * out,const vector * a,int bits)418 static void vector_encode(uint8_t *out, const vector *a, int bits) {
419   for (int i = 0; i < RANK; i++) {
420     scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
421   }
422 }
423 
424 // scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
425 // |out|. It returns one on success and zero if any parsed value is >=
426 // |kPrime|.
scalar_decode(scalar * out,const uint8_t * in,int bits)427 static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
428   assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
429 
430   uint8_t in_byte = 0;
431   int in_byte_bits_left = 0;
432 
433   for (int i = 0; i < DEGREE; i++) {
434     uint16_t element = 0;
435     int element_bits_done = 0;
436 
437     while (element_bits_done < bits) {
438       if (in_byte_bits_left == 0) {
439         in_byte = *in;
440         in++;
441         in_byte_bits_left = 8;
442       }
443 
444       int chunk_bits = bits - element_bits_done;
445       if (chunk_bits > in_byte_bits_left) {
446         chunk_bits = in_byte_bits_left;
447       }
448 
449       element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
450       in_byte_bits_left -= chunk_bits;
451       in_byte >>= chunk_bits;
452 
453       element_bits_done += chunk_bits;
454     }
455 
456     if (element >= kPrime) {
457       return 0;
458     }
459     out->c[i] = element;
460   }
461 
462   return 1;
463 }
464 
465 // scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
scalar_decode_1(scalar * out,const uint8_t in[32])466 static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
467   for (int i = 0; i < DEGREE; i += 8) {
468     uint8_t in_byte = *in;
469     in++;
470     for (int j = 0; j < 8; j++) {
471       out->c[i + j] = in_byte & 1;
472       in_byte >>= 1;
473     }
474   }
475 }
476 
477 // Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
478 // success or zero if any parsed value is >= |kPrime|.
vector_decode(vector * out,const uint8_t * in,int bits)479 static int vector_decode(vector *out, const uint8_t *in, int bits) {
480   for (int i = 0; i < RANK; i++) {
481     if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
482       return 0;
483     }
484   }
485   return 1;
486 }
487 
488 // Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
489 // numbers close to each other together. The formula used is
490 // round(2^|bits|/kPrime*x) mod 2^|bits|.
491 // Uses Barrett reduction to achieve constant time. Since we need both the
492 // remainder (for rounding) and the quotient (as the result), we cannot use
493 // |reduce| here, but need to do the Barrett reduction directly.
compress(uint16_t x,int bits)494 static uint16_t compress(uint16_t x, int bits) {
495   uint32_t shifted = (uint32_t)x << bits;
496   uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
497   uint32_t quotient = (uint32_t)(product >> kBarrettShift);
498   uint32_t remainder = shifted - quotient * kPrime;
499 
500   // Adjust the quotient to round correctly:
501   //   0 <= remainder <= kHalfPrime round to 0
502   //   kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
503   //   kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
504   assert(remainder < 2u * kPrime);
505   quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
506   quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
507   return quotient & ((1 << bits) - 1);
508 }
509 
510 // Decompresses |x| by using an equi-distant representative. The formula is
511 // round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
512 // implement this logic using only bit operations.
decompress(uint16_t x,int bits)513 static uint16_t decompress(uint16_t x, int bits) {
514   uint32_t product = (uint32_t)x * kPrime;
515   uint32_t power = 1 << bits;
516   // This is |product| % power, since |power| is a power of 2.
517   uint32_t remainder = product & (power - 1);
518   // This is |product| / power, since |power| is a power of 2.
519   uint32_t lower = product >> bits;
520   // The rounding logic works since the first half of numbers mod |power| have a
521   // 0 as first bit, and the second half has a 1 as first bit, since |power| is
522   // a power of 2. As a 12 bit number, |remainder| is always positive, so we
523   // will shift in 0s for a right shift.
524   return lower + (remainder >> (bits - 1));
525 }
526 
scalar_compress(scalar * s,int bits)527 static void scalar_compress(scalar *s, int bits) {
528   for (int i = 0; i < DEGREE; i++) {
529     s->c[i] = compress(s->c[i], bits);
530   }
531 }
532 
scalar_decompress(scalar * s,int bits)533 static void scalar_decompress(scalar *s, int bits) {
534   for (int i = 0; i < DEGREE; i++) {
535     s->c[i] = decompress(s->c[i], bits);
536   }
537 }
538 
vector_compress(vector * a,int bits)539 static void vector_compress(vector *a, int bits) {
540   for (int i = 0; i < RANK; i++) {
541     scalar_compress(&a->v[i], bits);
542   }
543 }
544 
vector_decompress(vector * a,int bits)545 static void vector_decompress(vector *a, int bits) {
546   for (int i = 0; i < RANK; i++) {
547     scalar_decompress(&a->v[i], bits);
548   }
549 }
550 
551 struct public_key {
552   vector t;
553   uint8_t rho[32];
554   uint8_t public_key_hash[32];
555   matrix m;
556 };
557 
public_key_from_external(const struct KYBER_public_key * external)558 static struct public_key *public_key_from_external(
559     const struct KYBER_public_key *external) {
560   static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
561                 "Kyber public key is too small");
562   static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
563                 "Kyber public key align incorrect");
564   return (struct public_key *)external;
565 }
566 
567 struct private_key {
568   struct public_key pub;
569   vector s;
570   uint8_t fo_failure_secret[32];
571 };
572 
private_key_from_external(const struct KYBER_private_key * external)573 static struct private_key *private_key_from_external(
574     const struct KYBER_private_key *external) {
575   static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
576                 "Kyber private key too small");
577   static_assert(
578       alignof(struct KYBER_private_key) >= alignof(struct private_key),
579       "Kyber private key align incorrect");
580   return (struct private_key *)external;
581 }
582 
583 // Calls |KYBER_generate_key_external_entropy| with random bytes from
584 // |RAND_bytes|.
KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key)585 void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
586                         struct KYBER_private_key *out_private_key) {
587   uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
588   RAND_bytes(entropy, sizeof(entropy));
589   KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
590                                       entropy);
591 }
592 
kyber_marshal_public_key(CBB * out,const struct public_key * pub)593 static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
594   uint8_t *vector_output;
595   if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
596     return 0;
597   }
598   vector_encode(vector_output, &pub->t, kLog2Prime);
599   if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
600     return 0;
601   }
602   return 1;
603 }
604 
605 // Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
606 // generation is not part of the FO transform, and the spec uses Algorithm 7 to
607 // specify the actual key format.
KYBER_generate_key_external_entropy(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key,const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY])608 void KYBER_generate_key_external_entropy(
609     uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
610     struct KYBER_private_key *out_private_key,
611     const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
612   struct private_key *priv = private_key_from_external(out_private_key);
613   uint8_t hashed[64];
614   BORINGSSL_keccak(hashed, sizeof(hashed), entropy, 32, boringssl_sha3_512);
615   const uint8_t *const rho = hashed;
616   const uint8_t *const sigma = hashed + 32;
617   OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
618   matrix_expand(&priv->pub.m, rho);
619   uint8_t counter = 0;
620   vector_generate_secret_eta_2(&priv->s, &counter, sigma);
621   vector_ntt(&priv->s);
622   vector error;
623   vector_generate_secret_eta_2(&error, &counter, sigma);
624   vector_ntt(&error);
625   matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
626   vector_add(&priv->pub.t, &error);
627 
628   CBB cbb;
629   CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
630   if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
631     abort();
632   }
633 
634   BORINGSSL_keccak(priv->pub.public_key_hash, sizeof(priv->pub.public_key_hash),
635                    out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES,
636                    boringssl_sha3_256);
637   OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
638 }
639 
KYBER_public_from_private(struct KYBER_public_key * out_public_key,const struct KYBER_private_key * private_key)640 void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
641                                const struct KYBER_private_key *private_key) {
642   struct public_key *const pub = public_key_from_external(out_public_key);
643   const struct private_key *const priv = private_key_from_external(private_key);
644   *pub = priv->pub;
645 }
646 
647 // Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
648 // the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
649 // would not result in a CCA secure scheme, since lattice schemes are vulnerable
650 // to decryption failure oracles.
encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])651 static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
652                         const struct public_key *pub, const uint8_t message[32],
653                         const uint8_t randomness[32]) {
654   uint8_t counter = 0;
655   vector secret;
656   vector_generate_secret_eta_2(&secret, &counter, randomness);
657   vector_ntt(&secret);
658   vector error;
659   vector_generate_secret_eta_2(&error, &counter, randomness);
660   uint8_t input[33];
661   OPENSSL_memcpy(input, randomness, 32);
662   input[32] = counter;
663   scalar scalar_error;
664   scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
665   vector u;
666   matrix_mult(&u, &pub->m, &secret);
667   vector_inverse_ntt(&u);
668   vector_add(&u, &error);
669   scalar v;
670   scalar_inner_product(&v, &pub->t, &secret);
671   scalar_inverse_ntt(&v);
672   scalar_add(&v, &scalar_error);
673   scalar expanded_message;
674   scalar_decode_1(&expanded_message, message);
675   scalar_decompress(&expanded_message, 1);
676   scalar_add(&v, &expanded_message);
677   vector_compress(&u, kDU);
678   vector_encode(out, &u, kDU);
679   scalar_compress(&v, kDV);
680   scalar_encode(out + kCompressedVectorSize, &v, kDV);
681 }
682 
683 // Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t * out_shared_secret,size_t out_shared_secret_len,const struct KYBER_public_key * public_key)684 void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
685                  uint8_t *out_shared_secret, size_t out_shared_secret_len,
686                  const struct KYBER_public_key *public_key) {
687   uint8_t entropy[KYBER_ENCAP_ENTROPY];
688   RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
689   KYBER_encap_external_entropy(out_ciphertext, out_shared_secret,
690                                out_shared_secret_len, public_key, entropy);
691 }
692 
693 // Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
694 // hashes the output of the system's random number generator, since the FO
695 // transform will reveal it to the decrypting party. There is no reason to do
696 // this when a secure random number generator is used. When an insecure random
697 // number generator is used, the caller should switch to a secure one before
698 // calling this method.
KYBER_encap_external_entropy(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t * out_shared_secret,size_t out_shared_secret_len,const struct KYBER_public_key * public_key,const uint8_t entropy[KYBER_ENCAP_ENTROPY])699 void KYBER_encap_external_entropy(
700     uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES], uint8_t *out_shared_secret,
701     size_t out_shared_secret_len, const struct KYBER_public_key *public_key,
702     const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
703   const struct public_key *pub = public_key_from_external(public_key);
704   uint8_t input[64];
705   OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
706   OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
707                  sizeof(input) - KYBER_ENCAP_ENTROPY);
708   uint8_t prekey_and_randomness[64];
709   BORINGSSL_keccak(prekey_and_randomness, sizeof(prekey_and_randomness), input,
710                    sizeof(input), boringssl_sha3_512);
711   encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
712   BORINGSSL_keccak(prekey_and_randomness + 32, 32, out_ciphertext,
713                    KYBER_CIPHERTEXT_BYTES, boringssl_sha3_256);
714   BORINGSSL_keccak(out_shared_secret, out_shared_secret_len,
715                    prekey_and_randomness, sizeof(prekey_and_randomness),
716                    boringssl_shake256);
717 }
718 
719 // Algorithm 6 of the Kyber spec.
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES])720 static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
721                         const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
722   vector u;
723   vector_decode(&u, ciphertext, kDU);
724   vector_decompress(&u, kDU);
725   vector_ntt(&u);
726   scalar v;
727   scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
728   scalar_decompress(&v, kDV);
729   scalar mask;
730   scalar_inner_product(&mask, &priv->s, &u);
731   scalar_inverse_ntt(&mask);
732   scalar_sub(&v, &mask);
733   scalar_compress(&v, 1);
734   scalar_encode_1(out, &v);
735 }
736 
737 // Algorithm 9 of the Kyber spec, performing the FO transform by running
738 // encrypt_cpa on the decrypted message. The spec does not allow the decryption
739 // failure to be passed on to the caller, and instead returns a result that is
740 // deterministic but unpredictable to anyone without knowledge of the private
741 // key.
KYBER_decap(uint8_t * out_shared_secret,size_t out_shared_secret_len,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],const struct KYBER_private_key * private_key)742 void KYBER_decap(uint8_t *out_shared_secret, size_t out_shared_secret_len,
743                  const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
744                  const struct KYBER_private_key *private_key) {
745   const struct private_key *priv = private_key_from_external(private_key);
746   uint8_t decrypted[64];
747   decrypt_cpa(decrypted, priv, ciphertext);
748   OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
749                  sizeof(decrypted) - 32);
750   uint8_t prekey_and_randomness[64];
751   BORINGSSL_keccak(prekey_and_randomness, sizeof(prekey_and_randomness),
752                    decrypted, sizeof(decrypted), boringssl_sha3_512);
753   uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
754   encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
755               prekey_and_randomness + 32);
756   uint8_t mask =
757       constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
758                                            sizeof(expected_ciphertext)),
759                              0);
760   uint8_t input[64];
761   for (int i = 0; i < 32; i++) {
762     input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
763                                       priv->fo_failure_secret[i]);
764   }
765   BORINGSSL_keccak(input + 32, 32, ciphertext, KYBER_CIPHERTEXT_BYTES,
766                    boringssl_sha3_256);
767   BORINGSSL_keccak(out_shared_secret, out_shared_secret_len, input,
768                    sizeof(input), boringssl_shake256);
769 }
770 
KYBER_marshal_public_key(CBB * out,const struct KYBER_public_key * public_key)771 int KYBER_marshal_public_key(CBB *out,
772                              const struct KYBER_public_key *public_key) {
773   return kyber_marshal_public_key(out, public_key_from_external(public_key));
774 }
775 
776 // kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
777 // the value of |pub->public_key_hash|.
kyber_parse_public_key_no_hash(struct public_key * pub,CBS * in)778 static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
779   CBS t_bytes;
780   if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
781       !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
782       !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
783     return 0;
784   }
785   matrix_expand(&pub->m, pub->rho);
786   return 1;
787 }
788 
KYBER_parse_public_key(struct KYBER_public_key * public_key,CBS * in)789 int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
790   struct public_key *pub = public_key_from_external(public_key);
791   CBS orig_in = *in;
792   if (!kyber_parse_public_key_no_hash(pub, in) ||  //
793       CBS_len(in) != 0) {
794     return 0;
795   }
796   BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
797                    CBS_data(&orig_in), CBS_len(&orig_in), boringssl_sha3_256);
798   return 1;
799 }
800 
KYBER_marshal_private_key(CBB * out,const struct KYBER_private_key * private_key)801 int KYBER_marshal_private_key(CBB *out,
802                               const struct KYBER_private_key *private_key) {
803   const struct private_key *const priv = private_key_from_external(private_key);
804   uint8_t *s_output;
805   if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
806     return 0;
807   }
808   vector_encode(s_output, &priv->s, kLog2Prime);
809   if (!kyber_marshal_public_key(out, &priv->pub) ||
810       !CBB_add_bytes(out, priv->pub.public_key_hash,
811                      sizeof(priv->pub.public_key_hash)) ||
812       !CBB_add_bytes(out, priv->fo_failure_secret,
813                      sizeof(priv->fo_failure_secret))) {
814     return 0;
815   }
816   return 1;
817 }
818 
KYBER_parse_private_key(struct KYBER_private_key * out_private_key,CBS * in)819 int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
820                             CBS *in) {
821   struct private_key *const priv = private_key_from_external(out_private_key);
822 
823   CBS s_bytes;
824   if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
825       !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
826       !kyber_parse_public_key_no_hash(&priv->pub, in) ||
827       !CBS_copy_bytes(in, priv->pub.public_key_hash,
828                       sizeof(priv->pub.public_key_hash)) ||
829       !CBS_copy_bytes(in, priv->fo_failure_secret,
830                       sizeof(priv->fo_failure_secret)) ||
831       CBS_len(in) != 0) {
832     return 0;
833   }
834   return 1;
835 }
836