1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include <stdlib.h>
19 #include <string.h>
20 #include <stdio.h>
21 #include "securec.h"
22 #include "bsl_errno.h"
23 #include "bsl_sal.h"
24 #include "crypt_utils.h"
25 #include "crypt_sha3.h"
26 #include "crypt_errno.h"
27 #include "bsl_err_internal.h"
28 #include "eal_md_local.h"
29 #include "ml_kem_local.h"
30
31 #define BITS_OF_BYTE 8
32 #define MLKEM_K_MAX 4
33 #define MLKEM_ETA1_MAX 3
34 #define MLKEM_ETA2_MAX 2
35
36 // A LUT of the primitive n-th roots of unity (psi) in bit-reversed order.
37 static const int16_t PRE_COMPUT_TABLE_NTT[MLKEM_N_HALF] = {
38 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476,
39 3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
40 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
41 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
42 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220,
43 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
44 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154
45 };
46
47 // A LUT of all powers of psi^{-1} in bit-reversed order.
48 static const int16_t PRE_COMPUT_TABLE_INTT[MLKEM_N_HALF] = {
49 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447,
50 2794, 1235, 1903, 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 2681, 1848, 712, 682, 927,
51 1795, 461, 1891, 2877, 2522, 1894, 1010, 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
52 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 2443, 554, 1179, 2186, 2303, 2926, 2237,
53 525, 735, 863, 2768, 1230, 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 2688, 3061,
54 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096,
55 48, 667, 1920, 2229, 1041, 2606, 1692, 680, 2746, 568, 3312
56 };
57
58 typedef struct {
59 int16_t *bufAddr;
60 int16_t *matrix[MLKEM_K_MAX][MLKEM_K_MAX];
61 int16_t *vectorS[MLKEM_K_MAX];
62 int16_t *vectorE[MLKEM_K_MAX];
63 int16_t *vectorT[MLKEM_K_MAX];
64 } MLKEM_MatrixSt; // Intermediate data of the key generation and encryption.
65
66 typedef struct {
67 int16_t *bufAddr;
68 int16_t *vectorS[MLKEM_K_MAX];
69 int16_t *vectorC1[MLKEM_K_MAX];
70 int16_t *vectorC2;
71 int16_t *polyM;
72 } MLKEM_DecVectorSt; // Intermediate data of the decryption.
73
CreateMatrixBuf(uint8_t k,MLKEM_MatrixSt * st)74 static int32_t CreateMatrixBuf(uint8_t k, MLKEM_MatrixSt *st)
75 {
76 // A total of (k * k + 3 * k) data blocks are required. Each block has 512 bytes.
77 int16_t *buf = BSL_SAL_Malloc((k * k + 3 * k) * MLKEM_N * sizeof(int16_t));
78 if (buf == NULL) {
79 return BSL_MALLOC_FAIL;
80 }
81 st->bufAddr = buf; // Used to release memory.
82 for (uint8_t i = 0; i < k; i++) {
83 for (uint8_t j = 0; j < k; j++) {
84 st->matrix[i][j] = buf + (i * k + j) * MLKEM_N;
85 }
86 // vectorS,vectorE,vectorT use 3 * k data blocks.
87 st->vectorS[i] = buf + (k * k + i * 3) * MLKEM_N;
88 st->vectorE[i] = buf + (k * k + i * 3 + 1) * MLKEM_N;
89 st->vectorT[i] = buf + (k * k + i * 3 + 2) * MLKEM_N;
90 }
91 return CRYPT_SUCCESS;
92 }
93
MatrixBufFree(uint8_t k,MLKEM_MatrixSt * st)94 static void MatrixBufFree(uint8_t k, MLKEM_MatrixSt *st)
95 {
96 // A total of (k * k + 3 * k) data blocks, each block has 512 bytes.
97 BSL_SAL_ClearFree(st->bufAddr, (k * k + 3 * k) * MLKEM_N * sizeof(int16_t));
98 }
99
CreateDecVectorBuf(uint8_t k,MLKEM_DecVectorSt * st)100 static int32_t CreateDecVectorBuf(uint8_t k, MLKEM_DecVectorSt *st)
101 {
102 // A total of (k * 2 + 2) data blocks are required. Each block has 512 bytes.
103 int16_t *buf = BSL_SAL_Malloc((k * 2 + 2) * MLKEM_N * sizeof(int16_t));
104 if (buf == NULL) {
105 return BSL_MALLOC_FAIL;
106 }
107 st->bufAddr = buf; // Used to release memory.
108 for (uint8_t i = 0; i < k; i++) {
109 st->vectorS[i] = buf + (i) * MLKEM_N;
110 st->vectorC1[i] = buf + (k + i) * MLKEM_N;
111 }
112 // vectorC2 and polyM use 2 * k data blocks.
113 st->vectorC2 = buf + (k * 2) * MLKEM_N;
114 st->polyM = buf + (k * 2 + 1) * MLKEM_N;
115 return CRYPT_SUCCESS;
116 }
117
DecVectorBufFree(uint8_t k,MLKEM_DecVectorSt * st)118 static void DecVectorBufFree(uint8_t k, MLKEM_DecVectorSt *st)
119 {
120 // A total of (k * 2 + 2) data blocks, each block has 512 bytes.
121 BSL_SAL_ClearFree(st->bufAddr, (k * 2 + 2) * MLKEM_N * sizeof(int16_t));
122 }
123
124 // Compress
125 typedef struct {
126 uint64_t barrettMultiplier; /* round(2 ^ barrettShift / MLKEM_Q) */
127 uint16_t barrettShift;
128 uint16_t halfQ; /* rounded (MLKEM_Q / 2) down or up */
129 uint8_t bits;
130 } MLKEM_BARRET_REDUCE;
131
132 // The values of du and dv are from NIST.FIPS.203 Table 2.
133 static const MLKEM_BARRET_REDUCE MLKEM_BARRETT_TABLE[] = {
134 {80635 /* round(2^28/MLKEM_Q) */, 28, 1665 /* Ceil(MLKEM_Q/2) */, 1},
135 {1290167 /* round(2^32/MLKEM_Q) */, 32, 1665 /* Ceil(MLKEM_Q/2) */, 10}, // 10 is mlkem768 du
136 {80635 /* round(2^28/MLKEM_Q) */, 28, 1665 /* Ceil(MLKEM_Q/2) */, 4}, // 4 is mlkem768 dv
137 {40318 /* round(2^27/MLKEM_Q) */, 27, 1664 /* Floor(MLKEM_Q/2) */, 5}, // 5 is mlkem1024 dv
138 {645084 /* round(2^31/MLKEM_Q) */, 31, 1664 /* Floor(MLKEM_Q/2) */, 11} // 11 is mlkem1024 du
139 };
140
DivMlKemQ(uint16_t x,uint8_t bits,uint16_t halfQ,uint16_t barrettShift,uint64_t barrettMultiplier)141 static int16_t DivMlKemQ(uint16_t x, uint8_t bits, uint16_t halfQ, uint16_t barrettShift, uint64_t barrettMultiplier)
142 {
143 uint64_t round = ((uint64_t)x << bits) + halfQ;
144 round *= barrettMultiplier;
145 round >>= barrettShift;
146 return (int16_t)(round & ((1 << bits) - 1));
147 }
148
Compress(int16_t x,uint8_t d)149 static int16_t Compress(int16_t x, uint8_t d)
150 {
151 int16_t value = 0;
152 uint16_t t = (uint16_t)(x + MLKEM_Q) % MLKEM_Q;
153 /* Computing (x << d) / MLKEM_Q by Barret Reduce */
154 for (uint32_t i = 0; i < sizeof(MLKEM_BARRETT_TABLE) / sizeof(MLKEM_BARRET_REDUCE); i++) {
155 if (d == MLKEM_BARRETT_TABLE[i].bits) {
156 value = DivMlKemQ(t,
157 MLKEM_BARRETT_TABLE[i].bits,
158 MLKEM_BARRETT_TABLE[i].halfQ,
159 MLKEM_BARRETT_TABLE[i].barrettShift,
160 MLKEM_BARRETT_TABLE[i].barrettMultiplier);
161 break;
162 }
163 }
164 return value;
165 }
166
167 // DeCompress
DeCompress(int16_t x,uint8_t bits)168 static int16_t DeCompress(int16_t x, uint8_t bits)
169 {
170 uint32_t product = (uint32_t)x * MLKEM_Q;
171 uint32_t power = 1 << bits;
172 return (int16_t)((product >> bits) + ((product & (power - 1)) >> (bits - 1)));
173 }
174
175 // hash functions
HashFuncH(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)176 static int32_t HashFuncH(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
177 {
178 uint32_t len = outLen;
179 return EAL_Md(CRYPT_MD_SHA3_256, in, inLen, out, &len);
180 }
181
HashFuncG(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)182 static int32_t HashFuncG(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
183 {
184 uint32_t len = outLen;
185 return EAL_Md(CRYPT_MD_SHA3_512, in, inLen, out, &len);
186 }
187
HashFuncXOF(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)188 static int32_t HashFuncXOF(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
189 {
190 uint32_t len = outLen;
191 return EAL_Md(CRYPT_MD_SHAKE128, in, inLen, out, &len);
192 }
193
HashFuncJ(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)194 static int32_t HashFuncJ(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
195 {
196 uint32_t len = outLen;
197 return EAL_Md(CRYPT_MD_SHAKE256, in, inLen, out, &len);
198 }
199
PRF(uint8_t * extSeed,uint32_t extSeedLen,uint8_t * outBuf,uint32_t bufLen)200 static int32_t PRF(uint8_t *extSeed, uint32_t extSeedLen, uint8_t *outBuf, uint32_t bufLen)
201 {
202 uint32_t len = bufLen;
203 return EAL_Md(CRYPT_MD_SHAKE256, extSeed, extSeedLen, outBuf, &len);
204 }
205
Parse(uint16_t * polyNtt,uint8_t * arrayB,uint32_t arrayLen,uint32_t n)206 static int32_t Parse(uint16_t *polyNtt, uint8_t *arrayB, uint32_t arrayLen, uint32_t n)
207 {
208 uint32_t i = 0;
209 uint32_t j = 0;
210 while (j < n) {
211 if (i + 3 > arrayLen) { // 3 bytes of arrayB are read in each round.
212 BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
213 return CRYPT_MLKEM_KEYLEN_ERROR;
214 }
215 // The 4 bits of each byte are combined with the 8 bits of another byte into 12 bits.
216 uint16_t d1 = ((uint16_t)arrayB[i]) + (((uint16_t)arrayB[i + 1] & 0x0f) << 8); // 4 bits.
217 uint16_t d2 = (((uint16_t)arrayB[i + 1]) >> 4) + (((uint16_t)arrayB[i + 2]) << 4);
218 if (d1 < MLKEM_Q) {
219 polyNtt[j] = d1;
220 j++;
221 }
222 if (d2 < MLKEM_Q && j < n) {
223 polyNtt[j] = d2;
224 j++;
225 }
226 i += 3; // 3 bytes are processed in each round.
227 }
228 return CRYPT_SUCCESS;
229 }
230
EncodeBits1(uint8_t * r,uint16_t * polyF)231 static void EncodeBits1(uint8_t *r, uint16_t *polyF)
232 {
233 for (uint32_t i = 0; i < MLKEM_N / BITS_OF_BYTE; i++) {
234 r[i] = (uint8_t)polyF[BITS_OF_BYTE * i];
235 for (uint32_t j = 1; j < BITS_OF_BYTE; j++) {
236 r[i] = (uint8_t)(polyF[BITS_OF_BYTE * i + j] << j) | r[i];
237 }
238 }
239 }
240
EncodeBits4(uint8_t * r,uint16_t * polyF)241 static void EncodeBits4(uint8_t *r, uint16_t *polyF)
242 {
243 for (uint32_t i = 0; i < MLKEM_N / 2; i++) { // Two 4 bits are combined into 1 byte.
244 r[i] = ((uint8_t)polyF[2 * i] | ((uint8_t)polyF[2 * i + 1] << 4));
245 }
246 }
247
EncodeBits5(uint8_t * r,uint16_t * polyF)248 static void EncodeBits5(uint8_t *r, uint16_t *polyF)
249 {
250 uint32_t indexR;
251 uint32_t indexF;
252 for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
253 indexR = 5 * i; // Each element in polyF has 5 bits.
254 indexF = 8 * i; // Each element in r has 8 bits.
255 // 8 polyF elements are padded to 5 bytes.
256 r[indexR + 0] = (uint8_t)(polyF[indexF] | (polyF[indexF + 1] << 5));
257 r[indexR + 1] =
258 (uint8_t)((polyF[indexF + 1] >> 3) | (polyF[indexF + 2] << 2) | (polyF[indexF + 3] << 7));
259 r[indexR + 2] = (uint8_t)((polyF[indexF + 3] >> 1) | (polyF[indexF + 4] << 4));
260 r[indexR + 3] =
261 (uint8_t)((polyF[indexF + 4] >> 4) | (polyF[indexF + 5] << 1) | (polyF[indexF + 6] << 6));
262 r[indexR + 4] = (uint8_t)((polyF[indexF + 6] >> 2) | (polyF[indexF + 7] << 3));
263 }
264 }
265
EncodeBits10(uint8_t * r,uint16_t * polyF)266 static void EncodeBits10(uint8_t *r, uint16_t *polyF)
267 {
268 uint32_t indexR;
269 uint32_t indexF;
270 for (uint32_t i = 0; i < MLKEM_N / 4; i++) {
271 // 4 polyF elements are padded to 5 bytes.
272 indexR = 5 * i;
273 indexF = 4 * i;
274 r[indexR + 0] = (uint8_t)polyF[indexF];
275 r[indexR + 1] = (uint8_t)((polyF[indexF] >> 8) | (polyF[indexF + 1] << 2));
276 r[indexR + 2] = (uint8_t)((polyF[indexF + 1] >> 6) | (polyF[indexF + 2] << 4));
277 r[indexR + 3] = (uint8_t)((polyF[indexF + 2] >> 4) | (polyF[indexF + 3] << 6));
278 r[indexR + 4] = (uint8_t)(polyF[indexF + 3] >> 2);
279 }
280 }
281
EncodeBits11(uint8_t * r,uint16_t * polyF)282 static void EncodeBits11(uint8_t *r, uint16_t *polyF)
283 {
284 uint32_t indexR;
285 uint32_t indexF;
286 for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
287 // 8 polyF elements are padded to 11 bytes.
288 indexR = 11 * i;
289 indexF = 8 * i;
290 r[indexR + 0] = (uint8_t)polyF[indexF];
291 r[indexR + 1] = (uint8_t)((polyF[indexF] >> 8) | (polyF[indexF + 1] << 3));
292 r[indexR + 2] = (uint8_t)((polyF[indexF + 1] >> 5) | (polyF[indexF + 2] << 6));
293 r[indexR + 3] = (uint8_t)((polyF[indexF + 2] >> 2));
294 r[indexR + 4] = (uint8_t)((polyF[indexF + 2] >> 10) | (polyF[indexF + 3] << 1));
295 r[indexR + 5] = (uint8_t)((polyF[indexF + 3] >> 7) | (polyF[indexF + 4] << 4));
296 r[indexR + 6] = (uint8_t)((polyF[indexF + 4] >> 4) | (polyF[indexF + 5] << 7));
297 r[indexR + 7] = (uint8_t)((polyF[indexF + 5] >> 1));
298 r[indexR + 8] = (uint8_t)((polyF[indexF + 5] >> 9) | (polyF[indexF + 6] << 2));
299 r[indexR + 9] = (uint8_t)((polyF[indexF + 6] >> 6) | (polyF[indexF + 7] << 5));
300 r[indexR + 10] = (uint8_t)(polyF[indexF + 7] >> 3);
301 }
302 }
303
EncodeBits12(uint8_t * r,uint16_t * polyF)304 static void EncodeBits12(uint8_t *r, uint16_t *polyF)
305 {
306 uint32_t i;
307 uint16_t t0;
308 uint16_t t1;
309 for (i = 0; i < MLKEM_N / 2; i++) {
310 // 2 polyF elements are padded to 3 bytes.
311 t0 = polyF[2 * i];
312 t1 = polyF[2 * i + 1];
313 r[3 * i + 0] = (uint8_t)(t0 >> 0);
314 r[3 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 4));
315 r[3 * i + 2] = (uint8_t)(t1 >> 4);
316 }
317 }
318
319 // Encodes an array of d-bit integers into a byte array for 1 ≤ d ≤ 12.
ByteEncode(uint8_t * r,int16_t * polyF,uint8_t bit)320 static void ByteEncode(uint8_t *r, int16_t *polyF, uint8_t bit)
321 {
322 switch (bit) { // Valid bits of each element in polyF.
323 case 1: // 1 Used for K-PKE.Decrypt Step 7.
324 EncodeBits1(r, (uint16_t *)polyF);
325 break;
326 case 4: // From FIPS 203 Table 2, dv = 4
327 EncodeBits4(r, (uint16_t *)polyF);
328 break;
329 case 5: // dv = 5
330 EncodeBits5(r, (uint16_t *)polyF);
331 break;
332 case 10: // du = 10
333 EncodeBits10(r, (uint16_t *)polyF);
334 break;
335 case 11: // du = 11
336 EncodeBits11(r, (uint16_t *)polyF);
337 break;
338 case 12: // 12 Used for K-PKE.KeyGen Step 19.
339 EncodeBits12(r, (uint16_t *)polyF);
340 break;
341 default:
342 break;
343 }
344 }
345
DecodeBits1(int16_t * polyF,const uint8_t * a)346 static void DecodeBits1(int16_t *polyF, const uint8_t *a)
347 {
348 uint32_t i;
349 uint32_t j;
350 for (i = 0; i < MLKEM_N / BITS_OF_BYTE; i++) {
351 // 1 byte data is decoded into 8 polyF elements.
352 for (j = 0; j < BITS_OF_BYTE; j++) {
353 polyF[BITS_OF_BYTE * i + j] = (a[i] >> j) & 0x01;
354 }
355 }
356 }
357
DecodeBits4(int16_t * polyF,const uint8_t * a)358 static void DecodeBits4(int16_t *polyF, const uint8_t *a)
359 {
360 uint32_t i;
361 for (i = 0; i < MLKEM_N / 2; i++) {
362 // 1 byte data is decoded into 2 polyF elements.
363 polyF[2 * i] = a[i] & 0xF;
364 polyF[2 * i + 1] = (a[i] >> 4) & 0xF;
365 }
366 }
367
DecodeBits5(int16_t * polyF,const uint8_t * a)368 static void DecodeBits5(int16_t *polyF, const uint8_t *a)
369 {
370 uint32_t indexF;
371 uint32_t indexA;
372 for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
373 // 8 byte data is decoded into 5 polyF elements.
374 indexF = 8 * i;
375 indexA = 5 * i;
376 // value & 0x1F is used to obtain 5 bits.
377 polyF[indexF + 0] = ((a[indexA + 0] >> 0)) & 0x1F;
378 polyF[indexF + 1] = ((a[indexA + 0] >> 5) | (a[indexA + 1] << 3)) & 0x1F;
379 polyF[indexF + 2] = ((a[indexA + 1] >> 2)) & 0x1F;
380 polyF[indexF + 3] = ((a[indexA + 1] >> 7) | (a[indexA + 2] << 1)) & 0x1F;
381 polyF[indexF + 4] = ((a[indexA + 2] >> 4) | (a[indexA + 3] << 4)) & 0x1F;
382 polyF[indexF + 5] = ((a[indexA + 3] >> 1)) & 0x1F;
383 polyF[indexF + 6] = ((a[indexA + 3] >> 6) | (a[indexA + 4] << 2)) & 0x1F;
384 polyF[indexF + 7] = ((a[indexA + 4] >> 3)) & 0x1F;
385 }
386 }
387
DecodeBits10(int16_t * polyF,const uint8_t * a)388 static void DecodeBits10(int16_t *polyF, const uint8_t *a)
389 {
390 uint32_t indexF;
391 uint32_t indexA;
392 for (uint32_t i = 0; i < MLKEM_N / 4; i++) {
393 // 5 byte data is decoded into 4 polyF elements.
394 indexF = 4 * i;
395 indexA = 5 * i;
396 // value & 0x3FF is used to obtain 10 bits.
397 polyF[indexF + 0] = ((a[indexA + 0] >> 0) | ((uint16_t)a[indexA + 1] << 8)) & 0x3FF;
398 polyF[indexF + 1] = ((a[indexA + 1] >> 2) | ((uint16_t)a[indexA + 2] << 6)) & 0x3FF;
399 polyF[indexF + 2] = ((a[indexA + 2] >> 4) | ((uint16_t)a[indexA + 3] << 4)) & 0x3FF;
400 polyF[indexF + 3] = ((a[indexA + 3] >> 6) | ((uint16_t)a[indexA + 4] << 2)) & 0x3FF;
401 }
402 }
403
DecodeBits11(int16_t * polyF,const uint8_t * a)404 static void DecodeBits11(int16_t *polyF, const uint8_t *a)
405 {
406 uint32_t indexF;
407 uint32_t indexA;
408 for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
409 // use type conversion because 11 > 8
410 indexF = 8 * i;
411 indexA = 11 * i;
412 // value & 0x7FF is used to obtain 11 bits.
413 polyF[indexF + 0] = ((a[indexA + 0] >> 0) | ((uint16_t)a[indexA + 1] << 8)) & 0x7FF;
414 polyF[indexF + 1] = ((a[indexA + 1] >> 3) | ((uint16_t)a[indexA + 2] << 5)) & 0x7FF;
415 polyF[indexF + 2] = ((a[indexA + 2] >> 6) | ((uint16_t)a[indexA + 3] << 2) |
416 ((uint16_t)a[indexA + 4] << 10)) & 0x7FF;
417 polyF[indexF + 3] = ((a[indexA + 4] >> 1) | ((uint16_t)a[indexA + 5] << 7)) & 0x7FF;
418 polyF[indexF + 4] = ((a[indexA + 5] >> 4) | ((uint16_t)a[indexA + 6] << 4)) & 0x7FF;
419 polyF[indexF + 5] = ((a[indexA + 6] >> 7) | ((uint16_t)a[indexA + 7] << 1) |
420 ((uint16_t)a[indexA + 8] << 9)) & 0x7FF;
421 polyF[indexF + 6] = ((a[indexA + 8] >> 2) | ((uint16_t)a[indexA + 9] << 6)) & 0x7FF;
422 polyF[indexF + 7] = ((a[indexA + 9] >> 5) | ((uint16_t)a[indexA + 10] << 3)) & 0x7FF;
423 }
424 }
425
DecodeBits12(int16_t * polyF,const uint8_t * a)426 static void DecodeBits12(int16_t *polyF, const uint8_t *a)
427 {
428 uint32_t i;
429 for (i = 0; i < MLKEM_N / 2; i++) {
430 // 3 byte data is decoded into 2 polyF elements, value & 0xFFF is used to obtain 12 bits.
431 polyF[2 * i] = ((a[3 * i + 0] >> 0) | ((uint16_t)a[3 * i + 1] << 8)) & 0xFFF;
432 polyF[2 * i + 1] = ((a[3 * i + 1] >> 4) | ((uint16_t)a[3 * i + 2] << 4)) & 0xFFF;
433 }
434 }
435
436 // Decodes a byte array into an array of d-bit integers for 1 ≤ d ≤ 12.
ByteDecode(int16_t * polyF,const uint8_t * a,uint8_t bit)437 static void ByteDecode(int16_t *polyF, const uint8_t *a, uint8_t bit)
438 {
439 switch (bit) {
440 case 1:
441 DecodeBits1(polyF, a);
442 break;
443 case 4:
444 DecodeBits4(polyF, a);
445 break;
446 case 5:
447 DecodeBits5(polyF, a);
448 break;
449 case 10:
450 DecodeBits10(polyF, a);
451 break;
452 case 11:
453 DecodeBits11(polyF, a);
454 break;
455 case 12:
456 DecodeBits12(polyF, a);
457 break;
458 default:
459 break;
460 }
461 }
462
GenMatrix(const CRYPT_ML_KEM_Ctx * ctx,const uint8_t * digest,int16_t * polyMatrix[MLKEM_K_MAX][MLKEM_K_MAX],bool isEnc)463 static int32_t GenMatrix(const CRYPT_ML_KEM_Ctx *ctx, const uint8_t *digest,
464 int16_t *polyMatrix[MLKEM_K_MAX][MLKEM_K_MAX], bool isEnc)
465 {
466 uint8_t k = ctx->info->k;
467 uint8_t p[MLKEM_SEED_LEN + 2]; // Reserved lengths of i and j is 2 byte.
468 uint8_t xofOut[MLKEM_XOF_OUTPUT_LENGTH];
469
470 (void)memcpy_s(p, MLKEM_SEED_LEN, digest, MLKEM_SEED_LEN);
471 for (uint8_t i = 0; i < k; i++) {
472 for (uint8_t j = 0; j < k; j++) {
473 if (isEnc) {
474 p[MLKEM_SEED_LEN] = i;
475 p[MLKEM_SEED_LEN + 1] = j;
476 } else {
477 p[MLKEM_SEED_LEN] = j;
478 p[MLKEM_SEED_LEN + 1] = i;
479 }
480 int32_t ret = HashFuncXOF(p, MLKEM_SEED_LEN + 2, xofOut, MLKEM_XOF_OUTPUT_LENGTH);
481 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
482 ret = Parse((uint16_t *)polyMatrix[i][j], xofOut, MLKEM_XOF_OUTPUT_LENGTH, MLKEM_N);
483 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
484 }
485 }
486 return CRYPT_SUCCESS;
487 }
488
SampleEta1(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * digest,int16_t * polyS[],uint8_t * nonce)489 static int32_t SampleEta1(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *digest, int16_t *polyS[], uint8_t *nonce)
490 {
491 uint8_t q[MLKEM_SEED_LEN + 1] = { 0 }; // Reserved lengths of nonce is 1 byte.
492 uint8_t prfOut[MLKEM_PRF_BLOCKSIZE * MLKEM_ETA1_MAX] = { 0 };
493 (void)memcpy_s(q, MLKEM_SEED_LEN, digest, MLKEM_SEED_LEN);
494
495 for (uint8_t i = 0; i < ctx->info->k; i++) {
496 q[MLKEM_SEED_LEN] = *nonce;
497 int32_t ret = PRF(q, MLKEM_SEED_LEN + 1, prfOut, MLKEM_PRF_BLOCKSIZE * MLKEM_ETA1_MAX);
498 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
499 MLKEM_SamplePolyCBD(polyS[i], prfOut, ctx->info->eta1);
500 *nonce = *nonce + 1;
501 MLKEM_ComputNTT(polyS[i], PRE_COMPUT_TABLE_NTT, MLKEM_N_HALF);
502 }
503 return CRYPT_SUCCESS;
504 }
505
SampleEta2(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * digest,int16_t * polyS[],uint8_t * nonce)506 static int32_t SampleEta2(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *digest, int16_t *polyS[], uint8_t *nonce)
507 {
508 uint8_t q[MLKEM_SEED_LEN + 1] = { 0 }; // Reserved lengths of nonce is 1 byte.
509 uint8_t prfOut[MLKEM_PRF_BLOCKSIZE * MLKEM_ETA2_MAX] = { 0 };
510 (void)memcpy_s(q, MLKEM_SEED_LEN, digest, MLKEM_SEED_LEN);
511
512 for (uint8_t i = 0; i < ctx->info->k; i++) {
513 q[MLKEM_SEED_LEN] = *nonce;
514 int32_t ret = PRF(q, MLKEM_SEED_LEN + 1, prfOut, MLKEM_PRF_BLOCKSIZE * MLKEM_ETA2_MAX);
515 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
516 MLKEM_SamplePolyCBD(polyS[i], prfOut, ctx->info->eta2);
517 *nonce = *nonce + 1;
518 }
519 return CRYPT_SUCCESS;
520 }
521
522 // NIST.FIPS.203 Algorithm 13 K-PKE.KeyGen()
PkeKeyGen(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * pk,uint8_t * dk,uint8_t * d)523 static int32_t PkeKeyGen(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *pk, uint8_t *dk, uint8_t *d)
524 {
525 uint8_t k = ctx->info->k;
526 uint8_t nonce = 0;
527 uint8_t seed[MLKEM_SEED_LEN + 1] = { 0 }; // Reserved lengths of k is 1 byte.
528 uint8_t digest[CRYPT_SHA3_512_DIGESTSIZE] = { 0 };
529
530 // (p,q) = G(d || k)
531 (void)memcpy_s(seed, MLKEM_SEED_LEN + 1, d, MLKEM_SEED_LEN);
532 seed[MLKEM_SEED_LEN] = k;
533 int32_t ret = HashFuncG(seed, MLKEM_SEED_LEN + 1, digest, CRYPT_SHA3_512_DIGESTSIZE); // Step 1
534 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
535
536 // expand 32+1 bytes to two pseudorandom 32-byte seeds
537 uint8_t *p = digest;
538 uint8_t *q = digest + CRYPT_SHA3_512_DIGESTSIZE / 2;
539
540 MLKEM_MatrixSt st = { 0 };
541 ret = CreateMatrixBuf(k, &st);
542 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
543
544 GOTO_ERR_IF(GenMatrix(ctx, p, st.matrix, false), ret); // Step 3 - 7
545 GOTO_ERR_IF(SampleEta1(ctx, q, st.vectorS, &nonce), ret); // Step 8 - 11
546 GOTO_ERR_IF(SampleEta1(ctx, q, st.vectorE, &nonce), ret); // Step 12 - 15
547 for (uint8_t i = 0; i < k; i++) { // Step 18
548 MLKEM_MatrixMulAdd(k, st.matrix[i], st.vectorS, st.vectorE[i], st.vectorT[i], PRE_COMPUT_TABLE_NTT);
549 }
550 // output: pk, dk, ekPKE ← ByteEncode12()‖p.
551 for (uint8_t i = 0; i < k; i++) {
552 // Step 19
553 ByteEncode(pk + MLKEM_SEED_LEN * MLKEM_BITS_OF_Q * i, st.vectorT[i], MLKEM_BITS_OF_Q);
554 // Step 20
555 ByteEncode(dk + MLKEM_SEED_LEN * MLKEM_BITS_OF_Q * i, st.vectorS[i], MLKEM_BITS_OF_Q);
556 }
557 // The buffer of pk is sufficient, check it before calling this function.
558 (void)memcpy_s(pk + MLKEM_SEED_LEN * MLKEM_BITS_OF_Q * k, MLKEM_SEED_LEN, p, MLKEM_SEED_LEN);
559
560 ERR:
561 MatrixBufFree(k, &st);
562 return ret;
563 }
564
565 // NIST.FIPS.203 Algorithm 14 K-PKE.Encrypt(ekPKE,,)
PkeEncrypt(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,const uint8_t * ek,uint8_t * m,uint8_t * r)566 static int32_t PkeEncrypt(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, const uint8_t *ek, uint8_t *m, uint8_t *r)
567 {
568 uint8_t i;
569 uint32_t n;
570 uint8_t k = ctx->info->k;
571 uint8_t nonce = 0; // Step 1
572 uint8_t seedE[MLKEM_SEED_LEN + 1];
573 uint8_t bufEncE[MLKEM_PRF_BLOCKSIZE * MLKEM_ETA1_MAX];
574 int16_t polyVectorE2[MLKEM_N] = { 0 };
575 int16_t polyVectorC2[MLKEM_N] = { 0 };
576 int16_t polyVectorM[MLKEM_N] = { 0 };
577
578 MLKEM_MatrixSt st = { 0 };
579 int32_t ret = CreateMatrixBuf(k, &st);
580 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
581
582 GOTO_ERR_IF(GenMatrix(ctx, ek + MLKEM_CIPHER_LEN * k, st.matrix, true), ret); // Step 3 - 8
583 GOTO_ERR_IF(SampleEta1(ctx, r, st.vectorS, &nonce), ret); // Step 9 - 12
584 GOTO_ERR_IF(SampleEta2(ctx, r, st.vectorE, &nonce), ret); // Step 13 - 16
585
586 // Step 17
587 (void)memcpy_s(seedE, MLKEM_SEED_LEN, r, MLKEM_SEED_LEN);
588 seedE[MLKEM_SEED_LEN] = nonce;
589 GOTO_ERR_IF(PRF(seedE, MLKEM_SEED_LEN + 1, bufEncE, MLKEM_PRF_BLOCKSIZE * ctx->info->eta2), ret);
590 MLKEM_SamplePolyCBD(polyVectorE2, bufEncE, ctx->info->eta2);
591
592 // Step 18
593 for (i = 0; i < k; i++) {
594 MLKEM_MatrixMulAdd(k, st.matrix[i], st.vectorS, NULL, st.vectorT[i], PRE_COMPUT_TABLE_NTT);
595 }
596
597 // Step 19
598 for (i = 0; i < k; i++) {
599 MLKEM_ComputINTT(st.vectorT[i], PRE_COMPUT_TABLE_INTT, MLKEM_N_HALF);
600 for (n = 0; n < MLKEM_N; n++) {
601 st.vectorT[i][n] = Compress(st.vectorT[i][n] + st.vectorE[i][n], ctx->info->du);
602 }
603 }
604
605 // Step 21
606 for (i = 0; i < k; i++) {
607 ByteDecode(st.vectorE[i], ek + MLKEM_CIPHER_LEN * i, MLKEM_BITS_OF_Q);
608 }
609 MLKEM_MatrixMulAdd(k, st.vectorE, st.vectorS, NULL, polyVectorC2, PRE_COMPUT_TABLE_NTT);
610
611 ByteDecode(polyVectorM, m, 1);
612 MLKEM_ComputINTT(polyVectorC2, PRE_COMPUT_TABLE_INTT, MLKEM_N_HALF);
613
614 for (n = 0; n < MLKEM_N; n++) {
615 polyVectorM[n] = DeCompress(polyVectorM[n], 1); // Step 20
616 // Step 22
617 polyVectorC2[n] = Compress(polyVectorC2[n] + polyVectorE2[n] + polyVectorM[n], ctx->info->dv);
618 }
619
620 // Step 22
621 for (i = 0; i < k; i++) {
622 ByteEncode(ct + MLKEM_ENCODE_BLOCKSIZE * ctx->info->du * i, st.vectorT[i], ctx->info->du);
623 }
624 // Step 23
625 ByteEncode(ct + MLKEM_ENCODE_BLOCKSIZE * ctx->info->du * k, polyVectorC2, ctx->info->dv);
626 ERR:
627 MatrixBufFree(k, &st);
628 return ret;
629 }
630
631 // NIST.FIPS.203 Algorithm 15 K-PKE.Decrypt(dkPKE, )
PkeDecrypt(const CRYPT_MlKemInfo * algInfo,uint8_t * result,const uint8_t * dk,const uint8_t * ciphertext)632 static int32_t PkeDecrypt(const CRYPT_MlKemInfo *algInfo, uint8_t *result, const uint8_t *dk,
633 const uint8_t *ciphertext)
634 {
635 uint8_t i;
636 uint8_t k = algInfo->k;
637 uint32_t n;
638
639 MLKEM_DecVectorSt st = { 0 };
640 int32_t ret = CreateDecVectorBuf(k, &st);
641 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
642
643 for (i = 0; i < k; i++) {
644 ByteDecode(st.vectorC1[i], ciphertext + MLKEM_ENCODE_BLOCKSIZE * algInfo->du * i, algInfo->du); // Step 3
645 ByteDecode(st.vectorS[i], dk + MLKEM_ENCODE_BLOCKSIZE * MLKEM_BITS_OF_Q * i, MLKEM_BITS_OF_Q); // Step 5
646 }
647 ByteDecode(st.vectorC2, ciphertext + MLKEM_ENCODE_BLOCKSIZE * algInfo->du * k, algInfo->dv); // Step 4
648
649 for (i = 0; i < k; i++) {
650 for (n = 0; n < MLKEM_N; n++) {
651 st.vectorC1[i][n] = DeCompress(st.vectorC1[i][n], algInfo->du); // Step 3
652 if (i == 0) {
653 st.vectorC2[n] = DeCompress(st.vectorC2[n], algInfo->dv); // Step 4
654 }
655 }
656 MLKEM_ComputNTT(st.vectorC1[i], PRE_COMPUT_TABLE_NTT, MLKEM_N_HALF);
657 }
658
659 MLKEM_MatrixMulAdd(k, st.vectorS, st.vectorC1, NULL, st.polyM, PRE_COMPUT_TABLE_NTT); // Step 6
660
661 // polyM = intt(polyM)
662 MLKEM_ComputINTT(st.polyM, PRE_COMPUT_TABLE_INTT, MLKEM_N_HALF);
663
664 // c2 - polyM
665 for (n = 0; n < MLKEM_N; n++) {
666 st.polyM[n] = Compress(st.vectorC2[n] - st.polyM[n], 1);
667 }
668
669 ByteEncode(result, st.polyM, 1); // Step 7
670 DecVectorBufFree(k, &st);
671 return CRYPT_SUCCESS;
672 }
673
674 // NIST.FIPS.203 Algorithm 16 ML-KEM.KeyGen_internal(,)
MLKEM_KeyGenInternal(CRYPT_ML_KEM_Ctx * ctx,uint8_t * d,uint8_t * z)675 int32_t MLKEM_KeyGenInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *d, uint8_t *z)
676 {
677 const CRYPT_MlKemInfo *algInfo = ctx->info;
678 uint32_t dkPkeLen = MLKEM_CIPHER_LEN * algInfo->k;
679
680 // (ekPKE,dkPKE) ← K-PKE.KeyGen()
681 int32_t ret = PkeKeyGen(ctx, ctx->ek, ctx->dk, d);
682 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
683
684 // dk ← (dkPKE‖ek‖H(ek)‖)
685 if (memcpy_s(ctx->dk + dkPkeLen, ctx->dkLen - dkPkeLen, ctx->ek, ctx->ekLen) != EOK) {
686 BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
687 return CRYPT_SECUREC_FAIL;
688 }
689
690 ret = HashFuncH(ctx->ek, ctx->ekLen, ctx->dk + dkPkeLen + ctx->ekLen, CRYPT_SHA3_256_DIGESTSIZE);
691 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
692
693 if (memcpy_s(ctx->dk + dkPkeLen + ctx->ekLen + CRYPT_SHA3_256_DIGESTSIZE,
694 ctx->dkLen - (dkPkeLen + ctx->ekLen + CRYPT_SHA3_256_DIGESTSIZE), z, MLKEM_SEED_LEN) != EOK) {
695 BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
696 return CRYPT_SECUREC_FAIL;
697 }
698 return CRYPT_SUCCESS;
699 }
700
701 // NIST.FIPS.203 Algorithm 17 ML-KEM.Encaps_internal(ek,)
MLKEM_EncapsInternal(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t * ctLen,uint8_t * sk,uint32_t * skLen,uint8_t * m)702 int32_t MLKEM_EncapsInternal(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t *ctLen, uint8_t *sk, uint32_t *skLen,
703 uint8_t *m)
704 {
705 uint8_t mhek[MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE]; // m and H(ek)
706 uint8_t kr[CRYPT_SHA3_512_DIGESTSIZE]; // K and r
707
708 // (K,r) = G(m || H(ek))
709 (void)memcpy_s(mhek, MLKEM_SEED_LEN, m, MLKEM_SEED_LEN);
710 int32_t ret = HashFuncH(ctx->ek, ctx->ekLen, mhek + MLKEM_SEED_LEN, CRYPT_SHA3_256_DIGESTSIZE);
711 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
712
713 ret = HashFuncG(mhek, MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE, kr, CRYPT_SHA3_512_DIGESTSIZE);
714 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
715
716 (void)memcpy_s(sk, *skLen, kr, MLKEM_SHARED_KEY_LEN);
717
718 // ← K-PKE.Encrypt(ek,,)
719 ret = PkeEncrypt(ctx, ct, ctx->ek, m, kr + MLKEM_SHARED_KEY_LEN);
720 BSL_SAL_CleanseData(kr, CRYPT_SHA3_512_DIGESTSIZE);
721 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
722
723 *ctLen = ctx->info->cipherLen;
724 *skLen = ctx->info->sharedLen;
725 return CRYPT_SUCCESS;
726 }
727
728 // NIST.FIPS.203 Algorithm 18 ML-KEM.Decaps_internal(dk, )
MLKEM_DecapsInternal(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t ctLen,uint8_t * sk,uint32_t * skLen)729 int32_t MLKEM_DecapsInternal(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t ctLen, uint8_t *sk, uint32_t *skLen)
730 {
731 const CRYPT_MlKemInfo *algInfo = ctx->info;
732 const uint8_t *dk = ctx->dk; // Step 1 dkPKE ← dk[0 : 384k]
733 const uint8_t *ek = dk + MLKEM_CIPHER_LEN * algInfo->k; // Step 2 ekPKE ← dk[384k : 768k +32]
734 const uint8_t *h = ek + algInfo->encapsKeyLen; // Step 3 h ← dk[768k +32 : 768k +64]
735 const uint8_t *z = h + MLKEM_SEED_LEN; // Step 4 z ← dk[768k +64 : 768k +96]
736
737 uint8_t mh[MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE]; // m′ and h
738 uint8_t kr[CRYPT_SHA3_512_DIGESTSIZE]; // K' and r'
739
740 int32_t ret = PkeDecrypt(algInfo, mh, dk, ct); // Step 5: ′ ← K-PKE.Decrypt(dkPKE, )
741 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
742 // Step 6: (K′,r′) ← G(m′ || h)
743 (void)memcpy_s(mh + MLKEM_SEED_LEN, CRYPT_SHA3_256_DIGESTSIZE, h, CRYPT_SHA3_256_DIGESTSIZE);
744 ret = HashFuncG(mh, MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE, kr, CRYPT_SHA3_512_DIGESTSIZE);
745 RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
746 // Step 8: ′ ← K-PKE.Encrypt(ekPKE,′,′)
747 uint8_t *r = kr + MLKEM_SHARED_KEY_LEN;
748 uint8_t *newCt = BSL_SAL_Malloc(ctLen + MLKEM_SEED_LEN);
749 RETURN_RET_IF(newCt == NULL, BSL_MALLOC_FAIL);
750 GOTO_ERR_IF(PkeEncrypt(ctx, newCt, ek, mh, r), ret);
751
752 // Step 9: if c != c′
753 if (memcmp(ct, newCt, ctLen) == 0) {
754 (void)memcpy_s(sk, *skLen, kr, MLKEM_SHARED_KEY_LEN);
755 } else {
756 // Step 7: K = J(z || c)
757 (void)memcpy_s(newCt, ctLen + MLKEM_SEED_LEN, z, MLKEM_SEED_LEN);
758 (void)memcpy_s(newCt + MLKEM_SEED_LEN, ctLen, ct, ctLen);
759 GOTO_ERR_IF(HashFuncJ(newCt, ctLen + MLKEM_SEED_LEN, sk, MLKEM_SHARED_KEY_LEN), ret);
760 }
761 *skLen = MLKEM_SHARED_KEY_LEN;
762 ERR:
763 BSL_SAL_CleanseData(kr, CRYPT_SHA3_512_DIGESTSIZE);
764 BSL_SAL_Free(newCt);
765 return ret;
766 }
767
768 #endif