1 /* Copyright (c) 2023, Google Inc.
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #include <openssl/kyber.h>
16
17 #include <assert.h>
18 #include <stdlib.h>
19
20 #include <openssl/bytestring.h>
21 #include <openssl/rand.h>
22
23 #include "../internal.h"
24 #include "../keccak/internal.h"
25 #include "./internal.h"
26
27
28 // See
29 // https://pq-crystals.org/kyber/data/kyber-specification-round3-20210804.pdf
30
31 #define DEGREE 256
32 #define RANK 3
33
34 static const size_t kBarrettMultiplier = 5039;
35 static const unsigned kBarrettShift = 24;
36 static const uint16_t kPrime = 3329;
37 static const int kLog2Prime = 12;
38 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
39 static const int kDU = 10;
40 static const int kDV = 4;
41 // kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
42 // root of unity.
43 static const uint16_t kInverseDegree = 3303;
44 static const size_t kEncodedVectorSize =
45 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK;
46 static const size_t kCompressedVectorSize = /*kDU=*/10 * RANK * DEGREE / 8;
47
48 typedef struct scalar {
49 // On every function entry and exit, 0 <= c < kPrime.
50 uint16_t c[DEGREE];
51 } scalar;
52
53 typedef struct vector {
54 scalar v[RANK];
55 } vector;
56
57 typedef struct matrix {
58 scalar v[RANK][RANK];
59 } matrix;
60
61 // This bit of Python will be referenced in some of the following comments:
62 //
63 // p = 3329
64 //
65 // def bitreverse(i):
66 // ret = 0
67 // for n in range(7):
68 // bit = i & 1
69 // ret <<= 1
70 // ret |= bit
71 // i >>= 1
72 // return ret
73
74 // kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)]
75 static const uint16_t kNTTRoots[128] = {
76 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
77 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
78 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
79 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
80 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
81 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
82 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
83 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
84 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
85 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
86 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
87 };
88
89 // kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)]
90 static const uint16_t kInverseNTTRoots[128] = {
91 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
92 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
93 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
94 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
95 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
96 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
97 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
98 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
99 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
100 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
101 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
102 };
103
104 // kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)]
105 static const uint16_t kModRoots[128] = {
106 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
107 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
108 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
109 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
110 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
111 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
112 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
113 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
114 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
115 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
116 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
117 };
118
119 // reduce_once reduces 0 <= x < 2*kPrime, mod kPrime.
reduce_once(uint16_t x)120 static uint16_t reduce_once(uint16_t x) {
121 assert(x < 2 * kPrime);
122 const uint16_t subtracted = x - kPrime;
123 uint16_t mask = 0u - (subtracted >> 15);
124 // On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of Kyber
125 // overall and Clang still produces constant-time code using `csel`. On other
126 // platforms & compilers on godbolt that we care about, this code also
127 // produces constant-time output.
128 return (mask & x) | (~mask & subtracted);
129 }
130
131 // constant time reduce x mod kPrime using Barrett reduction. x must be less
132 // than kPrime + 2×kPrime².
reduce(uint32_t x)133 static uint16_t reduce(uint32_t x) {
134 assert(x < kPrime + 2u * kPrime * kPrime);
135 uint64_t product = (uint64_t)x * kBarrettMultiplier;
136 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
137 uint32_t remainder = x - quotient * kPrime;
138 return reduce_once(remainder);
139 }
140
scalar_zero(scalar * out)141 static void scalar_zero(scalar *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
142
vector_zero(vector * out)143 static void vector_zero(vector *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
144
145 // In place number theoretic transform of a given scalar.
146 // Note that Kyber's kPrime 3329 does not have a 512th root of unity, so this
147 // transform leaves off the last iteration of the usual FFT code, with the 128
148 // relevant roots of unity being stored in |kNTTRoots|. This means the output
149 // should be seen as 128 elements in GF(3329^2), with the coefficients of the
150 // elements being consecutive entries in |s->c|.
scalar_ntt(scalar * s)151 static void scalar_ntt(scalar *s) {
152 int offset = DEGREE;
153 // `int` is used here because using `size_t` throughout caused a ~5% slowdown
154 // with Clang 14 on Aarch64.
155 for (int step = 1; step < DEGREE / 2; step <<= 1) {
156 offset >>= 1;
157 int k = 0;
158 for (int i = 0; i < step; i++) {
159 const uint32_t step_root = kNTTRoots[i + step];
160 for (int j = k; j < k + offset; j++) {
161 uint16_t odd = reduce(step_root * s->c[j + offset]);
162 uint16_t even = s->c[j];
163 s->c[j] = reduce_once(odd + even);
164 s->c[j + offset] = reduce_once(even - odd + kPrime);
165 }
166 k += 2 * offset;
167 }
168 }
169 }
170
vector_ntt(vector * a)171 static void vector_ntt(vector *a) {
172 for (int i = 0; i < RANK; i++) {
173 scalar_ntt(&a->v[i]);
174 }
175 }
176
177 // In place inverse number theoretic transform of a given scalar, with pairs of
178 // entries of s->v being interpreted as elements of GF(3329^2). Just as with the
179 // number theoretic transform, this leaves off the first step of the normal iFFT
180 // to account for the fact that 3329 does not have a 512th root of unity, using
181 // the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
scalar_inverse_ntt(scalar * s)182 static void scalar_inverse_ntt(scalar *s) {
183 int step = DEGREE / 2;
184 // `int` is used here because using `size_t` throughout caused a ~5% slowdown
185 // with Clang 14 on Aarch64.
186 for (int offset = 2; offset < DEGREE; offset <<= 1) {
187 step >>= 1;
188 int k = 0;
189 for (int i = 0; i < step; i++) {
190 uint32_t step_root = kInverseNTTRoots[i + step];
191 for (int j = k; j < k + offset; j++) {
192 uint16_t odd = s->c[j + offset];
193 uint16_t even = s->c[j];
194 s->c[j] = reduce_once(odd + even);
195 s->c[j + offset] = reduce(step_root * (even - odd + kPrime));
196 }
197 k += 2 * offset;
198 }
199 }
200 for (int i = 0; i < DEGREE; i++) {
201 s->c[i] = reduce(s->c[i] * kInverseDegree);
202 }
203 }
204
vector_inverse_ntt(vector * a)205 static void vector_inverse_ntt(vector *a) {
206 for (int i = 0; i < RANK; i++) {
207 scalar_inverse_ntt(&a->v[i]);
208 }
209 }
210
scalar_add(scalar * lhs,const scalar * rhs)211 static void scalar_add(scalar *lhs, const scalar *rhs) {
212 for (int i = 0; i < DEGREE; i++) {
213 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
214 }
215 }
216
scalar_sub(scalar * lhs,const scalar * rhs)217 static void scalar_sub(scalar *lhs, const scalar *rhs) {
218 for (int i = 0; i < DEGREE; i++) {
219 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
220 }
221 }
222
223 // Multiplying two scalars in the number theoretically transformed state. Since
224 // 3329 does not have a 512th root of unity, this means we have to interpret
225 // the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
226 // - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
227 // stored in the precomputed |kModRoots| table. Note that our Barrett transform
228 // only allows us to multipy two reduced numbers together, so we need some
229 // intermediate reduction steps, even if an uint64_t could hold 3 multiplied
230 // numbers.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)231 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
232 for (int i = 0; i < DEGREE / 2; i++) {
233 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
234 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i + 1];
235 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
236 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
237 out->c[2 * i] =
238 reduce(real_real + (uint32_t)reduce(img_img) * kModRoots[i]);
239 out->c[2 * i + 1] = reduce(img_real + real_img);
240 }
241 }
242
vector_add(vector * lhs,const vector * rhs)243 static void vector_add(vector *lhs, const vector *rhs) {
244 for (int i = 0; i < RANK; i++) {
245 scalar_add(&lhs->v[i], &rhs->v[i]);
246 }
247 }
248
matrix_mult(vector * out,const matrix * m,const vector * a)249 static void matrix_mult(vector *out, const matrix *m, const vector *a) {
250 vector_zero(out);
251 for (int i = 0; i < RANK; i++) {
252 for (int j = 0; j < RANK; j++) {
253 scalar product;
254 scalar_mult(&product, &m->v[i][j], &a->v[j]);
255 scalar_add(&out->v[i], &product);
256 }
257 }
258 }
259
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)260 static void matrix_mult_transpose(vector *out, const matrix *m,
261 const vector *a) {
262 vector_zero(out);
263 for (int i = 0; i < RANK; i++) {
264 for (int j = 0; j < RANK; j++) {
265 scalar product;
266 scalar_mult(&product, &m->v[j][i], &a->v[j]);
267 scalar_add(&out->v[i], &product);
268 }
269 }
270 }
271
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)272 static void scalar_inner_product(scalar *out, const vector *lhs,
273 const vector *rhs) {
274 scalar_zero(out);
275 for (int i = 0; i < RANK; i++) {
276 scalar product;
277 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
278 scalar_add(out, &product);
279 }
280 }
281
282 // Algorithm 1 of the Kyber spec. Rejection samples a Keccak stream to get
283 // uniformly distributed elements. This is used for matrix expansion and only
284 // operates on public inputs.
scalar_from_keccak_vartime(scalar * out,struct BORINGSSL_keccak_st * keccak_ctx)285 static void scalar_from_keccak_vartime(scalar *out,
286 struct BORINGSSL_keccak_st *keccak_ctx) {
287 assert(keccak_ctx->squeeze_offset == 0);
288 assert(keccak_ctx->rate_bytes == 168);
289 static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
290
291 int done = 0;
292 while (done < DEGREE) {
293 uint8_t block[168];
294 BORINGSSL_keccak_squeeze(keccak_ctx, block, sizeof(block));
295 for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
296 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
297 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
298 if (d1 < kPrime) {
299 out->c[done++] = d1;
300 }
301 if (d2 < kPrime && done < DEGREE) {
302 out->c[done++] = d2;
303 }
304 }
305 }
306 }
307
308 // Algorithm 2 of the Kyber spec, with eta fixed to two and the PRF call
309 // included. Creates binominally distributed elements by sampling 2*|eta| bits,
310 // and setting the coefficient to the count of the first bits minus the count of
311 // the second bits, resulting in a centered binomial distribution. Since eta is
312 // two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
313 // and 0 with probability 3/8.
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])314 static void scalar_centered_binomial_distribution_eta_2_with_prf(
315 scalar *out, const uint8_t input[33]) {
316 uint8_t entropy[128];
317 static_assert(sizeof(entropy) == 2 * /*kEta=*/2 * DEGREE / 8, "");
318 BORINGSSL_keccak(entropy, sizeof(entropy), input, 33, boringssl_shake256);
319
320 for (int i = 0; i < DEGREE; i += 2) {
321 uint8_t byte = entropy[i / 2];
322
323 uint16_t value = kPrime;
324 value += (byte & 1) + ((byte >> 1) & 1);
325 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
326 out->c[i] = reduce_once(value);
327
328 byte >>= 4;
329 value = kPrime;
330 value += (byte & 1) + ((byte >> 1) & 1);
331 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
332 out->c[i + 1] = reduce_once(value);
333 }
334 }
335
336 // Generates a secret vector by using
337 // |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
338 // appending and incrementing |counter| for entry of the vector.
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])339 static void vector_generate_secret_eta_2(vector *out, uint8_t *counter,
340 const uint8_t seed[32]) {
341 uint8_t input[33];
342 OPENSSL_memcpy(input, seed, 32);
343 for (int i = 0; i < RANK; i++) {
344 input[32] = (*counter)++;
345 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], input);
346 }
347 }
348
349 // Expands the matrix of a seed for key generation and for encaps-CPA.
matrix_expand(matrix * out,const uint8_t rho[32])350 static void matrix_expand(matrix *out, const uint8_t rho[32]) {
351 uint8_t input[34];
352 OPENSSL_memcpy(input, rho, 32);
353 for (int i = 0; i < RANK; i++) {
354 for (int j = 0; j < RANK; j++) {
355 input[32] = i;
356 input[33] = j;
357 struct BORINGSSL_keccak_st keccak_ctx;
358 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
359 BORINGSSL_keccak_absorb(&keccak_ctx, input, sizeof(input));
360 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
361 }
362 }
363 }
364
365 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
366 0x1f, 0x3f, 0x7f, 0xff};
367
scalar_encode(uint8_t * out,const scalar * s,int bits)368 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
369 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
370
371 uint8_t out_byte = 0;
372 int out_byte_bits = 0;
373
374 for (int i = 0; i < DEGREE; i++) {
375 uint16_t element = s->c[i];
376 int element_bits_done = 0;
377
378 while (element_bits_done < bits) {
379 int chunk_bits = bits - element_bits_done;
380 int out_bits_remaining = 8 - out_byte_bits;
381 if (chunk_bits >= out_bits_remaining) {
382 chunk_bits = out_bits_remaining;
383 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
384 *out = out_byte;
385 out++;
386 out_byte_bits = 0;
387 out_byte = 0;
388 } else {
389 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
390 out_byte_bits += chunk_bits;
391 }
392
393 element_bits_done += chunk_bits;
394 element >>= chunk_bits;
395 }
396 }
397
398 if (out_byte_bits > 0) {
399 *out = out_byte;
400 }
401 }
402
403 // scalar_encode_1 is |scalar_encode| specialised for |bits| == 1.
scalar_encode_1(uint8_t out[32],const scalar * s)404 static void scalar_encode_1(uint8_t out[32], const scalar *s) {
405 for (int i = 0; i < DEGREE; i += 8) {
406 uint8_t out_byte = 0;
407 for (int j = 0; j < 8; j++) {
408 out_byte |= (s->c[i + j] & 1) << j;
409 }
410 *out = out_byte;
411 out++;
412 }
413 }
414
415 // Encodes an entire vector into 32*|RANK|*|bits| bytes. Note that since 256
416 // (DEGREE) is divisible by 8, the individual vector entries will always fill a
417 // whole number of bytes, so we do not need to worry about bit packing here.
vector_encode(uint8_t * out,const vector * a,int bits)418 static void vector_encode(uint8_t *out, const vector *a, int bits) {
419 for (int i = 0; i < RANK; i++) {
420 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
421 }
422 }
423
424 // scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
425 // |out|. It returns one on success and zero if any parsed value is >=
426 // |kPrime|.
scalar_decode(scalar * out,const uint8_t * in,int bits)427 static int scalar_decode(scalar *out, const uint8_t *in, int bits) {
428 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
429
430 uint8_t in_byte = 0;
431 int in_byte_bits_left = 0;
432
433 for (int i = 0; i < DEGREE; i++) {
434 uint16_t element = 0;
435 int element_bits_done = 0;
436
437 while (element_bits_done < bits) {
438 if (in_byte_bits_left == 0) {
439 in_byte = *in;
440 in++;
441 in_byte_bits_left = 8;
442 }
443
444 int chunk_bits = bits - element_bits_done;
445 if (chunk_bits > in_byte_bits_left) {
446 chunk_bits = in_byte_bits_left;
447 }
448
449 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
450 in_byte_bits_left -= chunk_bits;
451 in_byte >>= chunk_bits;
452
453 element_bits_done += chunk_bits;
454 }
455
456 if (element >= kPrime) {
457 return 0;
458 }
459 out->c[i] = element;
460 }
461
462 return 1;
463 }
464
465 // scalar_decode_1 is |scalar_decode| specialised for |bits| == 1.
scalar_decode_1(scalar * out,const uint8_t in[32])466 static void scalar_decode_1(scalar *out, const uint8_t in[32]) {
467 for (int i = 0; i < DEGREE; i += 8) {
468 uint8_t in_byte = *in;
469 in++;
470 for (int j = 0; j < 8; j++) {
471 out->c[i + j] = in_byte & 1;
472 in_byte >>= 1;
473 }
474 }
475 }
476
477 // Decodes 32*|RANK|*|bits| bytes from |in| into |out|. It returns one on
478 // success or zero if any parsed value is >= |kPrime|.
vector_decode(vector * out,const uint8_t * in,int bits)479 static int vector_decode(vector *out, const uint8_t *in, int bits) {
480 for (int i = 0; i < RANK; i++) {
481 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits)) {
482 return 0;
483 }
484 }
485 return 1;
486 }
487
488 // Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
489 // numbers close to each other together. The formula used is
490 // round(2^|bits|/kPrime*x) mod 2^|bits|.
491 // Uses Barrett reduction to achieve constant time. Since we need both the
492 // remainder (for rounding) and the quotient (as the result), we cannot use
493 // |reduce| here, but need to do the Barrett reduction directly.
compress(uint16_t x,int bits)494 static uint16_t compress(uint16_t x, int bits) {
495 uint32_t shifted = (uint32_t)x << bits;
496 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
497 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
498 uint32_t remainder = shifted - quotient * kPrime;
499
500 // Adjust the quotient to round correctly:
501 // 0 <= remainder <= kHalfPrime round to 0
502 // kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
503 // kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
504 assert(remainder < 2u * kPrime);
505 quotient += 1 & constant_time_lt_w(kHalfPrime, remainder);
506 quotient += 1 & constant_time_lt_w(kPrime + kHalfPrime, remainder);
507 return quotient & ((1 << bits) - 1);
508 }
509
510 // Decompresses |x| by using an equi-distant representative. The formula is
511 // round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
512 // implement this logic using only bit operations.
decompress(uint16_t x,int bits)513 static uint16_t decompress(uint16_t x, int bits) {
514 uint32_t product = (uint32_t)x * kPrime;
515 uint32_t power = 1 << bits;
516 // This is |product| % power, since |power| is a power of 2.
517 uint32_t remainder = product & (power - 1);
518 // This is |product| / power, since |power| is a power of 2.
519 uint32_t lower = product >> bits;
520 // The rounding logic works since the first half of numbers mod |power| have a
521 // 0 as first bit, and the second half has a 1 as first bit, since |power| is
522 // a power of 2. As a 12 bit number, |remainder| is always positive, so we
523 // will shift in 0s for a right shift.
524 return lower + (remainder >> (bits - 1));
525 }
526
scalar_compress(scalar * s,int bits)527 static void scalar_compress(scalar *s, int bits) {
528 for (int i = 0; i < DEGREE; i++) {
529 s->c[i] = compress(s->c[i], bits);
530 }
531 }
532
scalar_decompress(scalar * s,int bits)533 static void scalar_decompress(scalar *s, int bits) {
534 for (int i = 0; i < DEGREE; i++) {
535 s->c[i] = decompress(s->c[i], bits);
536 }
537 }
538
vector_compress(vector * a,int bits)539 static void vector_compress(vector *a, int bits) {
540 for (int i = 0; i < RANK; i++) {
541 scalar_compress(&a->v[i], bits);
542 }
543 }
544
vector_decompress(vector * a,int bits)545 static void vector_decompress(vector *a, int bits) {
546 for (int i = 0; i < RANK; i++) {
547 scalar_decompress(&a->v[i], bits);
548 }
549 }
550
551 struct public_key {
552 vector t;
553 uint8_t rho[32];
554 uint8_t public_key_hash[32];
555 matrix m;
556 };
557
public_key_from_external(const struct KYBER_public_key * external)558 static struct public_key *public_key_from_external(
559 const struct KYBER_public_key *external) {
560 static_assert(sizeof(struct KYBER_public_key) >= sizeof(struct public_key),
561 "Kyber public key is too small");
562 static_assert(alignof(struct KYBER_public_key) >= alignof(struct public_key),
563 "Kyber public key align incorrect");
564 return (struct public_key *)external;
565 }
566
567 struct private_key {
568 struct public_key pub;
569 vector s;
570 uint8_t fo_failure_secret[32];
571 };
572
private_key_from_external(const struct KYBER_private_key * external)573 static struct private_key *private_key_from_external(
574 const struct KYBER_private_key *external) {
575 static_assert(sizeof(struct KYBER_private_key) >= sizeof(struct private_key),
576 "Kyber private key too small");
577 static_assert(
578 alignof(struct KYBER_private_key) >= alignof(struct private_key),
579 "Kyber private key align incorrect");
580 return (struct private_key *)external;
581 }
582
583 // Calls |KYBER_generate_key_external_entropy| with random bytes from
584 // |RAND_bytes|.
KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key)585 void KYBER_generate_key(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
586 struct KYBER_private_key *out_private_key) {
587 uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY];
588 RAND_bytes(entropy, sizeof(entropy));
589 KYBER_generate_key_external_entropy(out_encoded_public_key, out_private_key,
590 entropy);
591 }
592
kyber_marshal_public_key(CBB * out,const struct public_key * pub)593 static int kyber_marshal_public_key(CBB *out, const struct public_key *pub) {
594 uint8_t *vector_output;
595 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
596 return 0;
597 }
598 vector_encode(vector_output, &pub->t, kLog2Prime);
599 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
600 return 0;
601 }
602 return 1;
603 }
604
605 // Algorithms 4 and 7 of the Kyber spec. Algorithms are combined since key
606 // generation is not part of the FO transform, and the spec uses Algorithm 7 to
607 // specify the actual key format.
KYBER_generate_key_external_entropy(uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],struct KYBER_private_key * out_private_key,const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY])608 void KYBER_generate_key_external_entropy(
609 uint8_t out_encoded_public_key[KYBER_PUBLIC_KEY_BYTES],
610 struct KYBER_private_key *out_private_key,
611 const uint8_t entropy[KYBER_GENERATE_KEY_ENTROPY]) {
612 struct private_key *priv = private_key_from_external(out_private_key);
613 uint8_t hashed[64];
614 BORINGSSL_keccak(hashed, sizeof(hashed), entropy, 32, boringssl_sha3_512);
615 const uint8_t *const rho = hashed;
616 const uint8_t *const sigma = hashed + 32;
617 OPENSSL_memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
618 matrix_expand(&priv->pub.m, rho);
619 uint8_t counter = 0;
620 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
621 vector_ntt(&priv->s);
622 vector error;
623 vector_generate_secret_eta_2(&error, &counter, sigma);
624 vector_ntt(&error);
625 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
626 vector_add(&priv->pub.t, &error);
627
628 CBB cbb;
629 CBB_init_fixed(&cbb, out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES);
630 if (!kyber_marshal_public_key(&cbb, &priv->pub)) {
631 abort();
632 }
633
634 BORINGSSL_keccak(priv->pub.public_key_hash, sizeof(priv->pub.public_key_hash),
635 out_encoded_public_key, KYBER_PUBLIC_KEY_BYTES,
636 boringssl_sha3_256);
637 OPENSSL_memcpy(priv->fo_failure_secret, entropy + 32, 32);
638 }
639
KYBER_public_from_private(struct KYBER_public_key * out_public_key,const struct KYBER_private_key * private_key)640 void KYBER_public_from_private(struct KYBER_public_key *out_public_key,
641 const struct KYBER_private_key *private_key) {
642 struct public_key *const pub = public_key_from_external(out_public_key);
643 const struct private_key *const priv = private_key_from_external(private_key);
644 *pub = priv->pub;
645 }
646
647 // Algorithm 5 of the Kyber spec. Encrypts a message with given randomness to
648 // the ciphertext in |out|. Without applying the Fujisaki-Okamoto transform this
649 // would not result in a CCA secure scheme, since lattice schemes are vulnerable
650 // to decryption failure oracles.
encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])651 static void encrypt_cpa(uint8_t out[KYBER_CIPHERTEXT_BYTES],
652 const struct public_key *pub, const uint8_t message[32],
653 const uint8_t randomness[32]) {
654 uint8_t counter = 0;
655 vector secret;
656 vector_generate_secret_eta_2(&secret, &counter, randomness);
657 vector_ntt(&secret);
658 vector error;
659 vector_generate_secret_eta_2(&error, &counter, randomness);
660 uint8_t input[33];
661 OPENSSL_memcpy(input, randomness, 32);
662 input[32] = counter;
663 scalar scalar_error;
664 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, input);
665 vector u;
666 matrix_mult(&u, &pub->m, &secret);
667 vector_inverse_ntt(&u);
668 vector_add(&u, &error);
669 scalar v;
670 scalar_inner_product(&v, &pub->t, &secret);
671 scalar_inverse_ntt(&v);
672 scalar_add(&v, &scalar_error);
673 scalar expanded_message;
674 scalar_decode_1(&expanded_message, message);
675 scalar_decompress(&expanded_message, 1);
676 scalar_add(&v, &expanded_message);
677 vector_compress(&u, kDU);
678 vector_encode(out, &u, kDU);
679 scalar_compress(&v, kDV);
680 scalar_encode(out + kCompressedVectorSize, &v, kDV);
681 }
682
683 // Calls KYBER_encap_external_entropy| with random bytes from |RAND_bytes|
KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t * out_shared_secret,size_t out_shared_secret_len,const struct KYBER_public_key * public_key)684 void KYBER_encap(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],
685 uint8_t *out_shared_secret, size_t out_shared_secret_len,
686 const struct KYBER_public_key *public_key) {
687 uint8_t entropy[KYBER_ENCAP_ENTROPY];
688 RAND_bytes(entropy, KYBER_ENCAP_ENTROPY);
689 KYBER_encap_external_entropy(out_ciphertext, out_shared_secret,
690 out_shared_secret_len, public_key, entropy);
691 }
692
693 // Algorithm 8 of the Kyber spec, safe for line 2 of the spec. The spec there
694 // hashes the output of the system's random number generator, since the FO
695 // transform will reveal it to the decrypting party. There is no reason to do
696 // this when a secure random number generator is used. When an insecure random
697 // number generator is used, the caller should switch to a secure one before
698 // calling this method.
KYBER_encap_external_entropy(uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES],uint8_t * out_shared_secret,size_t out_shared_secret_len,const struct KYBER_public_key * public_key,const uint8_t entropy[KYBER_ENCAP_ENTROPY])699 void KYBER_encap_external_entropy(
700 uint8_t out_ciphertext[KYBER_CIPHERTEXT_BYTES], uint8_t *out_shared_secret,
701 size_t out_shared_secret_len, const struct KYBER_public_key *public_key,
702 const uint8_t entropy[KYBER_ENCAP_ENTROPY]) {
703 const struct public_key *pub = public_key_from_external(public_key);
704 uint8_t input[64];
705 OPENSSL_memcpy(input, entropy, KYBER_ENCAP_ENTROPY);
706 OPENSSL_memcpy(input + KYBER_ENCAP_ENTROPY, pub->public_key_hash,
707 sizeof(input) - KYBER_ENCAP_ENTROPY);
708 uint8_t prekey_and_randomness[64];
709 BORINGSSL_keccak(prekey_and_randomness, sizeof(prekey_and_randomness), input,
710 sizeof(input), boringssl_sha3_512);
711 encrypt_cpa(out_ciphertext, pub, entropy, prekey_and_randomness + 32);
712 BORINGSSL_keccak(prekey_and_randomness + 32, 32, out_ciphertext,
713 KYBER_CIPHERTEXT_BYTES, boringssl_sha3_256);
714 BORINGSSL_keccak(out_shared_secret, out_shared_secret_len,
715 prekey_and_randomness, sizeof(prekey_and_randomness),
716 boringssl_shake256);
717 }
718
719 // Algorithm 6 of the Kyber spec.
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES])720 static void decrypt_cpa(uint8_t out[32], const struct private_key *priv,
721 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES]) {
722 vector u;
723 vector_decode(&u, ciphertext, kDU);
724 vector_decompress(&u, kDU);
725 vector_ntt(&u);
726 scalar v;
727 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV);
728 scalar_decompress(&v, kDV);
729 scalar mask;
730 scalar_inner_product(&mask, &priv->s, &u);
731 scalar_inverse_ntt(&mask);
732 scalar_sub(&v, &mask);
733 scalar_compress(&v, 1);
734 scalar_encode_1(out, &v);
735 }
736
737 // Algorithm 9 of the Kyber spec, performing the FO transform by running
738 // encrypt_cpa on the decrypted message. The spec does not allow the decryption
739 // failure to be passed on to the caller, and instead returns a result that is
740 // deterministic but unpredictable to anyone without knowledge of the private
741 // key.
KYBER_decap(uint8_t * out_shared_secret,size_t out_shared_secret_len,const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],const struct KYBER_private_key * private_key)742 void KYBER_decap(uint8_t *out_shared_secret, size_t out_shared_secret_len,
743 const uint8_t ciphertext[KYBER_CIPHERTEXT_BYTES],
744 const struct KYBER_private_key *private_key) {
745 const struct private_key *priv = private_key_from_external(private_key);
746 uint8_t decrypted[64];
747 decrypt_cpa(decrypted, priv, ciphertext);
748 OPENSSL_memcpy(decrypted + 32, priv->pub.public_key_hash,
749 sizeof(decrypted) - 32);
750 uint8_t prekey_and_randomness[64];
751 BORINGSSL_keccak(prekey_and_randomness, sizeof(prekey_and_randomness),
752 decrypted, sizeof(decrypted), boringssl_sha3_512);
753 uint8_t expected_ciphertext[KYBER_CIPHERTEXT_BYTES];
754 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
755 prekey_and_randomness + 32);
756 uint8_t mask =
757 constant_time_eq_int_8(CRYPTO_memcmp(ciphertext, expected_ciphertext,
758 sizeof(expected_ciphertext)),
759 0);
760 uint8_t input[64];
761 for (int i = 0; i < 32; i++) {
762 input[i] = constant_time_select_8(mask, prekey_and_randomness[i],
763 priv->fo_failure_secret[i]);
764 }
765 BORINGSSL_keccak(input + 32, 32, ciphertext, KYBER_CIPHERTEXT_BYTES,
766 boringssl_sha3_256);
767 BORINGSSL_keccak(out_shared_secret, out_shared_secret_len, input,
768 sizeof(input), boringssl_shake256);
769 }
770
KYBER_marshal_public_key(CBB * out,const struct KYBER_public_key * public_key)771 int KYBER_marshal_public_key(CBB *out,
772 const struct KYBER_public_key *public_key) {
773 return kyber_marshal_public_key(out, public_key_from_external(public_key));
774 }
775
776 // kyber_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
777 // the value of |pub->public_key_hash|.
kyber_parse_public_key_no_hash(struct public_key * pub,CBS * in)778 static int kyber_parse_public_key_no_hash(struct public_key *pub, CBS *in) {
779 CBS t_bytes;
780 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
781 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime) ||
782 !CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
783 return 0;
784 }
785 matrix_expand(&pub->m, pub->rho);
786 return 1;
787 }
788
KYBER_parse_public_key(struct KYBER_public_key * public_key,CBS * in)789 int KYBER_parse_public_key(struct KYBER_public_key *public_key, CBS *in) {
790 struct public_key *pub = public_key_from_external(public_key);
791 CBS orig_in = *in;
792 if (!kyber_parse_public_key_no_hash(pub, in) || //
793 CBS_len(in) != 0) {
794 return 0;
795 }
796 BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
797 CBS_data(&orig_in), CBS_len(&orig_in), boringssl_sha3_256);
798 return 1;
799 }
800
KYBER_marshal_private_key(CBB * out,const struct KYBER_private_key * private_key)801 int KYBER_marshal_private_key(CBB *out,
802 const struct KYBER_private_key *private_key) {
803 const struct private_key *const priv = private_key_from_external(private_key);
804 uint8_t *s_output;
805 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
806 return 0;
807 }
808 vector_encode(s_output, &priv->s, kLog2Prime);
809 if (!kyber_marshal_public_key(out, &priv->pub) ||
810 !CBB_add_bytes(out, priv->pub.public_key_hash,
811 sizeof(priv->pub.public_key_hash)) ||
812 !CBB_add_bytes(out, priv->fo_failure_secret,
813 sizeof(priv->fo_failure_secret))) {
814 return 0;
815 }
816 return 1;
817 }
818
KYBER_parse_private_key(struct KYBER_private_key * out_private_key,CBS * in)819 int KYBER_parse_private_key(struct KYBER_private_key *out_private_key,
820 CBS *in) {
821 struct private_key *const priv = private_key_from_external(out_private_key);
822
823 CBS s_bytes;
824 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
825 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
826 !kyber_parse_public_key_no_hash(&priv->pub, in) ||
827 !CBS_copy_bytes(in, priv->pub.public_key_hash,
828 sizeof(priv->pub.public_key_hash)) ||
829 !CBS_copy_bytes(in, priv->fo_failure_secret,
830 sizeof(priv->fo_failure_secret)) ||
831 CBS_len(in) != 0) {
832 return 0;
833 }
834 return 1;
835 }
836