1 /* Copyright (c) 2023, Google LLC
2 *
3 * Permission to use, copy, modify, and/or distribute this software for any
4 * purpose with or without fee is hereby granted, provided that the above
5 * copyright notice and this permission notice appear in all copies.
6 *
7 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
14
15 #define OPENSSL_UNSTABLE_EXPERIMENTAL_DILITHIUM
16 #include <openssl/experimental/dilithium.h>
17
18 #include <assert.h>
19 #include <stdlib.h>
20
21 #include <openssl/bytestring.h>
22 #include <openssl/rand.h>
23
24 #include "../internal.h"
25 #include "../keccak/internal.h"
26 #include "./internal.h"
27
28 #define DEGREE 256
29 #define K 6
30 #define L 5
31 #define ETA 4
32 #define TAU 49
33 #define BETA 196
34 #define OMEGA 55
35
36 #define RHO_BYTES 32
37 #define SIGMA_BYTES 64
38 #define K_BYTES 32
39 #define TR_BYTES 64
40 #define MU_BYTES 64
41 #define RHO_PRIME_BYTES 64
42 #define LAMBDA_BITS 192
43 #define LAMBDA_BYTES (LAMBDA_BITS / 8)
44
45 // 2^23 - 2^13 + 1
46 static const uint32_t kPrime = 8380417;
47 // Inverse of -kPrime modulo 2^32
48 static const uint32_t kPrimeNegInverse = 4236238847;
49 static const int kDroppedBits = 13;
50 static const uint32_t kHalfPrime = (8380417 - 1) / 2;
51 static const uint32_t kGamma1 = 1 << 19;
52 static const uint32_t kGamma2 = (8380417 - 1) / 32;
53 // 256^-1 mod kPrime, in Montgomery form.
54 static const uint32_t kInverseDegreeMontgomery = 41978;
55
56 typedef struct scalar {
57 uint32_t c[DEGREE];
58 } scalar;
59
60 typedef struct vectork {
61 scalar v[K];
62 } vectork;
63
64 typedef struct vectorl {
65 scalar v[L];
66 } vectorl;
67
68 typedef struct matrix {
69 scalar v[K][L];
70 } matrix;
71
72 /* Arithmetic */
73
74 // This bit of Python will be referenced in some of the following comments:
75 //
76 // q = 8380417
77 // # Inverse of -q modulo 2^32
78 // q_neg_inverse = 4236238847
79 // # 2^64 modulo q
80 // montgomery_square = 2365951
81 //
82 // def bitreverse(i):
83 // ret = 0
84 // for n in range(8):
85 // bit = i & 1
86 // ret <<= 1
87 // ret |= bit
88 // i >>= 1
89 // return ret
90 //
91 // def montgomery_reduce(x):
92 // a = (x * q_neg_inverse) % 2**32
93 // b = x + a * q
94 // assert b & 0xFFFF_FFFF == 0
95 // c = b >> 32
96 // assert c < q
97 // return c
98 //
99 // def montgomery_transform(x):
100 // return montgomery_reduce(x * montgomery_square)
101
102 // kNTTRootsMontgomery = [
103 // montgomery_transform(pow(1753, bitreverse(i), q)) for i in range(256)
104 // ]
105 static const uint32_t kNTTRootsMontgomery[256] = {
106 4193792, 25847, 5771523, 7861508, 237124, 7602457, 7504169, 466468,
107 1826347, 2353451, 8021166, 6288512, 3119733, 5495562, 3111497, 2680103,
108 2725464, 1024112, 7300517, 3585928, 7830929, 7260833, 2619752, 6271868,
109 6262231, 4520680, 6980856, 5102745, 1757237, 8360995, 4010497, 280005,
110 2706023, 95776, 3077325, 3530437, 6718724, 4788269, 5842901, 3915439,
111 4519302, 5336701, 3574422, 5512770, 3539968, 8079950, 2348700, 7841118,
112 6681150, 6736599, 3505694, 4558682, 3507263, 6239768, 6779997, 3699596,
113 811944, 531354, 954230, 3881043, 3900724, 5823537, 2071892, 5582638,
114 4450022, 6851714, 4702672, 5339162, 6927966, 3475950, 2176455, 6795196,
115 7122806, 1939314, 4296819, 7380215, 5190273, 5223087, 4747489, 126922,
116 3412210, 7396998, 2147896, 2715295, 5412772, 4686924, 7969390, 5903370,
117 7709315, 7151892, 8357436, 7072248, 7998430, 1349076, 1852771, 6949987,
118 5037034, 264944, 508951, 3097992, 44288, 7280319, 904516, 3958618,
119 4656075, 8371839, 1653064, 5130689, 2389356, 8169440, 759969, 7063561,
120 189548, 4827145, 3159746, 6529015, 5971092, 8202977, 1315589, 1341330,
121 1285669, 6795489, 7567685, 6940675, 5361315, 4499357, 4751448, 3839961,
122 2091667, 3407706, 2316500, 3817976, 5037939, 2244091, 5933984, 4817955,
123 266997, 2434439, 7144689, 3513181, 4860065, 4621053, 7183191, 5187039,
124 900702, 1859098, 909542, 819034, 495491, 6767243, 8337157, 7857917,
125 7725090, 5257975, 2031748, 3207046, 4823422, 7855319, 7611795, 4784579,
126 342297, 286988, 5942594, 4108315, 3437287, 5038140, 1735879, 203044,
127 2842341, 2691481, 5790267, 1265009, 4055324, 1247620, 2486353, 1595974,
128 4613401, 1250494, 2635921, 4832145, 5386378, 1869119, 1903435, 7329447,
129 7047359, 1237275, 5062207, 6950192, 7929317, 1312455, 3306115, 6417775,
130 7100756, 1917081, 5834105, 7005614, 1500165, 777191, 2235880, 3406031,
131 7838005, 5548557, 6709241, 6533464, 5796124, 4656147, 594136, 4603424,
132 6366809, 2432395, 2454455, 8215696, 1957272, 3369112, 185531, 7173032,
133 5196991, 162844, 1616392, 3014001, 810149, 1652634, 4686184, 6581310,
134 5341501, 3523897, 3866901, 269760, 2213111, 7404533, 1717735, 472078,
135 7953734, 1723600, 6577327, 1910376, 6712985, 7276084, 8119771, 4546524,
136 5441381, 6144432, 7959518, 6094090, 183443, 7403526, 1612842, 4834730,
137 7826001, 3919660, 8332111, 7018208, 3937738, 1400424, 7534263, 1976782};
138
139 // Reduces x mod kPrime in constant time, where 0 <= x < 2*kPrime.
reduce_once(uint32_t x)140 static uint32_t reduce_once(uint32_t x) {
141 declassify_assert(x < 2 * kPrime);
142 // return x < kPrime ? x : x - kPrime;
143 return constant_time_select_int(constant_time_lt_w(x, kPrime), x, x - kPrime);
144 }
145
146 // Returns the absolute value in constant time.
abs_signed(uint32_t x)147 static uint32_t abs_signed(uint32_t x) {
148 // return is_positive(x) ? x : -x;
149 // Note: MSVC doesn't like applying the unary minus operator to unsigned types
150 // (warning C4146), so we write the negation as a bitwise not plus one
151 // (assuming two's complement representation).
152 return constant_time_select_int(constant_time_lt_w(x, 0x80000000), x, ~x + 1);
153 }
154
155 // Returns the absolute value modulo kPrime.
abs_mod_prime(uint32_t x)156 static uint32_t abs_mod_prime(uint32_t x) {
157 declassify_assert(x < kPrime);
158 // return x > kHalfPrime ? kPrime - x : x;
159 return constant_time_select_int(constant_time_lt_w(kHalfPrime, x), kPrime - x,
160 x);
161 }
162
163 // Returns the maximum of two values in constant time.
maximum(uint32_t x,uint32_t y)164 static uint32_t maximum(uint32_t x, uint32_t y) {
165 // return x < y ? y : x;
166 return constant_time_select_int(constant_time_lt_w(x, y), y, x);
167 }
168
scalar_add(scalar * out,const scalar * lhs,const scalar * rhs)169 static void scalar_add(scalar *out, const scalar *lhs, const scalar *rhs) {
170 for (int i = 0; i < DEGREE; i++) {
171 out->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
172 }
173 }
174
scalar_sub(scalar * out,const scalar * lhs,const scalar * rhs)175 static void scalar_sub(scalar *out, const scalar *lhs, const scalar *rhs) {
176 for (int i = 0; i < DEGREE; i++) {
177 out->c[i] = reduce_once(kPrime + lhs->c[i] - rhs->c[i]);
178 }
179 }
180
reduce_montgomery(uint64_t x)181 static uint32_t reduce_montgomery(uint64_t x) {
182 uint64_t a = (uint32_t)x * kPrimeNegInverse;
183 uint64_t b = x + a * kPrime;
184 declassify_assert((b & 0xffffffff) == 0);
185 uint32_t c = b >> 32;
186 return reduce_once(c);
187 }
188
189 // Multiply two scalars in the number theoretically transformed state.
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)190 static void scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) {
191 for (int i = 0; i < DEGREE; i++) {
192 out->c[i] = reduce_montgomery((uint64_t)lhs->c[i] * (uint64_t)rhs->c[i]);
193 }
194 }
195
196 // In place number theoretic transform of a given scalar.
197 //
198 // FIPS 204, Algorithm 35 (`NTT`).
scalar_ntt(scalar * s)199 static void scalar_ntt(scalar *s) {
200 // Step: 1, 2, 4, 8, ..., 128
201 // Offset: 128, 64, 32, 16, ..., 1
202 int offset = DEGREE;
203 for (int step = 1; step < DEGREE; step <<= 1) {
204 offset >>= 1;
205 int k = 0;
206 for (int i = 0; i < step; i++) {
207 assert(k == 2 * offset * i);
208 const uint32_t step_root = kNTTRootsMontgomery[step + i];
209 for (int j = k; j < k + offset; j++) {
210 uint32_t even = s->c[j];
211 uint32_t odd =
212 reduce_montgomery((uint64_t)step_root * (uint64_t)s->c[j + offset]);
213 s->c[j] = reduce_once(odd + even);
214 s->c[j + offset] = reduce_once(kPrime + even - odd);
215 }
216 k += 2 * offset;
217 }
218 }
219 }
220
221 // In place inverse number theoretic transform of a given scalar.
222 //
223 // FIPS 204, Algorithm 36 (`NTT^-1`).
scalar_inverse_ntt(scalar * s)224 static void scalar_inverse_ntt(scalar *s) {
225 // Step: 128, 64, 32, 16, ..., 1
226 // Offset: 1, 2, 4, 8, ..., 128
227 int step = DEGREE;
228 for (int offset = 1; offset < DEGREE; offset <<= 1) {
229 step >>= 1;
230 int k = 0;
231 for (int i = 0; i < step; i++) {
232 assert(k == 2 * offset * i);
233 const uint32_t step_root =
234 kPrime - kNTTRootsMontgomery[step + (step - 1 - i)];
235 for (int j = k; j < k + offset; j++) {
236 uint32_t even = s->c[j];
237 uint32_t odd = s->c[j + offset];
238 s->c[j] = reduce_once(odd + even);
239 s->c[j + offset] = reduce_montgomery((uint64_t)step_root *
240 (uint64_t)(kPrime + even - odd));
241 }
242 k += 2 * offset;
243 }
244 }
245 for (int i = 0; i < DEGREE; i++) {
246 s->c[i] = reduce_montgomery((uint64_t)s->c[i] *
247 (uint64_t)kInverseDegreeMontgomery);
248 }
249 }
250
vectork_zero(vectork * out)251 static void vectork_zero(vectork *out) { OPENSSL_memset(out, 0, sizeof(*out)); }
252
vectork_add(vectork * out,const vectork * lhs,const vectork * rhs)253 static void vectork_add(vectork *out, const vectork *lhs, const vectork *rhs) {
254 for (int i = 0; i < K; i++) {
255 scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
256 }
257 }
258
vectork_sub(vectork * out,const vectork * lhs,const vectork * rhs)259 static void vectork_sub(vectork *out, const vectork *lhs, const vectork *rhs) {
260 for (int i = 0; i < K; i++) {
261 scalar_sub(&out->v[i], &lhs->v[i], &rhs->v[i]);
262 }
263 }
264
vectork_mult_scalar(vectork * out,const vectork * lhs,const scalar * rhs)265 static void vectork_mult_scalar(vectork *out, const vectork *lhs,
266 const scalar *rhs) {
267 for (int i = 0; i < K; i++) {
268 scalar_mult(&out->v[i], &lhs->v[i], rhs);
269 }
270 }
271
vectork_ntt(vectork * a)272 static void vectork_ntt(vectork *a) {
273 for (int i = 0; i < K; i++) {
274 scalar_ntt(&a->v[i]);
275 }
276 }
277
vectork_inverse_ntt(vectork * a)278 static void vectork_inverse_ntt(vectork *a) {
279 for (int i = 0; i < K; i++) {
280 scalar_inverse_ntt(&a->v[i]);
281 }
282 }
283
vectorl_add(vectorl * out,const vectorl * lhs,const vectorl * rhs)284 static void vectorl_add(vectorl *out, const vectorl *lhs, const vectorl *rhs) {
285 for (int i = 0; i < L; i++) {
286 scalar_add(&out->v[i], &lhs->v[i], &rhs->v[i]);
287 }
288 }
289
vectorl_mult_scalar(vectorl * out,const vectorl * lhs,const scalar * rhs)290 static void vectorl_mult_scalar(vectorl *out, const vectorl *lhs,
291 const scalar *rhs) {
292 for (int i = 0; i < L; i++) {
293 scalar_mult(&out->v[i], &lhs->v[i], rhs);
294 }
295 }
296
vectorl_ntt(vectorl * a)297 static void vectorl_ntt(vectorl *a) {
298 for (int i = 0; i < L; i++) {
299 scalar_ntt(&a->v[i]);
300 }
301 }
302
vectorl_inverse_ntt(vectorl * a)303 static void vectorl_inverse_ntt(vectorl *a) {
304 for (int i = 0; i < L; i++) {
305 scalar_inverse_ntt(&a->v[i]);
306 }
307 }
308
matrix_mult(vectork * out,const matrix * m,const vectorl * a)309 static void matrix_mult(vectork *out, const matrix *m, const vectorl *a) {
310 vectork_zero(out);
311 for (int i = 0; i < K; i++) {
312 for (int j = 0; j < L; j++) {
313 scalar product;
314 scalar_mult(&product, &m->v[i][j], &a->v[j]);
315 scalar_add(&out->v[i], &out->v[i], &product);
316 }
317 }
318 }
319
320 /* Rounding & hints */
321
322 // FIPS 204, Algorithm 29 (`Power2Round`).
power2_round(uint32_t * r1,uint32_t * r0,uint32_t r)323 static void power2_round(uint32_t *r1, uint32_t *r0, uint32_t r) {
324 *r1 = r >> kDroppedBits;
325 *r0 = r - (*r1 << kDroppedBits);
326
327 uint32_t r0_adjusted = reduce_once(kPrime + *r0 - (1 << kDroppedBits));
328 uint32_t r1_adjusted = *r1 + 1;
329
330 // Mask is set iff r0 > 2^(dropped_bits - 1).
331 crypto_word_t mask =
332 constant_time_lt_w((uint32_t)(1 << (kDroppedBits - 1)), *r0);
333 // r0 = mask ? r0_adjusted : r0
334 *r0 = constant_time_select_int(mask, r0_adjusted, *r0);
335 // r1 = mask ? r1_adjusted : r1
336 *r1 = constant_time_select_int(mask, r1_adjusted, *r1);
337 }
338
339 // Scale back previously rounded value.
scale_power2_round(uint32_t * out,uint32_t r1)340 static void scale_power2_round(uint32_t *out, uint32_t r1) {
341 // Pre-condition: 0 <= r1 <= 2^10 - 1
342 *out = r1 << kDroppedBits;
343 // Post-condition: 0 <= out <= 2^23 - 2^13 = kPrime - 1
344 assert(*out < kPrime);
345 }
346
347 // FIPS 204, Algorithm 31 (`HighBits`).
high_bits(uint32_t x)348 static uint32_t high_bits(uint32_t x) {
349 // Reference description (given 0 <= x < q):
350 //
351 // ```
352 // int32_t r0 = x mod+- (2 * kGamma2);
353 // if (x - r0 == q - 1) {
354 // return 0;
355 // } else {
356 // return (x - r0) / (2 * kGamma2);
357 // }
358 // ```
359 //
360 // Below is the formula taken from the reference implementation.
361 //
362 // Here, kGamma2 == 2^18 - 2^8
363 // This returns ((ceil(x / 2^7) * (2^10 + 1) + 2^21) / 2^22) mod 2^4
364 uint32_t r1 = (x + 127) >> 7;
365 r1 = (r1 * 1025 + (1 << 21)) >> 22;
366 r1 &= 15;
367 return r1;
368 }
369
370 // FIPS 204, Algorithm 30 (`Decompose`).
decompose(uint32_t * r1,int32_t * r0,uint32_t r)371 static void decompose(uint32_t *r1, int32_t *r0, uint32_t r) {
372 *r1 = high_bits(r);
373
374 *r0 = r;
375 *r0 -= *r1 * 2 * (int32_t)kGamma2;
376 *r0 -= (((int32_t)kHalfPrime - *r0) >> 31) & (int32_t)kPrime;
377 }
378
379 // FIPS 204, Algorithm 32 (`LowBits`).
low_bits(uint32_t x)380 static int32_t low_bits(uint32_t x) {
381 uint32_t r1;
382 int32_t r0;
383 decompose(&r1, &r0, x);
384 return r0;
385 }
386
387 // FIPS 204, Algorithm 33 (`MakeHint`).
make_hint(uint32_t ct0,uint32_t cs2,uint32_t w)388 static int32_t make_hint(uint32_t ct0, uint32_t cs2, uint32_t w) {
389 uint32_t r_plus_z = reduce_once(kPrime + w - cs2);
390 uint32_t r = reduce_once(r_plus_z + ct0);
391 return high_bits(r) != high_bits(r_plus_z);
392 }
393
394 // FIPS 204, Algorithm 34 (`UseHint`).
use_hint_vartime(uint32_t h,uint32_t r)395 static uint32_t use_hint_vartime(uint32_t h, uint32_t r) {
396 uint32_t r1;
397 int32_t r0;
398 decompose(&r1, &r0, r);
399
400 if (h) {
401 if (r0 > 0) {
402 return (r1 + 1) & 15;
403 } else {
404 return (r1 - 1) & 15;
405 }
406 } else {
407 return r1;
408 }
409 }
410
scalar_power2_round(scalar * s1,scalar * s0,const scalar * s)411 static void scalar_power2_round(scalar *s1, scalar *s0, const scalar *s) {
412 for (int i = 0; i < DEGREE; i++) {
413 power2_round(&s1->c[i], &s0->c[i], s->c[i]);
414 }
415 }
416
scalar_scale_power2_round(scalar * out,const scalar * in)417 static void scalar_scale_power2_round(scalar *out, const scalar *in) {
418 for (int i = 0; i < DEGREE; i++) {
419 scale_power2_round(&out->c[i], in->c[i]);
420 }
421 }
422
scalar_high_bits(scalar * out,const scalar * in)423 static void scalar_high_bits(scalar *out, const scalar *in) {
424 for (int i = 0; i < DEGREE; i++) {
425 out->c[i] = high_bits(in->c[i]);
426 }
427 }
428
scalar_low_bits(scalar * out,const scalar * in)429 static void scalar_low_bits(scalar *out, const scalar *in) {
430 for (int i = 0; i < DEGREE; i++) {
431 out->c[i] = low_bits(in->c[i]);
432 }
433 }
434
scalar_max(uint32_t * max,const scalar * s)435 static void scalar_max(uint32_t *max, const scalar *s) {
436 for (int i = 0; i < DEGREE; i++) {
437 uint32_t abs = abs_mod_prime(s->c[i]);
438 *max = maximum(*max, abs);
439 }
440 }
441
scalar_max_signed(uint32_t * max,const scalar * s)442 static void scalar_max_signed(uint32_t *max, const scalar *s) {
443 for (int i = 0; i < DEGREE; i++) {
444 uint32_t abs = abs_signed(s->c[i]);
445 *max = maximum(*max, abs);
446 }
447 }
448
scalar_make_hint(scalar * out,const scalar * ct0,const scalar * cs2,const scalar * w)449 static void scalar_make_hint(scalar *out, const scalar *ct0, const scalar *cs2,
450 const scalar *w) {
451 for (int i = 0; i < DEGREE; i++) {
452 out->c[i] = make_hint(ct0->c[i], cs2->c[i], w->c[i]);
453 }
454 }
455
scalar_use_hint_vartime(scalar * out,const scalar * h,const scalar * r)456 static void scalar_use_hint_vartime(scalar *out, const scalar *h,
457 const scalar *r) {
458 for (int i = 0; i < DEGREE; i++) {
459 out->c[i] = use_hint_vartime(h->c[i], r->c[i]);
460 }
461 }
462
vectork_power2_round(vectork * t1,vectork * t0,const vectork * t)463 static void vectork_power2_round(vectork *t1, vectork *t0, const vectork *t) {
464 for (int i = 0; i < K; i++) {
465 scalar_power2_round(&t1->v[i], &t0->v[i], &t->v[i]);
466 }
467 }
468
vectork_scale_power2_round(vectork * out,const vectork * in)469 static void vectork_scale_power2_round(vectork *out, const vectork *in) {
470 for (int i = 0; i < K; i++) {
471 scalar_scale_power2_round(&out->v[i], &in->v[i]);
472 }
473 }
474
vectork_high_bits(vectork * out,const vectork * in)475 static void vectork_high_bits(vectork *out, const vectork *in) {
476 for (int i = 0; i < K; i++) {
477 scalar_high_bits(&out->v[i], &in->v[i]);
478 }
479 }
480
vectork_low_bits(vectork * out,const vectork * in)481 static void vectork_low_bits(vectork *out, const vectork *in) {
482 for (int i = 0; i < K; i++) {
483 scalar_low_bits(&out->v[i], &in->v[i]);
484 }
485 }
486
vectork_max(const vectork * a)487 static uint32_t vectork_max(const vectork *a) {
488 uint32_t max = 0;
489 for (int i = 0; i < K; i++) {
490 scalar_max(&max, &a->v[i]);
491 }
492 return max;
493 }
494
vectork_max_signed(const vectork * a)495 static uint32_t vectork_max_signed(const vectork *a) {
496 uint32_t max = 0;
497 for (int i = 0; i < K; i++) {
498 scalar_max_signed(&max, &a->v[i]);
499 }
500 return max;
501 }
502
503 // The input vector contains only zeroes and ones.
vectork_count_ones(const vectork * a)504 static size_t vectork_count_ones(const vectork *a) {
505 size_t count = 0;
506 for (int i = 0; i < K; i++) {
507 for (int j = 0; j < DEGREE; j++) {
508 count += a->v[i].c[j];
509 }
510 }
511 return count;
512 }
513
vectork_make_hint(vectork * out,const vectork * ct0,const vectork * cs2,const vectork * w)514 static void vectork_make_hint(vectork *out, const vectork *ct0,
515 const vectork *cs2, const vectork *w) {
516 for (int i = 0; i < K; i++) {
517 scalar_make_hint(&out->v[i], &ct0->v[i], &cs2->v[i], &w->v[i]);
518 }
519 }
520
vectork_use_hint_vartime(vectork * out,const vectork * h,const vectork * r)521 static void vectork_use_hint_vartime(vectork *out, const vectork *h,
522 const vectork *r) {
523 for (int i = 0; i < K; i++) {
524 scalar_use_hint_vartime(&out->v[i], &h->v[i], &r->v[i]);
525 }
526 }
527
vectorl_max(const vectorl * a)528 static uint32_t vectorl_max(const vectorl *a) {
529 uint32_t max = 0;
530 for (int i = 0; i < L; i++) {
531 scalar_max(&max, &a->v[i]);
532 }
533 return max;
534 }
535
536 /* Bit packing */
537
538 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
539 0x1f, 0x3f, 0x7f, 0xff};
540
541 // FIPS 204, Algorithm 10 (`SimpleBitPack`).
scalar_encode(uint8_t * out,const scalar * s,int bits)542 static void scalar_encode(uint8_t *out, const scalar *s, int bits) {
543 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
544
545 uint8_t out_byte = 0;
546 int out_byte_bits = 0;
547
548 for (int i = 0; i < DEGREE; i++) {
549 uint32_t element = s->c[i];
550 int element_bits_done = 0;
551
552 while (element_bits_done < bits) {
553 int chunk_bits = bits - element_bits_done;
554 int out_bits_remaining = 8 - out_byte_bits;
555 if (chunk_bits >= out_bits_remaining) {
556 chunk_bits = out_bits_remaining;
557 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
558 *out = out_byte;
559 out++;
560 out_byte_bits = 0;
561 out_byte = 0;
562 } else {
563 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
564 out_byte_bits += chunk_bits;
565 }
566
567 element_bits_done += chunk_bits;
568 element >>= chunk_bits;
569 }
570 }
571
572 if (out_byte_bits > 0) {
573 *out = out_byte;
574 }
575 }
576
577 // FIPS 204, Algorithm 11 (`BitPack`).
scalar_encode_signed(uint8_t * out,const scalar * s,int bits,uint32_t max)578 static void scalar_encode_signed(uint8_t *out, const scalar *s, int bits,
579 uint32_t max) {
580 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
581
582 uint8_t out_byte = 0;
583 int out_byte_bits = 0;
584
585 for (int i = 0; i < DEGREE; i++) {
586 uint32_t element = reduce_once(kPrime + max - s->c[i]);
587 declassify_assert(element <= 2 * max);
588 int element_bits_done = 0;
589
590 while (element_bits_done < bits) {
591 int chunk_bits = bits - element_bits_done;
592 int out_bits_remaining = 8 - out_byte_bits;
593 if (chunk_bits >= out_bits_remaining) {
594 chunk_bits = out_bits_remaining;
595 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
596 *out = out_byte;
597 out++;
598 out_byte_bits = 0;
599 out_byte = 0;
600 } else {
601 out_byte |= (element & kMasks[chunk_bits - 1]) << out_byte_bits;
602 out_byte_bits += chunk_bits;
603 }
604
605 element_bits_done += chunk_bits;
606 element >>= chunk_bits;
607 }
608 }
609
610 if (out_byte_bits > 0) {
611 *out = out_byte;
612 }
613 }
614
615 // FIPS 204, Algorithm 12 (`SimpleBitUnpack`).
scalar_decode(scalar * out,const uint8_t * in,int bits)616 static void scalar_decode(scalar *out, const uint8_t *in, int bits) {
617 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
618
619 uint8_t in_byte = 0;
620 int in_byte_bits_left = 0;
621
622 for (int i = 0; i < DEGREE; i++) {
623 uint32_t element = 0;
624 int element_bits_done = 0;
625
626 while (element_bits_done < bits) {
627 if (in_byte_bits_left == 0) {
628 in_byte = *in;
629 in++;
630 in_byte_bits_left = 8;
631 }
632
633 int chunk_bits = bits - element_bits_done;
634 if (chunk_bits > in_byte_bits_left) {
635 chunk_bits = in_byte_bits_left;
636 }
637
638 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
639 in_byte_bits_left -= chunk_bits;
640 in_byte >>= chunk_bits;
641
642 element_bits_done += chunk_bits;
643 }
644
645 out->c[i] = element;
646 }
647 }
648
649 // FIPS 204, Algorithm 13 (`BitUnpack`).
scalar_decode_signed(scalar * out,const uint8_t * in,int bits,uint32_t max)650 static int scalar_decode_signed(scalar *out, const uint8_t *in, int bits,
651 uint32_t max) {
652 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
653
654 uint8_t in_byte = 0;
655 int in_byte_bits_left = 0;
656
657 for (int i = 0; i < DEGREE; i++) {
658 uint32_t element = 0;
659 int element_bits_done = 0;
660
661 while (element_bits_done < bits) {
662 if (in_byte_bits_left == 0) {
663 in_byte = *in;
664 in++;
665 in_byte_bits_left = 8;
666 }
667
668 int chunk_bits = bits - element_bits_done;
669 if (chunk_bits > in_byte_bits_left) {
670 chunk_bits = in_byte_bits_left;
671 }
672
673 element |= (in_byte & kMasks[chunk_bits - 1]) << element_bits_done;
674 in_byte_bits_left -= chunk_bits;
675 in_byte >>= chunk_bits;
676
677 element_bits_done += chunk_bits;
678 }
679
680 // This may be only out of range in cases of invalid input, in which case it
681 // is okay to leak the value. This function is also called with secret
682 // input during signing, in |scalar_sample_mask|. However, in that case
683 // (and in any case when |max| is a power of two), this case is impossible.
684 if (constant_time_declassify_int(element > 2 * max)) {
685 return 0;
686 }
687 out->c[i] = reduce_once(kPrime + max - element);
688 }
689
690 return 1;
691 }
692
693 /* Expansion functions */
694
695 // FIPS 204, Algorithm 24 (`RejNTTPoly`).
696 //
697 // Rejection samples a Keccak stream to get uniformly distributed elements. This
698 // is used for matrix expansion and only operates on public inputs.
scalar_from_keccak_vartime(scalar * out,const uint8_t derived_seed[RHO_BYTES+2])699 static void scalar_from_keccak_vartime(
700 scalar *out, const uint8_t derived_seed[RHO_BYTES + 2]) {
701 struct BORINGSSL_keccak_st keccak_ctx;
702 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake128);
703 BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, RHO_BYTES + 2);
704 assert(keccak_ctx.squeeze_offset == 0);
705 assert(keccak_ctx.rate_bytes == 168);
706 static_assert(168 % 3 == 0, "block and coefficient boundaries do not align");
707
708 int done = 0;
709 while (done < DEGREE) {
710 uint8_t block[168];
711 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
712 for (size_t i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
713 // FIPS 204, Algorithm 8 (`CoeffFromThreeBytes`).
714 uint32_t value = (uint32_t)block[i] | ((uint32_t)block[i + 1] << 8) |
715 (((uint32_t)block[i + 2] & 0x7f) << 16);
716 if (value < kPrime) {
717 out->c[done++] = value;
718 }
719 }
720 }
721 }
722
723 // FIPS 204, Algorithm 25 (`RejBoundedPoly`).
scalar_uniform_eta_4(scalar * out,const uint8_t derived_seed[SIGMA_BYTES+2])724 static void scalar_uniform_eta_4(
725 scalar *out, const uint8_t derived_seed[SIGMA_BYTES + 2]) {
726 static_assert(ETA == 4, "This implementation is specialized for ETA == 4");
727
728 struct BORINGSSL_keccak_st keccak_ctx;
729 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
730 BORINGSSL_keccak_absorb(&keccak_ctx, derived_seed, SIGMA_BYTES + 2);
731 assert(keccak_ctx.squeeze_offset == 0);
732 assert(keccak_ctx.rate_bytes == 136);
733
734 int done = 0;
735 while (done < DEGREE) {
736 uint8_t block[136];
737 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
738 for (size_t i = 0; i < sizeof(block) && done < DEGREE; ++i) {
739 uint32_t t0 = block[i] & 0x0F;
740 uint32_t t1 = block[i] >> 4;
741 // FIPS 204, Algorithm 9 (`CoefFromHalfByte`). Although both the input and
742 // output here are secret, it is OK to leak when we rejected a byte.
743 // Individual bytes of the SHAKE-256 stream are (indistiguishable from)
744 // independent of each other and the original seed, so leaking information
745 // about the rejected bytes does not reveal the input or output.
746 if (constant_time_declassify_int(t0 < 9)) {
747 out->c[done++] = reduce_once(kPrime + ETA - t0);
748 }
749 if (done < DEGREE && constant_time_declassify_int(t1 < 9)) {
750 out->c[done++] = reduce_once(kPrime + ETA - t1);
751 }
752 }
753 }
754 }
755
756 // FIPS 204, Algorithm 28 (`ExpandMask`).
scalar_sample_mask(scalar * out,const uint8_t derived_seed[RHO_PRIME_BYTES+2])757 static void scalar_sample_mask(
758 scalar *out, const uint8_t derived_seed[RHO_PRIME_BYTES + 2]) {
759 uint8_t buf[640];
760 BORINGSSL_keccak(buf, sizeof(buf), derived_seed, RHO_PRIME_BYTES + 2,
761 boringssl_shake256);
762
763 // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
764 scalar_decode_signed(out, buf, 20, 1 << 19);
765 }
766
767 // FIPS 204, Algorithm 23 (`SampleInBall`).
scalar_sample_in_ball_vartime(scalar * out,const uint8_t * seed,int len)768 static void scalar_sample_in_ball_vartime(scalar *out, const uint8_t *seed,
769 int len) {
770 assert(len == 32);
771
772 struct BORINGSSL_keccak_st keccak_ctx;
773 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
774 BORINGSSL_keccak_absorb(&keccak_ctx, seed, len);
775 assert(keccak_ctx.squeeze_offset == 0);
776 assert(keccak_ctx.rate_bytes == 136);
777
778 uint8_t block[136];
779 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
780
781 uint64_t signs = CRYPTO_load_u64_le(block);
782 int offset = 8;
783 // SampleInBall implements a Fisher–Yates shuffle, which unavoidably leaks
784 // where the zeros are by memory access pattern. Although this leak happens
785 // before bad signatures are rejected, this is safe. See
786 // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/8d8f01ac_70af3f21/
787 CONSTTIME_DECLASSIFY(block + offset, sizeof(block) - offset);
788
789 OPENSSL_memset(out, 0, sizeof(*out));
790 for (size_t i = DEGREE - TAU; i < DEGREE; i++) {
791 size_t byte;
792 for (;;) {
793 if (offset == 136) {
794 BORINGSSL_keccak_squeeze(&keccak_ctx, block, sizeof(block));
795 // See above.
796 CONSTTIME_DECLASSIFY(block, sizeof(block));
797 offset = 0;
798 }
799
800 byte = block[offset++];
801 if (byte <= i) {
802 break;
803 }
804 }
805
806 out->c[i] = out->c[byte];
807 out->c[byte] = reduce_once(kPrime + 1 - 2 * (signs & 1));
808 signs >>= 1;
809 }
810 }
811
812 // FIPS 204, Algorithm 26 (`ExpandA`).
matrix_expand(matrix * out,const uint8_t rho[RHO_BYTES])813 static void matrix_expand(matrix *out, const uint8_t rho[RHO_BYTES]) {
814 static_assert(K <= 0x100, "K must fit in 8 bits");
815 static_assert(L <= 0x100, "L must fit in 8 bits");
816
817 uint8_t derived_seed[RHO_BYTES + 2];
818 OPENSSL_memcpy(derived_seed, rho, RHO_BYTES);
819 for (int i = 0; i < K; i++) {
820 for (int j = 0; j < L; j++) {
821 derived_seed[RHO_BYTES + 1] = i;
822 derived_seed[RHO_BYTES] = j;
823 scalar_from_keccak_vartime(&out->v[i][j], derived_seed);
824 }
825 }
826 }
827
828 // FIPS 204, Algorithm 27 (`ExpandS`).
vector_expand_short(vectorl * s1,vectork * s2,const uint8_t sigma[SIGMA_BYTES])829 static void vector_expand_short(vectorl *s1, vectork *s2,
830 const uint8_t sigma[SIGMA_BYTES]) {
831 static_assert(K <= 0x100, "K must fit in 8 bits");
832 static_assert(L <= 0x100, "L must fit in 8 bits");
833 static_assert(K + L <= 0x100, "K+L must fit in 8 bits");
834
835 uint8_t derived_seed[SIGMA_BYTES + 2];
836 OPENSSL_memcpy(derived_seed, sigma, SIGMA_BYTES);
837 derived_seed[SIGMA_BYTES] = 0;
838 derived_seed[SIGMA_BYTES + 1] = 0;
839 for (int i = 0; i < L; i++) {
840 scalar_uniform_eta_4(&s1->v[i], derived_seed);
841 ++derived_seed[SIGMA_BYTES];
842 }
843 for (int i = 0; i < K; i++) {
844 scalar_uniform_eta_4(&s2->v[i], derived_seed);
845 ++derived_seed[SIGMA_BYTES];
846 }
847 }
848
849 // FIPS 204, Algorithm 28 (`ExpandMask`).
vectorl_expand_mask(vectorl * out,const uint8_t seed[RHO_PRIME_BYTES],size_t kappa)850 static void vectorl_expand_mask(vectorl *out,
851 const uint8_t seed[RHO_PRIME_BYTES],
852 size_t kappa) {
853 assert(kappa + L <= 0x10000);
854
855 uint8_t derived_seed[RHO_PRIME_BYTES + 2];
856 OPENSSL_memcpy(derived_seed, seed, RHO_PRIME_BYTES);
857 for (int i = 0; i < L; i++) {
858 size_t index = kappa + i;
859 derived_seed[RHO_PRIME_BYTES] = index & 0xFF;
860 derived_seed[RHO_PRIME_BYTES + 1] = (index >> 8) & 0xFF;
861 scalar_sample_mask(&out->v[i], derived_seed);
862 }
863 }
864
865 /* Encoding */
866
867 // FIPS 204, Algorithm 10 (`SimpleBitPack`).
868 //
869 // Encodes an entire vector into 32*K*|bits| bytes. Note that since 256 (DEGREE)
870 // is divisible by 8, the individual vector entries will always fill a whole
871 // number of bytes, so we do not need to worry about bit packing here.
vectork_encode(uint8_t * out,const vectork * a,int bits)872 static void vectork_encode(uint8_t *out, const vectork *a, int bits) {
873 for (int i = 0; i < K; i++) {
874 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
875 }
876 }
877
878 // FIPS 204, Algorithm 12 (`SimpleBitUnpack`).
vectork_decode(vectork * out,const uint8_t * in,int bits)879 static void vectork_decode(vectork *out, const uint8_t *in, int bits) {
880 for (int i = 0; i < K; i++) {
881 scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, bits);
882 }
883 }
884
vectork_encode_signed(uint8_t * out,const vectork * a,int bits,uint32_t max)885 static void vectork_encode_signed(uint8_t *out, const vectork *a, int bits,
886 uint32_t max) {
887 for (int i = 0; i < K; i++) {
888 scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
889 }
890 }
891
vectork_decode_signed(vectork * out,const uint8_t * in,int bits,uint32_t max)892 static int vectork_decode_signed(vectork *out, const uint8_t *in, int bits,
893 uint32_t max) {
894 for (int i = 0; i < K; i++) {
895 if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
896 max)) {
897 return 0;
898 }
899 }
900 return 1;
901 }
902
903 // FIPS 204, Algorithm 11 (`BitPack`).
904 //
905 // Encodes an entire vector into 32*L*|bits| bytes. Note that since 256 (DEGREE)
906 // is divisible by 8, the individual vector entries will always fill a whole
907 // number of bytes, so we do not need to worry about bit packing here.
vectorl_encode_signed(uint8_t * out,const vectorl * a,int bits,uint32_t max)908 static void vectorl_encode_signed(uint8_t *out, const vectorl *a, int bits,
909 uint32_t max) {
910 for (int i = 0; i < L; i++) {
911 scalar_encode_signed(out + i * bits * DEGREE / 8, &a->v[i], bits, max);
912 }
913 }
914
vectorl_decode_signed(vectorl * out,const uint8_t * in,int bits,uint32_t max)915 static int vectorl_decode_signed(vectorl *out, const uint8_t *in, int bits,
916 uint32_t max) {
917 for (int i = 0; i < L; i++) {
918 if (!scalar_decode_signed(&out->v[i], in + i * bits * DEGREE / 8, bits,
919 max)) {
920 return 0;
921 }
922 }
923 return 1;
924 }
925
926 // FIPS 204, Algorithm 22 (`w1Encode`).
927 //
928 // The output must point to an array of 128*K bytes.
w1_encode(uint8_t * out,const vectork * w1)929 static void w1_encode(uint8_t *out, const vectork *w1) {
930 vectork_encode(out, w1, 4);
931 }
932
933 // FIPS 204, Algorithm 14 (`HintBitPack`).
hint_bit_pack(uint8_t * out,const vectork * h)934 static void hint_bit_pack(uint8_t *out, const vectork *h) {
935 OPENSSL_memset(out, 0, OMEGA + K);
936 int index = 0;
937 for (int i = 0; i < K; i++) {
938 for (int j = 0; j < DEGREE; j++) {
939 if (h->v[i].c[j]) {
940 out[index++] = j;
941 }
942 }
943 out[OMEGA + i] = index;
944 }
945 }
946
947 // FIPS 204, Algorithm 15 (`HintBitUnpack`).
hint_bit_unpack(vectork * h,const uint8_t * in)948 static int hint_bit_unpack(vectork *h, const uint8_t *in) {
949 vectork_zero(h);
950 int index = 0;
951 for (int i = 0; i < K; i++) {
952 int limit = in[OMEGA + i];
953 if (limit < index || limit > OMEGA) {
954 return 0;
955 }
956
957 int last = -1;
958 while (index < limit) {
959 int byte = in[index++];
960 if (last >= 0 && byte <= last) {
961 return 0;
962 }
963 last = byte;
964 h->v[i].c[byte] = 1;
965 }
966 }
967 for (; index < OMEGA; index++) {
968 if (in[index] != 0) {
969 return 0;
970 }
971 }
972 return 1;
973 }
974
975 struct public_key {
976 uint8_t rho[RHO_BYTES];
977 vectork t1;
978 // Pre-cached value(s).
979 uint8_t public_key_hash[TR_BYTES];
980 };
981
982 struct private_key {
983 uint8_t rho[RHO_BYTES];
984 uint8_t k[K_BYTES];
985 uint8_t public_key_hash[TR_BYTES];
986 vectorl s1;
987 vectork s2;
988 vectork t0;
989 };
990
991 struct signature {
992 uint8_t c_tilde[2 * LAMBDA_BYTES];
993 vectorl z;
994 vectork h;
995 };
996
997 // FIPS 204, Algorithm 16 (`pkEncode`).
dilithium_marshal_public_key(CBB * out,const struct public_key * pub)998 static int dilithium_marshal_public_key(CBB *out,
999 const struct public_key *pub) {
1000 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
1001 return 0;
1002 }
1003
1004 uint8_t *vectork_output;
1005 if (!CBB_add_space(out, &vectork_output, 320 * K)) {
1006 return 0;
1007 }
1008 vectork_encode(vectork_output, &pub->t1, 10);
1009
1010 return 1;
1011 }
1012
1013 // FIPS 204, Algorithm 17 (`pkDecode`).
dilithium_parse_public_key(struct public_key * pub,CBS * in)1014 static int dilithium_parse_public_key(struct public_key *pub, CBS *in) {
1015 if (!CBS_copy_bytes(in, pub->rho, sizeof(pub->rho))) {
1016 return 0;
1017 }
1018
1019 CBS t1_bytes;
1020 if (!CBS_get_bytes(in, &t1_bytes, 320 * K)) {
1021 return 0;
1022 }
1023 vectork_decode(&pub->t1, CBS_data(&t1_bytes), 10);
1024
1025 return 1;
1026 }
1027
1028 // FIPS 204, Algorithm 18 (`skEncode`).
dilithium_marshal_private_key(CBB * out,const struct private_key * priv)1029 static int dilithium_marshal_private_key(CBB *out,
1030 const struct private_key *priv) {
1031 if (!CBB_add_bytes(out, priv->rho, sizeof(priv->rho)) ||
1032 !CBB_add_bytes(out, priv->k, sizeof(priv->k)) ||
1033 !CBB_add_bytes(out, priv->public_key_hash,
1034 sizeof(priv->public_key_hash))) {
1035 return 0;
1036 }
1037
1038 uint8_t *vectorl_output;
1039 if (!CBB_add_space(out, &vectorl_output, 128 * L)) {
1040 return 0;
1041 }
1042 vectorl_encode_signed(vectorl_output, &priv->s1, 4, ETA);
1043
1044 uint8_t *vectork_output;
1045 if (!CBB_add_space(out, &vectork_output, 128 * K)) {
1046 return 0;
1047 }
1048 vectork_encode_signed(vectork_output, &priv->s2, 4, ETA);
1049
1050 if (!CBB_add_space(out, &vectork_output, 416 * K)) {
1051 return 0;
1052 }
1053 vectork_encode_signed(vectork_output, &priv->t0, 13, 1 << 12);
1054
1055 return 1;
1056 }
1057
1058 // FIPS 204, Algorithm 19 (`skDecode`).
dilithium_parse_private_key(struct private_key * priv,CBS * in)1059 static int dilithium_parse_private_key(struct private_key *priv, CBS *in) {
1060 CBS s1_bytes;
1061 CBS s2_bytes;
1062 CBS t0_bytes;
1063 if (!CBS_copy_bytes(in, priv->rho, sizeof(priv->rho)) ||
1064 !CBS_copy_bytes(in, priv->k, sizeof(priv->k)) ||
1065 !CBS_copy_bytes(in, priv->public_key_hash,
1066 sizeof(priv->public_key_hash)) ||
1067 !CBS_get_bytes(in, &s1_bytes, 128 * L) ||
1068 !vectorl_decode_signed(&priv->s1, CBS_data(&s1_bytes), 4, ETA) ||
1069 !CBS_get_bytes(in, &s2_bytes, 128 * K) ||
1070 !vectork_decode_signed(&priv->s2, CBS_data(&s2_bytes), 4, ETA) ||
1071 !CBS_get_bytes(in, &t0_bytes, 416 * K) ||
1072 // Note: Decoding 13 bits into (-2^12, 2^12] cannot fail.
1073 !vectork_decode_signed(&priv->t0, CBS_data(&t0_bytes), 13, 1 << 12)) {
1074 return 0;
1075 }
1076
1077 return 1;
1078 }
1079
1080 // FIPS 204, Algorithm 20 (`sigEncode`).
dilithium_marshal_signature(CBB * out,const struct signature * sign)1081 static int dilithium_marshal_signature(CBB *out, const struct signature *sign) {
1082 if (!CBB_add_bytes(out, sign->c_tilde, sizeof(sign->c_tilde))) {
1083 return 0;
1084 }
1085
1086 uint8_t *vectorl_output;
1087 if (!CBB_add_space(out, &vectorl_output, 640 * L)) {
1088 return 0;
1089 }
1090 vectorl_encode_signed(vectorl_output, &sign->z, 20, 1 << 19);
1091
1092 uint8_t *hint_output;
1093 if (!CBB_add_space(out, &hint_output, OMEGA + K)) {
1094 return 0;
1095 }
1096 hint_bit_pack(hint_output, &sign->h);
1097
1098 return 1;
1099 }
1100
1101 // FIPS 204, Algorithm 21 (`sigDecode`).
dilithium_parse_signature(struct signature * sign,CBS * in)1102 static int dilithium_parse_signature(struct signature *sign, CBS *in) {
1103 CBS z_bytes;
1104 CBS hint_bytes;
1105 if (!CBS_copy_bytes(in, sign->c_tilde, sizeof(sign->c_tilde)) ||
1106 !CBS_get_bytes(in, &z_bytes, 640 * L) ||
1107 // Note: Decoding 20 bits into (-2^19, 2^19] cannot fail.
1108 !vectorl_decode_signed(&sign->z, CBS_data(&z_bytes), 20, 1 << 19) ||
1109 !CBS_get_bytes(in, &hint_bytes, OMEGA + K) ||
1110 !hint_bit_unpack(&sign->h, CBS_data(&hint_bytes))) {
1111 return 0;
1112 };
1113
1114 return 1;
1115 }
1116
private_key_from_external(const struct DILITHIUM_private_key * external)1117 static struct private_key *private_key_from_external(
1118 const struct DILITHIUM_private_key *external) {
1119 static_assert(
1120 sizeof(struct DILITHIUM_private_key) == sizeof(struct private_key),
1121 "Kyber private key size incorrect");
1122 static_assert(
1123 alignof(struct DILITHIUM_private_key) == alignof(struct private_key),
1124 "Kyber private key align incorrect");
1125 return (struct private_key *)external;
1126 }
1127
public_key_from_external(const struct DILITHIUM_public_key * external)1128 static struct public_key *public_key_from_external(
1129 const struct DILITHIUM_public_key *external) {
1130 static_assert(
1131 sizeof(struct DILITHIUM_public_key) == sizeof(struct public_key),
1132 "Dilithium public key size incorrect");
1133 static_assert(
1134 alignof(struct DILITHIUM_public_key) == alignof(struct public_key),
1135 "Dilithium public key align incorrect");
1136 return (struct public_key *)external;
1137 }
1138
1139 /* API */
1140
1141 // Calls |DILITHIUM_generate_key_external_entropy| with random bytes from
1142 // |RAND_bytes|. Returns 1 on success and 0 on failure.
DILITHIUM_generate_key(uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],struct DILITHIUM_private_key * out_private_key)1143 int DILITHIUM_generate_key(
1144 uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],
1145 struct DILITHIUM_private_key *out_private_key) {
1146 uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY];
1147 RAND_bytes(entropy, sizeof(entropy));
1148 return DILITHIUM_generate_key_external_entropy(out_encoded_public_key,
1149 out_private_key, entropy);
1150 }
1151
1152 // FIPS 204, Algorithm 1 (`ML-DSA.KeyGen`). Returns 1 on success and 0 on
1153 // failure.
DILITHIUM_generate_key_external_entropy(uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],struct DILITHIUM_private_key * out_private_key,const uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY])1154 int DILITHIUM_generate_key_external_entropy(
1155 uint8_t out_encoded_public_key[DILITHIUM_PUBLIC_KEY_BYTES],
1156 struct DILITHIUM_private_key *out_private_key,
1157 const uint8_t entropy[DILITHIUM_GENERATE_KEY_ENTROPY]) {
1158 int ret = 0;
1159
1160 // Intermediate values, allocated on the heap to allow use when there is a
1161 // limited amount of stack.
1162 struct values_st {
1163 struct public_key pub;
1164 matrix a_ntt;
1165 vectorl s1_ntt;
1166 vectork t;
1167 };
1168 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1169 if (values == NULL) {
1170 goto err;
1171 }
1172
1173 struct private_key *priv = private_key_from_external(out_private_key);
1174
1175 uint8_t expanded_seed[RHO_BYTES + SIGMA_BYTES + K_BYTES];
1176 BORINGSSL_keccak(expanded_seed, sizeof(expanded_seed), entropy,
1177 DILITHIUM_GENERATE_KEY_ENTROPY, boringssl_shake256);
1178 const uint8_t *const rho = expanded_seed;
1179 const uint8_t *const sigma = expanded_seed + RHO_BYTES;
1180 const uint8_t *const k = expanded_seed + RHO_BYTES + SIGMA_BYTES;
1181 // rho is public.
1182 CONSTTIME_DECLASSIFY(rho, RHO_BYTES);
1183 OPENSSL_memcpy(values->pub.rho, rho, sizeof(values->pub.rho));
1184 OPENSSL_memcpy(priv->rho, rho, sizeof(priv->rho));
1185 OPENSSL_memcpy(priv->k, k, sizeof(priv->k));
1186
1187 matrix_expand(&values->a_ntt, rho);
1188 vector_expand_short(&priv->s1, &priv->s2, sigma);
1189
1190 OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1191 vectorl_ntt(&values->s1_ntt);
1192
1193 matrix_mult(&values->t, &values->a_ntt, &values->s1_ntt);
1194 vectork_inverse_ntt(&values->t);
1195 vectork_add(&values->t, &values->t, &priv->s2);
1196
1197 vectork_power2_round(&values->pub.t1, &priv->t0, &values->t);
1198 // t1 is public.
1199 CONSTTIME_DECLASSIFY(&pub.t1, sizeof(pub.t1));
1200
1201 CBB cbb;
1202 CBB_init_fixed(&cbb, out_encoded_public_key, DILITHIUM_PUBLIC_KEY_BYTES);
1203 if (!dilithium_marshal_public_key(&cbb, &values->pub)) {
1204 goto err;
1205 }
1206
1207 BORINGSSL_keccak(priv->public_key_hash, sizeof(priv->public_key_hash),
1208 out_encoded_public_key, DILITHIUM_PUBLIC_KEY_BYTES,
1209 boringssl_shake256);
1210
1211 ret = 1;
1212 err:
1213 OPENSSL_free(values);
1214 return ret;
1215 }
1216
1217 // FIPS 204, Algorithm 2 (`ML-DSA.Sign`). Returns 1 on success and 0 on failure.
dilithium_sign_with_randomizer(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len,const uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES])1218 static int dilithium_sign_with_randomizer(
1219 uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1220 const struct DILITHIUM_private_key *private_key, const uint8_t *msg,
1221 size_t msg_len,
1222 const uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES]) {
1223 int ret = 0;
1224
1225 const struct private_key *priv = private_key_from_external(private_key);
1226
1227 uint8_t mu[MU_BYTES];
1228 struct BORINGSSL_keccak_st keccak_ctx;
1229 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1230 BORINGSSL_keccak_absorb(&keccak_ctx, priv->public_key_hash,
1231 sizeof(priv->public_key_hash));
1232 BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1233 BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
1234
1235 uint8_t rho_prime[RHO_PRIME_BYTES];
1236 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1237 BORINGSSL_keccak_absorb(&keccak_ctx, priv->k, sizeof(priv->k));
1238 BORINGSSL_keccak_absorb(&keccak_ctx, randomizer,
1239 DILITHIUM_SIGNATURE_RANDOMIZER_BYTES);
1240 BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1241 BORINGSSL_keccak_squeeze(&keccak_ctx, rho_prime, RHO_PRIME_BYTES);
1242
1243 // Intermediate values, allocated on the heap to allow use when there is a
1244 // limited amount of stack.
1245 struct values_st {
1246 struct signature sign;
1247 vectorl s1_ntt;
1248 vectork s2_ntt;
1249 vectork t0_ntt;
1250 matrix a_ntt;
1251 vectorl y;
1252 vectorl y_ntt;
1253 vectork w;
1254 vectork w1;
1255 vectorl cs1;
1256 vectork cs2;
1257 vectork r0;
1258 vectork ct0;
1259 };
1260 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1261 if (values == NULL) {
1262 goto err;
1263 }
1264 OPENSSL_memcpy(&values->s1_ntt, &priv->s1, sizeof(values->s1_ntt));
1265 vectorl_ntt(&values->s1_ntt);
1266
1267 OPENSSL_memcpy(&values->s2_ntt, &priv->s2, sizeof(values->s2_ntt));
1268 vectork_ntt(&values->s2_ntt);
1269
1270 OPENSSL_memcpy(&values->t0_ntt, &priv->t0, sizeof(values->t0_ntt));
1271 vectork_ntt(&values->t0_ntt);
1272
1273 matrix_expand(&values->a_ntt, priv->rho);
1274
1275 for (size_t kappa = 0;; kappa += L) {
1276 //TODO(bbe): y only lives long enough to compute y_ntt.
1277 //consider using another vectorl to save memory.
1278 vectorl_expand_mask(&values->y, rho_prime, kappa);
1279
1280 OPENSSL_memcpy(&values->y_ntt, &values->y, sizeof(values->y_ntt));
1281 vectorl_ntt(&values->y_ntt);
1282
1283 //TODO(bbe): w only lives long enough to compute y_ntt.
1284 //consider using another vectork to save memory.
1285 matrix_mult(&values->w, &values->a_ntt, &values->y_ntt);
1286 vectork_inverse_ntt(&values->w);
1287
1288 vectork_high_bits(&values->w1, &values->w);
1289 uint8_t w1_encoded[128 * K];
1290 w1_encode(w1_encoded, &values->w1);
1291
1292 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1293 BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1294 BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1295 BORINGSSL_keccak_squeeze(&keccak_ctx, values->sign.c_tilde,
1296 2 * LAMBDA_BYTES);
1297
1298 scalar c_ntt;
1299 scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde, 32);
1300 scalar_ntt(&c_ntt);
1301
1302 vectorl_mult_scalar(&values->cs1, &values->s1_ntt, &c_ntt);
1303 vectorl_inverse_ntt(&values->cs1);
1304 vectork_mult_scalar(&values->cs2, &values->s2_ntt, &c_ntt);
1305 vectork_inverse_ntt(&values->cs2);
1306
1307 vectorl_add(&values->sign.z, &values->y, &values->cs1);
1308
1309 vectork_sub(&values->r0, &values->w, &values->cs2);
1310 vectork_low_bits(&values->r0, &values->r0);
1311
1312 // Leaking the fact that a signature was rejected is fine as the next
1313 // attempt at a signature will be (indistinguishable from) independent of
1314 // this one. Note, however, that we additionally leak which of the two
1315 // branches rejected the signature. Section 5.5 of
1316 // https://pq-crystals.org/dilithium/data/dilithium-specification-round3.pdf
1317 // describes this leak as OK. Note we leak less than what is described by
1318 // the paper; we do not reveal which coefficient violated the bound, and we
1319 // hide which of the |z_max| or |r0_max| bound failed. See also
1320 // https://boringssl-review.googlesource.com/c/boringssl/+/67747/comment/2bbab0fa_d241d35a/
1321 uint32_t z_max = vectorl_max(&values->sign.z);
1322 uint32_t r0_max = vectork_max_signed(&values->r0);
1323 if (constant_time_declassify_w(
1324 constant_time_ge_w(z_max, kGamma1 - BETA) |
1325 constant_time_ge_w(r0_max, kGamma2 - BETA))) {
1326 continue;
1327 }
1328
1329 vectork_mult_scalar(&values->ct0, &values->t0_ntt, &c_ntt);
1330 vectork_inverse_ntt(&values->ct0);
1331 vectork_make_hint(&values->sign.h, &values->ct0, &values->cs2, &values->w);
1332
1333 // See above.
1334 uint32_t ct0_max = vectork_max(&values->ct0);
1335 size_t h_ones = vectork_count_ones(&values->sign.h);
1336 if (constant_time_declassify_w(constant_time_ge_w(ct0_max, kGamma2) |
1337 constant_time_ge_w(h_ones, OMEGA))) {
1338 continue;
1339 }
1340
1341 // Although computed with the private key, the signature is public.
1342 CONSTTIME_DECLASSIFY(values->sign.c_tilde, sizeof(values->sign.c_tilde));
1343 CONSTTIME_DECLASSIFY(&values->sign.z, sizeof(values->sign.z));
1344 CONSTTIME_DECLASSIFY(&values->sign.h, sizeof(values->sign.h));
1345
1346 CBB cbb;
1347 CBB_init_fixed(&cbb, out_encoded_signature, DILITHIUM_SIGNATURE_BYTES);
1348 if (!dilithium_marshal_signature(&cbb, &values->sign)) {
1349 goto err;
1350 }
1351
1352 BSSL_CHECK(CBB_len(&cbb) == DILITHIUM_SIGNATURE_BYTES);
1353 ret = 1;
1354 break;
1355 }
1356
1357 err:
1358 OPENSSL_free(values);
1359 return ret;
1360 }
1361
1362 // Dilithium signature in deterministic mode. Returns 1 on success and 0 on
1363 // failure.
DILITHIUM_sign_deterministic(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len)1364 int DILITHIUM_sign_deterministic(
1365 uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1366 const struct DILITHIUM_private_key *private_key, const uint8_t *msg,
1367 size_t msg_len) {
1368 uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES];
1369 OPENSSL_memset(randomizer, 0, sizeof(randomizer));
1370 return dilithium_sign_with_randomizer(out_encoded_signature, private_key, msg,
1371 msg_len, randomizer);
1372 }
1373
1374 // Dilithium signature in randomized mode, filling the random bytes with
1375 // |RAND_bytes|. Returns 1 on success and 0 on failure.
DILITHIUM_sign(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],const struct DILITHIUM_private_key * private_key,const uint8_t * msg,size_t msg_len)1376 int DILITHIUM_sign(uint8_t out_encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1377 const struct DILITHIUM_private_key *private_key,
1378 const uint8_t *msg, size_t msg_len) {
1379 uint8_t randomizer[DILITHIUM_SIGNATURE_RANDOMIZER_BYTES];
1380 RAND_bytes(randomizer, sizeof(randomizer));
1381 return dilithium_sign_with_randomizer(out_encoded_signature, private_key, msg,
1382 msg_len, randomizer);
1383 }
1384
1385 // FIPS 204, Algorithm 3 (`ML-DSA.Verify`).
DILITHIUM_verify(const struct DILITHIUM_public_key * public_key,const uint8_t encoded_signature[DILITHIUM_SIGNATURE_BYTES],const uint8_t * msg,size_t msg_len)1386 int DILITHIUM_verify(const struct DILITHIUM_public_key *public_key,
1387 const uint8_t encoded_signature[DILITHIUM_SIGNATURE_BYTES],
1388 const uint8_t *msg, size_t msg_len) {
1389 int ret = 0;
1390
1391 // Intermediate values, allocated on the heap to allow use when there is a
1392 // limited amount of stack.
1393 struct values_st {
1394 struct signature sign;
1395 matrix a_ntt;
1396 vectorl z_ntt;
1397 vectork az_ntt;
1398 vectork t1_ntt;
1399 vectork ct1_ntt;
1400 vectork w_approx;
1401 vectork w1;
1402 };
1403 struct values_st *values = OPENSSL_malloc(sizeof(*values));
1404 if (values == NULL) {
1405 goto err;
1406 }
1407
1408 const struct public_key *pub = public_key_from_external(public_key);
1409
1410 CBS cbs;
1411 CBS_init(&cbs, encoded_signature, DILITHIUM_SIGNATURE_BYTES);
1412 if (!dilithium_parse_signature(&values->sign, &cbs)) {
1413 goto err;
1414 }
1415
1416 matrix_expand(&values->a_ntt, pub->rho);
1417
1418 uint8_t mu[MU_BYTES];
1419 struct BORINGSSL_keccak_st keccak_ctx;
1420 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1421 BORINGSSL_keccak_absorb(&keccak_ctx, pub->public_key_hash,
1422 sizeof(pub->public_key_hash));
1423 BORINGSSL_keccak_absorb(&keccak_ctx, msg, msg_len);
1424 BORINGSSL_keccak_squeeze(&keccak_ctx, mu, MU_BYTES);
1425
1426 scalar c_ntt;
1427 scalar_sample_in_ball_vartime(&c_ntt, values->sign.c_tilde, 32);
1428 scalar_ntt(&c_ntt);
1429
1430 OPENSSL_memcpy(&values->z_ntt, &values->sign.z, sizeof(values->z_ntt));
1431 vectorl_ntt(&values->z_ntt);
1432
1433 matrix_mult(&values->az_ntt, &values->a_ntt, &values->z_ntt);
1434
1435 vectork_scale_power2_round(&values->t1_ntt, &pub->t1);
1436 vectork_ntt(&values->t1_ntt);
1437
1438 vectork_mult_scalar(&values->ct1_ntt, &values->t1_ntt, &c_ntt);
1439
1440 vectork_sub(&values->w_approx, &values->az_ntt, &values->ct1_ntt);
1441 vectork_inverse_ntt(&values->w_approx);
1442
1443 vectork_use_hint_vartime(&values->w1, &values->sign.h, &values->w_approx);
1444 uint8_t w1_encoded[128 * K];
1445 w1_encode(w1_encoded, &values->w1);
1446
1447 uint8_t c_tilde[2 * LAMBDA_BYTES];
1448 BORINGSSL_keccak_init(&keccak_ctx, boringssl_shake256);
1449 BORINGSSL_keccak_absorb(&keccak_ctx, mu, MU_BYTES);
1450 BORINGSSL_keccak_absorb(&keccak_ctx, w1_encoded, 128 * K);
1451 BORINGSSL_keccak_squeeze(&keccak_ctx, c_tilde, 2 * LAMBDA_BYTES);
1452
1453 uint32_t z_max = vectorl_max(&values->sign.z);
1454 size_t h_ones = vectork_count_ones(&values->sign.h);
1455 if (z_max < kGamma1 - BETA && h_ones <= OMEGA &&
1456 OPENSSL_memcmp(c_tilde, values->sign.c_tilde, 2 * LAMBDA_BYTES) == 0) {
1457 ret = 1;
1458 }
1459
1460 err:
1461 OPENSSL_free(values);
1462 return ret;
1463 }
1464
1465 /* Serialization of keys. */
1466
DILITHIUM_marshal_public_key(CBB * out,const struct DILITHIUM_public_key * public_key)1467 int DILITHIUM_marshal_public_key(
1468 CBB *out, const struct DILITHIUM_public_key *public_key) {
1469 return dilithium_marshal_public_key(out,
1470 public_key_from_external(public_key));
1471 }
1472
DILITHIUM_parse_public_key(struct DILITHIUM_public_key * public_key,CBS * in)1473 int DILITHIUM_parse_public_key(struct DILITHIUM_public_key *public_key,
1474 CBS *in) {
1475 struct public_key *pub = public_key_from_external(public_key);
1476 CBS orig_in = *in;
1477 if (!dilithium_parse_public_key(pub, in) || CBS_len(in) != 0) {
1478 return 0;
1479 }
1480
1481 // Compute pre-cached values.
1482 BORINGSSL_keccak(pub->public_key_hash, sizeof(pub->public_key_hash),
1483 CBS_data(&orig_in), CBS_len(&orig_in), boringssl_shake256);
1484 return 1;
1485 }
1486
DILITHIUM_marshal_private_key(CBB * out,const struct DILITHIUM_private_key * private_key)1487 int DILITHIUM_marshal_private_key(
1488 CBB *out, const struct DILITHIUM_private_key *private_key) {
1489 return dilithium_marshal_private_key(out,
1490 private_key_from_external(private_key));
1491 }
1492
DILITHIUM_parse_private_key(struct DILITHIUM_private_key * private_key,CBS * in)1493 int DILITHIUM_parse_private_key(struct DILITHIUM_private_key *private_key,
1494 CBS *in) {
1495 struct private_key *priv = private_key_from_external(private_key);
1496 return dilithium_parse_private_key(priv, in) && CBS_len(in) == 0;
1497 }
1498