• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_MLKEM
18 #include <stdlib.h>
19 #include <string.h>
20 #include <stdio.h>
21 #include "securec.h"
22 #include "bsl_errno.h"
23 #include "bsl_sal.h"
24 #include "crypt_utils.h"
25 #include "crypt_sha3.h"
26 #include "crypt_errno.h"
27 #include "bsl_err_internal.h"
28 #include "eal_md_local.h"
29 #include "ml_kem_local.h"
30 
31 #define BITS_OF_BYTE 8
32 #define MLKEM_K_MAX    4
33 #define MLKEM_ETA1_MAX    3
34 #define MLKEM_ETA2_MAX    2
35 
36 // A LUT of the primitive n-th roots of unity (psi) in bit-reversed order.
37 static const int16_t PRE_COMPUT_TABLE_NTT[MLKEM_N_HALF] = {
38     1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 2786, 3260, 569, 1746, 296, 2447, 1339, 1476,
39     3046, 56, 2240, 1333, 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 1197, 2304, 2277, 2055,
40     650, 1977, 2513, 632, 2865, 33, 1320, 1915, 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
41     2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 1409, 2662, 3281, 233, 756, 2156, 3015, 3050,
42     1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 1584, 2298, 2037, 3220,
43     375, 2549, 2090, 1645, 1063, 319, 2773, 757, 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
44     1722, 1212, 1874, 1029, 2110, 2935, 885, 2154
45 };
46 
47 // A LUT of all powers of psi^{-1} in bit-reversed order.
48 static const int16_t PRE_COMPUT_TABLE_INTT[MLKEM_N_HALF] = {
49     1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447,
50     2794, 1235, 1903, 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 2681, 1848, 712, 682, 927,
51     1795, 461, 1891, 2877, 2522, 1894, 1010, 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
52     1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 2443, 554, 1179, 2186, 2303, 2926, 2237,
53     525, 735, 863, 2768, 1230, 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 2688, 3061,
54     992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096,
55     48, 667, 1920, 2229, 1041, 2606, 1692, 680, 2746, 568, 3312
56 };
57 
58 typedef struct {
59     int16_t *bufAddr;
60     int16_t *matrix[MLKEM_K_MAX][MLKEM_K_MAX];
61     int16_t *vectorS[MLKEM_K_MAX];
62     int16_t *vectorE[MLKEM_K_MAX];
63     int16_t *vectorT[MLKEM_K_MAX];
64 } MLKEM_MatrixSt;  // Intermediate data of the key generation and encryption.
65 
66 typedef struct {
67     int16_t *bufAddr;
68     int16_t *vectorS[MLKEM_K_MAX];
69     int16_t *vectorC1[MLKEM_K_MAX];
70     int16_t *vectorC2;
71     int16_t *polyM;
72 } MLKEM_DecVectorSt;  // Intermediate data of the decryption.
73 
CreateMatrixBuf(uint8_t k,MLKEM_MatrixSt * st)74 static int32_t CreateMatrixBuf(uint8_t k, MLKEM_MatrixSt *st)
75 {
76     // A total of (k * k + 3 * k) data blocks are required. Each block has 512 bytes.
77     int16_t *buf = BSL_SAL_Malloc((k * k + 3 * k) * MLKEM_N * sizeof(int16_t));
78     if (buf == NULL) {
79         return BSL_MALLOC_FAIL;
80     }
81     st->bufAddr = buf;  // Used to release memory.
82     for (uint8_t i = 0; i < k; i++) {
83         for (uint8_t j = 0; j < k; j++) {
84             st->matrix[i][j] = buf + (i * k + j) * MLKEM_N;
85         }
86         // vectorS,vectorE,vectorT use 3 * k data blocks.
87         st->vectorS[i] = buf + (k * k + i * 3) * MLKEM_N;
88         st->vectorE[i] = buf + (k * k + i * 3 + 1) * MLKEM_N;
89         st->vectorT[i] = buf + (k * k + i * 3 + 2) * MLKEM_N;
90     }
91     return CRYPT_SUCCESS;
92 }
93 
MatrixBufFree(uint8_t k,MLKEM_MatrixSt * st)94 static void MatrixBufFree(uint8_t k, MLKEM_MatrixSt *st)
95 {
96     // A total of (k * k + 3 * k) data blocks, each block has 512 bytes.
97     BSL_SAL_ClearFree(st->bufAddr, (k * k + 3 * k) * MLKEM_N * sizeof(int16_t));
98 }
99 
CreateDecVectorBuf(uint8_t k,MLKEM_DecVectorSt * st)100 static int32_t CreateDecVectorBuf(uint8_t k, MLKEM_DecVectorSt *st)
101 {
102     // A total of (k * 2 + 2) data blocks are required. Each block has 512 bytes.
103     int16_t *buf = BSL_SAL_Malloc((k * 2 + 2) * MLKEM_N * sizeof(int16_t));
104     if (buf == NULL) {
105         return BSL_MALLOC_FAIL;
106     }
107     st->bufAddr = buf;  // Used to release memory.
108     for (uint8_t i = 0; i < k; i++) {
109         st->vectorS[i] = buf + (i) * MLKEM_N;
110         st->vectorC1[i] = buf + (k + i) * MLKEM_N;
111     }
112     // vectorC2 and polyM use 2 * k data blocks.
113     st->vectorC2 = buf + (k * 2) * MLKEM_N;
114     st->polyM = buf + (k * 2 + 1) * MLKEM_N;
115     return CRYPT_SUCCESS;
116 }
117 
DecVectorBufFree(uint8_t k,MLKEM_DecVectorSt * st)118 static void DecVectorBufFree(uint8_t k, MLKEM_DecVectorSt *st)
119 {
120     // A total of (k * 2 + 2) data blocks, each block has 512 bytes.
121     BSL_SAL_ClearFree(st->bufAddr, (k * 2 + 2) * MLKEM_N * sizeof(int16_t));
122 }
123 
124 // Compress
125 typedef struct {
126     uint64_t barrettMultiplier;  /* round(2 ^ barrettShift / MLKEM_Q) */
127     uint16_t barrettShift;
128     uint16_t halfQ;              /* rounded (MLKEM_Q / 2) down or up */
129     uint8_t  bits;
130 } MLKEM_BARRET_REDUCE;
131 
132 // The values of du and dv are from NIST.FIPS.203 Table 2.
133 static const MLKEM_BARRET_REDUCE MLKEM_BARRETT_TABLE[] = {
134     {80635   /* round(2^28/MLKEM_Q) */, 28, 1665 /* Ceil(MLKEM_Q/2)  */, 1},
135     {1290167 /* round(2^32/MLKEM_Q) */, 32, 1665 /* Ceil(MLKEM_Q/2)  */, 10},  // 10 is mlkem768 du
136     {80635   /* round(2^28/MLKEM_Q) */, 28, 1665 /* Ceil(MLKEM_Q/2)  */, 4},   // 4 is mlkem768 dv
137     {40318   /* round(2^27/MLKEM_Q) */, 27, 1664 /* Floor(MLKEM_Q/2) */, 5},   // 5 is mlkem1024 dv
138     {645084  /* round(2^31/MLKEM_Q) */, 31, 1664 /* Floor(MLKEM_Q/2) */, 11}   // 11 is mlkem1024 du
139 };
140 
DivMlKemQ(uint16_t x,uint8_t bits,uint16_t halfQ,uint16_t barrettShift,uint64_t barrettMultiplier)141 static int16_t DivMlKemQ(uint16_t x, uint8_t bits, uint16_t halfQ, uint16_t barrettShift, uint64_t barrettMultiplier)
142 {
143     uint64_t round = ((uint64_t)x << bits) + halfQ;
144     round *= barrettMultiplier;
145     round >>= barrettShift;
146     return (int16_t)(round & ((1 << bits) - 1));
147 }
148 
Compress(int16_t x,uint8_t d)149 static int16_t Compress(int16_t x, uint8_t d)
150 {
151     int16_t value = 0;
152     uint16_t t = (uint16_t)(x + MLKEM_Q) % MLKEM_Q;
153     /* Computing (x << d) / MLKEM_Q by Barret Reduce */
154     for (uint32_t i = 0; i < sizeof(MLKEM_BARRETT_TABLE) / sizeof(MLKEM_BARRET_REDUCE); i++) {
155         if (d == MLKEM_BARRETT_TABLE[i].bits) {
156             value = DivMlKemQ(t,
157                 MLKEM_BARRETT_TABLE[i].bits,
158                 MLKEM_BARRETT_TABLE[i].halfQ,
159                 MLKEM_BARRETT_TABLE[i].barrettShift,
160                 MLKEM_BARRETT_TABLE[i].barrettMultiplier);
161             break;
162         }
163     }
164     return value;
165 }
166 
167 // DeCompress
DeCompress(int16_t x,uint8_t bits)168 static int16_t DeCompress(int16_t x, uint8_t bits)
169 {
170     uint32_t product = (uint32_t)x * MLKEM_Q;
171     uint32_t power = 1 << bits;
172     return (int16_t)((product >> bits) + ((product & (power - 1)) >> (bits - 1)));
173 }
174 
175 // hash functions
HashFuncH(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)176 static int32_t HashFuncH(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
177 {
178     uint32_t len = outLen;
179     return EAL_Md(CRYPT_MD_SHA3_256, in, inLen, out, &len);
180 }
181 
HashFuncG(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)182 static int32_t HashFuncG(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
183 {
184     uint32_t len = outLen;
185     return EAL_Md(CRYPT_MD_SHA3_512, in, inLen, out, &len);
186 }
187 
HashFuncXOF(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)188 static int32_t HashFuncXOF(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
189 {
190     uint32_t len = outLen;
191     return EAL_Md(CRYPT_MD_SHAKE128, in, inLen, out, &len);
192 }
193 
HashFuncJ(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)194 static int32_t HashFuncJ(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t outLen)
195 {
196     uint32_t len = outLen;
197     return EAL_Md(CRYPT_MD_SHAKE256, in, inLen, out, &len);
198 }
199 
PRF(uint8_t * extSeed,uint32_t extSeedLen,uint8_t * outBuf,uint32_t bufLen)200 static int32_t PRF(uint8_t *extSeed, uint32_t extSeedLen, uint8_t *outBuf, uint32_t bufLen)
201 {
202     uint32_t len = bufLen;
203     return EAL_Md(CRYPT_MD_SHAKE256, extSeed, extSeedLen, outBuf, &len);
204 }
205 
Parse(uint16_t * polyNtt,uint8_t * arrayB,uint32_t arrayLen,uint32_t n)206 static int32_t Parse(uint16_t *polyNtt, uint8_t *arrayB, uint32_t arrayLen, uint32_t n)
207 {
208     uint32_t i = 0;
209     uint32_t j = 0;
210     while (j < n) {
211         if (i + 3 > arrayLen) {  // 3 bytes of arrayB are read in each round.
212             BSL_ERR_PUSH_ERROR(CRYPT_MLKEM_KEYLEN_ERROR);
213             return CRYPT_MLKEM_KEYLEN_ERROR;
214         }
215         // The 4 bits of each byte are combined with the 8 bits of another byte into 12 bits.
216         uint16_t d1 = ((uint16_t)arrayB[i]) + (((uint16_t)arrayB[i + 1] & 0x0f) << 8);  // 4 bits.
217         uint16_t d2 = (((uint16_t)arrayB[i + 1]) >> 4) + (((uint16_t)arrayB[i + 2]) << 4);
218         if (d1 < MLKEM_Q) {
219             polyNtt[j] = d1;
220             j++;
221         }
222         if (d2 < MLKEM_Q && j < n) {
223             polyNtt[j] = d2;
224             j++;
225         }
226         i += 3;  // 3 bytes are processed in each round.
227     }
228     return CRYPT_SUCCESS;
229 }
230 
EncodeBits1(uint8_t * r,uint16_t * polyF)231 static void EncodeBits1(uint8_t *r, uint16_t *polyF)
232 {
233     for (uint32_t i = 0; i < MLKEM_N / BITS_OF_BYTE; i++) {
234         r[i] = (uint8_t)polyF[BITS_OF_BYTE * i];
235         for (uint32_t j = 1; j < BITS_OF_BYTE; j++) {
236             r[i] = (uint8_t)(polyF[BITS_OF_BYTE * i + j] << j) | r[i];
237         }
238     }
239 }
240 
EncodeBits4(uint8_t * r,uint16_t * polyF)241 static void EncodeBits4(uint8_t *r, uint16_t *polyF)
242 {
243     for (uint32_t i = 0; i < MLKEM_N / 2; i++) { // Two 4 bits are combined into 1 byte.
244         r[i] = ((uint8_t)polyF[2 * i] | ((uint8_t)polyF[2 * i + 1] << 4));
245     }
246 }
247 
EncodeBits5(uint8_t * r,uint16_t * polyF)248 static void EncodeBits5(uint8_t *r, uint16_t *polyF)
249 {
250     uint32_t indexR;
251     uint32_t indexF;
252     for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
253         indexR = 5 * i;  // Each element in polyF has 5 bits.
254         indexF = 8 * i;  // Each element in r has 8 bits.
255         // 8 polyF elements are padded to 5 bytes.
256         r[indexR + 0] = (uint8_t)(polyF[indexF] | (polyF[indexF + 1] << 5));
257         r[indexR + 1] =
258             (uint8_t)((polyF[indexF + 1] >> 3) | (polyF[indexF + 2] << 2) | (polyF[indexF + 3] << 7));
259         r[indexR + 2] = (uint8_t)((polyF[indexF + 3] >> 1) | (polyF[indexF + 4] << 4));
260         r[indexR + 3] =
261             (uint8_t)((polyF[indexF + 4] >> 4) | (polyF[indexF + 5] << 1) | (polyF[indexF + 6] << 6));
262         r[indexR + 4] = (uint8_t)((polyF[indexF + 6] >> 2) | (polyF[indexF + 7] << 3));
263     }
264 }
265 
EncodeBits10(uint8_t * r,uint16_t * polyF)266 static void EncodeBits10(uint8_t *r, uint16_t *polyF)
267 {
268     uint32_t indexR;
269     uint32_t indexF;
270     for (uint32_t i = 0; i < MLKEM_N / 4; i++) {
271         // 4 polyF elements are padded to 5 bytes.
272         indexR = 5 * i;
273         indexF = 4 * i;
274         r[indexR + 0] = (uint8_t)polyF[indexF];
275         r[indexR + 1] = (uint8_t)((polyF[indexF] >> 8) | (polyF[indexF + 1] << 2));
276         r[indexR + 2] = (uint8_t)((polyF[indexF + 1] >> 6) | (polyF[indexF + 2] << 4));
277         r[indexR + 3] = (uint8_t)((polyF[indexF + 2] >> 4) | (polyF[indexF + 3] << 6));
278         r[indexR + 4] = (uint8_t)(polyF[indexF + 3] >> 2);
279     }
280 }
281 
EncodeBits11(uint8_t * r,uint16_t * polyF)282 static void EncodeBits11(uint8_t *r, uint16_t *polyF)
283 {
284     uint32_t indexR;
285     uint32_t indexF;
286     for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
287         // 8 polyF elements are padded to 11 bytes.
288         indexR = 11 * i;
289         indexF = 8 * i;
290         r[indexR + 0] = (uint8_t)polyF[indexF];
291         r[indexR + 1] = (uint8_t)((polyF[indexF] >> 8) | (polyF[indexF + 1] << 3));
292         r[indexR + 2] = (uint8_t)((polyF[indexF + 1] >> 5) | (polyF[indexF + 2] << 6));
293         r[indexR + 3] = (uint8_t)((polyF[indexF + 2] >> 2));
294         r[indexR + 4] = (uint8_t)((polyF[indexF + 2] >> 10) | (polyF[indexF + 3] << 1));
295         r[indexR + 5] = (uint8_t)((polyF[indexF + 3] >> 7) | (polyF[indexF + 4] << 4));
296         r[indexR + 6] = (uint8_t)((polyF[indexF + 4] >> 4) | (polyF[indexF + 5] << 7));
297         r[indexR + 7] = (uint8_t)((polyF[indexF + 5] >> 1));
298         r[indexR + 8] = (uint8_t)((polyF[indexF + 5] >> 9) | (polyF[indexF + 6] << 2));
299         r[indexR + 9] = (uint8_t)((polyF[indexF + 6] >> 6) | (polyF[indexF + 7] << 5));
300         r[indexR + 10] = (uint8_t)(polyF[indexF + 7] >> 3);
301     }
302 }
303 
EncodeBits12(uint8_t * r,uint16_t * polyF)304 static void EncodeBits12(uint8_t *r, uint16_t *polyF)
305 {
306     uint32_t i;
307     uint16_t t0;
308     uint16_t t1;
309     for (i = 0; i < MLKEM_N / 2; i++) {
310         // 2 polyF elements are padded to 3 bytes.
311         t0 = polyF[2 * i];
312         t1 = polyF[2 * i + 1];
313         r[3 * i + 0] = (uint8_t)(t0 >> 0);
314         r[3 * i + 1] = (uint8_t)((t0 >> 8) | (t1 << 4));
315         r[3 * i + 2] = (uint8_t)(t1 >> 4);
316     }
317 }
318 
319 // Encodes an array of d-bit integers into a byte array for 1 ≤ d ≤ 12.
ByteEncode(uint8_t * r,int16_t * polyF,uint8_t bit)320 static void ByteEncode(uint8_t *r, int16_t *polyF, uint8_t bit)
321 {
322     switch (bit) {  // Valid bits of each element in polyF.
323         case 1:    // 1 Used for K-PKE.Decrypt Step 7.
324             EncodeBits1(r, (uint16_t *)polyF);
325             break;
326         case 4:    // From FIPS 203 Table 2, dv = 4
327             EncodeBits4(r, (uint16_t *)polyF);
328             break;
329         case 5:    // dv = 5
330             EncodeBits5(r, (uint16_t *)polyF);
331             break;
332         case 10:   // du = 10
333             EncodeBits10(r, (uint16_t *)polyF);
334             break;
335         case 11:    // du = 11
336             EncodeBits11(r, (uint16_t *)polyF);
337             break;
338         case 12:    // 12 Used for K-PKE.KeyGen Step 19.
339             EncodeBits12(r, (uint16_t *)polyF);
340             break;
341         default:
342             break;
343     }
344 }
345 
DecodeBits1(int16_t * polyF,const uint8_t * a)346 static void DecodeBits1(int16_t *polyF, const uint8_t *a)
347 {
348     uint32_t i;
349     uint32_t j;
350     for (i = 0; i < MLKEM_N / BITS_OF_BYTE; i++) {
351         // 1 byte data is decoded into 8 polyF elements.
352         for (j = 0; j < BITS_OF_BYTE; j++) {
353             polyF[BITS_OF_BYTE * i + j] = (a[i] >> j) & 0x01;
354         }
355     }
356 }
357 
DecodeBits4(int16_t * polyF,const uint8_t * a)358 static void DecodeBits4(int16_t *polyF, const uint8_t *a)
359 {
360     uint32_t i;
361     for (i = 0; i < MLKEM_N / 2; i++) {
362         // 1 byte data is decoded into 2 polyF elements.
363         polyF[2 * i] = a[i] & 0xF;
364         polyF[2 * i + 1] = (a[i] >> 4) & 0xF;
365     }
366 }
367 
DecodeBits5(int16_t * polyF,const uint8_t * a)368 static void DecodeBits5(int16_t *polyF, const uint8_t *a)
369 {
370     uint32_t indexF;
371     uint32_t indexA;
372     for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
373         // 8 byte data is decoded into 5 polyF elements.
374         indexF = 8 * i;
375         indexA = 5 * i;
376         // value & 0x1F is used to obtain 5 bits.
377         polyF[indexF + 0] = ((a[indexA + 0] >> 0)) & 0x1F;
378         polyF[indexF + 1] = ((a[indexA + 0] >> 5) | (a[indexA + 1] << 3)) & 0x1F;
379         polyF[indexF + 2] = ((a[indexA + 1] >> 2)) & 0x1F;
380         polyF[indexF + 3] = ((a[indexA + 1] >> 7) | (a[indexA + 2] << 1)) & 0x1F;
381         polyF[indexF + 4] = ((a[indexA + 2] >> 4) | (a[indexA + 3] << 4)) & 0x1F;
382         polyF[indexF + 5] = ((a[indexA + 3] >> 1)) & 0x1F;
383         polyF[indexF + 6] = ((a[indexA + 3] >> 6) | (a[indexA + 4] << 2)) & 0x1F;
384         polyF[indexF + 7] = ((a[indexA + 4] >> 3)) & 0x1F;
385     }
386 }
387 
DecodeBits10(int16_t * polyF,const uint8_t * a)388 static void DecodeBits10(int16_t *polyF, const uint8_t *a)
389 {
390     uint32_t indexF;
391     uint32_t indexA;
392     for (uint32_t i = 0; i < MLKEM_N / 4; i++) {
393         // 5 byte data is decoded into 4 polyF elements.
394         indexF = 4 * i;
395         indexA = 5 * i;
396         // value & 0x3FF is used to obtain 10 bits.
397         polyF[indexF + 0] = ((a[indexA + 0] >> 0) | ((uint16_t)a[indexA + 1] << 8)) & 0x3FF;
398         polyF[indexF + 1] = ((a[indexA + 1] >> 2) | ((uint16_t)a[indexA + 2] << 6)) & 0x3FF;
399         polyF[indexF + 2] = ((a[indexA + 2] >> 4) | ((uint16_t)a[indexA + 3] << 4)) & 0x3FF;
400         polyF[indexF + 3] = ((a[indexA + 3] >> 6) | ((uint16_t)a[indexA + 4] << 2)) & 0x3FF;
401     }
402 }
403 
DecodeBits11(int16_t * polyF,const uint8_t * a)404 static void DecodeBits11(int16_t *polyF, const uint8_t *a)
405 {
406     uint32_t indexF;
407     uint32_t indexA;
408     for (uint32_t i = 0; i < MLKEM_N / 8; i++) {
409         // use type conversion because 11 > 8
410         indexF = 8 * i;
411         indexA = 11 * i;
412         // value & 0x7FF is used to obtain 11 bits.
413         polyF[indexF + 0] = ((a[indexA + 0] >> 0) | ((uint16_t)a[indexA + 1] << 8)) & 0x7FF;
414         polyF[indexF + 1] = ((a[indexA + 1] >> 3) | ((uint16_t)a[indexA + 2] << 5)) & 0x7FF;
415         polyF[indexF + 2] = ((a[indexA + 2] >> 6) | ((uint16_t)a[indexA + 3] << 2) |
416             ((uint16_t)a[indexA + 4] << 10)) & 0x7FF;
417         polyF[indexF + 3] = ((a[indexA + 4] >> 1) | ((uint16_t)a[indexA + 5] << 7)) & 0x7FF;
418         polyF[indexF + 4] = ((a[indexA + 5] >> 4) | ((uint16_t)a[indexA + 6] << 4)) & 0x7FF;
419         polyF[indexF + 5] = ((a[indexA + 6] >> 7) | ((uint16_t)a[indexA + 7] << 1) |
420             ((uint16_t)a[indexA + 8] << 9)) & 0x7FF;
421         polyF[indexF + 6] = ((a[indexA + 8] >> 2) | ((uint16_t)a[indexA + 9] << 6)) & 0x7FF;
422         polyF[indexF + 7] = ((a[indexA + 9] >> 5) | ((uint16_t)a[indexA + 10] << 3)) & 0x7FF;
423     }
424 }
425 
DecodeBits12(int16_t * polyF,const uint8_t * a)426 static void DecodeBits12(int16_t *polyF, const uint8_t *a)
427 {
428     uint32_t i;
429     for (i = 0; i < MLKEM_N / 2; i++) {
430         // 3 byte data is decoded into 2 polyF elements, value & 0xFFF is used to obtain 12 bits.
431         polyF[2 * i] = ((a[3 * i + 0] >> 0) | ((uint16_t)a[3 * i + 1] << 8)) & 0xFFF;
432         polyF[2 * i + 1] = ((a[3 * i + 1] >> 4) | ((uint16_t)a[3 * i + 2] << 4)) & 0xFFF;
433     }
434 }
435 
436 // Decodes a byte array into an array of d-bit integers for 1 ≤ d ≤ 12.
ByteDecode(int16_t * polyF,const uint8_t * a,uint8_t bit)437 static void ByteDecode(int16_t *polyF, const uint8_t *a, uint8_t bit)
438 {
439     switch (bit) {
440         case 1:
441             DecodeBits1(polyF, a);
442             break;
443         case 4:
444             DecodeBits4(polyF, a);
445             break;
446         case 5:
447             DecodeBits5(polyF, a);
448             break;
449         case 10:
450             DecodeBits10(polyF, a);
451             break;
452         case 11:
453             DecodeBits11(polyF, a);
454             break;
455         case 12:
456             DecodeBits12(polyF, a);
457             break;
458         default:
459             break;
460     }
461 }
462 
GenMatrix(const CRYPT_ML_KEM_Ctx * ctx,const uint8_t * digest,int16_t * polyMatrix[MLKEM_K_MAX][MLKEM_K_MAX],bool isEnc)463 static int32_t GenMatrix(const CRYPT_ML_KEM_Ctx *ctx, const uint8_t *digest,
464     int16_t *polyMatrix[MLKEM_K_MAX][MLKEM_K_MAX], bool isEnc)
465 {
466     uint8_t k = ctx->info->k;
467     uint8_t p[MLKEM_SEED_LEN + 2];  // Reserved lengths of i and j is 2 byte.
468     uint8_t xofOut[MLKEM_XOF_OUTPUT_LENGTH];
469 
470     (void)memcpy_s(p, MLKEM_SEED_LEN, digest, MLKEM_SEED_LEN);
471     for (uint8_t i = 0; i < k; i++) {
472         for (uint8_t j = 0; j < k; j++) {
473             if (isEnc) {
474                 p[MLKEM_SEED_LEN] = i;
475                 p[MLKEM_SEED_LEN + 1] = j;
476             } else {
477                 p[MLKEM_SEED_LEN] = j;
478                 p[MLKEM_SEED_LEN + 1] = i;
479             }
480             int32_t ret = HashFuncXOF(p, MLKEM_SEED_LEN + 2, xofOut, MLKEM_XOF_OUTPUT_LENGTH);
481             RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
482             ret = Parse((uint16_t *)polyMatrix[i][j], xofOut, MLKEM_XOF_OUTPUT_LENGTH, MLKEM_N);
483             RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
484         }
485     }
486     return CRYPT_SUCCESS;
487 }
488 
SampleEta1(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * digest,int16_t * polyS[],uint8_t * nonce)489 static int32_t SampleEta1(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *digest, int16_t *polyS[], uint8_t *nonce)
490 {
491     uint8_t q[MLKEM_SEED_LEN + 1] = { 0 };  // Reserved lengths of nonce is 1 byte.
492     uint8_t prfOut[MLKEM_PRF_BLOCKSIZE * MLKEM_ETA1_MAX] = { 0 };
493     (void)memcpy_s(q, MLKEM_SEED_LEN, digest, MLKEM_SEED_LEN);
494 
495     for (uint8_t i = 0; i < ctx->info->k; i++) {
496         q[MLKEM_SEED_LEN] = *nonce;
497         int32_t ret = PRF(q, MLKEM_SEED_LEN + 1, prfOut, MLKEM_PRF_BLOCKSIZE * MLKEM_ETA1_MAX);
498         RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
499         MLKEM_SamplePolyCBD(polyS[i], prfOut, ctx->info->eta1);
500         *nonce = *nonce + 1;
501         MLKEM_ComputNTT(polyS[i], PRE_COMPUT_TABLE_NTT, MLKEM_N_HALF);
502     }
503     return CRYPT_SUCCESS;
504 }
505 
SampleEta2(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * digest,int16_t * polyS[],uint8_t * nonce)506 static int32_t SampleEta2(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *digest, int16_t *polyS[], uint8_t *nonce)
507 {
508     uint8_t q[MLKEM_SEED_LEN + 1] = { 0 };  // Reserved lengths of nonce is 1 byte.
509     uint8_t prfOut[MLKEM_PRF_BLOCKSIZE * MLKEM_ETA2_MAX] = { 0 };
510     (void)memcpy_s(q, MLKEM_SEED_LEN, digest, MLKEM_SEED_LEN);
511 
512     for (uint8_t i = 0; i < ctx->info->k; i++) {
513         q[MLKEM_SEED_LEN] = *nonce;
514         int32_t ret = PRF(q, MLKEM_SEED_LEN + 1, prfOut, MLKEM_PRF_BLOCKSIZE * MLKEM_ETA2_MAX);
515         RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
516         MLKEM_SamplePolyCBD(polyS[i], prfOut, ctx->info->eta2);
517         *nonce = *nonce + 1;
518     }
519     return CRYPT_SUCCESS;
520 }
521 
522 // NIST.FIPS.203 Algorithm 13 K-PKE.KeyGen(��)
PkeKeyGen(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * pk,uint8_t * dk,uint8_t * d)523 static int32_t PkeKeyGen(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *pk, uint8_t *dk, uint8_t *d)
524 {
525     uint8_t k = ctx->info->k;
526     uint8_t nonce = 0;
527     uint8_t seed[MLKEM_SEED_LEN + 1] = { 0 };  // Reserved lengths of k is 1 byte.
528     uint8_t digest[CRYPT_SHA3_512_DIGESTSIZE] = { 0 };
529 
530     // (p,q) = G(d || k)
531     (void)memcpy_s(seed, MLKEM_SEED_LEN + 1, d, MLKEM_SEED_LEN);
532     seed[MLKEM_SEED_LEN] = k;
533     int32_t ret = HashFuncG(seed, MLKEM_SEED_LEN + 1, digest, CRYPT_SHA3_512_DIGESTSIZE);  // Step 1
534     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
535 
536     // expand 32+1 bytes to two pseudorandom 32-byte seeds
537     uint8_t *p = digest;
538     uint8_t *q = digest + CRYPT_SHA3_512_DIGESTSIZE / 2;
539 
540     MLKEM_MatrixSt st = { 0 };
541     ret = CreateMatrixBuf(k, &st);
542     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
543 
544     GOTO_ERR_IF(GenMatrix(ctx, p, st.matrix, false), ret);  // Step 3 - 7
545     GOTO_ERR_IF(SampleEta1(ctx, q, st.vectorS, &nonce), ret);  // Step 8 - 11
546     GOTO_ERR_IF(SampleEta1(ctx, q, st.vectorE, &nonce), ret);  // Step 12 - 15
547     for (uint8_t i = 0; i < k; i++) {  // Step 18
548         MLKEM_MatrixMulAdd(k, st.matrix[i], st.vectorS, st.vectorE[i], st.vectorT[i], PRE_COMPUT_TABLE_NTT);
549     }
550     // output: pk, dk,  ekPKE ← ByteEncode12(��)‖p.
551     for (uint8_t i = 0; i < k; i++) {
552         // Step 19
553         ByteEncode(pk + MLKEM_SEED_LEN * MLKEM_BITS_OF_Q * i, st.vectorT[i], MLKEM_BITS_OF_Q);
554         // Step 20
555         ByteEncode(dk + MLKEM_SEED_LEN * MLKEM_BITS_OF_Q * i, st.vectorS[i], MLKEM_BITS_OF_Q);
556     }
557     // The buffer of pk is sufficient, check it before calling this function.
558     (void)memcpy_s(pk + MLKEM_SEED_LEN * MLKEM_BITS_OF_Q * k, MLKEM_SEED_LEN, p, MLKEM_SEED_LEN);
559 
560 ERR:
561     MatrixBufFree(k, &st);
562     return ret;
563 }
564 
565 // NIST.FIPS.203 Algorithm 14 K-PKE.Encrypt(ekPKE,��,��)
PkeEncrypt(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,const uint8_t * ek,uint8_t * m,uint8_t * r)566 static int32_t PkeEncrypt(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, const uint8_t *ek, uint8_t *m, uint8_t *r)
567 {
568     uint8_t i;
569     uint32_t n;
570     uint8_t k = ctx->info->k;
571     uint8_t nonce = 0; // Step 1
572     uint8_t seedE[MLKEM_SEED_LEN + 1];
573     uint8_t bufEncE[MLKEM_PRF_BLOCKSIZE * MLKEM_ETA1_MAX];
574     int16_t polyVectorE2[MLKEM_N] = { 0 };
575     int16_t polyVectorC2[MLKEM_N] = { 0 };
576     int16_t polyVectorM[MLKEM_N] = { 0 };
577 
578     MLKEM_MatrixSt st = { 0 };
579     int32_t ret = CreateMatrixBuf(k, &st);
580     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
581 
582     GOTO_ERR_IF(GenMatrix(ctx, ek + MLKEM_CIPHER_LEN * k, st.matrix, true), ret);  // Step 3 - 8
583     GOTO_ERR_IF(SampleEta1(ctx, r, st.vectorS, &nonce), ret);  // Step 9 - 12
584     GOTO_ERR_IF(SampleEta2(ctx, r, st.vectorE, &nonce), ret);  // Step 13 - 16
585 
586     // Step 17
587     (void)memcpy_s(seedE, MLKEM_SEED_LEN, r, MLKEM_SEED_LEN);
588     seedE[MLKEM_SEED_LEN] = nonce;
589     GOTO_ERR_IF(PRF(seedE, MLKEM_SEED_LEN + 1, bufEncE, MLKEM_PRF_BLOCKSIZE * ctx->info->eta2), ret);
590     MLKEM_SamplePolyCBD(polyVectorE2, bufEncE, ctx->info->eta2);
591 
592     // Step 18
593     for (i = 0; i < k; i++) {
594         MLKEM_MatrixMulAdd(k, st.matrix[i], st.vectorS, NULL, st.vectorT[i], PRE_COMPUT_TABLE_NTT);
595     }
596 
597     // Step 19
598     for (i = 0; i < k; i++) {
599         MLKEM_ComputINTT(st.vectorT[i], PRE_COMPUT_TABLE_INTT, MLKEM_N_HALF);
600         for (n = 0; n < MLKEM_N; n++) {
601             st.vectorT[i][n] = Compress(st.vectorT[i][n] + st.vectorE[i][n], ctx->info->du);
602         }
603     }
604 
605     // Step 21
606     for (i = 0; i < k; i++) {
607         ByteDecode(st.vectorE[i], ek + MLKEM_CIPHER_LEN * i, MLKEM_BITS_OF_Q);
608     }
609     MLKEM_MatrixMulAdd(k, st.vectorE, st.vectorS, NULL, polyVectorC2, PRE_COMPUT_TABLE_NTT);
610 
611     ByteDecode(polyVectorM, m, 1);
612     MLKEM_ComputINTT(polyVectorC2, PRE_COMPUT_TABLE_INTT, MLKEM_N_HALF);
613 
614     for (n = 0; n < MLKEM_N; n++) {
615         polyVectorM[n] = DeCompress(polyVectorM[n], 1); // Step 20
616         // Step 22
617         polyVectorC2[n] = Compress(polyVectorC2[n] + polyVectorE2[n] + polyVectorM[n], ctx->info->dv);
618     }
619 
620     // Step 22
621     for (i = 0; i < k; i++) {
622         ByteEncode(ct + MLKEM_ENCODE_BLOCKSIZE * ctx->info->du * i, st.vectorT[i], ctx->info->du);
623     }
624     // Step 23
625     ByteEncode(ct + MLKEM_ENCODE_BLOCKSIZE * ctx->info->du * k, polyVectorC2, ctx->info->dv);
626 ERR:
627     MatrixBufFree(k, &st);
628     return ret;
629 }
630 
631 // NIST.FIPS.203 Algorithm 15 K-PKE.Decrypt(dkPKE, ��)
PkeDecrypt(const CRYPT_MlKemInfo * algInfo,uint8_t * result,const uint8_t * dk,const uint8_t * ciphertext)632 static int32_t PkeDecrypt(const CRYPT_MlKemInfo *algInfo, uint8_t *result, const uint8_t *dk,
633     const uint8_t *ciphertext)
634 {
635     uint8_t i;
636     uint8_t k = algInfo->k;
637     uint32_t n;
638 
639     MLKEM_DecVectorSt st = { 0 };
640     int32_t ret = CreateDecVectorBuf(k, &st);
641     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
642 
643     for (i = 0; i < k; i++) {
644         ByteDecode(st.vectorC1[i], ciphertext + MLKEM_ENCODE_BLOCKSIZE * algInfo->du * i, algInfo->du);  // Step 3
645         ByteDecode(st.vectorS[i], dk + MLKEM_ENCODE_BLOCKSIZE * MLKEM_BITS_OF_Q * i, MLKEM_BITS_OF_Q);   // Step 5
646     }
647     ByteDecode(st.vectorC2, ciphertext + MLKEM_ENCODE_BLOCKSIZE * algInfo->du * k, algInfo->dv);   // Step 4
648 
649     for (i = 0; i < k; i++) {
650         for (n = 0; n < MLKEM_N; n++) {
651             st.vectorC1[i][n] = DeCompress(st.vectorC1[i][n], algInfo->du);  // Step 3
652             if (i == 0) {
653                 st.vectorC2[n] = DeCompress(st.vectorC2[n], algInfo->dv);  // Step 4
654             }
655         }
656         MLKEM_ComputNTT(st.vectorC1[i], PRE_COMPUT_TABLE_NTT, MLKEM_N_HALF);
657     }
658 
659     MLKEM_MatrixMulAdd(k, st.vectorS, st.vectorC1, NULL, st.polyM, PRE_COMPUT_TABLE_NTT);      // Step 6
660 
661     // polyM = intt(polyM)
662     MLKEM_ComputINTT(st.polyM, PRE_COMPUT_TABLE_INTT, MLKEM_N_HALF);
663 
664     // c2 - polyM
665     for (n = 0; n < MLKEM_N; n++) {
666         st.polyM[n] = Compress(st.vectorC2[n] - st.polyM[n], 1);
667     }
668 
669     ByteEncode(result, st.polyM, 1);  // Step 7
670     DecVectorBufFree(k, &st);
671     return CRYPT_SUCCESS;
672 }
673 
674 // NIST.FIPS.203 Algorithm 16 ML-KEM.KeyGen_internal(��,��)
MLKEM_KeyGenInternal(CRYPT_ML_KEM_Ctx * ctx,uint8_t * d,uint8_t * z)675 int32_t MLKEM_KeyGenInternal(CRYPT_ML_KEM_Ctx *ctx, uint8_t *d, uint8_t *z)
676 {
677     const CRYPT_MlKemInfo *algInfo = ctx->info;
678     uint32_t dkPkeLen = MLKEM_CIPHER_LEN * algInfo->k;
679 
680     // (ekPKE,dkPKE) ← K-PKE.KeyGen(��)
681     int32_t ret = PkeKeyGen(ctx, ctx->ek, ctx->dk, d);
682     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
683 
684     // dk ← (dkPKE‖ek‖H(ek)‖��)
685     if (memcpy_s(ctx->dk + dkPkeLen, ctx->dkLen - dkPkeLen, ctx->ek, ctx->ekLen) != EOK) {
686         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
687         return CRYPT_SECUREC_FAIL;
688     }
689 
690     ret = HashFuncH(ctx->ek, ctx->ekLen, ctx->dk + dkPkeLen + ctx->ekLen, CRYPT_SHA3_256_DIGESTSIZE);
691     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
692 
693     if (memcpy_s(ctx->dk + dkPkeLen + ctx->ekLen + CRYPT_SHA3_256_DIGESTSIZE,
694         ctx->dkLen - (dkPkeLen + ctx->ekLen + CRYPT_SHA3_256_DIGESTSIZE), z, MLKEM_SEED_LEN) != EOK) {
695         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
696         return CRYPT_SECUREC_FAIL;
697     }
698     return CRYPT_SUCCESS;
699 }
700 
701 // NIST.FIPS.203 Algorithm 17 ML-KEM.Encaps_internal(ek,��)
MLKEM_EncapsInternal(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t * ctLen,uint8_t * sk,uint32_t * skLen,uint8_t * m)702 int32_t MLKEM_EncapsInternal(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t *ctLen, uint8_t *sk, uint32_t *skLen,
703     uint8_t *m)
704 {
705     uint8_t mhek[MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE];  // m and H(ek)
706     uint8_t kr[CRYPT_SHA3_512_DIGESTSIZE];    // K and r
707 
708     //  (K,r) = G(m || H(ek))
709     (void)memcpy_s(mhek, MLKEM_SEED_LEN, m, MLKEM_SEED_LEN);
710     int32_t ret = HashFuncH(ctx->ek, ctx->ekLen, mhek + MLKEM_SEED_LEN, CRYPT_SHA3_256_DIGESTSIZE);
711     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
712 
713     ret = HashFuncG(mhek, MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE, kr, CRYPT_SHA3_512_DIGESTSIZE);
714     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
715 
716     (void)memcpy_s(sk, *skLen, kr, MLKEM_SHARED_KEY_LEN);
717 
718     // �� ← K-PKE.Encrypt(ek,��,��)
719     ret = PkeEncrypt(ctx, ct, ctx->ek, m, kr + MLKEM_SHARED_KEY_LEN);
720     BSL_SAL_CleanseData(kr, CRYPT_SHA3_512_DIGESTSIZE);
721     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
722 
723     *ctLen = ctx->info->cipherLen;
724     *skLen = ctx->info->sharedLen;
725     return CRYPT_SUCCESS;
726 }
727 
728 // NIST.FIPS.203 Algorithm 18 ML-KEM.Decaps_internal(dk, ��)
MLKEM_DecapsInternal(const CRYPT_ML_KEM_Ctx * ctx,uint8_t * ct,uint32_t ctLen,uint8_t * sk,uint32_t * skLen)729 int32_t MLKEM_DecapsInternal(const CRYPT_ML_KEM_Ctx *ctx, uint8_t *ct, uint32_t ctLen, uint8_t *sk, uint32_t *skLen)
730 {
731     const CRYPT_MlKemInfo *algInfo = ctx->info;
732     const uint8_t *dk = ctx->dk;                            // Step 1  dkPKE ← dk[0 : 384k]
733     const uint8_t *ek = dk + MLKEM_CIPHER_LEN * algInfo->k; // Step 2  ekPKE ← dk[384k : 768k +32]
734     const uint8_t *h = ek + algInfo->encapsKeyLen;          // Step 3  h ← dk[768k +32 : 768k +64]
735     const uint8_t *z = h + MLKEM_SEED_LEN;                  // Step 4  z ← dk[768k +64 : 768k +96]
736 
737     uint8_t mh[MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE];    // m′ and h
738     uint8_t kr[CRYPT_SHA3_512_DIGESTSIZE];    // K' and r'
739 
740     int32_t ret = PkeDecrypt(algInfo, mh, dk, ct);  // Step 5: ��′ ← K-PKE.Decrypt(dkPKE, ��)
741     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
742     // Step 6: (K′,r′) ← G(m′ || h)
743     (void)memcpy_s(mh + MLKEM_SEED_LEN, CRYPT_SHA3_256_DIGESTSIZE, h, CRYPT_SHA3_256_DIGESTSIZE);
744     ret = HashFuncG(mh, MLKEM_SEED_LEN + CRYPT_SHA3_256_DIGESTSIZE, kr, CRYPT_SHA3_512_DIGESTSIZE);
745     RETURN_RET_IF(ret != CRYPT_SUCCESS, ret);
746     // Step 8: ��′ ← K-PKE.Encrypt(ekPKE,��′,��′)
747     uint8_t *r = kr + MLKEM_SHARED_KEY_LEN;
748     uint8_t *newCt = BSL_SAL_Malloc(ctLen + MLKEM_SEED_LEN);
749     RETURN_RET_IF(newCt == NULL, BSL_MALLOC_FAIL);
750     GOTO_ERR_IF(PkeEncrypt(ctx, newCt, ek, mh, r), ret);
751 
752     // Step 9: if c != c′
753     if (memcmp(ct, newCt, ctLen) == 0) {
754         (void)memcpy_s(sk, *skLen, kr, MLKEM_SHARED_KEY_LEN);
755     } else {
756         // Step 7: K = J(z || c)
757         (void)memcpy_s(newCt, ctLen + MLKEM_SEED_LEN, z, MLKEM_SEED_LEN);
758         (void)memcpy_s(newCt + MLKEM_SEED_LEN, ctLen, ct, ctLen);
759         GOTO_ERR_IF(HashFuncJ(newCt, ctLen + MLKEM_SEED_LEN, sk, MLKEM_SHARED_KEY_LEN), ret);
760     }
761     *skLen = MLKEM_SHARED_KEY_LEN;
762 ERR:
763     BSL_SAL_CleanseData(kr, CRYPT_SHA3_512_DIGESTSIZE);
764     BSL_SAL_Free(newCt);
765     return ret;
766 }
767 
768 #endif